diff mbox series

[v12,21/31] mm: Introduce find_vma_rcu()

Message ID 20190416134522.17540-22-ldufour@linux.ibm.com (mailing list archive)
State New, archived
Headers show
Series Speculative page faults | expand

Commit Message

Laurent Dufour April 16, 2019, 1:45 p.m. UTC
This allows to search for a VMA structure without holding the mmap_sem.

The search is repeated while the mm seqlock is changing and until we found
a valid VMA.

While under the RCU protection, a reference is taken on the VMA, so the
caller must call put_vma() once it not more need the VMA structure.

At the time a VMA is inserted in the MM RB tree, in vma_rb_insert(), a
reference is taken to the VMA by calling get_vma().

When removing a VMA from the MM RB tree, the VMA is not release immediately
but at the end of the RCU grace period through vm_rcu_put(). This ensures
that the VMA remains allocated until the end the RCU grace period.

Since the vm_file pointer, if valid, is released in put_vma(), there is no
guarantee that the file pointer will be valid on the returned VMA.

Signed-off-by: Laurent Dufour <ldufour@linux.ibm.com>
---
 include/linux/mm_types.h |  1 +
 mm/internal.h            |  5 ++-
 mm/mmap.c                | 76 ++++++++++++++++++++++++++++++++++++++--
 3 files changed, 78 insertions(+), 4 deletions(-)

Comments

Jerome Glisse April 22, 2019, 8:57 p.m. UTC | #1
On Tue, Apr 16, 2019 at 03:45:12PM +0200, Laurent Dufour wrote:
> This allows to search for a VMA structure without holding the mmap_sem.
> 
> The search is repeated while the mm seqlock is changing and until we found
> a valid VMA.
> 
> While under the RCU protection, a reference is taken on the VMA, so the
> caller must call put_vma() once it not more need the VMA structure.
> 
> At the time a VMA is inserted in the MM RB tree, in vma_rb_insert(), a
> reference is taken to the VMA by calling get_vma().
> 
> When removing a VMA from the MM RB tree, the VMA is not release immediately
> but at the end of the RCU grace period through vm_rcu_put(). This ensures
> that the VMA remains allocated until the end the RCU grace period.
> 
> Since the vm_file pointer, if valid, is released in put_vma(), there is no
> guarantee that the file pointer will be valid on the returned VMA.
> 
> Signed-off-by: Laurent Dufour <ldufour@linux.ibm.com>

Minor comments about comment (i love recursion :)) see below.

Reviewed-by: Jérôme Glisse <jglisse@redhat.com>

> ---
>  include/linux/mm_types.h |  1 +
>  mm/internal.h            |  5 ++-
>  mm/mmap.c                | 76 ++++++++++++++++++++++++++++++++++++++--
>  3 files changed, 78 insertions(+), 4 deletions(-)
> 
> diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
> index 6a6159e11a3f..9af6694cb95d 100644
> --- a/include/linux/mm_types.h
> +++ b/include/linux/mm_types.h
> @@ -287,6 +287,7 @@ struct vm_area_struct {
>  
>  #ifdef CONFIG_SPECULATIVE_PAGE_FAULT
>  	atomic_t vm_ref_count;
> +	struct rcu_head vm_rcu;
>  #endif
>  	struct rb_node vm_rb;
>  
> diff --git a/mm/internal.h b/mm/internal.h
> index 302382bed406..1e368e4afe3c 100644
> --- a/mm/internal.h
> +++ b/mm/internal.h
> @@ -55,7 +55,10 @@ static inline void put_vma(struct vm_area_struct *vma)
>  		__free_vma(vma);
>  }
>  
> -#else
> +extern struct vm_area_struct *find_vma_rcu(struct mm_struct *mm,
> +					   unsigned long addr);
> +
> +#else /* CONFIG_SPECULATIVE_PAGE_FAULT */
>  
>  static inline void get_vma(struct vm_area_struct *vma)
>  {
> diff --git a/mm/mmap.c b/mm/mmap.c
> index c106440dcae7..34bf261dc2c8 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -179,6 +179,18 @@ static inline void mm_write_sequnlock(struct mm_struct *mm)
>  {
>  	write_sequnlock(&mm->mm_seq);
>  }
> +
> +static void __vm_rcu_put(struct rcu_head *head)
> +{
> +	struct vm_area_struct *vma = container_of(head, struct vm_area_struct,
> +						  vm_rcu);
> +	put_vma(vma);
> +}
> +static void vm_rcu_put(struct vm_area_struct *vma)
> +{
> +	VM_BUG_ON_VMA(!RB_EMPTY_NODE(&vma->vm_rb), vma);
> +	call_rcu(&vma->vm_rcu, __vm_rcu_put);
> +}
>  #else
>  static inline void mm_write_seqlock(struct mm_struct *mm)
>  {
> @@ -190,6 +202,8 @@ static inline void mm_write_sequnlock(struct mm_struct *mm)
>  
>  void __free_vma(struct vm_area_struct *vma)
>  {
> +	if (IS_ENABLED(CONFIG_SPECULATIVE_PAGE_FAULT))
> +		VM_BUG_ON_VMA(!RB_EMPTY_NODE(&vma->vm_rb), vma);
>  	mpol_put(vma_policy(vma));
>  	vm_area_free(vma);
>  }
> @@ -197,11 +211,24 @@ void __free_vma(struct vm_area_struct *vma)
>  /*
>   * Close a vm structure and free it, returning the next.
>   */
> -static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
> +static struct vm_area_struct *__remove_vma(struct vm_area_struct *vma)
>  {
>  	struct vm_area_struct *next = vma->vm_next;
>  
>  	might_sleep();
> +	if (IS_ENABLED(CONFIG_SPECULATIVE_PAGE_FAULT) &&
> +	    !RB_EMPTY_NODE(&vma->vm_rb)) {
> +		/*
> +		 * If the VMA is still linked in the RB tree, we must release
> +		 * that reference by calling put_vma().
> +		 * This should only happen when called from exit_mmap().
> +		 * We forcely clear the node to satisfy the chec in
                                                        ^
Typo: chec -> check

> +		 * __free_vma(). This is safe since the RB tree is not walked
> +		 * anymore.
> +		 */
> +		RB_CLEAR_NODE(&vma->vm_rb);
> +		put_vma(vma);
> +	}
>  	if (vma->vm_ops && vma->vm_ops->close)
>  		vma->vm_ops->close(vma);
>  	if (vma->vm_file)
> @@ -211,6 +238,13 @@ static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
>  	return next;
>  }
>  
> +static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
> +{
> +	if (IS_ENABLED(CONFIG_SPECULATIVE_PAGE_FAULT))
> +		VM_BUG_ON_VMA(!RB_EMPTY_NODE(&vma->vm_rb), vma);

Adding a comment here explaining the BUG_ON so people can understand
what is wrong if that happens. For instance:

/*
 * remove_vma() should be call only once a vma have been remove from the rbtree
 * at which point the vma->vm_rb is an empty node. The exception is when vmas
 * are destroy through exit_mmap() in which case we do not bother updating the
 * rbtree (see comment in __remove_vma()).
 */

> +	return __remove_vma(vma);
> +}
> +
>  static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long flags,
>  		struct list_head *uf);
>  SYSCALL_DEFINE1(brk, unsigned long, brk)
> @@ -475,7 +509,7 @@ static inline void vma_rb_insert(struct vm_area_struct *vma,
>  
>  	/* All rb_subtree_gap values must be consistent prior to insertion */
>  	validate_mm_rb(root, NULL);
> -
> +	get_vma(vma);
>  	rb_insert_augmented(&vma->vm_rb, root, &vma_gap_callbacks);
>  }
>  
> @@ -491,6 +525,14 @@ static void __vma_rb_erase(struct vm_area_struct *vma, struct mm_struct *mm)
>  	mm_write_seqlock(mm);
>  	rb_erase_augmented(&vma->vm_rb, root, &vma_gap_callbacks);
>  	mm_write_sequnlock(mm);	/* wmb */
> +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT
> +	/*
> +	 * Ensure the removal is complete before clearing the node.
> +	 * Matched by vma_has_changed()/handle_speculative_fault().
> +	 */
> +	RB_CLEAR_NODE(&vma->vm_rb);
> +	vm_rcu_put(vma);
> +#endif
>  }
>  
>  static __always_inline void vma_rb_erase_ignore(struct vm_area_struct *vma,
> @@ -2331,6 +2373,34 @@ struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr)
>  
>  EXPORT_SYMBOL(find_vma);
>  
> +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT
> +/*
> + * Like find_vma() but under the protection of RCU and the mm sequence counter.
> + * The vma returned has to be relaesed by the caller through the call to
> + * put_vma()
> + */
> +struct vm_area_struct *find_vma_rcu(struct mm_struct *mm, unsigned long addr)
> +{
> +	struct vm_area_struct *vma = NULL;
> +	unsigned int seq;
> +
> +	do {
> +		if (vma)
> +			put_vma(vma);
> +
> +		seq = read_seqbegin(&mm->mm_seq);
> +
> +		rcu_read_lock();
> +		vma = find_vma(mm, addr);
> +		if (vma)
> +			get_vma(vma);
> +		rcu_read_unlock();
> +	} while (read_seqretry(&mm->mm_seq, seq));
> +
> +	return vma;
> +}
> +#endif
> +
>  /*
>   * Same as find_vma, but also return a pointer to the previous VMA in *pprev.
>   */
> @@ -3231,7 +3301,7 @@ void exit_mmap(struct mm_struct *mm)
>  	while (vma) {
>  		if (vma->vm_flags & VM_ACCOUNT)
>  			nr_accounted += vma_pages(vma);
> -		vma = remove_vma(vma);
> +		vma = __remove_vma(vma);
>  	}
>  	vm_unacct_memory(nr_accounted);
>  }
> -- 
> 2.21.0
>
Peter Zijlstra April 23, 2019, 9:27 a.m. UTC | #2
On Tue, Apr 16, 2019 at 03:45:12PM +0200, Laurent Dufour wrote:
> This allows to search for a VMA structure without holding the mmap_sem.
> 
> The search is repeated while the mm seqlock is changing and until we found
> a valid VMA.
> 
> While under the RCU protection, a reference is taken on the VMA, so the
> caller must call put_vma() once it not more need the VMA structure.
> 
> At the time a VMA is inserted in the MM RB tree, in vma_rb_insert(), a
> reference is taken to the VMA by calling get_vma().
> 
> When removing a VMA from the MM RB tree, the VMA is not release immediately
> but at the end of the RCU grace period through vm_rcu_put(). This ensures
> that the VMA remains allocated until the end the RCU grace period.
> 
> Since the vm_file pointer, if valid, is released in put_vma(), there is no
> guarantee that the file pointer will be valid on the returned VMA.

What I'm missing here, and in the previous patch introducing the
refcount (also see refcount_t), is _why_ we need the refcount thing at
all.

My original plan was to use SRCU, which at the time was not complete
enough so I abused/hacked preemptible RCU, but that is no longer the
case, SRCU has all the required bits and pieces.

Also; the initial motivation was prefaulting large VMAs and the
contention on mmap was killing things; but similarly, the contention on
the refcount (I did try that) killed things just the same.

So I'm really sad to see the refcount return; and without any apparent
justification.
Davidlohr Bueso April 23, 2019, 6:13 p.m. UTC | #3
On Tue, 23 Apr 2019, Peter Zijlstra wrote:

>Also; the initial motivation was prefaulting large VMAs and the
>contention on mmap was killing things; but similarly, the contention on
>the refcount (I did try that) killed things just the same.

Right, this is just like what can happen with per-vma locking.

Thanks,
Davidlohr
Laurent Dufour April 24, 2019, 7:57 a.m. UTC | #4
Le 23/04/2019 à 11:27, Peter Zijlstra a écrit :
> On Tue, Apr 16, 2019 at 03:45:12PM +0200, Laurent Dufour wrote:
>> This allows to search for a VMA structure without holding the mmap_sem.
>>
>> The search is repeated while the mm seqlock is changing and until we found
>> a valid VMA.
>>
>> While under the RCU protection, a reference is taken on the VMA, so the
>> caller must call put_vma() once it not more need the VMA structure.
>>
>> At the time a VMA is inserted in the MM RB tree, in vma_rb_insert(), a
>> reference is taken to the VMA by calling get_vma().
>>
>> When removing a VMA from the MM RB tree, the VMA is not release immediately
>> but at the end of the RCU grace period through vm_rcu_put(). This ensures
>> that the VMA remains allocated until the end the RCU grace period.
>>
>> Since the vm_file pointer, if valid, is released in put_vma(), there is no
>> guarantee that the file pointer will be valid on the returned VMA.
> 
> What I'm missing here, and in the previous patch introducing the
> refcount (also see refcount_t), is _why_ we need the refcount thing at
> all.

The need for the VMA's refcount is to ensure that the VMA will remain 
until the end of the SPF handler. This is a consequence of the use of 
RCU instead of SRCU to protect the RB tree.

I was not aware of the refcount_t type, it would be better here to avoid 
wrapping.

> My original plan was to use SRCU, which at the time was not complete
> enough so I abused/hacked preemptible RCU, but that is no longer the
> case, SRCU has all the required bits and pieces.

When I did test using SRCU it was involving a lot a scheduling to run 
the SRCU callback mechanism. In some workload the impact on the 
perfomance was significant [1].

I can't see this overhead using RCU.

> 
> Also; the initial motivation was prefaulting large VMAs and the
> contention on mmap was killing things; but similarly, the contention on
> the refcount (I did try that) killed things just the same.

Doing prefaulting should be doable, I'll try to think further about that.

Regarding the refcount, I should I missed something, this is an atomic 
counter, so there should not be contention on it but cache exclusivity, 
not ideal I agree but I can't see what else to use here.

> So I'm really sad to see the refcount return; and without any apparent
> justification.

I'm not opposed to use another mechanism here, but SRCU didn't show good 
performance with some workload, and I can't see how to use RCU without a 
reference counter here. So please, advise.

Thanks,
Laurent.

[1] 
https://lore.kernel.org/linux-mm/7ca80231-fe02-a3a7-84bc-ce81690ea051@intel.com/
Laurent Dufour April 24, 2019, 2:39 p.m. UTC | #5
Le 22/04/2019 à 22:57, Jerome Glisse a écrit :
> On Tue, Apr 16, 2019 at 03:45:12PM +0200, Laurent Dufour wrote:
>> This allows to search for a VMA structure without holding the mmap_sem.
>>
>> The search is repeated while the mm seqlock is changing and until we found
>> a valid VMA.
>>
>> While under the RCU protection, a reference is taken on the VMA, so the
>> caller must call put_vma() once it not more need the VMA structure.
>>
>> At the time a VMA is inserted in the MM RB tree, in vma_rb_insert(), a
>> reference is taken to the VMA by calling get_vma().
>>
>> When removing a VMA from the MM RB tree, the VMA is not release immediately
>> but at the end of the RCU grace period through vm_rcu_put(). This ensures
>> that the VMA remains allocated until the end the RCU grace period.
>>
>> Since the vm_file pointer, if valid, is released in put_vma(), there is no
>> guarantee that the file pointer will be valid on the returned VMA.
>>
>> Signed-off-by: Laurent Dufour <ldufour@linux.ibm.com>
> 
> Minor comments about comment (i love recursion :)) see below.
> 
> Reviewed-by: Jérôme Glisse <jglisse@redhat.com>

Thanks Jérôme, see my comments to your comments on my comments below ;)

>> ---
>>   include/linux/mm_types.h |  1 +
>>   mm/internal.h            |  5 ++-
>>   mm/mmap.c                | 76 ++++++++++++++++++++++++++++++++++++++--
>>   3 files changed, 78 insertions(+), 4 deletions(-)
>>
>> diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
>> index 6a6159e11a3f..9af6694cb95d 100644
>> --- a/include/linux/mm_types.h
>> +++ b/include/linux/mm_types.h
>> @@ -287,6 +287,7 @@ struct vm_area_struct {
>>   
>>   #ifdef CONFIG_SPECULATIVE_PAGE_FAULT
>>   	atomic_t vm_ref_count;
>> +	struct rcu_head vm_rcu;
>>   #endif
>>   	struct rb_node vm_rb;
>>   
>> diff --git a/mm/internal.h b/mm/internal.h
>> index 302382bed406..1e368e4afe3c 100644
>> --- a/mm/internal.h
>> +++ b/mm/internal.h
>> @@ -55,7 +55,10 @@ static inline void put_vma(struct vm_area_struct *vma)
>>   		__free_vma(vma);
>>   }
>>   
>> -#else
>> +extern struct vm_area_struct *find_vma_rcu(struct mm_struct *mm,
>> +					   unsigned long addr);
>> +
>> +#else /* CONFIG_SPECULATIVE_PAGE_FAULT */
>>   
>>   static inline void get_vma(struct vm_area_struct *vma)
>>   {
>> diff --git a/mm/mmap.c b/mm/mmap.c
>> index c106440dcae7..34bf261dc2c8 100644
>> --- a/mm/mmap.c
>> +++ b/mm/mmap.c
>> @@ -179,6 +179,18 @@ static inline void mm_write_sequnlock(struct mm_struct *mm)
>>   {
>>   	write_sequnlock(&mm->mm_seq);
>>   }
>> +
>> +static void __vm_rcu_put(struct rcu_head *head)
>> +{
>> +	struct vm_area_struct *vma = container_of(head, struct vm_area_struct,
>> +						  vm_rcu);
>> +	put_vma(vma);
>> +}
>> +static void vm_rcu_put(struct vm_area_struct *vma)
>> +{
>> +	VM_BUG_ON_VMA(!RB_EMPTY_NODE(&vma->vm_rb), vma);
>> +	call_rcu(&vma->vm_rcu, __vm_rcu_put);
>> +}
>>   #else
>>   static inline void mm_write_seqlock(struct mm_struct *mm)
>>   {
>> @@ -190,6 +202,8 @@ static inline void mm_write_sequnlock(struct mm_struct *mm)
>>   
>>   void __free_vma(struct vm_area_struct *vma)
>>   {
>> +	if (IS_ENABLED(CONFIG_SPECULATIVE_PAGE_FAULT))
>> +		VM_BUG_ON_VMA(!RB_EMPTY_NODE(&vma->vm_rb), vma);
>>   	mpol_put(vma_policy(vma));
>>   	vm_area_free(vma);
>>   }
>> @@ -197,11 +211,24 @@ void __free_vma(struct vm_area_struct *vma)
>>   /*
>>    * Close a vm structure and free it, returning the next.
>>    */
>> -static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
>> +static struct vm_area_struct *__remove_vma(struct vm_area_struct *vma)
>>   {
>>   	struct vm_area_struct *next = vma->vm_next;
>>   
>>   	might_sleep();
>> +	if (IS_ENABLED(CONFIG_SPECULATIVE_PAGE_FAULT) &&
>> +	    !RB_EMPTY_NODE(&vma->vm_rb)) {
>> +		/*
>> +		 * If the VMA is still linked in the RB tree, we must release
>> +		 * that reference by calling put_vma().
>> +		 * This should only happen when called from exit_mmap().
>> +		 * We forcely clear the node to satisfy the chec in
>                                                          ^
> Typo: chec -> check

Yep

> 
>> +		 * __free_vma(). This is safe since the RB tree is not walked
>> +		 * anymore.
>> +		 */
>> +		RB_CLEAR_NODE(&vma->vm_rb);
>> +		put_vma(vma);
>> +	}
>>   	if (vma->vm_ops && vma->vm_ops->close)
>>   		vma->vm_ops->close(vma);
>>   	if (vma->vm_file)
>> @@ -211,6 +238,13 @@ static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
>>   	return next;
>>   }
>>   
>> +static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
>> +{
>> +	if (IS_ENABLED(CONFIG_SPECULATIVE_PAGE_FAULT))
>> +		VM_BUG_ON_VMA(!RB_EMPTY_NODE(&vma->vm_rb), vma);
> 
> Adding a comment here explaining the BUG_ON so people can understand
> what is wrong if that happens. For instance:
> 
> /*
>   * remove_vma() should be call only once a vma have been remove from the rbtree
>   * at which point the vma->vm_rb is an empty node. The exception is when vmas
>   * are destroy through exit_mmap() in which case we do not bother updating the
>   * rbtree (see comment in __remove_vma()).
>   */

I agree !


>> +	return __remove_vma(vma);
>> +}
>> +
>>   static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long flags,
>>   		struct list_head *uf);
>>   SYSCALL_DEFINE1(brk, unsigned long, brk)
>> @@ -475,7 +509,7 @@ static inline void vma_rb_insert(struct vm_area_struct *vma,
>>   
>>   	/* All rb_subtree_gap values must be consistent prior to insertion */
>>   	validate_mm_rb(root, NULL);
>> -
>> +	get_vma(vma);
>>   	rb_insert_augmented(&vma->vm_rb, root, &vma_gap_callbacks);
>>   }
>>   
>> @@ -491,6 +525,14 @@ static void __vma_rb_erase(struct vm_area_struct *vma, struct mm_struct *mm)
>>   	mm_write_seqlock(mm);
>>   	rb_erase_augmented(&vma->vm_rb, root, &vma_gap_callbacks);
>>   	mm_write_sequnlock(mm);	/* wmb */
>> +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT
>> +	/*
>> +	 * Ensure the removal is complete before clearing the node.
>> +	 * Matched by vma_has_changed()/handle_speculative_fault().
>> +	 */
>> +	RB_CLEAR_NODE(&vma->vm_rb);
>> +	vm_rcu_put(vma);
>> +#endif
>>   }
>>   
>>   static __always_inline void vma_rb_erase_ignore(struct vm_area_struct *vma,
>> @@ -2331,6 +2373,34 @@ struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr)
>>   
>>   EXPORT_SYMBOL(find_vma);
>>   
>> +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT
>> +/*
>> + * Like find_vma() but under the protection of RCU and the mm sequence counter.
>> + * The vma returned has to be relaesed by the caller through the call to
>> + * put_vma()
>> + */
>> +struct vm_area_struct *find_vma_rcu(struct mm_struct *mm, unsigned long addr)
>> +{
>> +	struct vm_area_struct *vma = NULL;
>> +	unsigned int seq;
>> +
>> +	do {
>> +		if (vma)
>> +			put_vma(vma);
>> +
>> +		seq = read_seqbegin(&mm->mm_seq);
>> +
>> +		rcu_read_lock();
>> +		vma = find_vma(mm, addr);
>> +		if (vma)
>> +			get_vma(vma);
>> +		rcu_read_unlock();
>> +	} while (read_seqretry(&mm->mm_seq, seq));
>> +
>> +	return vma;
>> +}
>> +#endif
>> +
>>   /*
>>    * Same as find_vma, but also return a pointer to the previous VMA in *pprev.
>>    */
>> @@ -3231,7 +3301,7 @@ void exit_mmap(struct mm_struct *mm)
>>   	while (vma) {
>>   		if (vma->vm_flags & VM_ACCOUNT)
>>   			nr_accounted += vma_pages(vma);
>> -		vma = remove_vma(vma);
>> +		vma = __remove_vma(vma);
>>   	}
>>   	vm_unacct_memory(nr_accounted);
>>   }
>> -- 
>> 2.21.0
>>
>
diff mbox series

Patch

diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 6a6159e11a3f..9af6694cb95d 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -287,6 +287,7 @@  struct vm_area_struct {
 
 #ifdef CONFIG_SPECULATIVE_PAGE_FAULT
 	atomic_t vm_ref_count;
+	struct rcu_head vm_rcu;
 #endif
 	struct rb_node vm_rb;
 
diff --git a/mm/internal.h b/mm/internal.h
index 302382bed406..1e368e4afe3c 100644
--- a/mm/internal.h
+++ b/mm/internal.h
@@ -55,7 +55,10 @@  static inline void put_vma(struct vm_area_struct *vma)
 		__free_vma(vma);
 }
 
-#else
+extern struct vm_area_struct *find_vma_rcu(struct mm_struct *mm,
+					   unsigned long addr);
+
+#else /* CONFIG_SPECULATIVE_PAGE_FAULT */
 
 static inline void get_vma(struct vm_area_struct *vma)
 {
diff --git a/mm/mmap.c b/mm/mmap.c
index c106440dcae7..34bf261dc2c8 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -179,6 +179,18 @@  static inline void mm_write_sequnlock(struct mm_struct *mm)
 {
 	write_sequnlock(&mm->mm_seq);
 }
+
+static void __vm_rcu_put(struct rcu_head *head)
+{
+	struct vm_area_struct *vma = container_of(head, struct vm_area_struct,
+						  vm_rcu);
+	put_vma(vma);
+}
+static void vm_rcu_put(struct vm_area_struct *vma)
+{
+	VM_BUG_ON_VMA(!RB_EMPTY_NODE(&vma->vm_rb), vma);
+	call_rcu(&vma->vm_rcu, __vm_rcu_put);
+}
 #else
 static inline void mm_write_seqlock(struct mm_struct *mm)
 {
@@ -190,6 +202,8 @@  static inline void mm_write_sequnlock(struct mm_struct *mm)
 
 void __free_vma(struct vm_area_struct *vma)
 {
+	if (IS_ENABLED(CONFIG_SPECULATIVE_PAGE_FAULT))
+		VM_BUG_ON_VMA(!RB_EMPTY_NODE(&vma->vm_rb), vma);
 	mpol_put(vma_policy(vma));
 	vm_area_free(vma);
 }
@@ -197,11 +211,24 @@  void __free_vma(struct vm_area_struct *vma)
 /*
  * Close a vm structure and free it, returning the next.
  */
-static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
+static struct vm_area_struct *__remove_vma(struct vm_area_struct *vma)
 {
 	struct vm_area_struct *next = vma->vm_next;
 
 	might_sleep();
+	if (IS_ENABLED(CONFIG_SPECULATIVE_PAGE_FAULT) &&
+	    !RB_EMPTY_NODE(&vma->vm_rb)) {
+		/*
+		 * If the VMA is still linked in the RB tree, we must release
+		 * that reference by calling put_vma().
+		 * This should only happen when called from exit_mmap().
+		 * We forcely clear the node to satisfy the chec in
+		 * __free_vma(). This is safe since the RB tree is not walked
+		 * anymore.
+		 */
+		RB_CLEAR_NODE(&vma->vm_rb);
+		put_vma(vma);
+	}
 	if (vma->vm_ops && vma->vm_ops->close)
 		vma->vm_ops->close(vma);
 	if (vma->vm_file)
@@ -211,6 +238,13 @@  static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
 	return next;
 }
 
+static struct vm_area_struct *remove_vma(struct vm_area_struct *vma)
+{
+	if (IS_ENABLED(CONFIG_SPECULATIVE_PAGE_FAULT))
+		VM_BUG_ON_VMA(!RB_EMPTY_NODE(&vma->vm_rb), vma);
+	return __remove_vma(vma);
+}
+
 static int do_brk_flags(unsigned long addr, unsigned long request, unsigned long flags,
 		struct list_head *uf);
 SYSCALL_DEFINE1(brk, unsigned long, brk)
@@ -475,7 +509,7 @@  static inline void vma_rb_insert(struct vm_area_struct *vma,
 
 	/* All rb_subtree_gap values must be consistent prior to insertion */
 	validate_mm_rb(root, NULL);
-
+	get_vma(vma);
 	rb_insert_augmented(&vma->vm_rb, root, &vma_gap_callbacks);
 }
 
@@ -491,6 +525,14 @@  static void __vma_rb_erase(struct vm_area_struct *vma, struct mm_struct *mm)
 	mm_write_seqlock(mm);
 	rb_erase_augmented(&vma->vm_rb, root, &vma_gap_callbacks);
 	mm_write_sequnlock(mm);	/* wmb */
+#ifdef CONFIG_SPECULATIVE_PAGE_FAULT
+	/*
+	 * Ensure the removal is complete before clearing the node.
+	 * Matched by vma_has_changed()/handle_speculative_fault().
+	 */
+	RB_CLEAR_NODE(&vma->vm_rb);
+	vm_rcu_put(vma);
+#endif
 }
 
 static __always_inline void vma_rb_erase_ignore(struct vm_area_struct *vma,
@@ -2331,6 +2373,34 @@  struct vm_area_struct *find_vma(struct mm_struct *mm, unsigned long addr)
 
 EXPORT_SYMBOL(find_vma);
 
+#ifdef CONFIG_SPECULATIVE_PAGE_FAULT
+/*
+ * Like find_vma() but under the protection of RCU and the mm sequence counter.
+ * The vma returned has to be relaesed by the caller through the call to
+ * put_vma()
+ */
+struct vm_area_struct *find_vma_rcu(struct mm_struct *mm, unsigned long addr)
+{
+	struct vm_area_struct *vma = NULL;
+	unsigned int seq;
+
+	do {
+		if (vma)
+			put_vma(vma);
+
+		seq = read_seqbegin(&mm->mm_seq);
+
+		rcu_read_lock();
+		vma = find_vma(mm, addr);
+		if (vma)
+			get_vma(vma);
+		rcu_read_unlock();
+	} while (read_seqretry(&mm->mm_seq, seq));
+
+	return vma;
+}
+#endif
+
 /*
  * Same as find_vma, but also return a pointer to the previous VMA in *pprev.
  */
@@ -3231,7 +3301,7 @@  void exit_mmap(struct mm_struct *mm)
 	while (vma) {
 		if (vma->vm_flags & VM_ACCOUNT)
 			nr_accounted += vma_pages(vma);
-		vma = remove_vma(vma);
+		vma = __remove_vma(vma);
 	}
 	vm_unacct_memory(nr_accounted);
 }