diff mbox series

[2/2] mm: use READ/WRITE_ONCE to access anonymous vmas vm_start/vm_end/vm_pgoff

Message ID 20190301035550.1124-3-aarcange@redhat.com (mailing list archive)
State New, archived
Headers show
Series RFC: READ/WRITE_ONCE vma/mm cleanups | expand

Commit Message

Andrea Arcangeli March 1, 2019, 3:55 a.m. UTC
This converts the updates under mmap_sem for reading, rmap lock for
writing and PT lock to vm_start/end/pgoff of anonymous vmas to use
WRITE_ONCE().

This also converts some of the accesses under mmap_sem for reading
that are concurrent with the aforementioned WRITE_ONCE()s to use
READ_ONCE().

Signed-off-by: Andrea Arcangeli <aarcange@redhat.com>
---
 mm/gup.c      | 23 +++++++++++++----------
 mm/internal.h |  3 ++-
 mm/memory.c   |  2 +-
 mm/mmap.c     | 16 ++++++++--------
 mm/rmap.c     |  3 ++-
 mm/vmacache.c |  3 ++-
 6 files changed, 28 insertions(+), 22 deletions(-)
diff mbox series

Patch

diff --git a/mm/gup.c b/mm/gup.c
index 75029649baca..5cac5c462b40 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -699,7 +699,7 @@  static long __get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
 		unsigned int page_increm;
 
 		/* first iteration or cross vma bound */
-		if (!vma || start >= vma->vm_end) {
+		if (!vma || start >= READ_ONCE(vma->vm_end)) {
 			vma = find_extend_vma(mm, start);
 			if (!vma && in_gate_area(mm, start)) {
 				ret = get_gate_page(mm, start & PAGE_MASK,
@@ -850,7 +850,7 @@  int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
 
 retry:
 	vma = find_extend_vma(mm, address);
-	if (!vma || address < vma->vm_start)
+	if (!vma || address < READ_ONCE(vma->vm_start))
 		return -EFAULT;
 
 	if (!vma_permits_fault(vma, fault_flags))
@@ -1218,8 +1218,8 @@  long populate_vma_page_range(struct vm_area_struct *vma,
 
 	VM_BUG_ON(start & ~PAGE_MASK);
 	VM_BUG_ON(end   & ~PAGE_MASK);
-	VM_BUG_ON_VMA(start < vma->vm_start, vma);
-	VM_BUG_ON_VMA(end   > vma->vm_end, vma);
+	VM_BUG_ON_VMA(start < READ_ONCE(vma->vm_start), vma);
+	VM_BUG_ON_VMA(end   > READ_ONCE(vma->vm_end), vma);
 	VM_BUG_ON_MM(!rwsem_is_locked(&mm->mmap_sem), mm);
 
 	gup_flags = FOLL_TOUCH | FOLL_POPULATE | FOLL_MLOCK;
@@ -1258,7 +1258,7 @@  long populate_vma_page_range(struct vm_area_struct *vma,
 int __mm_populate(unsigned long start, unsigned long len, int ignore_errors)
 {
 	struct mm_struct *mm = current->mm;
-	unsigned long end, nstart, nend;
+	unsigned long end, nstart, nend, vma_start, vma_end;
 	struct vm_area_struct *vma = NULL;
 	int locked = 0;
 	long ret = 0;
@@ -1274,19 +1274,22 @@  int __mm_populate(unsigned long start, unsigned long len, int ignore_errors)
 			locked = 1;
 			down_read(&mm->mmap_sem);
 			vma = find_vma(mm, nstart);
-		} else if (nstart >= vma->vm_end)
+		} else if (nstart >= vma_end)
 			vma = vma->vm_next;
-		if (!vma || vma->vm_start >= end)
+		if (!vma)
 			break;
+		vma_start = READ_ONCE(vma->vm_start);
+		if (vma_start >= end)
+			break;
+		vma_end = READ_ONCE(vma->vm_end);
 		/*
 		 * Set [nstart; nend) to intersection of desired address
 		 * range with the first VMA. Also, skip undesirable VMA types.
 		 */
-		nend = min(end, vma->vm_end);
+		nend = min(end, vma_end);
 		if (vma->vm_flags & (VM_IO | VM_PFNMAP))
 			continue;
-		if (nstart < vma->vm_start)
-			nstart = vma->vm_start;
+		nstart = max(nstart, vma_start);
 		/*
 		 * Now fault in a range of pages. populate_vma_page_range()
 		 * double checks the vma flags, so that it won't mlock pages
diff --git a/mm/internal.h b/mm/internal.h
index f4a7bb02decf..839dbcf3c7ed 100644
--- a/mm/internal.h
+++ b/mm/internal.h
@@ -337,7 +337,8 @@  static inline unsigned long
 __vma_address(struct page *page, struct vm_area_struct *vma)
 {
 	pgoff_t pgoff = page_to_pgoff(page);
-	return vma->vm_start + ((pgoff - vma->vm_pgoff) << PAGE_SHIFT);
+	return READ_ONCE(vma->vm_start) +
+		((pgoff - READ_ONCE(vma->vm_pgoff)) << PAGE_SHIFT);
 }
 
 static inline unsigned long
diff --git a/mm/memory.c b/mm/memory.c
index 896d8aa08c0a..b76b659a026d 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -4257,7 +4257,7 @@  int __access_remote_vm(struct task_struct *tsk, struct mm_struct *mm,
 			 * we can access using slightly different code.
 			 */
 			vma = find_vma(mm, addr);
-			if (!vma || vma->vm_start > addr)
+			if (!vma || READ_ONCE(vma->vm_start) > addr)
 				break;
 			if (vma->vm_ops && vma->vm_ops->access)
 				ret = vma->vm_ops->access(vma, addr, buf,
diff --git a/mm/mmap.c b/mm/mmap.c
index f901065c4c64..9b84617c11c6 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -2240,9 +2240,9 @@  struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr)
 
 		tmp = rb_entry(rb_node, struct vm_area_struct, vm_rb);
 
-		if (tmp->vm_end > addr) {
+		if (READ_ONCE(tmp->vm_end) > addr) {
 			vma = tmp;
-			if (tmp->vm_start <= addr)
+			if (READ_ONCE(tmp->vm_start) <= addr)
 				break;
 			rb_node = rb_node->rb_left;
 		} else
@@ -2399,7 +2399,7 @@  int expand_upwards(struct vm_area_struct *vma, unsigned long address)
 					mm->locked_vm += grow;
 				vm_stat_account(mm, vma->vm_flags, grow);
 				anon_vma_interval_tree_pre_update_vma(vma);
-				vma->vm_end = address;
+				WRITE_ONCE(vma->vm_end, address);
 				anon_vma_interval_tree_post_update_vma(vma);
 				if (vma->vm_next)
 					vma_gap_update(vma->vm_next);
@@ -2480,8 +2480,8 @@  int expand_downwards(struct vm_area_struct *vma,
 					mm->locked_vm += grow;
 				vm_stat_account(mm, vma->vm_flags, grow);
 				anon_vma_interval_tree_pre_update_vma(vma);
-				vma->vm_start = address;
-				vma->vm_pgoff -= grow;
+				WRITE_ONCE(vma->vm_start, address);
+				WRITE_ONCE(vma->vm_pgoff, vma->vm_pgoff - grow);
 				anon_vma_interval_tree_post_update_vma(vma);
 				vma_gap_update(vma);
 				spin_unlock(&mm->page_table_lock);
@@ -2530,7 +2530,7 @@  find_extend_vma(struct mm_struct *mm, unsigned long addr)
 	if (!prev || expand_stack(prev, addr))
 		return NULL;
 	if (prev->vm_flags & VM_LOCKED)
-		populate_vma_page_range(prev, addr, prev->vm_end, NULL);
+		populate_vma_page_range(prev, addr, READ_ONCE(prev->vm_end), NULL);
 	return prev;
 }
 #else
@@ -2549,11 +2549,11 @@  find_extend_vma(struct mm_struct *mm, unsigned long addr)
 	vma = find_vma(mm, addr);
 	if (!vma)
 		return NULL;
-	if (vma->vm_start <= addr)
+	start = READ_ONCE(vma->vm_start);
+	if (start <= addr)
 		return vma;
 	if (!(vma->vm_flags & VM_GROWSDOWN))
 		return NULL;
-	start = vma->vm_start;
 	if (expand_stack(vma, addr))
 		return NULL;
 	if (vma->vm_flags & VM_LOCKED)
diff --git a/mm/rmap.c b/mm/rmap.c
index 0454ecc29537..d8d06bb87381 100644
--- a/mm/rmap.c
+++ b/mm/rmap.c
@@ -702,7 +702,8 @@  unsigned long page_address_in_vma(struct page *page, struct vm_area_struct *vma)
 	} else
 		return -EFAULT;
 	address = __vma_address(page, vma);
-	if (unlikely(address < vma->vm_start || address >= vma->vm_end))
+	if (unlikely(address < READ_ONCE(vma->vm_start) ||
+		     address >= READ_ONCE(vma->vm_end)))
 		return -EFAULT;
 	return address;
 }
diff --git a/mm/vmacache.c b/mm/vmacache.c
index cdc32a3b02fa..655554c85bdb 100644
--- a/mm/vmacache.c
+++ b/mm/vmacache.c
@@ -77,7 +77,8 @@  struct vm_area_struct *vmacache_find(struct mm_struct *mm, unsigned long addr)
 			if (WARN_ON_ONCE(vma->vm_mm != mm))
 				break;
 #endif
-			if (vma->vm_start <= addr && vma->vm_end > addr) {
+			if (READ_ONCE(vma->vm_start) <= addr &&
+			    READ_ONCE(vma->vm_end) > addr) {
 				count_vm_vmacache_event(VMACACHE_FIND_HITS);
 				return vma;
 			}