diff mbox series

[hmm,03/15] mm/hmm: allow hmm_range to be used with a mmu_range_notifier or hmm_mirror

Message ID 20191015181242.8343-4-jgg@ziepe.ca (mailing list archive)
State New, archived
Headers show
Series Consolidate the mmu notifier interval_tree and locking | expand

Commit Message

Jason Gunthorpe Oct. 15, 2019, 6:12 p.m. UTC
From: Jason Gunthorpe <jgg@mellanox.com>

hmm_mirror's handling of ranges does not use a sequence count which
results in this bug:

         CPU0                                   CPU1
                                     hmm_range_wait_until_valid(range)
                                         valid == true
                                     hmm_range_fault(range)
hmm_invalidate_range_start()
   range->valid = false
hmm_invalidate_range_end()
   range->valid = true
                                     hmm_range_valid(range)
                                          valid == true

Where the hmm_range_valid should not have succeeded.

Adding the required sequence count would make it nearly identical to the
new mmu_range_notifier. Instead replace the hmm_mirror stuff with
mmu_range_notifier.

Co-existence of the two APIs is the first step.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
---
 include/linux/hmm.h |  5 +++++
 mm/hmm.c            | 25 +++++++++++++++++++------
 2 files changed, 24 insertions(+), 6 deletions(-)

Comments

Jerome Glisse Oct. 21, 2019, 6:33 p.m. UTC | #1
On Tue, Oct 15, 2019 at 03:12:30PM -0300, Jason Gunthorpe wrote:
> From: Jason Gunthorpe <jgg@mellanox.com>
> 
> hmm_mirror's handling of ranges does not use a sequence count which
> results in this bug:
> 
>          CPU0                                   CPU1
>                                      hmm_range_wait_until_valid(range)
>                                          valid == true
>                                      hmm_range_fault(range)
> hmm_invalidate_range_start()
>    range->valid = false
> hmm_invalidate_range_end()
>    range->valid = true
>                                      hmm_range_valid(range)
>                                           valid == true
> 
> Where the hmm_range_valid should not have succeeded.
> 
> Adding the required sequence count would make it nearly identical to the
> new mmu_range_notifier. Instead replace the hmm_mirror stuff with
> mmu_range_notifier.
> 
> Co-existence of the two APIs is the first step.
> 
> Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>

Reviewed-by: Jérôme Glisse <jglisse@redhat.com>

> ---
>  include/linux/hmm.h |  5 +++++
>  mm/hmm.c            | 25 +++++++++++++++++++------
>  2 files changed, 24 insertions(+), 6 deletions(-)
> 
> diff --git a/include/linux/hmm.h b/include/linux/hmm.h
> index 3fec513b9c00f1..8ac1fd6a81af8f 100644
> --- a/include/linux/hmm.h
> +++ b/include/linux/hmm.h
> @@ -145,6 +145,9 @@ enum hmm_pfn_value_e {
>  /*
>   * struct hmm_range - track invalidation lock on virtual address range
>   *
> + * @notifier: an optional mmu_range_notifier
> + * @notifier_seq: when notifier is used this is the result of
> + *                mmu_range_read_begin()
>   * @hmm: the core HMM structure this range is active against
>   * @vma: the vm area struct for the range
>   * @list: all range lock are on a list
> @@ -159,6 +162,8 @@ enum hmm_pfn_value_e {
>   * @valid: pfns array did not change since it has been fill by an HMM function
>   */
>  struct hmm_range {
> +	struct mmu_range_notifier *notifier;
> +	unsigned long		notifier_seq;
>  	struct hmm		*hmm;
>  	struct list_head	list;
>  	unsigned long		start;
> diff --git a/mm/hmm.c b/mm/hmm.c
> index 902f5fa6bf93ad..22ac3595771feb 100644
> --- a/mm/hmm.c
> +++ b/mm/hmm.c
> @@ -852,6 +852,14 @@ void hmm_range_unregister(struct hmm_range *range)
>  }
>  EXPORT_SYMBOL(hmm_range_unregister);
>  
> +static bool needs_retry(struct hmm_range *range)
> +{
> +	if (range->notifier)
> +		return mmu_range_check_retry(range->notifier,
> +					     range->notifier_seq);
> +	return !range->valid;
> +}
> +
>  static const struct mm_walk_ops hmm_walk_ops = {
>  	.pud_entry	= hmm_vma_walk_pud,
>  	.pmd_entry	= hmm_vma_walk_pmd,
> @@ -892,18 +900,23 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags)
>  	const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
>  	unsigned long start = range->start, end;
>  	struct hmm_vma_walk hmm_vma_walk;
> -	struct hmm *hmm = range->hmm;
> +	struct mm_struct *mm;
>  	struct vm_area_struct *vma;
>  	int ret;
>  
> -	lockdep_assert_held(&hmm->mmu_notifier.mm->mmap_sem);
> +	if (range->notifier)
> +		mm = range->notifier->mm;
> +	else
> +		mm = range->hmm->mmu_notifier.mm;
> +
> +	lockdep_assert_held(&mm->mmap_sem);
>  
>  	do {
>  		/* If range is no longer valid force retry. */
> -		if (!range->valid)
> +		if (needs_retry(range))
>  			return -EBUSY;
>  
> -		vma = find_vma(hmm->mmu_notifier.mm, start);
> +		vma = find_vma(mm, start);
>  		if (vma == NULL || (vma->vm_flags & device_vma))
>  			return -EFAULT;
>  
> @@ -933,7 +946,7 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags)
>  			start = hmm_vma_walk.last;
>  
>  			/* Keep trying while the range is valid. */
> -		} while (ret == -EBUSY && range->valid);
> +		} while (ret == -EBUSY && !needs_retry(range));
>  
>  		if (ret) {
>  			unsigned long i;
> @@ -991,7 +1004,7 @@ long hmm_range_dma_map(struct hmm_range *range, struct device *device,
>  			continue;
>  
>  		/* Check if range is being invalidated */
> -		if (!range->valid) {
> +		if (needs_retry(range)) {
>  			ret = -EBUSY;
>  			goto unmap;
>  		}
> -- 
> 2.23.0
>
diff mbox series

Patch

diff --git a/include/linux/hmm.h b/include/linux/hmm.h
index 3fec513b9c00f1..8ac1fd6a81af8f 100644
--- a/include/linux/hmm.h
+++ b/include/linux/hmm.h
@@ -145,6 +145,9 @@  enum hmm_pfn_value_e {
 /*
  * struct hmm_range - track invalidation lock on virtual address range
  *
+ * @notifier: an optional mmu_range_notifier
+ * @notifier_seq: when notifier is used this is the result of
+ *                mmu_range_read_begin()
  * @hmm: the core HMM structure this range is active against
  * @vma: the vm area struct for the range
  * @list: all range lock are on a list
@@ -159,6 +162,8 @@  enum hmm_pfn_value_e {
  * @valid: pfns array did not change since it has been fill by an HMM function
  */
 struct hmm_range {
+	struct mmu_range_notifier *notifier;
+	unsigned long		notifier_seq;
 	struct hmm		*hmm;
 	struct list_head	list;
 	unsigned long		start;
diff --git a/mm/hmm.c b/mm/hmm.c
index 902f5fa6bf93ad..22ac3595771feb 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -852,6 +852,14 @@  void hmm_range_unregister(struct hmm_range *range)
 }
 EXPORT_SYMBOL(hmm_range_unregister);
 
+static bool needs_retry(struct hmm_range *range)
+{
+	if (range->notifier)
+		return mmu_range_check_retry(range->notifier,
+					     range->notifier_seq);
+	return !range->valid;
+}
+
 static const struct mm_walk_ops hmm_walk_ops = {
 	.pud_entry	= hmm_vma_walk_pud,
 	.pmd_entry	= hmm_vma_walk_pmd,
@@ -892,18 +900,23 @@  long hmm_range_fault(struct hmm_range *range, unsigned int flags)
 	const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
 	unsigned long start = range->start, end;
 	struct hmm_vma_walk hmm_vma_walk;
-	struct hmm *hmm = range->hmm;
+	struct mm_struct *mm;
 	struct vm_area_struct *vma;
 	int ret;
 
-	lockdep_assert_held(&hmm->mmu_notifier.mm->mmap_sem);
+	if (range->notifier)
+		mm = range->notifier->mm;
+	else
+		mm = range->hmm->mmu_notifier.mm;
+
+	lockdep_assert_held(&mm->mmap_sem);
 
 	do {
 		/* If range is no longer valid force retry. */
-		if (!range->valid)
+		if (needs_retry(range))
 			return -EBUSY;
 
-		vma = find_vma(hmm->mmu_notifier.mm, start);
+		vma = find_vma(mm, start);
 		if (vma == NULL || (vma->vm_flags & device_vma))
 			return -EFAULT;
 
@@ -933,7 +946,7 @@  long hmm_range_fault(struct hmm_range *range, unsigned int flags)
 			start = hmm_vma_walk.last;
 
 			/* Keep trying while the range is valid. */
-		} while (ret == -EBUSY && range->valid);
+		} while (ret == -EBUSY && !needs_retry(range));
 
 		if (ret) {
 			unsigned long i;
@@ -991,7 +1004,7 @@  long hmm_range_dma_map(struct hmm_range *range, struct device *device,
 			continue;
 
 		/* Check if range is being invalidated */
-		if (!range->valid) {
+		if (needs_retry(range)) {
 			ret = -EBUSY;
 			goto unmap;
 		}