diff mbox series

[1/6] mm: enable page walking API to lock vmas during the walk

Message ID 20230731171233.1098105-2-surenb@google.com (mailing list archive)
State New
Headers show
Series make vma locking more obvious | expand

Commit Message

Suren Baghdasaryan July 31, 2023, 5:12 p.m. UTC
walk_page_range() and friends often operate under write-locked mmap_lock.
With introduction of vma locks, the vmas have to be locked as well
during such walks to prevent concurrent page faults in these areas.
Add an additional parameter to walk_page_range() functions to indicate the
walks which should lock the vmas before operating on them.

Cc: stable@vger.kernel.org # 6.4.x
Suggested-by: Linus Torvalds <torvalds@linuxfoundation.org>
Suggested-by: Jann Horn <jannh@google.com>
Signed-off-by: Suren Baghdasaryan <surenb@google.com>
---
 arch/powerpc/mm/book3s64/subpage_prot.c |  2 +-
 arch/riscv/mm/pageattr.c                |  4 ++--
 arch/s390/mm/gmap.c                     | 10 +++++-----
 fs/proc/task_mmu.c                      | 10 +++++-----
 include/linux/pagewalk.h                |  6 +++---
 mm/damon/vaddr.c                        |  4 ++--
 mm/hmm.c                                |  2 +-
 mm/ksm.c                                | 16 ++++++++--------
 mm/madvise.c                            |  8 ++++----
 mm/memcontrol.c                         |  4 ++--
 mm/memory-failure.c                     |  2 +-
 mm/mempolicy.c                          | 12 ++++--------
 mm/migrate_device.c                     |  2 +-
 mm/mincore.c                            |  2 +-
 mm/mlock.c                              |  2 +-
 mm/mprotect.c                           |  2 +-
 mm/pagewalk.c                           | 13 ++++++++++---
 mm/vmscan.c                             |  3 ++-
 18 files changed, 54 insertions(+), 50 deletions(-)

Comments

Linus Torvalds July 31, 2023, 6:02 p.m. UTC | #1
On Mon, 31 Jul 2023 at 10:12, Suren Baghdasaryan <surenb@google.com> wrote:
>
> -               walk_page_vma(vma, &subpage_walk_ops, NULL);
> +               walk_page_vma(vma, &subpage_walk_ops, true, NULL);

Rather than add a new argument to the walk_page_*() functions, I
really think you should just add the locking rule to the 'const struct
mm_walk_ops' structure.

The locking rule goes along with the rules for what you are actually
doing, after all. Plus it would actually make it all much more legible
when it's not just some random "true/false" argument, but a an actual

        .write_lock = 1

in the ops definition.

Yes, yes, that might mean that some ops might need duplicating in case
you really have a walk that sometimes takes the lock, and sometimes
doesn't, but that is odd to begin with.

The only such case I found from a quick look was the very strange
queue_pages_range() case. Is it really true that do_mbind() needs the
write-lock, but do_migrate_pages() does not?

And if they really are that different maybe they should have different walk_ops?

Maybe there were other cases that I didn't notice.

>                 error = walk_page_range(current->mm, start, end,
> -                               &prot_none_walk_ops, &new_pgprot);
> +                               &prot_none_walk_ops, true, &new_pgprot);

This looks odd. You're adding vma locking to a place that didn't do it before.

Yes, the mmap semaphore is held for writing, but this particular walk
doesn't need it as far as I can tell.

In fact, this feels like that walker should maybe *verify* that it's
held for writing, but not try to write it again?

Maybe the "lock_vma" flag should be a tri-state:

 - lock for reading (no-op per vma), verify that the mmap sem is held
for reading

 - lock for reading (no-op per vma), but with mmap sem held for
writing (this kind of "check before doing changes" walker)

 - lock for writing (with mmap sem obviously needs to be held for writing)

>         mmap_assert_locked(walk.mm);
> +       if (lock_vma)
> +               vma_start_write(vma);

So I think this should also be tightened up, and something like

        switch (ops->locking) {
        case WRLOCK:
                vma_start_write(vma);
                fallthrough;
        case WRLOCK_VERIFY:
                mmap_assert_write_locked(mm);
                break;
        case RDLOCK:
                mmap_assert_locked(walk.mm);
        }

because we shouldn't have a 'vma_start_write()' without holding the
mmap sem for *writing*, and the above would also allow that
mprotect_fixup() "walk to see if we can merge, verify that it was
already locked" thing.

Hmm?

NOTE! The above names are just completely made up. I dcon't think it
should actually be some "WRLOCK" enum. There are probably much better
names. Take the above as a "maybe something kind of in this direction"
rather than "do it exactly like this".

            Linus
Suren Baghdasaryan July 31, 2023, 7:30 p.m. UTC | #2
On Mon, Jul 31, 2023 at 11:02 AM Linus Torvalds
<torvalds@linux-foundation.org> wrote:
>
> On Mon, 31 Jul 2023 at 10:12, Suren Baghdasaryan <surenb@google.com> wrote:
> >
> > -               walk_page_vma(vma, &subpage_walk_ops, NULL);
> > +               walk_page_vma(vma, &subpage_walk_ops, true, NULL);
>
> Rather than add a new argument to the walk_page_*() functions, I
> really think you should just add the locking rule to the 'const struct
> mm_walk_ops' structure.
>
> The locking rule goes along with the rules for what you are actually
> doing, after all. Plus it would actually make it all much more legible
> when it's not just some random "true/false" argument, but a an actual
>
>         .write_lock = 1
>
> in the ops definition.

Yeah, I was thinking about that but thought a flag like this in a pure
"ops" struct would be frowned upon. If this is acceptable then it
makes it much cleaner.

>
> Yes, yes, that might mean that some ops might need duplicating in case
> you really have a walk that sometimes takes the lock, and sometimes
> doesn't, but that is odd to begin with.
>
> The only such case I found from a quick look was the very strange
> queue_pages_range() case. Is it really true that do_mbind() needs the
> write-lock, but do_migrate_pages() does not?
>
> And if they really are that different maybe they should have different walk_ops?

Makes sense to me.

>
> Maybe there were other cases that I didn't notice.
>
> >                 error = walk_page_range(current->mm, start, end,
> > -                               &prot_none_walk_ops, &new_pgprot);
> > +                               &prot_none_walk_ops, true, &new_pgprot);
>
> This looks odd. You're adding vma locking to a place that didn't do it before.
>
> Yes, the mmap semaphore is held for writing, but this particular walk
> doesn't need it as far as I can tell.

Yes you are correct. Locking a vma in this case seems unnecessary.

>
> In fact, this feels like that walker should maybe *verify* that it's
> held for writing, but not try to write it again?

In this particular case, does this walk even require the vma to be
write locked? Looks like it's simply checking the ptes. And if so,
walk_page_range() already has mmap_assert_locked(walk.mm) at the
beginning to ensure the tree is stable. Do we need anything else here?

>
> Maybe the "lock_vma" flag should be a tri-state:
>
>  - lock for reading (no-op per vma), verify that the mmap sem is held
> for reading
>
>  - lock for reading (no-op per vma), but with mmap sem held for
> writing (this kind of "check before doing changes" walker)
>
>  - lock for writing (with mmap sem obviously needs to be held for writing)
>
> >         mmap_assert_locked(walk.mm);
> > +       if (lock_vma)
> > +               vma_start_write(vma);
>
> So I think this should also be tightened up, and something like
>
>         switch (ops->locking) {
>         case WRLOCK:
>                 vma_start_write(vma);
>                 fallthrough;
>         case WRLOCK_VERIFY:
>                 mmap_assert_write_locked(mm);
>                 break;
>         case RDLOCK:
>                 mmap_assert_locked(walk.mm);
>         }

I got the idea but a couple of modifications, if I may.
walk_page_range() already does mmap_assert_locked() at the beginning,
so we can change it to:

if (ops->locking == RDLOCK)
        mmap_assert_locked(walk.mm);
else
        mmap_assert_write_locked(mm);

and during the walk:

        switch (ops->locking) {
        case WRLOCK:
                 vma_start_write(vma);
                 break;
#ifdef CONFIG_PER_VMA_LOCK
        case WRLOCK_VERIFY:
                 vma_assert_write_locked(vma);
                 break;
#endif
         }

The above CONFIG_PER_VMA_LOCK is ugly but with !CONFIG_PER_VMA_LOCK
vma_assert_write_locked() becomes mmap_assert_write_locked() and we
already checked that, so it's unnecessary.

>
> because we shouldn't have a 'vma_start_write()' without holding the
> mmap sem for *writing*, and the above would also allow that
> mprotect_fixup() "walk to see if we can merge, verify that it was
> already locked" thing.
>
> Hmm?
>
> NOTE! The above names are just completely made up. I dcon't think it
> should actually be some "WRLOCK" enum. There are probably much better
> names. Take the above as a "maybe something kind of in this direction"
> rather than "do it exactly like this".

I'm not great with names... Maybe just add a PGWALK_ prefix like this:

PGWALK_RDLOCK
PGWALK_WRLOCK
PGWALK_WRLOCK_VERIFY

?

>
>             Linus
Linus Torvalds July 31, 2023, 7:33 p.m. UTC | #3
On Mon, 31 Jul 2023 at 12:31, Suren Baghdasaryan <surenb@google.com> wrote:
>
> I got the idea but a couple of modifications, if I may.

Ack, sounds sane to me.

             Linus
Suren Baghdasaryan July 31, 2023, 8:24 p.m. UTC | #4
On Mon, Jul 31, 2023 at 12:33 PM Linus Torvalds
<torvalds@linux-foundation.org> wrote:
>
> On Mon, 31 Jul 2023 at 12:31, Suren Baghdasaryan <surenb@google.com> wrote:
> >
> > I got the idea but a couple of modifications, if I may.
>
> Ack, sounds sane to me.

Ok, I'll wait for more feedback today and will post an update tomorrow. Thanks!

>
>              Linus
Suren Baghdasaryan Aug. 1, 2023, 8:28 p.m. UTC | #5
On Mon, Jul 31, 2023 at 1:24 PM Suren Baghdasaryan <surenb@google.com> wrote:
>
> On Mon, Jul 31, 2023 at 12:33 PM Linus Torvalds
> <torvalds@linux-foundation.org> wrote:
> >
> > On Mon, 31 Jul 2023 at 12:31, Suren Baghdasaryan <surenb@google.com> wrote:
> > >
> > > I got the idea but a couple of modifications, if I may.
> >
> > Ack, sounds sane to me.
>
> Ok, I'll wait for more feedback today and will post an update tomorrow. Thanks!

I have the new patchset ready but I see 3 places where we walk the
pages after mmap_write_lock() while *I think* we can tolerate
concurrent page faults (don't need to lock the vmas):

s390_enable_sie()
break_ksm()
clear_refs_write()

In all these walks we lock PTL when modifying the page table entries,
that's why I think we can skip locking the vma but maybe I'm missing
something? Could someone please check these 3 cases and confirm or
deny my claim?
Thanks,
Suren.

>
> >
> >              Linus
Peter Xu Aug. 1, 2023, 9:34 p.m. UTC | #6
On Tue, Aug 01, 2023 at 01:28:56PM -0700, Suren Baghdasaryan wrote:
> I have the new patchset ready but I see 3 places where we walk the
> pages after mmap_write_lock() while *I think* we can tolerate
> concurrent page faults (don't need to lock the vmas):
> 
> s390_enable_sie()
> break_ksm()
> clear_refs_write()

This one doesn't look right to be listed - tlb flushing is postponed after
pgtable lock released, so I assume the same issue can happen like fork():
where we can have race coditions to corrupt data if, e.g., thread A
writting with a writable (unflushed) tlb, alongside with thread B CoWing.

It'll indeed be nice to know whether break_ksm() can avoid that lock_vma
parameter across quite a few function jumps. I don't yet see an immediate
issue with this one..  No idea on s390_enable_sie(), but to make it simple
and safe I'd simply leave it with the write vma lock to match the mmap
write lock.

Thanks,
Suren Baghdasaryan Aug. 1, 2023, 9:46 p.m. UTC | #7
On Tue, Aug 1, 2023 at 2:34 PM Peter Xu <peterx@redhat.com> wrote:
>
> On Tue, Aug 01, 2023 at 01:28:56PM -0700, Suren Baghdasaryan wrote:
> > I have the new patchset ready but I see 3 places where we walk the
> > pages after mmap_write_lock() while *I think* we can tolerate
> > concurrent page faults (don't need to lock the vmas):
> >
> > s390_enable_sie()
> > break_ksm()
> > clear_refs_write()
>
> This one doesn't look right to be listed - tlb flushing is postponed after
> pgtable lock released, so I assume the same issue can happen like fork():
> where we can have race coditions to corrupt data if, e.g., thread A
> writting with a writable (unflushed) tlb, alongside with thread B CoWing.

Ah, good point.

>
> It'll indeed be nice to know whether break_ksm() can avoid that lock_vma
> parameter across quite a few function jumps. I don't yet see an immediate
> issue with this one..  No idea on s390_enable_sie(), but to make it simple
> and safe I'd simply leave it with the write vma lock to match the mmap
> write lock.

Thanks for taking a look, Peter!

Ok, let me keep all three of them with vma locking in place to be safe
and will post v2 for further discussion.
Thanks,
Suren.

>
> Thanks,
>
> --
> Peter Xu
>
Suren Baghdasaryan Aug. 1, 2023, 10:13 p.m. UTC | #8
On Tue, Aug 1, 2023 at 2:46 PM Suren Baghdasaryan <surenb@google.com> wrote:
>
> On Tue, Aug 1, 2023 at 2:34 PM Peter Xu <peterx@redhat.com> wrote:
> >
> > On Tue, Aug 01, 2023 at 01:28:56PM -0700, Suren Baghdasaryan wrote:
> > > I have the new patchset ready but I see 3 places where we walk the
> > > pages after mmap_write_lock() while *I think* we can tolerate
> > > concurrent page faults (don't need to lock the vmas):
> > >
> > > s390_enable_sie()
> > > break_ksm()
> > > clear_refs_write()
> >
> > This one doesn't look right to be listed - tlb flushing is postponed after
> > pgtable lock released, so I assume the same issue can happen like fork():
> > where we can have race coditions to corrupt data if, e.g., thread A
> > writting with a writable (unflushed) tlb, alongside with thread B CoWing.
>
> Ah, good point.
>
> >
> > It'll indeed be nice to know whether break_ksm() can avoid that lock_vma
> > parameter across quite a few function jumps. I don't yet see an immediate
> > issue with this one..  No idea on s390_enable_sie(), but to make it simple
> > and safe I'd simply leave it with the write vma lock to match the mmap
> > write lock.
>
> Thanks for taking a look, Peter!
>
> Ok, let me keep all three of them with vma locking in place to be safe
> and will post v2 for further discussion.

v2 posted at https://lore.kernel.org/all/20230801220733.1987762-1-surenb@google.com/

> Thanks,
> Suren.
>
> >
> > Thanks,
> >
> > --
> > Peter Xu
> >
diff mbox series

Patch

diff --git a/arch/powerpc/mm/book3s64/subpage_prot.c b/arch/powerpc/mm/book3s64/subpage_prot.c
index 0dc85556dec5..177e5c646d9c 100644
--- a/arch/powerpc/mm/book3s64/subpage_prot.c
+++ b/arch/powerpc/mm/book3s64/subpage_prot.c
@@ -159,7 +159,7 @@  static void subpage_mark_vma_nohuge(struct mm_struct *mm, unsigned long addr,
 	 */
 	for_each_vma_range(vmi, vma, addr + len) {
 		vm_flags_set(vma, VM_NOHUGEPAGE);
-		walk_page_vma(vma, &subpage_walk_ops, NULL);
+		walk_page_vma(vma, &subpage_walk_ops, true, NULL);
 	}
 }
 #else
diff --git a/arch/riscv/mm/pageattr.c b/arch/riscv/mm/pageattr.c
index ea3d61de065b..95207994cbf0 100644
--- a/arch/riscv/mm/pageattr.c
+++ b/arch/riscv/mm/pageattr.c
@@ -167,7 +167,7 @@  int set_direct_map_invalid_noflush(struct page *page)
 	};
 
 	mmap_read_lock(&init_mm);
-	ret = walk_page_range(&init_mm, start, end, &pageattr_ops, &masks);
+	ret = walk_page_range(&init_mm, start, end, &pageattr_ops, false, &masks);
 	mmap_read_unlock(&init_mm);
 
 	return ret;
@@ -184,7 +184,7 @@  int set_direct_map_default_noflush(struct page *page)
 	};
 
 	mmap_read_lock(&init_mm);
-	ret = walk_page_range(&init_mm, start, end, &pageattr_ops, &masks);
+	ret = walk_page_range(&init_mm, start, end, &pageattr_ops, false, &masks);
 	mmap_read_unlock(&init_mm);
 
 	return ret;
diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
index 9c8af31be970..16a58c860c74 100644
--- a/arch/s390/mm/gmap.c
+++ b/arch/s390/mm/gmap.c
@@ -2523,7 +2523,7 @@  static inline void thp_split_mm(struct mm_struct *mm)
 
 	for_each_vma(vmi, vma) {
 		vm_flags_mod(vma, VM_NOHUGEPAGE, VM_HUGEPAGE);
-		walk_page_vma(vma, &thp_split_walk_ops, NULL);
+		walk_page_vma(vma, &thp_split_walk_ops, true, NULL);
 	}
 	mm->def_flags |= VM_NOHUGEPAGE;
 }
@@ -2584,7 +2584,7 @@  int s390_enable_sie(void)
 	mm->context.has_pgste = 1;
 	/* split thp mappings and disable thp for future mappings */
 	thp_split_mm(mm);
-	walk_page_range(mm, 0, TASK_SIZE, &zap_zero_walk_ops, NULL);
+	walk_page_range(mm, 0, TASK_SIZE, &zap_zero_walk_ops, true, NULL);
 	mmap_write_unlock(mm);
 	return 0;
 }
@@ -2672,7 +2672,7 @@  int s390_enable_skey(void)
 		mm->context.uses_skeys = 0;
 		goto out_up;
 	}
-	walk_page_range(mm, 0, TASK_SIZE, &enable_skey_walk_ops, NULL);
+	walk_page_range(mm, 0, TASK_SIZE, &enable_skey_walk_ops, true, NULL);
 
 out_up:
 	mmap_write_unlock(mm);
@@ -2697,7 +2697,7 @@  static const struct mm_walk_ops reset_cmma_walk_ops = {
 void s390_reset_cmma(struct mm_struct *mm)
 {
 	mmap_write_lock(mm);
-	walk_page_range(mm, 0, TASK_SIZE, &reset_cmma_walk_ops, NULL);
+	walk_page_range(mm, 0, TASK_SIZE, &reset_cmma_walk_ops, true, NULL);
 	mmap_write_unlock(mm);
 }
 EXPORT_SYMBOL_GPL(s390_reset_cmma);
@@ -2771,7 +2771,7 @@  int __s390_uv_destroy_range(struct mm_struct *mm, unsigned long start,
 	while (r > 0) {
 		state.count = 0;
 		mmap_read_lock(mm);
-		r = walk_page_range(mm, state.next, end, &gather_pages_ops, &state);
+		r = walk_page_range(mm, state.next, end, &gather_pages_ops, false, &state);
 		mmap_read_unlock(mm);
 		cond_resched();
 		s390_uv_destroy_pfns(state.count, state.pfns);
diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c
index 507cd4e59d07..f0d0f2959f91 100644
--- a/fs/proc/task_mmu.c
+++ b/fs/proc/task_mmu.c
@@ -804,9 +804,9 @@  static void smap_gather_stats(struct vm_area_struct *vma,
 
 	/* mmap_lock is held in m_start */
 	if (!start)
-		walk_page_vma(vma, ops, mss);
+		walk_page_vma(vma, ops, false, mss);
 	else
-		walk_page_range(vma->vm_mm, start, vma->vm_end, ops, mss);
+		walk_page_range(vma->vm_mm, start, vma->vm_end, ops, false, mss);
 }
 
 #define SEQ_PUT_DEC(str, val) \
@@ -1307,7 +1307,7 @@  static ssize_t clear_refs_write(struct file *file, const char __user *buf,
 						0, mm, 0, -1UL);
 			mmu_notifier_invalidate_range_start(&range);
 		}
-		walk_page_range(mm, 0, -1, &clear_refs_walk_ops, &cp);
+		walk_page_range(mm, 0, -1, &clear_refs_walk_ops, true, &cp);
 		if (type == CLEAR_REFS_SOFT_DIRTY) {
 			mmu_notifier_invalidate_range_end(&range);
 			flush_tlb_mm(mm);
@@ -1720,7 +1720,7 @@  static ssize_t pagemap_read(struct file *file, char __user *buf,
 		ret = mmap_read_lock_killable(mm);
 		if (ret)
 			goto out_free;
-		ret = walk_page_range(mm, start_vaddr, end, &pagemap_ops, &pm);
+		ret = walk_page_range(mm, start_vaddr, end, &pagemap_ops, false, &pm);
 		mmap_read_unlock(mm);
 		start_vaddr = end;
 
@@ -1981,7 +1981,7 @@  static int show_numa_map(struct seq_file *m, void *v)
 		seq_puts(m, " huge");
 
 	/* mmap_lock is held by m_start */
-	walk_page_vma(vma, &show_numa_ops, md);
+	walk_page_vma(vma, &show_numa_ops, false, md);
 
 	if (!md->pages)
 		goto out;
diff --git a/include/linux/pagewalk.h b/include/linux/pagewalk.h
index 27a6df448ee5..69656ec44049 100644
--- a/include/linux/pagewalk.h
+++ b/include/linux/pagewalk.h
@@ -105,16 +105,16 @@  struct mm_walk {
 
 int walk_page_range(struct mm_struct *mm, unsigned long start,
 		unsigned long end, const struct mm_walk_ops *ops,
-		void *private);
+		bool lock_vma, void *private);
 int walk_page_range_novma(struct mm_struct *mm, unsigned long start,
 			  unsigned long end, const struct mm_walk_ops *ops,
 			  pgd_t *pgd,
 			  void *private);
 int walk_page_range_vma(struct vm_area_struct *vma, unsigned long start,
 			unsigned long end, const struct mm_walk_ops *ops,
-			void *private);
+			bool lock_vma, void *private);
 int walk_page_vma(struct vm_area_struct *vma, const struct mm_walk_ops *ops,
-		void *private);
+		  bool lock_vma, void *private);
 int walk_page_mapping(struct address_space *mapping, pgoff_t first_index,
 		      pgoff_t nr, const struct mm_walk_ops *ops,
 		      void *private);
diff --git a/mm/damon/vaddr.c b/mm/damon/vaddr.c
index 2fcc9731528a..54f50b1aefe4 100644
--- a/mm/damon/vaddr.c
+++ b/mm/damon/vaddr.c
@@ -391,7 +391,7 @@  static const struct mm_walk_ops damon_mkold_ops = {
 static void damon_va_mkold(struct mm_struct *mm, unsigned long addr)
 {
 	mmap_read_lock(mm);
-	walk_page_range(mm, addr, addr + 1, &damon_mkold_ops, NULL);
+	walk_page_range(mm, addr, addr + 1, &damon_mkold_ops, false, NULL);
 	mmap_read_unlock(mm);
 }
 
@@ -536,7 +536,7 @@  static bool damon_va_young(struct mm_struct *mm, unsigned long addr,
 	};
 
 	mmap_read_lock(mm);
-	walk_page_range(mm, addr, addr + 1, &damon_young_ops, &arg);
+	walk_page_range(mm, addr, addr + 1, &damon_young_ops, false, &arg);
 	mmap_read_unlock(mm);
 	return arg.young;
 }
diff --git a/mm/hmm.c b/mm/hmm.c
index 855e25e59d8f..f94f5e268e40 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -600,7 +600,7 @@  int hmm_range_fault(struct hmm_range *range)
 					     range->notifier_seq))
 			return -EBUSY;
 		ret = walk_page_range(mm, hmm_vma_walk.last, range->end,
-				      &hmm_walk_ops, &hmm_vma_walk);
+				      &hmm_walk_ops, false, &hmm_vma_walk);
 		/*
 		 * When -EBUSY is returned the loop restarts with
 		 * hmm_vma_walk.last set to an address that has not been stored
diff --git a/mm/ksm.c b/mm/ksm.c
index ba266359da55..494a1f3fcb97 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -470,7 +470,7 @@  static const struct mm_walk_ops break_ksm_ops = {
  * of the process that owns 'vma'.  We also do not want to enforce
  * protection keys here anyway.
  */
-static int break_ksm(struct vm_area_struct *vma, unsigned long addr)
+static int break_ksm(struct vm_area_struct *vma, unsigned long addr, bool lock_vma)
 {
 	vm_fault_t ret = 0;
 
@@ -479,7 +479,7 @@  static int break_ksm(struct vm_area_struct *vma, unsigned long addr)
 
 		cond_resched();
 		ksm_page = walk_page_range_vma(vma, addr, addr + 1,
-					       &break_ksm_ops, NULL);
+					       &break_ksm_ops, lock_vma, NULL);
 		if (WARN_ON_ONCE(ksm_page < 0))
 			return ksm_page;
 		if (!ksm_page)
@@ -565,7 +565,7 @@  static void break_cow(struct ksm_rmap_item *rmap_item)
 	mmap_read_lock(mm);
 	vma = find_mergeable_vma(mm, addr);
 	if (vma)
-		break_ksm(vma, addr);
+		break_ksm(vma, addr, false);
 	mmap_read_unlock(mm);
 }
 
@@ -871,7 +871,7 @@  static void remove_trailing_rmap_items(struct ksm_rmap_item **rmap_list)
  * in cmp_and_merge_page on one of the rmap_items we would be removing.
  */
 static int unmerge_ksm_pages(struct vm_area_struct *vma,
-			     unsigned long start, unsigned long end)
+			     unsigned long start, unsigned long end, bool lock_vma)
 {
 	unsigned long addr;
 	int err = 0;
@@ -882,7 +882,7 @@  static int unmerge_ksm_pages(struct vm_area_struct *vma,
 		if (signal_pending(current))
 			err = -ERESTARTSYS;
 		else
-			err = break_ksm(vma, addr);
+			err = break_ksm(vma, addr, lock_vma);
 	}
 	return err;
 }
@@ -1029,7 +1029,7 @@  static int unmerge_and_remove_all_rmap_items(void)
 			if (!(vma->vm_flags & VM_MERGEABLE) || !vma->anon_vma)
 				continue;
 			err = unmerge_ksm_pages(vma,
-						vma->vm_start, vma->vm_end);
+						vma->vm_start, vma->vm_end, false);
 			if (err)
 				goto error;
 		}
@@ -2530,7 +2530,7 @@  static int __ksm_del_vma(struct vm_area_struct *vma)
 		return 0;
 
 	if (vma->anon_vma) {
-		err = unmerge_ksm_pages(vma, vma->vm_start, vma->vm_end);
+		err = unmerge_ksm_pages(vma, vma->vm_start, vma->vm_end, true);
 		if (err)
 			return err;
 	}
@@ -2668,7 +2668,7 @@  int ksm_madvise(struct vm_area_struct *vma, unsigned long start,
 			return 0;		/* just ignore the advice */
 
 		if (vma->anon_vma) {
-			err = unmerge_ksm_pages(vma, start, end);
+			err = unmerge_ksm_pages(vma, start, end, true);
 			if (err)
 				return err;
 		}
diff --git a/mm/madvise.c b/mm/madvise.c
index 886f06066622..0e484111a1d2 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -287,7 +287,7 @@  static long madvise_willneed(struct vm_area_struct *vma,
 	*prev = vma;
 #ifdef CONFIG_SWAP
 	if (!file) {
-		walk_page_range(vma->vm_mm, start, end, &swapin_walk_ops, vma);
+		walk_page_range(vma->vm_mm, start, end, &swapin_walk_ops, false, vma);
 		lru_add_drain(); /* Push any new pages onto the LRU now */
 		return 0;
 	}
@@ -546,7 +546,7 @@  static void madvise_cold_page_range(struct mmu_gather *tlb,
 	};
 
 	tlb_start_vma(tlb, vma);
-	walk_page_range(vma->vm_mm, addr, end, &cold_walk_ops, &walk_private);
+	walk_page_range(vma->vm_mm, addr, end, &cold_walk_ops, false, &walk_private);
 	tlb_end_vma(tlb, vma);
 }
 
@@ -584,7 +584,7 @@  static void madvise_pageout_page_range(struct mmu_gather *tlb,
 	};
 
 	tlb_start_vma(tlb, vma);
-	walk_page_range(vma->vm_mm, addr, end, &cold_walk_ops, &walk_private);
+	walk_page_range(vma->vm_mm, addr, end, &cold_walk_ops, false, &walk_private);
 	tlb_end_vma(tlb, vma);
 }
 
@@ -786,7 +786,7 @@  static int madvise_free_single_vma(struct vm_area_struct *vma,
 	mmu_notifier_invalidate_range_start(&range);
 	tlb_start_vma(&tlb, vma);
 	walk_page_range(vma->vm_mm, range.start, range.end,
-			&madvise_free_walk_ops, &tlb);
+			&madvise_free_walk_ops, false, &tlb);
 	tlb_end_vma(&tlb, vma);
 	mmu_notifier_invalidate_range_end(&range);
 	tlb_finish_mmu(&tlb);
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index e8ca4bdcb03c..76aaadbd4bf9 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -6031,7 +6031,7 @@  static unsigned long mem_cgroup_count_precharge(struct mm_struct *mm)
 	unsigned long precharge;
 
 	mmap_read_lock(mm);
-	walk_page_range(mm, 0, ULONG_MAX, &precharge_walk_ops, NULL);
+	walk_page_range(mm, 0, ULONG_MAX, &precharge_walk_ops, false, NULL);
 	mmap_read_unlock(mm);
 
 	precharge = mc.precharge;
@@ -6332,7 +6332,7 @@  static void mem_cgroup_move_charge(void)
 	 * When we have consumed all precharges and failed in doing
 	 * additional charge, the page walk just aborts.
 	 */
-	walk_page_range(mc.mm, 0, ULONG_MAX, &charge_walk_ops, NULL);
+	walk_page_range(mc.mm, 0, ULONG_MAX, &charge_walk_ops, false, NULL);
 	mmap_read_unlock(mc.mm);
 	atomic_dec(&mc.from->moving_account);
 }
diff --git a/mm/memory-failure.c b/mm/memory-failure.c
index ece5d481b5ff..763297df9240 100644
--- a/mm/memory-failure.c
+++ b/mm/memory-failure.c
@@ -860,7 +860,7 @@  static int kill_accessing_process(struct task_struct *p, unsigned long pfn,
 
 	mmap_read_lock(p->mm);
 	ret = walk_page_range(p->mm, 0, TASK_SIZE, &hwp_walk_ops,
-			      (void *)&priv);
+			      false, (void *)&priv);
 	if (ret == 1 && priv.tk.addr)
 		kill_proc(&priv.tk, pfn, flags);
 	else
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index c53f8beeb507..70ba53c70700 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -738,7 +738,7 @@  static const struct mm_walk_ops queue_pages_walk_ops = {
 static int
 queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end,
 		nodemask_t *nodes, unsigned long flags,
-		struct list_head *pagelist)
+		struct list_head *pagelist, bool lock_vma)
 {
 	int err;
 	struct queue_pages qp = {
@@ -750,7 +750,7 @@  queue_pages_range(struct mm_struct *mm, unsigned long start, unsigned long end,
 		.first = NULL,
 	};
 
-	err = walk_page_range(mm, start, end, &queue_pages_walk_ops, &qp);
+	err = walk_page_range(mm, start, end, &queue_pages_walk_ops, lock_vma, &qp);
 
 	if (!qp.first)
 		/* whole range in hole */
@@ -1078,7 +1078,7 @@  static int migrate_to_node(struct mm_struct *mm, int source, int dest,
 	vma = find_vma(mm, 0);
 	VM_BUG_ON(!(flags & (MPOL_MF_MOVE | MPOL_MF_MOVE_ALL)));
 	queue_pages_range(mm, vma->vm_start, mm->task_size, &nmask,
-			flags | MPOL_MF_DISCONTIG_OK, &pagelist);
+			flags | MPOL_MF_DISCONTIG_OK, &pagelist, false);
 
 	if (!list_empty(&pagelist)) {
 		err = migrate_pages(&pagelist, alloc_migration_target, NULL,
@@ -1321,12 +1321,8 @@  static long do_mbind(unsigned long start, unsigned long len,
 	 * Lock the VMAs before scanning for pages to migrate, to ensure we don't
 	 * miss a concurrently inserted page.
 	 */
-	vma_iter_init(&vmi, mm, start);
-	for_each_vma_range(vmi, vma, end)
-		vma_start_write(vma);
-
 	ret = queue_pages_range(mm, start, end, nmask,
-			  flags | MPOL_MF_INVERT, &pagelist);
+			  flags | MPOL_MF_INVERT, &pagelist, true);
 
 	if (ret < 0) {
 		err = ret;
diff --git a/mm/migrate_device.c b/mm/migrate_device.c
index 8365158460ed..1bc9937bf1fb 100644
--- a/mm/migrate_device.c
+++ b/mm/migrate_device.c
@@ -304,7 +304,7 @@  static void migrate_vma_collect(struct migrate_vma *migrate)
 	mmu_notifier_invalidate_range_start(&range);
 
 	walk_page_range(migrate->vma->vm_mm, migrate->start, migrate->end,
-			&migrate_vma_walk_ops, migrate);
+			&migrate_vma_walk_ops, false, migrate);
 
 	mmu_notifier_invalidate_range_end(&range);
 	migrate->end = migrate->start + (migrate->npages << PAGE_SHIFT);
diff --git a/mm/mincore.c b/mm/mincore.c
index b7f7a516b26c..a06288c6c126 100644
--- a/mm/mincore.c
+++ b/mm/mincore.c
@@ -198,7 +198,7 @@  static long do_mincore(unsigned long addr, unsigned long pages, unsigned char *v
 		memset(vec, 1, pages);
 		return pages;
 	}
-	err = walk_page_range(vma->vm_mm, addr, end, &mincore_walk_ops, vec);
+	err = walk_page_range(vma->vm_mm, addr, end, &mincore_walk_ops, false, vec);
 	if (err < 0)
 		return err;
 	return (end - addr) >> PAGE_SHIFT;
diff --git a/mm/mlock.c b/mm/mlock.c
index 0a0c996c5c21..3634de0b28e3 100644
--- a/mm/mlock.c
+++ b/mm/mlock.c
@@ -389,7 +389,7 @@  static void mlock_vma_pages_range(struct vm_area_struct *vma,
 	vm_flags_reset_once(vma, newflags);
 
 	lru_add_drain();
-	walk_page_range(vma->vm_mm, start, end, &mlock_walk_ops, NULL);
+	walk_page_range(vma->vm_mm, start, end, &mlock_walk_ops, true, NULL);
 	lru_add_drain();
 
 	if (newflags & VM_IO) {
diff --git a/mm/mprotect.c b/mm/mprotect.c
index 6f658d483704..f781f709c39d 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -599,7 +599,7 @@  mprotect_fixup(struct vma_iterator *vmi, struct mmu_gather *tlb,
 		pgprot_t new_pgprot = vm_get_page_prot(newflags);
 
 		error = walk_page_range(current->mm, start, end,
-				&prot_none_walk_ops, &new_pgprot);
+				&prot_none_walk_ops, true, &new_pgprot);
 		if (error)
 			return error;
 	}
diff --git a/mm/pagewalk.c b/mm/pagewalk.c
index 2022333805d3..7503885fae75 100644
--- a/mm/pagewalk.c
+++ b/mm/pagewalk.c
@@ -406,6 +406,7 @@  static int __walk_page_range(unsigned long start, unsigned long end,
  * @start:	start address of the virtual address range
  * @end:	end address of the virtual address range
  * @ops:	operation to call during the walk
+ * @lock_vma	write-lock the vma before operating on it
  * @private:	private data for callbacks' usage
  *
  * Recursively walk the page table tree of the process represented by @mm
@@ -442,7 +443,7 @@  static int __walk_page_range(unsigned long start, unsigned long end,
  */
 int walk_page_range(struct mm_struct *mm, unsigned long start,
 		unsigned long end, const struct mm_walk_ops *ops,
-		void *private)
+		bool lock_vma, void *private)
 {
 	int err = 0;
 	unsigned long next;
@@ -474,6 +475,8 @@  int walk_page_range(struct mm_struct *mm, unsigned long start,
 			if (ops->pte_hole)
 				err = ops->pte_hole(start, next, -1, &walk);
 		} else { /* inside vma */
+			if (lock_vma)
+				vma_start_write(vma);
 			walk.vma = vma;
 			next = min(end, vma->vm_end);
 			vma = find_vma(mm, vma->vm_end);
@@ -535,7 +538,7 @@  int walk_page_range_novma(struct mm_struct *mm, unsigned long start,
 
 int walk_page_range_vma(struct vm_area_struct *vma, unsigned long start,
 			unsigned long end, const struct mm_walk_ops *ops,
-			void *private)
+			bool lock_vma, void *private)
 {
 	struct mm_walk walk = {
 		.ops		= ops,
@@ -550,11 +553,13 @@  int walk_page_range_vma(struct vm_area_struct *vma, unsigned long start,
 		return -EINVAL;
 
 	mmap_assert_locked(walk.mm);
+	if (lock_vma)
+		vma_start_write(vma);
 	return __walk_page_range(start, end, &walk);
 }
 
 int walk_page_vma(struct vm_area_struct *vma, const struct mm_walk_ops *ops,
-		void *private)
+		  bool lock_vma, void *private)
 {
 	struct mm_walk walk = {
 		.ops		= ops,
@@ -567,6 +572,8 @@  int walk_page_vma(struct vm_area_struct *vma, const struct mm_walk_ops *ops,
 		return -EINVAL;
 
 	mmap_assert_locked(walk.mm);
+	if (lock_vma)
+		vma_start_write(vma);
 	return __walk_page_range(vma->vm_start, vma->vm_end, &walk);
 }
 
diff --git a/mm/vmscan.c b/mm/vmscan.c
index 1080209a568b..d85f86871fd9 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -4306,7 +4306,8 @@  static void walk_mm(struct lruvec *lruvec, struct mm_struct *mm, struct lru_gen_
 
 		/* the caller might be holding the lock for write */
 		if (mmap_read_trylock(mm)) {
-			err = walk_page_range(mm, walk->next_addr, ULONG_MAX, &mm_walk_ops, walk);
+			err = walk_page_range(mm, walk->next_addr, ULONG_MAX,
+					      &mm_walk_ops, false, walk);
 
 			mmap_read_unlock(mm);
 		}