diff mbox series

[v6,5/6] nouveau: use new mmu interval notifiers

Message ID 20200113224703.5917-6-rcampbell@nvidia.com
State New, archived
Headers show
Series mm/hmm/test: add self tests for HMM | expand

Commit Message

Ralph Campbell Jan. 13, 2020, 10:47 p.m. UTC
Update nouveau to only use the mmu interval notifiers.

Signed-off-by: Ralph Campbell <rcampbell@nvidia.com>
---
 drivers/gpu/drm/nouveau/nouveau_svm.c | 313 +++++++++++++++++---------
 1 file changed, 201 insertions(+), 112 deletions(-)

Comments

Jason Gunthorpe Jan. 14, 2020, 1 p.m. UTC | #1
On Mon, Jan 13, 2020 at 02:47:02PM -0800, Ralph Campbell wrote:
>  void
>  nouveau_svmm_fini(struct nouveau_svmm **psvmm)
>  {
>  	struct nouveau_svmm *svmm = *psvmm;
> +	struct mmu_interval_notifier *mni;
> +
>  	if (svmm) {
>  		mutex_lock(&svmm->mutex);
> +		while (true) {
> +			mni = mmu_interval_notifier_find(svmm->mm,
> +					&nouveau_svm_mni_ops, 0UL, ~0UL);
> +			if (!mni)
> +				break;
> +			mmu_interval_notifier_put(mni);

Oh, now I really don't like the name 'put'. It looks like mni is
refcounted here, and it isn't. put should be called 'remove_deferred'

And then you also need a way to barrier this scheme on driver unload.

> +		}
>  		svmm->vmm = NULL;
>  		mutex_unlock(&svmm->mutex);
> -		mmu_notifier_put(&svmm->notifier);

While here it was actually a refcount.

> +static void nouveau_svmm_do_unmap(struct mmu_interval_notifier *mni,
> +				 const struct mmu_notifier_range *range)
> +{
> +	struct svmm_interval *smi =
> +		container_of(mni, struct svmm_interval, notifier);
> +	struct nouveau_svmm *svmm = smi->svmm;
> +	unsigned long start = mmu_interval_notifier_start(mni);
> +	unsigned long last = mmu_interval_notifier_last(mni);

This whole algorithm only works if it is protected by the read side of
the interval tree lock. Deserves at least a comment if not an
assertion too.

>  static int nouveau_range_fault(struct nouveau_svmm *svmm,
>  			       struct nouveau_drm *drm, void *data, u32 size,
> -			       u64 *pfns, struct svm_notifier *notifier)
> +			       u64 *pfns, u64 start, u64 end)
>  {
>  	unsigned long timeout =
>  		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
>  	/* Have HMM fault pages within the fault window to the GPU. */
>  	struct hmm_range range = {
> -		.notifier = &notifier->notifier,
> -		.start = notifier->notifier.interval_tree.start,
> -		.end = notifier->notifier.interval_tree.last + 1,
> +		.start = start,
> +		.end = end,
>  		.pfns = pfns,
>  		.flags = nouveau_svm_pfn_flags,
>  		.values = nouveau_svm_pfn_values,
> +		.default_flags = 0,
> +		.pfn_flags_mask = ~0UL,
>  		.pfn_shift = NVIF_VMM_PFNMAP_V0_ADDR_SHIFT,
>  	};
> -	struct mm_struct *mm = notifier->notifier.mm;
> +	struct mm_struct *mm = svmm->mm;
>  	long ret;
>  
>  	while (true) {
>  		if (time_after(jiffies, timeout))
>  			return -EBUSY;
>  
> -		range.notifier_seq = mmu_interval_read_begin(range.notifier);
> -		range.default_flags = 0;
> -		range.pfn_flags_mask = -1UL;
>  		down_read(&mm->mmap_sem);

mmap sem doesn't have to be held for the interval search, and again we
have lifetime issues with the membership here.

> +		ret = nouveau_svmm_interval_find(svmm, &range);
> +		if (ret) {
> +			up_read(&mm->mmap_sem);
> +			return ret;
> +		}
> +		range.notifier_seq = mmu_interval_read_begin(range.notifier);
>  		ret = hmm_range_fault(&range, 0);
>  		up_read(&mm->mmap_sem);
>  		if (ret <= 0) {

I'm still not sure this is a better approach than what ODP does. It
looks very expensive on the fault path..

Jason
Ralph Campbell Jan. 15, 2020, 10:09 p.m. UTC | #2
On 1/14/20 5:00 AM, Jason Gunthorpe wrote:
> On Mon, Jan 13, 2020 at 02:47:02PM -0800, Ralph Campbell wrote:
>>   void
>>   nouveau_svmm_fini(struct nouveau_svmm **psvmm)
>>   {
>>   	struct nouveau_svmm *svmm = *psvmm;
>> +	struct mmu_interval_notifier *mni;
>> +
>>   	if (svmm) {
>>   		mutex_lock(&svmm->mutex);
>> +		while (true) {
>> +			mni = mmu_interval_notifier_find(svmm->mm,
>> +					&nouveau_svm_mni_ops, 0UL, ~0UL);
>> +			if (!mni)
>> +				break;
>> +			mmu_interval_notifier_put(mni);
> 
> Oh, now I really don't like the name 'put'. It looks like mni is
> refcounted here, and it isn't. put should be called 'remove_deferred'

OK.

> And then you also need a way to barrier this scheme on driver unload.

Good point. I can add something like
void mmu_interval_notifier_synchronize(struct mm_struct *mm)
that waits for deferred operations to complete similar to
mmu_interval_read_begin().

>> +		}
>>   		svmm->vmm = NULL;
>>   		mutex_unlock(&svmm->mutex);
>> -		mmu_notifier_put(&svmm->notifier);
> 
> While here it was actually a refcount.
> 
>> +static void nouveau_svmm_do_unmap(struct mmu_interval_notifier *mni,
>> +				 const struct mmu_notifier_range *range)
>> +{
>> +	struct svmm_interval *smi =
>> +		container_of(mni, struct svmm_interval, notifier);
>> +	struct nouveau_svmm *svmm = smi->svmm;
>> +	unsigned long start = mmu_interval_notifier_start(mni);
>> +	unsigned long last = mmu_interval_notifier_last(mni);
> 
> This whole algorithm only works if it is protected by the read side of
> the interval tree lock. Deserves at least a comment if not an
> assertion too.

This is called from the invalidate() callback and while holding the
driver page table lock so the struct mmu_interval_notifier and
the interval tree can't change.
I will add comments for v7.

>>   static int nouveau_range_fault(struct nouveau_svmm *svmm,
>>   			       struct nouveau_drm *drm, void *data, u32 size,
>> -			       u64 *pfns, struct svm_notifier *notifier)
>> +			       u64 *pfns, u64 start, u64 end)
>>   {
>>   	unsigned long timeout =
>>   		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
>>   	/* Have HMM fault pages within the fault window to the GPU. */
>>   	struct hmm_range range = {
>> -		.notifier = &notifier->notifier,
>> -		.start = notifier->notifier.interval_tree.start,
>> -		.end = notifier->notifier.interval_tree.last + 1,
>> +		.start = start,
>> +		.end = end,
>>   		.pfns = pfns,
>>   		.flags = nouveau_svm_pfn_flags,
>>   		.values = nouveau_svm_pfn_values,
>> +		.default_flags = 0,
>> +		.pfn_flags_mask = ~0UL,
>>   		.pfn_shift = NVIF_VMM_PFNMAP_V0_ADDR_SHIFT,
>>   	};
>> -	struct mm_struct *mm = notifier->notifier.mm;
>> +	struct mm_struct *mm = svmm->mm;
>>   	long ret;
>>   
>>   	while (true) {
>>   		if (time_after(jiffies, timeout))
>>   			return -EBUSY;
>>   
>> -		range.notifier_seq = mmu_interval_read_begin(range.notifier);
>> -		range.default_flags = 0;
>> -		range.pfn_flags_mask = -1UL;
>>   		down_read(&mm->mmap_sem);
> 
> mmap sem doesn't have to be held for the interval search, and again we
> have lifetime issues with the membership here.

I agree mmap_sem isn't needed for the interval search, it is needed if
the search doesn't find a registered interval and one needs to be created
to cover the underlying VMA. If an arbitrary size interval was created
instead, then mmap_sem wouldn't be needed.
I don't understand the lifetime/membership issue. The driver is the only thing
that allocates, inserts, or removes struct mmu_interval_notifier and thus
completely controls the lifetime.

>> +		ret = nouveau_svmm_interval_find(svmm, &range);
>> +		if (ret) {
>> +			up_read(&mm->mmap_sem);
>> +			return ret;
>> +		}
>> +		range.notifier_seq = mmu_interval_read_begin(range.notifier);
>>   		ret = hmm_range_fault(&range, 0);
>>   		up_read(&mm->mmap_sem);
>>   		if (ret <= 0) {
> 
> I'm still not sure this is a better approach than what ODP does. It
> looks very expensive on the fault path..
> 
> Jason
> 

ODP doesn't have this problem because users have to call ib_reg_mr()
before any I/O can happen to the process address space. That is when
mmu_interval_notifier_insert() / mmu_interval_notifier_remove() can
be called and the driver doesn't have to worry about the interval
changing sizes or being removed while I/O is happening.
For GPU like devices, I'm trying to allow hardware access to any user
level address without pre-registering it. That means inserting mmu
interval notifiers for the ranges the GPU page faults on and updating
the intervals as munmap() calls remove parts of the address space.
I don't want to register an interval per page so the logical range
is the underlying VMA.

It isn't that expensive, there is an extra driver lock/unlock as
part of the lookup and possibly a find_vma() and kmalloc(GFP_ATOMIC)
for new intervals. Also, the deferred interval updates for munmap().
Compared to the cost of updating PTEs in the device and GPU fault
handling, this is minimal overhead.
Jason Gunthorpe Jan. 16, 2020, 4 p.m. UTC | #3
On Wed, Jan 15, 2020 at 02:09:47PM -0800, Ralph Campbell wrote:

> I don't understand the lifetime/membership issue. The driver is the only thing
> that allocates, inserts, or removes struct mmu_interval_notifier and thus
> completely controls the lifetime.

If the returned value is on the defered list it could be freed at any
moment. The existing locks do not prevent it.

> > > +		ret = nouveau_svmm_interval_find(svmm, &range);
> > > +		if (ret) {
> > > +			up_read(&mm->mmap_sem);
> > > +			return ret;
> > > +		}
> > > +		range.notifier_seq = mmu_interval_read_begin(range.notifier);
> > >   		ret = hmm_range_fault(&range, 0);
> > >   		up_read(&mm->mmap_sem);
> > >   		if (ret <= 0) {
> > 
> > I'm still not sure this is a better approach than what ODP does. It
> > looks very expensive on the fault path..
> > 
> > Jason
> > 
> 
> ODP doesn't have this problem because users have to call ib_reg_mr()
> before any I/O can happen to the process address space.

ODP supports a single 'full VA' call at process startup, just like
these cases.

> That is when mmu_interval_notifier_insert() /
> mmu_interval_notifier_remove() can be called and the driver doesn't
> have to worry about the interval changing sizes or being removed
> while I/O is happening.  

No, for the 'ODP full process VA' (aka implicit ODP) mode it
dynamically maintains a list of intervals. ODP chooses the align the
dynamic intervals to it's HW page table levels, and not to SW VMAs.
This is much simpler to manage and faster to fault, at the cost of
capturing more VA for invalidations which have to be probed against
the HW shadow PTEs.

> It isn't that expensive, there is an extra driver lock/unlock as
> part of the lookup and possibly a find_vma() and kmalloc(GFP_ATOMIC)
> for new intervals. Also, the deferred interval updates for munmap().
> Compared to the cost of updating PTEs in the device and GPU fault
> handling, this is minimal overhead.

Well, compared to ODP which does a single xa lookup with no lock to
find its interval, this looks very expensive and not parallel.

I think if there is merit in having ranges cover the vmas and track
changes then there is probably merit in having the core code provide
much of that logic, not the driver.

But it would be interesting to see some kind of analysis on the two
methods to decide if the complexity is worthwhile.

Jason
Ralph Campbell Jan. 16, 2020, 8:16 p.m. UTC | #4
On 1/16/20 8:00 AM, Jason Gunthorpe wrote:
> On Wed, Jan 15, 2020 at 02:09:47PM -0800, Ralph Campbell wrote:
> 
>> I don't understand the lifetime/membership issue. The driver is the only thing
>> that allocates, inserts, or removes struct mmu_interval_notifier and thus
>> completely controls the lifetime.
> 
> If the returned value is on the defered list it could be freed at any
> moment. The existing locks do not prevent it.
> 
>>>> +		ret = nouveau_svmm_interval_find(svmm, &range);
>>>> +		if (ret) {
>>>> +			up_read(&mm->mmap_sem);
>>>> +			return ret;
>>>> +		}
>>>> +		range.notifier_seq = mmu_interval_read_begin(range.notifier);
>>>>    		ret = hmm_range_fault(&range, 0);
>>>>    		up_read(&mm->mmap_sem);
>>>>    		if (ret <= 0) {
>>>
>>> I'm still not sure this is a better approach than what ODP does. It
>>> looks very expensive on the fault path..
>>>
>>> Jason
>>>
>>
>> ODP doesn't have this problem because users have to call ib_reg_mr()
>> before any I/O can happen to the process address space.
> 
> ODP supports a single 'full VA' call at process startup, just like
> these cases.
> 
>> That is when mmu_interval_notifier_insert() /
>> mmu_interval_notifier_remove() can be called and the driver doesn't
>> have to worry about the interval changing sizes or being removed
>> while I/O is happening.
> 
> No, for the 'ODP full process VA' (aka implicit ODP) mode it
> dynamically maintains a list of intervals. ODP chooses the align the
> dynamic intervals to it's HW page table levels, and not to SW VMAs.
> This is much simpler to manage and faster to fault, at the cost of
> capturing more VA for invalidations which have to be probed against
> the HW shadow PTEs.
> 
>> It isn't that expensive, there is an extra driver lock/unlock as
>> part of the lookup and possibly a find_vma() and kmalloc(GFP_ATOMIC)
>> for new intervals. Also, the deferred interval updates for munmap().
>> Compared to the cost of updating PTEs in the device and GPU fault
>> handling, this is minimal overhead.
> 
> Well, compared to ODP which does a single xa lookup with no lock to
> find its interval, this looks very expensive and not parallel.
> 
> I think if there is merit in having ranges cover the vmas and track
> changes then there is probably merit in having the core code provide
> much of that logic, not the driver.
> 
> But it would be interesting to see some kind of analysis on the two
> methods to decide if the complexity is worthwhile.
> 
> Jason
> 

Can you point me to the latest ODP code? Seems like my understanding is
quite off.
Jason Gunthorpe Jan. 16, 2020, 8:21 p.m. UTC | #5
On Thu, Jan 16, 2020 at 12:16:30PM -0800, Ralph Campbell wrote:
> Can you point me to the latest ODP code? Seems like my understanding is
> quite off.

https://elixir.bootlin.com/linux/v5.5-rc6/source/drivers/infiniband/hw/mlx5/odp.c

Look for the word 'implicit'

mlx5_ib_invalidate_range() releases the interval_notifier when there are
no populated shadow PTEs in its leaf

pagefault_implicit_mr() creates an interval_notifier that covers the
level in the page table that needs population. Notice it just uses an
unlocked xa_load to find the page table level.

The locking is pretty tricky as it relies on RCU, but the fault flow
is fairly lightweight.

Jason
Ralph Campbell Feb. 20, 2020, 1:10 a.m. UTC | #6
On 1/16/20 12:21 PM, Jason Gunthorpe wrote:
> On Thu, Jan 16, 2020 at 12:16:30PM -0800, Ralph Campbell wrote:
>> Can you point me to the latest ODP code? Seems like my understanding is
>> quite off.
> 
> https://elixir.bootlin.com/linux/v5.5-rc6/source/drivers/infiniband/hw/mlx5/odp.c
> 
> Look for the word 'implicit'
> 
> mlx5_ib_invalidate_range() releases the interval_notifier when there are
> no populated shadow PTEs in its leaf
> 
> pagefault_implicit_mr() creates an interval_notifier that covers the
> level in the page table that needs population. Notice it just uses an
> unlocked xa_load to find the page table level.
> 
> The locking is pretty tricky as it relies on RCU, but the fault flow
> is fairly lightweight.
> 
> Jason
> 
Thanks for the information, Jason.

I'm still interested in finding a way to support range based hints to device drivers.
madvise() looks like it only sets a bit in vma->vm_flags or acts on the
advice immediately. mbind() and set_mempolicy() only work with CPUs and memory
with NUMA a node number. What I'm looking for is a way for the device to know
whether to migrate pages to device private memory on a fault, whether to duplicate
read-only pages in device private memory, or remote map/access a page instead of migrating it.
For example, there is a working draft extension to OpenCL,
https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/USM/cl_intel_unified_shared_memory.asciidoc
that could provide a way to specify this sort of advice.
C++ is also looking at extentions for specifying affinity attributes.
In any case, these are probably a long ways off before being finalized and implemented.

I also have some changes to support THP migration to device private memory but that
will require updating nouveau to use 2MB TLB mappings.

In the mean time, I can update the HMM self tests to do something like ODP without
changing mm/mmu_notifier.c but I don't think I can easily change nouveau to that model.
diff mbox series

Patch

diff --git a/drivers/gpu/drm/nouveau/nouveau_svm.c b/drivers/gpu/drm/nouveau/nouveau_svm.c
index df9bf1fd1bc0..0343e48d41d7 100644
--- a/drivers/gpu/drm/nouveau/nouveau_svm.c
+++ b/drivers/gpu/drm/nouveau/nouveau_svm.c
@@ -88,7 +88,7 @@  nouveau_ivmm_find(struct nouveau_svm *svm, u64 inst)
 }
 
 struct nouveau_svmm {
-	struct mmu_notifier notifier;
+	struct mm_struct *mm;
 	struct nouveau_vmm *vmm;
 	struct {
 		unsigned long start;
@@ -98,6 +98,13 @@  struct nouveau_svmm {
 	struct mutex mutex;
 };
 
+struct svmm_interval {
+	struct mmu_interval_notifier notifier;
+	struct nouveau_svmm *svmm;
+};
+
+static const struct mmu_interval_notifier_ops nouveau_svm_mni_ops;
+
 #define SVMM_DBG(s,f,a...)                                                     \
 	NV_DEBUG((s)->vmm->cli->drm, "svm-%p: "f"\n", (s), ##a)
 #define SVMM_ERR(s,f,a...)                                                     \
@@ -236,6 +243,8 @@  nouveau_svmm_join(struct nouveau_svmm *svmm, u64 inst)
 static void
 nouveau_svmm_invalidate(struct nouveau_svmm *svmm, u64 start, u64 limit)
 {
+	SVMM_DBG(svmm, "invalidate %016llx-%016llx", start, limit);
+
 	if (limit > start) {
 		bool super = svmm->vmm->vmm.object.client->super;
 		svmm->vmm->vmm.object.client->super = true;
@@ -248,58 +257,25 @@  nouveau_svmm_invalidate(struct nouveau_svmm *svmm, u64 start, u64 limit)
 	}
 }
 
-static int
-nouveau_svmm_invalidate_range_start(struct mmu_notifier *mn,
-				    const struct mmu_notifier_range *update)
-{
-	struct nouveau_svmm *svmm =
-		container_of(mn, struct nouveau_svmm, notifier);
-	unsigned long start = update->start;
-	unsigned long limit = update->end;
-
-	if (!mmu_notifier_range_blockable(update))
-		return -EAGAIN;
-
-	SVMM_DBG(svmm, "invalidate %016lx-%016lx", start, limit);
-
-	mutex_lock(&svmm->mutex);
-	if (unlikely(!svmm->vmm))
-		goto out;
-
-	if (limit > svmm->unmanaged.start && start < svmm->unmanaged.limit) {
-		if (start < svmm->unmanaged.start) {
-			nouveau_svmm_invalidate(svmm, start,
-						svmm->unmanaged.limit);
-		}
-		start = svmm->unmanaged.limit;
-	}
-
-	nouveau_svmm_invalidate(svmm, start, limit);
-
-out:
-	mutex_unlock(&svmm->mutex);
-	return 0;
-}
-
-static void nouveau_svmm_free_notifier(struct mmu_notifier *mn)
-{
-	kfree(container_of(mn, struct nouveau_svmm, notifier));
-}
-
-static const struct mmu_notifier_ops nouveau_mn_ops = {
-	.invalidate_range_start = nouveau_svmm_invalidate_range_start,
-	.free_notifier = nouveau_svmm_free_notifier,
-};
-
 void
 nouveau_svmm_fini(struct nouveau_svmm **psvmm)
 {
 	struct nouveau_svmm *svmm = *psvmm;
+	struct mmu_interval_notifier *mni;
+
 	if (svmm) {
 		mutex_lock(&svmm->mutex);
+		while (true) {
+			mni = mmu_interval_notifier_find(svmm->mm,
+					&nouveau_svm_mni_ops, 0UL, ~0UL);
+			if (!mni)
+				break;
+			mmu_interval_notifier_put(mni);
+		}
 		svmm->vmm = NULL;
 		mutex_unlock(&svmm->mutex);
-		mmu_notifier_put(&svmm->notifier);
+		mmdrop(svmm->mm);
+		kfree(svmm);
 		*psvmm = NULL;
 	}
 }
@@ -343,11 +319,12 @@  nouveau_svmm_init(struct drm_device *dev, void *data,
 		goto out_free;
 
 	down_write(&current->mm->mmap_sem);
-	svmm->notifier.ops = &nouveau_mn_ops;
-	ret = __mmu_notifier_register(&svmm->notifier, current->mm);
+	ret = __mmu_notifier_register(NULL, current->mm);
 	if (ret)
 		goto out_mm_unlock;
-	/* Note, ownership of svmm transfers to mmu_notifier */
+
+	mmgrab(current->mm);
+	svmm->mm = current->mm;
 
 	cli->svm.svmm = svmm;
 	cli->svm.cli = cli;
@@ -482,65 +459,212 @@  nouveau_svm_fault_cache(struct nouveau_svm *svm,
 		fault->inst, fault->addr, fault->access);
 }
 
-struct svm_notifier {
-	struct mmu_interval_notifier notifier;
-	struct nouveau_svmm *svmm;
-};
+static struct svmm_interval *nouveau_svmm_new_interval(
+					struct nouveau_svmm *svmm,
+					unsigned long start,
+					unsigned long last)
+{
+	struct svmm_interval *smi;
+	int ret;
+
+	smi = kmalloc(sizeof(*smi), GFP_ATOMIC);
+	if (!smi)
+		return NULL;
+
+	smi->svmm = svmm;
+
+	ret = mmu_interval_notifier_insert_safe(&smi->notifier, svmm->mm,
+				start, last - start + 1, &nouveau_svm_mni_ops);
+	if (ret) {
+		kfree(smi);
+		return NULL;
+	}
+
+	return smi;
+}
+
+static void nouveau_svmm_do_unmap(struct mmu_interval_notifier *mni,
+				 const struct mmu_notifier_range *range)
+{
+	struct svmm_interval *smi =
+		container_of(mni, struct svmm_interval, notifier);
+	struct nouveau_svmm *svmm = smi->svmm;
+	unsigned long start = mmu_interval_notifier_start(mni);
+	unsigned long last = mmu_interval_notifier_last(mni);
+
+	if (start >= range->start) {
+		/* Remove the whole interval or keep the right-hand part. */
+		if (last <= range->end)
+			mmu_interval_notifier_put(mni);
+		else
+			mmu_interval_notifier_update(mni, range->end, last);
+		return;
+	}
+
+	/* Keep the left-hand part of the interval. */
+	mmu_interval_notifier_update(mni, start, range->start - 1);
+
+	/* If a hole is created, create an interval for the right-hand part. */
+	if (last >= range->end) {
+		smi = nouveau_svmm_new_interval(svmm, range->end, last);
+		/*
+		 * If we can't allocate an interval, we won't get invalidation
+		 * callbacks so clear the mapping and rely on faults to reload
+		 * the mappings if needed.
+		 */
+		if (!smi)
+			nouveau_svmm_invalidate(svmm, range->end, last + 1);
+	}
+}
 
-static bool nouveau_svm_range_invalidate(struct mmu_interval_notifier *mni,
-					 const struct mmu_notifier_range *range,
-					 unsigned long cur_seq)
+static bool nouveau_svmm_interval_invalidate(struct mmu_interval_notifier *mni,
+				const struct mmu_notifier_range *range,
+				unsigned long cur_seq)
 {
-	struct svm_notifier *sn =
-		container_of(mni, struct svm_notifier, notifier);
+	struct svmm_interval *smi =
+		container_of(mni, struct svmm_interval, notifier);
+	struct nouveau_svmm *svmm = smi->svmm;
 
 	/*
-	 * serializes the update to mni->invalidate_seq done by caller and
+	 * Serializes the update to mni->invalidate_seq done by the caller and
 	 * prevents invalidation of the PTE from progressing while HW is being
-	 * programmed. This is very hacky and only works because the normal
-	 * notifier that does invalidation is always called after the range
-	 * notifier.
+	 * programmed.
 	 */
 	if (mmu_notifier_range_blockable(range))
-		mutex_lock(&sn->svmm->mutex);
-	else if (!mutex_trylock(&sn->svmm->mutex))
+		mutex_lock(&svmm->mutex);
+	else if (!mutex_trylock(&svmm->mutex))
 		return false;
+
 	mmu_interval_set_seq(mni, cur_seq);
-	mutex_unlock(&sn->svmm->mutex);
+	nouveau_svmm_invalidate(svmm, range->start, range->end);
+
+	/* Stop tracking the range if it is an unmap. */
+	if (range->event == MMU_NOTIFY_UNMAP)
+		nouveau_svmm_do_unmap(mni, range);
+
+	mutex_unlock(&svmm->mutex);
 	return true;
 }
 
+static void nouveau_svmm_interval_release(struct mmu_interval_notifier *mni)
+{
+	struct svmm_interval *smi =
+		container_of(mni, struct svmm_interval, notifier);
+
+	kfree(smi);
+}
+
 static const struct mmu_interval_notifier_ops nouveau_svm_mni_ops = {
-	.invalidate = nouveau_svm_range_invalidate,
+	.invalidate = nouveau_svmm_interval_invalidate,
+	.release = nouveau_svmm_interval_release,
 };
 
+/*
+ * Find or create a mmu_interval_notifier for the given range.
+ * Although mmu_interval_notifier_insert_safe() can handle overlapping
+ * intervals, we only create non-overlapping intervals, shrinking the hmm_range
+ * if it spans more than one svmm_interval.
+ */
+static int nouveau_svmm_interval_find(struct nouveau_svmm *svmm,
+				 struct hmm_range *range)
+{
+	struct mmu_interval_notifier *mni;
+	struct svmm_interval *smi;
+	struct vm_area_struct *vma;
+	unsigned long start = range->start;
+	unsigned long last = range->end - 1;
+	int ret;
+
+	mutex_lock(&svmm->mutex);
+	mni = mmu_interval_notifier_find(svmm->mm, &nouveau_svm_mni_ops, start,
+					 last);
+	if (mni) {
+		if (start >= mmu_interval_notifier_start(mni)) {
+			smi = container_of(mni, struct svmm_interval, notifier);
+			if (last > mmu_interval_notifier_last(mni))
+				range->end =
+					mmu_interval_notifier_last(mni) + 1;
+			goto found;
+		}
+		WARN_ON(last <= mmu_interval_notifier_start(mni));
+		range->end = mmu_interval_notifier_start(mni);
+		last = range->end - 1;
+	}
+	/*
+	 * Might as well create an interval covering the underlying VMA to
+	 * avoid having to create a bunch of small intervals.
+	 */
+	vma = find_vma(svmm->mm, range->start);
+	if (!vma || start < vma->vm_start) {
+		ret = -ENOENT;
+		goto err;
+	}
+	if (range->end > vma->vm_end) {
+		range->end = vma->vm_end;
+		last = range->end - 1;
+	} else if (!mni) {
+		/* Anything registered on the right part of the vma? */
+		mni = mmu_interval_notifier_find(svmm->mm, &nouveau_svm_mni_ops,
+						 range->end, vma->vm_end - 1);
+		if (mni)
+			last = mmu_interval_notifier_start(mni) - 1;
+		else
+			last = vma->vm_end - 1;
+	}
+	/* Anything registered on the left part of the vma? */
+	mni = mmu_interval_notifier_find(svmm->mm, &nouveau_svm_mni_ops,
+					 vma->vm_start, start - 1);
+	if (mni)
+		start = mmu_interval_notifier_last(mni) + 1;
+	else
+		start = vma->vm_start;
+	smi = nouveau_svmm_new_interval(svmm, start, last);
+	if (!smi) {
+		ret = -ENOMEM;
+		goto err;
+	}
+
+found:
+	range->notifier = &smi->notifier;
+	mutex_unlock(&svmm->mutex);
+	return 0;
+
+err:
+	mutex_unlock(&svmm->mutex);
+	return ret;
+}
+
 static int nouveau_range_fault(struct nouveau_svmm *svmm,
 			       struct nouveau_drm *drm, void *data, u32 size,
-			       u64 *pfns, struct svm_notifier *notifier)
+			       u64 *pfns, u64 start, u64 end)
 {
 	unsigned long timeout =
 		jiffies + msecs_to_jiffies(HMM_RANGE_DEFAULT_TIMEOUT);
 	/* Have HMM fault pages within the fault window to the GPU. */
 	struct hmm_range range = {
-		.notifier = &notifier->notifier,
-		.start = notifier->notifier.interval_tree.start,
-		.end = notifier->notifier.interval_tree.last + 1,
+		.start = start,
+		.end = end,
 		.pfns = pfns,
 		.flags = nouveau_svm_pfn_flags,
 		.values = nouveau_svm_pfn_values,
+		.default_flags = 0,
+		.pfn_flags_mask = ~0UL,
 		.pfn_shift = NVIF_VMM_PFNMAP_V0_ADDR_SHIFT,
 	};
-	struct mm_struct *mm = notifier->notifier.mm;
+	struct mm_struct *mm = svmm->mm;
 	long ret;
 
 	while (true) {
 		if (time_after(jiffies, timeout))
 			return -EBUSY;
 
-		range.notifier_seq = mmu_interval_read_begin(range.notifier);
-		range.default_flags = 0;
-		range.pfn_flags_mask = -1UL;
 		down_read(&mm->mmap_sem);
+		ret = nouveau_svmm_interval_find(svmm, &range);
+		if (ret) {
+			up_read(&mm->mmap_sem);
+			return ret;
+		}
+		range.notifier_seq = mmu_interval_read_begin(range.notifier);
 		ret = hmm_range_fault(&range, 0);
 		up_read(&mm->mmap_sem);
 		if (ret <= 0) {
@@ -585,7 +709,6 @@  nouveau_svm_fault(struct nvif_notify *notify)
 		} i;
 		u64 phys[16];
 	} args;
-	struct vm_area_struct *vma;
 	u64 inst, start, limit;
 	int fi, fn, pi, fill;
 	int replay = 0, ret;
@@ -640,7 +763,6 @@  nouveau_svm_fault(struct nvif_notify *notify)
 	args.i.p.version = 0;
 
 	for (fi = 0; fn = fi + 1, fi < buffer->fault_nr; fi = fn) {
-		struct svm_notifier notifier;
 		struct mm_struct *mm;
 
 		/* Cancel any faults from non-SVM channels. */
@@ -662,36 +784,12 @@  nouveau_svm_fault(struct nvif_notify *notify)
 			start = max_t(u64, start, svmm->unmanaged.limit);
 		SVMM_DBG(svmm, "wndw %016llx-%016llx", start, limit);
 
-		mm = svmm->notifier.mm;
+		mm = svmm->mm;
 		if (!mmget_not_zero(mm)) {
 			nouveau_svm_fault_cancel_fault(svm, buffer->fault[fi]);
 			continue;
 		}
 
-		/* Intersect fault window with the CPU VMA, cancelling
-		 * the fault if the address is invalid.
-		 */
-		down_read(&mm->mmap_sem);
-		vma = find_vma_intersection(mm, start, limit);
-		if (!vma) {
-			SVMM_ERR(svmm, "wndw %016llx-%016llx", start, limit);
-			up_read(&mm->mmap_sem);
-			mmput(mm);
-			nouveau_svm_fault_cancel_fault(svm, buffer->fault[fi]);
-			continue;
-		}
-		start = max_t(u64, start, vma->vm_start);
-		limit = min_t(u64, limit, vma->vm_end);
-		up_read(&mm->mmap_sem);
-		SVMM_DBG(svmm, "wndw %016llx-%016llx", start, limit);
-
-		if (buffer->fault[fi]->addr != start) {
-			SVMM_ERR(svmm, "addr %016llx", buffer->fault[fi]->addr);
-			mmput(mm);
-			nouveau_svm_fault_cancel_fault(svm, buffer->fault[fi]);
-			continue;
-		}
-
 		/* Prepare the GPU-side update of all pages within the
 		 * fault window, determining required pages and access
 		 * permissions based on pending faults.
@@ -743,18 +841,9 @@  nouveau_svm_fault(struct nvif_notify *notify)
 			 args.i.p.addr,
 			 args.i.p.addr + args.i.p.size, fn - fi);
 
-		notifier.svmm = svmm;
-		ret = mmu_interval_notifier_insert(&notifier.notifier,
-						   svmm->notifier.mm,
-						   args.i.p.addr, args.i.p.size,
-						   &nouveau_svm_mni_ops);
-		if (!ret) {
-			ret = nouveau_range_fault(
-				svmm, svm->drm, &args,
-				sizeof(args.i) + pi * sizeof(args.phys[0]),
-				args.phys, &notifier);
-			mmu_interval_notifier_remove(&notifier.notifier);
-		}
+		ret = nouveau_range_fault(svmm, svm->drm, &args,
+			sizeof(args.i) + pi * sizeof(args.phys[0]), args.phys,
+			start, start + args.i.p.size);
 		mmput(mm);
 
 		/* Cancel any faults in the window whose pages didn't manage