diff mbox series

[v2,2/5] mm: abstract the vma_merge()/split_vma() pattern for mprotect() et al.

Message ID ade506aa09184dc06d57785fe90a6076682556ca.1696884493.git.lstoakes@gmail.com (mailing list archive)
State New
Headers show
Series Abstract vma_merge() and split_vma() | expand

Commit Message

Lorenzo Stoakes Oct. 9, 2023, 8:53 p.m. UTC
mprotect() and other functions which change VMA parameters over a range
each employ a pattern of:-

1. Attempt to merge the range with adjacent VMAs.
2. If this fails, and the range spans a subset of the VMA, split it
   accordingly.

This is open-coded and duplicated in each case. Also in each case most of
the parameters passed to vma_merge() remain the same.

Create a new function, vma_modify(), which abstracts this operation,
accepting only those parameters which can be changed.

To avoid the mess of invoking each function call with unnecessary
parameters, create inline wrapper functions for each of the modify
operations, parameterised only by what is required to perform the action.

Note that the userfaultfd_release() case works even though it does not
split VMAs - since start is set to vma->vm_start and end is set to
vma->vm_end, the split logic does not trigger.

In addition, since we calculate pgoff to be equal to vma->vm_pgoff + (start
- vma->vm_start) >> PAGE_SHIFT, and start - vma->vm_start will be 0 in this
instance, this invocation will remain unchanged.

Signed-off-by: Lorenzo Stoakes <lstoakes@gmail.com>
---
 fs/userfaultfd.c   | 69 +++++++++++++++-------------------------------
 include/linux/mm.h | 60 ++++++++++++++++++++++++++++++++++++++++
 mm/madvise.c       | 32 ++++++---------------
 mm/mempolicy.c     | 22 +++------------
 mm/mlock.c         | 27 +++++-------------
 mm/mmap.c          | 45 ++++++++++++++++++++++++++++++
 mm/mprotect.c      | 35 +++++++----------------
 7 files changed, 157 insertions(+), 133 deletions(-)

Comments

Vlastimil Babka Oct. 10, 2023, 7:12 a.m. UTC | #1
On 10/9/23 22:53, Lorenzo Stoakes wrote:
> mprotect() and other functions which change VMA parameters over a range
> each employ a pattern of:-
> 
> 1. Attempt to merge the range with adjacent VMAs.
> 2. If this fails, and the range spans a subset of the VMA, split it
>    accordingly.
> 
> This is open-coded and duplicated in each case. Also in each case most of
> the parameters passed to vma_merge() remain the same.
> 
> Create a new function, vma_modify(), which abstracts this operation,
> accepting only those parameters which can be changed.
> 
> To avoid the mess of invoking each function call with unnecessary
> parameters, create inline wrapper functions for each of the modify
> operations, parameterised only by what is required to perform the action.
> 
> Note that the userfaultfd_release() case works even though it does not
> split VMAs - since start is set to vma->vm_start and end is set to
> vma->vm_end, the split logic does not trigger.
> 
> In addition, since we calculate pgoff to be equal to vma->vm_pgoff + (start
> - vma->vm_start) >> PAGE_SHIFT, and start - vma->vm_start will be 0 in this
> instance, this invocation will remain unchanged.
> 
> Signed-off-by: Lorenzo Stoakes <lstoakes@gmail.com>

Reviewed-by: Vlastimil Babka <vbabka@suse.cz>

some nits below:

> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -2437,6 +2437,51 @@ int split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  	return __split_vma(vmi, vma, addr, new_below);
>  }
>  
> +/*
> + * We are about to modify one or multiple of a VMA's flags, policy, userfaultfd
> + * context and anonymous VMA name within the range [start, end).
> + *
> + * As a result, we might be able to merge the newly modified VMA range with an
> + * adjacent VMA with identical properties.
> + *
> + * If no merge is possible and the range does not span the entirety of the VMA,
> + * we then need to split the VMA to accommodate the change.
> + */

This could describe the return value too? It's not entirely trivial.
But I also wonder if we could just return 'vma' for the split_vma() cases
and the callers could simply stop distinguishing whether there was a merge
or split, and their code would become even simpler?
It seems to me most callers don't care, except mprotect, see below...

> +struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
> +				  struct vm_area_struct *prev,
> +				  struct vm_area_struct *vma,
> +				  unsigned long start, unsigned long end,
> +				  unsigned long vm_flags,
> +				  struct mempolicy *policy,
> +				  struct vm_userfaultfd_ctx uffd_ctx,
> +				  struct anon_vma_name *anon_name)
> +{
> +	pgoff_t pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> +	struct vm_area_struct *merged;
> +
> +	merged = vma_merge(vmi, vma->vm_mm, prev, start, end, vm_flags,
> +			   vma->anon_vma, vma->vm_file, pgoff, policy,
> +			   uffd_ctx, anon_name);
> +	if (merged)
> +		return merged;
> +
> +	if (vma->vm_start < start) {
> +		int err = split_vma(vmi, vma, start, 1);
> +
> +		if (err)
> +			return ERR_PTR(err);
> +	}
> +
> +	if (vma->vm_end > end) {
> +		int err = split_vma(vmi, vma, end, 0);
> +
> +		if (err)
> +			return ERR_PTR(err);
> +	}
> +
> +	return NULL;
> +}
> +
>  /*
>   * do_vmi_align_munmap() - munmap the aligned region from @start to @end.
>   * @vmi: The vma iterator
> diff --git a/mm/mprotect.c b/mm/mprotect.c
> index b94fbb45d5c7..6f85d99682ab 100644
> --- a/mm/mprotect.c
> +++ b/mm/mprotect.c
> @@ -581,7 +581,7 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
>  	long nrpages = (end - start) >> PAGE_SHIFT;
>  	unsigned int mm_cp_flags = 0;
>  	unsigned long charged = 0;
> -	pgoff_t pgoff;
> +	struct vm_area_struct *merged;
>  	int error;
>  
>  	if (newflags == oldflags) {
> @@ -625,34 +625,19 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
>  		}
>  	}
>  
> -	/*
> -	 * First try to merge with previous and/or next vma.
> -	 */
> -	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -	*pprev = vma_merge(vmi, mm, *pprev, start, end, newflags,
> -			   vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> -			   vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> -	if (*pprev) {
> -		vma = *pprev;
> -		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
> -		goto success;
> +	merged = vma_modify_flags(vmi, *pprev, vma, start, end, newflags);
> +	if (IS_ERR(merged)) {
> +		error = PTR_ERR(merged);
> +		goto fail;
>  	}
>  
> -	*pprev = vma;
> -
> -	if (start != vma->vm_start) {
> -		error = split_vma(vmi, vma, start, 1);
> -		if (error)
> -			goto fail;
> -	}
> -
> -	if (end != vma->vm_end) {
> -		error = split_vma(vmi, vma, end, 0);
> -		if (error)
> -			goto fail;
> +	if (merged) {
> +		vma = *pprev = merged;
> +		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);

This VM_WARN_ON() is AFAICS the only piece of code that cares about merged
vs split. Would it be ok to call it for the split vma cases as well, or
maybe remove it?

> +	} else {
> +		*pprev = vma;
>  	}
>  
> -success:
>  	/*
>  	 * vm_flags and vm_page_prot are protected by the mmap_lock
>  	 * held in write mode.
Lorenzo Stoakes Oct. 10, 2023, 6:11 p.m. UTC | #2
On Tue, Oct 10, 2023 at 09:12:21AM +0200, Vlastimil Babka wrote:
> On 10/9/23 22:53, Lorenzo Stoakes wrote:
> > mprotect() and other functions which change VMA parameters over a range
> > each employ a pattern of:-
> >
> > 1. Attempt to merge the range with adjacent VMAs.
> > 2. If this fails, and the range spans a subset of the VMA, split it
> >    accordingly.
> >
> > This is open-coded and duplicated in each case. Also in each case most of
> > the parameters passed to vma_merge() remain the same.
> >
> > Create a new function, vma_modify(), which abstracts this operation,
> > accepting only those parameters which can be changed.
> >
> > To avoid the mess of invoking each function call with unnecessary
> > parameters, create inline wrapper functions for each of the modify
> > operations, parameterised only by what is required to perform the action.
> >
> > Note that the userfaultfd_release() case works even though it does not
> > split VMAs - since start is set to vma->vm_start and end is set to
> > vma->vm_end, the split logic does not trigger.
> >
> > In addition, since we calculate pgoff to be equal to vma->vm_pgoff + (start
> > - vma->vm_start) >> PAGE_SHIFT, and start - vma->vm_start will be 0 in this
> > instance, this invocation will remain unchanged.
> >
> > Signed-off-by: Lorenzo Stoakes <lstoakes@gmail.com>
>
> Reviewed-by: Vlastimil Babka <vbabka@suse.cz>
>
> some nits below:
>
> > --- a/mm/mmap.c
> > +++ b/mm/mmap.c
> > @@ -2437,6 +2437,51 @@ int split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
> >  	return __split_vma(vmi, vma, addr, new_below);
> >  }
> >
> > +/*
> > + * We are about to modify one or multiple of a VMA's flags, policy, userfaultfd
> > + * context and anonymous VMA name within the range [start, end).
> > + *
> > + * As a result, we might be able to merge the newly modified VMA range with an
> > + * adjacent VMA with identical properties.
> > + *
> > + * If no merge is possible and the range does not span the entirety of the VMA,
> > + * we then need to split the VMA to accommodate the change.
> > + */
>
> This could describe the return value too? It's not entirely trivial.
> But I also wonder if we could just return 'vma' for the split_vma() cases
> and the callers could simply stop distinguishing whether there was a merge
> or split, and their code would become even simpler?
> It seems to me most callers don't care, except mprotect, see below...

What a great idea, thanks! I have worked through and implemented this and
it does indeed work and simplify things even further, cheers!

>
> > +struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
> > +				  struct vm_area_struct *prev,
> > +				  struct vm_area_struct *vma,
> > +				  unsigned long start, unsigned long end,
> > +				  unsigned long vm_flags,
> > +				  struct mempolicy *policy,
> > +				  struct vm_userfaultfd_ctx uffd_ctx,
> > +				  struct anon_vma_name *anon_name)
> > +{
> > +	pgoff_t pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> > +	struct vm_area_struct *merged;
> > +
> > +	merged = vma_merge(vmi, vma->vm_mm, prev, start, end, vm_flags,
> > +			   vma->anon_vma, vma->vm_file, pgoff, policy,
> > +			   uffd_ctx, anon_name);
> > +	if (merged)
> > +		return merged;
> > +
> > +	if (vma->vm_start < start) {
> > +		int err = split_vma(vmi, vma, start, 1);
> > +
> > +		if (err)
> > +			return ERR_PTR(err);
> > +	}
> > +
> > +	if (vma->vm_end > end) {
> > +		int err = split_vma(vmi, vma, end, 0);
> > +
> > +		if (err)
> > +			return ERR_PTR(err);
> > +	}
> > +
> > +	return NULL;
> > +}
> > +
> >  /*
> >   * do_vmi_align_munmap() - munmap the aligned region from @start to @end.
> >   * @vmi: The vma iterator
> > diff --git a/mm/mprotect.c b/mm/mprotect.c
> > index b94fbb45d5c7..6f85d99682ab 100644
> > --- a/mm/mprotect.c
> > +++ b/mm/mprotect.c
> > @@ -581,7 +581,7 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
> >  	long nrpages = (end - start) >> PAGE_SHIFT;
> >  	unsigned int mm_cp_flags = 0;
> >  	unsigned long charged = 0;
> > -	pgoff_t pgoff;
> > +	struct vm_area_struct *merged;
> >  	int error;
> >
> >  	if (newflags == oldflags) {
> > @@ -625,34 +625,19 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
> >  		}
> >  	}
> >
> > -	/*
> > -	 * First try to merge with previous and/or next vma.
> > -	 */
> > -	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> > -	*pprev = vma_merge(vmi, mm, *pprev, start, end, newflags,
> > -			   vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> > -			   vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> > -	if (*pprev) {
> > -		vma = *pprev;
> > -		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
> > -		goto success;
> > +	merged = vma_modify_flags(vmi, *pprev, vma, start, end, newflags);
> > +	if (IS_ERR(merged)) {
> > +		error = PTR_ERR(merged);
> > +		goto fail;
> >  	}
> >
> > -	*pprev = vma;
> > -
> > -	if (start != vma->vm_start) {
> > -		error = split_vma(vmi, vma, start, 1);
> > -		if (error)
> > -			goto fail;
> > -	}
> > -
> > -	if (end != vma->vm_end) {
> > -		error = split_vma(vmi, vma, end, 0);
> > -		if (error)
> > -			goto fail;
> > +	if (merged) {
> > +		vma = *pprev = merged;
> > +		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
>
> This VM_WARN_ON() is AFAICS the only piece of code that cares about merged
> vs split. Would it be ok to call it for the split vma cases as well, or
> maybe remove it?

This is simply asserting a fundamental requirement of vma_merge() in
general, i.e. that the flags of what was merged match those of the VMA that
is being merged.

This is already checked in the VMA merge implementation, so this feels
super redundant, so I think we're good to simply remove it.

>
> > +	} else {
> > +		*pprev = vma;
> >  	}
> >
> > -success:
> >  	/*
> >  	 * vm_flags and vm_page_prot are protected by the mmap_lock
> >  	 * held in write mode.
>
Liam R. Howlett Oct. 11, 2023, 2:14 a.m. UTC | #3
* Lorenzo Stoakes <lstoakes@gmail.com> [231009 16:53]:
> mprotect() and other functions which change VMA parameters over a range
> each employ a pattern of:-
> 
> 1. Attempt to merge the range with adjacent VMAs.
> 2. If this fails, and the range spans a subset of the VMA, split it
>    accordingly.
> 
> This is open-coded and duplicated in each case. Also in each case most of
> the parameters passed to vma_merge() remain the same.
> 
> Create a new function, vma_modify(), which abstracts this operation,
> accepting only those parameters which can be changed.
> 
> To avoid the mess of invoking each function call with unnecessary
> parameters, create inline wrapper functions for each of the modify
> operations, parameterised only by what is required to perform the action.
> 
> Note that the userfaultfd_release() case works even though it does not
> split VMAs - since start is set to vma->vm_start and end is set to
> vma->vm_end, the split logic does not trigger.
> 
> In addition, since we calculate pgoff to be equal to vma->vm_pgoff + (start
> - vma->vm_start) >> PAGE_SHIFT, and start - vma->vm_start will be 0 in this
> instance, this invocation will remain unchanged.
> 
> Signed-off-by: Lorenzo Stoakes <lstoakes@gmail.com>
> ---
>  fs/userfaultfd.c   | 69 +++++++++++++++-------------------------------
>  include/linux/mm.h | 60 ++++++++++++++++++++++++++++++++++++++++
>  mm/madvise.c       | 32 ++++++---------------
>  mm/mempolicy.c     | 22 +++------------
>  mm/mlock.c         | 27 +++++-------------
>  mm/mmap.c          | 45 ++++++++++++++++++++++++++++++
>  mm/mprotect.c      | 35 +++++++----------------
>  7 files changed, 157 insertions(+), 133 deletions(-)
> 
> diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> index a7c6ef764e63..ba44a67a0a34 100644
> --- a/fs/userfaultfd.c
> +++ b/fs/userfaultfd.c
> @@ -927,11 +927,10 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
>  			continue;
>  		}
>  		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
> -		prev = vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
> -				 new_flags, vma->anon_vma,
> -				 vma->vm_file, vma->vm_pgoff,
> -				 vma_policy(vma),
> -				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
> +		prev = vma_modify_flags_uffd(&vmi, prev, vma, vma->vm_start,
> +					     vma->vm_end, new_flags,
> +					     NULL_VM_UFFD_CTX);
> +
>  		if (prev) {
>  			vma = prev;
>  		} else {
> @@ -1331,7 +1330,6 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
>  	unsigned long start, end, vma_end;
>  	struct vma_iterator vmi;
>  	bool wp_async = userfaultfd_wp_async_ctx(ctx);
> -	pgoff_t pgoff;
>  
>  	user_uffdio_register = (struct uffdio_register __user *) arg;
>  
> @@ -1484,28 +1482,17 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
>  		vma_end = min(end, vma->vm_end);
>  
>  		new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
> -		pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -		prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
> -				 vma->anon_vma, vma->vm_file, pgoff,
> -				 vma_policy(vma),
> -				 ((struct vm_userfaultfd_ctx){ ctx }),
> -				 anon_vma_name(vma));
> -		if (prev) {
> -			/* vma_merge() invalidated the mas */
> -			vma = prev;
> -			goto next;
> -		}
> -		if (vma->vm_start < start) {
> -			ret = split_vma(&vmi, vma, start, 1);
> -			if (ret)
> -				break;
> -		}
> -		if (vma->vm_end > end) {
> -			ret = split_vma(&vmi, vma, end, 0);
> -			if (ret)
> -				break;
> +		prev = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
> +					     new_flags,
> +					     (struct vm_userfaultfd_ctx){ctx});
> +		if (IS_ERR(prev)) {
> +			ret = PTR_ERR(prev);
> +			break;
>  		}
> -	next:
> +
> +		if (prev)
> +			vma = prev; /* vma_merge() invalidated the mas */

This is a stale comment.  The maple state is in the vma iterator, which
is passed through.  I missed this on the vma iterator conversion.

> +
>  		/*
>  		 * In the vma_merge() successful mprotect-like case 8:
>  		 * the next vma was merged into the current one and
> @@ -1568,7 +1555,6 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
>  	const void __user *buf = (void __user *)arg;
>  	struct vma_iterator vmi;
>  	bool wp_async = userfaultfd_wp_async_ctx(ctx);
> -	pgoff_t pgoff;
>  
>  	ret = -EFAULT;
>  	if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
> @@ -1671,26 +1657,15 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
>  			uffd_wp_range(vma, start, vma_end - start, false);
>  
>  		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
> -		pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -		prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
> -				 vma->anon_vma, vma->vm_file, pgoff,
> -				 vma_policy(vma),
> -				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
> -		if (prev) {
> -			vma = prev;
> -			goto next;
> -		}
> -		if (vma->vm_start < start) {
> -			ret = split_vma(&vmi, vma, start, 1);
> -			if (ret)
> -				break;
> -		}
> -		if (vma->vm_end > end) {
> -			ret = split_vma(&vmi, vma, end, 0);
> -			if (ret)
> -				break;
> +		prev = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
> +					     new_flags, NULL_VM_UFFD_CTX);
> +		if (IS_ERR(prev)) {
> +			ret = PTR_ERR(prev);
> +			break;
>  		}
> -	next:
> +
> +		if (prev)
> +			vma = prev;
>  		/*
>  		 * In the vma_merge() successful mprotect-like case 8:
>  		 * the next vma was merged into the current one and
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index a7b667786cde..83ee1f35febe 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -3253,6 +3253,66 @@ extern struct vm_area_struct *copy_vma(struct vm_area_struct **,
>  	unsigned long addr, unsigned long len, pgoff_t pgoff,
>  	bool *need_rmap_locks);
>  extern void exit_mmap(struct mm_struct *);
> +struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
> +				  struct vm_area_struct *prev,
> +				  struct vm_area_struct *vma,
> +				  unsigned long start, unsigned long end,
> +				  unsigned long vm_flags,
> +				  struct mempolicy *policy,
> +				  struct vm_userfaultfd_ctx uffd_ctx,
> +				  struct anon_vma_name *anon_name);
> +
> +/* We are about to modify the VMA's flags. */
> +static inline struct vm_area_struct
> +*vma_modify_flags(struct vma_iterator *vmi,
> +		  struct vm_area_struct *prev,
> +		  struct vm_area_struct *vma,
> +		  unsigned long start, unsigned long end,
> +		  unsigned long new_flags)
> +{
> +	return vma_modify(vmi, prev, vma, start, end, new_flags,
> +			  vma_policy(vma), vma->vm_userfaultfd_ctx,
> +			  anon_vma_name(vma));
> +}
> +
> +/* We are about to modify the VMA's flags and/or anon_name. */
> +static inline struct vm_area_struct
> +*vma_modify_flags_name(struct vma_iterator *vmi,
> +		       struct vm_area_struct *prev,
> +		       struct vm_area_struct *vma,
> +		       unsigned long start,
> +		       unsigned long end,
> +		       unsigned long new_flags,
> +		       struct anon_vma_name *new_name)
> +{
> +	return vma_modify(vmi, prev, vma, start, end, new_flags,
> +			  vma_policy(vma), vma->vm_userfaultfd_ctx, new_name);
> +}
> +
> +/* We are about to modify the VMA's memory policy. */
> +static inline struct vm_area_struct
> +*vma_modify_policy(struct vma_iterator *vmi,
> +		   struct vm_area_struct *prev,
> +		   struct vm_area_struct *vma,
> +		   unsigned long start, unsigned long end,
> +		   struct mempolicy *new_pol)
> +{
> +	return vma_modify(vmi, prev, vma, start, end, vma->vm_flags,
> +			  new_pol, vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> +}
> +
> +/* We are about to modify the VMA's flags and/or uffd context. */
> +static inline struct vm_area_struct
> +*vma_modify_flags_uffd(struct vma_iterator *vmi,
> +		       struct vm_area_struct *prev,
> +		       struct vm_area_struct *vma,
> +		       unsigned long start, unsigned long end,
> +		       unsigned long new_flags,
> +		       struct vm_userfaultfd_ctx new_ctx)
> +{
> +	return vma_modify(vmi, prev, vma, start, end, new_flags,
> +			  vma_policy(vma), new_ctx, anon_vma_name(vma));
> +}
>  
>  static inline int check_data_rlimit(unsigned long rlim,
>  				    unsigned long new,
> diff --git a/mm/madvise.c b/mm/madvise.c
> index a4a20de50494..801d3c1bb7b3 100644
> --- a/mm/madvise.c
> +++ b/mm/madvise.c
> @@ -141,7 +141,7 @@ static int madvise_update_vma(struct vm_area_struct *vma,
>  {
>  	struct mm_struct *mm = vma->vm_mm;
>  	int error;
> -	pgoff_t pgoff;
> +	struct vm_area_struct *merged;
>  	VMA_ITERATOR(vmi, mm, start);
>  
>  	if (new_flags == vma->vm_flags && anon_vma_name_eq(anon_vma_name(vma), anon_name)) {
> @@ -149,30 +149,16 @@ static int madvise_update_vma(struct vm_area_struct *vma,
>  		return 0;
>  	}
>  
> -	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -	*prev = vma_merge(&vmi, mm, *prev, start, end, new_flags,
> -			  vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> -			  vma->vm_userfaultfd_ctx, anon_name);
> -	if (*prev) {
> -		vma = *prev;
> -		goto success;
> -	}
> -
> -	*prev = vma;
> -
> -	if (start != vma->vm_start) {
> -		error = split_vma(&vmi, vma, start, 1);
> -		if (error)
> -			return error;
> -	}
> +	merged = vma_modify_flags_name(&vmi, *prev, vma, start, end, new_flags,
> +				       anon_name);
> +	if (IS_ERR(merged))
> +		return PTR_ERR(merged);
>  
> -	if (end != vma->vm_end) {
> -		error = split_vma(&vmi, vma, end, 0);
> -		if (error)
> -			return error;
> -	}
> +	if (merged)
> +		vma = *prev = merged;
> +	else
> +		*prev = vma;
>  
> -success:
>  	/* vm_flags is protected by the mmap_lock held in write mode. */
>  	vma_start_write(vma);
>  	vm_flags_reset(vma, new_flags);
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index b01922e88548..6b2e99db6dd5 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -786,8 +786,6 @@ static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  {
>  	struct vm_area_struct *merged;
>  	unsigned long vmstart, vmend;
> -	pgoff_t pgoff;
> -	int err;
>  
>  	vmend = min(end, vma->vm_end);
>  	if (start > vma->vm_start) {
> @@ -802,27 +800,15 @@ static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  		return 0;
>  	}
>  
> -	pgoff = vma->vm_pgoff + ((vmstart - vma->vm_start) >> PAGE_SHIFT);
> -	merged = vma_merge(vmi, vma->vm_mm, *prev, vmstart, vmend, vma->vm_flags,
> -			 vma->anon_vma, vma->vm_file, pgoff, new_pol,
> -			 vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> +	merged =  vma_modify_policy(vmi, *prev, vma, vmstart, vmend, new_pol);
> +	if (IS_ERR(merged))
> +		return PTR_ERR(merged);
> +
>  	if (merged) {
>  		*prev = merged;
>  		return vma_replace_policy(merged, new_pol);
>  	}
>  
> -	if (vma->vm_start != vmstart) {
> -		err = split_vma(vmi, vma, vmstart, 1);
> -		if (err)
> -			return err;
> -	}
> -
> -	if (vma->vm_end != vmend) {
> -		err = split_vma(vmi, vma, vmend, 0);
> -		if (err)
> -			return err;
> -	}
> -
>  	*prev = vma;
>  	return vma_replace_policy(vma, new_pol);
>  }
> diff --git a/mm/mlock.c b/mm/mlock.c
> index 42b6865f8f82..ae83a33c387e 100644
> --- a/mm/mlock.c
> +++ b/mm/mlock.c
> @@ -476,10 +476,10 @@ static int mlock_fixup(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  	       unsigned long end, vm_flags_t newflags)
>  {
>  	struct mm_struct *mm = vma->vm_mm;
> -	pgoff_t pgoff;
>  	int nr_pages;
>  	int ret = 0;
>  	vm_flags_t oldflags = vma->vm_flags;
> +	struct vm_area_struct *merged;
>  
>  	if (newflags == oldflags || (oldflags & VM_SPECIAL) ||
>  	    is_vm_hugetlb_page(vma) || vma == get_gate_vma(current->mm) ||
> @@ -487,28 +487,15 @@ static int mlock_fixup(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  		/* don't set VM_LOCKED or VM_LOCKONFAULT and don't count */
>  		goto out;
>  
> -	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -	*prev = vma_merge(vmi, mm, *prev, start, end, newflags,
> -			vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> -			vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> -	if (*prev) {
> -		vma = *prev;
> -		goto success;
> -	}
> -
> -	if (start != vma->vm_start) {
> -		ret = split_vma(vmi, vma, start, 1);
> -		if (ret)
> -			goto out;
> +	merged = vma_modify_flags(vmi, *prev, vma, start, end, newflags);
> +	if (IS_ERR(merged)) {
> +		ret = PTR_ERR(merged);
> +		goto out;
>  	}
>  
> -	if (end != vma->vm_end) {
> -		ret = split_vma(vmi, vma, end, 0);
> -		if (ret)
> -			goto out;
> -	}
> +	if (merged)
> +		vma = *prev = merged;
>  
> -success:
>  	/*
>  	 * Keep track of amount of locked VM.
>  	 */
> diff --git a/mm/mmap.c b/mm/mmap.c
> index 673429ee8a9e..22d968affc07 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -2437,6 +2437,51 @@ int split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  	return __split_vma(vmi, vma, addr, new_below);
>  }
>  
> +/*
> + * We are about to modify one or multiple of a VMA's flags, policy, userfaultfd
> + * context and anonymous VMA name within the range [start, end).
> + *
> + * As a result, we might be able to merge the newly modified VMA range with an
> + * adjacent VMA with identical properties.
> + *
> + * If no merge is possible and the range does not span the entirety of the VMA,
> + * we then need to split the VMA to accommodate the change.
> + */
> +struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
> +				  struct vm_area_struct *prev,
> +				  struct vm_area_struct *vma,
> +				  unsigned long start, unsigned long end,
> +				  unsigned long vm_flags,
> +				  struct mempolicy *policy,
> +				  struct vm_userfaultfd_ctx uffd_ctx,
> +				  struct anon_vma_name *anon_name)
> +{
> +	pgoff_t pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> +	struct vm_area_struct *merged;
> +
> +	merged = vma_merge(vmi, vma->vm_mm, prev, start, end, vm_flags,
> +			   vma->anon_vma, vma->vm_file, pgoff, policy,
> +			   uffd_ctx, anon_name);
> +	if (merged)
> +		return merged;
> +
> +	if (vma->vm_start < start) {
> +		int err = split_vma(vmi, vma, start, 1);
> +
> +		if (err)
> +			return ERR_PTR(err);
> +	}
> +
> +	if (vma->vm_end > end) {
> +		int err = split_vma(vmi, vma, end, 0);
> +
> +		if (err)
> +			return ERR_PTR(err);
> +	}
> +
> +	return NULL;
> +}
> +
>  /*
>   * do_vmi_align_munmap() - munmap the aligned region from @start to @end.
>   * @vmi: The vma iterator
> diff --git a/mm/mprotect.c b/mm/mprotect.c
> index b94fbb45d5c7..6f85d99682ab 100644
> --- a/mm/mprotect.c
> +++ b/mm/mprotect.c
> @@ -581,7 +581,7 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
>  	long nrpages = (end - start) >> PAGE_SHIFT;
>  	unsigned int mm_cp_flags = 0;
>  	unsigned long charged = 0;
> -	pgoff_t pgoff;
> +	struct vm_area_struct *merged;
>  	int error;
>  
>  	if (newflags == oldflags) {
> @@ -625,34 +625,19 @@ mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
>  		}
>  	}
>  
> -	/*
> -	 * First try to merge with previous and/or next vma.
> -	 */
> -	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> -	*pprev = vma_merge(vmi, mm, *pprev, start, end, newflags,
> -			   vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
> -			   vma->vm_userfaultfd_ctx, anon_vma_name(vma));
> -	if (*pprev) {
> -		vma = *pprev;
> -		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
> -		goto success;
> +	merged = vma_modify_flags(vmi, *pprev, vma, start, end, newflags);
> +	if (IS_ERR(merged)) {
> +		error = PTR_ERR(merged);
> +		goto fail;
>  	}
>  
> -	*pprev = vma;
> -
> -	if (start != vma->vm_start) {
> -		error = split_vma(vmi, vma, start, 1);
> -		if (error)
> -			goto fail;
> -	}
> -
> -	if (end != vma->vm_end) {
> -		error = split_vma(vmi, vma, end, 0);
> -		if (error)
> -			goto fail;
> +	if (merged) {
> +		vma = *pprev = merged;
> +		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
> +	} else {
> +		*pprev = vma;
>  	}
>  
> -success:
>  	/*
>  	 * vm_flags and vm_page_prot are protected by the mmap_lock
>  	 * held in write mode.
> -- 
> 2.42.0
>
Lorenzo Stoakes Oct. 11, 2023, 6:34 a.m. UTC | #4
On Tue, Oct 10, 2023 at 10:14:52PM -0400, Liam R. Howlett wrote:
> * Lorenzo Stoakes <lstoakes@gmail.com> [231009 16:53]:
> > mprotect() and other functions which change VMA parameters over a range
> > each employ a pattern of:-
> >
> > 1. Attempt to merge the range with adjacent VMAs.
> > 2. If this fails, and the range spans a subset of the VMA, split it
> >    accordingly.
> >
> > This is open-coded and duplicated in each case. Also in each case most of
> > the parameters passed to vma_merge() remain the same.
> >
> > Create a new function, vma_modify(), which abstracts this operation,
> > accepting only those parameters which can be changed.
> >
> > To avoid the mess of invoking each function call with unnecessary
> > parameters, create inline wrapper functions for each of the modify
> > operations, parameterised only by what is required to perform the action.
> >
> > Note that the userfaultfd_release() case works even though it does not
> > split VMAs - since start is set to vma->vm_start and end is set to
> > vma->vm_end, the split logic does not trigger.
> >
> > In addition, since we calculate pgoff to be equal to vma->vm_pgoff + (start
> > - vma->vm_start) >> PAGE_SHIFT, and start - vma->vm_start will be 0 in this
> > instance, this invocation will remain unchanged.
> >
> > Signed-off-by: Lorenzo Stoakes <lstoakes@gmail.com>
> > ---
> >  fs/userfaultfd.c   | 69 +++++++++++++++-------------------------------
> >  include/linux/mm.h | 60 ++++++++++++++++++++++++++++++++++++++++
> >  mm/madvise.c       | 32 ++++++---------------
> >  mm/mempolicy.c     | 22 +++------------
> >  mm/mlock.c         | 27 +++++-------------
> >  mm/mmap.c          | 45 ++++++++++++++++++++++++++++++
> >  mm/mprotect.c      | 35 +++++++----------------
> >  7 files changed, 157 insertions(+), 133 deletions(-)
> >
> > diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> > index a7c6ef764e63..ba44a67a0a34 100644
> > --- a/fs/userfaultfd.c
> > +++ b/fs/userfaultfd.c
> > @@ -927,11 +927,10 @@ static int userfaultfd_release(struct inode *inode, struct file *file)
> >  			continue;
> >  		}
> >  		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
> > -		prev = vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
> > -				 new_flags, vma->anon_vma,
> > -				 vma->vm_file, vma->vm_pgoff,
> > -				 vma_policy(vma),
> > -				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
> > +		prev = vma_modify_flags_uffd(&vmi, prev, vma, vma->vm_start,
> > +					     vma->vm_end, new_flags,
> > +					     NULL_VM_UFFD_CTX);
> > +
> >  		if (prev) {
> >  			vma = prev;
> >  		} else {
> > @@ -1331,7 +1330,6 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
> >  	unsigned long start, end, vma_end;
> >  	struct vma_iterator vmi;
> >  	bool wp_async = userfaultfd_wp_async_ctx(ctx);
> > -	pgoff_t pgoff;
> >
> >  	user_uffdio_register = (struct uffdio_register __user *) arg;
> >
> > @@ -1484,28 +1482,17 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
> >  		vma_end = min(end, vma->vm_end);
> >
> >  		new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
> > -		pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
> > -		prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
> > -				 vma->anon_vma, vma->vm_file, pgoff,
> > -				 vma_policy(vma),
> > -				 ((struct vm_userfaultfd_ctx){ ctx }),
> > -				 anon_vma_name(vma));
> > -		if (prev) {
> > -			/* vma_merge() invalidated the mas */
> > -			vma = prev;
> > -			goto next;
> > -		}
> > -		if (vma->vm_start < start) {
> > -			ret = split_vma(&vmi, vma, start, 1);
> > -			if (ret)
> > -				break;
> > -		}
> > -		if (vma->vm_end > end) {
> > -			ret = split_vma(&vmi, vma, end, 0);
> > -			if (ret)
> > -				break;
> > +		prev = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
> > +					     new_flags,
> > +					     (struct vm_userfaultfd_ctx){ctx});
> > +		if (IS_ERR(prev)) {
> > +			ret = PTR_ERR(prev);
> > +			break;
> >  		}
> > -	next:
> > +
> > +		if (prev)
> > +			vma = prev; /* vma_merge() invalidated the mas */
>
> This is a stale comment.  The maple state is in the vma iterator, which
> is passed through.  I missed this on the vma iterator conversion.

Ack, this was coincidentally removed in v3 so this is already resolved.
diff mbox series

Patch

diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index a7c6ef764e63..ba44a67a0a34 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -927,11 +927,10 @@  static int userfaultfd_release(struct inode *inode, struct file *file)
 			continue;
 		}
 		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
-		prev = vma_merge(&vmi, mm, prev, vma->vm_start, vma->vm_end,
-				 new_flags, vma->anon_vma,
-				 vma->vm_file, vma->vm_pgoff,
-				 vma_policy(vma),
-				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
+		prev = vma_modify_flags_uffd(&vmi, prev, vma, vma->vm_start,
+					     vma->vm_end, new_flags,
+					     NULL_VM_UFFD_CTX);
+
 		if (prev) {
 			vma = prev;
 		} else {
@@ -1331,7 +1330,6 @@  static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 	unsigned long start, end, vma_end;
 	struct vma_iterator vmi;
 	bool wp_async = userfaultfd_wp_async_ctx(ctx);
-	pgoff_t pgoff;
 
 	user_uffdio_register = (struct uffdio_register __user *) arg;
 
@@ -1484,28 +1482,17 @@  static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 		vma_end = min(end, vma->vm_end);
 
 		new_flags = (vma->vm_flags & ~__VM_UFFD_FLAGS) | vm_flags;
-		pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
-		prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
-				 vma->anon_vma, vma->vm_file, pgoff,
-				 vma_policy(vma),
-				 ((struct vm_userfaultfd_ctx){ ctx }),
-				 anon_vma_name(vma));
-		if (prev) {
-			/* vma_merge() invalidated the mas */
-			vma = prev;
-			goto next;
-		}
-		if (vma->vm_start < start) {
-			ret = split_vma(&vmi, vma, start, 1);
-			if (ret)
-				break;
-		}
-		if (vma->vm_end > end) {
-			ret = split_vma(&vmi, vma, end, 0);
-			if (ret)
-				break;
+		prev = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
+					     new_flags,
+					     (struct vm_userfaultfd_ctx){ctx});
+		if (IS_ERR(prev)) {
+			ret = PTR_ERR(prev);
+			break;
 		}
-	next:
+
+		if (prev)
+			vma = prev; /* vma_merge() invalidated the mas */
+
 		/*
 		 * In the vma_merge() successful mprotect-like case 8:
 		 * the next vma was merged into the current one and
@@ -1568,7 +1555,6 @@  static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
 	const void __user *buf = (void __user *)arg;
 	struct vma_iterator vmi;
 	bool wp_async = userfaultfd_wp_async_ctx(ctx);
-	pgoff_t pgoff;
 
 	ret = -EFAULT;
 	if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
@@ -1671,26 +1657,15 @@  static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
 			uffd_wp_range(vma, start, vma_end - start, false);
 
 		new_flags = vma->vm_flags & ~__VM_UFFD_FLAGS;
-		pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
-		prev = vma_merge(&vmi, mm, prev, start, vma_end, new_flags,
-				 vma->anon_vma, vma->vm_file, pgoff,
-				 vma_policy(vma),
-				 NULL_VM_UFFD_CTX, anon_vma_name(vma));
-		if (prev) {
-			vma = prev;
-			goto next;
-		}
-		if (vma->vm_start < start) {
-			ret = split_vma(&vmi, vma, start, 1);
-			if (ret)
-				break;
-		}
-		if (vma->vm_end > end) {
-			ret = split_vma(&vmi, vma, end, 0);
-			if (ret)
-				break;
+		prev = vma_modify_flags_uffd(&vmi, prev, vma, start, vma_end,
+					     new_flags, NULL_VM_UFFD_CTX);
+		if (IS_ERR(prev)) {
+			ret = PTR_ERR(prev);
+			break;
 		}
-	next:
+
+		if (prev)
+			vma = prev;
 		/*
 		 * In the vma_merge() successful mprotect-like case 8:
 		 * the next vma was merged into the current one and
diff --git a/include/linux/mm.h b/include/linux/mm.h
index a7b667786cde..83ee1f35febe 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -3253,6 +3253,66 @@  extern struct vm_area_struct *copy_vma(struct vm_area_struct **,
 	unsigned long addr, unsigned long len, pgoff_t pgoff,
 	bool *need_rmap_locks);
 extern void exit_mmap(struct mm_struct *);
+struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
+				  struct vm_area_struct *prev,
+				  struct vm_area_struct *vma,
+				  unsigned long start, unsigned long end,
+				  unsigned long vm_flags,
+				  struct mempolicy *policy,
+				  struct vm_userfaultfd_ctx uffd_ctx,
+				  struct anon_vma_name *anon_name);
+
+/* We are about to modify the VMA's flags. */
+static inline struct vm_area_struct
+*vma_modify_flags(struct vma_iterator *vmi,
+		  struct vm_area_struct *prev,
+		  struct vm_area_struct *vma,
+		  unsigned long start, unsigned long end,
+		  unsigned long new_flags)
+{
+	return vma_modify(vmi, prev, vma, start, end, new_flags,
+			  vma_policy(vma), vma->vm_userfaultfd_ctx,
+			  anon_vma_name(vma));
+}
+
+/* We are about to modify the VMA's flags and/or anon_name. */
+static inline struct vm_area_struct
+*vma_modify_flags_name(struct vma_iterator *vmi,
+		       struct vm_area_struct *prev,
+		       struct vm_area_struct *vma,
+		       unsigned long start,
+		       unsigned long end,
+		       unsigned long new_flags,
+		       struct anon_vma_name *new_name)
+{
+	return vma_modify(vmi, prev, vma, start, end, new_flags,
+			  vma_policy(vma), vma->vm_userfaultfd_ctx, new_name);
+}
+
+/* We are about to modify the VMA's memory policy. */
+static inline struct vm_area_struct
+*vma_modify_policy(struct vma_iterator *vmi,
+		   struct vm_area_struct *prev,
+		   struct vm_area_struct *vma,
+		   unsigned long start, unsigned long end,
+		   struct mempolicy *new_pol)
+{
+	return vma_modify(vmi, prev, vma, start, end, vma->vm_flags,
+			  new_pol, vma->vm_userfaultfd_ctx, anon_vma_name(vma));
+}
+
+/* We are about to modify the VMA's flags and/or uffd context. */
+static inline struct vm_area_struct
+*vma_modify_flags_uffd(struct vma_iterator *vmi,
+		       struct vm_area_struct *prev,
+		       struct vm_area_struct *vma,
+		       unsigned long start, unsigned long end,
+		       unsigned long new_flags,
+		       struct vm_userfaultfd_ctx new_ctx)
+{
+	return vma_modify(vmi, prev, vma, start, end, new_flags,
+			  vma_policy(vma), new_ctx, anon_vma_name(vma));
+}
 
 static inline int check_data_rlimit(unsigned long rlim,
 				    unsigned long new,
diff --git a/mm/madvise.c b/mm/madvise.c
index a4a20de50494..801d3c1bb7b3 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -141,7 +141,7 @@  static int madvise_update_vma(struct vm_area_struct *vma,
 {
 	struct mm_struct *mm = vma->vm_mm;
 	int error;
-	pgoff_t pgoff;
+	struct vm_area_struct *merged;
 	VMA_ITERATOR(vmi, mm, start);
 
 	if (new_flags == vma->vm_flags && anon_vma_name_eq(anon_vma_name(vma), anon_name)) {
@@ -149,30 +149,16 @@  static int madvise_update_vma(struct vm_area_struct *vma,
 		return 0;
 	}
 
-	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
-	*prev = vma_merge(&vmi, mm, *prev, start, end, new_flags,
-			  vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
-			  vma->vm_userfaultfd_ctx, anon_name);
-	if (*prev) {
-		vma = *prev;
-		goto success;
-	}
-
-	*prev = vma;
-
-	if (start != vma->vm_start) {
-		error = split_vma(&vmi, vma, start, 1);
-		if (error)
-			return error;
-	}
+	merged = vma_modify_flags_name(&vmi, *prev, vma, start, end, new_flags,
+				       anon_name);
+	if (IS_ERR(merged))
+		return PTR_ERR(merged);
 
-	if (end != vma->vm_end) {
-		error = split_vma(&vmi, vma, end, 0);
-		if (error)
-			return error;
-	}
+	if (merged)
+		vma = *prev = merged;
+	else
+		*prev = vma;
 
-success:
 	/* vm_flags is protected by the mmap_lock held in write mode. */
 	vma_start_write(vma);
 	vm_flags_reset(vma, new_flags);
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index b01922e88548..6b2e99db6dd5 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -786,8 +786,6 @@  static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
 {
 	struct vm_area_struct *merged;
 	unsigned long vmstart, vmend;
-	pgoff_t pgoff;
-	int err;
 
 	vmend = min(end, vma->vm_end);
 	if (start > vma->vm_start) {
@@ -802,27 +800,15 @@  static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
 		return 0;
 	}
 
-	pgoff = vma->vm_pgoff + ((vmstart - vma->vm_start) >> PAGE_SHIFT);
-	merged = vma_merge(vmi, vma->vm_mm, *prev, vmstart, vmend, vma->vm_flags,
-			 vma->anon_vma, vma->vm_file, pgoff, new_pol,
-			 vma->vm_userfaultfd_ctx, anon_vma_name(vma));
+	merged =  vma_modify_policy(vmi, *prev, vma, vmstart, vmend, new_pol);
+	if (IS_ERR(merged))
+		return PTR_ERR(merged);
+
 	if (merged) {
 		*prev = merged;
 		return vma_replace_policy(merged, new_pol);
 	}
 
-	if (vma->vm_start != vmstart) {
-		err = split_vma(vmi, vma, vmstart, 1);
-		if (err)
-			return err;
-	}
-
-	if (vma->vm_end != vmend) {
-		err = split_vma(vmi, vma, vmend, 0);
-		if (err)
-			return err;
-	}
-
 	*prev = vma;
 	return vma_replace_policy(vma, new_pol);
 }
diff --git a/mm/mlock.c b/mm/mlock.c
index 42b6865f8f82..ae83a33c387e 100644
--- a/mm/mlock.c
+++ b/mm/mlock.c
@@ -476,10 +476,10 @@  static int mlock_fixup(struct vma_iterator *vmi, struct vm_area_struct *vma,
 	       unsigned long end, vm_flags_t newflags)
 {
 	struct mm_struct *mm = vma->vm_mm;
-	pgoff_t pgoff;
 	int nr_pages;
 	int ret = 0;
 	vm_flags_t oldflags = vma->vm_flags;
+	struct vm_area_struct *merged;
 
 	if (newflags == oldflags || (oldflags & VM_SPECIAL) ||
 	    is_vm_hugetlb_page(vma) || vma == get_gate_vma(current->mm) ||
@@ -487,28 +487,15 @@  static int mlock_fixup(struct vma_iterator *vmi, struct vm_area_struct *vma,
 		/* don't set VM_LOCKED or VM_LOCKONFAULT and don't count */
 		goto out;
 
-	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
-	*prev = vma_merge(vmi, mm, *prev, start, end, newflags,
-			vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
-			vma->vm_userfaultfd_ctx, anon_vma_name(vma));
-	if (*prev) {
-		vma = *prev;
-		goto success;
-	}
-
-	if (start != vma->vm_start) {
-		ret = split_vma(vmi, vma, start, 1);
-		if (ret)
-			goto out;
+	merged = vma_modify_flags(vmi, *prev, vma, start, end, newflags);
+	if (IS_ERR(merged)) {
+		ret = PTR_ERR(merged);
+		goto out;
 	}
 
-	if (end != vma->vm_end) {
-		ret = split_vma(vmi, vma, end, 0);
-		if (ret)
-			goto out;
-	}
+	if (merged)
+		vma = *prev = merged;
 
-success:
 	/*
 	 * Keep track of amount of locked VM.
 	 */
diff --git a/mm/mmap.c b/mm/mmap.c
index 673429ee8a9e..22d968affc07 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -2437,6 +2437,51 @@  int split_vma(struct vma_iterator *vmi, struct vm_area_struct *vma,
 	return __split_vma(vmi, vma, addr, new_below);
 }
 
+/*
+ * We are about to modify one or multiple of a VMA's flags, policy, userfaultfd
+ * context and anonymous VMA name within the range [start, end).
+ *
+ * As a result, we might be able to merge the newly modified VMA range with an
+ * adjacent VMA with identical properties.
+ *
+ * If no merge is possible and the range does not span the entirety of the VMA,
+ * we then need to split the VMA to accommodate the change.
+ */
+struct vm_area_struct *vma_modify(struct vma_iterator *vmi,
+				  struct vm_area_struct *prev,
+				  struct vm_area_struct *vma,
+				  unsigned long start, unsigned long end,
+				  unsigned long vm_flags,
+				  struct mempolicy *policy,
+				  struct vm_userfaultfd_ctx uffd_ctx,
+				  struct anon_vma_name *anon_name)
+{
+	pgoff_t pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
+	struct vm_area_struct *merged;
+
+	merged = vma_merge(vmi, vma->vm_mm, prev, start, end, vm_flags,
+			   vma->anon_vma, vma->vm_file, pgoff, policy,
+			   uffd_ctx, anon_name);
+	if (merged)
+		return merged;
+
+	if (vma->vm_start < start) {
+		int err = split_vma(vmi, vma, start, 1);
+
+		if (err)
+			return ERR_PTR(err);
+	}
+
+	if (vma->vm_end > end) {
+		int err = split_vma(vmi, vma, end, 0);
+
+		if (err)
+			return ERR_PTR(err);
+	}
+
+	return NULL;
+}
+
 /*
  * do_vmi_align_munmap() - munmap the aligned region from @start to @end.
  * @vmi: The vma iterator
diff --git a/mm/mprotect.c b/mm/mprotect.c
index b94fbb45d5c7..6f85d99682ab 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -581,7 +581,7 @@  mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
 	long nrpages = (end - start) >> PAGE_SHIFT;
 	unsigned int mm_cp_flags = 0;
 	unsigned long charged = 0;
-	pgoff_t pgoff;
+	struct vm_area_struct *merged;
 	int error;
 
 	if (newflags == oldflags) {
@@ -625,34 +625,19 @@  mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
 		}
 	}
 
-	/*
-	 * First try to merge with previous and/or next vma.
-	 */
-	pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
-	*pprev = vma_merge(vmi, mm, *pprev, start, end, newflags,
-			   vma->anon_vma, vma->vm_file, pgoff, vma_policy(vma),
-			   vma->vm_userfaultfd_ctx, anon_vma_name(vma));
-	if (*pprev) {
-		vma = *pprev;
-		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
-		goto success;
+	merged = vma_modify_flags(vmi, *pprev, vma, start, end, newflags);
+	if (IS_ERR(merged)) {
+		error = PTR_ERR(merged);
+		goto fail;
 	}
 
-	*pprev = vma;
-
-	if (start != vma->vm_start) {
-		error = split_vma(vmi, vma, start, 1);
-		if (error)
-			goto fail;
-	}
-
-	if (end != vma->vm_end) {
-		error = split_vma(vmi, vma, end, 0);
-		if (error)
-			goto fail;
+	if (merged) {
+		vma = *pprev = merged;
+		VM_WARN_ON((vma->vm_flags ^ newflags) & ~VM_SOFTDIRTY);
+	} else {
+		*pprev = vma;
 	}
 
-success:
 	/*
 	 * vm_flags and vm_page_prot are protected by the mmap_lock
 	 * held in write mode.