diff mbox series

[v6,14/20] mm/mmap: Avoid zeroing vma tree in mmap_region()

Message ID 20240820235730.2852400-15-Liam.Howlett@oracle.com (mailing list archive)
State New
Headers show
Series Avoid MAP_FIXED gap exposure | expand

Commit Message

Liam R. Howlett Aug. 20, 2024, 11:57 p.m. UTC
From: "Liam R. Howlett" <Liam.Howlett@Oracle.com>

Instead of zeroing the vma tree and then overwriting the area, let the
area be overwritten and then clean up the gathered vmas using
vms_complete_munmap_vmas().

To ensure locking is downgraded correctly, the mm is set regardless of
MAP_FIXED or not (NULL vma).

If a driver is mapping over an existing vma, then clear the ptes before
the call_mmap() invocation.  This is done using the vms_clean_up_area()
helper.  If there is a close vm_ops, that must also be called to ensure
any cleanup is done before mapping over the area.  This also means that
calling open has been added to the abort of an unmap operation, for now.

Temporarily keep track of the number of pages that will be removed and
reduce the charged amount.

This also drops the validate_mm() call in the vma_expand() function.
It is necessary to drop the validate as it would fail since the mm
map_count would be incorrect during a vma expansion, prior to the
cleanup from vms_complete_munmap_vmas().

Clean up the error handing of the vms_gather_munmap_vmas() by calling
the verification within the function.

Signed-off-by: Liam R. Howlett <Liam.Howlett@Oracle.com>
---
 mm/mmap.c | 62 ++++++++++++++++++++++++++-----------------------------
 mm/vma.c  | 54 +++++++++++++++++++++++++++++++++++++-----------
 mm/vma.h  | 22 ++++++++++++++------
 3 files changed, 87 insertions(+), 51 deletions(-)

Comments

Lorenzo Stoakes Aug. 21, 2024, 11:02 a.m. UTC | #1
On Tue, Aug 20, 2024 at 07:57:23PM GMT, Liam R. Howlett wrote:
> From: "Liam R. Howlett" <Liam.Howlett@Oracle.com>
>
> Instead of zeroing the vma tree and then overwriting the area, let the
> area be overwritten and then clean up the gathered vmas using
> vms_complete_munmap_vmas().
>
> To ensure locking is downgraded correctly, the mm is set regardless of
> MAP_FIXED or not (NULL vma).
>
> If a driver is mapping over an existing vma, then clear the ptes before
> the call_mmap() invocation.  This is done using the vms_clean_up_area()
> helper.  If there is a close vm_ops, that must also be called to ensure
> any cleanup is done before mapping over the area.  This also means that
> calling open has been added to the abort of an unmap operation, for now.

Might be worth explicitly expanding this to say that this isn't a permanent
solution because of asymmetric vm_ops->open() / close().

>
> Temporarily keep track of the number of pages that will be removed and
> reduce the charged amount.
>
> This also drops the validate_mm() call in the vma_expand() function.
> It is necessary to drop the validate as it would fail since the mm
> map_count would be incorrect during a vma expansion, prior to the
> cleanup from vms_complete_munmap_vmas().
>
> Clean up the error handing of the vms_gather_munmap_vmas() by calling
> the verification within the function.
>
> Signed-off-by: Liam R. Howlett <Liam.Howlett@Oracle.com>

Broadly looks good, some nits and questions below, but generally:

Reviewed-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>

> ---
>  mm/mmap.c | 62 ++++++++++++++++++++++++++-----------------------------
>  mm/vma.c  | 54 +++++++++++++++++++++++++++++++++++++-----------
>  mm/vma.h  | 22 ++++++++++++++------
>  3 files changed, 87 insertions(+), 51 deletions(-)
>
> diff --git a/mm/mmap.c b/mm/mmap.c
> index 71b2bad717b6..6550d9470d3a 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -1373,23 +1373,19 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>  	unsigned long merge_start = addr, merge_end = end;
>  	bool writable_file_mapping = false;
>  	pgoff_t vm_pgoff;
> -	int error;
> +	int error = -ENOMEM;
>  	VMA_ITERATOR(vmi, mm, addr);
> +	unsigned long nr_pages, nr_accounted;
>
> -	/* Check against address space limit. */
> -	if (!may_expand_vm(mm, vm_flags, len >> PAGE_SHIFT)) {
> -		unsigned long nr_pages;
> +	nr_pages = count_vma_pages_range(mm, addr, end, &nr_accounted);
>
> -		/*
> -		 * MAP_FIXED may remove pages of mappings that intersects with
> -		 * requested mapping. Account for the pages it would unmap.
> -		 */
> -		nr_pages = count_vma_pages_range(mm, addr, end);
> -
> -		if (!may_expand_vm(mm, vm_flags,
> -					(len >> PAGE_SHIFT) - nr_pages))
> -			return -ENOMEM;
> -	}
> +	/*
> +	 * Check against address space limit.
> +	 * MAP_FIXED may remove pages of mappings that intersects with requested
> +	 * mapping. Account for the pages it would unmap.
> +	 */
> +	if (!may_expand_vm(mm, vm_flags, (len >> PAGE_SHIFT) - nr_pages))
> +		return -ENOMEM;
>
>  	/* Find the first overlapping VMA */
>  	vma = vma_find(&vmi, end);
> @@ -1400,14 +1396,8 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>  		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
>  		/* Prepare to unmap any existing mapping in the area */
>  		if (vms_gather_munmap_vmas(&vms, &mas_detach))
> -			goto gather_failed;
> -
> -		/* Remove any existing mappings from the vma tree */
> -		if (vma_iter_clear_gfp(&vmi, addr, end, GFP_KERNEL))
> -			goto clear_tree_failed;
> +			return -ENOMEM;
>
> -		/* Unmap any existing mapping in the area */
> -		vms_complete_munmap_vmas(&vms, &mas_detach);
>  		next = vms.next;
>  		prev = vms.prev;
>  		vma = NULL;
> @@ -1423,8 +1413,10 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>  	 */
>  	if (accountable_mapping(file, vm_flags)) {
>  		charged = len >> PAGE_SHIFT;
> +		charged -= nr_accounted;
>  		if (security_vm_enough_memory_mm(mm, charged))
> -			return -ENOMEM;
> +			goto abort_munmap;
> +		vms.nr_accounted = 0;
>  		vm_flags |= VM_ACCOUNT;
>  	}
>
> @@ -1473,10 +1465,8 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>  	 * not unmapped, but the maps are removed from the list.
>  	 */
>  	vma = vm_area_alloc(mm);
> -	if (!vma) {
> -		error = -ENOMEM;
> +	if (!vma)
>  		goto unacct_error;
> -	}
>
>  	vma_iter_config(&vmi, addr, end);
>  	vma_set_range(vma, addr, end, pgoff);
> @@ -1485,6 +1475,11 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>
>  	if (file) {
>  		vma->vm_file = get_file(file);
> +		/*
> +		 * call_mmap() may map PTE, so ensure there are no existing PTEs
> +		 * call the vm_ops close function if one exists.

Super-nit, but maybe add an 'and' here.

> +		 */
> +		vms_clean_up_area(&vms, &mas_detach, true);

I hate that we have to do this. These kind of hooks are the devil's works...

>  		error = call_mmap(file, vma);
>  		if (error)
>  			goto unmap_and_free_vma;
> @@ -1575,6 +1570,9 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>  expanded:
>  	perf_event_mmap(vma);
>
> +	/* Unmap any existing mapping in the area */
> +	vms_complete_munmap_vmas(&vms, &mas_detach);
> +
>  	vm_stat_account(mm, vm_flags, len >> PAGE_SHIFT);
>  	if (vm_flags & VM_LOCKED) {
>  		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
> @@ -1603,7 +1601,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>  	return addr;
>
>  close_and_free_vma:
> -	if (file && vma->vm_ops && vma->vm_ops->close)
> +	if (file && !vms.closed_vm_ops && vma->vm_ops && vma->vm_ops->close)
>  		vma->vm_ops->close(vma);
>
>  	if (file || vma->vm_file) {
> @@ -1622,14 +1620,12 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
>  unacct_error:
>  	if (charged)
>  		vm_unacct_memory(charged);
> -	validate_mm(mm);
> -	return error;
>
> -clear_tree_failed:
> -	abort_munmap_vmas(&mas_detach);
> -gather_failed:
> +abort_munmap:
> +	if (vms.nr_pages)
> +		abort_munmap_vmas(&mas_detach, vms.closed_vm_ops);
>  	validate_mm(mm);
> -	return -ENOMEM;
> +	return error;
>  }
>
>  static int __vm_munmap(unsigned long start, size_t len, bool unlock)
> @@ -1959,7 +1955,7 @@ void exit_mmap(struct mm_struct *mm)
>  	do {
>  		if (vma->vm_flags & VM_ACCOUNT)
>  			nr_accounted += vma_pages(vma);
> -		remove_vma(vma, true);
> +		remove_vma(vma, /* unreachable = */ true, /* closed = */ false);
>  		count++;
>  		cond_resched();
>  		vma = vma_next(&vmi);
> diff --git a/mm/vma.c b/mm/vma.c
> index 7104c2c080bb..5b33f7460ab7 100644
> --- a/mm/vma.c
> +++ b/mm/vma.c
> @@ -136,10 +136,10 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
>  /*
>   * Close a vm structure and free it.
>   */
> -void remove_vma(struct vm_area_struct *vma, bool unreachable)
> +void remove_vma(struct vm_area_struct *vma, bool unreachable, bool closed)
>  {
>  	might_sleep();
> -	if (vma->vm_ops && vma->vm_ops->close)
> +	if (!closed && vma->vm_ops && vma->vm_ops->close)
>  		vma->vm_ops->close(vma);
>  	if (vma->vm_file)
>  		fput(vma->vm_file);
> @@ -521,7 +521,6 @@ int vma_expand(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  	vma_iter_store(vmi, vma);
>
>  	vma_complete(&vp, vmi, vma->vm_mm);
> -	validate_mm(vma->vm_mm);
>  	return 0;
>
>  nomem:
> @@ -645,11 +644,14 @@ void vma_complete(struct vma_prepare *vp,
>  		uprobe_mmap(vp->insert);
>  }
>
> -static void vms_complete_pte_clear(struct vma_munmap_struct *vms,
> -		struct ma_state *mas_detach, bool mm_wr_locked)
> +static inline void vms_clear_ptes(struct vma_munmap_struct *vms,
> +		    struct ma_state *mas_detach, bool mm_wr_locked)
>  {
>  	struct mmu_gather tlb;
>
> +	if (!vms->clear_ptes) /* Nothing to do */
> +		return;
> +
>  	/*
>  	 * We can free page tables without write-locking mmap_lock because VMAs
>  	 * were isolated before we downgraded mmap_lock.
> @@ -658,11 +660,31 @@ static void vms_complete_pte_clear(struct vma_munmap_struct *vms,
>  	lru_add_drain();
>  	tlb_gather_mmu(&tlb, vms->mm);
>  	update_hiwater_rss(vms->mm);
> -	unmap_vmas(&tlb, mas_detach, vms->vma, vms->start, vms->end, vms->vma_count, mm_wr_locked);
> +	unmap_vmas(&tlb, mas_detach, vms->vma, vms->start, vms->end,
> +		   vms->vma_count, mm_wr_locked);
> +
>  	mas_set(mas_detach, 1);
>  	/* start and end may be different if there is no prev or next vma. */
> -	free_pgtables(&tlb, mas_detach, vms->vma, vms->unmap_start, vms->unmap_end, mm_wr_locked);
> +	free_pgtables(&tlb, mas_detach, vms->vma, vms->unmap_start,
> +		      vms->unmap_end, mm_wr_locked);
>  	tlb_finish_mmu(&tlb);
> +	vms->clear_ptes = false;
> +}
> +
> +void vms_clean_up_area(struct vma_munmap_struct *vms,
> +		struct ma_state *mas_detach, bool mm_wr_locked)

The only invocation of this function has mm_wr_locked set, is this
parameter necessary?

> +{
> +	struct vm_area_struct *vma;
> +
> +	if (!vms->nr_pages)
> +		return;
> +
> +	vms_clear_ptes(vms, mas_detach, mm_wr_locked);
> +	mas_set(mas_detach, 0);
> +	mas_for_each(mas_detach, vma, ULONG_MAX)
> +		if (vma->vm_ops && vma->vm_ops->close)
> +			vma->vm_ops->close(vma);
> +	vms->closed_vm_ops = true;
>  }
>
>  /*
> @@ -686,7 +708,10 @@ void vms_complete_munmap_vmas(struct vma_munmap_struct *vms,
>  	if (vms->unlock)
>  		mmap_write_downgrade(mm);
>
> -	vms_complete_pte_clear(vms, mas_detach, !vms->unlock);
> +	if (!vms->nr_pages)
> +		return;
> +
> +	vms_clear_ptes(vms, mas_detach, !vms->unlock);
>  	/* Update high watermark before we lower total_vm */
>  	update_hiwater_vm(mm);
>  	/* Stat accounting */
> @@ -697,7 +722,7 @@ void vms_complete_munmap_vmas(struct vma_munmap_struct *vms,
>  	/* Remove and clean up vmas */
>  	mas_set(mas_detach, 0);
>  	mas_for_each(mas_detach, vma, ULONG_MAX)
> -		remove_vma(vma, false);
> +		remove_vma(vma, /* = */ false, vms->closed_vm_ops);
>
>  	vm_unacct_memory(vms->nr_accounted);
>  	validate_mm(mm);
> @@ -849,13 +874,14 @@ int vms_gather_munmap_vmas(struct vma_munmap_struct *vms,
>  	while (vma_iter_addr(vms->vmi) > vms->start)
>  		vma_iter_prev_range(vms->vmi);
>
> +	vms->clear_ptes = true;
>  	return 0;
>
>  userfaultfd_error:
>  munmap_gather_failed:
>  end_split_failed:
>  modify_vma_failed:
> -	abort_munmap_vmas(mas_detach);
> +	abort_munmap_vmas(mas_detach, /* closed = */ false);
>  start_split_failed:
>  map_count_exceeded:
>  	return error;
> @@ -900,7 +926,7 @@ int do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
>  	return 0;
>
>  clear_tree_failed:
> -	abort_munmap_vmas(&mas_detach);
> +	abort_munmap_vmas(&mas_detach, /* closed = */ false);
>  gather_failed:
>  	validate_mm(mm);
>  	return error;
> @@ -1618,17 +1644,21 @@ bool vma_wants_writenotify(struct vm_area_struct *vma, pgprot_t vm_page_prot)
>  }
>
>  unsigned long count_vma_pages_range(struct mm_struct *mm,
> -				    unsigned long addr, unsigned long end)
> +		unsigned long addr, unsigned long end,
> +		unsigned long *nr_accounted)
>  {
>  	VMA_ITERATOR(vmi, mm, addr);
>  	struct vm_area_struct *vma;
>  	unsigned long nr_pages = 0;
>
> +	*nr_accounted = 0;
>  	for_each_vma_range(vmi, vma, end) {
>  		unsigned long vm_start = max(addr, vma->vm_start);
>  		unsigned long vm_end = min(end, vma->vm_end);
>
>  		nr_pages += PHYS_PFN(vm_end - vm_start);
> +		if (vma->vm_flags & VM_ACCOUNT)
> +			*nr_accounted += PHYS_PFN(vm_end - vm_start);

Nitty, but maybe:
		...
  		unsigned long pages = PHYS_PFN(vm_end - vm_start);

 		nr_pages += pages;
		if (vma->vm_flags & VM_ACCOUNT)
			*nr_accounted += pages;

>  	}
>
>  	return nr_pages;
> diff --git a/mm/vma.h b/mm/vma.h
> index 6028fdf79257..756dd42a6ec4 100644
> --- a/mm/vma.h
> +++ b/mm/vma.h
> @@ -48,6 +48,8 @@ struct vma_munmap_struct {
>  	unsigned long stack_vm;
>  	unsigned long data_vm;
>  	bool unlock;                    /* Unlock after the munmap */
> +	bool clear_ptes;                /* If there are outstanding PTE to be cleared */
> +	bool closed_vm_ops;		/* call_mmap() was encountered, so vmas may be closed */
>  };
>
>  #ifdef CONFIG_DEBUG_VM_MAPLE_TREE
> @@ -95,14 +97,13 @@ static inline void init_vma_munmap(struct vma_munmap_struct *vms,
>  		unsigned long start, unsigned long end, struct list_head *uf,
>  		bool unlock)
>  {
> +	vms->mm = current->mm;

I'm guessing there's no circumstances under which we'd be looking at a
remote mm_struct?

This does sort of beg the question as to why we're bothering to store the
field if we can't just grab it from current->mm? Perhaps because the cache
line for the start of vms will be populated and current's containing ->mm
may not?

>  	vms->vmi = vmi;
>  	vms->vma = vma;
>  	if (vma) {
> -		vms->mm = vma->vm_mm;
>  		vms->start = start;
>  		vms->end = end;
>  	} else {
> -		vms->mm = NULL;

I guess as well there's no drawback to having an otherwise empty vms have a
populated mm?

>  		vms->start = vms->end = 0;
>  	}
>  	vms->unlock = unlock;
> @@ -112,6 +113,8 @@ static inline void init_vma_munmap(struct vma_munmap_struct *vms,
>  	vms->exec_vm = vms->stack_vm = vms->data_vm = 0;
>  	vms->unmap_start = FIRST_USER_ADDRESS;
>  	vms->unmap_end = USER_PGTABLES_CEILING;
> +	vms->clear_ptes = false;
> +	vms->closed_vm_ops = false;
>  }
>
>  int vms_gather_munmap_vmas(struct vma_munmap_struct *vms,
> @@ -120,18 +123,24 @@ int vms_gather_munmap_vmas(struct vma_munmap_struct *vms,
>  void vms_complete_munmap_vmas(struct vma_munmap_struct *vms,
>  		struct ma_state *mas_detach);
>
> +void vms_clean_up_area(struct vma_munmap_struct *vms,
> +		struct ma_state *mas_detach, bool mm_wr_locked);
> +
>  /*
>   * abort_munmap_vmas - Undo any munmap work and free resources
>   *
>   * Reattach any detached vmas and free up the maple tree used to track the vmas.
>   */
> -static inline void abort_munmap_vmas(struct ma_state *mas_detach)
> +static inline void abort_munmap_vmas(struct ma_state *mas_detach, bool closed)
>  {
>  	struct vm_area_struct *vma;
>
>  	mas_set(mas_detach, 0);
> -	mas_for_each(mas_detach, vma, ULONG_MAX)
> +	mas_for_each(mas_detach, vma, ULONG_MAX) {
>  		vma_mark_detached(vma, false);
> +		if (closed && vma->vm_ops && vma->vm_ops->open)
> +			vma->vm_ops->open(vma);
> +	}

Hang on, I thought we eliminated this approach? OK I see you change this in
the next commmit.

Not necessarily a huge fan of having a commit in the tree that's broken for
(hideous, asymmetric) drivers + such but I guess it's okay given we address
it immediately and it helps document the thinking process + split up the
code.

>
>  	__mt_destroy(mas_detach->tree);
>  }
> @@ -145,7 +154,7 @@ int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm,
>  		  unsigned long start, size_t len, struct list_head *uf,
>  		  bool unlock);
>
> -void remove_vma(struct vm_area_struct *vma, bool unreachable);
> +void remove_vma(struct vm_area_struct *vma, bool unreachable, bool closed);
>
>  void unmap_region(struct ma_state *mas, struct vm_area_struct *vma,
>  		struct vm_area_struct *prev, struct vm_area_struct *next);
> @@ -259,7 +268,8 @@ bool vma_wants_writenotify(struct vm_area_struct *vma, pgprot_t vm_page_prot);
>  int mm_take_all_locks(struct mm_struct *mm);
>  void mm_drop_all_locks(struct mm_struct *mm);
>  unsigned long count_vma_pages_range(struct mm_struct *mm,
> -				    unsigned long addr, unsigned long end);
> +				    unsigned long addr, unsigned long end,
> +				    unsigned long *nr_accounted);
>
>  static inline bool vma_wants_manual_pte_write_upgrade(struct vm_area_struct *vma)
>  {
> --
> 2.43.0
>
Liam R. Howlett Aug. 21, 2024, 3:09 p.m. UTC | #2
* Lorenzo Stoakes <lorenzo.stoakes@oracle.com> [240821 07:02]:
> On Tue, Aug 20, 2024 at 07:57:23PM GMT, Liam R. Howlett wrote:
> > From: "Liam R. Howlett" <Liam.Howlett@Oracle.com>
> >
> > Instead of zeroing the vma tree and then overwriting the area, let the
> > area be overwritten and then clean up the gathered vmas using
> > vms_complete_munmap_vmas().
> >
> > To ensure locking is downgraded correctly, the mm is set regardless of
> > MAP_FIXED or not (NULL vma).
> >
> > If a driver is mapping over an existing vma, then clear the ptes before
> > the call_mmap() invocation.  This is done using the vms_clean_up_area()
> > helper.  If there is a close vm_ops, that must also be called to ensure
> > any cleanup is done before mapping over the area.  This also means that
> > calling open has been added to the abort of an unmap operation, for now.
> 
> Might be worth explicitly expanding this to say that this isn't a permanent
> solution because of asymmetric vm_ops->open() / close().

Yes, I will expand it.

> 
> >
> > Temporarily keep track of the number of pages that will be removed and
> > reduce the charged amount.
> >
> > This also drops the validate_mm() call in the vma_expand() function.
> > It is necessary to drop the validate as it would fail since the mm
> > map_count would be incorrect during a vma expansion, prior to the
> > cleanup from vms_complete_munmap_vmas().
> >
> > Clean up the error handing of the vms_gather_munmap_vmas() by calling
> > the verification within the function.
> >
> > Signed-off-by: Liam R. Howlett <Liam.Howlett@Oracle.com>
> 
> Broadly looks good, some nits and questions below, but generally:
> 
> Reviewed-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
> 
> > ---
> >  mm/mmap.c | 62 ++++++++++++++++++++++++++-----------------------------
> >  mm/vma.c  | 54 +++++++++++++++++++++++++++++++++++++-----------
> >  mm/vma.h  | 22 ++++++++++++++------
> >  3 files changed, 87 insertions(+), 51 deletions(-)
> >
> > diff --git a/mm/mmap.c b/mm/mmap.c
> > index 71b2bad717b6..6550d9470d3a 100644
> > --- a/mm/mmap.c
> > +++ b/mm/mmap.c
> > @@ -1373,23 +1373,19 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
> >  	unsigned long merge_start = addr, merge_end = end;
> >  	bool writable_file_mapping = false;
> >  	pgoff_t vm_pgoff;
> > -	int error;
> > +	int error = -ENOMEM;
> >  	VMA_ITERATOR(vmi, mm, addr);
> > +	unsigned long nr_pages, nr_accounted;
> >
> > -	/* Check against address space limit. */
> > -	if (!may_expand_vm(mm, vm_flags, len >> PAGE_SHIFT)) {
> > -		unsigned long nr_pages;
> > +	nr_pages = count_vma_pages_range(mm, addr, end, &nr_accounted);
> >
> > -		/*
> > -		 * MAP_FIXED may remove pages of mappings that intersects with
> > -		 * requested mapping. Account for the pages it would unmap.
> > -		 */
> > -		nr_pages = count_vma_pages_range(mm, addr, end);
> > -
> > -		if (!may_expand_vm(mm, vm_flags,
> > -					(len >> PAGE_SHIFT) - nr_pages))
> > -			return -ENOMEM;
> > -	}
> > +	/*
> > +	 * Check against address space limit.
> > +	 * MAP_FIXED may remove pages of mappings that intersects with requested
> > +	 * mapping. Account for the pages it would unmap.
> > +	 */
> > +	if (!may_expand_vm(mm, vm_flags, (len >> PAGE_SHIFT) - nr_pages))
> > +		return -ENOMEM;
> >
> >  	/* Find the first overlapping VMA */
> >  	vma = vma_find(&vmi, end);
> > @@ -1400,14 +1396,8 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
> >  		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
> >  		/* Prepare to unmap any existing mapping in the area */
> >  		if (vms_gather_munmap_vmas(&vms, &mas_detach))
> > -			goto gather_failed;
> > -
> > -		/* Remove any existing mappings from the vma tree */
> > -		if (vma_iter_clear_gfp(&vmi, addr, end, GFP_KERNEL))
> > -			goto clear_tree_failed;
> > +			return -ENOMEM;
> >
> > -		/* Unmap any existing mapping in the area */
> > -		vms_complete_munmap_vmas(&vms, &mas_detach);
> >  		next = vms.next;
> >  		prev = vms.prev;
> >  		vma = NULL;
> > @@ -1423,8 +1413,10 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
> >  	 */
> >  	if (accountable_mapping(file, vm_flags)) {
> >  		charged = len >> PAGE_SHIFT;
> > +		charged -= nr_accounted;
> >  		if (security_vm_enough_memory_mm(mm, charged))
> > -			return -ENOMEM;
> > +			goto abort_munmap;
> > +		vms.nr_accounted = 0;
> >  		vm_flags |= VM_ACCOUNT;
> >  	}
> >
> > @@ -1473,10 +1465,8 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
> >  	 * not unmapped, but the maps are removed from the list.
> >  	 */
> >  	vma = vm_area_alloc(mm);
> > -	if (!vma) {
> > -		error = -ENOMEM;
> > +	if (!vma)
> >  		goto unacct_error;
> > -	}
> >
> >  	vma_iter_config(&vmi, addr, end);
> >  	vma_set_range(vma, addr, end, pgoff);
> > @@ -1485,6 +1475,11 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
> >
> >  	if (file) {
> >  		vma->vm_file = get_file(file);
> > +		/*
> > +		 * call_mmap() may map PTE, so ensure there are no existing PTEs
> > +		 * call the vm_ops close function if one exists.
> 
> Super-nit, but maybe add an 'and' here.

I'm re-spinning anyways - thanks.

> 
> > +		 */
> > +		vms_clean_up_area(&vms, &mas_detach, true);
> 
> I hate that we have to do this. These kind of hooks are the devil's works...
> 
> >  		error = call_mmap(file, vma);
> >  		if (error)
> >  			goto unmap_and_free_vma;
> > @@ -1575,6 +1570,9 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
> >  expanded:
> >  	perf_event_mmap(vma);
> >
> > +	/* Unmap any existing mapping in the area */
> > +	vms_complete_munmap_vmas(&vms, &mas_detach);
> > +
> >  	vm_stat_account(mm, vm_flags, len >> PAGE_SHIFT);
> >  	if (vm_flags & VM_LOCKED) {
> >  		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
> > @@ -1603,7 +1601,7 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
> >  	return addr;
> >
> >  close_and_free_vma:
> > -	if (file && vma->vm_ops && vma->vm_ops->close)
> > +	if (file && !vms.closed_vm_ops && vma->vm_ops && vma->vm_ops->close)
> >  		vma->vm_ops->close(vma);
> >
> >  	if (file || vma->vm_file) {
> > @@ -1622,14 +1620,12 @@ unsigned long mmap_region(struct file *file, unsigned long addr,
> >  unacct_error:
> >  	if (charged)
> >  		vm_unacct_memory(charged);
> > -	validate_mm(mm);
> > -	return error;
> >
> > -clear_tree_failed:
> > -	abort_munmap_vmas(&mas_detach);
> > -gather_failed:
> > +abort_munmap:
> > +	if (vms.nr_pages)
> > +		abort_munmap_vmas(&mas_detach, vms.closed_vm_ops);
> >  	validate_mm(mm);
> > -	return -ENOMEM;
> > +	return error;
> >  }
> >
> >  static int __vm_munmap(unsigned long start, size_t len, bool unlock)
> > @@ -1959,7 +1955,7 @@ void exit_mmap(struct mm_struct *mm)
> >  	do {
> >  		if (vma->vm_flags & VM_ACCOUNT)
> >  			nr_accounted += vma_pages(vma);
> > -		remove_vma(vma, true);
> > +		remove_vma(vma, /* unreachable = */ true, /* closed = */ false);
> >  		count++;
> >  		cond_resched();
> >  		vma = vma_next(&vmi);
> > diff --git a/mm/vma.c b/mm/vma.c
> > index 7104c2c080bb..5b33f7460ab7 100644
> > --- a/mm/vma.c
> > +++ b/mm/vma.c
> > @@ -136,10 +136,10 @@ can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
> >  /*
> >   * Close a vm structure and free it.
> >   */
> > -void remove_vma(struct vm_area_struct *vma, bool unreachable)
> > +void remove_vma(struct vm_area_struct *vma, bool unreachable, bool closed)
> >  {
> >  	might_sleep();
> > -	if (vma->vm_ops && vma->vm_ops->close)
> > +	if (!closed && vma->vm_ops && vma->vm_ops->close)
> >  		vma->vm_ops->close(vma);
> >  	if (vma->vm_file)
> >  		fput(vma->vm_file);
> > @@ -521,7 +521,6 @@ int vma_expand(struct vma_iterator *vmi, struct vm_area_struct *vma,
> >  	vma_iter_store(vmi, vma);
> >
> >  	vma_complete(&vp, vmi, vma->vm_mm);
> > -	validate_mm(vma->vm_mm);
> >  	return 0;
> >
> >  nomem:
> > @@ -645,11 +644,14 @@ void vma_complete(struct vma_prepare *vp,
> >  		uprobe_mmap(vp->insert);
> >  }
> >
> > -static void vms_complete_pte_clear(struct vma_munmap_struct *vms,
> > -		struct ma_state *mas_detach, bool mm_wr_locked)
> > +static inline void vms_clear_ptes(struct vma_munmap_struct *vms,
> > +		    struct ma_state *mas_detach, bool mm_wr_locked)
> >  {
> >  	struct mmu_gather tlb;
> >
> > +	if (!vms->clear_ptes) /* Nothing to do */
> > +		return;
> > +
> >  	/*
> >  	 * We can free page tables without write-locking mmap_lock because VMAs
> >  	 * were isolated before we downgraded mmap_lock.
> > @@ -658,11 +660,31 @@ static void vms_complete_pte_clear(struct vma_munmap_struct *vms,
> >  	lru_add_drain();
> >  	tlb_gather_mmu(&tlb, vms->mm);
> >  	update_hiwater_rss(vms->mm);
> > -	unmap_vmas(&tlb, mas_detach, vms->vma, vms->start, vms->end, vms->vma_count, mm_wr_locked);
> > +	unmap_vmas(&tlb, mas_detach, vms->vma, vms->start, vms->end,
> > +		   vms->vma_count, mm_wr_locked);
> > +
> >  	mas_set(mas_detach, 1);
> >  	/* start and end may be different if there is no prev or next vma. */
> > -	free_pgtables(&tlb, mas_detach, vms->vma, vms->unmap_start, vms->unmap_end, mm_wr_locked);
> > +	free_pgtables(&tlb, mas_detach, vms->vma, vms->unmap_start,
> > +		      vms->unmap_end, mm_wr_locked);
> >  	tlb_finish_mmu(&tlb);
> > +	vms->clear_ptes = false;
> > +}
> > +
> > +void vms_clean_up_area(struct vma_munmap_struct *vms,
> > +		struct ma_state *mas_detach, bool mm_wr_locked)
> 
> The only invocation of this function has mm_wr_locked set, is this
> parameter necessary?

I'll remove this, and replace the in-function pass-through with the
constant true.

> 
> > +{
> > +	struct vm_area_struct *vma;
> > +
> > +	if (!vms->nr_pages)
> > +		return;
> > +
> > +	vms_clear_ptes(vms, mas_detach, mm_wr_locked);
> > +	mas_set(mas_detach, 0);
> > +	mas_for_each(mas_detach, vma, ULONG_MAX)
> > +		if (vma->vm_ops && vma->vm_ops->close)
> > +			vma->vm_ops->close(vma);
> > +	vms->closed_vm_ops = true;
> >  }
> >
> >  /*
> > @@ -686,7 +708,10 @@ void vms_complete_munmap_vmas(struct vma_munmap_struct *vms,
> >  	if (vms->unlock)
> >  		mmap_write_downgrade(mm);
> >
> > -	vms_complete_pte_clear(vms, mas_detach, !vms->unlock);
> > +	if (!vms->nr_pages)
> > +		return;
> > +
> > +	vms_clear_ptes(vms, mas_detach, !vms->unlock);
> >  	/* Update high watermark before we lower total_vm */
> >  	update_hiwater_vm(mm);
> >  	/* Stat accounting */
> > @@ -697,7 +722,7 @@ void vms_complete_munmap_vmas(struct vma_munmap_struct *vms,
> >  	/* Remove and clean up vmas */
> >  	mas_set(mas_detach, 0);
> >  	mas_for_each(mas_detach, vma, ULONG_MAX)
> > -		remove_vma(vma, false);
> > +		remove_vma(vma, /* = */ false, vms->closed_vm_ops);
> >
> >  	vm_unacct_memory(vms->nr_accounted);
> >  	validate_mm(mm);
> > @@ -849,13 +874,14 @@ int vms_gather_munmap_vmas(struct vma_munmap_struct *vms,
> >  	while (vma_iter_addr(vms->vmi) > vms->start)
> >  		vma_iter_prev_range(vms->vmi);
> >
> > +	vms->clear_ptes = true;
> >  	return 0;
> >
> >  userfaultfd_error:
> >  munmap_gather_failed:
> >  end_split_failed:
> >  modify_vma_failed:
> > -	abort_munmap_vmas(mas_detach);
> > +	abort_munmap_vmas(mas_detach, /* closed = */ false);
> >  start_split_failed:
> >  map_count_exceeded:
> >  	return error;
> > @@ -900,7 +926,7 @@ int do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
> >  	return 0;
> >
> >  clear_tree_failed:
> > -	abort_munmap_vmas(&mas_detach);
> > +	abort_munmap_vmas(&mas_detach, /* closed = */ false);
> >  gather_failed:
> >  	validate_mm(mm);
> >  	return error;
> > @@ -1618,17 +1644,21 @@ bool vma_wants_writenotify(struct vm_area_struct *vma, pgprot_t vm_page_prot)
> >  }
> >
> >  unsigned long count_vma_pages_range(struct mm_struct *mm,
> > -				    unsigned long addr, unsigned long end)
> > +		unsigned long addr, unsigned long end,
> > +		unsigned long *nr_accounted)
> >  {
> >  	VMA_ITERATOR(vmi, mm, addr);
> >  	struct vm_area_struct *vma;
> >  	unsigned long nr_pages = 0;
> >
> > +	*nr_accounted = 0;
> >  	for_each_vma_range(vmi, vma, end) {
> >  		unsigned long vm_start = max(addr, vma->vm_start);
> >  		unsigned long vm_end = min(end, vma->vm_end);
> >
> >  		nr_pages += PHYS_PFN(vm_end - vm_start);
> > +		if (vma->vm_flags & VM_ACCOUNT)
> > +			*nr_accounted += PHYS_PFN(vm_end - vm_start);
> 
> Nitty, but maybe:
> 		...
>   		unsigned long pages = PHYS_PFN(vm_end - vm_start);
> 
>  		nr_pages += pages;
> 		if (vma->vm_flags & VM_ACCOUNT)
> 			*nr_accounted += pages;
> 

This is a temporary state as count_vma_pages_range() is removed shortly
after.

> >  	}
> >
> >  	return nr_pages;
> > diff --git a/mm/vma.h b/mm/vma.h
> > index 6028fdf79257..756dd42a6ec4 100644
> > --- a/mm/vma.h
> > +++ b/mm/vma.h
> > @@ -48,6 +48,8 @@ struct vma_munmap_struct {
> >  	unsigned long stack_vm;
> >  	unsigned long data_vm;
> >  	bool unlock;                    /* Unlock after the munmap */
> > +	bool clear_ptes;                /* If there are outstanding PTE to be cleared */
> > +	bool closed_vm_ops;		/* call_mmap() was encountered, so vmas may be closed */
> >  };
> >
> >  #ifdef CONFIG_DEBUG_VM_MAPLE_TREE
> > @@ -95,14 +97,13 @@ static inline void init_vma_munmap(struct vma_munmap_struct *vms,
> >  		unsigned long start, unsigned long end, struct list_head *uf,
> >  		bool unlock)
> >  {
> > +	vms->mm = current->mm;
> 
> I'm guessing there's no circumstances under which we'd be looking at a
> remote mm_struct?

Yes, we always unmap things from the mm that contain them.

> 
> This does sort of beg the question as to why we're bothering to store the
> field if we can't just grab it from current->mm? Perhaps because the cache
> line for the start of vms will be populated and current's containing ->mm
> may not?

If I don't do this here, then vms_clear_ptes(),
vms_gather_munmap_vmas(), and vms_complete_munmap_vmas() will need this.
It's not critical, but it means we're doing it twice instead of putting
it in the sturct that is used in those functions already..

Actually, I think I will do this so that I can reduce the cachelines of
the vma_munmap_struct by this reduction and rearranging.

I can use the vms->vma->vm_mm for vms_clear_ptes() and the gather case.

> 
> >  	vms->vmi = vmi;
> >  	vms->vma = vma;
> >  	if (vma) {
> > -		vms->mm = vma->vm_mm;
> >  		vms->start = start;
> >  		vms->end = end;
> >  	} else {
> > -		vms->mm = NULL;
> 
> I guess as well there's no drawback to having an otherwise empty vms have a
> populated mm?

It's actually needed because we may need to downgrade the mm lock, so
having it set makes vms_complete_munmap_vmas() cleaner.

But I'm removing it, so it doesn't really matter.

> 
> >  		vms->start = vms->end = 0;
> >  	}
> >  	vms->unlock = unlock;
> > @@ -112,6 +113,8 @@ static inline void init_vma_munmap(struct vma_munmap_struct *vms,
> >  	vms->exec_vm = vms->stack_vm = vms->data_vm = 0;
> >  	vms->unmap_start = FIRST_USER_ADDRESS;
> >  	vms->unmap_end = USER_PGTABLES_CEILING;
> > +	vms->clear_ptes = false;
> > +	vms->closed_vm_ops = false;
> >  }
> >
> >  int vms_gather_munmap_vmas(struct vma_munmap_struct *vms,
> > @@ -120,18 +123,24 @@ int vms_gather_munmap_vmas(struct vma_munmap_struct *vms,
> >  void vms_complete_munmap_vmas(struct vma_munmap_struct *vms,
> >  		struct ma_state *mas_detach);
> >
> > +void vms_clean_up_area(struct vma_munmap_struct *vms,
> > +		struct ma_state *mas_detach, bool mm_wr_locked);
> > +
> >  /*
> >   * abort_munmap_vmas - Undo any munmap work and free resources
> >   *
> >   * Reattach any detached vmas and free up the maple tree used to track the vmas.
> >   */
> > -static inline void abort_munmap_vmas(struct ma_state *mas_detach)
> > +static inline void abort_munmap_vmas(struct ma_state *mas_detach, bool closed)
> >  {
> >  	struct vm_area_struct *vma;
> >
> >  	mas_set(mas_detach, 0);
> > -	mas_for_each(mas_detach, vma, ULONG_MAX)
> > +	mas_for_each(mas_detach, vma, ULONG_MAX) {
> >  		vma_mark_detached(vma, false);
> > +		if (closed && vma->vm_ops && vma->vm_ops->open)
> > +			vma->vm_ops->open(vma);
> > +	}
> 
> Hang on, I thought we eliminated this approach? OK I see you change this in
> the next commmit.
> 
> Not necessarily a huge fan of having a commit in the tree that's broken for
> (hideous, asymmetric) drivers + such but I guess it's okay given we address
> it immediately and it helps document the thinking process + split up the
> code.

Yes, this is really to show where things would happen - and for this to
be an issue with bisection, one would have to test the failure paths and
hit a case where the failure would cause issues if the vma was closed
and re-opened.  I think this is sufficiently rare.

> 
> >
> >  	__mt_destroy(mas_detach->tree);
> >  }
> > @@ -145,7 +154,7 @@ int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm,
> >  		  unsigned long start, size_t len, struct list_head *uf,
> >  		  bool unlock);
> >
> > -void remove_vma(struct vm_area_struct *vma, bool unreachable);
> > +void remove_vma(struct vm_area_struct *vma, bool unreachable, bool closed);
> >
> >  void unmap_region(struct ma_state *mas, struct vm_area_struct *vma,
> >  		struct vm_area_struct *prev, struct vm_area_struct *next);
> > @@ -259,7 +268,8 @@ bool vma_wants_writenotify(struct vm_area_struct *vma, pgprot_t vm_page_prot);
> >  int mm_take_all_locks(struct mm_struct *mm);
> >  void mm_drop_all_locks(struct mm_struct *mm);
> >  unsigned long count_vma_pages_range(struct mm_struct *mm,
> > -				    unsigned long addr, unsigned long end);
> > +				    unsigned long addr, unsigned long end,
> > +				    unsigned long *nr_accounted);
> >
> >  static inline bool vma_wants_manual_pte_write_upgrade(struct vm_area_struct *vma)
> >  {
> > --
> > 2.43.0
> >
diff mbox series

Patch

diff --git a/mm/mmap.c b/mm/mmap.c
index 71b2bad717b6..6550d9470d3a 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -1373,23 +1373,19 @@  unsigned long mmap_region(struct file *file, unsigned long addr,
 	unsigned long merge_start = addr, merge_end = end;
 	bool writable_file_mapping = false;
 	pgoff_t vm_pgoff;
-	int error;
+	int error = -ENOMEM;
 	VMA_ITERATOR(vmi, mm, addr);
+	unsigned long nr_pages, nr_accounted;
 
-	/* Check against address space limit. */
-	if (!may_expand_vm(mm, vm_flags, len >> PAGE_SHIFT)) {
-		unsigned long nr_pages;
+	nr_pages = count_vma_pages_range(mm, addr, end, &nr_accounted);
 
-		/*
-		 * MAP_FIXED may remove pages of mappings that intersects with
-		 * requested mapping. Account for the pages it would unmap.
-		 */
-		nr_pages = count_vma_pages_range(mm, addr, end);
-
-		if (!may_expand_vm(mm, vm_flags,
-					(len >> PAGE_SHIFT) - nr_pages))
-			return -ENOMEM;
-	}
+	/*
+	 * Check against address space limit.
+	 * MAP_FIXED may remove pages of mappings that intersects with requested
+	 * mapping. Account for the pages it would unmap.
+	 */
+	if (!may_expand_vm(mm, vm_flags, (len >> PAGE_SHIFT) - nr_pages))
+		return -ENOMEM;
 
 	/* Find the first overlapping VMA */
 	vma = vma_find(&vmi, end);
@@ -1400,14 +1396,8 @@  unsigned long mmap_region(struct file *file, unsigned long addr,
 		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
 		/* Prepare to unmap any existing mapping in the area */
 		if (vms_gather_munmap_vmas(&vms, &mas_detach))
-			goto gather_failed;
-
-		/* Remove any existing mappings from the vma tree */
-		if (vma_iter_clear_gfp(&vmi, addr, end, GFP_KERNEL))
-			goto clear_tree_failed;
+			return -ENOMEM;
 
-		/* Unmap any existing mapping in the area */
-		vms_complete_munmap_vmas(&vms, &mas_detach);
 		next = vms.next;
 		prev = vms.prev;
 		vma = NULL;
@@ -1423,8 +1413,10 @@  unsigned long mmap_region(struct file *file, unsigned long addr,
 	 */
 	if (accountable_mapping(file, vm_flags)) {
 		charged = len >> PAGE_SHIFT;
+		charged -= nr_accounted;
 		if (security_vm_enough_memory_mm(mm, charged))
-			return -ENOMEM;
+			goto abort_munmap;
+		vms.nr_accounted = 0;
 		vm_flags |= VM_ACCOUNT;
 	}
 
@@ -1473,10 +1465,8 @@  unsigned long mmap_region(struct file *file, unsigned long addr,
 	 * not unmapped, but the maps are removed from the list.
 	 */
 	vma = vm_area_alloc(mm);
-	if (!vma) {
-		error = -ENOMEM;
+	if (!vma)
 		goto unacct_error;
-	}
 
 	vma_iter_config(&vmi, addr, end);
 	vma_set_range(vma, addr, end, pgoff);
@@ -1485,6 +1475,11 @@  unsigned long mmap_region(struct file *file, unsigned long addr,
 
 	if (file) {
 		vma->vm_file = get_file(file);
+		/*
+		 * call_mmap() may map PTE, so ensure there are no existing PTEs
+		 * call the vm_ops close function if one exists.
+		 */
+		vms_clean_up_area(&vms, &mas_detach, true);
 		error = call_mmap(file, vma);
 		if (error)
 			goto unmap_and_free_vma;
@@ -1575,6 +1570,9 @@  unsigned long mmap_region(struct file *file, unsigned long addr,
 expanded:
 	perf_event_mmap(vma);
 
+	/* Unmap any existing mapping in the area */
+	vms_complete_munmap_vmas(&vms, &mas_detach);
+
 	vm_stat_account(mm, vm_flags, len >> PAGE_SHIFT);
 	if (vm_flags & VM_LOCKED) {
 		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
@@ -1603,7 +1601,7 @@  unsigned long mmap_region(struct file *file, unsigned long addr,
 	return addr;
 
 close_and_free_vma:
-	if (file && vma->vm_ops && vma->vm_ops->close)
+	if (file && !vms.closed_vm_ops && vma->vm_ops && vma->vm_ops->close)
 		vma->vm_ops->close(vma);
 
 	if (file || vma->vm_file) {
@@ -1622,14 +1620,12 @@  unsigned long mmap_region(struct file *file, unsigned long addr,
 unacct_error:
 	if (charged)
 		vm_unacct_memory(charged);
-	validate_mm(mm);
-	return error;
 
-clear_tree_failed:
-	abort_munmap_vmas(&mas_detach);
-gather_failed:
+abort_munmap:
+	if (vms.nr_pages)
+		abort_munmap_vmas(&mas_detach, vms.closed_vm_ops);
 	validate_mm(mm);
-	return -ENOMEM;
+	return error;
 }
 
 static int __vm_munmap(unsigned long start, size_t len, bool unlock)
@@ -1959,7 +1955,7 @@  void exit_mmap(struct mm_struct *mm)
 	do {
 		if (vma->vm_flags & VM_ACCOUNT)
 			nr_accounted += vma_pages(vma);
-		remove_vma(vma, true);
+		remove_vma(vma, /* unreachable = */ true, /* closed = */ false);
 		count++;
 		cond_resched();
 		vma = vma_next(&vmi);
diff --git a/mm/vma.c b/mm/vma.c
index 7104c2c080bb..5b33f7460ab7 100644
--- a/mm/vma.c
+++ b/mm/vma.c
@@ -136,10 +136,10 @@  can_vma_merge_after(struct vm_area_struct *vma, unsigned long vm_flags,
 /*
  * Close a vm structure and free it.
  */
-void remove_vma(struct vm_area_struct *vma, bool unreachable)
+void remove_vma(struct vm_area_struct *vma, bool unreachable, bool closed)
 {
 	might_sleep();
-	if (vma->vm_ops && vma->vm_ops->close)
+	if (!closed && vma->vm_ops && vma->vm_ops->close)
 		vma->vm_ops->close(vma);
 	if (vma->vm_file)
 		fput(vma->vm_file);
@@ -521,7 +521,6 @@  int vma_expand(struct vma_iterator *vmi, struct vm_area_struct *vma,
 	vma_iter_store(vmi, vma);
 
 	vma_complete(&vp, vmi, vma->vm_mm);
-	validate_mm(vma->vm_mm);
 	return 0;
 
 nomem:
@@ -645,11 +644,14 @@  void vma_complete(struct vma_prepare *vp,
 		uprobe_mmap(vp->insert);
 }
 
-static void vms_complete_pte_clear(struct vma_munmap_struct *vms,
-		struct ma_state *mas_detach, bool mm_wr_locked)
+static inline void vms_clear_ptes(struct vma_munmap_struct *vms,
+		    struct ma_state *mas_detach, bool mm_wr_locked)
 {
 	struct mmu_gather tlb;
 
+	if (!vms->clear_ptes) /* Nothing to do */
+		return;
+
 	/*
 	 * We can free page tables without write-locking mmap_lock because VMAs
 	 * were isolated before we downgraded mmap_lock.
@@ -658,11 +660,31 @@  static void vms_complete_pte_clear(struct vma_munmap_struct *vms,
 	lru_add_drain();
 	tlb_gather_mmu(&tlb, vms->mm);
 	update_hiwater_rss(vms->mm);
-	unmap_vmas(&tlb, mas_detach, vms->vma, vms->start, vms->end, vms->vma_count, mm_wr_locked);
+	unmap_vmas(&tlb, mas_detach, vms->vma, vms->start, vms->end,
+		   vms->vma_count, mm_wr_locked);
+
 	mas_set(mas_detach, 1);
 	/* start and end may be different if there is no prev or next vma. */
-	free_pgtables(&tlb, mas_detach, vms->vma, vms->unmap_start, vms->unmap_end, mm_wr_locked);
+	free_pgtables(&tlb, mas_detach, vms->vma, vms->unmap_start,
+		      vms->unmap_end, mm_wr_locked);
 	tlb_finish_mmu(&tlb);
+	vms->clear_ptes = false;
+}
+
+void vms_clean_up_area(struct vma_munmap_struct *vms,
+		struct ma_state *mas_detach, bool mm_wr_locked)
+{
+	struct vm_area_struct *vma;
+
+	if (!vms->nr_pages)
+		return;
+
+	vms_clear_ptes(vms, mas_detach, mm_wr_locked);
+	mas_set(mas_detach, 0);
+	mas_for_each(mas_detach, vma, ULONG_MAX)
+		if (vma->vm_ops && vma->vm_ops->close)
+			vma->vm_ops->close(vma);
+	vms->closed_vm_ops = true;
 }
 
 /*
@@ -686,7 +708,10 @@  void vms_complete_munmap_vmas(struct vma_munmap_struct *vms,
 	if (vms->unlock)
 		mmap_write_downgrade(mm);
 
-	vms_complete_pte_clear(vms, mas_detach, !vms->unlock);
+	if (!vms->nr_pages)
+		return;
+
+	vms_clear_ptes(vms, mas_detach, !vms->unlock);
 	/* Update high watermark before we lower total_vm */
 	update_hiwater_vm(mm);
 	/* Stat accounting */
@@ -697,7 +722,7 @@  void vms_complete_munmap_vmas(struct vma_munmap_struct *vms,
 	/* Remove and clean up vmas */
 	mas_set(mas_detach, 0);
 	mas_for_each(mas_detach, vma, ULONG_MAX)
-		remove_vma(vma, false);
+		remove_vma(vma, /* = */ false, vms->closed_vm_ops);
 
 	vm_unacct_memory(vms->nr_accounted);
 	validate_mm(mm);
@@ -849,13 +874,14 @@  int vms_gather_munmap_vmas(struct vma_munmap_struct *vms,
 	while (vma_iter_addr(vms->vmi) > vms->start)
 		vma_iter_prev_range(vms->vmi);
 
+	vms->clear_ptes = true;
 	return 0;
 
 userfaultfd_error:
 munmap_gather_failed:
 end_split_failed:
 modify_vma_failed:
-	abort_munmap_vmas(mas_detach);
+	abort_munmap_vmas(mas_detach, /* closed = */ false);
 start_split_failed:
 map_count_exceeded:
 	return error;
@@ -900,7 +926,7 @@  int do_vmi_align_munmap(struct vma_iterator *vmi, struct vm_area_struct *vma,
 	return 0;
 
 clear_tree_failed:
-	abort_munmap_vmas(&mas_detach);
+	abort_munmap_vmas(&mas_detach, /* closed = */ false);
 gather_failed:
 	validate_mm(mm);
 	return error;
@@ -1618,17 +1644,21 @@  bool vma_wants_writenotify(struct vm_area_struct *vma, pgprot_t vm_page_prot)
 }
 
 unsigned long count_vma_pages_range(struct mm_struct *mm,
-				    unsigned long addr, unsigned long end)
+		unsigned long addr, unsigned long end,
+		unsigned long *nr_accounted)
 {
 	VMA_ITERATOR(vmi, mm, addr);
 	struct vm_area_struct *vma;
 	unsigned long nr_pages = 0;
 
+	*nr_accounted = 0;
 	for_each_vma_range(vmi, vma, end) {
 		unsigned long vm_start = max(addr, vma->vm_start);
 		unsigned long vm_end = min(end, vma->vm_end);
 
 		nr_pages += PHYS_PFN(vm_end - vm_start);
+		if (vma->vm_flags & VM_ACCOUNT)
+			*nr_accounted += PHYS_PFN(vm_end - vm_start);
 	}
 
 	return nr_pages;
diff --git a/mm/vma.h b/mm/vma.h
index 6028fdf79257..756dd42a6ec4 100644
--- a/mm/vma.h
+++ b/mm/vma.h
@@ -48,6 +48,8 @@  struct vma_munmap_struct {
 	unsigned long stack_vm;
 	unsigned long data_vm;
 	bool unlock;                    /* Unlock after the munmap */
+	bool clear_ptes;                /* If there are outstanding PTE to be cleared */
+	bool closed_vm_ops;		/* call_mmap() was encountered, so vmas may be closed */
 };
 
 #ifdef CONFIG_DEBUG_VM_MAPLE_TREE
@@ -95,14 +97,13 @@  static inline void init_vma_munmap(struct vma_munmap_struct *vms,
 		unsigned long start, unsigned long end, struct list_head *uf,
 		bool unlock)
 {
+	vms->mm = current->mm;
 	vms->vmi = vmi;
 	vms->vma = vma;
 	if (vma) {
-		vms->mm = vma->vm_mm;
 		vms->start = start;
 		vms->end = end;
 	} else {
-		vms->mm = NULL;
 		vms->start = vms->end = 0;
 	}
 	vms->unlock = unlock;
@@ -112,6 +113,8 @@  static inline void init_vma_munmap(struct vma_munmap_struct *vms,
 	vms->exec_vm = vms->stack_vm = vms->data_vm = 0;
 	vms->unmap_start = FIRST_USER_ADDRESS;
 	vms->unmap_end = USER_PGTABLES_CEILING;
+	vms->clear_ptes = false;
+	vms->closed_vm_ops = false;
 }
 
 int vms_gather_munmap_vmas(struct vma_munmap_struct *vms,
@@ -120,18 +123,24 @@  int vms_gather_munmap_vmas(struct vma_munmap_struct *vms,
 void vms_complete_munmap_vmas(struct vma_munmap_struct *vms,
 		struct ma_state *mas_detach);
 
+void vms_clean_up_area(struct vma_munmap_struct *vms,
+		struct ma_state *mas_detach, bool mm_wr_locked);
+
 /*
  * abort_munmap_vmas - Undo any munmap work and free resources
  *
  * Reattach any detached vmas and free up the maple tree used to track the vmas.
  */
-static inline void abort_munmap_vmas(struct ma_state *mas_detach)
+static inline void abort_munmap_vmas(struct ma_state *mas_detach, bool closed)
 {
 	struct vm_area_struct *vma;
 
 	mas_set(mas_detach, 0);
-	mas_for_each(mas_detach, vma, ULONG_MAX)
+	mas_for_each(mas_detach, vma, ULONG_MAX) {
 		vma_mark_detached(vma, false);
+		if (closed && vma->vm_ops && vma->vm_ops->open)
+			vma->vm_ops->open(vma);
+	}
 
 	__mt_destroy(mas_detach->tree);
 }
@@ -145,7 +154,7 @@  int do_vmi_munmap(struct vma_iterator *vmi, struct mm_struct *mm,
 		  unsigned long start, size_t len, struct list_head *uf,
 		  bool unlock);
 
-void remove_vma(struct vm_area_struct *vma, bool unreachable);
+void remove_vma(struct vm_area_struct *vma, bool unreachable, bool closed);
 
 void unmap_region(struct ma_state *mas, struct vm_area_struct *vma,
 		struct vm_area_struct *prev, struct vm_area_struct *next);
@@ -259,7 +268,8 @@  bool vma_wants_writenotify(struct vm_area_struct *vma, pgprot_t vm_page_prot);
 int mm_take_all_locks(struct mm_struct *mm);
 void mm_drop_all_locks(struct mm_struct *mm);
 unsigned long count_vma_pages_range(struct mm_struct *mm,
-				    unsigned long addr, unsigned long end);
+				    unsigned long addr, unsigned long end,
+				    unsigned long *nr_accounted);
 
 static inline bool vma_wants_manual_pte_write_upgrade(struct vm_area_struct *vma)
 {