diff mbox series

[v4,08/10] userfaultfd: add UFFDIO_CONTINUE ioctl

Message ID 20210204183433.1431202-9-axelrasmussen@google.com (mailing list archive)
State New, archived
Headers show
Series userfaultfd: add minor fault handling | expand

Commit Message

Axel Rasmussen Feb. 4, 2021, 6:34 p.m. UTC
This ioctl is how userspace ought to resolve "minor" userfaults. The
idea is, userspace is notified that a minor fault has occurred. It might
change the contents of the page using its second non-UFFD mapping, or
not. Then, it calls UFFDIO_CONTINUE to tell the kernel "I have ensured
the page contents are correct, carry on setting up the mapping".

Note that it doesn't make much sense to use UFFDIO_{COPY,ZEROPAGE} for
MINOR registered VMAs. ZEROPAGE maps the VMA to the zero page; but in
the minor fault case, we already have some pre-existing underlying page.
Likewise, UFFDIO_COPY isn't useful if we have a second non-UFFD mapping.
We'd just use memcpy() or similar instead.

It turns out hugetlb_mcopy_atomic_pte() already does very close to what
we want, if an existing page is provided via `struct page **pagep`. We
already special-case the behavior a bit for the UFFDIO_ZEROPAGE case, so
just extend that design: add an enum for the three modes of operation,
and make the small adjustments needed for the MCOPY_ATOMIC_CONTINUE
case. (Basically, look up the existing page, and avoid adding the
existing page to the page cache or calling set_page_huge_active() on
it.)

Signed-off-by: Axel Rasmussen <axelrasmussen@google.com>
---
 fs/userfaultfd.c                 | 67 ++++++++++++++++++++++++++++++++
 include/linux/hugetlb.h          |  3 ++
 include/linux/userfaultfd_k.h    | 18 +++++++++
 include/uapi/linux/userfaultfd.h | 21 +++++++++-
 mm/hugetlb.c                     | 36 ++++++++++-------
 mm/userfaultfd.c                 | 49 ++++++++++++++---------
 6 files changed, 160 insertions(+), 34 deletions(-)

Comments

Peter Xu Feb. 8, 2021, 11:54 p.m. UTC | #1
On Thu, Feb 04, 2021 at 10:34:31AM -0800, Axel Rasmussen wrote:
> +enum mcopy_atomic_mode {
> +	/* A normal copy_from_user into the destination range. */
> +	MCOPY_ATOMIC_NORMAL,
> +	/* Don't copy; map the destination range to the zero page. */
> +	MCOPY_ATOMIC_ZEROPAGE,
> +	/* Just setup the dst_vma, without modifying the underlying page(s). */

"setup the dst_vma" sounds odd.  How about "install pte with the existing page
in the page cache"?

> +	MCOPY_ATOMIC_CONTINUE,
> +};

[...]

> @@ -4749,22 +4754,27 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
>  		hugepage_add_new_anon_rmap(page, dst_vma, dst_addr);
>  	}
>  
> -	_dst_pte = make_huge_pte(dst_vma, page, dst_vma->vm_flags & VM_WRITE);
> -	if (dst_vma->vm_flags & VM_WRITE)
> +	dst_pte_flags = dst_vma->vm_flags & VM_WRITE;
> +	/* For CONTINUE on a non-shared VMA, don't set VM_WRITE for CoW. */
> +	if (mode == MCOPY_ATOMIC_CONTINUE && !vm_shared)
> +		dst_pte_flags &= ~VM_WRITE;

I agree it should work but it's odd to explicitly remove a VM_WRITE bit, since
imho what we want to do is not changing vma or vma flags but deciding whether
to keep the write bit in the ptes.  How about as simple as:

        bool writable;

        if (mode == MCOPY_ATOMIC_CONTINUE && !vm_shared)
            writable = false;
        else
            writable = dst_vma->vm_flags & VM_WRITE;

        _dst_pte = make_huge_pte(dst_vma, page, writable);
        if (writable)
        	_dst_pte = huge_pte_mkdirty(_dst_pte);

?

> +	_dst_pte = make_huge_pte(dst_vma, page, dst_pte_flags);
> +	if (dst_pte_flags & VM_WRITE)
>  		_dst_pte = huge_pte_mkdirty(_dst_pte);
>  	_dst_pte = pte_mkyoung(_dst_pte);
>  
>  	set_huge_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
>  
>  	(void)huge_ptep_set_access_flags(dst_vma, dst_addr, dst_pte, _dst_pte,
> -					dst_vma->vm_flags & VM_WRITE);
> +					 dst_pte_flags);
>  	hugetlb_count_add(pages_per_huge_page(h), dst_mm);
>  
>  	/* No need to invalidate - it was non-present before */
>  	update_mmu_cache(dst_vma, dst_addr, dst_pte);
>  
>  	spin_unlock(ptl);
> -	set_page_huge_active(page);
> +	if (mode != MCOPY_ATOMIC_CONTINUE)
> +		set_page_huge_active(page);

This has been changed to SetHPageMigratable(page) in akpm-next by Mike's new
series.  So maybe it's time to rebase your series to that starting from the
next post.

>  	if (vm_shared)
>  		unlock_page(page);

After removing the shared restriction, I think we need:

        if (vm_shared || (mode == MCOPY_ATOMIC_CONTINUE))
        	unlock_page(page);

Since we seem to check (mode == MCOPY_ATOMIC_CONTINUE) a lot, maybe we can
introduce a temp var for that too.

>  	ret = 0;
> diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
> index b2ce61c1b50d..7bf83ffa456b 100644
> --- a/mm/userfaultfd.c
> +++ b/mm/userfaultfd.c
> @@ -207,7 +207,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
>  					      unsigned long dst_start,
>  					      unsigned long src_start,
>  					      unsigned long len,
> -					      bool zeropage)
> +					      enum mcopy_atomic_mode mode)
>  {
>  	int vm_alloc_shared = dst_vma->vm_flags & VM_SHARED;
>  	int vm_shared = dst_vma->vm_flags & VM_SHARED;
> @@ -227,7 +227,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
>  	 * by THP.  Since we can not reliably insert a zero page, this
>  	 * feature is not supported.
>  	 */
> -	if (zeropage) {
> +	if (mode == MCOPY_ATOMIC_ZEROPAGE) {
>  		mmap_read_unlock(dst_mm);
>  		return -EINVAL;
>  	}
> @@ -273,8 +273,6 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
>  	}
>  
>  	while (src_addr < src_start + len) {
> -		pte_t dst_pteval;
> -
>  		BUG_ON(dst_addr >= dst_start + len);
>  
>  		/*
> @@ -297,16 +295,17 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
>  			goto out_unlock;
>  		}
>  
> -		err = -EEXIST;
> -		dst_pteval = huge_ptep_get(dst_pte);
> -		if (!huge_pte_none(dst_pteval)) {
> -			mutex_unlock(&hugetlb_fault_mutex_table[hash]);
> -			i_mmap_unlock_read(mapping);
> -			goto out_unlock;
> +		if (mode != MCOPY_ATOMIC_CONTINUE) {
> +			if (!huge_pte_none(huge_ptep_get(dst_pte))) {

Maybe merge the two "if"s?

> +				err = -EEXIST;
> +				mutex_unlock(&hugetlb_fault_mutex_table[hash]);
> +				i_mmap_unlock_read(mapping);
> +				goto out_unlock;
> +			}
>  		}
>  
>  		err = hugetlb_mcopy_atomic_pte(dst_mm, dst_pte, dst_vma,
> -						dst_addr, src_addr, &page);
> +					       dst_addr, src_addr, mode, &page);
>  
>  		mutex_unlock(&hugetlb_fault_mutex_table[hash]);
>  		i_mmap_unlock_read(mapping);
> @@ -408,7 +407,7 @@ extern ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
>  				      unsigned long dst_start,
>  				      unsigned long src_start,
>  				      unsigned long len,
> -				      bool zeropage);
> +				      enum mcopy_atomic_mode mode);
>  #endif /* CONFIG_HUGETLB_PAGE */
>  
>  static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> @@ -417,10 +416,14 @@ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
>  						unsigned long dst_addr,
>  						unsigned long src_addr,
>  						struct page **page,
> -						bool zeropage,
> +						enum mcopy_atomic_mode mode,
>  						bool wp_copy)
>  {
>  	ssize_t err;
> +	bool zeropage = (mode == MCOPY_ATOMIC_ZEROPAGE);
> +
> +	if (mode == MCOPY_ATOMIC_CONTINUE)
> +		return -EINVAL;

So you still passed in the mode into mfill_atomic_pte() just to make sure
CONTINUE is not called there.  It's okay, but again I think it's not extremely
necessary: we should make sure to fail early at the entry of uffdio_continue()
by checking against the vma type to be hugetlb, rather than reaching here.

Thanks,
Axel Rasmussen Feb. 10, 2021, 6 p.m. UTC | #2
On Mon, Feb 8, 2021 at 3:54 PM Peter Xu <peterx@redhat.com> wrote:
>
> On Thu, Feb 04, 2021 at 10:34:31AM -0800, Axel Rasmussen wrote:
> > +enum mcopy_atomic_mode {
> > +     /* A normal copy_from_user into the destination range. */
> > +     MCOPY_ATOMIC_NORMAL,
> > +     /* Don't copy; map the destination range to the zero page. */
> > +     MCOPY_ATOMIC_ZEROPAGE,
> > +     /* Just setup the dst_vma, without modifying the underlying page(s). */
>
> "setup the dst_vma" sounds odd.  How about "install pte with the existing page
> in the page cache"?
>
> > +     MCOPY_ATOMIC_CONTINUE,
> > +};
>
> [...]
>
> > @@ -4749,22 +4754,27 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
> >               hugepage_add_new_anon_rmap(page, dst_vma, dst_addr);
> >       }
> >
> > -     _dst_pte = make_huge_pte(dst_vma, page, dst_vma->vm_flags & VM_WRITE);
> > -     if (dst_vma->vm_flags & VM_WRITE)
> > +     dst_pte_flags = dst_vma->vm_flags & VM_WRITE;
> > +     /* For CONTINUE on a non-shared VMA, don't set VM_WRITE for CoW. */
> > +     if (mode == MCOPY_ATOMIC_CONTINUE && !vm_shared)
> > +             dst_pte_flags &= ~VM_WRITE;
>
> I agree it should work but it's odd to explicitly remove a VM_WRITE bit, since
> imho what we want to do is not changing vma or vma flags but deciding whether
> to keep the write bit in the ptes.  How about as simple as:
>
>         bool writable;
>
>         if (mode == MCOPY_ATOMIC_CONTINUE && !vm_shared)
>             writable = false;
>         else
>             writable = dst_vma->vm_flags & VM_WRITE;
>
>         _dst_pte = make_huge_pte(dst_vma, page, writable);
>         if (writable)
>                 _dst_pte = huge_pte_mkdirty(_dst_pte);
>
> ?
>
> > +     _dst_pte = make_huge_pte(dst_vma, page, dst_pte_flags);
> > +     if (dst_pte_flags & VM_WRITE)
> >               _dst_pte = huge_pte_mkdirty(_dst_pte);
> >       _dst_pte = pte_mkyoung(_dst_pte);
> >
> >       set_huge_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
> >
> >       (void)huge_ptep_set_access_flags(dst_vma, dst_addr, dst_pte, _dst_pte,
> > -                                     dst_vma->vm_flags & VM_WRITE);
> > +                                      dst_pte_flags);
> >       hugetlb_count_add(pages_per_huge_page(h), dst_mm);
> >
> >       /* No need to invalidate - it was non-present before */
> >       update_mmu_cache(dst_vma, dst_addr, dst_pte);
> >
> >       spin_unlock(ptl);
> > -     set_page_huge_active(page);
> > +     if (mode != MCOPY_ATOMIC_CONTINUE)
> > +             set_page_huge_active(page);
>
> This has been changed to SetHPageMigratable(page) in akpm-next by Mike's new
> series.  So maybe it's time to rebase your series to that starting from the
> next post.
>
> >       if (vm_shared)
> >               unlock_page(page);
>
> After removing the shared restriction, I think we need:
>
>         if (vm_shared || (mode == MCOPY_ATOMIC_CONTINUE))
>                 unlock_page(page);
>
> Since we seem to check (mode == MCOPY_ATOMIC_CONTINUE) a lot, maybe we can
> introduce a temp var for that too.
>
> >       ret = 0;
> > diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
> > index b2ce61c1b50d..7bf83ffa456b 100644
> > --- a/mm/userfaultfd.c
> > +++ b/mm/userfaultfd.c
> > @@ -207,7 +207,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
> >                                             unsigned long dst_start,
> >                                             unsigned long src_start,
> >                                             unsigned long len,
> > -                                           bool zeropage)
> > +                                           enum mcopy_atomic_mode mode)
> >  {
> >       int vm_alloc_shared = dst_vma->vm_flags & VM_SHARED;
> >       int vm_shared = dst_vma->vm_flags & VM_SHARED;
> > @@ -227,7 +227,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
> >        * by THP.  Since we can not reliably insert a zero page, this
> >        * feature is not supported.
> >        */
> > -     if (zeropage) {
> > +     if (mode == MCOPY_ATOMIC_ZEROPAGE) {
> >               mmap_read_unlock(dst_mm);
> >               return -EINVAL;
> >       }
> > @@ -273,8 +273,6 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
> >       }
> >
> >       while (src_addr < src_start + len) {
> > -             pte_t dst_pteval;
> > -
> >               BUG_ON(dst_addr >= dst_start + len);
> >
> >               /*
> > @@ -297,16 +295,17 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
> >                       goto out_unlock;
> >               }
> >
> > -             err = -EEXIST;
> > -             dst_pteval = huge_ptep_get(dst_pte);
> > -             if (!huge_pte_none(dst_pteval)) {
> > -                     mutex_unlock(&hugetlb_fault_mutex_table[hash]);
> > -                     i_mmap_unlock_read(mapping);
> > -                     goto out_unlock;
> > +             if (mode != MCOPY_ATOMIC_CONTINUE) {
> > +                     if (!huge_pte_none(huge_ptep_get(dst_pte))) {
>
> Maybe merge the two "if"s?
>
> > +                             err = -EEXIST;
> > +                             mutex_unlock(&hugetlb_fault_mutex_table[hash]);
> > +                             i_mmap_unlock_read(mapping);
> > +                             goto out_unlock;
> > +                     }
> >               }
> >
> >               err = hugetlb_mcopy_atomic_pte(dst_mm, dst_pte, dst_vma,
> > -                                             dst_addr, src_addr, &page);
> > +                                            dst_addr, src_addr, mode, &page);
> >
> >               mutex_unlock(&hugetlb_fault_mutex_table[hash]);
> >               i_mmap_unlock_read(mapping);
> > @@ -408,7 +407,7 @@ extern ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
> >                                     unsigned long dst_start,
> >                                     unsigned long src_start,
> >                                     unsigned long len,
> > -                                   bool zeropage);
> > +                                   enum mcopy_atomic_mode mode);
> >  #endif /* CONFIG_HUGETLB_PAGE */
> >
> >  static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> > @@ -417,10 +416,14 @@ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> >                                               unsigned long dst_addr,
> >                                               unsigned long src_addr,
> >                                               struct page **page,
> > -                                             bool zeropage,
> > +                                             enum mcopy_atomic_mode mode,
> >                                               bool wp_copy)
> >  {
> >       ssize_t err;
> > +     bool zeropage = (mode == MCOPY_ATOMIC_ZEROPAGE);
> > +
> > +     if (mode == MCOPY_ATOMIC_CONTINUE)
> > +             return -EINVAL;
>
> So you still passed in the mode into mfill_atomic_pte() just to make sure
> CONTINUE is not called there.  It's okay, but again I think it's not extremely
> necessary: we should make sure to fail early at the entry of uffdio_continue()
> by checking against the vma type to be hugetlb, rather than reaching here.

Hmm, it's not quite as simple as that. We don't have the dst_vma yet
in uffdio_continue(), __mcopy_atomic looks it up.

I'd prefer not to look it up in uffdio_continue(), because I think
that means changing the API so all the ioctls look up the vma, and
then plumb it into __mcopy_atomic. (We don't want to look it up twice,
since each lookup has to traverse the rbtree.) This is complicated too
by the fact that the ioctl handlers would need to perform various
validation / checks - e.g., acquiring mmap_lock, dealing with
*mmap_changing, validating the range, ....

We can move the enforcement up one more layer, into __mcopy_atomic,
easily enough, though.



The other comments above I agree with, so I'll send a v5. :)

>
> Thanks,
>
> --
> Peter Xu
>
Peter Xu Feb. 10, 2021, 7:06 p.m. UTC | #3
On Wed, Feb 10, 2021 at 10:00:21AM -0800, Axel Rasmussen wrote:
> > >  static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> > > @@ -417,10 +416,14 @@ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> > >                                               unsigned long dst_addr,
> > >                                               unsigned long src_addr,
> > >                                               struct page **page,
> > > -                                             bool zeropage,
> > > +                                             enum mcopy_atomic_mode mode,
> > >                                               bool wp_copy)
> > >  {
> > >       ssize_t err;
> > > +     bool zeropage = (mode == MCOPY_ATOMIC_ZEROPAGE);
> > > +
> > > +     if (mode == MCOPY_ATOMIC_CONTINUE)
> > > +             return -EINVAL;
> >
> > So you still passed in the mode into mfill_atomic_pte() just to make sure
> > CONTINUE is not called there.  It's okay, but again I think it's not extremely
> > necessary: we should make sure to fail early at the entry of uffdio_continue()
> > by checking against the vma type to be hugetlb, rather than reaching here.
> 
> Hmm, it's not quite as simple as that. We don't have the dst_vma yet
> in uffdio_continue(), __mcopy_atomic looks it up.
> 
> I'd prefer not to look it up in uffdio_continue(), because I think
> that means changing the API so all the ioctls look up the vma, and
> then plumb it into __mcopy_atomic. (We don't want to look it up twice,
> since each lookup has to traverse the rbtree.) This is complicated too
> by the fact that the ioctl handlers would need to perform various
> validation / checks - e.g., acquiring mmap_lock, dealing with
> *mmap_changing, validating the range, ....

Sure.

> 
> We can move the enforcement up one more layer, into __mcopy_atomic,
> easily enough, though.

Right, that sounds good to me.  It should be right after the "if
(!vma_is_anonymous(dst_vma) && !vma_is_shmem(dst_vma))" check.

Thanks,
diff mbox series

Patch

diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index c643cf13d957..c311086b6085 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -1536,6 +1536,10 @@  static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 		if (!(uffdio_register.mode & UFFDIO_REGISTER_MODE_WP))
 			ioctls_out &= ~((__u64)1 << _UFFDIO_WRITEPROTECT);
 
+		/* CONTINUE ioctl is only supported for MINOR ranges. */
+		if (!(uffdio_register.mode & UFFDIO_REGISTER_MODE_MINOR))
+			ioctls_out &= ~((__u64)1 << _UFFDIO_CONTINUE);
+
 		/*
 		 * Now that we scanned all vmas we can already tell
 		 * userland which ioctls methods are guaranteed to
@@ -1889,6 +1893,66 @@  static int userfaultfd_writeprotect(struct userfaultfd_ctx *ctx,
 	return ret;
 }
 
+static int userfaultfd_continue(struct userfaultfd_ctx *ctx, unsigned long arg)
+{
+	__s64 ret;
+	struct uffdio_continue uffdio_continue;
+	struct uffdio_continue __user *user_uffdio_continue;
+	struct userfaultfd_wake_range range;
+
+	user_uffdio_continue = (struct uffdio_continue __user *)arg;
+
+	ret = -EAGAIN;
+	if (READ_ONCE(ctx->mmap_changing))
+		goto out;
+
+	ret = -EFAULT;
+	if (copy_from_user(&uffdio_continue, user_uffdio_continue,
+			   /* don't copy the output fields */
+			   sizeof(uffdio_continue) - (sizeof(__s64))))
+		goto out;
+
+	ret = validate_range(ctx->mm, &uffdio_continue.range.start,
+			     uffdio_continue.range.len);
+	if (ret)
+		goto out;
+
+	ret = -EINVAL;
+	/* double check for wraparound just in case. */
+	if (uffdio_continue.range.start + uffdio_continue.range.len <=
+	    uffdio_continue.range.start) {
+		goto out;
+	}
+	if (uffdio_continue.mode & ~UFFDIO_CONTINUE_MODE_DONTWAKE)
+		goto out;
+
+	if (mmget_not_zero(ctx->mm)) {
+		ret = mcopy_continue(ctx->mm, uffdio_continue.range.start,
+				     uffdio_continue.range.len,
+				     &ctx->mmap_changing);
+		mmput(ctx->mm);
+	} else {
+		return -ESRCH;
+	}
+
+	if (unlikely(put_user(ret, &user_uffdio_continue->mapped)))
+		return -EFAULT;
+	if (ret < 0)
+		goto out;
+
+	/* len == 0 would wake all */
+	BUG_ON(!ret);
+	range.len = ret;
+	if (!(uffdio_continue.mode & UFFDIO_CONTINUE_MODE_DONTWAKE)) {
+		range.start = uffdio_continue.range.start;
+		wake_userfault(ctx, &range);
+	}
+	ret = range.len == uffdio_continue.range.len ? 0 : -EAGAIN;
+
+out:
+	return ret;
+}
+
 static inline unsigned int uffd_ctx_features(__u64 user_features)
 {
 	/*
@@ -1973,6 +2037,9 @@  static long userfaultfd_ioctl(struct file *file, unsigned cmd,
 	case UFFDIO_WRITEPROTECT:
 		ret = userfaultfd_writeprotect(ctx, arg);
 		break;
+	case UFFDIO_CONTINUE:
+		ret = userfaultfd_continue(ctx, arg);
+		break;
 	}
 	return ret;
 }
diff --git a/include/linux/hugetlb.h b/include/linux/hugetlb.h
index 261c3284015d..9d4badc73cc8 100644
--- a/include/linux/hugetlb.h
+++ b/include/linux/hugetlb.h
@@ -11,6 +11,7 @@ 
 #include <linux/kref.h>
 #include <linux/pgtable.h>
 #include <linux/gfp.h>
+#include <linux/userfaultfd_k.h>
 
 struct ctl_table;
 struct user_struct;
@@ -139,6 +140,7 @@  int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm, pte_t *dst_pte,
 				struct vm_area_struct *dst_vma,
 				unsigned long dst_addr,
 				unsigned long src_addr,
+				enum mcopy_atomic_mode mode,
 				struct page **pagep);
 #endif /* CONFIG_USERFAULTFD */
 int hugetlb_reserve_pages(struct inode *inode, long from, long to,
@@ -317,6 +319,7 @@  static inline int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
 						struct vm_area_struct *dst_vma,
 						unsigned long dst_addr,
 						unsigned long src_addr,
+						enum mcopy_atomic_mode mode,
 						struct page **pagep)
 {
 	BUG();
diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h
index e060d5f77cc5..333ee531e8be 100644
--- a/include/linux/userfaultfd_k.h
+++ b/include/linux/userfaultfd_k.h
@@ -37,6 +37,22 @@  extern int sysctl_unprivileged_userfaultfd;
 
 extern vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason);
 
+/*
+ * The mode of operation for __mcopy_atomic and its helpers.
+ *
+ * This is almost an implementation detail (mcopy_atomic below doesn't take this
+ * as a parameter), but it's exposed here because memory-kind-specific
+ * implementations (e.g. hugetlbfs) need to know the mode of operation.
+ */
+enum mcopy_atomic_mode {
+	/* A normal copy_from_user into the destination range. */
+	MCOPY_ATOMIC_NORMAL,
+	/* Don't copy; map the destination range to the zero page. */
+	MCOPY_ATOMIC_ZEROPAGE,
+	/* Just setup the dst_vma, without modifying the underlying page(s). */
+	MCOPY_ATOMIC_CONTINUE,
+};
+
 extern ssize_t mcopy_atomic(struct mm_struct *dst_mm, unsigned long dst_start,
 			    unsigned long src_start, unsigned long len,
 			    bool *mmap_changing, __u64 mode);
@@ -44,6 +60,8 @@  extern ssize_t mfill_zeropage(struct mm_struct *dst_mm,
 			      unsigned long dst_start,
 			      unsigned long len,
 			      bool *mmap_changing);
+extern ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long dst_start,
+			      unsigned long len, bool *mmap_changing);
 extern int mwriteprotect_range(struct mm_struct *dst_mm,
 			       unsigned long start, unsigned long len,
 			       bool enable_wp, bool *mmap_changing);
diff --git a/include/uapi/linux/userfaultfd.h b/include/uapi/linux/userfaultfd.h
index f24dd4fcbad9..bafbeb1a2624 100644
--- a/include/uapi/linux/userfaultfd.h
+++ b/include/uapi/linux/userfaultfd.h
@@ -40,10 +40,12 @@ 
 	((__u64)1 << _UFFDIO_WAKE |		\
 	 (__u64)1 << _UFFDIO_COPY |		\
 	 (__u64)1 << _UFFDIO_ZEROPAGE |		\
-	 (__u64)1 << _UFFDIO_WRITEPROTECT)
+	 (__u64)1 << _UFFDIO_WRITEPROTECT |	\
+	 (__u64)1 << _UFFDIO_CONTINUE)
 #define UFFD_API_RANGE_IOCTLS_BASIC		\
 	((__u64)1 << _UFFDIO_WAKE |		\
-	 (__u64)1 << _UFFDIO_COPY)
+	 (__u64)1 << _UFFDIO_COPY |		\
+	 (__u64)1 << _UFFDIO_CONTINUE)
 
 /*
  * Valid ioctl command number range with this API is from 0x00 to
@@ -59,6 +61,7 @@ 
 #define _UFFDIO_COPY			(0x03)
 #define _UFFDIO_ZEROPAGE		(0x04)
 #define _UFFDIO_WRITEPROTECT		(0x06)
+#define _UFFDIO_CONTINUE		(0x07)
 #define _UFFDIO_API			(0x3F)
 
 /* userfaultfd ioctl ids */
@@ -77,6 +80,8 @@ 
 				      struct uffdio_zeropage)
 #define UFFDIO_WRITEPROTECT	_IOWR(UFFDIO, _UFFDIO_WRITEPROTECT, \
 				      struct uffdio_writeprotect)
+#define UFFDIO_CONTINUE		_IOR(UFFDIO, _UFFDIO_CONTINUE,	\
+				     struct uffdio_continue)
 
 /* read() structure */
 struct uffd_msg {
@@ -268,6 +273,18 @@  struct uffdio_writeprotect {
 	__u64 mode;
 };
 
+struct uffdio_continue {
+	struct uffdio_range range;
+#define UFFDIO_CONTINUE_MODE_DONTWAKE		((__u64)1<<0)
+	__u64 mode;
+
+	/*
+	 * Fields below here are written by the ioctl and must be at the end:
+	 * the copy_from_user will not read past here.
+	 */
+	__s64 mapped;
+};
+
 /*
  * Flags for the userfaultfd(2) system call itself.
  */
diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index 868292cf148a..fd0693996029 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -39,7 +39,6 @@ 
 #include <linux/hugetlb.h>
 #include <linux/hugetlb_cgroup.h>
 #include <linux/node.h>
-#include <linux/userfaultfd_k.h>
 #include <linux/page_owner.h>
 #include "internal.h"
 
@@ -4657,6 +4656,7 @@  int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
 			    struct vm_area_struct *dst_vma,
 			    unsigned long dst_addr,
 			    unsigned long src_addr,
+			    enum mcopy_atomic_mode mode,
 			    struct page **pagep)
 {
 	struct address_space *mapping;
@@ -4668,8 +4668,18 @@  int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
 	spinlock_t *ptl;
 	int ret;
 	struct page *page;
+	vm_flags_t dst_pte_flags;
 
-	if (!*pagep) {
+	mapping = dst_vma->vm_file->f_mapping;
+	idx = vma_hugecache_offset(h, dst_vma, dst_addr);
+
+	if (mode == MCOPY_ATOMIC_CONTINUE) {
+		ret = -EFAULT;
+		page = find_lock_page(mapping, idx);
+		*pagep = NULL;
+		if (!page)
+			goto out;
+	} else if (!*pagep) {
 		ret = -ENOMEM;
 		page = alloc_huge_page(dst_vma, dst_addr, 0);
 		if (IS_ERR(page))
@@ -4698,13 +4708,8 @@  int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
 	 */
 	__SetPageUptodate(page);
 
-	mapping = dst_vma->vm_file->f_mapping;
-	idx = vma_hugecache_offset(h, dst_vma, dst_addr);
-
-	/*
-	 * If shared, add to page cache
-	 */
-	if (vm_shared) {
+	/* Add shared, newly allocated pages to the page cache. */
+	if (vm_shared && mode != MCOPY_ATOMIC_CONTINUE) {
 		size = i_size_read(mapping->host) >> huge_page_shift(h);
 		ret = -EFAULT;
 		if (idx >= size)
@@ -4749,22 +4754,27 @@  int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
 		hugepage_add_new_anon_rmap(page, dst_vma, dst_addr);
 	}
 
-	_dst_pte = make_huge_pte(dst_vma, page, dst_vma->vm_flags & VM_WRITE);
-	if (dst_vma->vm_flags & VM_WRITE)
+	dst_pte_flags = dst_vma->vm_flags & VM_WRITE;
+	/* For CONTINUE on a non-shared VMA, don't set VM_WRITE for CoW. */
+	if (mode == MCOPY_ATOMIC_CONTINUE && !vm_shared)
+		dst_pte_flags &= ~VM_WRITE;
+	_dst_pte = make_huge_pte(dst_vma, page, dst_pte_flags);
+	if (dst_pte_flags & VM_WRITE)
 		_dst_pte = huge_pte_mkdirty(_dst_pte);
 	_dst_pte = pte_mkyoung(_dst_pte);
 
 	set_huge_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
 
 	(void)huge_ptep_set_access_flags(dst_vma, dst_addr, dst_pte, _dst_pte,
-					dst_vma->vm_flags & VM_WRITE);
+					 dst_pte_flags);
 	hugetlb_count_add(pages_per_huge_page(h), dst_mm);
 
 	/* No need to invalidate - it was non-present before */
 	update_mmu_cache(dst_vma, dst_addr, dst_pte);
 
 	spin_unlock(ptl);
-	set_page_huge_active(page);
+	if (mode != MCOPY_ATOMIC_CONTINUE)
+		set_page_huge_active(page);
 	if (vm_shared)
 		unlock_page(page);
 	ret = 0;
diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index b2ce61c1b50d..7bf83ffa456b 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -207,7 +207,7 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 					      unsigned long dst_start,
 					      unsigned long src_start,
 					      unsigned long len,
-					      bool zeropage)
+					      enum mcopy_atomic_mode mode)
 {
 	int vm_alloc_shared = dst_vma->vm_flags & VM_SHARED;
 	int vm_shared = dst_vma->vm_flags & VM_SHARED;
@@ -227,7 +227,7 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 	 * by THP.  Since we can not reliably insert a zero page, this
 	 * feature is not supported.
 	 */
-	if (zeropage) {
+	if (mode == MCOPY_ATOMIC_ZEROPAGE) {
 		mmap_read_unlock(dst_mm);
 		return -EINVAL;
 	}
@@ -273,8 +273,6 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 	}
 
 	while (src_addr < src_start + len) {
-		pte_t dst_pteval;
-
 		BUG_ON(dst_addr >= dst_start + len);
 
 		/*
@@ -297,16 +295,17 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 			goto out_unlock;
 		}
 
-		err = -EEXIST;
-		dst_pteval = huge_ptep_get(dst_pte);
-		if (!huge_pte_none(dst_pteval)) {
-			mutex_unlock(&hugetlb_fault_mutex_table[hash]);
-			i_mmap_unlock_read(mapping);
-			goto out_unlock;
+		if (mode != MCOPY_ATOMIC_CONTINUE) {
+			if (!huge_pte_none(huge_ptep_get(dst_pte))) {
+				err = -EEXIST;
+				mutex_unlock(&hugetlb_fault_mutex_table[hash]);
+				i_mmap_unlock_read(mapping);
+				goto out_unlock;
+			}
 		}
 
 		err = hugetlb_mcopy_atomic_pte(dst_mm, dst_pte, dst_vma,
-						dst_addr, src_addr, &page);
+					       dst_addr, src_addr, mode, &page);
 
 		mutex_unlock(&hugetlb_fault_mutex_table[hash]);
 		i_mmap_unlock_read(mapping);
@@ -408,7 +407,7 @@  extern ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 				      unsigned long dst_start,
 				      unsigned long src_start,
 				      unsigned long len,
-				      bool zeropage);
+				      enum mcopy_atomic_mode mode);
 #endif /* CONFIG_HUGETLB_PAGE */
 
 static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
@@ -417,10 +416,14 @@  static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
 						unsigned long dst_addr,
 						unsigned long src_addr,
 						struct page **page,
-						bool zeropage,
+						enum mcopy_atomic_mode mode,
 						bool wp_copy)
 {
 	ssize_t err;
+	bool zeropage = (mode == MCOPY_ATOMIC_ZEROPAGE);
+
+	if (mode == MCOPY_ATOMIC_CONTINUE)
+		return -EINVAL;
 
 	/*
 	 * The normal page fault path for a shmem will invoke the
@@ -458,7 +461,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 					      unsigned long dst_start,
 					      unsigned long src_start,
 					      unsigned long len,
-					      bool zeropage,
+					      enum mcopy_atomic_mode mcopy_mode,
 					      bool *mmap_changing,
 					      __u64 mode)
 {
@@ -527,7 +530,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 	 */
 	if (is_vm_hugetlb_page(dst_vma))
 		return  __mcopy_atomic_hugetlb(dst_mm, dst_vma, dst_start,
-						src_start, len, zeropage);
+						src_start, len, mcopy_mode);
 
 	if (!vma_is_anonymous(dst_vma) && !vma_is_shmem(dst_vma))
 		goto out_unlock;
@@ -577,7 +580,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 		BUG_ON(pmd_trans_huge(*dst_pmd));
 
 		err = mfill_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
-				       src_addr, &page, zeropage, wp_copy);
+				       src_addr, &page, mcopy_mode, wp_copy);
 		cond_resched();
 
 		if (unlikely(err == -ENOENT)) {
@@ -626,14 +629,22 @@  ssize_t mcopy_atomic(struct mm_struct *dst_mm, unsigned long dst_start,
 		     unsigned long src_start, unsigned long len,
 		     bool *mmap_changing, __u64 mode)
 {
-	return __mcopy_atomic(dst_mm, dst_start, src_start, len, false,
-			      mmap_changing, mode);
+	return __mcopy_atomic(dst_mm, dst_start, src_start, len,
+			      MCOPY_ATOMIC_NORMAL, mmap_changing, mode);
 }
 
 ssize_t mfill_zeropage(struct mm_struct *dst_mm, unsigned long start,
 		       unsigned long len, bool *mmap_changing)
 {
-	return __mcopy_atomic(dst_mm, start, 0, len, true, mmap_changing, 0);
+	return __mcopy_atomic(dst_mm, start, 0, len, MCOPY_ATOMIC_ZEROPAGE,
+			      mmap_changing, 0);
+}
+
+ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long start,
+		       unsigned long len, bool *mmap_changing)
+{
+	return __mcopy_atomic(dst_mm, start, 0, len, MCOPY_ATOMIC_CONTINUE,
+			      mmap_changing, 0);
 }
 
 int mwriteprotect_range(struct mm_struct *dst_mm, unsigned long start,