diff mbox series

[1/2] mm/mprotect: use mmu_gather

Message ID 20210925205423.168858-2-namit@vmware.com (mailing list archive)
State New
Headers show
Series mm/mprotect: avoid unnecessary TLB flushes | expand

Commit Message

Nadav Amit Sept. 25, 2021, 8:54 p.m. UTC
From: Nadav Amit <namit@vmware.com>

change_pXX_range() currently does not use mmu_gather, but instead
implements its own deferred TLB flushes scheme. This both complicates
the code, as developers need to be aware of different invalidation
schemes, and prevents opportunities to avoid TLB flushes or perform them
in finer granularity.

Use mmu_gather in change_pXX_range(). As the pages are not released,
only record the flushed range using tlb_flush_pXX_range().

Cc: Andrea Arcangeli <aarcange@redhat.com>
Cc: Andrew Cooper <andrew.cooper3@citrix.com>
Cc: Andrew Morton <akpm@linux-foundation.org>
Cc: Andy Lutomirski <luto@kernel.org>
Cc: Dave Hansen <dave.hansen@linux.intel.com>
Cc: Peter Zijlstra <peterz@infradead.org>
Cc: Thomas Gleixner <tglx@linutronix.de>
Cc: Will Deacon <will@kernel.org>
Cc: Yu Zhao <yuzhao@google.com>
Cc: Nick Piggin <npiggin@gmail.com>
Cc: x86@kernel.org
Signed-off-by: Nadav Amit <namit@vmware.com>
---
 mm/mprotect.c | 50 ++++++++++++++++++++++++++++----------------------
 1 file changed, 28 insertions(+), 22 deletions(-)

Comments

Peter Zijlstra Oct. 3, 2021, 12:10 p.m. UTC | #1
On Sat, Sep 25, 2021 at 01:54:22PM -0700, Nadav Amit wrote:

> @@ -338,25 +344,25 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,
>  	struct mm_struct *mm = vma->vm_mm;
>  	pgd_t *pgd;
>  	unsigned long next;
> -	unsigned long start = addr;
>  	unsigned long pages = 0;
> +	struct mmu_gather tlb;
>  
>  	BUG_ON(addr >= end);
>  	pgd = pgd_offset(mm, addr);
>  	flush_cache_range(vma, addr, end);
>  	inc_tlb_flush_pending(mm);

That seems unbalanced...

> +	tlb_gather_mmu(&tlb, mm);
> +	tlb_start_vma(&tlb, vma);
>  	do {
>  		next = pgd_addr_end(addr, end);
>  		if (pgd_none_or_clear_bad(pgd))
>  			continue;
> -		pages += change_p4d_range(vma, pgd, addr, next, newprot,
> +		pages += change_p4d_range(&tlb, vma, pgd, addr, next, newprot,
>  					  cp_flags);
>  	} while (pgd++, addr = next, addr != end);
>  
> -	/* Only flush the TLB if we actually modified any entries: */
> -	if (pages)
> -		flush_tlb_range(vma, start, end);
> -	dec_tlb_flush_pending(mm);

... seeing you do remove the extra decrement.

> +	tlb_end_vma(&tlb, vma);
> +	tlb_finish_mmu(&tlb);
>  
>  	return pages;
>  }
> -- 
> 2.25.1
>
Nadav Amit Oct. 4, 2021, 7:24 p.m. UTC | #2
> On Oct 3, 2021, at 5:10 AM, Peter Zijlstra <peterz@infradead.org> wrote:
> 
> On Sat, Sep 25, 2021 at 01:54:22PM -0700, Nadav Amit wrote:
> 
>> @@ -338,25 +344,25 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,
>> 	struct mm_struct *mm = vma->vm_mm;
>> 	pgd_t *pgd;
>> 	unsigned long next;
>> -	unsigned long start = addr;
>> 	unsigned long pages = 0;
>> +	struct mmu_gather tlb;
>> 
>> 	BUG_ON(addr >= end);
>> 	pgd = pgd_offset(mm, addr);
>> 	flush_cache_range(vma, addr, end);
>> 	inc_tlb_flush_pending(mm);
> 
> That seems unbalanced...

Bad rebase. Thanks for catching it!

> 
>> +	tlb_gather_mmu(&tlb, mm);
>> +	tlb_start_vma(&tlb, vma);
>> 	do {
>> 		next = pgd_addr_end(addr, end);
>> 		if (pgd_none_or_clear_bad(pgd))
>> 			continue;
>> -		pages += change_p4d_range(vma, pgd, addr, next, newprot,
>> +		pages += change_p4d_range(&tlb, vma, pgd, addr, next, newprot,
>> 					  cp_flags);
>> 	} while (pgd++, addr = next, addr != end);
>> 
>> -	/* Only flush the TLB if we actually modified any entries: */
>> -	if (pages)
>> -		flush_tlb_range(vma, start, end);
>> -	dec_tlb_flush_pending(mm);
> 
> ... seeing you do remove the extra decrement.

Is it really needed? We do not put this comment elsewhere for
tlb_finish_mmu(). But no problem, I’ll keep it.
Peter Zijlstra Oct. 5, 2021, 6:53 a.m. UTC | #3
On Mon, Oct 04, 2021 at 12:24:14PM -0700, Nadav Amit wrote:
> 
> 
> > On Oct 3, 2021, at 5:10 AM, Peter Zijlstra <peterz@infradead.org> wrote:
> > 
> > On Sat, Sep 25, 2021 at 01:54:22PM -0700, Nadav Amit wrote:
> > 
> >> @@ -338,25 +344,25 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,
> >> 	struct mm_struct *mm = vma->vm_mm;
> >> 	pgd_t *pgd;
> >> 	unsigned long next;
> >> -	unsigned long start = addr;
> >> 	unsigned long pages = 0;
> >> +	struct mmu_gather tlb;
> >> 
> >> 	BUG_ON(addr >= end);
> >> 	pgd = pgd_offset(mm, addr);
> >> 	flush_cache_range(vma, addr, end);
> >> 	inc_tlb_flush_pending(mm);
> > 
> > That seems unbalanced...
> 
> Bad rebase. Thanks for catching it!
> 
> > 
> >> +	tlb_gather_mmu(&tlb, mm);
> >> +	tlb_start_vma(&tlb, vma);
> >> 	do {
> >> 		next = pgd_addr_end(addr, end);
> >> 		if (pgd_none_or_clear_bad(pgd))
> >> 			continue;
> >> -		pages += change_p4d_range(vma, pgd, addr, next, newprot,
> >> +		pages += change_p4d_range(&tlb, vma, pgd, addr, next, newprot,
> >> 					  cp_flags);
> >> 	} while (pgd++, addr = next, addr != end);
> >> 
> >> -	/* Only flush the TLB if we actually modified any entries: */
> >> -	if (pages)
> >> -		flush_tlb_range(vma, start, end);
> >> -	dec_tlb_flush_pending(mm);
> > 
> > ... seeing you do remove the extra decrement.
> 
> Is it really needed? We do not put this comment elsewhere for
> tlb_finish_mmu(). But no problem, I’ll keep it.

-ENOPARSE, did you read decrement as comment? In any case, I don't
particularly care about the comment, and tlb_*_mmu() imply the inc/dec
thingies.

All I tried to do is point out that removing the dec but leaving the inc
is somewhat inconsistent :-)
Nadav Amit Oct. 5, 2021, 4:34 p.m. UTC | #4
> On Oct 4, 2021, at 11:53 PM, Peter Zijlstra <peterz@infradead.org> wrote:
> 
> On Mon, Oct 04, 2021 at 12:24:14PM -0700, Nadav Amit wrote:
>> 
>> 
>>> On Oct 3, 2021, at 5:10 AM, Peter Zijlstra <peterz@infradead.org> wrote:
>>> 
>>> On Sat, Sep 25, 2021 at 01:54:22PM -0700, Nadav Amit wrote:
>>> 
>>>> @@ -338,25 +344,25 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,
>>>> 	struct mm_struct *mm = vma->vm_mm;
>>>> 	pgd_t *pgd;
>>>> 	unsigned long next;
>>>> -	unsigned long start = addr;
>>>> 	unsigned long pages = 0;
>>>> +	struct mmu_gather tlb;
>>>> 
>>>> 	BUG_ON(addr >= end);
>>>> 	pgd = pgd_offset(mm, addr);
>>>> 	flush_cache_range(vma, addr, end);
>>>> 	inc_tlb_flush_pending(mm);
>>> 
>>> That seems unbalanced...
>> 
>> Bad rebase. Thanks for catching it!
>> 
>>> 
>>>> +	tlb_gather_mmu(&tlb, mm);
>>>> +	tlb_start_vma(&tlb, vma);
>>>> 	do {
>>>> 		next = pgd_addr_end(addr, end);
>>>> 		if (pgd_none_or_clear_bad(pgd))
>>>> 			continue;
>>>> -		pages += change_p4d_range(vma, pgd, addr, next, newprot,
>>>> +		pages += change_p4d_range(&tlb, vma, pgd, addr, next, newprot,
>>>> 					  cp_flags);
>>>> 	} while (pgd++, addr = next, addr != end);
>>>> 
>>>> -	/* Only flush the TLB if we actually modified any entries: */
>>>> -	if (pages)
>>>> -		flush_tlb_range(vma, start, end);
>>>> -	dec_tlb_flush_pending(mm);
>>> 
>>> ... seeing you do remove the extra decrement.
>> 
>> Is it really needed? We do not put this comment elsewhere for
>> tlb_finish_mmu(). But no problem, I’ll keep it.
> 
> -ENOPARSE, did you read decrement as comment? In any case, I don't
> particularly care about the comment, and tlb_*_mmu() imply the inc/dec
> thingies.
> 
> All I tried to do is point out that removing the dec but leaving the inc
> is somewhat inconsistent :-)

The autocorrect in my mind was broken so I read as “documentation”
instead of “decrement”.

I will send v2 soon.

Thanks again!
Nadav
Nadav Amit Oct. 11, 2021, 3:45 a.m. UTC | #5
> On Sep 25, 2021, at 1:54 PM, Nadav Amit <nadav.amit@gmail.com> wrote:
> 
> From: Nadav Amit <namit@vmware.com>
> 
> change_pXX_range() currently does not use mmu_gather, but instead
> implements its own deferred TLB flushes scheme. This both complicates
> the code, as developers need to be aware of different invalidation
> schemes, and prevents opportunities to avoid TLB flushes or perform them
> in finer granularity.
> 
> Use mmu_gather in change_pXX_range(). As the pages are not released,
> only record the flushed range using tlb_flush_pXX_range().

Andrea pointed out that I do not take care of THP. Actually, there is
indeed a missing TLB flush on THP, but it is not required due to the
pmdp_invalidate(). Anyhow, the patch needs to address it cleanly, and
to try to avoid the flush on pmdp_invalidate(), which at least on x86
does not appear to be necessary.

There is an additional bug, as tlb_change_page_size() needs to be
called.

-- Jerome,

While I am reviewing my (bad) code, I wanted to understand whether
update of migration entries requires a TLB flush, because I do not
think I got that right either.

I thought they should not, but I now am not very sure. I am very
confused by the following code in migrate_vma_collect_pmd():

        pte_unmap_unlock(ptep - 1, ptl);

        /* Only flush the TLB if we actually modified any entries */
        if (unmapped)
                flush_tlb_range(walk->vma, start, end);


According to this code flush_tlb_range() is called without the ptl.
So theoretically there is a possible race:


	CPU0				CPU1
	----				----
	migrate_vma_collect_pmd()
	 set_pte_at() [ present->
			non-present]

	 pte_unmap_unlock()

					madvise(MADV_DONTNEED)
					 zap_pte_range()

					[ PTE non-present =>
					  no flush ]

So my questions:

1. Is there a reason the above scenario is invalid?
2. Does one need to flush a migration entry he updates it?

Thanks,
Nadav
Peter Xu Oct. 12, 2021, 10:16 a.m. UTC | #6
On Sat, Sep 25, 2021 at 01:54:22PM -0700, Nadav Amit wrote:
> @@ -338,25 +344,25 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,
>  	struct mm_struct *mm = vma->vm_mm;
>  	pgd_t *pgd;
>  	unsigned long next;
> -	unsigned long start = addr;
>  	unsigned long pages = 0;
> +	struct mmu_gather tlb;
>  
>  	BUG_ON(addr >= end);
>  	pgd = pgd_offset(mm, addr);
>  	flush_cache_range(vma, addr, end);
>  	inc_tlb_flush_pending(mm);
> +	tlb_gather_mmu(&tlb, mm);
> +	tlb_start_vma(&tlb, vma);

Pure question:

I actually have no idea why tlb_start_vma() is needed here, as protection range
can be just a single page, but anyway.. I do see that tlb_start_vma() contains
a whole-vma flush_cache_range() when the arch needs it, then does it mean that
besides the inc_tlb_flush_pending() to be dropped, so as to the other call to
flush_cache_range() above?

>  	do {
>  		next = pgd_addr_end(addr, end);
>  		if (pgd_none_or_clear_bad(pgd))
>  			continue;
> -		pages += change_p4d_range(vma, pgd, addr, next, newprot,
> +		pages += change_p4d_range(&tlb, vma, pgd, addr, next, newprot,
>  					  cp_flags);
>  	} while (pgd++, addr = next, addr != end);
>  
> -	/* Only flush the TLB if we actually modified any entries: */
> -	if (pages)
> -		flush_tlb_range(vma, start, end);
> -	dec_tlb_flush_pending(mm);
> +	tlb_end_vma(&tlb, vma);
> +	tlb_finish_mmu(&tlb);
>  
>  	return pages;
>  }
> -- 
> 2.25.1
>
Nadav Amit Oct. 12, 2021, 5:31 p.m. UTC | #7
> On Oct 12, 2021, at 3:16 AM, Peter Xu <peterx@redhat.com> wrote:
> 
> On Sat, Sep 25, 2021 at 01:54:22PM -0700, Nadav Amit wrote:
>> @@ -338,25 +344,25 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,
>> 	struct mm_struct *mm = vma->vm_mm;
>> 	pgd_t *pgd;
>> 	unsigned long next;
>> -	unsigned long start = addr;
>> 	unsigned long pages = 0;
>> +	struct mmu_gather tlb;
>> 
>> 	BUG_ON(addr >= end);
>> 	pgd = pgd_offset(mm, addr);
>> 	flush_cache_range(vma, addr, end);
>> 	inc_tlb_flush_pending(mm);
>> +	tlb_gather_mmu(&tlb, mm);
>> +	tlb_start_vma(&tlb, vma);
> 
> Pure question:
> 
> I actually have no idea why tlb_start_vma() is needed here, as protection range
> can be just a single page, but anyway.. I do see that tlb_start_vma() contains
> a whole-vma flush_cache_range() when the arch needs it, then does it mean that
> besides the inc_tlb_flush_pending() to be dropped, so as to the other call to
> flush_cache_range() above?

Good point.

tlb_start_vma() and tlb_end_vma() are required since some archs do not
batch TLB flushes across VMAs (e.g., ARM). I am not sure whether that’s
the best behavior for all archs, but I do not want to change it.

Anyhow, you make a valid point that the flush_cache_range() should be
dropped as well. I will do so for next version.

Regards,
Nadav
Peter Xu Oct. 12, 2021, 11:20 p.m. UTC | #8
On Tue, Oct 12, 2021 at 10:31:45AM -0700, Nadav Amit wrote:
> 
> 
> > On Oct 12, 2021, at 3:16 AM, Peter Xu <peterx@redhat.com> wrote:
> > 
> > On Sat, Sep 25, 2021 at 01:54:22PM -0700, Nadav Amit wrote:
> >> @@ -338,25 +344,25 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,
> >> 	struct mm_struct *mm = vma->vm_mm;
> >> 	pgd_t *pgd;
> >> 	unsigned long next;
> >> -	unsigned long start = addr;
> >> 	unsigned long pages = 0;
> >> +	struct mmu_gather tlb;
> >> 
> >> 	BUG_ON(addr >= end);
> >> 	pgd = pgd_offset(mm, addr);
> >> 	flush_cache_range(vma, addr, end);
> >> 	inc_tlb_flush_pending(mm);
> >> +	tlb_gather_mmu(&tlb, mm);
> >> +	tlb_start_vma(&tlb, vma);
> > 
> > Pure question:
> > 
> > I actually have no idea why tlb_start_vma() is needed here, as protection range
> > can be just a single page, but anyway.. I do see that tlb_start_vma() contains
> > a whole-vma flush_cache_range() when the arch needs it, then does it mean that
> > besides the inc_tlb_flush_pending() to be dropped, so as to the other call to
> > flush_cache_range() above?
> 
> Good point.
> 
> tlb_start_vma() and tlb_end_vma() are required since some archs do not
> batch TLB flushes across VMAs (e.g., ARM).

Sorry I didn't follow here - as change_protection() is per-vma anyway, so I
don't see why it needs to consider vma crossing.

In all cases, it'll be great if you could add some explanation into commit
message on why we need tlb_{start|end}_vma(), as I think it could not be
obvious to all people.

> I am not sure whether that’s the best behavior for all archs, but I do not
> want to change it.
> 
> Anyhow, you make a valid point that the flush_cache_range() should be
> dropped as well. I will do so for next version.

Thanks,
Nadav Amit Oct. 13, 2021, 3:59 p.m. UTC | #9
> On Oct 12, 2021, at 4:20 PM, Peter Xu <peterx@redhat.com> wrote:
> 
> On Tue, Oct 12, 2021 at 10:31:45AM -0700, Nadav Amit wrote:
>> 
>> 
>>> On Oct 12, 2021, at 3:16 AM, Peter Xu <peterx@redhat.com> wrote:
>>> 
>>> On Sat, Sep 25, 2021 at 01:54:22PM -0700, Nadav Amit wrote:
>>>> @@ -338,25 +344,25 @@ static unsigned long change_protection_range(struct vm_area_struct *vma,
>>>> 	struct mm_struct *mm = vma->vm_mm;
>>>> 	pgd_t *pgd;
>>>> 	unsigned long next;
>>>> -	unsigned long start = addr;
>>>> 	unsigned long pages = 0;
>>>> +	struct mmu_gather tlb;
>>>> 
>>>> 	BUG_ON(addr >= end);
>>>> 	pgd = pgd_offset(mm, addr);
>>>> 	flush_cache_range(vma, addr, end);
>>>> 	inc_tlb_flush_pending(mm);
>>>> +	tlb_gather_mmu(&tlb, mm);
>>>> +	tlb_start_vma(&tlb, vma);
>>> 
>>> Pure question:
>>> 
>>> I actually have no idea why tlb_start_vma() is needed here, as protection range
>>> can be just a single page, but anyway.. I do see that tlb_start_vma() contains
>>> a whole-vma flush_cache_range() when the arch needs it, then does it mean that
>>> besides the inc_tlb_flush_pending() to be dropped, so as to the other call to
>>> flush_cache_range() above?
>> 
>> Good point.
>> 
>> tlb_start_vma() and tlb_end_vma() are required since some archs do not
>> batch TLB flushes across VMAs (e.g., ARM).
> 
> Sorry I didn't follow here - as change_protection() is per-vma anyway, so I
> don't see why it needs to consider vma crossing.
> 
> In all cases, it'll be great if you could add some explanation into commit
> message on why we need tlb_{start|end}_vma(), as I think it could not be
> obvious to all people.

tlb_start_vma() is required when we switch from flush_tlb_range() because
certain properties of the VMA (e.g., executable) are needed on certain
arch. That’s the reason flush_tlb_range() requires the VMA that is
invalidated to be provided.

Regardless, there is an interface and that is the way it is used. I am not
inclined to break it, even if it was possible, for unclear performance
benefits.

As I discussed offline with Andrea and David, switching to tlb_gather_mmu()
interface has additional advantages than batching and avoiding unnecessary
flushes on PTE permission promotion (as done in patch 2). If a single PTE
is updated out of a bigger range, currently flush_tlb_range() would flush
the whole range instead of the single page. In addition, once I fix this
patch-set, if you update a THP, you would (at least on x86) be able to
flush a single PTE instead of flushing 512 entries (which would actually
be done using a full TLB flush).

I would say that as I mentioned in a different thread, and was not
upfront about before, one of the motivations of mine behind this patch
is that I need a vectored UFFDIO_WRITEPROTECTV interface for performance.
Nevertheless, I think these two patches stand by themselves and have
independent value.
diff mbox series

Patch

diff --git a/mm/mprotect.c b/mm/mprotect.c
index 883e2cc85cad..075ff94aa51c 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -32,12 +32,13 @@ 
 #include <asm/cacheflush.h>
 #include <asm/mmu_context.h>
 #include <asm/tlbflush.h>
+#include <asm/tlb.h>
 
 #include "internal.h"
 
-static unsigned long change_pte_range(struct vm_area_struct *vma, pmd_t *pmd,
-		unsigned long addr, unsigned long end, pgprot_t newprot,
-		unsigned long cp_flags)
+static unsigned long change_pte_range(struct mmu_gather *tlb,
+		struct vm_area_struct *vma, pmd_t *pmd, unsigned long addr,
+		unsigned long end, pgprot_t newprot, unsigned long cp_flags)
 {
 	pte_t *pte, oldpte;
 	spinlock_t *ptl;
@@ -138,6 +139,7 @@  static unsigned long change_pte_range(struct vm_area_struct *vma, pmd_t *pmd,
 				ptent = pte_mkwrite(ptent);
 			}
 			ptep_modify_prot_commit(vma, addr, pte, oldpte, ptent);
+			tlb_flush_pte_range(tlb, addr, PAGE_SIZE);
 			pages++;
 		} else if (is_swap_pte(oldpte)) {
 			swp_entry_t entry = pte_to_swp_entry(oldpte);
@@ -219,9 +221,9 @@  static inline int pmd_none_or_clear_bad_unless_trans_huge(pmd_t *pmd)
 	return 0;
 }
 
-static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
-		pud_t *pud, unsigned long addr, unsigned long end,
-		pgprot_t newprot, unsigned long cp_flags)
+static inline unsigned long change_pmd_range(struct mmu_gather *tlb,
+		struct vm_area_struct *vma, pud_t *pud, unsigned long addr,
+		unsigned long end, pgprot_t newprot, unsigned long cp_flags)
 {
 	pmd_t *pmd;
 	unsigned long next;
@@ -261,6 +263,10 @@  static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
 			if (next - addr != HPAGE_PMD_SIZE) {
 				__split_huge_pmd(vma, pmd, addr, false, NULL);
 			} else {
+				/*
+				 * change_huge_pmd() does not defer TLB flushes,
+				 * so no need to propagate the tlb argument.
+				 */
 				int nr_ptes = change_huge_pmd(vma, pmd, addr,
 							      newprot, cp_flags);
 
@@ -276,8 +282,8 @@  static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
 			}
 			/* fall through, the trans huge pmd just split */
 		}
-		this_pages = change_pte_range(vma, pmd, addr, next, newprot,
-					      cp_flags);
+		this_pages = change_pte_range(tlb, vma, pmd, addr, next,
+					      newprot, cp_flags);
 		pages += this_pages;
 next:
 		cond_resched();
@@ -291,9 +297,9 @@  static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
 	return pages;
 }
 
-static inline unsigned long change_pud_range(struct vm_area_struct *vma,
-		p4d_t *p4d, unsigned long addr, unsigned long end,
-		pgprot_t newprot, unsigned long cp_flags)
+static inline unsigned long change_pud_range(struct mmu_gather *tlb,
+		struct vm_area_struct *vma, p4d_t *p4d, unsigned long addr,
+		unsigned long end, pgprot_t newprot, unsigned long cp_flags)
 {
 	pud_t *pud;
 	unsigned long next;
@@ -304,16 +310,16 @@  static inline unsigned long change_pud_range(struct vm_area_struct *vma,
 		next = pud_addr_end(addr, end);
 		if (pud_none_or_clear_bad(pud))
 			continue;
-		pages += change_pmd_range(vma, pud, addr, next, newprot,
+		pages += change_pmd_range(tlb, vma, pud, addr, next, newprot,
 					  cp_flags);
 	} while (pud++, addr = next, addr != end);
 
 	return pages;
 }
 
-static inline unsigned long change_p4d_range(struct vm_area_struct *vma,
-		pgd_t *pgd, unsigned long addr, unsigned long end,
-		pgprot_t newprot, unsigned long cp_flags)
+static inline unsigned long change_p4d_range(struct mmu_gather *tlb,
+		struct vm_area_struct *vma, pgd_t *pgd, unsigned long addr,
+		unsigned long end, pgprot_t newprot, unsigned long cp_flags)
 {
 	p4d_t *p4d;
 	unsigned long next;
@@ -324,7 +330,7 @@  static inline unsigned long change_p4d_range(struct vm_area_struct *vma,
 		next = p4d_addr_end(addr, end);
 		if (p4d_none_or_clear_bad(p4d))
 			continue;
-		pages += change_pud_range(vma, p4d, addr, next, newprot,
+		pages += change_pud_range(tlb, vma, p4d, addr, next, newprot,
 					  cp_flags);
 	} while (p4d++, addr = next, addr != end);
 
@@ -338,25 +344,25 @@  static unsigned long change_protection_range(struct vm_area_struct *vma,
 	struct mm_struct *mm = vma->vm_mm;
 	pgd_t *pgd;
 	unsigned long next;
-	unsigned long start = addr;
 	unsigned long pages = 0;
+	struct mmu_gather tlb;
 
 	BUG_ON(addr >= end);
 	pgd = pgd_offset(mm, addr);
 	flush_cache_range(vma, addr, end);
 	inc_tlb_flush_pending(mm);
+	tlb_gather_mmu(&tlb, mm);
+	tlb_start_vma(&tlb, vma);
 	do {
 		next = pgd_addr_end(addr, end);
 		if (pgd_none_or_clear_bad(pgd))
 			continue;
-		pages += change_p4d_range(vma, pgd, addr, next, newprot,
+		pages += change_p4d_range(&tlb, vma, pgd, addr, next, newprot,
 					  cp_flags);
 	} while (pgd++, addr = next, addr != end);
 
-	/* Only flush the TLB if we actually modified any entries: */
-	if (pages)
-		flush_tlb_range(vma, start, end);
-	dec_tlb_flush_pending(mm);
+	tlb_end_vma(&tlb, vma);
+	tlb_finish_mmu(&tlb);
 
 	return pages;
 }