diff mbox series

[v3,1/3] mm/gup: fix wrongly calculated returned value in fault_in_safe_writeable()

Message ID 20250407030306.411977-2-bhe@redhat.com (mailing list archive)
State New
Headers show
Series mm/gup: Minor fix, cleanup and improvements | expand

Commit Message

Baoquan He April 7, 2025, 3:03 a.m. UTC
Not like fault_in_readable() or fault_in_writeable(), in
fault_in_safe_writeable() local variable 'start' is increased page
by page to loop till the whole address range is handled. However,
it mistakenly calcalates the size of handled range with 'uaddr - start'.

Here fix the code bug in fault_in_safe_writeable(), and also adjusting
the codes in fault_in_readable() and fault_in_writeable() to use local
variable 'start' to loop so that codes in these three functions are
consistent.

Signed-off-by: Baoquan He <bhe@redhat.com>
---
 mm/gup.c | 40 ++++++++++++++++++++--------------------
 1 file changed, 20 insertions(+), 20 deletions(-)

Comments

Oscar Salvador April 8, 2025, 9:40 a.m. UTC | #1
On Mon, Apr 07, 2025 at 11:03:04AM +0800, Baoquan He wrote:
> Not like fault_in_readable() or fault_in_writeable(), in
> fault_in_safe_writeable() local variable 'start' is increased page
> by page to loop till the whole address range is handled. However,
> it mistakenly calcalates the size of handled range with 'uaddr - start'.
                ^^ calculates
> 
> Here fix the code bug in fault_in_safe_writeable(), and also adjusting
> the codes in fault_in_readable() and fault_in_writeable() to use local
> variable 'start' to loop so that codes in these three functions are
> consistent.
> 
> Signed-off-by: Baoquan He <bhe@redhat.com>

The fix for the bug in fault_in_safe_writeable() looks good to me.
But I think that David suggested the other way around wrt. uaddr and
start variables in those three functions? I think he had in mind that
fault_in_safe_writeable() follows fault_in_safe_writeable() and
fault_in_readable() lead.

Other than that looks good to me.
David Hildenbrand April 8, 2025, 9:52 a.m. UTC | #2
On 07.04.25 05:03, Baoquan He wrote:
> Not like fault_in_readable() or fault_in_writeable(), in
> fault_in_safe_writeable() local variable 'start' is increased page
> by page to loop till the whole address range is handled. However,
> it mistakenly calcalates the size of handled range with 'uaddr - start'.
> 
> Here fix the code bug in fault_in_safe_writeable(), and also adjusting
> the codes in fault_in_readable() and fault_in_writeable() to use local
> variable 'start' to loop so that codes in these three functions are
> consistent.
> 

I probably phrased it poorly in my other reply: the confusing part (to 
me) is adjusting "start". Maybe we should have unsigned long start,end,cur;

Maybe we should really split the "fix" from the cleanups, and tag the 
fix with a Fixes:.

I was wondering if these functions could be simplified a bit. But the 
overflow handling is a bit nasty.
David Hildenbrand April 8, 2025, 10 a.m. UTC | #3
On 08.04.25 11:52, David Hildenbrand wrote:
> On 07.04.25 05:03, Baoquan He wrote:
>> Not like fault_in_readable() or fault_in_writeable(), in
>> fault_in_safe_writeable() local variable 'start' is increased page
>> by page to loop till the whole address range is handled. However,
>> it mistakenly calcalates the size of handled range with 'uaddr - start'.
>>
>> Here fix the code bug in fault_in_safe_writeable(), and also adjusting
>> the codes in fault_in_readable() and fault_in_writeable() to use local
>> variable 'start' to loop so that codes in these three functions are
>> consistent.
>>
> 
> I probably phrased it poorly in my other reply: the confusing part (to
> me) is adjusting "start". Maybe we should have unsigned long start,end,cur;
> 
> Maybe we should really split the "fix" from the cleanups, and tag the
> fix with a Fixes:.
> 
> I was wondering if these functions could be simplified a bit. But the
> overflow handling is a bit nasty.

FWIW, maybe the following could work and clarify things. Just a thought.


diff --git a/mm/gup.c b/mm/gup.c
index 92351e2fa876b..7a3f78a209f8b 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -2223,30 +2223,23 @@ EXPORT_SYMBOL(fault_in_safe_writeable);
   */
  size_t fault_in_readable(const char __user *uaddr, size_t size)
  {
-       const char __user *start = uaddr, *end;
+       const unsigned long start = (unsigned long)uaddr;
+       const unsigned long end = start + size;
+       unsigned long cur;
         volatile char c;
  
         if (unlikely(size == 0))
                 return 0;
         if (!user_read_access_begin(uaddr, size))
                 return size;
-       if (!PAGE_ALIGNED(uaddr)) {
-               unsafe_get_user(c, uaddr, out);
-               uaddr = (const char __user *)PAGE_ALIGN((unsigned long)uaddr);
-       }
-       end = (const char __user *)PAGE_ALIGN((unsigned long)start + size);
-       if (unlikely(end < start))
-               end = NULL;
-       while (uaddr != end) {
-               unsafe_get_user(c, uaddr, out);
-               uaddr += PAGE_SIZE;
-       }
-
-out:
+       /* Stop once we overflow to 0. */
+       for (cur = start; cur && cur < end; cur = PAGE_ALIGN_DOWN(cur + PAGE_SIZE))
+               unsafe_get_user(c, (const char __user *)cur, out);
         user_read_access_end();
         (void)c;
-       if (size > uaddr - start)
-               return size - (uaddr - start);
+out:
+       if (size > cur - start)
+               return size - (cur - start);
         return 0;
  }
  EXPORT_SYMBOL(fault_in_readable);
Baoquan He April 8, 2025, 2:59 p.m. UTC | #4
On 04/08/25 at 12:00pm, David Hildenbrand wrote:
> On 08.04.25 11:52, David Hildenbrand wrote:
> > On 07.04.25 05:03, Baoquan He wrote:
> > > Not like fault_in_readable() or fault_in_writeable(), in
> > > fault_in_safe_writeable() local variable 'start' is increased page
> > > by page to loop till the whole address range is handled. However,
> > > it mistakenly calcalates the size of handled range with 'uaddr - start'.
> > > 
> > > Here fix the code bug in fault_in_safe_writeable(), and also adjusting
> > > the codes in fault_in_readable() and fault_in_writeable() to use local
> > > variable 'start' to loop so that codes in these three functions are
> > > consistent.
> > > 
> > 
> > I probably phrased it poorly in my other reply: the confusing part (to
> > me) is adjusting "start". Maybe we should have unsigned long start,end,cur;
> > 
> > Maybe we should really split the "fix" from the cleanups, and tag the
> > fix with a Fixes:.

> > 
> > I was wondering if these functions could be simplified a bit. But the
> > overflow handling is a bit nasty.
> 
> FWIW, maybe the following could work and clarify things. Just a thought.

The code simplification looks great to me. I will make a patch to only
contains the code bug fixing with Fixes so that it's eaiser to back port
to stable kernel, and make another patch as below to refactor codes in
fault_in_readable/writable/safe_writable(). Thanks for suggestion.

> 
> 
> diff --git a/mm/gup.c b/mm/gup.c
> index 92351e2fa876b..7a3f78a209f8b 100644
> --- a/mm/gup.c
> +++ b/mm/gup.c
> @@ -2223,30 +2223,23 @@ EXPORT_SYMBOL(fault_in_safe_writeable);
>   */
>  size_t fault_in_readable(const char __user *uaddr, size_t size)
>  {
> -       const char __user *start = uaddr, *end;
> +       const unsigned long start = (unsigned long)uaddr;
> +       const unsigned long end = start + size;
> +       unsigned long cur;
>         volatile char c;
>         if (unlikely(size == 0))
>                 return 0;
>         if (!user_read_access_begin(uaddr, size))
>                 return size;
> -       if (!PAGE_ALIGNED(uaddr)) {
> -               unsafe_get_user(c, uaddr, out);
> -               uaddr = (const char __user *)PAGE_ALIGN((unsigned long)uaddr);
> -       }
> -       end = (const char __user *)PAGE_ALIGN((unsigned long)start + size);
> -       if (unlikely(end < start))
> -               end = NULL;
> -       while (uaddr != end) {
> -               unsafe_get_user(c, uaddr, out);
> -               uaddr += PAGE_SIZE;
> -       }
> -
> -out:
> +       /* Stop once we overflow to 0. */
> +       for (cur = start; cur && cur < end; cur = PAGE_ALIGN_DOWN(cur + PAGE_SIZE))
> +               unsafe_get_user(c, (const char __user *)cur, out);
>         user_read_access_end();
>         (void)c;
> -       if (size > uaddr - start)
> -               return size - (uaddr - start);
> +out:
> +       if (size > cur - start)
> +               return size - (cur - start);
>         return 0;
>  }
>  EXPORT_SYMBOL(fault_in_readable);
> 
> 
> -- 
> Cheers,
> 
> David / dhildenb
>
Baoquan He April 8, 2025, 3:01 p.m. UTC | #5
On 04/08/25 at 11:40am, Oscar Salvador wrote:
> On Mon, Apr 07, 2025 at 11:03:04AM +0800, Baoquan He wrote:
> > Not like fault_in_readable() or fault_in_writeable(), in
> > fault_in_safe_writeable() local variable 'start' is increased page
> > by page to loop till the whole address range is handled. However,
> > it mistakenly calcalates the size of handled range with 'uaddr - start'.
>                 ^^ calculates

Will fix, thanks.
> > 
> > Here fix the code bug in fault_in_safe_writeable(), and also adjusting
> > the codes in fault_in_readable() and fault_in_writeable() to use local
> > variable 'start' to loop so that codes in these three functions are
> > consistent.
> > 
> > Signed-off-by: Baoquan He <bhe@redhat.com>
> 
> The fix for the bug in fault_in_safe_writeable() looks good to me.
> But I think that David suggested the other way around wrt. uaddr and
> start variables in those three functions? I think he had in mind that
> fault_in_safe_writeable() follows fault_in_safe_writeable() and
> fault_in_readable() lead.

Right, will follow the way he suggested in another sub-thread, thanks
for careful reviewing.

> 
> Other than that looks good to me.
> 
> 
> -- 
> Oscar Salvador
> SUSE Labs
>
diff mbox series

Patch

diff --git a/mm/gup.c b/mm/gup.c
index 92351e2fa876..67a7de9e4f80 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -2114,7 +2114,7 @@  static long __get_user_pages_locked(struct mm_struct *mm, unsigned long start,
  */
 size_t fault_in_writeable(char __user *uaddr, size_t size)
 {
-	char __user *start = uaddr, *end;
+	unsigned long start = (unsigned long)uaddr, end;
 
 	if (unlikely(size == 0))
 		return 0;
@@ -2122,20 +2122,20 @@  size_t fault_in_writeable(char __user *uaddr, size_t size)
 		return size;
 	if (!PAGE_ALIGNED(uaddr)) {
 		unsafe_put_user(0, uaddr, out);
-		uaddr = (char __user *)PAGE_ALIGN((unsigned long)uaddr);
+		start = PAGE_ALIGN((unsigned long)uaddr);
 	}
-	end = (char __user *)PAGE_ALIGN((unsigned long)start + size);
+	end = PAGE_ALIGN(start + size);
 	if (unlikely(end < start))
-		end = NULL;
-	while (uaddr != end) {
-		unsafe_put_user(0, uaddr, out);
-		uaddr += PAGE_SIZE;
+		end = 0;
+	while (start != end) {
+		unsafe_put_user(0, (char __user *)start, out);
+		start += PAGE_SIZE;
 	}
 
 out:
 	user_write_access_end();
-	if (size > uaddr - start)
-		return size - (uaddr - start);
+	if (size > start - (unsigned long)uaddr)
+		return size - (start - (unsigned long)uaddr);
 	return 0;
 }
 EXPORT_SYMBOL(fault_in_writeable);
@@ -2207,8 +2207,8 @@  size_t fault_in_safe_writeable(const char __user *uaddr, size_t size)
 	} while (start != end);
 	mmap_read_unlock(mm);
 
-	if (size > (unsigned long)uaddr - start)
-		return size - ((unsigned long)uaddr - start);
+	if (size > start - (unsigned long)uaddr)
+		return size - (start - (unsigned long)uaddr);
 	return 0;
 }
 EXPORT_SYMBOL(fault_in_safe_writeable);
@@ -2223,7 +2223,7 @@  EXPORT_SYMBOL(fault_in_safe_writeable);
  */
 size_t fault_in_readable(const char __user *uaddr, size_t size)
 {
-	const char __user *start = uaddr, *end;
+	unsigned long start = (unsigned long)uaddr, end;
 	volatile char c;
 
 	if (unlikely(size == 0))
@@ -2232,21 +2232,21 @@  size_t fault_in_readable(const char __user *uaddr, size_t size)
 		return size;
 	if (!PAGE_ALIGNED(uaddr)) {
 		unsafe_get_user(c, uaddr, out);
-		uaddr = (const char __user *)PAGE_ALIGN((unsigned long)uaddr);
+		start = PAGE_ALIGN((unsigned long)uaddr);
 	}
-	end = (const char __user *)PAGE_ALIGN((unsigned long)start + size);
+	end = PAGE_ALIGN(start + size);
 	if (unlikely(end < start))
-		end = NULL;
-	while (uaddr != end) {
-		unsafe_get_user(c, uaddr, out);
-		uaddr += PAGE_SIZE;
+		end = 0;
+	while (start != end) {
+		unsafe_get_user(c, (const char __user *)start, out);
+		start += PAGE_SIZE;
 	}
 
 out:
 	user_read_access_end();
 	(void)c;
-	if (size > uaddr - start)
-		return size - (uaddr - start);
+	if (size > start - (unsigned long)uaddr)
+		return size - (start - (unsigned long)uaddr);
 	return 0;
 }
 EXPORT_SYMBOL(fault_in_readable);