diff mbox series

[1/4] mm_notifiers: Rename invalidate_range notifier

Message ID c0daf0870f7220bbf815713463aff86970a5d0fa.1689666760.git-series.apopple@nvidia.com (mailing list archive)
State New, archived
Headers show
Series Invalidate secondary IOMMU TLB on permission upgrade | expand

Commit Message

Alistair Popple July 18, 2023, 7:56 a.m. UTC
There are two main use cases for mmu notifiers. One is by KVM which
uses mmu_notifier_invalidate_range_start()/end() to manage a software
TLB.

The other is to manage hardware TLBs which need to use the
invalidate_range() callback because HW can establish new TLB entries
at any time. Hence using start/end() can lead to memory corruption as
these callbacks happen too soon/late during page unmap.

mmu notifier users should therefore either use the start()/end()
callbacks or the invalidate_range() callbacks. To make this usage
clearer rename the invalidate_range() callback to
arch_invalidate_secondary_tlbs() and update documention.

Signed-off-by: Alistair Popple <apopple@nvidia.com>
Suggested-by: Jason Gunthorpe <jgg@nvidia.com>
---
 arch/x86/mm/tlb.c                               |  1 +-
 drivers/iommu/amd/iommu_v2.c                    | 10 +--
 drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c | 13 ++--
 drivers/iommu/intel/svm.c                       |  8 +--
 drivers/misc/ocxl/link.c                        |  8 +--
 include/asm-generic/tlb.h                       |  2 +-
 include/linux/mmu_notifier.h                    | 54 +++++++++---------
 mm/huge_memory.c                                |  4 +-
 mm/hugetlb.c                                    | 10 +--
 mm/mmu_notifier.c                               | 52 ++++++++++-------
 mm/rmap.c                                       | 42 +++++++-------
 11 files changed, 110 insertions(+), 94 deletions(-)

Comments

Jason Gunthorpe July 18, 2023, 5:57 p.m. UTC | #1
On Tue, Jul 18, 2023 at 05:56:15PM +1000, Alistair Popple wrote:
> diff --git a/include/asm-generic/tlb.h b/include/asm-generic/tlb.h
> index b466172..48c81b9 100644
> --- a/include/asm-generic/tlb.h
> +++ b/include/asm-generic/tlb.h
> @@ -456,7 +456,7 @@ static inline void tlb_flush_mmu_tlbonly(struct mmu_gather *tlb)
>  		return;
>  
>  	tlb_flush(tlb);
> -	mmu_notifier_invalidate_range(tlb->mm, tlb->start, tlb->end);
> +	mmu_notifier_invalidate_secondary_tlbs(tlb->mm, tlb->start, tlb->end);
>  	__tlb_reset_range(tlb);

Does this compile? I don't see
"mmu_notifier_invalidate_secondary_tlbs" ?

Maybe we don't need to rename this function since you pretty much
remove it in the next patches?

> diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
> index 50c0dde..34c5a84 100644
> --- a/mm/mmu_notifier.c
> +++ b/mm/mmu_notifier.c
> @@ -207,7 +207,7 @@ mmu_interval_read_begin(struct mmu_interval_notifier *interval_sub)
>  	 *    spin_lock
>  	 *     seq = ++subscriptions->invalidate_seq
>  	 *    spin_unlock
> -	 *     op->invalidate_range():
> +	 *     op->invalidate_secondary_tlbs():

The later patch should delete this stuff from the comment too, we
no longer guarantee this relationship?

> @@ -560,23 +560,23 @@ mn_hlist_invalidate_end(struct mmu_notifier_subscriptions *subscriptions,
>  	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
>  				 srcu_read_lock_held(&srcu)) {
>  		/*
> -		 * Call invalidate_range here too to avoid the need for the
> -		 * subsystem of having to register an invalidate_range_end
> -		 * call-back when there is invalidate_range already. Usually a
> -		 * subsystem registers either invalidate_range_start()/end() or
> -		 * invalidate_range(), so this will be no additional overhead
> -		 * (besides the pointer check).
> +		 * Subsystems should register either invalidate_secondary_tlbs()
> +		 * or invalidate_range_start()/end() callbacks.
>  		 *
> -		 * We skip call to invalidate_range() if we know it is safe ie
> -		 * call site use mmu_notifier_invalidate_range_only_end() which
> -		 * is safe to do when we know that a call to invalidate_range()
> -		 * already happen under page table lock.
> +		 * We call invalidate_secondary_tlbs() here so that subsystems
> +		 * can use larger range based invalidations. In some cases
> +		 * though invalidate_secondary_tlbs() needs to be called while
> +		 * holding the page table lock. In that case call sites use
> +		 * mmu_notifier_invalidate_range_only_end() and we know it is
> +		 * safe to skip secondary TLB invalidation as it will have
> +		 * already been done.
>  		 */
> -		if (!only_end && subscription->ops->invalidate_range)
> -			subscription->ops->invalidate_range(subscription,
> -							    range->mm,
> -							    range->start,
> -							    range->end);
> +		if (!only_end && subscription->ops->invalidate_secondary_tlbs)
> +			subscription->ops->invalidate_secondary_tlbs(

More doesn't compile, and the comment has the same issue..

But I think the approach in this series looks fine, it is so much
cleaner after we remove all the cruft in patch 4, just look at the
diffstat..

Jason
Andrew Morton July 18, 2023, 6:36 p.m. UTC | #2
On Tue, 18 Jul 2023 14:57:12 -0300 Jason Gunthorpe <jgg@ziepe.ca> wrote:

> On Tue, Jul 18, 2023 at 05:56:15PM +1000, Alistair Popple wrote:
> > diff --git a/include/asm-generic/tlb.h b/include/asm-generic/tlb.h
> > index b466172..48c81b9 100644
> > --- a/include/asm-generic/tlb.h
> > +++ b/include/asm-generic/tlb.h
> > @@ -456,7 +456,7 @@ static inline void tlb_flush_mmu_tlbonly(struct mmu_gather *tlb)
> >  		return;
> >  
> >  	tlb_flush(tlb);
> > -	mmu_notifier_invalidate_range(tlb->mm, tlb->start, tlb->end);
> > +	mmu_notifier_invalidate_secondary_tlbs(tlb->mm, tlb->start, tlb->end);
> >  	__tlb_reset_range(tlb);
> 
> Does this compile? I don't see
> "mmu_notifier_invalidate_secondary_tlbs" ?

Seems this call gets deleted later in the series.

> But I think the approach in this series looks fine, it is so much
> cleaner after we remove all the cruft in patch 4, just look at the
> diffstat..

I'll push this into -next if it compiles OK for me, but yes, a redo is
desirable please.
Alistair Popple July 18, 2023, 11:49 p.m. UTC | #3
Andrew Morton <akpm@linux-foundation.org> writes:

> On Tue, 18 Jul 2023 14:57:12 -0300 Jason Gunthorpe <jgg@ziepe.ca> wrote:
>
>> On Tue, Jul 18, 2023 at 05:56:15PM +1000, Alistair Popple wrote:
>> > diff --git a/include/asm-generic/tlb.h b/include/asm-generic/tlb.h
>> > index b466172..48c81b9 100644
>> > --- a/include/asm-generic/tlb.h
>> > +++ b/include/asm-generic/tlb.h
>> > @@ -456,7 +456,7 @@ static inline void tlb_flush_mmu_tlbonly(struct mmu_gather *tlb)
>> >  		return;
>> >  
>> >  	tlb_flush(tlb);
>> > -	mmu_notifier_invalidate_range(tlb->mm, tlb->start, tlb->end);
>> > +	mmu_notifier_invalidate_secondary_tlbs(tlb->mm, tlb->start, tlb->end);
>> >  	__tlb_reset_range(tlb);
>> 
>> Does this compile? I don't see
>> "mmu_notifier_invalidate_secondary_tlbs" ?

Dang, sorry. The original rename was to that but then we added *_arch_*
and I obviously missed some of the already renamed calls.

> Seems this call gets deleted later in the series.
>
>> But I think the approach in this series looks fine, it is so much
>> cleaner after we remove all the cruft in patch 4, just look at the
>> diffstat..
>
> I'll push this into -next if it compiles OK for me, but yes, a redo is
> desirable please.

Yep, will respin.
diff mbox series

Patch

diff --git a/arch/x86/mm/tlb.c b/arch/x86/mm/tlb.c
index 267acf2..eaefc10 100644
--- a/arch/x86/mm/tlb.c
+++ b/arch/x86/mm/tlb.c
@@ -10,6 +10,7 @@ 
 #include <linux/debugfs.h>
 #include <linux/sched/smt.h>
 #include <linux/task_work.h>
+#include <linux/mmu_notifier.h>
 
 #include <asm/tlbflush.h>
 #include <asm/mmu_context.h>
diff --git a/drivers/iommu/amd/iommu_v2.c b/drivers/iommu/amd/iommu_v2.c
index 261352a..2596466 100644
--- a/drivers/iommu/amd/iommu_v2.c
+++ b/drivers/iommu/amd/iommu_v2.c
@@ -355,9 +355,9 @@  static struct pasid_state *mn_to_state(struct mmu_notifier *mn)
 	return container_of(mn, struct pasid_state, mn);
 }
 
-static void mn_invalidate_range(struct mmu_notifier *mn,
-				struct mm_struct *mm,
-				unsigned long start, unsigned long end)
+static void mn_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
+					struct mm_struct *mm,
+					unsigned long start, unsigned long end)
 {
 	struct pasid_state *pasid_state;
 	struct device_state *dev_state;
@@ -391,8 +391,8 @@  static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
 }
 
 static const struct mmu_notifier_ops iommu_mn = {
-	.release		= mn_release,
-	.invalidate_range       = mn_invalidate_range,
+	.release			= mn_release,
+	.arch_invalidate_secondary_tlbs	= mn_arch_invalidate_secondary_tlbs,
 };
 
 static void set_pri_tag_status(struct pasid_state *pasid_state,
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
index a5a63b1..aa63cff 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
@@ -186,9 +186,10 @@  static void arm_smmu_free_shared_cd(struct arm_smmu_ctx_desc *cd)
 	}
 }
 
-static void arm_smmu_mm_invalidate_range(struct mmu_notifier *mn,
-					 struct mm_struct *mm,
-					 unsigned long start, unsigned long end)
+static void arm_smmu_mm_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
+						struct mm_struct *mm,
+						unsigned long start,
+						unsigned long end)
 {
 	struct arm_smmu_mmu_notifier *smmu_mn = mn_to_smmu(mn);
 	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
@@ -237,9 +238,9 @@  static void arm_smmu_mmu_notifier_free(struct mmu_notifier *mn)
 }
 
 static const struct mmu_notifier_ops arm_smmu_mmu_notifier_ops = {
-	.invalidate_range	= arm_smmu_mm_invalidate_range,
-	.release		= arm_smmu_mm_release,
-	.free_notifier		= arm_smmu_mmu_notifier_free,
+	.arch_invalidate_secondary_tlbs	= arm_smmu_mm_arch_invalidate_secondary_tlbs,
+	.release			= arm_smmu_mm_release,
+	.free_notifier			= arm_smmu_mmu_notifier_free,
 };
 
 /* Allocate or get existing MMU notifier for this {domain, mm} pair */
diff --git a/drivers/iommu/intel/svm.c b/drivers/iommu/intel/svm.c
index e95b339..8f6d680 100644
--- a/drivers/iommu/intel/svm.c
+++ b/drivers/iommu/intel/svm.c
@@ -219,9 +219,9 @@  static void intel_flush_svm_range(struct intel_svm *svm, unsigned long address,
 }
 
 /* Pages have been freed at this point */
-static void intel_invalidate_range(struct mmu_notifier *mn,
-				   struct mm_struct *mm,
-				   unsigned long start, unsigned long end)
+static void intel_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
+					struct mm_struct *mm,
+					unsigned long start, unsigned long end)
 {
 	struct intel_svm *svm = container_of(mn, struct intel_svm, notifier);
 
@@ -256,7 +256,7 @@  static void intel_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 
 static const struct mmu_notifier_ops intel_mmuops = {
 	.release = intel_mm_release,
-	.invalidate_range = intel_invalidate_range,
+	.arch_invalidate_secondary_tlbs = intel_arch_invalidate_secondary_tlbs,
 };
 
 static DEFINE_MUTEX(pasid_mutex);
diff --git a/drivers/misc/ocxl/link.c b/drivers/misc/ocxl/link.c
index 4cf4c55..c06c699 100644
--- a/drivers/misc/ocxl/link.c
+++ b/drivers/misc/ocxl/link.c
@@ -491,9 +491,9 @@  void ocxl_link_release(struct pci_dev *dev, void *link_handle)
 }
 EXPORT_SYMBOL_GPL(ocxl_link_release);
 
-static void invalidate_range(struct mmu_notifier *mn,
-			     struct mm_struct *mm,
-			     unsigned long start, unsigned long end)
+static void arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
+					struct mm_struct *mm,
+					unsigned long start, unsigned long end)
 {
 	struct pe_data *pe_data = container_of(mn, struct pe_data, mmu_notifier);
 	struct ocxl_link *link = pe_data->link;
@@ -509,7 +509,7 @@  static void invalidate_range(struct mmu_notifier *mn,
 }
 
 static const struct mmu_notifier_ops ocxl_mmu_notifier_ops = {
-	.invalidate_range = invalidate_range,
+	.arch_invalidate_secondary_tlbs = arch_invalidate_secondary_tlbs,
 };
 
 static u64 calculate_cfg_state(bool kernel)
diff --git a/include/asm-generic/tlb.h b/include/asm-generic/tlb.h
index b466172..48c81b9 100644
--- a/include/asm-generic/tlb.h
+++ b/include/asm-generic/tlb.h
@@ -456,7 +456,7 @@  static inline void tlb_flush_mmu_tlbonly(struct mmu_gather *tlb)
 		return;
 
 	tlb_flush(tlb);
-	mmu_notifier_invalidate_range(tlb->mm, tlb->start, tlb->end);
+	mmu_notifier_invalidate_secondary_tlbs(tlb->mm, tlb->start, tlb->end);
 	__tlb_reset_range(tlb);
 }
 
diff --git a/include/linux/mmu_notifier.h b/include/linux/mmu_notifier.h
index 64a3e05..a4bc818 100644
--- a/include/linux/mmu_notifier.h
+++ b/include/linux/mmu_notifier.h
@@ -187,27 +187,27 @@  struct mmu_notifier_ops {
 				     const struct mmu_notifier_range *range);
 
 	/*
-	 * invalidate_range() is either called between
-	 * invalidate_range_start() and invalidate_range_end() when the
-	 * VM has to free pages that where unmapped, but before the
-	 * pages are actually freed, or outside of _start()/_end() when
-	 * a (remote) TLB is necessary.
+	 * arch_invalidate_secondary_tlbs() is used to manage a non-CPU TLB
+	 * which shares page-tables with the CPU. The
+	 * invalidate_range_start()/end() callbacks should not be implemented as
+	 * invalidate_secondary_tlbs() already catches the points in time when
+	 * an external TLB needs to be flushed.
 	 *
-	 * If invalidate_range() is used to manage a non-CPU TLB with
-	 * shared page-tables, it not necessary to implement the
-	 * invalidate_range_start()/end() notifiers, as
-	 * invalidate_range() already catches the points in time when an
-	 * external TLB range needs to be flushed. For more in depth
-	 * discussion on this see Documentation/mm/mmu_notifier.rst
+	 * This requires arch_invalidate_secondary_tlbs() to be called while
+	 * holding the ptl spin-lock and therefore this callback is not allowed
+	 * to sleep.
 	 *
-	 * Note that this function might be called with just a sub-range
-	 * of what was passed to invalidate_range_start()/end(), if
-	 * called between those functions.
+	 * This is called by architecture code whenever invalidating a TLB
+	 * entry. It is assumed that any secondary TLB has the same rules for
+	 * when invalidations are required. If this is not the case architecture
+	 * code will need to call this explicitly when required for secondary
+	 * TLB invalidation.
 	 */
-	void (*invalidate_range)(struct mmu_notifier *subscription,
-				 struct mm_struct *mm,
-				 unsigned long start,
-				 unsigned long end);
+	void (*arch_invalidate_secondary_tlbs)(
+					struct mmu_notifier *subscription,
+					struct mm_struct *mm,
+					unsigned long start,
+					unsigned long end);
 
 	/*
 	 * These callbacks are used with the get/put interface to manage the
@@ -397,8 +397,8 @@  extern void __mmu_notifier_change_pte(struct mm_struct *mm,
 extern int __mmu_notifier_invalidate_range_start(struct mmu_notifier_range *r);
 extern void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *r,
 				  bool only_end);
-extern void __mmu_notifier_invalidate_range(struct mm_struct *mm,
-				  unsigned long start, unsigned long end);
+extern void __mmu_notifier_arch_invalidate_secondary_tlbs(struct mm_struct *mm,
+					unsigned long start, unsigned long end);
 extern bool
 mmu_notifier_range_update_to_read_only(const struct mmu_notifier_range *range);
 
@@ -491,11 +491,11 @@  mmu_notifier_invalidate_range_only_end(struct mmu_notifier_range *range)
 		__mmu_notifier_invalidate_range_end(range, true);
 }
 
-static inline void mmu_notifier_invalidate_range(struct mm_struct *mm,
-				  unsigned long start, unsigned long end)
+static inline void mmu_notifier_arch_invalidate_secondary_tlbs(struct mm_struct *mm,
+					unsigned long start, unsigned long end)
 {
 	if (mm_has_notifiers(mm))
-		__mmu_notifier_invalidate_range(mm, start, end);
+		__mmu_notifier_arch_invalidate_secondary_tlbs(mm, start, end);
 }
 
 static inline void mmu_notifier_subscriptions_init(struct mm_struct *mm)
@@ -589,7 +589,7 @@  static inline void mmu_notifier_range_init_owner(
 	pte_t ___pte;							\
 									\
 	___pte = ptep_clear_flush(__vma, __address, __ptep);		\
-	mmu_notifier_invalidate_range(___mm, ___addr,			\
+	mmu_notifier_arch_invalidate_secondary_tlbs(___mm, ___addr,		\
 					___addr + PAGE_SIZE);		\
 									\
 	___pte;								\
@@ -602,7 +602,7 @@  static inline void mmu_notifier_range_init_owner(
 	pmd_t ___pmd;							\
 									\
 	___pmd = pmdp_huge_clear_flush(__vma, __haddr, __pmd);		\
-	mmu_notifier_invalidate_range(___mm, ___haddr,			\
+	mmu_notifier_arch_invalidate_secondary_tlbs(___mm, ___haddr,		\
 				      ___haddr + HPAGE_PMD_SIZE);	\
 									\
 	___pmd;								\
@@ -615,7 +615,7 @@  static inline void mmu_notifier_range_init_owner(
 	pud_t ___pud;							\
 									\
 	___pud = pudp_huge_clear_flush(__vma, __haddr, __pud);		\
-	mmu_notifier_invalidate_range(___mm, ___haddr,			\
+	mmu_notifier_arch_invalidate_secondary_tlbs(___mm, ___haddr,		\
 				      ___haddr + HPAGE_PUD_SIZE);	\
 									\
 	___pud;								\
@@ -716,7 +716,7 @@  mmu_notifier_invalidate_range_only_end(struct mmu_notifier_range *range)
 {
 }
 
-static inline void mmu_notifier_invalidate_range(struct mm_struct *mm,
+static inline void mmu_notifier_arch_invalidate_secondary_tlbs(struct mm_struct *mm,
 				  unsigned long start, unsigned long end)
 {
 }
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index eb36783..a232891 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -2124,8 +2124,8 @@  static void __split_huge_pmd_locked(struct vm_area_struct *vma, pmd_t *pmd,
 	if (is_huge_zero_pmd(*pmd)) {
 		/*
 		 * FIXME: Do we want to invalidate secondary mmu by calling
-		 * mmu_notifier_invalidate_range() see comments below inside
-		 * __split_huge_pmd() ?
+		 * mmu_notifier_arch_invalidate_secondary_tlbs() see comments below
+		 * inside __split_huge_pmd() ?
 		 *
 		 * We are going from a zero huge page write protected to zero
 		 * small page also write protected so it does not seems useful
diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index 64a3239..178c930 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -5690,7 +5690,8 @@  static vm_fault_t hugetlb_wp(struct mm_struct *mm, struct vm_area_struct *vma,
 
 		/* Break COW or unshare */
 		huge_ptep_clear_flush(vma, haddr, ptep);
-		mmu_notifier_invalidate_range(mm, range.start, range.end);
+		mmu_notifier_arch_invalidate_secondary_tlbs(mm, range.start,
+						range.end);
 		page_remove_rmap(&old_folio->page, vma, true);
 		hugepage_add_new_anon_rmap(new_folio, vma, haddr);
 		if (huge_pte_uffd_wp(pte))
@@ -6822,8 +6823,9 @@  long hugetlb_change_protection(struct vm_area_struct *vma,
 	else
 		flush_hugetlb_tlb_range(vma, start, end);
 	/*
-	 * No need to call mmu_notifier_invalidate_range() we are downgrading
-	 * page table protection not changing it to point to a new page.
+	 * No need to call mmu_notifier_arch_invalidate_secondary_tlbs() we are
+	 * downgrading page table protection not changing it to point to a new
+	 * page.
 	 *
 	 * See Documentation/mm/mmu_notifier.rst
 	 */
@@ -7467,7 +7469,7 @@  static void hugetlb_unshare_pmds(struct vm_area_struct *vma,
 	i_mmap_unlock_write(vma->vm_file->f_mapping);
 	hugetlb_vma_unlock_write(vma);
 	/*
-	 * No need to call mmu_notifier_invalidate_range(), see
+	 * No need to call mmu_notifier_arch_invalidate_secondary_tlbs(), see
 	 * Documentation/mm/mmu_notifier.rst.
 	 */
 	mmu_notifier_invalidate_range_end(&range);
diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index 50c0dde..34c5a84 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -207,7 +207,7 @@  mmu_interval_read_begin(struct mmu_interval_notifier *interval_sub)
 	 *    spin_lock
 	 *     seq = ++subscriptions->invalidate_seq
 	 *    spin_unlock
-	 *     op->invalidate_range():
+	 *     op->invalidate_secondary_tlbs():
 	 *       user_lock
 	 *        mmu_interval_set_seq()
 	 *         interval_sub->invalidate_seq = seq
@@ -560,23 +560,23 @@  mn_hlist_invalidate_end(struct mmu_notifier_subscriptions *subscriptions,
 	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
 		/*
-		 * Call invalidate_range here too to avoid the need for the
-		 * subsystem of having to register an invalidate_range_end
-		 * call-back when there is invalidate_range already. Usually a
-		 * subsystem registers either invalidate_range_start()/end() or
-		 * invalidate_range(), so this will be no additional overhead
-		 * (besides the pointer check).
+		 * Subsystems should register either invalidate_secondary_tlbs()
+		 * or invalidate_range_start()/end() callbacks.
 		 *
-		 * We skip call to invalidate_range() if we know it is safe ie
-		 * call site use mmu_notifier_invalidate_range_only_end() which
-		 * is safe to do when we know that a call to invalidate_range()
-		 * already happen under page table lock.
+		 * We call invalidate_secondary_tlbs() here so that subsystems
+		 * can use larger range based invalidations. In some cases
+		 * though invalidate_secondary_tlbs() needs to be called while
+		 * holding the page table lock. In that case call sites use
+		 * mmu_notifier_invalidate_range_only_end() and we know it is
+		 * safe to skip secondary TLB invalidation as it will have
+		 * already been done.
 		 */
-		if (!only_end && subscription->ops->invalidate_range)
-			subscription->ops->invalidate_range(subscription,
-							    range->mm,
-							    range->start,
-							    range->end);
+		if (!only_end && subscription->ops->invalidate_secondary_tlbs)
+			subscription->ops->invalidate_secondary_tlbs(
+				subscription,
+				range->mm,
+				range->start,
+				range->end);
 		if (subscription->ops->invalidate_range_end) {
 			if (!mmu_notifier_range_blockable(range))
 				non_block_start();
@@ -604,8 +604,8 @@  void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
 	lock_map_release(&__mmu_notifier_invalidate_range_start_map);
 }
 
-void __mmu_notifier_invalidate_range(struct mm_struct *mm,
-				  unsigned long start, unsigned long end)
+void __mmu_notifier_arch_invalidate_secondary_tlbs(struct mm_struct *mm,
+					unsigned long start, unsigned long end)
 {
 	struct mmu_notifier *subscription;
 	int id;
@@ -614,9 +614,10 @@  void __mmu_notifier_invalidate_range(struct mm_struct *mm,
 	hlist_for_each_entry_rcu(subscription,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
-		if (subscription->ops->invalidate_range)
-			subscription->ops->invalidate_range(subscription, mm,
-							    start, end);
+		if (subscription->ops->arch_invalidate_secondary_tlbs)
+			subscription->ops->arch_invalidate_secondary_tlbs(
+				subscription, mm,
+				start, end);
 	}
 	srcu_read_unlock(&srcu, id);
 }
@@ -635,6 +636,15 @@  int __mmu_notifier_register(struct mmu_notifier *subscription,
 	mmap_assert_write_locked(mm);
 	BUG_ON(atomic_read(&mm->mm_users) <= 0);
 
+	/*
+	 * Subsystems should only register for invalidate_secondary_tlbs() or
+	 * invalidate_range_start()/end() callbacks, not both.
+	 */
+	if (WARN_ON_ONCE(subscription->ops->arch_invalidate_secondary_tlbs &&
+				(subscription->ops->invalidate_range_start ||
+				subscription->ops->invalidate_range_end)))
+		return -EINVAL;
+
 	if (!mm->notifier_subscriptions) {
 		/*
 		 * kmalloc cannot be called under mm_take_all_locks(), but we
diff --git a/mm/rmap.c b/mm/rmap.c
index 0c0d885..b74fc2c 100644
--- a/mm/rmap.c
+++ b/mm/rmap.c
@@ -991,9 +991,9 @@  static int page_vma_mkclean_one(struct page_vma_mapped_walk *pvmw)
 		}
 
 		/*
-		 * No need to call mmu_notifier_invalidate_range() as we are
-		 * downgrading page table protection not changing it to point
-		 * to a new page.
+		 * No need to call mmu_notifier_arch_invalidate_secondary_tlbs() as
+		 * we are downgrading page table protection not changing it to
+		 * point to a new page.
 		 *
 		 * See Documentation/mm/mmu_notifier.rst
 		 */
@@ -1554,8 +1554,8 @@  static bool try_to_unmap_one(struct folio *folio, struct vm_area_struct *vma,
 					hugetlb_vma_unlock_write(vma);
 					flush_tlb_range(vma,
 						range.start, range.end);
-					mmu_notifier_invalidate_range(mm,
-						range.start, range.end);
+					mmu_notifier_arch_invalidate_secondary_tlbs(
+						mm, range.start, range.end);
 					/*
 					 * The ref count of the PMD page was
 					 * dropped which is part of the way map
@@ -1629,7 +1629,7 @@  static bool try_to_unmap_one(struct folio *folio, struct vm_area_struct *vma,
 			 */
 			dec_mm_counter(mm, mm_counter(&folio->page));
 			/* We have to invalidate as we cleared the pte */
-			mmu_notifier_invalidate_range(mm, address,
+			mmu_notifier_arch_invalidate_secondary_tlbs(mm, address,
 						      address + PAGE_SIZE);
 		} else if (folio_test_anon(folio)) {
 			swp_entry_t entry = { .val = page_private(subpage) };
@@ -1643,7 +1643,8 @@  static bool try_to_unmap_one(struct folio *folio, struct vm_area_struct *vma,
 				WARN_ON_ONCE(1);
 				ret = false;
 				/* We have to invalidate as we cleared the pte */
-				mmu_notifier_invalidate_range(mm, address,
+				mmu_notifier_arch_invalidate_secondary_tlbs(mm,
+							address,
 							address + PAGE_SIZE);
 				page_vma_mapped_walk_done(&pvmw);
 				break;
@@ -1676,8 +1677,9 @@  static bool try_to_unmap_one(struct folio *folio, struct vm_area_struct *vma,
 				if (ref_count == 1 + map_count &&
 				    !folio_test_dirty(folio)) {
 					/* Invalidate as we cleared the pte */
-					mmu_notifier_invalidate_range(mm,
-						address, address + PAGE_SIZE);
+					mmu_notifier_arch_invalidate_secondary_tlbs(
+						mm, address,
+						address + PAGE_SIZE);
 					dec_mm_counter(mm, MM_ANONPAGES);
 					goto discard;
 				}
@@ -1733,7 +1735,7 @@  static bool try_to_unmap_one(struct folio *folio, struct vm_area_struct *vma,
 				swp_pte = pte_swp_mkuffd_wp(swp_pte);
 			set_pte_at(mm, address, pvmw.pte, swp_pte);
 			/* Invalidate as we cleared the pte */
-			mmu_notifier_invalidate_range(mm, address,
+			mmu_notifier_arch_invalidate_secondary_tlbs(mm, address,
 						      address + PAGE_SIZE);
 		} else {
 			/*
@@ -1751,9 +1753,9 @@  static bool try_to_unmap_one(struct folio *folio, struct vm_area_struct *vma,
 		}
 discard:
 		/*
-		 * No need to call mmu_notifier_invalidate_range() it has be
-		 * done above for all cases requiring it to happen under page
-		 * table lock before mmu_notifier_invalidate_range_end()
+		 * No need to call mmu_notifier_arch_invalidate_secondary_tlbs() it
+		 * has be done above for all cases requiring it to happen under
+		 * page table lock before mmu_notifier_invalidate_range_end()
 		 *
 		 * See Documentation/mm/mmu_notifier.rst
 		 */
@@ -1935,8 +1937,8 @@  static bool try_to_migrate_one(struct folio *folio, struct vm_area_struct *vma,
 					hugetlb_vma_unlock_write(vma);
 					flush_tlb_range(vma,
 						range.start, range.end);
-					mmu_notifier_invalidate_range(mm,
-						range.start, range.end);
+					mmu_notifier_arch_invalidate_secondary_tlbs(
+						mm, range.start, range.end);
 
 					/*
 					 * The ref count of the PMD page was
@@ -2042,8 +2044,8 @@  static bool try_to_migrate_one(struct folio *folio, struct vm_area_struct *vma,
 			 */
 			dec_mm_counter(mm, mm_counter(&folio->page));
 			/* We have to invalidate as we cleared the pte */
-			mmu_notifier_invalidate_range(mm, address,
-						      address + PAGE_SIZE);
+			mmu_notifier_arch_invalidate_secondary_tlbs(mm, address,
+							address + PAGE_SIZE);
 		} else {
 			swp_entry_t entry;
 			pte_t swp_pte;
@@ -2108,9 +2110,9 @@  static bool try_to_migrate_one(struct folio *folio, struct vm_area_struct *vma,
 		}
 
 		/*
-		 * No need to call mmu_notifier_invalidate_range() it has be
-		 * done above for all cases requiring it to happen under page
-		 * table lock before mmu_notifier_invalidate_range_end()
+		 * No need to call mmu_notifier_arch_invalidate_secondary_tlbs() it
+		 * has be done above for all cases requiring it to happen under
+		 * page table lock before mmu_notifier_invalidate_range_end()
 		 *
 		 * See Documentation/mm/mmu_notifier.rst
 		 */