diff mbox series

[06/14] mm: teach the mm about range locking

Message ID 20190521045242.24378-7-dave@stgolabs.net (mailing list archive)
State New, archived
Headers show
Series mmap_sem range locking | expand

Commit Message

Davidlohr Bueso May 21, 2019, 4:52 a.m. UTC
Conversion is straightforward, mmap_sem is used within the
the same function context most of the time, and we already
have vmf updated. No changes in semantics.

Signed-off-by: Davidlohr Bueso <dbueso@suse.de>
---
 include/linux/mm.h     |  8 +++---
 mm/filemap.c           |  8 +++---
 mm/frame_vector.c      |  4 +--
 mm/gup.c               | 21 +++++++--------
 mm/hmm.c               |  3 ++-
 mm/khugepaged.c        | 54 +++++++++++++++++++++------------------
 mm/ksm.c               | 42 +++++++++++++++++-------------
 mm/madvise.c           | 36 ++++++++++++++------------
 mm/memcontrol.c        | 10 +++++---
 mm/memory.c            | 10 +++++---
 mm/mempolicy.c         | 25 ++++++++++--------
 mm/migrate.c           | 10 +++++---
 mm/mincore.c           |  6 +++--
 mm/mlock.c             | 20 +++++++++------
 mm/mmap.c              | 69 ++++++++++++++++++++++++++++----------------------
 mm/mmu_notifier.c      |  9 ++++---
 mm/mprotect.c          | 15 ++++++-----
 mm/mremap.c            |  9 ++++---
 mm/msync.c             |  9 ++++---
 mm/nommu.c             | 25 ++++++++++--------
 mm/oom_kill.c          |  5 ++--
 mm/process_vm_access.c |  4 +--
 mm/shmem.c             |  2 +-
 mm/swapfile.c          |  5 ++--
 mm/userfaultfd.c       | 21 ++++++++-------
 mm/util.c              | 10 +++++---
 26 files changed, 252 insertions(+), 188 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/mm.h b/include/linux/mm.h
index 044e428b1905..8bf3e2542047 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1459,6 +1459,7 @@  void unmap_vmas(struct mmu_gather *tlb, struct vm_area_struct *start_vma,
  *             right now." 1 means "skip the current vma."
  * @mm:        mm_struct representing the target process of page table walk
  * @vma:       vma currently walked (NULL if walking outside vmas)
+ * @mmrange:   mm address space range locking
  * @private:   private data for callbacks' usage
  *
  * (see the comment on walk_page_range() for more details)
@@ -2358,8 +2359,8 @@  static inline int check_data_rlimit(unsigned long rlim,
 	return 0;
 }
 
-extern int mm_take_all_locks(struct mm_struct *mm);
-extern void mm_drop_all_locks(struct mm_struct *mm);
+extern int mm_take_all_locks(struct mm_struct *mm, struct range_lock *mmrange);
+extern void mm_drop_all_locks(struct mm_struct *mm, struct range_lock *mmrange);
 
 extern void set_mm_exe_file(struct mm_struct *mm, struct file *new_exe_file);
 extern struct file *get_mm_exe_file(struct mm_struct *mm);
@@ -2389,7 +2390,8 @@  extern unsigned long do_mmap(struct file *file, unsigned long addr,
 	vm_flags_t vm_flags, unsigned long pgoff, unsigned long *populate,
 	struct list_head *uf);
 extern int __do_munmap(struct mm_struct *, unsigned long, size_t,
-		       struct list_head *uf, bool downgrade);
+		       struct list_head *uf, bool downgrade,
+		       struct range_lock *);
 extern int do_munmap(struct mm_struct *, unsigned long, size_t,
 		     struct list_head *uf);
 
diff --git a/mm/filemap.c b/mm/filemap.c
index 959022841bab..71f0d8a18f40 100644
--- a/mm/filemap.c
+++ b/mm/filemap.c
@@ -1388,7 +1388,7 @@  int __lock_page_or_retry(struct page *page, struct mm_struct *mm,
 		if (flags & FAULT_FLAG_RETRY_NOWAIT)
 			return 0;
 
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, mmrange);
 		if (flags & FAULT_FLAG_KILLABLE)
 			wait_on_page_locked_killable(page);
 		else
@@ -1400,7 +1400,7 @@  int __lock_page_or_retry(struct page *page, struct mm_struct *mm,
 
 			ret = __lock_page_killable(page);
 			if (ret) {
-				up_read(&mm->mmap_sem);
+				mm_read_unlock(mm, mmrange);
 				return 0;
 			}
 		} else
@@ -2317,7 +2317,7 @@  static struct file *maybe_unlock_mmap_for_io(struct vm_fault *vmf,
 	if ((flags & (FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_RETRY_NOWAIT)) ==
 	    FAULT_FLAG_ALLOW_RETRY) {
 		fpin = get_file(vmf->vma->vm_file);
-		up_read(&vmf->vma->vm_mm->mmap_sem);
+		mm_read_unlock(vmf->vma->vm_mm, vmf->lockrange);
 	}
 	return fpin;
 }
@@ -2357,7 +2357,7 @@  static int lock_page_maybe_drop_mmap(struct vm_fault *vmf, struct page *page,
 			 * mmap_sem here and return 0 if we don't have a fpin.
 			 */
 			if (*fpin == NULL)
-				up_read(&vmf->vma->vm_mm->mmap_sem);
+				mm_read_unlock(vmf->vma->vm_mm, vmf->lockrange);
 			return 0;
 		}
 	} else
diff --git a/mm/frame_vector.c b/mm/frame_vector.c
index 4e1a577cbb79..ef33d21b3f39 100644
--- a/mm/frame_vector.c
+++ b/mm/frame_vector.c
@@ -47,7 +47,7 @@  int get_vaddr_frames(unsigned long start, unsigned int nr_frames,
 	if (WARN_ON_ONCE(nr_frames > vec->nr_allocated))
 		nr_frames = vec->nr_allocated;
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	locked = 1;
 	vma = find_vma_intersection(mm, start, start + 1);
 	if (!vma) {
@@ -102,7 +102,7 @@  int get_vaddr_frames(unsigned long start, unsigned int nr_frames,
 	} while (vma && vma->vm_flags & (VM_IO | VM_PFNMAP));
 out:
 	if (locked)
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 	if (!ret)
 		ret = -EFAULT;
 	if (ret > 0)
diff --git a/mm/gup.c b/mm/gup.c
index cf8fa037ce27..70b546a01682 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -990,7 +990,7 @@  int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
 	}
 
 	if (ret & VM_FAULT_RETRY) {
-		down_read(&mm->mmap_sem);
+		mm_read_lock(mm, mmrange);
 		if (!(fault_flags & FAULT_FLAG_TRIED)) {
 			*unlocked = true;
 			fault_flags &= ~FAULT_FLAG_ALLOW_RETRY;
@@ -1077,7 +1077,7 @@  static __always_inline long __get_user_pages_locked(struct task_struct *tsk,
 		 */
 		*locked = 1;
 		lock_dropped = true;
-		down_read(&mm->mmap_sem);
+		mm_read_lock(mm, mmrange);
 		ret = __get_user_pages(tsk, mm, start, 1, flags | FOLL_TRIED,
 				       pages, NULL, NULL, NULL);
 		if (ret != 1) {
@@ -1098,7 +1098,7 @@  static __always_inline long __get_user_pages_locked(struct task_struct *tsk,
 		 * We must let the caller know we temporarily dropped the lock
 		 * and so the critical section protected by it was lost.
 		 */
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, mmrange);
 		*locked = 0;
 	}
 	return pages_done;
@@ -1176,11 +1176,11 @@  long get_user_pages_unlocked(unsigned long start, unsigned long nr_pages,
 	if (WARN_ON_ONCE(gup_flags & FOLL_LONGTERM))
 		return -EINVAL;
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	ret = __get_user_pages_locked(current, mm, start, nr_pages, pages, NULL,
 				      &locked, gup_flags | FOLL_TOUCH, &mmrange);
 	if (locked)
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 	return ret;
 }
 EXPORT_SYMBOL(get_user_pages_unlocked);
@@ -1543,7 +1543,7 @@  long populate_vma_page_range(struct vm_area_struct *vma,
 	VM_BUG_ON(end   & ~PAGE_MASK);
 	VM_BUG_ON_VMA(start < vma->vm_start, vma);
 	VM_BUG_ON_VMA(end   > vma->vm_end, vma);
-	VM_BUG_ON_MM(!rwsem_is_locked(&mm->mmap_sem), mm);
+	VM_BUG_ON_MM(!mm_is_locked(mm, mmrange), mm);
 
 	gup_flags = FOLL_TOUCH | FOLL_POPULATE | FOLL_MLOCK;
 	if (vma->vm_flags & VM_LOCKONFAULT)
@@ -1596,7 +1596,7 @@  int __mm_populate(unsigned long start, unsigned long len, int ignore_errors)
 		 */
 		if (!locked) {
 			locked = 1;
-			down_read(&mm->mmap_sem);
+			mm_read_lock(mm, &mmrange);
 			vma = find_vma(mm, nstart);
 		} else if (nstart >= vma->vm_end)
 			vma = vma->vm_next;
@@ -1628,7 +1628,7 @@  int __mm_populate(unsigned long start, unsigned long len, int ignore_errors)
 		ret = 0;
 	}
 	if (locked)
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 	return ret;	/* 0 or negative error code */
 }
 
@@ -2189,17 +2189,18 @@  static int __gup_longterm_unlocked(unsigned long start, int nr_pages,
 				   unsigned int gup_flags, struct page **pages)
 {
 	int ret;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	/*
 	 * FIXME: FOLL_LONGTERM does not work with
 	 * get_user_pages_unlocked() (see comments in that function)
 	 */
 	if (gup_flags & FOLL_LONGTERM) {
-		down_read(&current->mm->mmap_sem);
+		mm_read_lock(current->mm, &mmrange);
 		ret = __gup_longterm_locked(current, current->mm,
 					    start, nr_pages,
 					    pages, NULL, gup_flags);
-		up_read(&current->mm->mmap_sem);
+		mm_read_unlock(current->mm, &mmrange);
 	} else {
 		ret = get_user_pages_unlocked(start, nr_pages,
 					      pages, gup_flags);
diff --git a/mm/hmm.c b/mm/hmm.c
index 723109ac6bdc..a79a07f7ccc1 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -1118,7 +1118,8 @@  long hmm_range_fault(struct hmm_range *range, bool block)
 	do {
 		/* If range is no longer valid force retry. */
 		if (!range->valid) {
-			up_read(&hmm->mm->mmap_sem);
+			/*** BROKEN mmrange, we don't care about hmm (for now) */
+			mm_read_unlock(hmm->mm, NULL);
 			return -EAGAIN;
 		}
 
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index 3eefcb8f797d..13d8e29f4674 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -488,6 +488,8 @@  void __khugepaged_exit(struct mm_struct *mm)
 		free_mm_slot(mm_slot);
 		mmdrop(mm);
 	} else if (mm_slot) {
+		DEFINE_RANGE_LOCK_FULL(mmrange);
+
 		/*
 		 * This is required to serialize against
 		 * khugepaged_test_exit() (which is guaranteed to run
@@ -496,8 +498,8 @@  void __khugepaged_exit(struct mm_struct *mm)
 		 * khugepaged has finished working on the pagetables
 		 * under the mmap_sem.
 		 */
-		down_write(&mm->mmap_sem);
-		up_write(&mm->mmap_sem);
+		mm_write_lock(mm, &mmrange);
+		mm_write_unlock(mm, &mmrange);
 	}
 }
 
@@ -908,7 +910,7 @@  static bool __collapse_huge_page_swapin(struct mm_struct *mm,
 
 		/* do_swap_page returns VM_FAULT_RETRY with released mmap_sem */
 		if (ret & VM_FAULT_RETRY) {
-			down_read(&mm->mmap_sem);
+			mm_read_lock(mm, mmrange);
 			if (hugepage_vma_revalidate(mm, address, &vmf.vma)) {
 				/* vma is no longer available, don't continue to swapin */
 				trace_mm_collapse_huge_page_swapin(mm, swapped_in, referenced, 0);
@@ -961,7 +963,7 @@  static void collapse_huge_page(struct mm_struct *mm,
 	 * sync compaction, and we do not need to hold the mmap_sem during
 	 * that. We will recheck the vma after taking it again in write mode.
 	 */
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, mmrange);
 	new_page = khugepaged_alloc_page(hpage, gfp, node);
 	if (!new_page) {
 		result = SCAN_ALLOC_HUGE_PAGE_FAIL;
@@ -973,11 +975,11 @@  static void collapse_huge_page(struct mm_struct *mm,
 		goto out_nolock;
 	}
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, mmrange);
 	result = hugepage_vma_revalidate(mm, address, &vma);
 	if (result) {
 		mem_cgroup_cancel_charge(new_page, memcg, true);
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, mmrange);
 		goto out_nolock;
 	}
 
@@ -985,7 +987,7 @@  static void collapse_huge_page(struct mm_struct *mm,
 	if (!pmd) {
 		result = SCAN_PMD_NULL;
 		mem_cgroup_cancel_charge(new_page, memcg, true);
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, mmrange);
 		goto out_nolock;
 	}
 
@@ -997,17 +999,17 @@  static void collapse_huge_page(struct mm_struct *mm,
 	if (!__collapse_huge_page_swapin(mm, vma, address, pmd,
 					 referenced, mmrange)) {
 		mem_cgroup_cancel_charge(new_page, memcg, true);
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, mmrange);
 		goto out_nolock;
 	}
 
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, mmrange);
 	/*
 	 * Prevent all access to pagetables with the exception of
 	 * gup_fast later handled by the ptep_clear_flush and the VM
 	 * handled by the anon_vma lock + PG_lock.
 	 */
-	down_write(&mm->mmap_sem);
+	mm_write_lock(mm, mmrange);
 	result = hugepage_vma_revalidate(mm, address, &vma);
 	if (result)
 		goto out;
@@ -1091,7 +1093,7 @@  static void collapse_huge_page(struct mm_struct *mm,
 	khugepaged_pages_collapsed++;
 	result = SCAN_SUCCEED;
 out_up_write:
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, mmrange);
 out_nolock:
 	trace_mm_collapse_huge_page(mm, isolated, result);
 	return;
@@ -1250,7 +1252,8 @@  static void collect_mm_slot(struct mm_slot *mm_slot)
 }
 
 #if defined(CONFIG_SHMEM) && defined(CONFIG_TRANSPARENT_HUGE_PAGECACHE)
-static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
+static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff,
+				struct range_lock *mmrange)
 {
 	struct vm_area_struct *vma;
 	unsigned long addr;
@@ -1275,12 +1278,12 @@  static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 		 * re-fault. Not ideal, but it's more important to not disturb
 		 * the system too much.
 		 */
-		if (down_write_trylock(&vma->vm_mm->mmap_sem)) {
+		if (mm_write_trylock(vma->vm_mm, mmrange)) {
 			spinlock_t *ptl = pmd_lock(vma->vm_mm, pmd);
 			/* assume page table is clear */
 			_pmd = pmdp_collapse_flush(vma, addr, pmd);
 			spin_unlock(ptl);
-			up_write(&vma->vm_mm->mmap_sem);
+			mm_write_unlock(vma->vm_mm, mmrange);
 			mm_dec_nr_ptes(vma->vm_mm);
 			pte_free(vma->vm_mm, pmd_pgtable(_pmd));
 		}
@@ -1307,8 +1310,9 @@  static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
  *    + unlock and free huge page;
  */
 static void collapse_shmem(struct mm_struct *mm,
-		struct address_space *mapping, pgoff_t start,
-		struct page **hpage, int node)
+			   struct address_space *mapping, pgoff_t start,
+			   struct page **hpage, int node,
+			   struct range_lock *mmrange)
 {
 	gfp_t gfp;
 	struct page *new_page;
@@ -1515,7 +1519,7 @@  static void collapse_shmem(struct mm_struct *mm,
 		/*
 		 * Remove pte page tables, so we can re-fault the page as huge.
 		 */
-		retract_page_tables(mapping, start);
+		retract_page_tables(mapping, start, mmrange);
 		*hpage = NULL;
 
 		khugepaged_pages_collapsed++;
@@ -1566,8 +1570,9 @@  static void collapse_shmem(struct mm_struct *mm,
 }
 
 static void khugepaged_scan_shmem(struct mm_struct *mm,
-		struct address_space *mapping,
-		pgoff_t start, struct page **hpage)
+				  struct address_space *mapping,
+				  pgoff_t start, struct page **hpage,
+				  struct range_lock *mmrange)
 {
 	struct page *page = NULL;
 	XA_STATE(xas, &mapping->i_pages, start);
@@ -1633,7 +1638,8 @@  static void khugepaged_scan_shmem(struct mm_struct *mm,
 			result = SCAN_EXCEED_NONE_PTE;
 		} else {
 			node = khugepaged_find_target_node();
-			collapse_shmem(mm, mapping, start, hpage, node);
+			collapse_shmem(mm, mapping, start, hpage,
+				       node, mmrange);
 		}
 	}
 
@@ -1678,7 +1684,7 @@  static unsigned int khugepaged_scan_mm_slot(unsigned int pages,
 	 * the next mm on the list.
 	 */
 	vma = NULL;
-	if (unlikely(!down_read_trylock(&mm->mmap_sem)))
+	if (unlikely(!mm_read_trylock(mm, &mmrange)))
 		goto breakouterloop_mmap_sem;
 	if (likely(!khugepaged_test_exit(mm)))
 		vma = find_vma(mm, khugepaged_scan.address);
@@ -1723,10 +1729,10 @@  static unsigned int khugepaged_scan_mm_slot(unsigned int pages,
 				if (!shmem_huge_enabled(vma))
 					goto skip;
 				file = get_file(vma->vm_file);
-				up_read(&mm->mmap_sem);
+				mm_read_unlock(mm, &mmrange);
 				ret = 1;
 				khugepaged_scan_shmem(mm, file->f_mapping,
-						pgoff, hpage);
+						      pgoff, hpage, &mmrange);
 				fput(file);
 			} else {
 				ret = khugepaged_scan_pmd(mm, vma,
@@ -1744,7 +1750,7 @@  static unsigned int khugepaged_scan_mm_slot(unsigned int pages,
 		}
 	}
 breakouterloop:
-	up_read(&mm->mmap_sem); /* exit_mmap will destroy ptes after this */
+	mm_read_unlock(mm, &mmrange); /* exit_mmap will destroy ptes after this */
 breakouterloop_mmap_sem:
 
 	spin_lock(&khugepaged_mm_lock);
diff --git a/mm/ksm.c b/mm/ksm.c
index ccc9737311eb..7f9826ea7dba 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -537,6 +537,7 @@  static void break_cow(struct rmap_item *rmap_item)
 	struct mm_struct *mm = rmap_item->mm;
 	unsigned long addr = rmap_item->address;
 	struct vm_area_struct *vma;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	/*
 	 * It is not an accident that whenever we want to break COW
@@ -544,11 +545,11 @@  static void break_cow(struct rmap_item *rmap_item)
 	 */
 	put_anon_vma(rmap_item->anon_vma);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	vma = find_mergeable_vma(mm, addr);
 	if (vma)
 		break_ksm(vma, addr);
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 }
 
 static struct page *get_mergeable_page(struct rmap_item *rmap_item)
@@ -557,8 +558,9 @@  static struct page *get_mergeable_page(struct rmap_item *rmap_item)
 	unsigned long addr = rmap_item->address;
 	struct vm_area_struct *vma;
 	struct page *page;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	vma = find_mergeable_vma(mm, addr);
 	if (!vma)
 		goto out;
@@ -574,7 +576,7 @@  static struct page *get_mergeable_page(struct rmap_item *rmap_item)
 out:
 		page = NULL;
 	}
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	return page;
 }
 
@@ -969,6 +971,7 @@  static int unmerge_and_remove_all_rmap_items(void)
 	struct mm_struct *mm;
 	struct vm_area_struct *vma;
 	int err = 0;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	spin_lock(&ksm_mmlist_lock);
 	ksm_scan.mm_slot = list_entry(ksm_mm_head.mm_list.next,
@@ -978,7 +981,7 @@  static int unmerge_and_remove_all_rmap_items(void)
 	for (mm_slot = ksm_scan.mm_slot;
 			mm_slot != &ksm_mm_head; mm_slot = ksm_scan.mm_slot) {
 		mm = mm_slot->mm;
-		down_read(&mm->mmap_sem);
+		mm_read_lock(mm, &mmrange);
 		for (vma = mm->mmap; vma; vma = vma->vm_next) {
 			if (ksm_test_exit(mm))
 				break;
@@ -991,7 +994,7 @@  static int unmerge_and_remove_all_rmap_items(void)
 		}
 
 		remove_trailing_rmap_items(mm_slot, &mm_slot->rmap_list);
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 
 		spin_lock(&ksm_mmlist_lock);
 		ksm_scan.mm_slot = list_entry(mm_slot->mm_list.next,
@@ -1014,7 +1017,7 @@  static int unmerge_and_remove_all_rmap_items(void)
 	return 0;
 
 error:
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	spin_lock(&ksm_mmlist_lock);
 	ksm_scan.mm_slot = &ksm_mm_head;
 	spin_unlock(&ksm_mmlist_lock);
@@ -1299,8 +1302,9 @@  static int try_to_merge_with_ksm_page(struct rmap_item *rmap_item,
 	struct mm_struct *mm = rmap_item->mm;
 	struct vm_area_struct *vma;
 	int err = -EFAULT;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	vma = find_mergeable_vma(mm, rmap_item->address);
 	if (!vma)
 		goto out;
@@ -1316,7 +1320,7 @@  static int try_to_merge_with_ksm_page(struct rmap_item *rmap_item,
 	rmap_item->anon_vma = vma->anon_vma;
 	get_anon_vma(vma->anon_vma);
 out:
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	return err;
 }
 
@@ -2129,12 +2133,13 @@  static void cmp_and_merge_page(struct page *page, struct rmap_item *rmap_item)
 	 */
 	if (ksm_use_zero_pages && (checksum == zero_checksum)) {
 		struct vm_area_struct *vma;
+		DEFINE_RANGE_LOCK_FULL(mmrange);
 
-		down_read(&mm->mmap_sem);
+		mm_read_lock(mm, &mmrange);
 		vma = find_mergeable_vma(mm, rmap_item->address);
 		err = try_to_merge_one_page(vma, page,
 					    ZERO_PAGE(rmap_item->address));
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 		/*
 		 * In case of failure, the page was not really empty, so we
 		 * need to continue. Otherwise we're done.
@@ -2240,6 +2245,7 @@  static struct rmap_item *scan_get_next_rmap_item(struct page **page)
 	struct vm_area_struct *vma;
 	struct rmap_item *rmap_item;
 	int nid;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	if (list_empty(&ksm_mm_head.mm_list))
 		return NULL;
@@ -2297,7 +2303,7 @@  static struct rmap_item *scan_get_next_rmap_item(struct page **page)
 	}
 
 	mm = slot->mm;
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	if (ksm_test_exit(mm))
 		vma = NULL;
 	else
@@ -2331,7 +2337,7 @@  static struct rmap_item *scan_get_next_rmap_item(struct page **page)
 					ksm_scan.address += PAGE_SIZE;
 				} else
 					put_page(*page);
-				up_read(&mm->mmap_sem);
+				mm_read_unlock(mm, &mmrange);
 				return rmap_item;
 			}
 			put_page(*page);
@@ -2369,10 +2375,10 @@  static struct rmap_item *scan_get_next_rmap_item(struct page **page)
 
 		free_mm_slot(slot);
 		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 		mmdrop(mm);
 	} else {
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 		/*
 		 * up_read(&mm->mmap_sem) first because after
 		 * spin_unlock(&ksm_mmlist_lock) run, the "mm" may
@@ -2571,8 +2577,10 @@  void __ksm_exit(struct mm_struct *mm)
 		clear_bit(MMF_VM_MERGEABLE, &mm->flags);
 		mmdrop(mm);
 	} else if (mm_slot) {
-		down_write(&mm->mmap_sem);
-		up_write(&mm->mmap_sem);
+		DEFINE_RANGE_LOCK_FULL(mmrange);
+
+		mm_write_lock(mm, &mmrange);
+		mm_write_unlock(mm, &mmrange);
 	}
 }
 
diff --git a/mm/madvise.c b/mm/madvise.c
index 628022e674a7..78a3f86d9c52 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -516,16 +516,16 @@  static long madvise_dontneed_single_vma(struct vm_area_struct *vma,
 static long madvise_dontneed_free(struct vm_area_struct *vma,
 				  struct vm_area_struct **prev,
 				  unsigned long start, unsigned long end,
-				  int behavior)
+				  int behavior, struct range_lock *mmrange)
 {
 	*prev = vma;
 	if (!can_madv_dontneed_vma(vma))
 		return -EINVAL;
 
-	if (!userfaultfd_remove(vma, start, end)) {
+	if (!userfaultfd_remove(vma, start, end, mmrange)) {
 		*prev = NULL; /* mmap_sem has been dropped, prev is stale */
 
-		down_read(&current->mm->mmap_sem);
+		mm_read_lock(current->mm, mmrange);
 		vma = find_vma(current->mm, start);
 		if (!vma)
 			return -ENOMEM;
@@ -574,8 +574,9 @@  static long madvise_dontneed_free(struct vm_area_struct *vma,
  * This is effectively punching a hole into the middle of a file.
  */
 static long madvise_remove(struct vm_area_struct *vma,
-				struct vm_area_struct **prev,
-				unsigned long start, unsigned long end)
+			   struct vm_area_struct **prev,
+			   unsigned long start, unsigned long end,
+			   struct range_lock *mmrange)
 {
 	loff_t offset;
 	int error;
@@ -605,15 +606,15 @@  static long madvise_remove(struct vm_area_struct *vma,
 	 * mmap_sem.
 	 */
 	get_file(f);
-	if (userfaultfd_remove(vma, start, end)) {
+	if (userfaultfd_remove(vma, start, end, mmrange)) {
 		/* mmap_sem was not released by userfaultfd_remove() */
-		up_read(&current->mm->mmap_sem);
+		mm_read_unlock(current->mm, mmrange);
 	}
 	error = vfs_fallocate(f,
 				FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE,
 				offset, end - start);
 	fput(f);
-	down_read(&current->mm->mmap_sem);
+	mm_read_lock(current->mm, mmrange);
 	return error;
 }
 
@@ -688,16 +689,18 @@  static int madvise_inject_error(int behavior,
 
 static long
 madvise_vma(struct vm_area_struct *vma, struct vm_area_struct **prev,
-		unsigned long start, unsigned long end, int behavior)
+	    unsigned long start, unsigned long end, int behavior,
+	    struct range_lock *mmrange)
 {
 	switch (behavior) {
 	case MADV_REMOVE:
-		return madvise_remove(vma, prev, start, end);
+		return madvise_remove(vma, prev, start, end, mmrange);
 	case MADV_WILLNEED:
 		return madvise_willneed(vma, prev, start, end);
 	case MADV_FREE:
 	case MADV_DONTNEED:
-		return madvise_dontneed_free(vma, prev, start, end, behavior);
+		return madvise_dontneed_free(vma, prev, start, end,
+					     behavior, mmrange);
 	default:
 		return madvise_behavior(vma, prev, start, end, behavior);
 	}
@@ -809,6 +812,7 @@  SYSCALL_DEFINE3(madvise, unsigned long, start, size_t, len_in, int, behavior)
 	int write;
 	size_t len;
 	struct blk_plug plug;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	if (!madvise_behavior_valid(behavior))
 		return error;
@@ -836,10 +840,10 @@  SYSCALL_DEFINE3(madvise, unsigned long, start, size_t, len_in, int, behavior)
 
 	write = madvise_need_mmap_write(behavior);
 	if (write) {
-		if (down_write_killable(&current->mm->mmap_sem))
+		if (mm_write_lock_killable(current->mm, &mmrange))
 			return -EINTR;
 	} else {
-		down_read(&current->mm->mmap_sem);
+		mm_read_lock(current->mm, &mmrange);
 	}
 
 	/*
@@ -872,7 +876,7 @@  SYSCALL_DEFINE3(madvise, unsigned long, start, size_t, len_in, int, behavior)
 			tmp = end;
 
 		/* Here vma->vm_start <= start < tmp <= (end|vma->vm_end). */
-		error = madvise_vma(vma, &prev, start, tmp, behavior);
+		error = madvise_vma(vma, &prev, start, tmp, behavior, &mmrange);
 		if (error)
 			goto out;
 		start = tmp;
@@ -889,9 +893,9 @@  SYSCALL_DEFINE3(madvise, unsigned long, start, size_t, len_in, int, behavior)
 out:
 	blk_finish_plug(&plug);
 	if (write)
-		up_write(&current->mm->mmap_sem);
+		mm_write_unlock(current->mm, &mmrange);
 	else
-		up_read(&current->mm->mmap_sem);
+		mm_read_unlock(current->mm, &mmrange);
 
 	return error;
 }
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 2535e54e7989..c822cea99570 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -5139,10 +5139,11 @@  static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm)
 		.pmd_entry = mem_cgroup_count_precharge_pte_range,
 		.mm = mm,
 	};
-	down_read(&mm->mmap_sem);
+	DEFINE_RANGE_LOCK_FULL(mmrange);
+	mm_read_lock(mm, &mmrange);
 	walk_page_range(0, mm->highest_vm_end,
 			&mem_cgroup_count_precharge_walk);
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 
 	precharge = mc.precharge;
 	mc.precharge = 0;
@@ -5412,6 +5413,7 @@  static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
 
 static void mem_cgroup_move_charge(void)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	struct mm_walk mem_cgroup_move_charge_walk = {
 		.pmd_entry = mem_cgroup_move_charge_pte_range,
 		.mm = mc.mm,
@@ -5426,7 +5428,7 @@  static void mem_cgroup_move_charge(void)
 	atomic_inc(&mc.from->moving_account);
 	synchronize_rcu();
 retry:
-	if (unlikely(!down_read_trylock(&mc.mm->mmap_sem))) {
+	if (unlikely(!mm_read_trylock(mc.mm, &mmrange))) {
 		/*
 		 * Someone who are holding the mmap_sem might be waiting in
 		 * waitq. So we cancel all extra charges, wake up all waiters,
@@ -5444,7 +5446,7 @@  static void mem_cgroup_move_charge(void)
 	 */
 	walk_page_range(0, mc.mm->highest_vm_end, &mem_cgroup_move_charge_walk);
 
-	up_read(&mc.mm->mmap_sem);
+	mm_read_unlock(mc.mm, &mmrange);
 	atomic_dec(&mc.from->moving_account);
 }
 
diff --git a/mm/memory.c b/mm/memory.c
index 73971f859035..8a5f52978893 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -4347,8 +4347,9 @@  int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm,
 	struct vm_area_struct *vma;
 	void *old_buf = buf;
 	int write = gup_flags & FOLL_WRITE;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	/* ignore errors, just check how much was successfully transferred */
 	while (len) {
 		int bytes, ret, offset;
@@ -4397,7 +4398,7 @@  int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm,
 		buf += bytes;
 		addr += bytes;
 	}
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 
 	return buf - old_buf;
 }
@@ -4450,11 +4451,12 @@  void print_vma_addr(char *prefix, unsigned long ip)
 {
 	struct mm_struct *mm = current->mm;
 	struct vm_area_struct *vma;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	/*
 	 * we might be running from an atomic context so we cannot sleep
 	 */
-	if (!down_read_trylock(&mm->mmap_sem))
+	if (!mm_read_trylock(mm, &mmrange))
 		return;
 
 	vma = find_vma(mm, ip);
@@ -4473,7 +4475,7 @@  void print_vma_addr(char *prefix, unsigned long ip)
 			free_page((unsigned long)buf);
 		}
 	}
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 }
 
 #if defined(CONFIG_PROVE_LOCKING) || defined(CONFIG_DEBUG_ATOMIC_SLEEP)
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 975793cc1d71..8bf8861e0c73 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -378,11 +378,12 @@  void mpol_rebind_task(struct task_struct *tsk, const nodemask_t *new)
 void mpol_rebind_mm(struct mm_struct *mm, nodemask_t *new)
 {
 	struct vm_area_struct *vma;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_write(&mm->mmap_sem);
+	mm_write_lock(mm, &mmrange);
 	for (vma = mm->mmap; vma; vma = vma->vm_next)
 		mpol_rebind_policy(vma->vm_policy, new);
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 }
 
 static const struct mempolicy_operations mpol_ops[MPOL_MAX] = {
@@ -837,7 +838,7 @@  static int lookup_node(struct mm_struct *mm, unsigned long addr,
 		put_page(p);
 	}
 	if (locked)
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, mmrange);
 	return err;
 }
 
@@ -871,10 +872,10 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 		 * vma/shared policy at addr is NULL.  We
 		 * want to return MPOL_DEFAULT in this case.
 		 */
-		down_read(&mm->mmap_sem);
+		mm_read_lock(mm, &mmrange);
 		vma = find_vma_intersection(mm, addr, addr+1);
 		if (!vma) {
-			up_read(&mm->mmap_sem);
+			mm_read_unlock(mm, &mmrange);
 			return -EFAULT;
 		}
 		if (vma->vm_ops && vma->vm_ops->get_policy)
@@ -933,7 +934,7 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
  out:
 	mpol_cond_put(pol);
 	if (vma)
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 	if (pol_refcount)
 		mpol_put(pol_refcount);
 	return err;
@@ -1026,12 +1027,13 @@  int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from,
 	int busy = 0;
 	int err;
 	nodemask_t tmp;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	err = migrate_prep();
 	if (err)
 		return err;
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 
 	/*
 	 * Find a 'source' bit set in 'tmp' whose corresponding 'dest'
@@ -1112,7 +1114,7 @@  int do_migrate_pages(struct mm_struct *mm, const nodemask_t *from,
 		if (err < 0)
 			break;
 	}
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	if (err < 0)
 		return err;
 	return busy;
@@ -1186,6 +1188,7 @@  static long do_mbind(unsigned long start, unsigned long len,
 	unsigned long end;
 	int err;
 	LIST_HEAD(pagelist);
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	if (flags & ~(unsigned long)MPOL_MF_VALID)
 		return -EINVAL;
@@ -1233,12 +1236,12 @@  static long do_mbind(unsigned long start, unsigned long len,
 	{
 		NODEMASK_SCRATCH(scratch);
 		if (scratch) {
-			down_write(&mm->mmap_sem);
+			mm_write_lock(mm, &mmrange);
 			task_lock(current);
 			err = mpol_set_nodemask(new, nmask, scratch);
 			task_unlock(current);
 			if (err)
-				up_write(&mm->mmap_sem);
+				mm_write_unlock(mm, &mmrange);
 		} else
 			err = -ENOMEM;
 		NODEMASK_SCRATCH_FREE(scratch);
@@ -1267,7 +1270,7 @@  static long do_mbind(unsigned long start, unsigned long len,
 	} else
 		putback_movable_pages(&pagelist);
 
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
  mpol_out:
 	mpol_put(new);
 	return err;
diff --git a/mm/migrate.c b/mm/migrate.c
index f2ecc2855a12..3a268b316e4e 100644
--- a/mm/migrate.c
+++ b/mm/migrate.c
@@ -1531,8 +1531,9 @@  static int add_page_for_migration(struct mm_struct *mm, unsigned long addr,
 	struct page *page;
 	unsigned int follflags;
 	int err;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	err = -EFAULT;
 	vma = find_vma(mm, addr);
 	if (!vma || addr < vma->vm_start || !vma_migratable(vma))
@@ -1585,7 +1586,7 @@  static int add_page_for_migration(struct mm_struct *mm, unsigned long addr,
 	 */
 	put_page(page);
 out:
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	return err;
 }
 
@@ -1686,8 +1687,9 @@  static void do_pages_stat_array(struct mm_struct *mm, unsigned long nr_pages,
 				const void __user **pages, int *status)
 {
 	unsigned long i;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 
 	for (i = 0; i < nr_pages; i++) {
 		unsigned long addr = (unsigned long)(*pages);
@@ -1714,7 +1716,7 @@  static void do_pages_stat_array(struct mm_struct *mm, unsigned long nr_pages,
 		status++;
 	}
 
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 }
 
 /*
diff --git a/mm/mincore.c b/mm/mincore.c
index c3f058bd0faf..c1d3a9cd2ba3 100644
--- a/mm/mincore.c
+++ b/mm/mincore.c
@@ -270,13 +270,15 @@  SYSCALL_DEFINE3(mincore, unsigned long, start, size_t, len,
 
 	retval = 0;
 	while (pages) {
+		DEFINE_RANGE_LOCK_FULL(mmrange);
+
 		/*
 		 * Do at most PAGE_SIZE entries per iteration, due to
 		 * the temporary buffer size.
 		 */
-		down_read(&current->mm->mmap_sem);
+		mm_read_lock(current->mm, &mmrange);
 		retval = do_mincore(start, min(pages, PAGE_SIZE), tmp);
-		up_read(&current->mm->mmap_sem);
+		mm_read_unlock(current->mm, &mmrange);
 
 		if (retval <= 0)
 			break;
diff --git a/mm/mlock.c b/mm/mlock.c
index e492a155c51a..c5b5dbd92a3a 100644
--- a/mm/mlock.c
+++ b/mm/mlock.c
@@ -670,6 +670,7 @@  static int count_mm_mlocked_page_nr(struct mm_struct *mm,
 
 static __must_check int do_mlock(unsigned long start, size_t len, vm_flags_t flags)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	unsigned long locked;
 	unsigned long lock_limit;
 	int error = -ENOMEM;
@@ -684,7 +685,7 @@  static __must_check int do_mlock(unsigned long start, size_t len, vm_flags_t fla
 	lock_limit >>= PAGE_SHIFT;
 	locked = len >> PAGE_SHIFT;
 
-	if (down_write_killable(&current->mm->mmap_sem))
+	if (mm_write_lock_killable(current->mm, &mmrange))
 		return -EINTR;
 
 	locked += atomic64_read(&current->mm->locked_vm);
@@ -703,7 +704,7 @@  static __must_check int do_mlock(unsigned long start, size_t len, vm_flags_t fla
 	if ((locked <= lock_limit) || capable(CAP_IPC_LOCK))
 		error = apply_vma_lock_flags(start, len, flags);
 
-	up_write(&current->mm->mmap_sem);
+	mm_write_unlock(current->mm, &mmrange);
 	if (error)
 		return error;
 
@@ -733,15 +734,16 @@  SYSCALL_DEFINE3(mlock2, unsigned long, start, size_t, len, int, flags)
 
 SYSCALL_DEFINE2(munlock, unsigned long, start, size_t, len)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	int ret;
 
 	len = PAGE_ALIGN(len + (offset_in_page(start)));
 	start &= PAGE_MASK;
 
-	if (down_write_killable(&current->mm->mmap_sem))
+	if (mm_write_lock_killable(current->mm, &mmrange))
 		return -EINTR;
 	ret = apply_vma_lock_flags(start, len, 0);
-	up_write(&current->mm->mmap_sem);
+	mm_write_unlock(current->mm, &mmrange);
 
 	return ret;
 }
@@ -794,6 +796,7 @@  static int apply_mlockall_flags(int flags)
 
 SYSCALL_DEFINE1(mlockall, int, flags)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	unsigned long lock_limit;
 	int ret;
 
@@ -806,14 +809,14 @@  SYSCALL_DEFINE1(mlockall, int, flags)
 	lock_limit = rlimit(RLIMIT_MEMLOCK);
 	lock_limit >>= PAGE_SHIFT;
 
-	if (down_write_killable(&current->mm->mmap_sem))
+	if (mm_write_lock_killable(current->mm, &mmrange))
 		return -EINTR;
 
 	ret = -ENOMEM;
 	if (!(flags & MCL_CURRENT) || (current->mm->total_vm <= lock_limit) ||
 	    capable(CAP_IPC_LOCK))
 		ret = apply_mlockall_flags(flags);
-	up_write(&current->mm->mmap_sem);
+	mm_write_unlock(current->mm, &mmrange);
 	if (!ret && (flags & MCL_CURRENT))
 		mm_populate(0, TASK_SIZE);
 
@@ -822,12 +825,13 @@  SYSCALL_DEFINE1(mlockall, int, flags)
 
 SYSCALL_DEFINE0(munlockall)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	int ret;
 
-	if (down_write_killable(&current->mm->mmap_sem))
+	if (mm_write_lock_killable(current->mm, &mmrange))
 		return -EINTR;
 	ret = apply_mlockall_flags(0);
-	up_write(&current->mm->mmap_sem);
+	mm_write_unlock(current->mm, &mmrange);
 	return ret;
 }
 
diff --git a/mm/mmap.c b/mm/mmap.c
index a03ded49f9eb..2eecdeb5fcd6 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -198,9 +198,10 @@  SYSCALL_DEFINE1(brk, unsigned long, brk)
 	unsigned long min_brk;
 	bool populate;
 	bool downgraded = false;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	LIST_HEAD(uf);
 
-	if (down_write_killable(&mm->mmap_sem))
+	if (mm_write_lock_killable(mm, &mmrange))
 		return -EINTR;
 
 	origbrk = mm->brk;
@@ -251,7 +252,7 @@  SYSCALL_DEFINE1(brk, unsigned long, brk)
 		 * mm->brk will be restored from origbrk.
 		 */
 		mm->brk = brk;
-		ret = __do_munmap(mm, newbrk, oldbrk-newbrk, &uf, true);
+		ret = __do_munmap(mm, newbrk, oldbrk-newbrk, &uf, true, &mmrange);
 		if (ret < 0) {
 			mm->brk = origbrk;
 			goto out;
@@ -274,9 +275,9 @@  SYSCALL_DEFINE1(brk, unsigned long, brk)
 success:
 	populate = newbrk > oldbrk && (mm->def_flags & VM_LOCKED) != 0;
 	if (downgraded)
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 	else
-		up_write(&mm->mmap_sem);
+		mm_write_unlock(mm, &mmrange);
 	userfaultfd_unmap_complete(mm, &uf);
 	if (populate)
 		mm_populate(oldbrk, newbrk - oldbrk);
@@ -284,7 +285,7 @@  SYSCALL_DEFINE1(brk, unsigned long, brk)
 
 out:
 	retval = origbrk;
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	return retval;
 }
 
@@ -2726,7 +2727,8 @@  int split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
  * Jeremy Fitzhardinge <jeremy@goop.org>
  */
 int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
-		struct list_head *uf, bool downgrade)
+		struct list_head *uf, bool downgrade,
+		struct range_lock *mmrange)
 {
 	unsigned long end;
 	struct vm_area_struct *vma, *prev, *last;
@@ -2824,7 +2826,7 @@  int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
 	detach_vmas_to_be_unmapped(mm, vma, prev, end);
 
 	if (downgrade)
-		downgrade_write(&mm->mmap_sem);
+		mm_downgrade_write(mm, mmrange);
 
 	unmap_region(mm, vma, prev, start, end);
 
@@ -2837,7 +2839,7 @@  int __do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
 int do_munmap(struct mm_struct *mm, unsigned long start, size_t len,
 	      struct list_head *uf)
 {
-	return __do_munmap(mm, start, len, uf, false);
+	return __do_munmap(mm, start, len, uf, false, NULL);
 }
 
 static int __vm_munmap(unsigned long start, size_t len, bool downgrade)
@@ -2845,21 +2847,22 @@  static int __vm_munmap(unsigned long start, size_t len, bool downgrade)
 	int ret;
 	struct mm_struct *mm = current->mm;
 	LIST_HEAD(uf);
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	if (down_write_killable(&mm->mmap_sem))
+	if (mm_write_lock_killable(mm, &mmrange))
 		return -EINTR;
 
-	ret = __do_munmap(mm, start, len, &uf, downgrade);
+	ret = __do_munmap(mm, start, len, &uf, downgrade, &mmrange);
 	/*
 	 * Returning 1 indicates mmap_sem is downgraded.
 	 * But 1 is not legal return value of vm_munmap() and munmap(), reset
 	 * it to 0 before return.
 	 */
 	if (ret == 1) {
-		up_read(&mm->mmap_sem);
+		mm_read_unlock(mm, &mmrange);
 		ret = 0;
 	} else
-		up_write(&mm->mmap_sem);
+		mm_write_unlock(mm, &mmrange);
 
 	userfaultfd_unmap_complete(mm, &uf);
 	return ret;
@@ -2884,6 +2887,7 @@  SYSCALL_DEFINE2(munmap, unsigned long, addr, size_t, len)
 SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
 		unsigned long, prot, unsigned long, pgoff, unsigned long, flags)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	struct mm_struct *mm = current->mm;
 	struct vm_area_struct *vma;
@@ -2906,7 +2910,7 @@  SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
 	if (pgoff + (size >> PAGE_SHIFT) < pgoff)
 		return ret;
 
-	if (down_write_killable(&mm->mmap_sem))
+	if (mm_write_lock_killable(mm, &mmrange))
 		return -EINTR;
 
 	vma = find_vma(mm, start);
@@ -2969,7 +2973,7 @@  SYSCALL_DEFINE5(remap_file_pages, unsigned long, start, unsigned long, size,
 			prot, flags, pgoff, &populate, NULL);
 	fput(file);
 out:
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	if (populate)
 		mm_populate(ret, populate);
 	if (!IS_ERR_VALUE(ret))
@@ -3056,6 +3060,7 @@  static int do_brk_flags(unsigned long addr, unsigned long len, unsigned long fla
 
 int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	struct mm_struct *mm = current->mm;
 	unsigned long len;
 	int ret;
@@ -3068,12 +3073,12 @@  int vm_brk_flags(unsigned long addr, unsigned long request, unsigned long flags)
 	if (!len)
 		return 0;
 
-	if (down_write_killable(&mm->mmap_sem))
+	if (mm_write_lock_killable(mm, &mmrange))
 		return -EINTR;
 
 	ret = do_brk_flags(addr, len, flags, &uf);
 	populate = ((mm->def_flags & VM_LOCKED) != 0);
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	userfaultfd_unmap_complete(mm, &uf);
 	if (populate && !ret)
 		mm_populate(addr, len);
@@ -3098,6 +3103,8 @@  void exit_mmap(struct mm_struct *mm)
 	mmu_notifier_release(mm);
 
 	if (unlikely(mm_is_oom_victim(mm))) {
+		DEFINE_RANGE_LOCK_FULL(mmrange);
+
 		/*
 		 * Manually reap the mm to free as much memory as possible.
 		 * Then, as the oom reaper does, set MMF_OOM_SKIP to disregard
@@ -3117,8 +3124,8 @@  void exit_mmap(struct mm_struct *mm)
 		(void)__oom_reap_task_mm(mm);
 
 		set_bit(MMF_OOM_SKIP, &mm->flags);
-		down_write(&mm->mmap_sem);
-		up_write(&mm->mmap_sem);
+		mm_write_lock(mm, &mmrange);
+		mm_write_unlock(mm, &mmrange);
 	}
 
 	if (atomic64_read(&mm->locked_vm)) {
@@ -3459,14 +3466,15 @@  int install_special_mapping(struct mm_struct *mm,
 
 static DEFINE_MUTEX(mm_all_locks_mutex);
 
-static void vm_lock_anon_vma(struct mm_struct *mm, struct anon_vma *anon_vma)
+static void vm_lock_anon_vma(struct mm_struct *mm, struct anon_vma *anon_vma,
+			     struct range_lock *mmrange)
 {
 	if (!test_bit(0, (unsigned long *) &anon_vma->root->rb_root.rb_root.rb_node)) {
 		/*
 		 * The LSB of head.next can't change from under us
 		 * because we hold the mm_all_locks_mutex.
 		 */
-		down_write(&mm->mmap_sem);
+		mm_write_lock(mm, mmrange);
 		/*
 		 * We can safely modify head.next after taking the
 		 * anon_vma->root->rwsem. If some other vma in this mm shares
@@ -3482,7 +3490,8 @@  static void vm_lock_anon_vma(struct mm_struct *mm, struct anon_vma *anon_vma)
 	}
 }
 
-static void vm_lock_mapping(struct mm_struct *mm, struct address_space *mapping)
+static void vm_lock_mapping(struct mm_struct *mm, struct address_space *mapping,
+			    struct range_lock *mmrange)
 {
 	if (!test_bit(AS_MM_ALL_LOCKS, &mapping->flags)) {
 		/*
@@ -3496,7 +3505,7 @@  static void vm_lock_mapping(struct mm_struct *mm, struct address_space *mapping)
 		 */
 		if (test_and_set_bit(AS_MM_ALL_LOCKS, &mapping->flags))
 			BUG();
-		down_write(&mm->mmap_sem);
+		mm_write_lock(mm, mmrange);
 	}
 }
 
@@ -3537,12 +3546,12 @@  static void vm_lock_mapping(struct mm_struct *mm, struct address_space *mapping)
  *
  * mm_take_all_locks() can fail if it's interrupted by signals.
  */
-int mm_take_all_locks(struct mm_struct *mm)
+int mm_take_all_locks(struct mm_struct *mm, struct range_lock *mmrange)
 {
 	struct vm_area_struct *vma;
 	struct anon_vma_chain *avc;
 
-	BUG_ON(down_read_trylock(&mm->mmap_sem));
+	BUG_ON(mm_read_trylock(mm, mmrange));
 
 	mutex_lock(&mm_all_locks_mutex);
 
@@ -3551,7 +3560,7 @@  int mm_take_all_locks(struct mm_struct *mm)
 			goto out_unlock;
 		if (vma->vm_file && vma->vm_file->f_mapping &&
 				is_vm_hugetlb_page(vma))
-			vm_lock_mapping(mm, vma->vm_file->f_mapping);
+			vm_lock_mapping(mm, vma->vm_file->f_mapping, mmrange);
 	}
 
 	for (vma = mm->mmap; vma; vma = vma->vm_next) {
@@ -3559,7 +3568,7 @@  int mm_take_all_locks(struct mm_struct *mm)
 			goto out_unlock;
 		if (vma->vm_file && vma->vm_file->f_mapping &&
 				!is_vm_hugetlb_page(vma))
-			vm_lock_mapping(mm, vma->vm_file->f_mapping);
+			vm_lock_mapping(mm, vma->vm_file->f_mapping, mmrange);
 	}
 
 	for (vma = mm->mmap; vma; vma = vma->vm_next) {
@@ -3567,13 +3576,13 @@  int mm_take_all_locks(struct mm_struct *mm)
 			goto out_unlock;
 		if (vma->anon_vma)
 			list_for_each_entry(avc, &vma->anon_vma_chain, same_vma)
-				vm_lock_anon_vma(mm, avc->anon_vma);
+				vm_lock_anon_vma(mm, avc->anon_vma, mmrange);
 	}
 
 	return 0;
 
 out_unlock:
-	mm_drop_all_locks(mm);
+	mm_drop_all_locks(mm, mmrange);
 	return -EINTR;
 }
 
@@ -3617,12 +3626,12 @@  static void vm_unlock_mapping(struct address_space *mapping)
  * The mmap_sem cannot be released by the caller until
  * mm_drop_all_locks() returns.
  */
-void mm_drop_all_locks(struct mm_struct *mm)
+void mm_drop_all_locks(struct mm_struct *mm, struct range_lock *mmrange)
 {
 	struct vm_area_struct *vma;
 	struct anon_vma_chain *avc;
 
-	BUG_ON(down_read_trylock(&mm->mmap_sem));
+	BUG_ON(mm_read_trylock(mm, mmrange));
 	BUG_ON(!mutex_is_locked(&mm_all_locks_mutex));
 
 	for (vma = mm->mmap; vma; vma = vma->vm_next) {
diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index ee36068077b6..028eaed031e1 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -244,6 +244,7 @@  static int do_mmu_notifier_register(struct mmu_notifier *mn,
 {
 	struct mmu_notifier_mm *mmu_notifier_mm;
 	int ret;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	BUG_ON(atomic_read(&mm->mm_users) <= 0);
 
@@ -253,8 +254,8 @@  static int do_mmu_notifier_register(struct mmu_notifier *mn,
 		goto out;
 
 	if (take_mmap_sem)
-		down_write(&mm->mmap_sem);
-	ret = mm_take_all_locks(mm);
+		mm_write_lock(mm, &mmrange);
+	ret = mm_take_all_locks(mm, &mmrange);
 	if (unlikely(ret))
 		goto out_clean;
 
@@ -279,10 +280,10 @@  static int do_mmu_notifier_register(struct mmu_notifier *mn,
 	hlist_add_head(&mn->hlist, &mm->mmu_notifier_mm->list);
 	spin_unlock(&mm->mmu_notifier_mm->lock);
 
-	mm_drop_all_locks(mm);
+	mm_drop_all_locks(mm, &mmrange);
 out_clean:
 	if (take_mmap_sem)
-		up_write(&mm->mmap_sem);
+		mm_write_unlock(mm, &mmrange);
 	kfree(mmu_notifier_mm);
 out:
 	BUG_ON(atomic_read(&mm->mm_users) <= 0);
diff --git a/mm/mprotect.c b/mm/mprotect.c
index 36c517c6a5b1..443b033f240c 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -458,6 +458,7 @@  mprotect_fixup(struct vm_area_struct *vma, struct vm_area_struct **pprev,
 static int do_mprotect_pkey(unsigned long start, size_t len,
 		unsigned long prot, int pkey)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	unsigned long nstart, end, tmp, reqprot;
 	struct vm_area_struct *vma, *prev;
 	int error = -EINVAL;
@@ -482,7 +483,7 @@  static int do_mprotect_pkey(unsigned long start, size_t len,
 
 	reqprot = prot;
 
-	if (down_write_killable(&current->mm->mmap_sem))
+	if (mm_write_lock_killable(current->mm, &mmrange))
 		return -EINTR;
 
 	/*
@@ -572,7 +573,7 @@  static int do_mprotect_pkey(unsigned long start, size_t len,
 		prot = reqprot;
 	}
 out:
-	up_write(&current->mm->mmap_sem);
+	mm_write_unlock(current->mm, &mmrange);
 	return error;
 }
 
@@ -594,6 +595,7 @@  SYSCALL_DEFINE2(pkey_alloc, unsigned long, flags, unsigned long, init_val)
 {
 	int pkey;
 	int ret;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	/* No flags supported yet. */
 	if (flags)
@@ -602,7 +604,7 @@  SYSCALL_DEFINE2(pkey_alloc, unsigned long, flags, unsigned long, init_val)
 	if (init_val & ~PKEY_ACCESS_MASK)
 		return -EINVAL;
 
-	down_write(&current->mm->mmap_sem);
+	mm_write_lock(current->mm, &mmrange);
 	pkey = mm_pkey_alloc(current->mm);
 
 	ret = -ENOSPC;
@@ -616,17 +618,18 @@  SYSCALL_DEFINE2(pkey_alloc, unsigned long, flags, unsigned long, init_val)
 	}
 	ret = pkey;
 out:
-	up_write(&current->mm->mmap_sem);
+	mm_write_unlock(current->mm, &mmrange);
 	return ret;
 }
 
 SYSCALL_DEFINE1(pkey_free, int, pkey)
 {
 	int ret;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_write(&current->mm->mmap_sem);
+	mm_write_lock(current->mm, &mmrange);
 	ret = mm_pkey_free(current->mm, pkey);
-	up_write(&current->mm->mmap_sem);
+	mm_write_unlock(current->mm, &mmrange);
 
 	/*
 	 * We could provie warnings or errors if any VMA still
diff --git a/mm/mremap.c b/mm/mremap.c
index 37b5b2ad91be..9009210aea97 100644
--- a/mm/mremap.c
+++ b/mm/mremap.c
@@ -603,6 +603,7 @@  SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
 	bool locked = false;
 	bool downgraded = false;
 	struct vm_userfaultfd_ctx uf = NULL_VM_UFFD_CTX;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	LIST_HEAD(uf_unmap_early);
 	LIST_HEAD(uf_unmap);
 
@@ -626,7 +627,7 @@  SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
 	if (!new_len)
 		return ret;
 
-	if (down_write_killable(&current->mm->mmap_sem))
+	if (mm_write_lock_killable(current->mm, &mmrange))
 		return -EINTR;
 
 	if (flags & MREMAP_FIXED) {
@@ -645,7 +646,7 @@  SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
 		int retval;
 
 		retval = __do_munmap(mm, addr+new_len, old_len - new_len,
-				  &uf_unmap, true);
+				     &uf_unmap, true, &mmrange);
 		if (retval < 0 && old_len != new_len) {
 			ret = retval;
 			goto out;
@@ -717,9 +718,9 @@  SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
 		locked = 0;
 	}
 	if (downgraded)
-		up_read(&current->mm->mmap_sem);
+		mm_read_unlock(current->mm, &mmrange);
 	else
-		up_write(&current->mm->mmap_sem);
+		mm_write_unlock(current->mm, &mmrange);
 	if (locked && new_len > old_len)
 		mm_populate(new_addr + old_len, new_len - old_len);
 	userfaultfd_unmap_complete(mm, &uf_unmap_early);
diff --git a/mm/msync.c b/mm/msync.c
index ef30a429623a..2524b4708e78 100644
--- a/mm/msync.c
+++ b/mm/msync.c
@@ -36,6 +36,7 @@  SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags)
 	struct vm_area_struct *vma;
 	int unmapped_error = 0;
 	int error = -EINVAL;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	if (flags & ~(MS_ASYNC | MS_INVALIDATE | MS_SYNC))
 		goto out;
@@ -55,7 +56,7 @@  SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags)
 	 * If the interval [start,end) covers some unmapped address ranges,
 	 * just ignore them, but return -ENOMEM at the end.
 	 */
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	vma = find_vma(mm, start);
 	for (;;) {
 		struct file *file;
@@ -86,12 +87,12 @@  SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags)
 		if ((flags & MS_SYNC) && file &&
 				(vma->vm_flags & VM_SHARED)) {
 			get_file(file);
-			up_read(&mm->mmap_sem);
+			mm_read_unlock(mm, &mmrange);
 			error = vfs_fsync_range(file, fstart, fend, 1);
 			fput(file);
 			if (error || start >= end)
 				goto out;
-			down_read(&mm->mmap_sem);
+			mm_read_lock(mm, &mmrange);
 			vma = find_vma(mm, start);
 		} else {
 			if (start >= end) {
@@ -102,7 +103,7 @@  SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags)
 		}
 	}
 out_unlock:
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 out:
 	return error ? : unmapped_error;
 }
diff --git a/mm/nommu.c b/mm/nommu.c
index b492fd1fcf9f..b454b0004fd2 100644
--- a/mm/nommu.c
+++ b/mm/nommu.c
@@ -183,10 +183,11 @@  static long __get_user_pages_unlocked(struct task_struct *tsk,
 			unsigned int gup_flags)
 {
 	long ret;
-	down_read(&mm->mmap_sem);
+	DEFINE_RANGE_LOCK_FULL(mmrange);
+	mm_read_lock(mm, &mmrange);
 	ret = __get_user_pages(tsk, mm, start, nr_pages, gup_flags, pages,
 				NULL, NULL);
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	return ret;
 }
 
@@ -249,12 +250,13 @@  void *vmalloc_user(unsigned long size)
 	ret = __vmalloc(size, GFP_KERNEL | __GFP_ZERO, PAGE_KERNEL);
 	if (ret) {
 		struct vm_area_struct *vma;
+		DEFINE_RANGE_LOCK_FULL(mmrange);
 
-		down_write(&current->mm->mmap_sem);
+		mm_write_lock(current->mm, &mmrange);
 		vma = find_vma(current->mm, (unsigned long)ret);
 		if (vma)
 			vma->vm_flags |= VM_USERMAP;
-		up_write(&current->mm->mmap_sem);
+		mm_write_unlock(current->mm, &mmrange);
 	}
 
 	return ret;
@@ -1627,10 +1629,11 @@  int vm_munmap(unsigned long addr, size_t len)
 {
 	struct mm_struct *mm = current->mm;
 	int ret;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_write(&mm->mmap_sem);
+	mm_write_lock(mm, &mmrange);
 	ret = do_munmap(mm, addr, len, NULL);
-	up_write(&mm->mmap_sem);
+	mm_write_unlock(mm, &mmrange);
 	return ret;
 }
 EXPORT_SYMBOL(vm_munmap);
@@ -1716,10 +1719,11 @@  SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
 		unsigned long, new_addr)
 {
 	unsigned long ret;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_write(&current->mm->mmap_sem);
+	mm_write_lock(current->mm, &mmrange);
 	ret = do_mremap(addr, old_len, new_len, flags, new_addr);
-	up_write(&current->mm->mmap_sem);
+	mm_write_unlock(current->mm, &mmrange);
 	return ret;
 }
 
@@ -1790,8 +1794,9 @@  int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm,
 {
 	struct vm_area_struct *vma;
 	int write = gup_flags & FOLL_WRITE;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 
 	/* the access must start within one of the target process's mappings */
 	vma = find_vma(mm, addr);
@@ -1813,7 +1818,7 @@  int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm,
 		len = 0;
 	}
 
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 
 	return len;
 }
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 539c91d0b26a..a8e3e6279718 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -558,8 +558,9 @@  bool __oom_reap_task_mm(struct mm_struct *mm)
 static bool oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
 {
 	bool ret = true;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	if (!down_read_trylock(&mm->mmap_sem)) {
+	if (!mm_read_trylock(mm, &mmrange)) {
 		trace_skip_task_reaping(tsk->pid);
 		return false;
 	}
@@ -590,7 +591,7 @@  static bool oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
 out_finish:
 	trace_finish_task_reaping(tsk->pid);
 out_unlock:
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 
 	return ret;
 }
diff --git a/mm/process_vm_access.c b/mm/process_vm_access.c
index ff6772b86195..aaccb8972f83 100644
--- a/mm/process_vm_access.c
+++ b/mm/process_vm_access.c
@@ -110,12 +110,12 @@  static int process_vm_rw_single_vec(unsigned long addr,
 		 * access remotely because task/mm might not
 		 * current/current->mm
 		 */
-		down_read(&mm->mmap_sem);
+		mm_read_lock(mm, &mmrange);
 		pages = get_user_pages_remote(task, mm, pa, pages, flags,
 					      process_pages, NULL, &locked,
 					      &mmrange);
 		if (locked)
-			up_read(&mm->mmap_sem);
+			mm_read_unlock(mm, &mmrange);
 		if (pages <= 0)
 			return -EFAULT;
 
diff --git a/mm/shmem.c b/mm/shmem.c
index 1bb3b8dc8bb2..bae06efb293d 100644
--- a/mm/shmem.c
+++ b/mm/shmem.c
@@ -2012,7 +2012,7 @@  static vm_fault_t shmem_fault(struct vm_fault *vmf)
 			if ((vmf->flags & FAULT_FLAG_ALLOW_RETRY) &&
 			   !(vmf->flags & FAULT_FLAG_RETRY_NOWAIT)) {
 				/* It's polite to up mmap_sem if we can */
-				up_read(&vma->vm_mm->mmap_sem);
+				mm_read_unlock(vma->vm_mm, vmf->lockrange);
 				ret = VM_FAULT_RETRY;
 			}
 
diff --git a/mm/swapfile.c b/mm/swapfile.c
index be36f6fe2f8c..dabe7d5391d1 100644
--- a/mm/swapfile.c
+++ b/mm/swapfile.c
@@ -1972,8 +1972,9 @@  static int unuse_mm(struct mm_struct *mm, unsigned int type,
 {
 	struct vm_area_struct *vma;
 	int ret = 0;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	for (vma = mm->mmap; vma; vma = vma->vm_next) {
 		if (vma->anon_vma) {
 			ret = unuse_vma(vma, type, frontswap,
@@ -1983,7 +1984,7 @@  static int unuse_mm(struct mm_struct *mm, unsigned int type,
 		}
 		cond_resched();
 	}
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 	return ret;
 }
 
diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index 9932d5755e4c..06daedcd06e6 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -177,7 +177,8 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 					      unsigned long dst_start,
 					      unsigned long src_start,
 					      unsigned long len,
-					      bool zeropage)
+					      bool zeropage,
+					      struct range_lock *mmrange)
 {
 	int vm_alloc_shared = dst_vma->vm_flags & VM_SHARED;
 	int vm_shared = dst_vma->vm_flags & VM_SHARED;
@@ -199,7 +200,7 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 	 * feature is not supported.
 	 */
 	if (zeropage) {
-		up_read(&dst_mm->mmap_sem);
+		mm_read_unlock(dst_mm, mmrange);
 		return -EINVAL;
 	}
 
@@ -297,7 +298,7 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 		cond_resched();
 
 		if (unlikely(err == -ENOENT)) {
-			up_read(&dst_mm->mmap_sem);
+			mm_read_unlock(dst_mm, mmrange);
 			BUG_ON(!page);
 
 			err = copy_huge_page_from_user(page,
@@ -307,7 +308,7 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 				err = -EFAULT;
 				goto out;
 			}
-			down_read(&dst_mm->mmap_sem);
+			mm_read_lock(dst_mm, mmrange);
 
 			dst_vma = NULL;
 			goto retry;
@@ -327,7 +328,7 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 	}
 
 out_unlock:
-	up_read(&dst_mm->mmap_sem);
+	mm_read_unlock(dst_mm, mmrange);
 out:
 	if (page) {
 		/*
@@ -445,6 +446,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 	unsigned long src_addr, dst_addr;
 	long copied;
 	struct page *page;
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 
 	/*
 	 * Sanitize the command parameters:
@@ -461,7 +463,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 	copied = 0;
 	page = NULL;
 retry:
-	down_read(&dst_mm->mmap_sem);
+	mm_read_lock(dst_mm, &mmrange);
 
 	/*
 	 * If memory mappings are changing because of non-cooperative
@@ -506,7 +508,8 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 	 */
 	if (is_vm_hugetlb_page(dst_vma))
 		return  __mcopy_atomic_hugetlb(dst_mm, dst_vma, dst_start,
-						src_start, len, zeropage);
+					       src_start, len, zeropage,
+					       &mmrange);
 
 	if (!vma_is_anonymous(dst_vma) && !vma_is_shmem(dst_vma))
 		goto out_unlock;
@@ -562,7 +565,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 		if (unlikely(err == -ENOENT)) {
 			void *page_kaddr;
 
-			up_read(&dst_mm->mmap_sem);
+			mm_read_unlock(dst_mm, &mmrange);
 			BUG_ON(!page);
 
 			page_kaddr = kmap(page);
@@ -591,7 +594,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 	}
 
 out_unlock:
-	up_read(&dst_mm->mmap_sem);
+	mm_read_unlock(dst_mm, &mmrange);
 out:
 	if (page)
 		put_page(page);
diff --git a/mm/util.c b/mm/util.c
index e2e4f8c3fa12..c410c17ddea7 100644
--- a/mm/util.c
+++ b/mm/util.c
@@ -350,6 +350,7 @@  unsigned long vm_mmap_pgoff(struct file *file, unsigned long addr,
 	unsigned long len, unsigned long prot,
 	unsigned long flag, unsigned long pgoff)
 {
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	unsigned long ret;
 	struct mm_struct *mm = current->mm;
 	unsigned long populate;
@@ -357,11 +358,11 @@  unsigned long vm_mmap_pgoff(struct file *file, unsigned long addr,
 
 	ret = security_mmap_file(file, prot, flag);
 	if (!ret) {
-		if (down_write_killable(&mm->mmap_sem))
+		if (mm_write_lock_killable(mm, &mmrange))
 			return -EINTR;
 		ret = do_mmap_pgoff(file, addr, len, prot, flag, pgoff,
 				    &populate, &uf);
-		up_write(&mm->mmap_sem);
+		mm_write_unlock(mm, &mmrange);
 		userfaultfd_unmap_complete(mm, &uf);
 		if (populate)
 			mm_populate(ret, populate);
@@ -711,18 +712,19 @@  int get_cmdline(struct task_struct *task, char *buffer, int buflen)
 	int res = 0;
 	unsigned int len;
 	struct mm_struct *mm = get_task_mm(task);
+	DEFINE_RANGE_LOCK_FULL(mmrange);
 	unsigned long arg_start, arg_end, env_start, env_end;
 	if (!mm)
 		goto out;
 	if (!mm->arg_end)
 		goto out_mm;	/* Shh! No looking before we're done */
 
-	down_read(&mm->mmap_sem);
+	mm_read_lock(mm, &mmrange);
 	arg_start = mm->arg_start;
 	arg_end = mm->arg_end;
 	env_start = mm->env_start;
 	env_end = mm->env_end;
-	up_read(&mm->mmap_sem);
+	mm_read_unlock(mm, &mmrange);
 
 	len = arg_end - arg_start;