diff mbox series

mm/mmap: simplify vma_merge()

Message ID 20240118082312.2801992-1-yajun.deng@linux.dev (mailing list archive)
State New
Headers show
Series mm/mmap: simplify vma_merge() | expand

Commit Message

Yajun Deng Jan. 18, 2024, 8:23 a.m. UTC
These vma_merge() callers will pass mm, anon_vma and file, they all from
vma. There is no need to pass three parameters at the same time.

We will find the current vma in vma_merge(). If we pass the original vma
to vma_merge(), the current vma is actually the original vma or NULL.
So we didn't need to find the current vma with find_vma_intersection().

Pass vma to vma_merge(), and add a check to make sure the current vma
is an existing vma.

Signed-off-by: Yajun Deng <yajun.deng@linux.dev>
---
 mm/mmap.c | 37 +++++++++++++++++--------------------
 1 file changed, 17 insertions(+), 20 deletions(-)

Comments

Liam R. Howlett Jan. 23, 2024, 4:08 p.m. UTC | #1
Adding to the Cc list, because it's vma_merge().

* Yajun Deng <yajun.deng@linux.dev> [240118 03:23]:
> These vma_merge() callers will pass mm, anon_vma and file, they all from
> vma. There is no need to pass three parameters at the same time.
> 
> We will find the current vma in vma_merge().

It sounds like you are adding a search for current to vma_merger(), but
you are removing that part in your patch, so it's odd to say this here.

>If we pass the original vma
> to vma_merge(), the current vma is actually the original vma or NULL.

What do you mean original vma?  The source of the anon_vma, vm_mm, etc?
If so, the 'original' vma could be prev (shifting boundaries in case 4
and 5 in the comments).  I think "vma that was the source of the
arguments" would be more clear than "original vma".

> So we didn't need to find the current vma with find_vma_intersection().
> 
> Pass vma to vma_merge(), and add a check to make sure the current vma
> is an existing vma.

How could it not be an existing vma?  It is dereferenced, so it exists.
Do you mean a vma in the vma tree?

I think this is all to say that we can pass through the vma to figure
out if curr == NULL, or if it's vma directly.

> 
> Signed-off-by: Yajun Deng <yajun.deng@linux.dev>
> ---
>  mm/mmap.c | 37 +++++++++++++++++--------------------
>  1 file changed, 17 insertions(+), 20 deletions(-)
> 
> diff --git a/mm/mmap.c b/mm/mmap.c
> index 49d25172eac8..7e00ae4f39e3 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -860,14 +860,16 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
>   *      area is returned, or the function will return NULL
>   */
>  static struct vm_area_struct
> -*vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
> -	   struct vm_area_struct *prev, unsigned long addr, unsigned long end,
> -	   unsigned long vm_flags, struct anon_vma *anon_vma, struct file *file,
> -	   pgoff_t pgoff, struct mempolicy *policy,
> +*vma_merge(struct vma_iterator *vmi, struct vm_area_struct *prev,
> +	   struct vm_area_struct *curr, unsigned long addr, unsigned long end,
> +	   unsigned long vm_flags, pgoff_t pgoff, struct mempolicy *policy,
>  	   struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
>  	   struct anon_vma_name *anon_name)
>  {
> -	struct vm_area_struct *curr, *next, *res;
> +	struct mm_struct *mm = curr->vm_mm;
> +	struct anon_vma *anon_vma = curr->anon_vma;
> +	struct file *file = curr->vm_file;
> +	struct vm_area_struct *next = NULL, *res;
>  	struct vm_area_struct *vma, *adjust, *remove, *remove2;
>  	struct vm_area_struct *anon_dup = NULL;
>  	struct vma_prepare vp;
> @@ -889,13 +891,12 @@ static struct vm_area_struct
>  		return NULL;
>  
>  	/* Does the input range span an existing VMA? (cases 5 - 8) */
> -	curr = find_vma_intersection(mm, prev ? prev->vm_end : 0, end);
> +	if (prev == curr || addr != curr->vm_start || end > curr->vm_end)
> +		curr = NULL;

It would be nice to have comments about what cases this logic covers,
because reverse engineering it is a pain.  And we have to do it every
time a change occurs in the function, even when we are the ones who
wrote the statement.  I think we can all agree that this function is
painful, but it's improving and thanks for joining.

>  
>  	if (!curr ||			/* cases 1 - 4 */
>  	    end == curr->vm_end)	/* cases 6 - 8, adjacent VMA */
> -		next = vma_lookup(mm, end);
> -	else
> -		next = NULL;		/* case 5 */
> +		next = vma_lookup(mm, end); /* NULL case 5 */

Ah, maybe put the comment about case 5 being null on a different line.
I thought you were saying the vma_lookup() will return NULL, not that it
was initialised as NULL above.  Change the wording to something like
"case 5 set to NULL above" or "case 5 remains NULL".

>  
>  	if (prev) {
>  		vma_start = prev->vm_start;
> @@ -919,7 +920,6 @@ static struct vm_area_struct
>  
>  	/* Verify some invariant that must be enforced by the caller. */
>  	VM_WARN_ON(prev && addr <= prev->vm_start);
> -	VM_WARN_ON(curr && (addr != curr->vm_start || end > curr->vm_end));

Why did you drop this?  I understand you moved basically all of it to an
if statement above, but it's still true, right?  Considering the
trickiness of the function I'd like to keep it if there's no one who
feels strongly about it.

>  	VM_WARN_ON(addr >= end);
>  
...

To increase the chances of actually finding an issue, I would suggest
splitting this into two patches:

1. Just passing through vma.
2. The logic changes to remove that find_vma_intersection() call.

By the way, what are the performance benefits to this change?  It's not
without its own risks - this function has caused subtle bugs that
persisted for several releases in the past and it'd be nice to know what
we are gaining for the risk.

Thanks,
Liam
Yajun Deng Jan. 24, 2024, 3:45 a.m. UTC | #2
On 2024/1/24 00:08, Liam R. Howlett wrote:
> Adding to the Cc list, because it's vma_merge().
>
> * Yajun Deng <yajun.deng@linux.dev> [240118 03:23]:
>> These vma_merge() callers will pass mm, anon_vma and file, they all from
>> vma. There is no need to pass three parameters at the same time.
>>
>> We will find the current vma in vma_merge().
> It sounds like you are adding a search for current to vma_merger(), but
> you are removing that part in your patch, so it's odd to say this here.
>

Okay.

>> If we pass the original vma
>> to vma_merge(), the current vma is actually the original vma or NULL.
> What do you mean original vma?  The source of the anon_vma, vm_mm, etc?
> If so, the 'original' vma could be prev (shifting boundaries in case 4
> and 5 in the comments).  I think "vma that was the source of the
> arguments" would be more clear than "original vma".
>

Okay.

>> So we didn't need to find the current vma with find_vma_intersection().
>>
>> Pass vma to vma_merge(), and add a check to make sure the current vma
>> is an existing vma.
> How could it not be an existing vma?  It is dereferenced, so it exists.
> Do you mean a vma in the vma tree?
It means the current vma is NULL or not.
> I think this is all to say that we can pass through the vma to figure
> out if curr == NULL, or if it's vma directly.
>
Okay.
>> Signed-off-by: Yajun Deng <yajun.deng@linux.dev>
>> ---
>>   mm/mmap.c | 37 +++++++++++++++++--------------------
>>   1 file changed, 17 insertions(+), 20 deletions(-)
>>
>> diff --git a/mm/mmap.c b/mm/mmap.c
>> index 49d25172eac8..7e00ae4f39e3 100644
>> --- a/mm/mmap.c
>> +++ b/mm/mmap.c
>> @@ -860,14 +860,16 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
>>    *      area is returned, or the function will return NULL
>>    */
>>   static struct vm_area_struct
>> -*vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
>> -	   struct vm_area_struct *prev, unsigned long addr, unsigned long end,
>> -	   unsigned long vm_flags, struct anon_vma *anon_vma, struct file *file,
>> -	   pgoff_t pgoff, struct mempolicy *policy,
>> +*vma_merge(struct vma_iterator *vmi, struct vm_area_struct *prev,
>> +	   struct vm_area_struct *curr, unsigned long addr, unsigned long end,
>> +	   unsigned long vm_flags, pgoff_t pgoff, struct mempolicy *policy,
>>   	   struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
>>   	   struct anon_vma_name *anon_name)
>>   {
>> -	struct vm_area_struct *curr, *next, *res;
>> +	struct mm_struct *mm = curr->vm_mm;
>> +	struct anon_vma *anon_vma = curr->anon_vma;
>> +	struct file *file = curr->vm_file;
>> +	struct vm_area_struct *next = NULL, *res;
>>   	struct vm_area_struct *vma, *adjust, *remove, *remove2;
>>   	struct vm_area_struct *anon_dup = NULL;
>>   	struct vma_prepare vp;
>> @@ -889,13 +891,12 @@ static struct vm_area_struct
>>   		return NULL;
>>   
>>   	/* Does the input range span an existing VMA? (cases 5 - 8) */
>> -	curr = find_vma_intersection(mm, prev ? prev->vm_end : 0, end);
>> +	if (prev == curr || addr != curr->vm_start || end > curr->vm_end)
>> +		curr = NULL;
> It would be nice to have comments about what cases this logic covers,
> because reverse engineering it is a pain.  And we have to do it every
> time a change occurs in the function, even when we are the ones who
> wrote the statement.  I think we can all agree that this function is
> painful, but it's improving and thanks for joining.


Okay.

>>   
>>   	if (!curr ||			/* cases 1 - 4 */
>>   	    end == curr->vm_end)	/* cases 6 - 8, adjacent VMA */
>> -		next = vma_lookup(mm, end);
>> -	else
>> -		next = NULL;		/* case 5 */
>> +		next = vma_lookup(mm, end); /* NULL case 5 */
> Ah, maybe put the comment about case 5 being null on a different line.
> I thought you were saying the vma_lookup() will return NULL, not that it
> was initialised as NULL above.  Change the wording to something like
> "case 5 set to NULL above" or "case 5 remains NULL".
>

Okay.

>>   
>>   	if (prev) {
>>   		vma_start = prev->vm_start;
>> @@ -919,7 +920,6 @@ static struct vm_area_struct
>>   
>>   	/* Verify some invariant that must be enforced by the caller. */
>>   	VM_WARN_ON(prev && addr <= prev->vm_start);
>> -	VM_WARN_ON(curr && (addr != curr->vm_start || end > curr->vm_end));
> Why did you drop this?  I understand you moved basically all of it to an
> if statement above, but it's still true, right?  Considering the
> trickiness of the function I'd like to keep it if there's no one who
> feels strongly about it.


I don't think we need this. We move this to the front of the function, 
addr, end and curr won't be

changed until then.

>>   	VM_WARN_ON(addr >= end);
>>   
> ...
>
> To increase the chances of actually finding an issue, I would suggest
> splitting this into two patches:
>
> 1. Just passing through vma.
> 2. The logic changes to remove that find_vma_intersection() call.
Okay.
> By the way, what are the performance benefits to this change?  It's not
> without its own risks - this function has caused subtle bugs that
> persisted for several releases in the past and it'd be nice to know what
> we are gaining for the risk.


No, I just found out that the current vma is the source vma. So we don't 
need to find the current

vma with find_vma_intersection().

I think we can add some case about vma_merge() to the LTP project. It 
currently has 5 test cases

about vma, but it doesn't seem to detect the risk of vma_merge().

Link: http://linux-test-project.github.io/

>
> Thanks,
> Liam
diff mbox series

Patch

diff --git a/mm/mmap.c b/mm/mmap.c
index 49d25172eac8..7e00ae4f39e3 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -860,14 +860,16 @@  can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
  *      area is returned, or the function will return NULL
  */
 static struct vm_area_struct
-*vma_merge(struct vma_iterator *vmi, struct mm_struct *mm,
-	   struct vm_area_struct *prev, unsigned long addr, unsigned long end,
-	   unsigned long vm_flags, struct anon_vma *anon_vma, struct file *file,
-	   pgoff_t pgoff, struct mempolicy *policy,
+*vma_merge(struct vma_iterator *vmi, struct vm_area_struct *prev,
+	   struct vm_area_struct *curr, unsigned long addr, unsigned long end,
+	   unsigned long vm_flags, pgoff_t pgoff, struct mempolicy *policy,
 	   struct vm_userfaultfd_ctx vm_userfaultfd_ctx,
 	   struct anon_vma_name *anon_name)
 {
-	struct vm_area_struct *curr, *next, *res;
+	struct mm_struct *mm = curr->vm_mm;
+	struct anon_vma *anon_vma = curr->anon_vma;
+	struct file *file = curr->vm_file;
+	struct vm_area_struct *next = NULL, *res;
 	struct vm_area_struct *vma, *adjust, *remove, *remove2;
 	struct vm_area_struct *anon_dup = NULL;
 	struct vma_prepare vp;
@@ -889,13 +891,12 @@  static struct vm_area_struct
 		return NULL;
 
 	/* Does the input range span an existing VMA? (cases 5 - 8) */
-	curr = find_vma_intersection(mm, prev ? prev->vm_end : 0, end);
+	if (prev == curr || addr != curr->vm_start || end > curr->vm_end)
+		curr = NULL;
 
 	if (!curr ||			/* cases 1 - 4 */
 	    end == curr->vm_end)	/* cases 6 - 8, adjacent VMA */
-		next = vma_lookup(mm, end);
-	else
-		next = NULL;		/* case 5 */
+		next = vma_lookup(mm, end); /* NULL case 5 */
 
 	if (prev) {
 		vma_start = prev->vm_start;
@@ -919,7 +920,6 @@  static struct vm_area_struct
 
 	/* Verify some invariant that must be enforced by the caller. */
 	VM_WARN_ON(prev && addr <= prev->vm_start);
-	VM_WARN_ON(curr && (addr != curr->vm_start || end > curr->vm_end));
 	VM_WARN_ON(addr >= end);
 
 	if (!merge_prev && !merge_next)
@@ -2424,9 +2424,8 @@  struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
 	pgoff_t pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
 	struct vm_area_struct *merged;
 
-	merged = vma_merge(vmi, vma->vm_mm, prev, start, end, vm_flags,
-			   vma->anon_vma, vma->vm_file, pgoff, policy,
-			   uffd_ctx, anon_name);
+	merged = vma_merge(vmi, prev, vma, start, end, vm_flags,
+			   pgoff, policy, uffd_ctx, anon_name);
 	if (merged)
 		return merged;
 
@@ -2456,9 +2455,8 @@  static struct vm_area_struct
 		   struct vm_area_struct *vma, unsigned long start,
 		   unsigned long end, pgoff_t pgoff)
 {
-	return vma_merge(vmi, vma->vm_mm, prev, start, end, vma->vm_flags,
-			 vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
-			 vma->vm_userfaultfd_ctx, anon_vma_name(vma));
+	return vma_merge(vmi, prev, vma, start, end, vma->vm_flags, pgoff,
+			 vma_policy(vma), vma->vm_userfaultfd_ctx, anon_vma_name(vma));
 }
 
 /*
@@ -2472,10 +2470,9 @@  struct vm_area_struct *vma_merge_extend(struct vma_iterator *vmi,
 	pgoff_t pgoff = vma->vm_pgoff + vma_pages(vma);
 
 	/* vma is specified as prev, so case 1 or 2 will apply. */
-	return vma_merge(vmi, vma->vm_mm, vma, vma->vm_end, vma->vm_end + delta,
-			 vma->vm_flags, vma->anon_vma, vma->vm_file, pgoff,
-			 vma_policy(vma), vma->vm_userfaultfd_ctx,
-			 anon_vma_name(vma));
+	return vma_merge(vmi, vma, vma, vma->vm_end, vma->vm_end + delta,
+			 vma->vm_flags, pgoff, vma_policy(vma),
+			 vma->vm_userfaultfd_ctx, anon_vma_name(vma));
 }
 
 /*