diff mbox series

[hotfix,6.12,7/8] mm: refactor __mmap_region()

Message ID 125b6ebddc7ae8790b8b10b47906c2d39e68f3d9.1729628198.git.lorenzo.stoakes@oracle.com (mailing list archive)
State New
Headers show
Series fix error handling in mmap_region() and refactor | expand

Commit Message

Lorenzo Stoakes Oct. 22, 2024, 8:40 p.m. UTC
We have seen bugs and resource leaks arise from the complexity of the
__mmap_region() function. This, and the generally deeply fragile error
handling logic and complexity which makes understanding the function
difficult make it highly desirable to refactor it into something readable.

Achieve this by separating the function into smaller logical parts which
are easier to understand and follow, and which importantly very
significantly simplify the error handling.

Note that we now call vms_abort_munmap_vmas() in more error paths than we
used to, however in cases where no abort need occur, vms->nr_pages will be
equal to zero and we simply exit this function without doing more than we
would have done previously.

Importantly, the invocation of the driver mmap hook via mmap_file() now has
very simple and obvious handling (this was previously the most problematic
part of the mmap() operation).

Signed-off-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
---
 mm/vma.c | 380 +++++++++++++++++++++++++++++++++++--------------------
 1 file changed, 240 insertions(+), 140 deletions(-)

--
2.47.0

Comments

Vlastimil Babka Oct. 23, 2024, 2:38 p.m. UTC | #1
On 10/22/24 22:40, Lorenzo Stoakes wrote:
> We have seen bugs and resource leaks arise from the complexity of the
> __mmap_region() function. This, and the generally deeply fragile error
> handling logic and complexity which makes understanding the function
> difficult make it highly desirable to refactor it into something readable.
> 
> Achieve this by separating the function into smaller logical parts which
> are easier to understand and follow, and which importantly very
> significantly simplify the error handling.
> 
> Note that we now call vms_abort_munmap_vmas() in more error paths than we
> used to, however in cases where no abort need occur, vms->nr_pages will be
> equal to zero and we simply exit this function without doing more than we
> would have done previously.
> 
> Importantly, the invocation of the driver mmap hook via mmap_file() now has
> very simple and obvious handling (this was previously the most problematic
> part of the mmap() operation).
> 
> Signed-off-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
> ---
>  mm/vma.c | 380 +++++++++++++++++++++++++++++++++++--------------------
>  1 file changed, 240 insertions(+), 140 deletions(-)
> 
> diff --git a/mm/vma.c b/mm/vma.c
> index 7617f9d50d62..a271e2b406ab 100644
> --- a/mm/vma.c
> +++ b/mm/vma.c
> @@ -7,6 +7,31 @@
>  #include "vma_internal.h"
>  #include "vma.h"
> 
> +struct mmap_state {
> +	struct mm_struct *mm;
> +	struct vma_iterator *vmi;
> +	struct vma_merge_struct *vmg;
> +	struct list_head *uf;
> +
> +	struct vma_munmap_struct vms;
> +	struct ma_state mas_detach;
> +	struct maple_tree mt_detach;
> +
> +	unsigned long flags;
> +	unsigned long pglen;
> +	unsigned long charged;
> +};
> +
> +#define MMAP_STATE(name, mm_, vmi_, vmg_, uf_, flags_, len_)	\
> +	struct mmap_state name = {				\
> +		.mm = mm_,					\
> +		.vmi = vmi_,					\
> +		.vmg = vmg_,					\
> +		.uf = uf_,					\
> +		.flags = flags_,				\
> +		.pglen = PHYS_PFN(len_),			\
> +	}
> +
>  static inline bool is_mergeable_vma(struct vma_merge_struct *vmg, bool merge_next)
>  {
>  	struct vm_area_struct *vma = merge_next ? vmg->next : vmg->prev;
> @@ -2169,189 +2194,247 @@ static void vms_abort_munmap_vmas(struct vma_munmap_struct *vms,
>  	vms_complete_munmap_vmas(vms, mas_detach);
>  }
> 
> -unsigned long __mmap_region(struct file *file, unsigned long addr,
> -		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> -		struct list_head *uf)
> +/*
> + * __mmap_prepare() - Prepare to gather any overlapping VMAs that need to be
> + *                    unmapped once the map operation is completed, check limits,
> + *                    account mapping and clean up any pre-existing VMAs.
> + *
> + * @map: Mapping state.
> + *
> + * Returns: 0 on success, error code otherwise.
> + */
> +static int __mmap_prepare(struct mmap_state *map)
>  {
> -	struct mm_struct *mm = current->mm;
> -	struct vm_area_struct *vma = NULL;
> -	pgoff_t pglen = PHYS_PFN(len);
> -	unsigned long charged = 0;
> -	struct vma_munmap_struct vms;
> -	struct ma_state mas_detach;
> -	struct maple_tree mt_detach;
> -	unsigned long end = addr + len;
>  	int error;
> -	VMA_ITERATOR(vmi, mm, addr);
> -	VMG_STATE(vmg, mm, &vmi, addr, end, vm_flags, pgoff);
> -
> -	vmg.file = file;
> -	/* Find the first overlapping VMA */
> -	vma = vma_find(&vmi, end);
> -	init_vma_munmap(&vms, &vmi, vma, addr, end, uf, /* unlock = */ false);
> -	if (vma) {
> -		mt_init_flags(&mt_detach, vmi.mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> -		mt_on_stack(mt_detach);
> -		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
> +	struct vma_iterator *vmi = map->vmi;
> +	struct vma_merge_struct *vmg = map->vmg;
> +	struct vma_munmap_struct *vms = &map->vms;
> +
> +	/* Find the first overlapping VMA and initialise unmap state. */
> +	vms->vma = vma_find(vmi, vmg->end);
> +	init_vma_munmap(vms, vmi, vms->vma, vmg->start, vmg->end, map->uf,
> +			/* unlock = */ false);
> +
> +	/* OK, we have overlapping VMAs - prepare to unmap them. */
> +	if (vms->vma) {
> +		mt_init_flags(&map->mt_detach, vmi->mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> +		mt_on_stack(map->mt_detach);
> +		mas_init(&map->mas_detach, &map->mt_detach, /* addr = */ 0);
>  		/* Prepare to unmap any existing mapping in the area */
> -		error = vms_gather_munmap_vmas(&vms, &mas_detach);
> +		error = vms_gather_munmap_vmas(vms, &map->mas_detach);
>  		if (error)
> -			goto gather_failed;
> +			return error;

So this assumes vms_abort_munmap_vmas() will rely on the "vms->nr_pages will
be equal to zero" mentioned in commit log. But AFAICS
vms_gather_munmap_vmas() can fail in Nth iteration of its
for_each_vma_range() after some iterations already increased nr_pages and it
will do a reattach_vmas() and return the error and we just return the error
here.
I think either here or maybe in vms_gather_munmap_vmas() itself a reset of
vms->nr_pages to zero on error should happen for the vms_abort_munmap_vmas()
to be a no-op?

> 
> -		vmg.next = vms.next;
> -		vmg.prev = vms.prev;
> -		vma = NULL;
> +		vmg->next = vms->next;
> +		vmg->prev = vms->prev;
>  	} else {
> -		vmg.next = vma_iter_next_rewind(&vmi, &vmg.prev);
> +		vmg->next = vma_iter_next_rewind(vmi, &vmg->prev);
>  	}
> 
>  	/* Check against address space limit. */
> -	if (!may_expand_vm(mm, vm_flags, pglen - vms.nr_pages)) {
> -		error = -ENOMEM;
> -		goto abort_munmap;
> -	}
> +	if (!may_expand_vm(map->mm, map->flags, map->pglen - vms->nr_pages))
> +		return -ENOMEM;
> 
> -	/*
> -	 * Private writable mapping: check memory availability
> -	 */
> -	if (accountable_mapping(file, vm_flags)) {
> -		charged = pglen;
> -		charged -= vms.nr_accounted;
> -		if (charged) {
> -			error = security_vm_enough_memory_mm(mm, charged);
> +	/* Private writable mapping: check memory availability. */
> +	if (accountable_mapping(vmg->file, map->flags)) {
> +		map->charged = map->pglen;
> +		map->charged -= vms->nr_accounted;
> +		if (map->charged) {
> +			error = security_vm_enough_memory_mm(map->mm, map->charged);
>  			if (error)
> -				goto abort_munmap;
> +				return error;
>  		}
> 
> -		vms.nr_accounted = 0;
> -		vm_flags |= VM_ACCOUNT;
> -		vmg.flags = vm_flags;
> +		vms->nr_accounted = 0;
> +		map->flags |= VM_ACCOUNT;
>  	}
> 
>  	/*
> -	 * clear PTEs while the vma is still in the tree so that rmap
> +	 * Clear PTEs while the vma is still in the tree so that rmap
>  	 * cannot race with the freeing later in the truncate scenario.
>  	 * This is also needed for mmap_file(), which is why vm_ops
>  	 * close function is called.
>  	 */
> -	vms_clean_up_area(&vms, &mas_detach);
> -	vma = vma_merge_new_range(&vmg);
> -	if (vma)
> -		goto expanded;
> +	vms_clean_up_area(vms, &map->mas_detach);
> +
> +	return 0;
> +}
> +
> +static int __mmap_new_file_vma(struct mmap_state *map, struct vm_area_struct *vma,
> +			       struct vm_area_struct **mergep)
> +{
> +	struct vma_iterator *vmi = map->vmi;
> +	struct vma_merge_struct *vmg = map->vmg;
> +	int error;
> +
> +	vma->vm_file = get_file(vmg->file);
> +	error = mmap_file(vma->vm_file, vma);
> +	if (error) {
> +		fput(vma->vm_file);
> +		vma->vm_file = NULL;
> +
> +		vma_iter_set(vmi, vma->vm_end);
> +		/* Undo any partial mapping done by a device driver. */
> +		unmap_region(&vmi->mas, vma, vmg->prev, vmg->next);
> +
> +		return error;
> +	}
> +
> +	/* Drivers cannot alter the address of the VMA. */
> +	WARN_ON_ONCE(vmg->start != vma->vm_start);
> +	/*
> +	 * Drivers should not permit writability when previously it was
> +	 * disallowed.
> +	 */
> +	VM_WARN_ON_ONCE(map->flags != vma->vm_flags &&
> +			!(map->flags & VM_MAYWRITE) &&
> +			(vma->vm_flags & VM_MAYWRITE));
> +
> +	vma_iter_config(vmi, vmg->start, vmg->end);
> +	/*
> +	 * If flags changed after mmap_file(), we should try merge
> +	 * vma again as we may succeed this time.
> +	 */
> +	if (unlikely(map->flags != vma->vm_flags && vmg->prev)) {
> +		struct vm_area_struct *merge;
> +
> +		vmg->flags = vma->vm_flags;
> +		/* If this fails, state is reset ready for a reattempt. */
> +		merge = vma_merge_new_range(vmg);
> +
> +		if (merge) {
> +			/*
> +			 * ->mmap() can change vma->vm_file and fput
> +			 * the original file. So fput the vma->vm_file
> +			 * here or we would add an extra fput for file
> +			 * and cause general protection fault
> +			 * ultimately.
> +			 */
> +			fput(vma->vm_file);
> +			vm_area_free(vma);

This frees the vma.

> +			vma_iter_free(vmi);
> +			*mergep = merge;
> +		} else {
> +			vma_iter_config(vmi, vmg->start, vmg->end);
> +		}
> +	}
> +
> +	map->flags = vma->vm_flags;

So this is use-after-free.

Maybe pass only a single struct vm_area_struct **vmap parameter to this
function, and in case of merge, change both vma and *vmap to it?.

Although I can see it's all moot after 8/8. Still let's not introduce a
temporary UAF step.

> +	return 0;
> +}
> +
> +/*
> + * __mmap_new_vma() - Allocate a new VMA for the region, as merging was not
> + *                    possible.
> + *
> + *                    An exception to this is if the mapping is file-backed, and
> + *                    the underlying driver changes the VMA flags, permitting a
> + *                    subsequent merge of the VMA, in which case the returned
> + *                    VMA is one that was merged on a second attempt.
> + *
> + * @map:  Mapping state.
> + * @vmap: Output pointer for the new VMA.
> + *
> + * Returns: Zero on success, or an error.
> + */
> +static int __mmap_new_vma(struct mmap_state *map, struct vm_area_struct **vmap)
> +{
> +	struct vma_iterator *vmi = map->vmi;
> +	struct vma_merge_struct *vmg = map->vmg;
> +	struct vm_area_struct *merge = NULL;
> +	int error = 0;
> +	struct vm_area_struct *vma;
> +
>  	/*
>  	 * Determine the object being mapped and call the appropriate
>  	 * specific mapper. the address has already been validated, but
>  	 * not unmapped, but the maps are removed from the list.
>  	 */
> -	vma = vm_area_alloc(mm);
> -	if (!vma) {
> -		error = -ENOMEM;
> -		goto unacct_error;
> -	}
> +	vma = vm_area_alloc(map->mm);
> +	if (!vma)
> +		return -ENOMEM;
> 
> -	vma_iter_config(&vmi, addr, end);
> -	vma_set_range(vma, addr, end, pgoff);
> -	vm_flags_init(vma, vm_flags);
> -	vma->vm_page_prot = vm_get_page_prot(vm_flags);
> +	vma_iter_config(vmi, vmg->start, vmg->end);
> +	vma_set_range(vma, vmg->start, vmg->end, vmg->pgoff);
> +	vm_flags_init(vma, map->flags);
> +	vma->vm_page_prot = vm_get_page_prot(map->flags);
> 
> -	if (vma_iter_prealloc(&vmi, vma)) {
> +	if (vma_iter_prealloc(vmi, vma)) {
>  		error = -ENOMEM;
>  		goto free_vma;
>  	}
> 
> -	if (file) {
> -		vma->vm_file = get_file(file);
> -		error = mmap_file(file, vma);
> -		if (error)
> -			goto unmap_and_free_file_vma;
> -
> -		/* Drivers cannot alter the address of the VMA. */
> -		WARN_ON_ONCE(addr != vma->vm_start);
> -		/*
> -		 * Drivers should not permit writability when previously it was
> -		 * disallowed.
> -		 */
> -		VM_WARN_ON_ONCE(vm_flags != vma->vm_flags &&
> -				!(vm_flags & VM_MAYWRITE) &&
> -				(vma->vm_flags & VM_MAYWRITE));
> -
> -		vma_iter_config(&vmi, addr, end);
> -		/*
> -		 * If vm_flags changed after mmap_file(), we should try merge
> -		 * vma again as we may succeed this time.
> -		 */
> -		if (unlikely(vm_flags != vma->vm_flags && vmg.prev)) {
> -			struct vm_area_struct *merge;
> -
> -			vmg.flags = vma->vm_flags;
> -			/* If this fails, state is reset ready for a reattempt. */
> -			merge = vma_merge_new_range(&vmg);
> -
> -			if (merge) {
> -				/*
> -				 * ->mmap() can change vma->vm_file and fput
> -				 * the original file. So fput the vma->vm_file
> -				 * here or we would add an extra fput for file
> -				 * and cause general protection fault
> -				 * ultimately.
> -				 */
> -				fput(vma->vm_file);
> -				vm_area_free(vma);
> -				vma_iter_free(&vmi);
> -				vma = merge;
> -				/* Update vm_flags to pick up the change. */
> -				vm_flags = vma->vm_flags;
> -				goto file_expanded;
> -			}
> -			vma_iter_config(&vmi, addr, end);
> -		}
> -
> -		vm_flags = vma->vm_flags;
> -	} else if (vm_flags & VM_SHARED) {
> +	if (vmg->file)
> +		error = __mmap_new_file_vma(map, vma, &merge);
> +	else if (map->flags & VM_SHARED)
>  		error = shmem_zero_setup(vma);
> -		if (error)
> -			goto free_iter_vma;
> -	} else {
> +	else
>  		vma_set_anonymous(vma);
> -	}
> +
> +	if (error)
> +		goto free_iter_vma;
> +
> +	if (merge)
> +		goto file_expanded;
> 
>  #ifdef CONFIG_SPARC64
>  	/* TODO: Fix SPARC ADI! */
> -	WARN_ON_ONCE(!arch_validate_flags(vm_flags));
> +	WARN_ON_ONCE(!arch_validate_flags(map->flags));
>  #endif
> 
>  	/* Lock the VMA since it is modified after insertion into VMA tree */
>  	vma_start_write(vma);
> -	vma_iter_store(&vmi, vma);
> -	mm->map_count++;
> +	vma_iter_store(vmi, vma);
> +	map->mm->map_count++;
>  	vma_link_file(vma);
> 
>  	/*
>  	 * vma_merge_new_range() calls khugepaged_enter_vma() too, the below
>  	 * call covers the non-merge case.
>  	 */
> -	khugepaged_enter_vma(vma, vma->vm_flags);
> +	khugepaged_enter_vma(vma, map->flags);
> 
>  file_expanded:
> -	file = vma->vm_file;
>  	ksm_add_vma(vma);
> -expanded:
> +
> +	*vmap = vma;
> +	return 0;
> +
> +free_iter_vma:
> +	vma_iter_free(vmi);
> +free_vma:
> +	vm_area_free(vma);
> +	return error;
> +}
> +
> +/*
> + * __mmap_complete() - Unmap any VMAs we overlap, account memory mapping
> + *                     statistics, handle locking and finalise the VMA.
> + *
> + * @map: Mapping state.
> + * @vma: Merged or newly allocated VMA for the mmap()'d region.
> + */
> +static void __mmap_complete(struct mmap_state *map, struct vm_area_struct *vma)
> +{
> +	struct mm_struct *mm = map->mm;
> +	unsigned long vm_flags = vma->vm_flags;
> +
>  	perf_event_mmap(vma);
> 
> -	/* Unmap any existing mapping in the area */
> -	vms_complete_munmap_vmas(&vms, &mas_detach);
> +	/* Unmap any existing mapping in the area. */
> +	vms_complete_munmap_vmas(&map->vms, &map->mas_detach);
> 
> -	vm_stat_account(mm, vm_flags, pglen);
> +	vm_stat_account(mm, vma->vm_flags, map->pglen);
>  	if (vm_flags & VM_LOCKED) {
>  		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
>  					is_vm_hugetlb_page(vma) ||
> -					vma == get_gate_vma(current->mm))
> +					vma == get_gate_vma(mm))
>  			vm_flags_clear(vma, VM_LOCKED_MASK);
>  		else
> -			mm->locked_vm += pglen;
> +			mm->locked_vm += map->pglen;
>  	}
> 
> -	if (file)
> +	if (vma->vm_file)
>  		uprobe_mmap(vma);
> 
>  	/*
> @@ -2364,26 +2447,43 @@ unsigned long __mmap_region(struct file *file, unsigned long addr,
>  	vm_flags_set(vma, VM_SOFTDIRTY);
> 
>  	vma_set_page_prot(vma);
> +}
> 
> -	return addr;
> +unsigned long __mmap_region(struct file *file, unsigned long addr,
> +		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> +		struct list_head *uf)
> +{
> +	struct mm_struct *mm = current->mm;
> +	struct vm_area_struct *vma;
> +	int error;
> +	VMA_ITERATOR(vmi, mm, addr);
> +	VMG_STATE(vmg, mm, &vmi, addr, addr + len, vm_flags, pgoff);
> +	MMAP_STATE(map, mm, &vmi, &vmg, uf, vm_flags, len);
> 
> -unmap_and_free_file_vma:
> -	fput(vma->vm_file);
> -	vma->vm_file = NULL;
> +	vmg.file = file;
> 
> -	vma_iter_set(&vmi, vma->vm_end);
> -	/* Undo any partial mapping done by a device driver. */
> -	unmap_region(&vmi.mas, vma, vmg.prev, vmg.next);
> -free_iter_vma:
> -	vma_iter_free(&vmi);
> -free_vma:
> -	vm_area_free(vma);
> -unacct_error:
> -	if (charged)
> -		vm_unacct_memory(charged);
> +	error = __mmap_prepare(&map);
> +	if (error)
> +		goto abort_munmap;
> +
> +	/* Attempt to merge with adjacent VMAs... */
> +	vmg.flags = map.flags;
> +	vma = vma_merge_new_range(&vmg);
> +	if (!vma) {
> +		/* ...but if we can't, allocate a new VMA. */
> +		error = __mmap_new_vma(&map, &vma);
> +		if (error)
> +			goto unacct_error;
> +	}
> +
> +	__mmap_complete(&map, vma);
> 
> +	return addr;
> +
> +unacct_error:
> +	if (map.charged)
> +		vm_unacct_memory(map.charged);
>  abort_munmap:
> -	vms_abort_munmap_vmas(&vms, &mas_detach);
> -gather_failed:
> +	vms_abort_munmap_vmas(&map.vms, &map.mas_detach);
>  	return error;
>  }
> --
> 2.47.0
Liam R. Howlett Oct. 23, 2024, 3:21 p.m. UTC | #2
* Vlastimil Babka <vbabka@suse.cz> [241023 10:39]:
> On 10/22/24 22:40, Lorenzo Stoakes wrote:
> > We have seen bugs and resource leaks arise from the complexity of the
> > __mmap_region() function. This, and the generally deeply fragile error
> > handling logic and complexity which makes understanding the function
> > difficult make it highly desirable to refactor it into something readable.
> > 
> > Achieve this by separating the function into smaller logical parts which
> > are easier to understand and follow, and which importantly very
> > significantly simplify the error handling.
> > 
> > Note that we now call vms_abort_munmap_vmas() in more error paths than we
> > used to, however in cases where no abort need occur, vms->nr_pages will be
> > equal to zero and we simply exit this function without doing more than we
> > would have done previously.
> > 
> > Importantly, the invocation of the driver mmap hook via mmap_file() now has
> > very simple and obvious handling (this was previously the most problematic
> > part of the mmap() operation).
> > 
> > Signed-off-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
> > ---
> >  mm/vma.c | 380 +++++++++++++++++++++++++++++++++++--------------------
> >  1 file changed, 240 insertions(+), 140 deletions(-)
> > 
> > diff --git a/mm/vma.c b/mm/vma.c
> > index 7617f9d50d62..a271e2b406ab 100644
> > --- a/mm/vma.c
> > +++ b/mm/vma.c
> > @@ -7,6 +7,31 @@
> >  #include "vma_internal.h"
> >  #include "vma.h"
> > 
> > +struct mmap_state {
> > +	struct mm_struct *mm;
> > +	struct vma_iterator *vmi;
> > +	struct vma_merge_struct *vmg;
> > +	struct list_head *uf;
> > +
> > +	struct vma_munmap_struct vms;
> > +	struct ma_state mas_detach;
> > +	struct maple_tree mt_detach;
> > +
> > +	unsigned long flags;
> > +	unsigned long pglen;
> > +	unsigned long charged;
> > +};
> > +
> > +#define MMAP_STATE(name, mm_, vmi_, vmg_, uf_, flags_, len_)	\
> > +	struct mmap_state name = {				\
> > +		.mm = mm_,					\
> > +		.vmi = vmi_,					\
> > +		.vmg = vmg_,					\
> > +		.uf = uf_,					\
> > +		.flags = flags_,				\
> > +		.pglen = PHYS_PFN(len_),			\
> > +	}
> > +
> >  static inline bool is_mergeable_vma(struct vma_merge_struct *vmg, bool merge_next)
> >  {
> >  	struct vm_area_struct *vma = merge_next ? vmg->next : vmg->prev;
> > @@ -2169,189 +2194,247 @@ static void vms_abort_munmap_vmas(struct vma_munmap_struct *vms,
> >  	vms_complete_munmap_vmas(vms, mas_detach);
> >  }
> > 
> > -unsigned long __mmap_region(struct file *file, unsigned long addr,
> > -		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> > -		struct list_head *uf)
> > +/*
> > + * __mmap_prepare() - Prepare to gather any overlapping VMAs that need to be
> > + *                    unmapped once the map operation is completed, check limits,
> > + *                    account mapping and clean up any pre-existing VMAs.
> > + *
> > + * @map: Mapping state.
> > + *
> > + * Returns: 0 on success, error code otherwise.
> > + */
> > +static int __mmap_prepare(struct mmap_state *map)
> >  {
> > -	struct mm_struct *mm = current->mm;
> > -	struct vm_area_struct *vma = NULL;
> > -	pgoff_t pglen = PHYS_PFN(len);
> > -	unsigned long charged = 0;
> > -	struct vma_munmap_struct vms;
> > -	struct ma_state mas_detach;
> > -	struct maple_tree mt_detach;
> > -	unsigned long end = addr + len;
> >  	int error;
> > -	VMA_ITERATOR(vmi, mm, addr);
> > -	VMG_STATE(vmg, mm, &vmi, addr, end, vm_flags, pgoff);
> > -
> > -	vmg.file = file;
> > -	/* Find the first overlapping VMA */
> > -	vma = vma_find(&vmi, end);
> > -	init_vma_munmap(&vms, &vmi, vma, addr, end, uf, /* unlock = */ false);
> > -	if (vma) {
> > -		mt_init_flags(&mt_detach, vmi.mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> > -		mt_on_stack(mt_detach);
> > -		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
> > +	struct vma_iterator *vmi = map->vmi;
> > +	struct vma_merge_struct *vmg = map->vmg;
> > +	struct vma_munmap_struct *vms = &map->vms;
> > +
> > +	/* Find the first overlapping VMA and initialise unmap state. */
> > +	vms->vma = vma_find(vmi, vmg->end);
> > +	init_vma_munmap(vms, vmi, vms->vma, vmg->start, vmg->end, map->uf,
> > +			/* unlock = */ false);
> > +
> > +	/* OK, we have overlapping VMAs - prepare to unmap them. */
> > +	if (vms->vma) {
> > +		mt_init_flags(&map->mt_detach, vmi->mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> > +		mt_on_stack(map->mt_detach);
> > +		mas_init(&map->mas_detach, &map->mt_detach, /* addr = */ 0);
> >  		/* Prepare to unmap any existing mapping in the area */
> > -		error = vms_gather_munmap_vmas(&vms, &mas_detach);
> > +		error = vms_gather_munmap_vmas(vms, &map->mas_detach);
> >  		if (error)
> > -			goto gather_failed;
> > +			return error;
> 
> So this assumes vms_abort_munmap_vmas() will rely on the "vms->nr_pages will
> be equal to zero" mentioned in commit log. But AFAICS
> vms_gather_munmap_vmas() can fail in Nth iteration of its
> for_each_vma_range() after some iterations already increased nr_pages and it
> will do a reattach_vmas() and return the error and we just return the error
> here.
> I think either here or maybe in vms_gather_munmap_vmas() itself a reset of
> vms->nr_pages to zero on error should happen for the vms_abort_munmap_vmas()
> to be a no-op?

Probably in reattach_vmas()?

> 
> > 
> > -		vmg.next = vms.next;
> > -		vmg.prev = vms.prev;
> > -		vma = NULL;
> > +		vmg->next = vms->next;
> > +		vmg->prev = vms->prev;
> >  	} else {
> > -		vmg.next = vma_iter_next_rewind(&vmi, &vmg.prev);
> > +		vmg->next = vma_iter_next_rewind(vmi, &vmg->prev);
> >  	}
> > 
> >  	/* Check against address space limit. */
> > -	if (!may_expand_vm(mm, vm_flags, pglen - vms.nr_pages)) {
> > -		error = -ENOMEM;
> > -		goto abort_munmap;
> > -	}
> > +	if (!may_expand_vm(map->mm, map->flags, map->pglen - vms->nr_pages))
> > +		return -ENOMEM;
> > 
> > -	/*
> > -	 * Private writable mapping: check memory availability
> > -	 */
> > -	if (accountable_mapping(file, vm_flags)) {
> > -		charged = pglen;
> > -		charged -= vms.nr_accounted;
> > -		if (charged) {
> > -			error = security_vm_enough_memory_mm(mm, charged);
> > +	/* Private writable mapping: check memory availability. */
> > +	if (accountable_mapping(vmg->file, map->flags)) {
> > +		map->charged = map->pglen;
> > +		map->charged -= vms->nr_accounted;
> > +		if (map->charged) {
> > +			error = security_vm_enough_memory_mm(map->mm, map->charged);
> >  			if (error)
> > -				goto abort_munmap;
> > +				return error;
> >  		}
> > 
> > -		vms.nr_accounted = 0;
> > -		vm_flags |= VM_ACCOUNT;
> > -		vmg.flags = vm_flags;
> > +		vms->nr_accounted = 0;
> > +		map->flags |= VM_ACCOUNT;
> >  	}
> > 
> >  	/*
> > -	 * clear PTEs while the vma is still in the tree so that rmap
> > +	 * Clear PTEs while the vma is still in the tree so that rmap
> >  	 * cannot race with the freeing later in the truncate scenario.
> >  	 * This is also needed for mmap_file(), which is why vm_ops
> >  	 * close function is called.
> >  	 */
> > -	vms_clean_up_area(&vms, &mas_detach);
> > -	vma = vma_merge_new_range(&vmg);
> > -	if (vma)
> > -		goto expanded;
> > +	vms_clean_up_area(vms, &map->mas_detach);
> > +
> > +	return 0;
> > +}
> > +
> > +static int __mmap_new_file_vma(struct mmap_state *map, struct vm_area_struct *vma,
> > +			       struct vm_area_struct **mergep)
> > +{
> > +	struct vma_iterator *vmi = map->vmi;
> > +	struct vma_merge_struct *vmg = map->vmg;
> > +	int error;
> > +
> > +	vma->vm_file = get_file(vmg->file);
> > +	error = mmap_file(vma->vm_file, vma);
> > +	if (error) {
> > +		fput(vma->vm_file);
> > +		vma->vm_file = NULL;
> > +
> > +		vma_iter_set(vmi, vma->vm_end);
> > +		/* Undo any partial mapping done by a device driver. */
> > +		unmap_region(&vmi->mas, vma, vmg->prev, vmg->next);
> > +
> > +		return error;
> > +	}
> > +
> > +	/* Drivers cannot alter the address of the VMA. */
> > +	WARN_ON_ONCE(vmg->start != vma->vm_start);
> > +	/*
> > +	 * Drivers should not permit writability when previously it was
> > +	 * disallowed.
> > +	 */
> > +	VM_WARN_ON_ONCE(map->flags != vma->vm_flags &&
> > +			!(map->flags & VM_MAYWRITE) &&
> > +			(vma->vm_flags & VM_MAYWRITE));
> > +
> > +	vma_iter_config(vmi, vmg->start, vmg->end);
> > +	/*
> > +	 * If flags changed after mmap_file(), we should try merge
> > +	 * vma again as we may succeed this time.
> > +	 */
> > +	if (unlikely(map->flags != vma->vm_flags && vmg->prev)) {
> > +		struct vm_area_struct *merge;
> > +
> > +		vmg->flags = vma->vm_flags;
> > +		/* If this fails, state is reset ready for a reattempt. */
> > +		merge = vma_merge_new_range(vmg);
> > +
> > +		if (merge) {
> > +			/*
> > +			 * ->mmap() can change vma->vm_file and fput
> > +			 * the original file. So fput the vma->vm_file
> > +			 * here or we would add an extra fput for file
> > +			 * and cause general protection fault
> > +			 * ultimately.
> > +			 */
> > +			fput(vma->vm_file);
> > +			vm_area_free(vma);
> 
> This frees the vma.
> 
> > +			vma_iter_free(vmi);
> > +			*mergep = merge;
> > +		} else {
> > +			vma_iter_config(vmi, vmg->start, vmg->end);
> > +		}
> > +	}
> > +
> > +	map->flags = vma->vm_flags;
> 
> So this is use-after-free.
> 
> Maybe pass only a single struct vm_area_struct **vmap parameter to this
> function, and in case of merge, change both vma and *vmap to it?.
> 
> Although I can see it's all moot after 8/8. Still let's not introduce a
> temporary UAF step.
> 
> > +	return 0;
> > +}
> > +
> > +/*
> > + * __mmap_new_vma() - Allocate a new VMA for the region, as merging was not
> > + *                    possible.
> > + *
> > + *                    An exception to this is if the mapping is file-backed, and
> > + *                    the underlying driver changes the VMA flags, permitting a
> > + *                    subsequent merge of the VMA, in which case the returned
> > + *                    VMA is one that was merged on a second attempt.
> > + *
> > + * @map:  Mapping state.
> > + * @vmap: Output pointer for the new VMA.
> > + *
> > + * Returns: Zero on success, or an error.
> > + */
> > +static int __mmap_new_vma(struct mmap_state *map, struct vm_area_struct **vmap)
> > +{
> > +	struct vma_iterator *vmi = map->vmi;
> > +	struct vma_merge_struct *vmg = map->vmg;
> > +	struct vm_area_struct *merge = NULL;
> > +	int error = 0;
> > +	struct vm_area_struct *vma;
> > +
> >  	/*
> >  	 * Determine the object being mapped and call the appropriate
> >  	 * specific mapper. the address has already been validated, but
> >  	 * not unmapped, but the maps are removed from the list.
> >  	 */
> > -	vma = vm_area_alloc(mm);
> > -	if (!vma) {
> > -		error = -ENOMEM;
> > -		goto unacct_error;
> > -	}
> > +	vma = vm_area_alloc(map->mm);
> > +	if (!vma)
> > +		return -ENOMEM;
> > 
> > -	vma_iter_config(&vmi, addr, end);
> > -	vma_set_range(vma, addr, end, pgoff);
> > -	vm_flags_init(vma, vm_flags);
> > -	vma->vm_page_prot = vm_get_page_prot(vm_flags);
> > +	vma_iter_config(vmi, vmg->start, vmg->end);
> > +	vma_set_range(vma, vmg->start, vmg->end, vmg->pgoff);
> > +	vm_flags_init(vma, map->flags);
> > +	vma->vm_page_prot = vm_get_page_prot(map->flags);
> > 
> > -	if (vma_iter_prealloc(&vmi, vma)) {
> > +	if (vma_iter_prealloc(vmi, vma)) {
> >  		error = -ENOMEM;
> >  		goto free_vma;
> >  	}
> > 
> > -	if (file) {
> > -		vma->vm_file = get_file(file);
> > -		error = mmap_file(file, vma);
> > -		if (error)
> > -			goto unmap_and_free_file_vma;
> > -
> > -		/* Drivers cannot alter the address of the VMA. */
> > -		WARN_ON_ONCE(addr != vma->vm_start);
> > -		/*
> > -		 * Drivers should not permit writability when previously it was
> > -		 * disallowed.
> > -		 */
> > -		VM_WARN_ON_ONCE(vm_flags != vma->vm_flags &&
> > -				!(vm_flags & VM_MAYWRITE) &&
> > -				(vma->vm_flags & VM_MAYWRITE));
> > -
> > -		vma_iter_config(&vmi, addr, end);
> > -		/*
> > -		 * If vm_flags changed after mmap_file(), we should try merge
> > -		 * vma again as we may succeed this time.
> > -		 */
> > -		if (unlikely(vm_flags != vma->vm_flags && vmg.prev)) {
> > -			struct vm_area_struct *merge;
> > -
> > -			vmg.flags = vma->vm_flags;
> > -			/* If this fails, state is reset ready for a reattempt. */
> > -			merge = vma_merge_new_range(&vmg);
> > -
> > -			if (merge) {
> > -				/*
> > -				 * ->mmap() can change vma->vm_file and fput
> > -				 * the original file. So fput the vma->vm_file
> > -				 * here or we would add an extra fput for file
> > -				 * and cause general protection fault
> > -				 * ultimately.
> > -				 */
> > -				fput(vma->vm_file);
> > -				vm_area_free(vma);
> > -				vma_iter_free(&vmi);
> > -				vma = merge;
> > -				/* Update vm_flags to pick up the change. */
> > -				vm_flags = vma->vm_flags;
> > -				goto file_expanded;
> > -			}
> > -			vma_iter_config(&vmi, addr, end);
> > -		}
> > -
> > -		vm_flags = vma->vm_flags;
> > -	} else if (vm_flags & VM_SHARED) {
> > +	if (vmg->file)
> > +		error = __mmap_new_file_vma(map, vma, &merge);
> > +	else if (map->flags & VM_SHARED)
> >  		error = shmem_zero_setup(vma);
> > -		if (error)
> > -			goto free_iter_vma;
> > -	} else {
> > +	else
> >  		vma_set_anonymous(vma);
> > -	}
> > +
> > +	if (error)
> > +		goto free_iter_vma;
> > +
> > +	if (merge)
> > +		goto file_expanded;
> > 
> >  #ifdef CONFIG_SPARC64
> >  	/* TODO: Fix SPARC ADI! */
> > -	WARN_ON_ONCE(!arch_validate_flags(vm_flags));
> > +	WARN_ON_ONCE(!arch_validate_flags(map->flags));
> >  #endif
> > 
> >  	/* Lock the VMA since it is modified after insertion into VMA tree */
> >  	vma_start_write(vma);
> > -	vma_iter_store(&vmi, vma);
> > -	mm->map_count++;
> > +	vma_iter_store(vmi, vma);
> > +	map->mm->map_count++;
> >  	vma_link_file(vma);
> > 
> >  	/*
> >  	 * vma_merge_new_range() calls khugepaged_enter_vma() too, the below
> >  	 * call covers the non-merge case.
> >  	 */
> > -	khugepaged_enter_vma(vma, vma->vm_flags);
> > +	khugepaged_enter_vma(vma, map->flags);
> > 
> >  file_expanded:
> > -	file = vma->vm_file;
> >  	ksm_add_vma(vma);
> > -expanded:
> > +
> > +	*vmap = vma;
> > +	return 0;
> > +
> > +free_iter_vma:
> > +	vma_iter_free(vmi);
> > +free_vma:
> > +	vm_area_free(vma);
> > +	return error;
> > +}
> > +
> > +/*
> > + * __mmap_complete() - Unmap any VMAs we overlap, account memory mapping
> > + *                     statistics, handle locking and finalise the VMA.
> > + *
> > + * @map: Mapping state.
> > + * @vma: Merged or newly allocated VMA for the mmap()'d region.
> > + */
> > +static void __mmap_complete(struct mmap_state *map, struct vm_area_struct *vma)
> > +{
> > +	struct mm_struct *mm = map->mm;
> > +	unsigned long vm_flags = vma->vm_flags;
> > +
> >  	perf_event_mmap(vma);
> > 
> > -	/* Unmap any existing mapping in the area */
> > -	vms_complete_munmap_vmas(&vms, &mas_detach);
> > +	/* Unmap any existing mapping in the area. */
> > +	vms_complete_munmap_vmas(&map->vms, &map->mas_detach);
> > 
> > -	vm_stat_account(mm, vm_flags, pglen);
> > +	vm_stat_account(mm, vma->vm_flags, map->pglen);
> >  	if (vm_flags & VM_LOCKED) {
> >  		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
> >  					is_vm_hugetlb_page(vma) ||
> > -					vma == get_gate_vma(current->mm))
> > +					vma == get_gate_vma(mm))
> >  			vm_flags_clear(vma, VM_LOCKED_MASK);
> >  		else
> > -			mm->locked_vm += pglen;
> > +			mm->locked_vm += map->pglen;
> >  	}
> > 
> > -	if (file)
> > +	if (vma->vm_file)
> >  		uprobe_mmap(vma);
> > 
> >  	/*
> > @@ -2364,26 +2447,43 @@ unsigned long __mmap_region(struct file *file, unsigned long addr,
> >  	vm_flags_set(vma, VM_SOFTDIRTY);
> > 
> >  	vma_set_page_prot(vma);
> > +}
> > 
> > -	return addr;
> > +unsigned long __mmap_region(struct file *file, unsigned long addr,
> > +		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> > +		struct list_head *uf)
> > +{
> > +	struct mm_struct *mm = current->mm;
> > +	struct vm_area_struct *vma;
> > +	int error;
> > +	VMA_ITERATOR(vmi, mm, addr);
> > +	VMG_STATE(vmg, mm, &vmi, addr, addr + len, vm_flags, pgoff);
> > +	MMAP_STATE(map, mm, &vmi, &vmg, uf, vm_flags, len);
> > 
> > -unmap_and_free_file_vma:
> > -	fput(vma->vm_file);
> > -	vma->vm_file = NULL;
> > +	vmg.file = file;
> > 
> > -	vma_iter_set(&vmi, vma->vm_end);
> > -	/* Undo any partial mapping done by a device driver. */
> > -	unmap_region(&vmi.mas, vma, vmg.prev, vmg.next);
> > -free_iter_vma:
> > -	vma_iter_free(&vmi);
> > -free_vma:
> > -	vm_area_free(vma);
> > -unacct_error:
> > -	if (charged)
> > -		vm_unacct_memory(charged);
> > +	error = __mmap_prepare(&map);
> > +	if (error)
> > +		goto abort_munmap;
> > +
> > +	/* Attempt to merge with adjacent VMAs... */
> > +	vmg.flags = map.flags;
> > +	vma = vma_merge_new_range(&vmg);
> > +	if (!vma) {
> > +		/* ...but if we can't, allocate a new VMA. */
> > +		error = __mmap_new_vma(&map, &vma);
> > +		if (error)
> > +			goto unacct_error;
> > +	}
> > +
> > +	__mmap_complete(&map, vma);
> > 
> > +	return addr;
> > +
> > +unacct_error:
> > +	if (map.charged)
> > +		vm_unacct_memory(map.charged);
> >  abort_munmap:
> > -	vms_abort_munmap_vmas(&vms, &mas_detach);
> > -gather_failed:
> > +	vms_abort_munmap_vmas(&map.vms, &map.mas_detach);
> >  	return error;
> >  }
> > --
> > 2.47.0
>
Liam R. Howlett Oct. 23, 2024, 5:19 p.m. UTC | #3
* Lorenzo Stoakes <lorenzo.stoakes@oracle.com> [241022 16:41]:
> We have seen bugs and resource leaks arise from the complexity of the
> __mmap_region() function. This, and the generally deeply fragile error
> handling logic and complexity which makes understanding the function
> difficult make it highly desirable to refactor it into something readable.
> 
> Achieve this by separating the function into smaller logical parts which
> are easier to understand and follow, and which importantly very
> significantly simplify the error handling.
> 
> Note that we now call vms_abort_munmap_vmas() in more error paths than we
> used to, however in cases where no abort need occur, vms->nr_pages will be
> equal to zero and we simply exit this function without doing more than we
> would have done previously.
> 
> Importantly, the invocation of the driver mmap hook via mmap_file() now has
> very simple and obvious handling (this was previously the most problematic
> part of the mmap() operation).
> 
> Signed-off-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
> ---
>  mm/vma.c | 380 +++++++++++++++++++++++++++++++++++--------------------
>  1 file changed, 240 insertions(+), 140 deletions(-)
> 
> diff --git a/mm/vma.c b/mm/vma.c
> index 7617f9d50d62..a271e2b406ab 100644
> --- a/mm/vma.c
> +++ b/mm/vma.c
> @@ -7,6 +7,31 @@
>  #include "vma_internal.h"
>  #include "vma.h"
> 
> +struct mmap_state {
> +	struct mm_struct *mm;
> +	struct vma_iterator *vmi;
> +	struct vma_merge_struct *vmg;
> +	struct list_head *uf;
> +
> +	struct vma_munmap_struct vms;
> +	struct ma_state mas_detach;
> +	struct maple_tree mt_detach;
> +
> +	unsigned long flags;
> +	unsigned long pglen;
> +	unsigned long charged;
> +};
> +
> +#define MMAP_STATE(name, mm_, vmi_, vmg_, uf_, flags_, len_)	\
> +	struct mmap_state name = {				\
> +		.mm = mm_,					\
> +		.vmi = vmi_,					\
> +		.vmg = vmg_,					\
> +		.uf = uf_,					\
> +		.flags = flags_,				\
> +		.pglen = PHYS_PFN(len_),			\
> +	}
> +
>  static inline bool is_mergeable_vma(struct vma_merge_struct *vmg, bool merge_next)
>  {
>  	struct vm_area_struct *vma = merge_next ? vmg->next : vmg->prev;
> @@ -2169,189 +2194,247 @@ static void vms_abort_munmap_vmas(struct vma_munmap_struct *vms,
>  	vms_complete_munmap_vmas(vms, mas_detach);
>  }
> 
> -unsigned long __mmap_region(struct file *file, unsigned long addr,
> -		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> -		struct list_head *uf)
> +/*
> + * __mmap_prepare() - Prepare to gather any overlapping VMAs that need to be
> + *                    unmapped once the map operation is completed, check limits,
> + *                    account mapping and clean up any pre-existing VMAs.
> + *

nit: formatting seems wrong here?

> + * @map: Mapping state.
> + *
> + * Returns: 0 on success, error code otherwise.
> + */
> +static int __mmap_prepare(struct mmap_state *map)
>  {
> -	struct mm_struct *mm = current->mm;
> -	struct vm_area_struct *vma = NULL;
> -	pgoff_t pglen = PHYS_PFN(len);
> -	unsigned long charged = 0;
> -	struct vma_munmap_struct vms;
> -	struct ma_state mas_detach;
> -	struct maple_tree mt_detach;
> -	unsigned long end = addr + len;
>  	int error;
> -	VMA_ITERATOR(vmi, mm, addr);
> -	VMG_STATE(vmg, mm, &vmi, addr, end, vm_flags, pgoff);
> -
> -	vmg.file = file;
> -	/* Find the first overlapping VMA */
> -	vma = vma_find(&vmi, end);
> -	init_vma_munmap(&vms, &vmi, vma, addr, end, uf, /* unlock = */ false);
> -	if (vma) {
> -		mt_init_flags(&mt_detach, vmi.mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> -		mt_on_stack(mt_detach);
> -		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
> +	struct vma_iterator *vmi = map->vmi;
> +	struct vma_merge_struct *vmg = map->vmg;
> +	struct vma_munmap_struct *vms = &map->vms;
> +
> +	/* Find the first overlapping VMA and initialise unmap state. */
> +	vms->vma = vma_find(vmi, vmg->end);
> +	init_vma_munmap(vms, vmi, vms->vma, vmg->start, vmg->end, map->uf,
> +			/* unlock = */ false);
> +
> +	/* OK, we have overlapping VMAs - prepare to unmap them. */
> +	if (vms->vma) {
> +		mt_init_flags(&map->mt_detach, vmi->mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);

Nit: line is too long.

> +		mt_on_stack(map->mt_detach);
> +		mas_init(&map->mas_detach, &map->mt_detach, /* addr = */ 0);
>  		/* Prepare to unmap any existing mapping in the area */
> -		error = vms_gather_munmap_vmas(&vms, &mas_detach);
> +		error = vms_gather_munmap_vmas(vms, &map->mas_detach);
>  		if (error)
> -			goto gather_failed;
> +			return error;

As Vlastimil pointed out, there is an issue just returning the error.

> 
> -		vmg.next = vms.next;
> -		vmg.prev = vms.prev;
> -		vma = NULL;
> +		vmg->next = vms->next;
> +		vmg->prev = vms->prev;
>  	} else {
> -		vmg.next = vma_iter_next_rewind(&vmi, &vmg.prev);
> +		vmg->next = vma_iter_next_rewind(vmi, &vmg->prev);
>  	}
> 
>  	/* Check against address space limit. */
> -	if (!may_expand_vm(mm, vm_flags, pglen - vms.nr_pages)) {
> -		error = -ENOMEM;
> -		goto abort_munmap;
> -	}
> +	if (!may_expand_vm(map->mm, map->flags, map->pglen - vms->nr_pages))
> +		return -ENOMEM;
> 
> -	/*
> -	 * Private writable mapping: check memory availability
> -	 */
> -	if (accountable_mapping(file, vm_flags)) {
> -		charged = pglen;
> -		charged -= vms.nr_accounted;
> -		if (charged) {
> -			error = security_vm_enough_memory_mm(mm, charged);
> +	/* Private writable mapping: check memory availability. */
> +	if (accountable_mapping(vmg->file, map->flags)) {
> +		map->charged = map->pglen;
> +		map->charged -= vms->nr_accounted;
> +		if (map->charged) {
> +			error = security_vm_enough_memory_mm(map->mm, map->charged);
>  			if (error)
> -				goto abort_munmap;
> +				return error;
>  		}
> 
> -		vms.nr_accounted = 0;
> -		vm_flags |= VM_ACCOUNT;
> -		vmg.flags = vm_flags;
> +		vms->nr_accounted = 0;
> +		map->flags |= VM_ACCOUNT;
>  	}
> 
>  	/*
> -	 * clear PTEs while the vma is still in the tree so that rmap
> +	 * Clear PTEs while the vma is still in the tree so that rmap
>  	 * cannot race with the freeing later in the truncate scenario.
>  	 * This is also needed for mmap_file(), which is why vm_ops
>  	 * close function is called.
>  	 */
> -	vms_clean_up_area(&vms, &mas_detach);
> -	vma = vma_merge_new_range(&vmg);
> -	if (vma)
> -		goto expanded;
> +	vms_clean_up_area(vms, &map->mas_detach);
> +
> +	return 0;
> +}
> +
> +static int __mmap_new_file_vma(struct mmap_state *map, struct vm_area_struct *vma,
> +			       struct vm_area_struct **mergep)
> +{
> +	struct vma_iterator *vmi = map->vmi;
> +	struct vma_merge_struct *vmg = map->vmg;
> +	int error;
> +
> +	vma->vm_file = get_file(vmg->file);
> +	error = mmap_file(vma->vm_file, vma);
> +	if (error) {
> +		fput(vma->vm_file);
> +		vma->vm_file = NULL;
> +
> +		vma_iter_set(vmi, vma->vm_end);
> +		/* Undo any partial mapping done by a device driver. */
> +		unmap_region(&vmi->mas, vma, vmg->prev, vmg->next);
> +
> +		return error;
> +	}
> +
> +	/* Drivers cannot alter the address of the VMA. */
> +	WARN_ON_ONCE(vmg->start != vma->vm_start);
> +	/*
> +	 * Drivers should not permit writability when previously it was
> +	 * disallowed.
> +	 */
> +	VM_WARN_ON_ONCE(map->flags != vma->vm_flags &&
> +			!(map->flags & VM_MAYWRITE) &&
> +			(vma->vm_flags & VM_MAYWRITE));
> +
> +	vma_iter_config(vmi, vmg->start, vmg->end);
> +	/*
> +	 * If flags changed after mmap_file(), we should try merge
> +	 * vma again as we may succeed this time.
> +	 */
> +	if (unlikely(map->flags != vma->vm_flags && vmg->prev)) {
> +		struct vm_area_struct *merge;
> +
> +		vmg->flags = vma->vm_flags;
> +		/* If this fails, state is reset ready for a reattempt. */
> +		merge = vma_merge_new_range(vmg);
> +
> +		if (merge) {
> +			/*
> +			 * ->mmap() can change vma->vm_file and fput
> +			 * the original file. So fput the vma->vm_file
> +			 * here or we would add an extra fput for file
> +			 * and cause general protection fault
> +			 * ultimately.
> +			 */
> +			fput(vma->vm_file);
> +			vm_area_free(vma);
> +			vma_iter_free(vmi);
> +			*mergep = merge;
> +		} else {
> +			vma_iter_config(vmi, vmg->start, vmg->end);
> +		}
> +	}
> +
> +	map->flags = vma->vm_flags;
> +	return 0;
> +}
> +
> +/*
> + * __mmap_new_vma() - Allocate a new VMA for the region, as merging was not
> + *                    possible.
> + *
> + *                    An exception to this is if the mapping is file-backed, and
> + *                    the underlying driver changes the VMA flags, permitting a
> + *                    subsequent merge of the VMA, in which case the returned
> + *                    VMA is one that was merged on a second attempt.

It seems all the descriptions have indented lines.

> + *
> + * @map:  Mapping state.
> + * @vmap: Output pointer for the new VMA.
> + *
> + * Returns: Zero on success, or an error.
> + */
> +static int __mmap_new_vma(struct mmap_state *map, struct vm_area_struct **vmap)
> +{
> +	struct vma_iterator *vmi = map->vmi;
> +	struct vma_merge_struct *vmg = map->vmg;
> +	struct vm_area_struct *merge = NULL;
> +	int error = 0;
> +	struct vm_area_struct *vma;
> +
>  	/*
>  	 * Determine the object being mapped and call the appropriate
>  	 * specific mapper. the address has already been validated, but
>  	 * not unmapped, but the maps are removed from the list.
>  	 */
> -	vma = vm_area_alloc(mm);
> -	if (!vma) {
> -		error = -ENOMEM;
> -		goto unacct_error;
> -	}
> +	vma = vm_area_alloc(map->mm);
> +	if (!vma)
> +		return -ENOMEM;
> 
> -	vma_iter_config(&vmi, addr, end);
> -	vma_set_range(vma, addr, end, pgoff);
> -	vm_flags_init(vma, vm_flags);
> -	vma->vm_page_prot = vm_get_page_prot(vm_flags);
> +	vma_iter_config(vmi, vmg->start, vmg->end);

This function is only called from __mmap_region() after an
__mmap_prepare() and vma_merge_new_range().  The state should be fine,
so maybe this could be WARN_ONs?  Although, it's probably safer to just
leave it.

> +	vma_set_range(vma, vmg->start, vmg->end, vmg->pgoff);
> +	vm_flags_init(vma, map->flags);
> +	vma->vm_page_prot = vm_get_page_prot(map->flags);
> 
> -	if (vma_iter_prealloc(&vmi, vma)) {
> +	if (vma_iter_prealloc(vmi, vma)) {
>  		error = -ENOMEM;
>  		goto free_vma;
>  	}
> 
> -	if (file) {
> -		vma->vm_file = get_file(file);
> -		error = mmap_file(file, vma);
> -		if (error)
> -			goto unmap_and_free_file_vma;
> -
> -		/* Drivers cannot alter the address of the VMA. */
> -		WARN_ON_ONCE(addr != vma->vm_start);
> -		/*
> -		 * Drivers should not permit writability when previously it was
> -		 * disallowed.
> -		 */
> -		VM_WARN_ON_ONCE(vm_flags != vma->vm_flags &&
> -				!(vm_flags & VM_MAYWRITE) &&
> -				(vma->vm_flags & VM_MAYWRITE));
> -
> -		vma_iter_config(&vmi, addr, end);
> -		/*
> -		 * If vm_flags changed after mmap_file(), we should try merge
> -		 * vma again as we may succeed this time.
> -		 */
> -		if (unlikely(vm_flags != vma->vm_flags && vmg.prev)) {
> -			struct vm_area_struct *merge;
> -
> -			vmg.flags = vma->vm_flags;
> -			/* If this fails, state is reset ready for a reattempt. */
> -			merge = vma_merge_new_range(&vmg);
> -
> -			if (merge) {
> -				/*
> -				 * ->mmap() can change vma->vm_file and fput
> -				 * the original file. So fput the vma->vm_file
> -				 * here or we would add an extra fput for file
> -				 * and cause general protection fault
> -				 * ultimately.
> -				 */
> -				fput(vma->vm_file);
> -				vm_area_free(vma);
> -				vma_iter_free(&vmi);
> -				vma = merge;
> -				/* Update vm_flags to pick up the change. */
> -				vm_flags = vma->vm_flags;
> -				goto file_expanded;
> -			}
> -			vma_iter_config(&vmi, addr, end);
> -		}
> -
> -		vm_flags = vma->vm_flags;
> -	} else if (vm_flags & VM_SHARED) {
> +	if (vmg->file)
> +		error = __mmap_new_file_vma(map, vma, &merge);
> +	else if (map->flags & VM_SHARED)
>  		error = shmem_zero_setup(vma);
> -		if (error)
> -			goto free_iter_vma;
> -	} else {
> +	else
>  		vma_set_anonymous(vma);
> -	}
> +
> +	if (error)
> +		goto free_iter_vma;
> +
> +	if (merge)
> +		goto file_expanded;
> 
>  #ifdef CONFIG_SPARC64
>  	/* TODO: Fix SPARC ADI! */
> -	WARN_ON_ONCE(!arch_validate_flags(vm_flags));
> +	WARN_ON_ONCE(!arch_validate_flags(map->flags));
>  #endif
> 
>  	/* Lock the VMA since it is modified after insertion into VMA tree */
>  	vma_start_write(vma);
> -	vma_iter_store(&vmi, vma);
> -	mm->map_count++;
> +	vma_iter_store(vmi, vma);
> +	map->mm->map_count++;
>  	vma_link_file(vma);
> 
>  	/*
>  	 * vma_merge_new_range() calls khugepaged_enter_vma() too, the below
>  	 * call covers the non-merge case.
>  	 */
> -	khugepaged_enter_vma(vma, vma->vm_flags);
> +	khugepaged_enter_vma(vma, map->flags);
> 
>  file_expanded:
> -	file = vma->vm_file;
>  	ksm_add_vma(vma);

__mmap_new_file_vma() may free the vma.  I assume this is what you
mentioned elsewhere about UAF.

> -expanded:
> +

Extra whitespace.

> +	*vmap = vma;
> +	return 0;
> +
> +free_iter_vma:
> +	vma_iter_free(vmi);
> +free_vma:
> +	vm_area_free(vma);
> +	return error;
> +}
> +
> +/*
> + * __mmap_complete() - Unmap any VMAs we overlap, account memory mapping
> + *                     statistics, handle locking and finalise the VMA.
> + *
> + * @map: Mapping state.
> + * @vma: Merged or newly allocated VMA for the mmap()'d region.
> + */
> +static void __mmap_complete(struct mmap_state *map, struct vm_area_struct *vma)
> +{
> +	struct mm_struct *mm = map->mm;
> +	unsigned long vm_flags = vma->vm_flags;
> +
>  	perf_event_mmap(vma);
> 
> -	/* Unmap any existing mapping in the area */
> -	vms_complete_munmap_vmas(&vms, &mas_detach);
> +	/* Unmap any existing mapping in the area. */
> +	vms_complete_munmap_vmas(&map->vms, &map->mas_detach);
> 
> -	vm_stat_account(mm, vm_flags, pglen);
> +	vm_stat_account(mm, vma->vm_flags, map->pglen);
>  	if (vm_flags & VM_LOCKED) {
>  		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
>  					is_vm_hugetlb_page(vma) ||
> -					vma == get_gate_vma(current->mm))
> +					vma == get_gate_vma(mm))
>  			vm_flags_clear(vma, VM_LOCKED_MASK);
>  		else
> -			mm->locked_vm += pglen;
> +			mm->locked_vm += map->pglen;
>  	}
> 
> -	if (file)
> +	if (vma->vm_file)
>  		uprobe_mmap(vma);
> 
>  	/*
> @@ -2364,26 +2447,43 @@ unsigned long __mmap_region(struct file *file, unsigned long addr,
>  	vm_flags_set(vma, VM_SOFTDIRTY);
> 
>  	vma_set_page_prot(vma);
> +}
> 
> -	return addr;
> +unsigned long __mmap_region(struct file *file, unsigned long addr,
> +		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> +		struct list_head *uf)
> +{
> +	struct mm_struct *mm = current->mm;
> +	struct vm_area_struct *vma;
> +	int error;
> +	VMA_ITERATOR(vmi, mm, addr);
> +	VMG_STATE(vmg, mm, &vmi, addr, addr + len, vm_flags, pgoff);
> +	MMAP_STATE(map, mm, &vmi, &vmg, uf, vm_flags, len);
> 
> -unmap_and_free_file_vma:
> -	fput(vma->vm_file);
> -	vma->vm_file = NULL;
> +	vmg.file = file;
> 
> -	vma_iter_set(&vmi, vma->vm_end);
> -	/* Undo any partial mapping done by a device driver. */
> -	unmap_region(&vmi.mas, vma, vmg.prev, vmg.next);
> -free_iter_vma:
> -	vma_iter_free(&vmi);
> -free_vma:
> -	vm_area_free(vma);
> -unacct_error:
> -	if (charged)
> -		vm_unacct_memory(charged);
> +	error = __mmap_prepare(&map);
> +	if (error)
> +		goto abort_munmap;
> +
> +	/* Attempt to merge with adjacent VMAs... */
> +	vmg.flags = map.flags;
> +	vma = vma_merge_new_range(&vmg);
> +	if (!vma) {
> +		/* ...but if we can't, allocate a new VMA. */
> +		error = __mmap_new_vma(&map, &vma);
> +		if (error)
> +			goto unacct_error;
> +	}
> +
> +	__mmap_complete(&map, vma);
> 
> +	return addr;
> +
> +unacct_error:
> +	if (map.charged)
> +		vm_unacct_memory(map.charged);

So this is effectively undoing __mmap_prepare()'s accounting. I don't
have a better label for it, but it's not obvious by the label that the
accounting was done in __mmap_prepare().

>  abort_munmap:
> -	vms_abort_munmap_vmas(&vms, &mas_detach);
> -gather_failed:
> +	vms_abort_munmap_vmas(&map.vms, &map.mas_detach);
>  	return error;
>  }
> --
> 2.47.0
Lorenzo Stoakes Oct. 23, 2024, 5:30 p.m. UTC | #4
On Wed, Oct 23, 2024 at 04:38:46PM +0200, Vlastimil Babka wrote:
> On 10/22/24 22:40, Lorenzo Stoakes wrote:
> > We have seen bugs and resource leaks arise from the complexity of the
> > __mmap_region() function. This, and the generally deeply fragile error
> > handling logic and complexity which makes understanding the function
> > difficult make it highly desirable to refactor it into something readable.
> >
> > Achieve this by separating the function into smaller logical parts which
> > are easier to understand and follow, and which importantly very
> > significantly simplify the error handling.
> >
> > Note that we now call vms_abort_munmap_vmas() in more error paths than we
> > used to, however in cases where no abort need occur, vms->nr_pages will be
> > equal to zero and we simply exit this function without doing more than we
> > would have done previously.
> >
> > Importantly, the invocation of the driver mmap hook via mmap_file() now has
> > very simple and obvious handling (this was previously the most problematic
> > part of the mmap() operation).
> >
> > Signed-off-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
> > ---
> >  mm/vma.c | 380 +++++++++++++++++++++++++++++++++++--------------------
> >  1 file changed, 240 insertions(+), 140 deletions(-)
> >
> > diff --git a/mm/vma.c b/mm/vma.c
> > index 7617f9d50d62..a271e2b406ab 100644
> > --- a/mm/vma.c
> > +++ b/mm/vma.c
> > @@ -7,6 +7,31 @@
> >  #include "vma_internal.h"
> >  #include "vma.h"
> >
> > +struct mmap_state {
> > +	struct mm_struct *mm;
> > +	struct vma_iterator *vmi;
> > +	struct vma_merge_struct *vmg;
> > +	struct list_head *uf;
> > +
> > +	struct vma_munmap_struct vms;
> > +	struct ma_state mas_detach;
> > +	struct maple_tree mt_detach;
> > +
> > +	unsigned long flags;
> > +	unsigned long pglen;
> > +	unsigned long charged;
> > +};
> > +
> > +#define MMAP_STATE(name, mm_, vmi_, vmg_, uf_, flags_, len_)	\
> > +	struct mmap_state name = {				\
> > +		.mm = mm_,					\
> > +		.vmi = vmi_,					\
> > +		.vmg = vmg_,					\
> > +		.uf = uf_,					\
> > +		.flags = flags_,				\
> > +		.pglen = PHYS_PFN(len_),			\
> > +	}
> > +
> >  static inline bool is_mergeable_vma(struct vma_merge_struct *vmg, bool merge_next)
> >  {
> >  	struct vm_area_struct *vma = merge_next ? vmg->next : vmg->prev;
> > @@ -2169,189 +2194,247 @@ static void vms_abort_munmap_vmas(struct vma_munmap_struct *vms,
> >  	vms_complete_munmap_vmas(vms, mas_detach);
> >  }
> >
> > -unsigned long __mmap_region(struct file *file, unsigned long addr,
> > -		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> > -		struct list_head *uf)
> > +/*
> > + * __mmap_prepare() - Prepare to gather any overlapping VMAs that need to be
> > + *                    unmapped once the map operation is completed, check limits,
> > + *                    account mapping and clean up any pre-existing VMAs.
> > + *
> > + * @map: Mapping state.
> > + *
> > + * Returns: 0 on success, error code otherwise.
> > + */
> > +static int __mmap_prepare(struct mmap_state *map)
> >  {
> > -	struct mm_struct *mm = current->mm;
> > -	struct vm_area_struct *vma = NULL;
> > -	pgoff_t pglen = PHYS_PFN(len);
> > -	unsigned long charged = 0;
> > -	struct vma_munmap_struct vms;
> > -	struct ma_state mas_detach;
> > -	struct maple_tree mt_detach;
> > -	unsigned long end = addr + len;
> >  	int error;
> > -	VMA_ITERATOR(vmi, mm, addr);
> > -	VMG_STATE(vmg, mm, &vmi, addr, end, vm_flags, pgoff);
> > -
> > -	vmg.file = file;
> > -	/* Find the first overlapping VMA */
> > -	vma = vma_find(&vmi, end);
> > -	init_vma_munmap(&vms, &vmi, vma, addr, end, uf, /* unlock = */ false);
> > -	if (vma) {
> > -		mt_init_flags(&mt_detach, vmi.mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> > -		mt_on_stack(mt_detach);
> > -		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
> > +	struct vma_iterator *vmi = map->vmi;
> > +	struct vma_merge_struct *vmg = map->vmg;
> > +	struct vma_munmap_struct *vms = &map->vms;
> > +
> > +	/* Find the first overlapping VMA and initialise unmap state. */
> > +	vms->vma = vma_find(vmi, vmg->end);
> > +	init_vma_munmap(vms, vmi, vms->vma, vmg->start, vmg->end, map->uf,
> > +			/* unlock = */ false);
> > +
> > +	/* OK, we have overlapping VMAs - prepare to unmap them. */
> > +	if (vms->vma) {
> > +		mt_init_flags(&map->mt_detach, vmi->mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> > +		mt_on_stack(map->mt_detach);
> > +		mas_init(&map->mas_detach, &map->mt_detach, /* addr = */ 0);
> >  		/* Prepare to unmap any existing mapping in the area */
> > -		error = vms_gather_munmap_vmas(&vms, &mas_detach);
> > +		error = vms_gather_munmap_vmas(vms, &map->mas_detach);
> >  		if (error)
> > -			goto gather_failed;
> > +			return error;
>
> So this assumes vms_abort_munmap_vmas() will rely on the "vms->nr_pages will
> be equal to zero" mentioned in commit log. But AFAICS
> vms_gather_munmap_vmas() can fail in Nth iteration of its
> for_each_vma_range() after some iterations already increased nr_pages and it
> will do a reattach_vmas() and return the error and we just return the error
> here.
> I think either here or maybe in vms_gather_munmap_vmas() itself a reset of
> vms->nr_pages to zero on error should happen for the vms_abort_munmap_vmas()
> to be a no-op?

Ugh yup, I had wrongly assumed this would not be the case but there we go,
makes the point as to what this whole series is about... will fix.

>
> >
> > -		vmg.next = vms.next;
> > -		vmg.prev = vms.prev;
> > -		vma = NULL;
> > +		vmg->next = vms->next;
> > +		vmg->prev = vms->prev;
> >  	} else {
> > -		vmg.next = vma_iter_next_rewind(&vmi, &vmg.prev);
> > +		vmg->next = vma_iter_next_rewind(vmi, &vmg->prev);
> >  	}
> >
> >  	/* Check against address space limit. */
> > -	if (!may_expand_vm(mm, vm_flags, pglen - vms.nr_pages)) {
> > -		error = -ENOMEM;
> > -		goto abort_munmap;
> > -	}
> > +	if (!may_expand_vm(map->mm, map->flags, map->pglen - vms->nr_pages))
> > +		return -ENOMEM;
> >
> > -	/*
> > -	 * Private writable mapping: check memory availability
> > -	 */
> > -	if (accountable_mapping(file, vm_flags)) {
> > -		charged = pglen;
> > -		charged -= vms.nr_accounted;
> > -		if (charged) {
> > -			error = security_vm_enough_memory_mm(mm, charged);
> > +	/* Private writable mapping: check memory availability. */
> > +	if (accountable_mapping(vmg->file, map->flags)) {
> > +		map->charged = map->pglen;
> > +		map->charged -= vms->nr_accounted;
> > +		if (map->charged) {
> > +			error = security_vm_enough_memory_mm(map->mm, map->charged);
> >  			if (error)
> > -				goto abort_munmap;
> > +				return error;
> >  		}
> >
> > -		vms.nr_accounted = 0;
> > -		vm_flags |= VM_ACCOUNT;
> > -		vmg.flags = vm_flags;
> > +		vms->nr_accounted = 0;
> > +		map->flags |= VM_ACCOUNT;
> >  	}
> >
> >  	/*
> > -	 * clear PTEs while the vma is still in the tree so that rmap
> > +	 * Clear PTEs while the vma is still in the tree so that rmap
> >  	 * cannot race with the freeing later in the truncate scenario.
> >  	 * This is also needed for mmap_file(), which is why vm_ops
> >  	 * close function is called.
> >  	 */
> > -	vms_clean_up_area(&vms, &mas_detach);
> > -	vma = vma_merge_new_range(&vmg);
> > -	if (vma)
> > -		goto expanded;
> > +	vms_clean_up_area(vms, &map->mas_detach);
> > +
> > +	return 0;
> > +}
> > +
> > +static int __mmap_new_file_vma(struct mmap_state *map, struct vm_area_struct *vma,
> > +			       struct vm_area_struct **mergep)
> > +{
> > +	struct vma_iterator *vmi = map->vmi;
> > +	struct vma_merge_struct *vmg = map->vmg;
> > +	int error;
> > +
> > +	vma->vm_file = get_file(vmg->file);
> > +	error = mmap_file(vma->vm_file, vma);
> > +	if (error) {
> > +		fput(vma->vm_file);
> > +		vma->vm_file = NULL;
> > +
> > +		vma_iter_set(vmi, vma->vm_end);
> > +		/* Undo any partial mapping done by a device driver. */
> > +		unmap_region(&vmi->mas, vma, vmg->prev, vmg->next);
> > +
> > +		return error;
> > +	}
> > +
> > +	/* Drivers cannot alter the address of the VMA. */
> > +	WARN_ON_ONCE(vmg->start != vma->vm_start);
> > +	/*
> > +	 * Drivers should not permit writability when previously it was
> > +	 * disallowed.
> > +	 */
> > +	VM_WARN_ON_ONCE(map->flags != vma->vm_flags &&
> > +			!(map->flags & VM_MAYWRITE) &&
> > +			(vma->vm_flags & VM_MAYWRITE));
> > +
> > +	vma_iter_config(vmi, vmg->start, vmg->end);
> > +	/*
> > +	 * If flags changed after mmap_file(), we should try merge
> > +	 * vma again as we may succeed this time.
> > +	 */
> > +	if (unlikely(map->flags != vma->vm_flags && vmg->prev)) {
> > +		struct vm_area_struct *merge;
> > +
> > +		vmg->flags = vma->vm_flags;
> > +		/* If this fails, state is reset ready for a reattempt. */
> > +		merge = vma_merge_new_range(vmg);
> > +
> > +		if (merge) {
> > +			/*
> > +			 * ->mmap() can change vma->vm_file and fput
> > +			 * the original file. So fput the vma->vm_file
> > +			 * here or we would add an extra fput for file
> > +			 * and cause general protection fault
> > +			 * ultimately.
> > +			 */
> > +			fput(vma->vm_file);
> > +			vm_area_free(vma);
>
> This frees the vma.
>
> > +			vma_iter_free(vmi);
> > +			*mergep = merge;
> > +		} else {
> > +			vma_iter_config(vmi, vmg->start, vmg->end);
> > +		}
> > +	}
> > +
> > +	map->flags = vma->vm_flags;
>
> So this is use-after-free.
>
> Maybe pass only a single struct vm_area_struct **vmap parameter to this
> function, and in case of merge, change both vma and *vmap to it?.
>
> Although I can see it's all moot after 8/8. Still let's not introduce a
> temporary UAF step.

Even more vom. Will fix. The irony is I 'fixed' this code to something
'neater' by referencing vma here, rather stupidly.

You can tell it's a rarely used path since I ran a full suite of tests and
didn't hit it.

Again, speaks to the point of this series in general...

>
> > +	return 0;
> > +}
> > +
> > +/*
> > + * __mmap_new_vma() - Allocate a new VMA for the region, as merging was not
> > + *                    possible.
> > + *
> > + *                    An exception to this is if the mapping is file-backed, and
> > + *                    the underlying driver changes the VMA flags, permitting a
> > + *                    subsequent merge of the VMA, in which case the returned
> > + *                    VMA is one that was merged on a second attempt.
> > + *
> > + * @map:  Mapping state.
> > + * @vmap: Output pointer for the new VMA.
> > + *
> > + * Returns: Zero on success, or an error.
> > + */
> > +static int __mmap_new_vma(struct mmap_state *map, struct vm_area_struct **vmap)
> > +{
> > +	struct vma_iterator *vmi = map->vmi;
> > +	struct vma_merge_struct *vmg = map->vmg;
> > +	struct vm_area_struct *merge = NULL;
> > +	int error = 0;
> > +	struct vm_area_struct *vma;
> > +
> >  	/*
> >  	 * Determine the object being mapped and call the appropriate
> >  	 * specific mapper. the address has already been validated, but
> >  	 * not unmapped, but the maps are removed from the list.
> >  	 */
> > -	vma = vm_area_alloc(mm);
> > -	if (!vma) {
> > -		error = -ENOMEM;
> > -		goto unacct_error;
> > -	}
> > +	vma = vm_area_alloc(map->mm);
> > +	if (!vma)
> > +		return -ENOMEM;
> >
> > -	vma_iter_config(&vmi, addr, end);
> > -	vma_set_range(vma, addr, end, pgoff);
> > -	vm_flags_init(vma, vm_flags);
> > -	vma->vm_page_prot = vm_get_page_prot(vm_flags);
> > +	vma_iter_config(vmi, vmg->start, vmg->end);
> > +	vma_set_range(vma, vmg->start, vmg->end, vmg->pgoff);
> > +	vm_flags_init(vma, map->flags);
> > +	vma->vm_page_prot = vm_get_page_prot(map->flags);
> >
> > -	if (vma_iter_prealloc(&vmi, vma)) {
> > +	if (vma_iter_prealloc(vmi, vma)) {
> >  		error = -ENOMEM;
> >  		goto free_vma;
> >  	}
> >
> > -	if (file) {
> > -		vma->vm_file = get_file(file);
> > -		error = mmap_file(file, vma);
> > -		if (error)
> > -			goto unmap_and_free_file_vma;
> > -
> > -		/* Drivers cannot alter the address of the VMA. */
> > -		WARN_ON_ONCE(addr != vma->vm_start);
> > -		/*
> > -		 * Drivers should not permit writability when previously it was
> > -		 * disallowed.
> > -		 */
> > -		VM_WARN_ON_ONCE(vm_flags != vma->vm_flags &&
> > -				!(vm_flags & VM_MAYWRITE) &&
> > -				(vma->vm_flags & VM_MAYWRITE));
> > -
> > -		vma_iter_config(&vmi, addr, end);
> > -		/*
> > -		 * If vm_flags changed after mmap_file(), we should try merge
> > -		 * vma again as we may succeed this time.
> > -		 */
> > -		if (unlikely(vm_flags != vma->vm_flags && vmg.prev)) {
> > -			struct vm_area_struct *merge;
> > -
> > -			vmg.flags = vma->vm_flags;
> > -			/* If this fails, state is reset ready for a reattempt. */
> > -			merge = vma_merge_new_range(&vmg);
> > -
> > -			if (merge) {
> > -				/*
> > -				 * ->mmap() can change vma->vm_file and fput
> > -				 * the original file. So fput the vma->vm_file
> > -				 * here or we would add an extra fput for file
> > -				 * and cause general protection fault
> > -				 * ultimately.
> > -				 */
> > -				fput(vma->vm_file);
> > -				vm_area_free(vma);
> > -				vma_iter_free(&vmi);
> > -				vma = merge;
> > -				/* Update vm_flags to pick up the change. */
> > -				vm_flags = vma->vm_flags;
> > -				goto file_expanded;
> > -			}
> > -			vma_iter_config(&vmi, addr, end);
> > -		}
> > -
> > -		vm_flags = vma->vm_flags;
> > -	} else if (vm_flags & VM_SHARED) {
> > +	if (vmg->file)
> > +		error = __mmap_new_file_vma(map, vma, &merge);
> > +	else if (map->flags & VM_SHARED)
> >  		error = shmem_zero_setup(vma);
> > -		if (error)
> > -			goto free_iter_vma;
> > -	} else {
> > +	else
> >  		vma_set_anonymous(vma);
> > -	}
> > +
> > +	if (error)
> > +		goto free_iter_vma;
> > +
> > +	if (merge)
> > +		goto file_expanded;
> >
> >  #ifdef CONFIG_SPARC64
> >  	/* TODO: Fix SPARC ADI! */
> > -	WARN_ON_ONCE(!arch_validate_flags(vm_flags));
> > +	WARN_ON_ONCE(!arch_validate_flags(map->flags));
> >  #endif
> >
> >  	/* Lock the VMA since it is modified after insertion into VMA tree */
> >  	vma_start_write(vma);
> > -	vma_iter_store(&vmi, vma);
> > -	mm->map_count++;
> > +	vma_iter_store(vmi, vma);
> > +	map->mm->map_count++;
> >  	vma_link_file(vma);
> >
> >  	/*
> >  	 * vma_merge_new_range() calls khugepaged_enter_vma() too, the below
> >  	 * call covers the non-merge case.
> >  	 */
> > -	khugepaged_enter_vma(vma, vma->vm_flags);
> > +	khugepaged_enter_vma(vma, map->flags);
> >
> >  file_expanded:
> > -	file = vma->vm_file;
> >  	ksm_add_vma(vma);
> > -expanded:
> > +
> > +	*vmap = vma;
> > +	return 0;
> > +
> > +free_iter_vma:
> > +	vma_iter_free(vmi);
> > +free_vma:
> > +	vm_area_free(vma);
> > +	return error;
> > +}
> > +
> > +/*
> > + * __mmap_complete() - Unmap any VMAs we overlap, account memory mapping
> > + *                     statistics, handle locking and finalise the VMA.
> > + *
> > + * @map: Mapping state.
> > + * @vma: Merged or newly allocated VMA for the mmap()'d region.
> > + */
> > +static void __mmap_complete(struct mmap_state *map, struct vm_area_struct *vma)
> > +{
> > +	struct mm_struct *mm = map->mm;
> > +	unsigned long vm_flags = vma->vm_flags;
> > +
> >  	perf_event_mmap(vma);
> >
> > -	/* Unmap any existing mapping in the area */
> > -	vms_complete_munmap_vmas(&vms, &mas_detach);
> > +	/* Unmap any existing mapping in the area. */
> > +	vms_complete_munmap_vmas(&map->vms, &map->mas_detach);
> >
> > -	vm_stat_account(mm, vm_flags, pglen);
> > +	vm_stat_account(mm, vma->vm_flags, map->pglen);
> >  	if (vm_flags & VM_LOCKED) {
> >  		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
> >  					is_vm_hugetlb_page(vma) ||
> > -					vma == get_gate_vma(current->mm))
> > +					vma == get_gate_vma(mm))
> >  			vm_flags_clear(vma, VM_LOCKED_MASK);
> >  		else
> > -			mm->locked_vm += pglen;
> > +			mm->locked_vm += map->pglen;
> >  	}
> >
> > -	if (file)
> > +	if (vma->vm_file)
> >  		uprobe_mmap(vma);
> >
> >  	/*
> > @@ -2364,26 +2447,43 @@ unsigned long __mmap_region(struct file *file, unsigned long addr,
> >  	vm_flags_set(vma, VM_SOFTDIRTY);
> >
> >  	vma_set_page_prot(vma);
> > +}
> >
> > -	return addr;
> > +unsigned long __mmap_region(struct file *file, unsigned long addr,
> > +		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> > +		struct list_head *uf)
> > +{
> > +	struct mm_struct *mm = current->mm;
> > +	struct vm_area_struct *vma;
> > +	int error;
> > +	VMA_ITERATOR(vmi, mm, addr);
> > +	VMG_STATE(vmg, mm, &vmi, addr, addr + len, vm_flags, pgoff);
> > +	MMAP_STATE(map, mm, &vmi, &vmg, uf, vm_flags, len);
> >
> > -unmap_and_free_file_vma:
> > -	fput(vma->vm_file);
> > -	vma->vm_file = NULL;
> > +	vmg.file = file;
> >
> > -	vma_iter_set(&vmi, vma->vm_end);
> > -	/* Undo any partial mapping done by a device driver. */
> > -	unmap_region(&vmi.mas, vma, vmg.prev, vmg.next);
> > -free_iter_vma:
> > -	vma_iter_free(&vmi);
> > -free_vma:
> > -	vm_area_free(vma);
> > -unacct_error:
> > -	if (charged)
> > -		vm_unacct_memory(charged);
> > +	error = __mmap_prepare(&map);
> > +	if (error)
> > +		goto abort_munmap;
> > +
> > +	/* Attempt to merge with adjacent VMAs... */
> > +	vmg.flags = map.flags;
> > +	vma = vma_merge_new_range(&vmg);
> > +	if (!vma) {
> > +		/* ...but if we can't, allocate a new VMA. */
> > +		error = __mmap_new_vma(&map, &vma);
> > +		if (error)
> > +			goto unacct_error;
> > +	}
> > +
> > +	__mmap_complete(&map, vma);
> >
> > +	return addr;
> > +
> > +unacct_error:
> > +	if (map.charged)
> > +		vm_unacct_memory(map.charged);
> >  abort_munmap:
> > -	vms_abort_munmap_vmas(&vms, &mas_detach);
> > -gather_failed:
> > +	vms_abort_munmap_vmas(&map.vms, &map.mas_detach);
> >  	return error;
> >  }
> > --
> > 2.47.0
>
Lorenzo Stoakes Oct. 23, 2024, 5:39 p.m. UTC | #5
On Wed, Oct 23, 2024 at 11:21:54AM -0400, Liam R. Howlett wrote:
> * Vlastimil Babka <vbabka@suse.cz> [241023 10:39]:
> > On 10/22/24 22:40, Lorenzo Stoakes wrote:
> > > We have seen bugs and resource leaks arise from the complexity of the
> > > __mmap_region() function. This, and the generally deeply fragile error
> > > handling logic and complexity which makes understanding the function
> > > difficult make it highly desirable to refactor it into something readable.
> > >
> > > Achieve this by separating the function into smaller logical parts which
> > > are easier to understand and follow, and which importantly very
> > > significantly simplify the error handling.
> > >
> > > Note that we now call vms_abort_munmap_vmas() in more error paths than we
> > > used to, however in cases where no abort need occur, vms->nr_pages will be
> > > equal to zero and we simply exit this function without doing more than we
> > > would have done previously.
> > >
> > > Importantly, the invocation of the driver mmap hook via mmap_file() now has
> > > very simple and obvious handling (this was previously the most problematic
> > > part of the mmap() operation).
> > >
> > > Signed-off-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
> > > ---
> > >  mm/vma.c | 380 +++++++++++++++++++++++++++++++++++--------------------
> > >  1 file changed, 240 insertions(+), 140 deletions(-)
> > >
> > > diff --git a/mm/vma.c b/mm/vma.c
> > > index 7617f9d50d62..a271e2b406ab 100644
> > > --- a/mm/vma.c
> > > +++ b/mm/vma.c
> > > @@ -7,6 +7,31 @@
> > >  #include "vma_internal.h"
> > >  #include "vma.h"
> > >
> > > +struct mmap_state {
> > > +	struct mm_struct *mm;
> > > +	struct vma_iterator *vmi;
> > > +	struct vma_merge_struct *vmg;
> > > +	struct list_head *uf;
> > > +
> > > +	struct vma_munmap_struct vms;
> > > +	struct ma_state mas_detach;
> > > +	struct maple_tree mt_detach;
> > > +
> > > +	unsigned long flags;
> > > +	unsigned long pglen;
> > > +	unsigned long charged;
> > > +};
> > > +
> > > +#define MMAP_STATE(name, mm_, vmi_, vmg_, uf_, flags_, len_)	\
> > > +	struct mmap_state name = {				\
> > > +		.mm = mm_,					\
> > > +		.vmi = vmi_,					\
> > > +		.vmg = vmg_,					\
> > > +		.uf = uf_,					\
> > > +		.flags = flags_,				\
> > > +		.pglen = PHYS_PFN(len_),			\
> > > +	}
> > > +
> > >  static inline bool is_mergeable_vma(struct vma_merge_struct *vmg, bool merge_next)
> > >  {
> > >  	struct vm_area_struct *vma = merge_next ? vmg->next : vmg->prev;
> > > @@ -2169,189 +2194,247 @@ static void vms_abort_munmap_vmas(struct vma_munmap_struct *vms,
> > >  	vms_complete_munmap_vmas(vms, mas_detach);
> > >  }
> > >
> > > -unsigned long __mmap_region(struct file *file, unsigned long addr,
> > > -		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> > > -		struct list_head *uf)
> > > +/*
> > > + * __mmap_prepare() - Prepare to gather any overlapping VMAs that need to be
> > > + *                    unmapped once the map operation is completed, check limits,
> > > + *                    account mapping and clean up any pre-existing VMAs.
> > > + *
> > > + * @map: Mapping state.
> > > + *
> > > + * Returns: 0 on success, error code otherwise.
> > > + */
> > > +static int __mmap_prepare(struct mmap_state *map)
> > >  {
> > > -	struct mm_struct *mm = current->mm;
> > > -	struct vm_area_struct *vma = NULL;
> > > -	pgoff_t pglen = PHYS_PFN(len);
> > > -	unsigned long charged = 0;
> > > -	struct vma_munmap_struct vms;
> > > -	struct ma_state mas_detach;
> > > -	struct maple_tree mt_detach;
> > > -	unsigned long end = addr + len;
> > >  	int error;
> > > -	VMA_ITERATOR(vmi, mm, addr);
> > > -	VMG_STATE(vmg, mm, &vmi, addr, end, vm_flags, pgoff);
> > > -
> > > -	vmg.file = file;
> > > -	/* Find the first overlapping VMA */
> > > -	vma = vma_find(&vmi, end);
> > > -	init_vma_munmap(&vms, &vmi, vma, addr, end, uf, /* unlock = */ false);
> > > -	if (vma) {
> > > -		mt_init_flags(&mt_detach, vmi.mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> > > -		mt_on_stack(mt_detach);
> > > -		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
> > > +	struct vma_iterator *vmi = map->vmi;
> > > +	struct vma_merge_struct *vmg = map->vmg;
> > > +	struct vma_munmap_struct *vms = &map->vms;
> > > +
> > > +	/* Find the first overlapping VMA and initialise unmap state. */
> > > +	vms->vma = vma_find(vmi, vmg->end);
> > > +	init_vma_munmap(vms, vmi, vms->vma, vmg->start, vmg->end, map->uf,
> > > +			/* unlock = */ false);
> > > +
> > > +	/* OK, we have overlapping VMAs - prepare to unmap them. */
> > > +	if (vms->vma) {
> > > +		mt_init_flags(&map->mt_detach, vmi->mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> > > +		mt_on_stack(map->mt_detach);
> > > +		mas_init(&map->mas_detach, &map->mt_detach, /* addr = */ 0);
> > >  		/* Prepare to unmap any existing mapping in the area */
> > > -		error = vms_gather_munmap_vmas(&vms, &mas_detach);
> > > +		error = vms_gather_munmap_vmas(vms, &map->mas_detach);
> > >  		if (error)
> > > -			goto gather_failed;
> > > +			return error;
> >
> > So this assumes vms_abort_munmap_vmas() will rely on the "vms->nr_pages will
> > be equal to zero" mentioned in commit log. But AFAICS
> > vms_gather_munmap_vmas() can fail in Nth iteration of its
> > for_each_vma_range() after some iterations already increased nr_pages and it
> > will do a reattach_vmas() and return the error and we just return the error
> > here.
> > I think either here or maybe in vms_gather_munmap_vmas() itself a reset of
> > vms->nr_pages to zero on error should happen for the vms_abort_munmap_vmas()
> > to be a no-op?
>
> Probably in reattach_vmas()?

Hm, but that only accepts a mas and seems redundant elsewhere... am going for
simply resetting nr_pages for now and maybe we can revisit if needs be?

>
> >
> > >
> > > -		vmg.next = vms.next;
> > > -		vmg.prev = vms.prev;
> > > -		vma = NULL;
> > > +		vmg->next = vms->next;
> > > +		vmg->prev = vms->prev;
> > >  	} else {
> > > -		vmg.next = vma_iter_next_rewind(&vmi, &vmg.prev);
> > > +		vmg->next = vma_iter_next_rewind(vmi, &vmg->prev);
> > >  	}
> > >
> > >  	/* Check against address space limit. */
> > > -	if (!may_expand_vm(mm, vm_flags, pglen - vms.nr_pages)) {
> > > -		error = -ENOMEM;
> > > -		goto abort_munmap;
> > > -	}
> > > +	if (!may_expand_vm(map->mm, map->flags, map->pglen - vms->nr_pages))
> > > +		return -ENOMEM;
> > >
> > > -	/*
> > > -	 * Private writable mapping: check memory availability
> > > -	 */
> > > -	if (accountable_mapping(file, vm_flags)) {
> > > -		charged = pglen;
> > > -		charged -= vms.nr_accounted;
> > > -		if (charged) {
> > > -			error = security_vm_enough_memory_mm(mm, charged);
> > > +	/* Private writable mapping: check memory availability. */
> > > +	if (accountable_mapping(vmg->file, map->flags)) {
> > > +		map->charged = map->pglen;
> > > +		map->charged -= vms->nr_accounted;
> > > +		if (map->charged) {
> > > +			error = security_vm_enough_memory_mm(map->mm, map->charged);
> > >  			if (error)
> > > -				goto abort_munmap;
> > > +				return error;
> > >  		}
> > >
> > > -		vms.nr_accounted = 0;
> > > -		vm_flags |= VM_ACCOUNT;
> > > -		vmg.flags = vm_flags;
> > > +		vms->nr_accounted = 0;
> > > +		map->flags |= VM_ACCOUNT;
> > >  	}
> > >
> > >  	/*
> > > -	 * clear PTEs while the vma is still in the tree so that rmap
> > > +	 * Clear PTEs while the vma is still in the tree so that rmap
> > >  	 * cannot race with the freeing later in the truncate scenario.
> > >  	 * This is also needed for mmap_file(), which is why vm_ops
> > >  	 * close function is called.
> > >  	 */
> > > -	vms_clean_up_area(&vms, &mas_detach);
> > > -	vma = vma_merge_new_range(&vmg);
> > > -	if (vma)
> > > -		goto expanded;
> > > +	vms_clean_up_area(vms, &map->mas_detach);
> > > +
> > > +	return 0;
> > > +}
> > > +
> > > +static int __mmap_new_file_vma(struct mmap_state *map, struct vm_area_struct *vma,
> > > +			       struct vm_area_struct **mergep)
> > > +{
> > > +	struct vma_iterator *vmi = map->vmi;
> > > +	struct vma_merge_struct *vmg = map->vmg;
> > > +	int error;
> > > +
> > > +	vma->vm_file = get_file(vmg->file);
> > > +	error = mmap_file(vma->vm_file, vma);
> > > +	if (error) {
> > > +		fput(vma->vm_file);
> > > +		vma->vm_file = NULL;
> > > +
> > > +		vma_iter_set(vmi, vma->vm_end);
> > > +		/* Undo any partial mapping done by a device driver. */
> > > +		unmap_region(&vmi->mas, vma, vmg->prev, vmg->next);
> > > +
> > > +		return error;
> > > +	}
> > > +
> > > +	/* Drivers cannot alter the address of the VMA. */
> > > +	WARN_ON_ONCE(vmg->start != vma->vm_start);
> > > +	/*
> > > +	 * Drivers should not permit writability when previously it was
> > > +	 * disallowed.
> > > +	 */
> > > +	VM_WARN_ON_ONCE(map->flags != vma->vm_flags &&
> > > +			!(map->flags & VM_MAYWRITE) &&
> > > +			(vma->vm_flags & VM_MAYWRITE));
> > > +
> > > +	vma_iter_config(vmi, vmg->start, vmg->end);
> > > +	/*
> > > +	 * If flags changed after mmap_file(), we should try merge
> > > +	 * vma again as we may succeed this time.
> > > +	 */
> > > +	if (unlikely(map->flags != vma->vm_flags && vmg->prev)) {
> > > +		struct vm_area_struct *merge;
> > > +
> > > +		vmg->flags = vma->vm_flags;
> > > +		/* If this fails, state is reset ready for a reattempt. */
> > > +		merge = vma_merge_new_range(vmg);
> > > +
> > > +		if (merge) {
> > > +			/*
> > > +			 * ->mmap() can change vma->vm_file and fput
> > > +			 * the original file. So fput the vma->vm_file
> > > +			 * here or we would add an extra fput for file
> > > +			 * and cause general protection fault
> > > +			 * ultimately.
> > > +			 */
> > > +			fput(vma->vm_file);
> > > +			vm_area_free(vma);
> >
> > This frees the vma.
> >
> > > +			vma_iter_free(vmi);
> > > +			*mergep = merge;
> > > +		} else {
> > > +			vma_iter_config(vmi, vmg->start, vmg->end);
> > > +		}
> > > +	}
> > > +
> > > +	map->flags = vma->vm_flags;
> >
> > So this is use-after-free.
> >
> > Maybe pass only a single struct vm_area_struct **vmap parameter to this
> > function, and in case of merge, change both vma and *vmap to it?.
> >
> > Although I can see it's all moot after 8/8. Still let's not introduce a
> > temporary UAF step.
> >
> > > +	return 0;
> > > +}
> > > +
> > > +/*
> > > + * __mmap_new_vma() - Allocate a new VMA for the region, as merging was not
> > > + *                    possible.
> > > + *
> > > + *                    An exception to this is if the mapping is file-backed, and
> > > + *                    the underlying driver changes the VMA flags, permitting a
> > > + *                    subsequent merge of the VMA, in which case the returned
> > > + *                    VMA is one that was merged on a second attempt.
> > > + *
> > > + * @map:  Mapping state.
> > > + * @vmap: Output pointer for the new VMA.
> > > + *
> > > + * Returns: Zero on success, or an error.
> > > + */
> > > +static int __mmap_new_vma(struct mmap_state *map, struct vm_area_struct **vmap)
> > > +{
> > > +	struct vma_iterator *vmi = map->vmi;
> > > +	struct vma_merge_struct *vmg = map->vmg;
> > > +	struct vm_area_struct *merge = NULL;
> > > +	int error = 0;
> > > +	struct vm_area_struct *vma;
> > > +
> > >  	/*
> > >  	 * Determine the object being mapped and call the appropriate
> > >  	 * specific mapper. the address has already been validated, but
> > >  	 * not unmapped, but the maps are removed from the list.
> > >  	 */
> > > -	vma = vm_area_alloc(mm);
> > > -	if (!vma) {
> > > -		error = -ENOMEM;
> > > -		goto unacct_error;
> > > -	}
> > > +	vma = vm_area_alloc(map->mm);
> > > +	if (!vma)
> > > +		return -ENOMEM;
> > >
> > > -	vma_iter_config(&vmi, addr, end);
> > > -	vma_set_range(vma, addr, end, pgoff);
> > > -	vm_flags_init(vma, vm_flags);
> > > -	vma->vm_page_prot = vm_get_page_prot(vm_flags);
> > > +	vma_iter_config(vmi, vmg->start, vmg->end);
> > > +	vma_set_range(vma, vmg->start, vmg->end, vmg->pgoff);
> > > +	vm_flags_init(vma, map->flags);
> > > +	vma->vm_page_prot = vm_get_page_prot(map->flags);
> > >
> > > -	if (vma_iter_prealloc(&vmi, vma)) {
> > > +	if (vma_iter_prealloc(vmi, vma)) {
> > >  		error = -ENOMEM;
> > >  		goto free_vma;
> > >  	}
> > >
> > > -	if (file) {
> > > -		vma->vm_file = get_file(file);
> > > -		error = mmap_file(file, vma);
> > > -		if (error)
> > > -			goto unmap_and_free_file_vma;
> > > -
> > > -		/* Drivers cannot alter the address of the VMA. */
> > > -		WARN_ON_ONCE(addr != vma->vm_start);
> > > -		/*
> > > -		 * Drivers should not permit writability when previously it was
> > > -		 * disallowed.
> > > -		 */
> > > -		VM_WARN_ON_ONCE(vm_flags != vma->vm_flags &&
> > > -				!(vm_flags & VM_MAYWRITE) &&
> > > -				(vma->vm_flags & VM_MAYWRITE));
> > > -
> > > -		vma_iter_config(&vmi, addr, end);
> > > -		/*
> > > -		 * If vm_flags changed after mmap_file(), we should try merge
> > > -		 * vma again as we may succeed this time.
> > > -		 */
> > > -		if (unlikely(vm_flags != vma->vm_flags && vmg.prev)) {
> > > -			struct vm_area_struct *merge;
> > > -
> > > -			vmg.flags = vma->vm_flags;
> > > -			/* If this fails, state is reset ready for a reattempt. */
> > > -			merge = vma_merge_new_range(&vmg);
> > > -
> > > -			if (merge) {
> > > -				/*
> > > -				 * ->mmap() can change vma->vm_file and fput
> > > -				 * the original file. So fput the vma->vm_file
> > > -				 * here or we would add an extra fput for file
> > > -				 * and cause general protection fault
> > > -				 * ultimately.
> > > -				 */
> > > -				fput(vma->vm_file);
> > > -				vm_area_free(vma);
> > > -				vma_iter_free(&vmi);
> > > -				vma = merge;
> > > -				/* Update vm_flags to pick up the change. */
> > > -				vm_flags = vma->vm_flags;
> > > -				goto file_expanded;
> > > -			}
> > > -			vma_iter_config(&vmi, addr, end);
> > > -		}
> > > -
> > > -		vm_flags = vma->vm_flags;
> > > -	} else if (vm_flags & VM_SHARED) {
> > > +	if (vmg->file)
> > > +		error = __mmap_new_file_vma(map, vma, &merge);
> > > +	else if (map->flags & VM_SHARED)
> > >  		error = shmem_zero_setup(vma);
> > > -		if (error)
> > > -			goto free_iter_vma;
> > > -	} else {
> > > +	else
> > >  		vma_set_anonymous(vma);
> > > -	}
> > > +
> > > +	if (error)
> > > +		goto free_iter_vma;
> > > +
> > > +	if (merge)
> > > +		goto file_expanded;
> > >
> > >  #ifdef CONFIG_SPARC64
> > >  	/* TODO: Fix SPARC ADI! */
> > > -	WARN_ON_ONCE(!arch_validate_flags(vm_flags));
> > > +	WARN_ON_ONCE(!arch_validate_flags(map->flags));
> > >  #endif
> > >
> > >  	/* Lock the VMA since it is modified after insertion into VMA tree */
> > >  	vma_start_write(vma);
> > > -	vma_iter_store(&vmi, vma);
> > > -	mm->map_count++;
> > > +	vma_iter_store(vmi, vma);
> > > +	map->mm->map_count++;
> > >  	vma_link_file(vma);
> > >
> > >  	/*
> > >  	 * vma_merge_new_range() calls khugepaged_enter_vma() too, the below
> > >  	 * call covers the non-merge case.
> > >  	 */
> > > -	khugepaged_enter_vma(vma, vma->vm_flags);
> > > +	khugepaged_enter_vma(vma, map->flags);
> > >
> > >  file_expanded:
> > > -	file = vma->vm_file;
> > >  	ksm_add_vma(vma);
> > > -expanded:
> > > +
> > > +	*vmap = vma;
> > > +	return 0;
> > > +
> > > +free_iter_vma:
> > > +	vma_iter_free(vmi);
> > > +free_vma:
> > > +	vm_area_free(vma);
> > > +	return error;
> > > +}
> > > +
> > > +/*
> > > + * __mmap_complete() - Unmap any VMAs we overlap, account memory mapping
> > > + *                     statistics, handle locking and finalise the VMA.
> > > + *
> > > + * @map: Mapping state.
> > > + * @vma: Merged or newly allocated VMA for the mmap()'d region.
> > > + */
> > > +static void __mmap_complete(struct mmap_state *map, struct vm_area_struct *vma)
> > > +{
> > > +	struct mm_struct *mm = map->mm;
> > > +	unsigned long vm_flags = vma->vm_flags;
> > > +
> > >  	perf_event_mmap(vma);
> > >
> > > -	/* Unmap any existing mapping in the area */
> > > -	vms_complete_munmap_vmas(&vms, &mas_detach);
> > > +	/* Unmap any existing mapping in the area. */
> > > +	vms_complete_munmap_vmas(&map->vms, &map->mas_detach);
> > >
> > > -	vm_stat_account(mm, vm_flags, pglen);
> > > +	vm_stat_account(mm, vma->vm_flags, map->pglen);
> > >  	if (vm_flags & VM_LOCKED) {
> > >  		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
> > >  					is_vm_hugetlb_page(vma) ||
> > > -					vma == get_gate_vma(current->mm))
> > > +					vma == get_gate_vma(mm))
> > >  			vm_flags_clear(vma, VM_LOCKED_MASK);
> > >  		else
> > > -			mm->locked_vm += pglen;
> > > +			mm->locked_vm += map->pglen;
> > >  	}
> > >
> > > -	if (file)
> > > +	if (vma->vm_file)
> > >  		uprobe_mmap(vma);
> > >
> > >  	/*
> > > @@ -2364,26 +2447,43 @@ unsigned long __mmap_region(struct file *file, unsigned long addr,
> > >  	vm_flags_set(vma, VM_SOFTDIRTY);
> > >
> > >  	vma_set_page_prot(vma);
> > > +}
> > >
> > > -	return addr;
> > > +unsigned long __mmap_region(struct file *file, unsigned long addr,
> > > +		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> > > +		struct list_head *uf)
> > > +{
> > > +	struct mm_struct *mm = current->mm;
> > > +	struct vm_area_struct *vma;
> > > +	int error;
> > > +	VMA_ITERATOR(vmi, mm, addr);
> > > +	VMG_STATE(vmg, mm, &vmi, addr, addr + len, vm_flags, pgoff);
> > > +	MMAP_STATE(map, mm, &vmi, &vmg, uf, vm_flags, len);
> > >
> > > -unmap_and_free_file_vma:
> > > -	fput(vma->vm_file);
> > > -	vma->vm_file = NULL;
> > > +	vmg.file = file;
> > >
> > > -	vma_iter_set(&vmi, vma->vm_end);
> > > -	/* Undo any partial mapping done by a device driver. */
> > > -	unmap_region(&vmi.mas, vma, vmg.prev, vmg.next);
> > > -free_iter_vma:
> > > -	vma_iter_free(&vmi);
> > > -free_vma:
> > > -	vm_area_free(vma);
> > > -unacct_error:
> > > -	if (charged)
> > > -		vm_unacct_memory(charged);
> > > +	error = __mmap_prepare(&map);
> > > +	if (error)
> > > +		goto abort_munmap;
> > > +
> > > +	/* Attempt to merge with adjacent VMAs... */
> > > +	vmg.flags = map.flags;
> > > +	vma = vma_merge_new_range(&vmg);
> > > +	if (!vma) {
> > > +		/* ...but if we can't, allocate a new VMA. */
> > > +		error = __mmap_new_vma(&map, &vma);
> > > +		if (error)
> > > +			goto unacct_error;
> > > +	}
> > > +
> > > +	__mmap_complete(&map, vma);
> > >
> > > +	return addr;
> > > +
> > > +unacct_error:
> > > +	if (map.charged)
> > > +		vm_unacct_memory(map.charged);
> > >  abort_munmap:
> > > -	vms_abort_munmap_vmas(&vms, &mas_detach);
> > > -gather_failed:
> > > +	vms_abort_munmap_vmas(&map.vms, &map.mas_detach);
> > >  	return error;
> > >  }
> > > --
> > > 2.47.0
> >
Lorenzo Stoakes Oct. 23, 2024, 5:52 p.m. UTC | #6
On Wed, Oct 23, 2024 at 01:19:35PM -0400, Liam R. Howlett wrote:
> * Lorenzo Stoakes <lorenzo.stoakes@oracle.com> [241022 16:41]:
> > We have seen bugs and resource leaks arise from the complexity of the
> > __mmap_region() function. This, and the generally deeply fragile error
> > handling logic and complexity which makes understanding the function
> > difficult make it highly desirable to refactor it into something readable.
> >
> > Achieve this by separating the function into smaller logical parts which
> > are easier to understand and follow, and which importantly very
> > significantly simplify the error handling.
> >
> > Note that we now call vms_abort_munmap_vmas() in more error paths than we
> > used to, however in cases where no abort need occur, vms->nr_pages will be
> > equal to zero and we simply exit this function without doing more than we
> > would have done previously.
> >
> > Importantly, the invocation of the driver mmap hook via mmap_file() now has
> > very simple and obvious handling (this was previously the most problematic
> > part of the mmap() operation).
> >
> > Signed-off-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
> > ---
> >  mm/vma.c | 380 +++++++++++++++++++++++++++++++++++--------------------
> >  1 file changed, 240 insertions(+), 140 deletions(-)
> >
> > diff --git a/mm/vma.c b/mm/vma.c
> > index 7617f9d50d62..a271e2b406ab 100644
> > --- a/mm/vma.c
> > +++ b/mm/vma.c
> > @@ -7,6 +7,31 @@
> >  #include "vma_internal.h"
> >  #include "vma.h"
> >
> > +struct mmap_state {
> > +	struct mm_struct *mm;
> > +	struct vma_iterator *vmi;
> > +	struct vma_merge_struct *vmg;
> > +	struct list_head *uf;
> > +
> > +	struct vma_munmap_struct vms;
> > +	struct ma_state mas_detach;
> > +	struct maple_tree mt_detach;
> > +
> > +	unsigned long flags;
> > +	unsigned long pglen;
> > +	unsigned long charged;
> > +};
> > +
> > +#define MMAP_STATE(name, mm_, vmi_, vmg_, uf_, flags_, len_)	\
> > +	struct mmap_state name = {				\
> > +		.mm = mm_,					\
> > +		.vmi = vmi_,					\
> > +		.vmg = vmg_,					\
> > +		.uf = uf_,					\
> > +		.flags = flags_,				\
> > +		.pglen = PHYS_PFN(len_),			\
> > +	}
> > +
> >  static inline bool is_mergeable_vma(struct vma_merge_struct *vmg, bool merge_next)
> >  {
> >  	struct vm_area_struct *vma = merge_next ? vmg->next : vmg->prev;
> > @@ -2169,189 +2194,247 @@ static void vms_abort_munmap_vmas(struct vma_munmap_struct *vms,
> >  	vms_complete_munmap_vmas(vms, mas_detach);
> >  }
> >
> > -unsigned long __mmap_region(struct file *file, unsigned long addr,
> > -		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> > -		struct list_head *uf)
> > +/*
> > + * __mmap_prepare() - Prepare to gather any overlapping VMAs that need to be
> > + *                    unmapped once the map operation is completed, check limits,
> > + *                    account mapping and clean up any pre-existing VMAs.
> > + *
>
> nit: formatting seems wrong here?

But I like it this way :( will change though.

>
> > + * @map: Mapping state.
> > + *
> > + * Returns: 0 on success, error code otherwise.
> > + */
> > +static int __mmap_prepare(struct mmap_state *map)
> >  {
> > -	struct mm_struct *mm = current->mm;
> > -	struct vm_area_struct *vma = NULL;
> > -	pgoff_t pglen = PHYS_PFN(len);
> > -	unsigned long charged = 0;
> > -	struct vma_munmap_struct vms;
> > -	struct ma_state mas_detach;
> > -	struct maple_tree mt_detach;
> > -	unsigned long end = addr + len;
> >  	int error;
> > -	VMA_ITERATOR(vmi, mm, addr);
> > -	VMG_STATE(vmg, mm, &vmi, addr, end, vm_flags, pgoff);
> > -
> > -	vmg.file = file;
> > -	/* Find the first overlapping VMA */
> > -	vma = vma_find(&vmi, end);
> > -	init_vma_munmap(&vms, &vmi, vma, addr, end, uf, /* unlock = */ false);
> > -	if (vma) {
> > -		mt_init_flags(&mt_detach, vmi.mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> > -		mt_on_stack(mt_detach);
> > -		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
> > +	struct vma_iterator *vmi = map->vmi;
> > +	struct vma_merge_struct *vmg = map->vmg;
> > +	struct vma_munmap_struct *vms = &map->vms;
> > +
> > +	/* Find the first overlapping VMA and initialise unmap state. */
> > +	vms->vma = vma_find(vmi, vmg->end);
> > +	init_vma_munmap(vms, vmi, vms->vma, vmg->start, vmg->end, map->uf,
> > +			/* unlock = */ false);
> > +
> > +	/* OK, we have overlapping VMAs - prepare to unmap them. */
> > +	if (vms->vma) {
> > +		mt_init_flags(&map->mt_detach, vmi->mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
>
> Nit: line is too long.

Yeah think this is possibly pre-existing but will fix either way.

>
> > +		mt_on_stack(map->mt_detach);
> > +		mas_init(&map->mas_detach, &map->mt_detach, /* addr = */ 0);
> >  		/* Prepare to unmap any existing mapping in the area */
> > -		error = vms_gather_munmap_vmas(&vms, &mas_detach);
> > +		error = vms_gather_munmap_vmas(vms, &map->mas_detach);
> >  		if (error)
> > -			goto gather_failed;
> > +			return error;
>
> As Vlastimil pointed out, there is an issue just returning the error.

Yeah have addressed it, thanks!

>
> >
> > -		vmg.next = vms.next;
> > -		vmg.prev = vms.prev;
> > -		vma = NULL;
> > +		vmg->next = vms->next;
> > +		vmg->prev = vms->prev;
> >  	} else {
> > -		vmg.next = vma_iter_next_rewind(&vmi, &vmg.prev);
> > +		vmg->next = vma_iter_next_rewind(vmi, &vmg->prev);
> >  	}
> >
> >  	/* Check against address space limit. */
> > -	if (!may_expand_vm(mm, vm_flags, pglen - vms.nr_pages)) {
> > -		error = -ENOMEM;
> > -		goto abort_munmap;
> > -	}
> > +	if (!may_expand_vm(map->mm, map->flags, map->pglen - vms->nr_pages))
> > +		return -ENOMEM;
> >
> > -	/*
> > -	 * Private writable mapping: check memory availability
> > -	 */
> > -	if (accountable_mapping(file, vm_flags)) {
> > -		charged = pglen;
> > -		charged -= vms.nr_accounted;
> > -		if (charged) {
> > -			error = security_vm_enough_memory_mm(mm, charged);
> > +	/* Private writable mapping: check memory availability. */
> > +	if (accountable_mapping(vmg->file, map->flags)) {
> > +		map->charged = map->pglen;
> > +		map->charged -= vms->nr_accounted;
> > +		if (map->charged) {
> > +			error = security_vm_enough_memory_mm(map->mm, map->charged);
> >  			if (error)
> > -				goto abort_munmap;
> > +				return error;
> >  		}
> >
> > -		vms.nr_accounted = 0;
> > -		vm_flags |= VM_ACCOUNT;
> > -		vmg.flags = vm_flags;
> > +		vms->nr_accounted = 0;
> > +		map->flags |= VM_ACCOUNT;
> >  	}
> >
> >  	/*
> > -	 * clear PTEs while the vma is still in the tree so that rmap
> > +	 * Clear PTEs while the vma is still in the tree so that rmap
> >  	 * cannot race with the freeing later in the truncate scenario.
> >  	 * This is also needed for mmap_file(), which is why vm_ops
> >  	 * close function is called.
> >  	 */
> > -	vms_clean_up_area(&vms, &mas_detach);
> > -	vma = vma_merge_new_range(&vmg);
> > -	if (vma)
> > -		goto expanded;
> > +	vms_clean_up_area(vms, &map->mas_detach);
> > +
> > +	return 0;
> > +}
> > +
> > +static int __mmap_new_file_vma(struct mmap_state *map, struct vm_area_struct *vma,
> > +			       struct vm_area_struct **mergep)
> > +{
> > +	struct vma_iterator *vmi = map->vmi;
> > +	struct vma_merge_struct *vmg = map->vmg;
> > +	int error;
> > +
> > +	vma->vm_file = get_file(vmg->file);
> > +	error = mmap_file(vma->vm_file, vma);
> > +	if (error) {
> > +		fput(vma->vm_file);
> > +		vma->vm_file = NULL;
> > +
> > +		vma_iter_set(vmi, vma->vm_end);
> > +		/* Undo any partial mapping done by a device driver. */
> > +		unmap_region(&vmi->mas, vma, vmg->prev, vmg->next);
> > +
> > +		return error;
> > +	}
> > +
> > +	/* Drivers cannot alter the address of the VMA. */
> > +	WARN_ON_ONCE(vmg->start != vma->vm_start);
> > +	/*
> > +	 * Drivers should not permit writability when previously it was
> > +	 * disallowed.
> > +	 */
> > +	VM_WARN_ON_ONCE(map->flags != vma->vm_flags &&
> > +			!(map->flags & VM_MAYWRITE) &&
> > +			(vma->vm_flags & VM_MAYWRITE));
> > +
> > +	vma_iter_config(vmi, vmg->start, vmg->end);
> > +	/*
> > +	 * If flags changed after mmap_file(), we should try merge
> > +	 * vma again as we may succeed this time.
> > +	 */
> > +	if (unlikely(map->flags != vma->vm_flags && vmg->prev)) {
> > +		struct vm_area_struct *merge;
> > +
> > +		vmg->flags = vma->vm_flags;
> > +		/* If this fails, state is reset ready for a reattempt. */
> > +		merge = vma_merge_new_range(vmg);
> > +
> > +		if (merge) {
> > +			/*
> > +			 * ->mmap() can change vma->vm_file and fput
> > +			 * the original file. So fput the vma->vm_file
> > +			 * here or we would add an extra fput for file
> > +			 * and cause general protection fault
> > +			 * ultimately.
> > +			 */
> > +			fput(vma->vm_file);
> > +			vm_area_free(vma);
> > +			vma_iter_free(vmi);
> > +			*mergep = merge;
> > +		} else {
> > +			vma_iter_config(vmi, vmg->start, vmg->end);
> > +		}
> > +	}
> > +
> > +	map->flags = vma->vm_flags;
> > +	return 0;
> > +}
> > +
> > +/*
> > + * __mmap_new_vma() - Allocate a new VMA for the region, as merging was not
> > + *                    possible.
> > + *
> > + *                    An exception to this is if the mapping is file-backed, and
> > + *                    the underlying driver changes the VMA flags, permitting a
> > + *                    subsequent merge of the VMA, in which case the returned
> > + *                    VMA is one that was merged on a second attempt.
>
> It seems all the descriptions have indented lines.

I like it that way :( will change :'(

>
> > + *
> > + * @map:  Mapping state.
> > + * @vmap: Output pointer for the new VMA.
> > + *
> > + * Returns: Zero on success, or an error.
> > + */
> > +static int __mmap_new_vma(struct mmap_state *map, struct vm_area_struct **vmap)
> > +{
> > +	struct vma_iterator *vmi = map->vmi;
> > +	struct vma_merge_struct *vmg = map->vmg;
> > +	struct vm_area_struct *merge = NULL;
> > +	int error = 0;
> > +	struct vm_area_struct *vma;
> > +
> >  	/*
> >  	 * Determine the object being mapped and call the appropriate
> >  	 * specific mapper. the address has already been validated, but
> >  	 * not unmapped, but the maps are removed from the list.
> >  	 */
> > -	vma = vm_area_alloc(mm);
> > -	if (!vma) {
> > -		error = -ENOMEM;
> > -		goto unacct_error;
> > -	}
> > +	vma = vm_area_alloc(map->mm);
> > +	if (!vma)
> > +		return -ENOMEM;
> >
> > -	vma_iter_config(&vmi, addr, end);
> > -	vma_set_range(vma, addr, end, pgoff);
> > -	vm_flags_init(vma, vm_flags);
> > -	vma->vm_page_prot = vm_get_page_prot(vm_flags);
> > +	vma_iter_config(vmi, vmg->start, vmg->end);
>
> This function is only called from __mmap_region() after an
> __mmap_prepare() and vma_merge_new_range().  The state should be fine,
> so maybe this could be WARN_ONs?  Although, it's probably safer to just
> leave it.

Yeah perhaps one for the laterbase?

>
> > +	vma_set_range(vma, vmg->start, vmg->end, vmg->pgoff);
> > +	vm_flags_init(vma, map->flags);
> > +	vma->vm_page_prot = vm_get_page_prot(map->flags);
> >
> > -	if (vma_iter_prealloc(&vmi, vma)) {
> > +	if (vma_iter_prealloc(vmi, vma)) {
> >  		error = -ENOMEM;
> >  		goto free_vma;
> >  	}
> >
> > -	if (file) {
> > -		vma->vm_file = get_file(file);
> > -		error = mmap_file(file, vma);
> > -		if (error)
> > -			goto unmap_and_free_file_vma;
> > -
> > -		/* Drivers cannot alter the address of the VMA. */
> > -		WARN_ON_ONCE(addr != vma->vm_start);
> > -		/*
> > -		 * Drivers should not permit writability when previously it was
> > -		 * disallowed.
> > -		 */
> > -		VM_WARN_ON_ONCE(vm_flags != vma->vm_flags &&
> > -				!(vm_flags & VM_MAYWRITE) &&
> > -				(vma->vm_flags & VM_MAYWRITE));
> > -
> > -		vma_iter_config(&vmi, addr, end);
> > -		/*
> > -		 * If vm_flags changed after mmap_file(), we should try merge
> > -		 * vma again as we may succeed this time.
> > -		 */
> > -		if (unlikely(vm_flags != vma->vm_flags && vmg.prev)) {
> > -			struct vm_area_struct *merge;
> > -
> > -			vmg.flags = vma->vm_flags;
> > -			/* If this fails, state is reset ready for a reattempt. */
> > -			merge = vma_merge_new_range(&vmg);
> > -
> > -			if (merge) {
> > -				/*
> > -				 * ->mmap() can change vma->vm_file and fput
> > -				 * the original file. So fput the vma->vm_file
> > -				 * here or we would add an extra fput for file
> > -				 * and cause general protection fault
> > -				 * ultimately.
> > -				 */
> > -				fput(vma->vm_file);
> > -				vm_area_free(vma);
> > -				vma_iter_free(&vmi);
> > -				vma = merge;
> > -				/* Update vm_flags to pick up the change. */
> > -				vm_flags = vma->vm_flags;
> > -				goto file_expanded;
> > -			}
> > -			vma_iter_config(&vmi, addr, end);
> > -		}
> > -
> > -		vm_flags = vma->vm_flags;
> > -	} else if (vm_flags & VM_SHARED) {
> > +	if (vmg->file)
> > +		error = __mmap_new_file_vma(map, vma, &merge);
> > +	else if (map->flags & VM_SHARED)
> >  		error = shmem_zero_setup(vma);
> > -		if (error)
> > -			goto free_iter_vma;
> > -	} else {
> > +	else
> >  		vma_set_anonymous(vma);
> > -	}
> > +
> > +	if (error)
> > +		goto free_iter_vma;
> > +
> > +	if (merge)
> > +		goto file_expanded;
> >
> >  #ifdef CONFIG_SPARC64
> >  	/* TODO: Fix SPARC ADI! */
> > -	WARN_ON_ONCE(!arch_validate_flags(vm_flags));
> > +	WARN_ON_ONCE(!arch_validate_flags(map->flags));
> >  #endif
> >
> >  	/* Lock the VMA since it is modified after insertion into VMA tree */
> >  	vma_start_write(vma);
> > -	vma_iter_store(&vmi, vma);
> > -	mm->map_count++;
> > +	vma_iter_store(vmi, vma);
> > +	map->mm->map_count++;
> >  	vma_link_file(vma);
> >
> >  	/*
> >  	 * vma_merge_new_range() calls khugepaged_enter_vma() too, the below
> >  	 * call covers the non-merge case.
> >  	 */
> > -	khugepaged_enter_vma(vma, vma->vm_flags);
> > +	khugepaged_enter_vma(vma, map->flags);
> >
> >  file_expanded:
> > -	file = vma->vm_file;
> >  	ksm_add_vma(vma);
>
> __mmap_new_file_vma() may free the vma.  I assume this is what you
> mentioned elsewhere about UAF.

Yeah no this is a new one, let me totally rework this thing and make it
operate on a single VMA and return a merged boolean and reduce the window
between 'pointer freed and pointer reassgined to literally ZERO LINES.

Sorry, I was probably feverish when I did this bit...

>
> > -expanded:
> > +
>
> Extra whitespace.

You hate that whitespace :(

>
> > +	*vmap = vma;
> > +	return 0;
> > +
> > +free_iter_vma:
> > +	vma_iter_free(vmi);
> > +free_vma:
> > +	vm_area_free(vma);
> > +	return error;
> > +}
> > +
> > +/*
> > + * __mmap_complete() - Unmap any VMAs we overlap, account memory mapping
> > + *                     statistics, handle locking and finalise the VMA.
> > + *
> > + * @map: Mapping state.
> > + * @vma: Merged or newly allocated VMA for the mmap()'d region.
> > + */
> > +static void __mmap_complete(struct mmap_state *map, struct vm_area_struct *vma)
> > +{
> > +	struct mm_struct *mm = map->mm;
> > +	unsigned long vm_flags = vma->vm_flags;
> > +
> >  	perf_event_mmap(vma);
> >
> > -	/* Unmap any existing mapping in the area */
> > -	vms_complete_munmap_vmas(&vms, &mas_detach);
> > +	/* Unmap any existing mapping in the area. */
> > +	vms_complete_munmap_vmas(&map->vms, &map->mas_detach);
> >
> > -	vm_stat_account(mm, vm_flags, pglen);
> > +	vm_stat_account(mm, vma->vm_flags, map->pglen);
> >  	if (vm_flags & VM_LOCKED) {
> >  		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
> >  					is_vm_hugetlb_page(vma) ||
> > -					vma == get_gate_vma(current->mm))
> > +					vma == get_gate_vma(mm))
> >  			vm_flags_clear(vma, VM_LOCKED_MASK);
> >  		else
> > -			mm->locked_vm += pglen;
> > +			mm->locked_vm += map->pglen;
> >  	}
> >
> > -	if (file)
> > +	if (vma->vm_file)
> >  		uprobe_mmap(vma);
> >
> >  	/*
> > @@ -2364,26 +2447,43 @@ unsigned long __mmap_region(struct file *file, unsigned long addr,
> >  	vm_flags_set(vma, VM_SOFTDIRTY);
> >
> >  	vma_set_page_prot(vma);
> > +}
> >
> > -	return addr;
> > +unsigned long __mmap_region(struct file *file, unsigned long addr,
> > +		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> > +		struct list_head *uf)
> > +{
> > +	struct mm_struct *mm = current->mm;
> > +	struct vm_area_struct *vma;
> > +	int error;
> > +	VMA_ITERATOR(vmi, mm, addr);
> > +	VMG_STATE(vmg, mm, &vmi, addr, addr + len, vm_flags, pgoff);
> > +	MMAP_STATE(map, mm, &vmi, &vmg, uf, vm_flags, len);
> >
> > -unmap_and_free_file_vma:
> > -	fput(vma->vm_file);
> > -	vma->vm_file = NULL;
> > +	vmg.file = file;
> >
> > -	vma_iter_set(&vmi, vma->vm_end);
> > -	/* Undo any partial mapping done by a device driver. */
> > -	unmap_region(&vmi.mas, vma, vmg.prev, vmg.next);
> > -free_iter_vma:
> > -	vma_iter_free(&vmi);
> > -free_vma:
> > -	vm_area_free(vma);
> > -unacct_error:
> > -	if (charged)
> > -		vm_unacct_memory(charged);
> > +	error = __mmap_prepare(&map);
> > +	if (error)
> > +		goto abort_munmap;
> > +
> > +	/* Attempt to merge with adjacent VMAs... */
> > +	vmg.flags = map.flags;
> > +	vma = vma_merge_new_range(&vmg);
> > +	if (!vma) {
> > +		/* ...but if we can't, allocate a new VMA. */
> > +		error = __mmap_new_vma(&map, &vma);
> > +		if (error)
> > +			goto unacct_error;
> > +	}
> > +
> > +	__mmap_complete(&map, vma);
> >
> > +	return addr;
> > +
> > +unacct_error:
> > +	if (map.charged)
> > +		vm_unacct_memory(map.charged);
>
> So this is effectively undoing __mmap_prepare()'s accounting. I don't
> have a better label for it, but it's not obvious by the label that the
> accounting was done in __mmap_prepare().

There's a comment in the description of the function. I'll add a comment
here too to be clear about it.

>
> >  abort_munmap:
> > -	vms_abort_munmap_vmas(&vms, &mas_detach);
> > -gather_failed:
> > +	vms_abort_munmap_vmas(&map.vms, &map.mas_detach);
> >  	return error;
> >  }
> > --
> > 2.47.0
Liam R. Howlett Oct. 23, 2024, 6:12 p.m. UTC | #7
* Lorenzo Stoakes <lorenzo.stoakes@oracle.com> [241023 13:39]:
> On Wed, Oct 23, 2024 at 11:21:54AM -0400, Liam R. Howlett wrote:
> > * Vlastimil Babka <vbabka@suse.cz> [241023 10:39]:
> > > On 10/22/24 22:40, Lorenzo Stoakes wrote:
> > > > We have seen bugs and resource leaks arise from the complexity of the
> > > > __mmap_region() function. This, and the generally deeply fragile error
> > > > handling logic and complexity which makes understanding the function
> > > > difficult make it highly desirable to refactor it into something readable.
> > > >
> > > > Achieve this by separating the function into smaller logical parts which
> > > > are easier to understand and follow, and which importantly very
> > > > significantly simplify the error handling.
> > > >
> > > > Note that we now call vms_abort_munmap_vmas() in more error paths than we
> > > > used to, however in cases where no abort need occur, vms->nr_pages will be
> > > > equal to zero and we simply exit this function without doing more than we
> > > > would have done previously.
> > > >
> > > > Importantly, the invocation of the driver mmap hook via mmap_file() now has
> > > > very simple and obvious handling (this was previously the most problematic
> > > > part of the mmap() operation).
> > > >
> > > > Signed-off-by: Lorenzo Stoakes <lorenzo.stoakes@oracle.com>
> > > > ---
> > > >  mm/vma.c | 380 +++++++++++++++++++++++++++++++++++--------------------
> > > >  1 file changed, 240 insertions(+), 140 deletions(-)
> > > >
> > > > diff --git a/mm/vma.c b/mm/vma.c
> > > > index 7617f9d50d62..a271e2b406ab 100644
> > > > --- a/mm/vma.c
> > > > +++ b/mm/vma.c
> > > > @@ -7,6 +7,31 @@
> > > >  #include "vma_internal.h"
> > > >  #include "vma.h"
> > > >
> > > > +struct mmap_state {
> > > > +	struct mm_struct *mm;
> > > > +	struct vma_iterator *vmi;
> > > > +	struct vma_merge_struct *vmg;
> > > > +	struct list_head *uf;
> > > > +
> > > > +	struct vma_munmap_struct vms;
> > > > +	struct ma_state mas_detach;
> > > > +	struct maple_tree mt_detach;
> > > > +
> > > > +	unsigned long flags;
> > > > +	unsigned long pglen;
> > > > +	unsigned long charged;
> > > > +};
> > > > +
> > > > +#define MMAP_STATE(name, mm_, vmi_, vmg_, uf_, flags_, len_)	\
> > > > +	struct mmap_state name = {				\
> > > > +		.mm = mm_,					\
> > > > +		.vmi = vmi_,					\
> > > > +		.vmg = vmg_,					\
> > > > +		.uf = uf_,					\
> > > > +		.flags = flags_,				\
> > > > +		.pglen = PHYS_PFN(len_),			\
> > > > +	}
> > > > +
> > > >  static inline bool is_mergeable_vma(struct vma_merge_struct *vmg, bool merge_next)
> > > >  {
> > > >  	struct vm_area_struct *vma = merge_next ? vmg->next : vmg->prev;
> > > > @@ -2169,189 +2194,247 @@ static void vms_abort_munmap_vmas(struct vma_munmap_struct *vms,
> > > >  	vms_complete_munmap_vmas(vms, mas_detach);
> > > >  }
> > > >
> > > > -unsigned long __mmap_region(struct file *file, unsigned long addr,
> > > > -		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
> > > > -		struct list_head *uf)
> > > > +/*
> > > > + * __mmap_prepare() - Prepare to gather any overlapping VMAs that need to be
> > > > + *                    unmapped once the map operation is completed, check limits,
> > > > + *                    account mapping and clean up any pre-existing VMAs.
> > > > + *
> > > > + * @map: Mapping state.
> > > > + *
> > > > + * Returns: 0 on success, error code otherwise.
> > > > + */
> > > > +static int __mmap_prepare(struct mmap_state *map)
> > > >  {
> > > > -	struct mm_struct *mm = current->mm;
> > > > -	struct vm_area_struct *vma = NULL;
> > > > -	pgoff_t pglen = PHYS_PFN(len);
> > > > -	unsigned long charged = 0;
> > > > -	struct vma_munmap_struct vms;
> > > > -	struct ma_state mas_detach;
> > > > -	struct maple_tree mt_detach;
> > > > -	unsigned long end = addr + len;
> > > >  	int error;
> > > > -	VMA_ITERATOR(vmi, mm, addr);
> > > > -	VMG_STATE(vmg, mm, &vmi, addr, end, vm_flags, pgoff);
> > > > -
> > > > -	vmg.file = file;
> > > > -	/* Find the first overlapping VMA */
> > > > -	vma = vma_find(&vmi, end);
> > > > -	init_vma_munmap(&vms, &vmi, vma, addr, end, uf, /* unlock = */ false);
> > > > -	if (vma) {
> > > > -		mt_init_flags(&mt_detach, vmi.mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> > > > -		mt_on_stack(mt_detach);
> > > > -		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
> > > > +	struct vma_iterator *vmi = map->vmi;
> > > > +	struct vma_merge_struct *vmg = map->vmg;
> > > > +	struct vma_munmap_struct *vms = &map->vms;
> > > > +
> > > > +	/* Find the first overlapping VMA and initialise unmap state. */
> > > > +	vms->vma = vma_find(vmi, vmg->end);
> > > > +	init_vma_munmap(vms, vmi, vms->vma, vmg->start, vmg->end, map->uf,
> > > > +			/* unlock = */ false);
> > > > +
> > > > +	/* OK, we have overlapping VMAs - prepare to unmap them. */
> > > > +	if (vms->vma) {
> > > > +		mt_init_flags(&map->mt_detach, vmi->mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
> > > > +		mt_on_stack(map->mt_detach);
> > > > +		mas_init(&map->mas_detach, &map->mt_detach, /* addr = */ 0);
> > > >  		/* Prepare to unmap any existing mapping in the area */
> > > > -		error = vms_gather_munmap_vmas(&vms, &mas_detach);
> > > > +		error = vms_gather_munmap_vmas(vms, &map->mas_detach);
> > > >  		if (error)
> > > > -			goto gather_failed;
> > > > +			return error;
> > >
> > > So this assumes vms_abort_munmap_vmas() will rely on the "vms->nr_pages will
> > > be equal to zero" mentioned in commit log. But AFAICS
> > > vms_gather_munmap_vmas() can fail in Nth iteration of its
> > > for_each_vma_range() after some iterations already increased nr_pages and it
> > > will do a reattach_vmas() and return the error and we just return the error
> > > here.
> > > I think either here or maybe in vms_gather_munmap_vmas() itself a reset of
> > > vms->nr_pages to zero on error should happen for the vms_abort_munmap_vmas()
> > > to be a no-op?
> >
> > Probably in reattach_vmas()?
> 
> Hm, but that only accepts a mas and seems redundant elsewhere... am going for
> simply resetting nr_pages for now and maybe we can revisit if needs be?

Okay.
diff mbox series

Patch

diff --git a/mm/vma.c b/mm/vma.c
index 7617f9d50d62..a271e2b406ab 100644
--- a/mm/vma.c
+++ b/mm/vma.c
@@ -7,6 +7,31 @@ 
 #include "vma_internal.h"
 #include "vma.h"

+struct mmap_state {
+	struct mm_struct *mm;
+	struct vma_iterator *vmi;
+	struct vma_merge_struct *vmg;
+	struct list_head *uf;
+
+	struct vma_munmap_struct vms;
+	struct ma_state mas_detach;
+	struct maple_tree mt_detach;
+
+	unsigned long flags;
+	unsigned long pglen;
+	unsigned long charged;
+};
+
+#define MMAP_STATE(name, mm_, vmi_, vmg_, uf_, flags_, len_)	\
+	struct mmap_state name = {				\
+		.mm = mm_,					\
+		.vmi = vmi_,					\
+		.vmg = vmg_,					\
+		.uf = uf_,					\
+		.flags = flags_,				\
+		.pglen = PHYS_PFN(len_),			\
+	}
+
 static inline bool is_mergeable_vma(struct vma_merge_struct *vmg, bool merge_next)
 {
 	struct vm_area_struct *vma = merge_next ? vmg->next : vmg->prev;
@@ -2169,189 +2194,247 @@  static void vms_abort_munmap_vmas(struct vma_munmap_struct *vms,
 	vms_complete_munmap_vmas(vms, mas_detach);
 }

-unsigned long __mmap_region(struct file *file, unsigned long addr,
-		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
-		struct list_head *uf)
+/*
+ * __mmap_prepare() - Prepare to gather any overlapping VMAs that need to be
+ *                    unmapped once the map operation is completed, check limits,
+ *                    account mapping and clean up any pre-existing VMAs.
+ *
+ * @map: Mapping state.
+ *
+ * Returns: 0 on success, error code otherwise.
+ */
+static int __mmap_prepare(struct mmap_state *map)
 {
-	struct mm_struct *mm = current->mm;
-	struct vm_area_struct *vma = NULL;
-	pgoff_t pglen = PHYS_PFN(len);
-	unsigned long charged = 0;
-	struct vma_munmap_struct vms;
-	struct ma_state mas_detach;
-	struct maple_tree mt_detach;
-	unsigned long end = addr + len;
 	int error;
-	VMA_ITERATOR(vmi, mm, addr);
-	VMG_STATE(vmg, mm, &vmi, addr, end, vm_flags, pgoff);
-
-	vmg.file = file;
-	/* Find the first overlapping VMA */
-	vma = vma_find(&vmi, end);
-	init_vma_munmap(&vms, &vmi, vma, addr, end, uf, /* unlock = */ false);
-	if (vma) {
-		mt_init_flags(&mt_detach, vmi.mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
-		mt_on_stack(mt_detach);
-		mas_init(&mas_detach, &mt_detach, /* addr = */ 0);
+	struct vma_iterator *vmi = map->vmi;
+	struct vma_merge_struct *vmg = map->vmg;
+	struct vma_munmap_struct *vms = &map->vms;
+
+	/* Find the first overlapping VMA and initialise unmap state. */
+	vms->vma = vma_find(vmi, vmg->end);
+	init_vma_munmap(vms, vmi, vms->vma, vmg->start, vmg->end, map->uf,
+			/* unlock = */ false);
+
+	/* OK, we have overlapping VMAs - prepare to unmap them. */
+	if (vms->vma) {
+		mt_init_flags(&map->mt_detach, vmi->mas.tree->ma_flags & MT_FLAGS_LOCK_MASK);
+		mt_on_stack(map->mt_detach);
+		mas_init(&map->mas_detach, &map->mt_detach, /* addr = */ 0);
 		/* Prepare to unmap any existing mapping in the area */
-		error = vms_gather_munmap_vmas(&vms, &mas_detach);
+		error = vms_gather_munmap_vmas(vms, &map->mas_detach);
 		if (error)
-			goto gather_failed;
+			return error;

-		vmg.next = vms.next;
-		vmg.prev = vms.prev;
-		vma = NULL;
+		vmg->next = vms->next;
+		vmg->prev = vms->prev;
 	} else {
-		vmg.next = vma_iter_next_rewind(&vmi, &vmg.prev);
+		vmg->next = vma_iter_next_rewind(vmi, &vmg->prev);
 	}

 	/* Check against address space limit. */
-	if (!may_expand_vm(mm, vm_flags, pglen - vms.nr_pages)) {
-		error = -ENOMEM;
-		goto abort_munmap;
-	}
+	if (!may_expand_vm(map->mm, map->flags, map->pglen - vms->nr_pages))
+		return -ENOMEM;

-	/*
-	 * Private writable mapping: check memory availability
-	 */
-	if (accountable_mapping(file, vm_flags)) {
-		charged = pglen;
-		charged -= vms.nr_accounted;
-		if (charged) {
-			error = security_vm_enough_memory_mm(mm, charged);
+	/* Private writable mapping: check memory availability. */
+	if (accountable_mapping(vmg->file, map->flags)) {
+		map->charged = map->pglen;
+		map->charged -= vms->nr_accounted;
+		if (map->charged) {
+			error = security_vm_enough_memory_mm(map->mm, map->charged);
 			if (error)
-				goto abort_munmap;
+				return error;
 		}

-		vms.nr_accounted = 0;
-		vm_flags |= VM_ACCOUNT;
-		vmg.flags = vm_flags;
+		vms->nr_accounted = 0;
+		map->flags |= VM_ACCOUNT;
 	}

 	/*
-	 * clear PTEs while the vma is still in the tree so that rmap
+	 * Clear PTEs while the vma is still in the tree so that rmap
 	 * cannot race with the freeing later in the truncate scenario.
 	 * This is also needed for mmap_file(), which is why vm_ops
 	 * close function is called.
 	 */
-	vms_clean_up_area(&vms, &mas_detach);
-	vma = vma_merge_new_range(&vmg);
-	if (vma)
-		goto expanded;
+	vms_clean_up_area(vms, &map->mas_detach);
+
+	return 0;
+}
+
+static int __mmap_new_file_vma(struct mmap_state *map, struct vm_area_struct *vma,
+			       struct vm_area_struct **mergep)
+{
+	struct vma_iterator *vmi = map->vmi;
+	struct vma_merge_struct *vmg = map->vmg;
+	int error;
+
+	vma->vm_file = get_file(vmg->file);
+	error = mmap_file(vma->vm_file, vma);
+	if (error) {
+		fput(vma->vm_file);
+		vma->vm_file = NULL;
+
+		vma_iter_set(vmi, vma->vm_end);
+		/* Undo any partial mapping done by a device driver. */
+		unmap_region(&vmi->mas, vma, vmg->prev, vmg->next);
+
+		return error;
+	}
+
+	/* Drivers cannot alter the address of the VMA. */
+	WARN_ON_ONCE(vmg->start != vma->vm_start);
+	/*
+	 * Drivers should not permit writability when previously it was
+	 * disallowed.
+	 */
+	VM_WARN_ON_ONCE(map->flags != vma->vm_flags &&
+			!(map->flags & VM_MAYWRITE) &&
+			(vma->vm_flags & VM_MAYWRITE));
+
+	vma_iter_config(vmi, vmg->start, vmg->end);
+	/*
+	 * If flags changed after mmap_file(), we should try merge
+	 * vma again as we may succeed this time.
+	 */
+	if (unlikely(map->flags != vma->vm_flags && vmg->prev)) {
+		struct vm_area_struct *merge;
+
+		vmg->flags = vma->vm_flags;
+		/* If this fails, state is reset ready for a reattempt. */
+		merge = vma_merge_new_range(vmg);
+
+		if (merge) {
+			/*
+			 * ->mmap() can change vma->vm_file and fput
+			 * the original file. So fput the vma->vm_file
+			 * here or we would add an extra fput for file
+			 * and cause general protection fault
+			 * ultimately.
+			 */
+			fput(vma->vm_file);
+			vm_area_free(vma);
+			vma_iter_free(vmi);
+			*mergep = merge;
+		} else {
+			vma_iter_config(vmi, vmg->start, vmg->end);
+		}
+	}
+
+	map->flags = vma->vm_flags;
+	return 0;
+}
+
+/*
+ * __mmap_new_vma() - Allocate a new VMA for the region, as merging was not
+ *                    possible.
+ *
+ *                    An exception to this is if the mapping is file-backed, and
+ *                    the underlying driver changes the VMA flags, permitting a
+ *                    subsequent merge of the VMA, in which case the returned
+ *                    VMA is one that was merged on a second attempt.
+ *
+ * @map:  Mapping state.
+ * @vmap: Output pointer for the new VMA.
+ *
+ * Returns: Zero on success, or an error.
+ */
+static int __mmap_new_vma(struct mmap_state *map, struct vm_area_struct **vmap)
+{
+	struct vma_iterator *vmi = map->vmi;
+	struct vma_merge_struct *vmg = map->vmg;
+	struct vm_area_struct *merge = NULL;
+	int error = 0;
+	struct vm_area_struct *vma;
+
 	/*
 	 * Determine the object being mapped and call the appropriate
 	 * specific mapper. the address has already been validated, but
 	 * not unmapped, but the maps are removed from the list.
 	 */
-	vma = vm_area_alloc(mm);
-	if (!vma) {
-		error = -ENOMEM;
-		goto unacct_error;
-	}
+	vma = vm_area_alloc(map->mm);
+	if (!vma)
+		return -ENOMEM;

-	vma_iter_config(&vmi, addr, end);
-	vma_set_range(vma, addr, end, pgoff);
-	vm_flags_init(vma, vm_flags);
-	vma->vm_page_prot = vm_get_page_prot(vm_flags);
+	vma_iter_config(vmi, vmg->start, vmg->end);
+	vma_set_range(vma, vmg->start, vmg->end, vmg->pgoff);
+	vm_flags_init(vma, map->flags);
+	vma->vm_page_prot = vm_get_page_prot(map->flags);

-	if (vma_iter_prealloc(&vmi, vma)) {
+	if (vma_iter_prealloc(vmi, vma)) {
 		error = -ENOMEM;
 		goto free_vma;
 	}

-	if (file) {
-		vma->vm_file = get_file(file);
-		error = mmap_file(file, vma);
-		if (error)
-			goto unmap_and_free_file_vma;
-
-		/* Drivers cannot alter the address of the VMA. */
-		WARN_ON_ONCE(addr != vma->vm_start);
-		/*
-		 * Drivers should not permit writability when previously it was
-		 * disallowed.
-		 */
-		VM_WARN_ON_ONCE(vm_flags != vma->vm_flags &&
-				!(vm_flags & VM_MAYWRITE) &&
-				(vma->vm_flags & VM_MAYWRITE));
-
-		vma_iter_config(&vmi, addr, end);
-		/*
-		 * If vm_flags changed after mmap_file(), we should try merge
-		 * vma again as we may succeed this time.
-		 */
-		if (unlikely(vm_flags != vma->vm_flags && vmg.prev)) {
-			struct vm_area_struct *merge;
-
-			vmg.flags = vma->vm_flags;
-			/* If this fails, state is reset ready for a reattempt. */
-			merge = vma_merge_new_range(&vmg);
-
-			if (merge) {
-				/*
-				 * ->mmap() can change vma->vm_file and fput
-				 * the original file. So fput the vma->vm_file
-				 * here or we would add an extra fput for file
-				 * and cause general protection fault
-				 * ultimately.
-				 */
-				fput(vma->vm_file);
-				vm_area_free(vma);
-				vma_iter_free(&vmi);
-				vma = merge;
-				/* Update vm_flags to pick up the change. */
-				vm_flags = vma->vm_flags;
-				goto file_expanded;
-			}
-			vma_iter_config(&vmi, addr, end);
-		}
-
-		vm_flags = vma->vm_flags;
-	} else if (vm_flags & VM_SHARED) {
+	if (vmg->file)
+		error = __mmap_new_file_vma(map, vma, &merge);
+	else if (map->flags & VM_SHARED)
 		error = shmem_zero_setup(vma);
-		if (error)
-			goto free_iter_vma;
-	} else {
+	else
 		vma_set_anonymous(vma);
-	}
+
+	if (error)
+		goto free_iter_vma;
+
+	if (merge)
+		goto file_expanded;

 #ifdef CONFIG_SPARC64
 	/* TODO: Fix SPARC ADI! */
-	WARN_ON_ONCE(!arch_validate_flags(vm_flags));
+	WARN_ON_ONCE(!arch_validate_flags(map->flags));
 #endif

 	/* Lock the VMA since it is modified after insertion into VMA tree */
 	vma_start_write(vma);
-	vma_iter_store(&vmi, vma);
-	mm->map_count++;
+	vma_iter_store(vmi, vma);
+	map->mm->map_count++;
 	vma_link_file(vma);

 	/*
 	 * vma_merge_new_range() calls khugepaged_enter_vma() too, the below
 	 * call covers the non-merge case.
 	 */
-	khugepaged_enter_vma(vma, vma->vm_flags);
+	khugepaged_enter_vma(vma, map->flags);

 file_expanded:
-	file = vma->vm_file;
 	ksm_add_vma(vma);
-expanded:
+
+	*vmap = vma;
+	return 0;
+
+free_iter_vma:
+	vma_iter_free(vmi);
+free_vma:
+	vm_area_free(vma);
+	return error;
+}
+
+/*
+ * __mmap_complete() - Unmap any VMAs we overlap, account memory mapping
+ *                     statistics, handle locking and finalise the VMA.
+ *
+ * @map: Mapping state.
+ * @vma: Merged or newly allocated VMA for the mmap()'d region.
+ */
+static void __mmap_complete(struct mmap_state *map, struct vm_area_struct *vma)
+{
+	struct mm_struct *mm = map->mm;
+	unsigned long vm_flags = vma->vm_flags;
+
 	perf_event_mmap(vma);

-	/* Unmap any existing mapping in the area */
-	vms_complete_munmap_vmas(&vms, &mas_detach);
+	/* Unmap any existing mapping in the area. */
+	vms_complete_munmap_vmas(&map->vms, &map->mas_detach);

-	vm_stat_account(mm, vm_flags, pglen);
+	vm_stat_account(mm, vma->vm_flags, map->pglen);
 	if (vm_flags & VM_LOCKED) {
 		if ((vm_flags & VM_SPECIAL) || vma_is_dax(vma) ||
 					is_vm_hugetlb_page(vma) ||
-					vma == get_gate_vma(current->mm))
+					vma == get_gate_vma(mm))
 			vm_flags_clear(vma, VM_LOCKED_MASK);
 		else
-			mm->locked_vm += pglen;
+			mm->locked_vm += map->pglen;
 	}

-	if (file)
+	if (vma->vm_file)
 		uprobe_mmap(vma);

 	/*
@@ -2364,26 +2447,43 @@  unsigned long __mmap_region(struct file *file, unsigned long addr,
 	vm_flags_set(vma, VM_SOFTDIRTY);

 	vma_set_page_prot(vma);
+}

-	return addr;
+unsigned long __mmap_region(struct file *file, unsigned long addr,
+		unsigned long len, vm_flags_t vm_flags, unsigned long pgoff,
+		struct list_head *uf)
+{
+	struct mm_struct *mm = current->mm;
+	struct vm_area_struct *vma;
+	int error;
+	VMA_ITERATOR(vmi, mm, addr);
+	VMG_STATE(vmg, mm, &vmi, addr, addr + len, vm_flags, pgoff);
+	MMAP_STATE(map, mm, &vmi, &vmg, uf, vm_flags, len);

-unmap_and_free_file_vma:
-	fput(vma->vm_file);
-	vma->vm_file = NULL;
+	vmg.file = file;

-	vma_iter_set(&vmi, vma->vm_end);
-	/* Undo any partial mapping done by a device driver. */
-	unmap_region(&vmi.mas, vma, vmg.prev, vmg.next);
-free_iter_vma:
-	vma_iter_free(&vmi);
-free_vma:
-	vm_area_free(vma);
-unacct_error:
-	if (charged)
-		vm_unacct_memory(charged);
+	error = __mmap_prepare(&map);
+	if (error)
+		goto abort_munmap;
+
+	/* Attempt to merge with adjacent VMAs... */
+	vmg.flags = map.flags;
+	vma = vma_merge_new_range(&vmg);
+	if (!vma) {
+		/* ...but if we can't, allocate a new VMA. */
+		error = __mmap_new_vma(&map, &vma);
+		if (error)
+			goto unacct_error;
+	}
+
+	__mmap_complete(&map, vma);

+	return addr;
+
+unacct_error:
+	if (map.charged)
+		vm_unacct_memory(map.charged);
 abort_munmap:
-	vms_abort_munmap_vmas(&vms, &mas_detach);
-gather_failed:
+	vms_abort_munmap_vmas(&map.vms, &map.mas_detach);
 	return error;
 }