diff mbox series

[v5,4/4] mm/gup: clean up codes in fault_in_xxx() functions

Message ID Z/sbv3EmLXWgEE7+@MiWiFi-R3L-srv (mailing list archive)
State New
Headers show
Series None | expand

Commit Message

Baoquan He April 13, 2025, 2:04 a.m. UTC
The code style in fault_in_readable() and fault_in_writable() is a
little inconsistent with fault_in_safe_writeable(). In fault_in_readable()
and fault_in_writable(), it uses 'uaddr' passed in as loop cursor. While
in fault_in_safe_writeable(), local variable 'start' is used as loop
cursor. This may mislead people when reading code or making change in
these codes.

Here define explicit loop cursor and use for loop to simplify codes in
these three functions. These cleanup can make them be consistent in
code style and improve readability.

Signed-off-by: Baoquan He <bhe@redhat.com>
---
v4->v5:
Address minor concerns from David:
- Remove one blank line in fault_in_writable() added in v4;
- Put the loop cursor initialization 'cur = start;' into the for loop
  initialization part.
  
 mm/gup.c | 62 ++++++++++++++++++++++----------------------------------
 1 file changed, 24 insertions(+), 38 deletions(-)

Comments

David Hildenbrand April 13, 2025, 8:09 p.m. UTC | #1
On 13.04.25 04:04, Baoquan He wrote:
> The code style in fault_in_readable() and fault_in_writable() is a
> little inconsistent with fault_in_safe_writeable(). In fault_in_readable()
> and fault_in_writable(), it uses 'uaddr' passed in as loop cursor. While
> in fault_in_safe_writeable(), local variable 'start' is used as loop
> cursor. This may mislead people when reading code or making change in
> these codes.
> 
> Here define explicit loop cursor and use for loop to simplify codes in
> these three functions. These cleanup can make them be consistent in
> code style and improve readability.
> 
> Signed-off-by: Baoquan He <bhe@redhat.com>
> ---

Hopefully we don't introduce anything unexpected ... do we have some 
unit test that could make use feel better, especially regarding end < start?

If not, could we add one based on some feature that ends up calling at 
least one of these functions?

Acked-by: David Hildenbrand <david@redhat.com>
Baoquan He April 14, 2025, 3:44 a.m. UTC | #2
On 04/13/25 at 10:09pm, David Hildenbrand wrote:
> On 13.04.25 04:04, Baoquan He wrote:
> > The code style in fault_in_readable() and fault_in_writable() is a
> > little inconsistent with fault_in_safe_writeable(). In fault_in_readable()
> > and fault_in_writable(), it uses 'uaddr' passed in as loop cursor. While
> > in fault_in_safe_writeable(), local variable 'start' is used as loop
> > cursor. This may mislead people when reading code or making change in
> > these codes.
> > 
> > Here define explicit loop cursor and use for loop to simplify codes in
> > these three functions. These cleanup can make them be consistent in
> > code style and improve readability.
> > 
> > Signed-off-by: Baoquan He <bhe@redhat.com>
> > ---
> 
> Hopefully we don't introduce anything unexpected ... do we have some unit
> test that could make use feel better, especially regarding end < start?
> 
> If not, could we add one based on some feature that ends up calling at least
> one of these functions?

Seems no existing case. GUP has selftests, no test codes for kunit. I will see
if I can add one, maybe it's not easy.

> 
> Acked-by: David Hildenbrand <david@redhat.com>

Thanks.
diff mbox series

Patch

diff --git a/mm/gup.c b/mm/gup.c
index 77a5bc622567..f32168339390 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -2113,28 +2113,22 @@  static long __get_user_pages_locked(struct mm_struct *mm, unsigned long start,
  */
 size_t fault_in_writeable(char __user *uaddr, size_t size)
 {
-	char __user *start = uaddr, *end;
+	const unsigned long start = (unsigned long)uaddr;
+	const unsigned long end = start + size;
+	unsigned long cur;
 
 	if (unlikely(size == 0))
 		return 0;
 	if (!user_write_access_begin(uaddr, size))
 		return size;
-	if (!PAGE_ALIGNED(uaddr)) {
-		unsafe_put_user(0, uaddr, out);
-		uaddr = (char __user *)PAGE_ALIGN((unsigned long)uaddr);
-	}
-	end = (char __user *)PAGE_ALIGN((unsigned long)start + size);
-	if (unlikely(end < start))
-		end = NULL;
-	while (uaddr != end) {
-		unsafe_put_user(0, uaddr, out);
-		uaddr += PAGE_SIZE;
-	}
 
+	/* Stop once we overflow to 0. */
+	for (cur = start; cur && cur < end; cur = PAGE_ALIGN_DOWN(cur + PAGE_SIZE))
+		unsafe_put_user(0, (char __user *)cur, out);
 out:
 	user_write_access_end();
-	if (size > uaddr - start)
-		return size - (uaddr - start);
+	if (size > cur - start)
+		return size - (cur - start);
 	return 0;
 }
 EXPORT_SYMBOL(fault_in_writeable);
@@ -2188,26 +2182,24 @@  EXPORT_SYMBOL(fault_in_subpage_writeable);
  */
 size_t fault_in_safe_writeable(const char __user *uaddr, size_t size)
 {
-	unsigned long start = (unsigned long)uaddr, end;
+	const unsigned long start = (unsigned long)uaddr;
+	const unsigned long end = start + size;
+	unsigned long cur;
 	struct mm_struct *mm = current->mm;
 	bool unlocked = false;
 
 	if (unlikely(size == 0))
 		return 0;
-	end = PAGE_ALIGN(start + size);
-	if (end < start)
-		end = 0;
 
 	mmap_read_lock(mm);
-	do {
-		if (fixup_user_fault(mm, start, FAULT_FLAG_WRITE, &unlocked))
+	/* Stop once we overflow to 0. */
+	for (cur = start; cur && cur < end; cur = PAGE_ALIGN_DOWN(cur + PAGE_SIZE))
+		if (fixup_user_fault(mm, cur, FAULT_FLAG_WRITE, &unlocked))
 			break;
-		start = (start + PAGE_SIZE) & PAGE_MASK;
-	} while (start != end);
 	mmap_read_unlock(mm);
 
-	if (size > start - (unsigned long)uaddr)
-		return size - (start - (unsigned long)uaddr);
+	if (size > cur - start)
+		return size - (cur - start);
 	return 0;
 }
 EXPORT_SYMBOL(fault_in_safe_writeable);
@@ -2222,30 +2214,24 @@  EXPORT_SYMBOL(fault_in_safe_writeable);
  */
 size_t fault_in_readable(const char __user *uaddr, size_t size)
 {
-	const char __user *start = uaddr, *end;
+	const unsigned long start = (unsigned long)uaddr;
+	const unsigned long end = start + size;
+	unsigned long cur;
 	volatile char c;
 
 	if (unlikely(size == 0))
 		return 0;
 	if (!user_read_access_begin(uaddr, size))
 		return size;
-	if (!PAGE_ALIGNED(uaddr)) {
-		unsafe_get_user(c, uaddr, out);
-		uaddr = (const char __user *)PAGE_ALIGN((unsigned long)uaddr);
-	}
-	end = (const char __user *)PAGE_ALIGN((unsigned long)start + size);
-	if (unlikely(end < start))
-		end = NULL;
-	while (uaddr != end) {
-		unsafe_get_user(c, uaddr, out);
-		uaddr += PAGE_SIZE;
-	}
 
+	/* Stop once we overflow to 0. */
+	for (cur = start; cur && cur < end; cur = PAGE_ALIGN_DOWN(cur + PAGE_SIZE))
+		unsafe_get_user(c, (const char __user *)cur, out);
 out:
 	user_read_access_end();
 	(void)c;
-	if (size > uaddr - start)
-		return size - (uaddr - start);
+	if (size > cur - start)
+		return size - (cur - start);
 	return 0;
 }
 EXPORT_SYMBOL(fault_in_readable);