diff mbox series

mm: userfaultfd: add UFFDIO_CONTINUE_MODE_WP to install WP PTEs

Message ID 20230214215046.1187635-1-axelrasmussen@google.com (mailing list archive)
State New, archived
Headers show
Series mm: userfaultfd: add UFFDIO_CONTINUE_MODE_WP to install WP PTEs | expand

Commit Message

Axel Rasmussen Feb. 14, 2023, 9:50 p.m. UTC
UFFDIO_COPY already has UFFDIO_COPY_MODE_WP, so when installing a new
PTE to resolve a missing fault, one can install a write-protected one.
This is useful when using UFFDIO_REGISTER_MODE_{MISSING,WP} in
combination.

So, add an analogous UFFDIO_CONTINUE_MODE_WP, which does the same thing
but for *minor* faults.

Rename "wp_copy" arguments to "wp_mode", since the mode now applies
more widely than just to the copy operation.

Update the selftest to do some very basic exercising of the new flag.

Signed-off-by: Axel Rasmussen <axelrasmussen@google.com>
---
 fs/userfaultfd.c                         |  6 ++-
 include/linux/userfaultfd_k.h            |  3 +-
 include/uapi/linux/userfaultfd.h         |  7 ++++
 mm/hugetlb.c                             |  6 +--
 mm/shmem.c                               |  4 +-
 mm/userfaultfd.c                         | 49 ++++++++++++++----------
 tools/testing/selftests/vm/userfaultfd.c |  4 ++
 7 files changed, 51 insertions(+), 28 deletions(-)

Comments

Matthew Wilcox Feb. 14, 2023, 10:17 p.m. UTC | #1
On Tue, Feb 14, 2023 at 01:50:46PM -0800, Axel Rasmussen wrote:
> UFFDIO_COPY already has UFFDIO_COPY_MODE_WP, so when installing a new
> PTE to resolve a missing fault, one can install a write-protected one.
> This is useful when using UFFDIO_REGISTER_MODE_{MISSING,WP} in
> combination.
> 
> So, add an analogous UFFDIO_CONTINUE_MODE_WP, which does the same thing
> but for *minor* faults.
> 
> Rename "wp_copy" arguments to "wp_mode", since the mode now applies
> more widely than just to the copy operation.

Speaking as someone not familiar with this code, i do not care for
this rename.  before, it says "bool wp_copy".  i'm not sure what "wp"
is, but "copy" i understand.  so if it's true, i need to copy and false,
i don't.  now it's "bool wp_mode".  so if it's true, it's a mode and if
it's false it's not a mode?  what?

stepping back for a second, i think the real problem is that the argument
list has got out of control.  8 arguments to hugetlb_mcopy_atomic_pte(),
the same # to shmem_mfill_atomic_pte(), a subtly different 8 to
mfill_atomic_pte().  someone needs to create a struct with a sensible
name to contain all of this.  i bet it'll shrink the code too; marshalling
all these arguments is unlikely to be cheap.

but more importantly, it lets you change the arguments without touching
any of the code which is just passing them through.  look at struct
vm_fault for inspiration (you probably don't want to reuse it, but maybe
you do?)

> Update the selftest to do some very basic exercising of the new flag.
> 
> Signed-off-by: Axel Rasmussen <axelrasmussen@google.com>
> ---
>  fs/userfaultfd.c                         |  6 ++-
>  include/linux/userfaultfd_k.h            |  3 +-
>  include/uapi/linux/userfaultfd.h         |  7 ++++
>  mm/hugetlb.c                             |  6 +--
>  mm/shmem.c                               |  4 +-
>  mm/userfaultfd.c                         | 49 ++++++++++++++----------
>  tools/testing/selftests/vm/userfaultfd.c |  4 ++
>  7 files changed, 51 insertions(+), 28 deletions(-)
> 
> diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> index cc694846617a..545dc033eec8 100644
> --- a/fs/userfaultfd.c
> +++ b/fs/userfaultfd.c
> @@ -1910,13 +1910,15 @@ static int userfaultfd_continue(struct userfaultfd_ctx *ctx, unsigned long arg)
>  	    uffdio_continue.range.start) {
>  		goto out;
>  	}
> -	if (uffdio_continue.mode & ~UFFDIO_CONTINUE_MODE_DONTWAKE)
> +	if (uffdio_continue.mode & ~(UFFDIO_CONTINUE_MODE_DONTWAKE |
> +				     UFFDIO_CONTINUE_MODE_WP))
>  		goto out;
>  
>  	if (mmget_not_zero(ctx->mm)) {
>  		ret = mcopy_continue(ctx->mm, uffdio_continue.range.start,
>  				     uffdio_continue.range.len,
> -				     &ctx->mmap_changing);
> +				     &ctx->mmap_changing,
> +				     uffdio_continue.mode);
>  		mmput(ctx->mm);
>  	} else {
>  		return -ESRCH;
> diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h
> index 9df0b9a762cc..a53aa56e78ad 100644
> --- a/include/linux/userfaultfd_k.h
> +++ b/include/linux/userfaultfd_k.h
> @@ -69,7 +69,8 @@ extern ssize_t mfill_zeropage(struct mm_struct *dst_mm,
>  			      unsigned long len,
>  			      atomic_t *mmap_changing);
>  extern ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long dst_start,
> -			      unsigned long len, atomic_t *mmap_changing);
> +			      unsigned long len, atomic_t *mmap_changing,
> +			      __u64 mode);
>  extern int mwriteprotect_range(struct mm_struct *dst_mm,
>  			       unsigned long start, unsigned long len,
>  			       bool enable_wp, atomic_t *mmap_changing);
> diff --git a/include/uapi/linux/userfaultfd.h b/include/uapi/linux/userfaultfd.h
> index 005e5e306266..14059a0861bf 100644
> --- a/include/uapi/linux/userfaultfd.h
> +++ b/include/uapi/linux/userfaultfd.h
> @@ -297,6 +297,13 @@ struct uffdio_writeprotect {
>  struct uffdio_continue {
>  	struct uffdio_range range;
>  #define UFFDIO_CONTINUE_MODE_DONTWAKE		((__u64)1<<0)
> +	/*
> +	 * UFFDIO_CONTINUE_MODE_WP will map the page write protected on
> +	 * the fly.  UFFDIO_CONTINUE_MODE_WP is available only if the
> +	 * write protected ioctl is implemented for the range
> +	 * according to the uffdio_register.ioctls.
> +	 */
> +#define UFFDIO_CONTINUE_MODE_WP			((__u64)1<<1)
>  	__u64 mode;
>  
>  	/*
> diff --git a/mm/hugetlb.c b/mm/hugetlb.c
> index bdbfeb6fb393..a1bd0b0c6c43 100644
> --- a/mm/hugetlb.c
> +++ b/mm/hugetlb.c
> @@ -6169,7 +6169,7 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
>  			    unsigned long src_addr,
>  			    enum mcopy_atomic_mode mode,
>  			    struct page **pagep,
> -			    bool wp_copy)
> +			    bool wp_mode)
>  {
>  	bool is_continue = (mode == MCOPY_ATOMIC_CONTINUE);
>  	struct hstate *h = hstate_vma(dst_vma);
> @@ -6306,7 +6306,7 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
>  	 * For either: (1) CONTINUE on a non-shared VMA, or (2) UFFDIO_COPY
>  	 * with wp flag set, don't set pte write bit.
>  	 */
> -	if (wp_copy || (is_continue && !vm_shared))
> +	if (wp_mode || (is_continue && !vm_shared))
>  		writable = 0;
>  	else
>  		writable = dst_vma->vm_flags & VM_WRITE;
> @@ -6321,7 +6321,7 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
>  	_dst_pte = huge_pte_mkdirty(_dst_pte);
>  	_dst_pte = pte_mkyoung(_dst_pte);
>  
> -	if (wp_copy)
> +	if (wp_mode)
>  		_dst_pte = huge_pte_mkuffd_wp(_dst_pte);
>  
>  	set_huge_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
> diff --git a/mm/shmem.c b/mm/shmem.c
> index 0005ab2c29af..725e5d5d85ab 100644
> --- a/mm/shmem.c
> +++ b/mm/shmem.c
> @@ -2402,7 +2402,7 @@ int shmem_mfill_atomic_pte(struct mm_struct *dst_mm,
>  			   struct vm_area_struct *dst_vma,
>  			   unsigned long dst_addr,
>  			   unsigned long src_addr,
> -			   bool zeropage, bool wp_copy,
> +			   bool zeropage, bool wp_mode,
>  			   struct page **pagep)
>  {
>  	struct inode *inode = file_inode(dst_vma->vm_file);
> @@ -2493,7 +2493,7 @@ int shmem_mfill_atomic_pte(struct mm_struct *dst_mm,
>  		goto out_release;
>  
>  	ret = mfill_atomic_install_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> -				       &folio->page, true, wp_copy);
> +				       &folio->page, true, wp_mode);
>  	if (ret)
>  		goto out_delete_from_cache;
>  
> diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
> index 0499907b6f1a..2ad4dcf72968 100644
> --- a/mm/userfaultfd.c
> +++ b/mm/userfaultfd.c
> @@ -58,7 +58,7 @@ struct vm_area_struct *find_dst_vma(struct mm_struct *dst_mm,
>  int mfill_atomic_install_pte(struct mm_struct *dst_mm, pmd_t *dst_pmd,
>  			     struct vm_area_struct *dst_vma,
>  			     unsigned long dst_addr, struct page *page,
> -			     bool newly_allocated, bool wp_copy)
> +			     bool newly_allocated, bool wp_mode)
>  {
>  	int ret;
>  	pte_t _dst_pte, *dst_pte;
> @@ -79,7 +79,7 @@ int mfill_atomic_install_pte(struct mm_struct *dst_mm, pmd_t *dst_pmd,
>  	 * Always mark a PTE as write-protected when needed, regardless of
>  	 * VM_WRITE, which the user might change.
>  	 */
> -	if (wp_copy) {
> +	if (wp_mode) {
>  		_dst_pte = pte_mkuffd_wp(_dst_pte);
>  		writable = false;
>  	}
> @@ -147,7 +147,7 @@ static int mcopy_atomic_pte(struct mm_struct *dst_mm,
>  			    unsigned long dst_addr,
>  			    unsigned long src_addr,
>  			    struct page **pagep,
> -			    bool wp_copy)
> +			    bool wp_mode)
>  {
>  	void *page_kaddr;
>  	int ret;
> @@ -208,7 +208,7 @@ static int mcopy_atomic_pte(struct mm_struct *dst_mm,
>  		goto out_release;
>  
>  	ret = mfill_atomic_install_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> -				       page, true, wp_copy);
> +				       page, true, wp_mode);
>  	if (ret)
>  		goto out_release;
>  out:
> @@ -258,7 +258,7 @@ static int mcontinue_atomic_pte(struct mm_struct *dst_mm,
>  				pmd_t *dst_pmd,
>  				struct vm_area_struct *dst_vma,
>  				unsigned long dst_addr,
> -				bool wp_copy)
> +				bool wp_mode)
>  {
>  	struct inode *inode = file_inode(dst_vma->vm_file);
>  	pgoff_t pgoff = linear_page_index(dst_vma, dst_addr);
> @@ -284,7 +284,7 @@ static int mcontinue_atomic_pte(struct mm_struct *dst_mm,
>  	}
>  
>  	ret = mfill_atomic_install_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> -				       page, false, wp_copy);
> +				       page, false, wp_mode);
>  	if (ret)
>  		goto out_release;
>  
> @@ -330,7 +330,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
>  					      unsigned long src_start,
>  					      unsigned long len,
>  					      enum mcopy_atomic_mode mode,
> -					      bool wp_copy)
> +					      bool wp_mode)
>  {
>  	int vm_shared = dst_vma->vm_flags & VM_SHARED;
>  	ssize_t err;
> @@ -427,7 +427,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
>  
>  		err = hugetlb_mcopy_atomic_pte(dst_mm, dst_pte, dst_vma,
>  					       dst_addr, src_addr, mode, &page,
> -					       wp_copy);
> +					       wp_mode);
>  
>  		hugetlb_vma_unlock_read(dst_vma);
>  		mutex_unlock(&hugetlb_fault_mutex_table[hash]);
> @@ -483,7 +483,7 @@ extern ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
>  				      unsigned long src_start,
>  				      unsigned long len,
>  				      enum mcopy_atomic_mode mode,
> -				      bool wp_copy);
> +				      bool wp_mode);
>  #endif /* CONFIG_HUGETLB_PAGE */
>  
>  static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> @@ -493,13 +493,13 @@ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
>  						unsigned long src_addr,
>  						struct page **page,
>  						enum mcopy_atomic_mode mode,
> -						bool wp_copy)
> +						bool wp_mode)
>  {
>  	ssize_t err;
>  
>  	if (mode == MCOPY_ATOMIC_CONTINUE) {
>  		return mcontinue_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> -					    wp_copy);
> +					    wp_mode);
>  	}
>  
>  	/*
> @@ -516,7 +516,7 @@ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
>  		if (mode == MCOPY_ATOMIC_NORMAL)
>  			err = mcopy_atomic_pte(dst_mm, dst_pmd, dst_vma,
>  					       dst_addr, src_addr, page,
> -					       wp_copy);
> +					       wp_mode);
>  		else
>  			err = mfill_zeropage_pte(dst_mm, dst_pmd,
>  						 dst_vma, dst_addr);
> @@ -524,12 +524,21 @@ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
>  		err = shmem_mfill_atomic_pte(dst_mm, dst_pmd, dst_vma,
>  					     dst_addr, src_addr,
>  					     mode != MCOPY_ATOMIC_NORMAL,
> -					     wp_copy, page);
> +					     wp_mode, page);
>  	}
>  
>  	return err;
>  }
>  
> +static inline bool wp_mode_enabled(enum mcopy_atomic_mode mcopy_mode, __u64 mode)
> +{
> +	switch (mode) {
> +	case MCOPY_ATOMIC_NORMAL: return mode & UFFDIO_COPY_MODE_WP;
> +	case MCOPY_ATOMIC_CONTINUE: return mode & UFFDIO_CONTINUE_MODE_WP;
> +	default: return false;
> +	}
> +}
> +
>  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>  					      unsigned long dst_start,
>  					      unsigned long src_start,
> @@ -544,7 +553,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>  	unsigned long src_addr, dst_addr;
>  	long copied;
>  	struct page *page;
> -	bool wp_copy;
> +	bool wp_mode;
>  
>  	/*
>  	 * Sanitize the command parameters:
> @@ -594,8 +603,8 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>  	 * validate 'mode' now that we know the dst_vma: don't allow
>  	 * a wrprotect copy if the userfaultfd didn't register as WP.
>  	 */
> -	wp_copy = mode & UFFDIO_COPY_MODE_WP;
> -	if (wp_copy && !(dst_vma->vm_flags & VM_UFFD_WP))
> +	wp_mode = wp_mode_enabled(mcopy_mode, mode);
> +	if (wp_mode && !(dst_vma->vm_flags & VM_UFFD_WP))
>  		goto out_unlock;
>  
>  	/*
> @@ -604,7 +613,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>  	if (is_vm_hugetlb_page(dst_vma))
>  		return  __mcopy_atomic_hugetlb(dst_mm, dst_vma, dst_start,
>  					       src_start, len, mcopy_mode,
> -					       wp_copy);
> +					       wp_mode);
>  
>  	if (!vma_is_anonymous(dst_vma) && !vma_is_shmem(dst_vma))
>  		goto out_unlock;
> @@ -656,7 +665,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>  		BUG_ON(pmd_trans_huge(*dst_pmd));
>  
>  		err = mfill_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> -				       src_addr, &page, mcopy_mode, wp_copy);
> +				       src_addr, &page, mcopy_mode, wp_mode);
>  		cond_resched();
>  
>  		if (unlikely(err == -ENOENT)) {
> @@ -718,10 +727,10 @@ ssize_t mfill_zeropage(struct mm_struct *dst_mm, unsigned long start,
>  }
>  
>  ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long start,
> -		       unsigned long len, atomic_t *mmap_changing)
> +		       unsigned long len, atomic_t *mmap_changing, __u64 mode)
>  {
>  	return __mcopy_atomic(dst_mm, start, 0, len, MCOPY_ATOMIC_CONTINUE,
> -			      mmap_changing, 0);
> +			      mmap_changing, mode);
>  }
>  
>  void uffd_wp_range(struct mm_struct *dst_mm, struct vm_area_struct *dst_vma,
> diff --git a/tools/testing/selftests/vm/userfaultfd.c b/tools/testing/selftests/vm/userfaultfd.c
> index 7f22844ed704..41c1f9abc481 100644
> --- a/tools/testing/selftests/vm/userfaultfd.c
> +++ b/tools/testing/selftests/vm/userfaultfd.c
> @@ -585,6 +585,8 @@ static void continue_range(int ufd, __u64 start, __u64 len)
>  	req.range.start = start;
>  	req.range.len = len;
>  	req.mode = 0;
> +	if (test_uffdio_wp)
> +		req.mode |= UFFDIO_CONTINUE_MODE_WP;
>  
>  	if (ioctl(ufd, UFFDIO_CONTINUE, &req))
>  		err("UFFDIO_CONTINUE failed for address 0x%" PRIx64,
> @@ -1332,6 +1334,8 @@ static int userfaultfd_minor_test(void)
>  	uffdio_register.range.start = (unsigned long)area_dst_alias;
>  	uffdio_register.range.len = nr_pages * page_size;
>  	uffdio_register.mode = UFFDIO_REGISTER_MODE_MINOR;
> +	if (test_uffdio_wp)
> +		uffdio_register.mode |= UFFDIO_REGISTER_MODE_WP;
>  	if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register))
>  		err("register failure");
>  
> -- 
> 2.39.1.581.gbfd45094c4-goog
> 
>
Axel Rasmussen Feb. 14, 2023, 10:37 p.m. UTC | #2
On Tue, Feb 14, 2023 at 2:17 PM Matthew Wilcox <willy@infradead.org> wrote:
>
> On Tue, Feb 14, 2023 at 01:50:46PM -0800, Axel Rasmussen wrote:
> > UFFDIO_COPY already has UFFDIO_COPY_MODE_WP, so when installing a new
> > PTE to resolve a missing fault, one can install a write-protected one.
> > This is useful when using UFFDIO_REGISTER_MODE_{MISSING,WP} in
> > combination.
> >
> > So, add an analogous UFFDIO_CONTINUE_MODE_WP, which does the same thing
> > but for *minor* faults.
> >
> > Rename "wp_copy" arguments to "wp_mode", since the mode now applies
> > more widely than just to the copy operation.
>
> Speaking as someone not familiar with this code, i do not care for
> this rename.  before, it says "bool wp_copy".  i'm not sure what "wp"
> is, but "copy" i understand.  so if it's true, i need to copy and false,
> i don't.  now it's "bool wp_mode".  so if it's true, it's a mode and if
> it's false it's not a mode?  what?

Agree it's confusing. The "copy" refers to the "UFFDIO_COPY" ioctl,
which specifically allocates a new page, and *copies* its contents
from userspace. Continue, on the other hand, doesn't do any page
allocation or copying, it just installs a PTE for a page already in
the page cache. This is why I thought leaving the "copy" around would
be confusing.

Honestly there are other existing cases where this confusion remains;
e.g. the function mcopy_atomic doesn't necessarily do any page
allocation or copying, because the same code path also handles the
continue case.

Agree that at the very least wp_enabled or something would make more
sense, given it's a bool.

>
> stepping back for a second, i think the real problem is that the argument
> list has got out of control.  8 arguments to hugetlb_mcopy_atomic_pte(),
> the same # to shmem_mfill_atomic_pte(), a subtly different 8 to
> mfill_atomic_pte().  someone needs to create a struct with a sensible
> name to contain all of this.  i bet it'll shrink the code too; marshalling
> all these arguments is unlikely to be cheap.

Agreed, it would likely be a nice cleanup. Peter, any objections? I
wouldn't mind writing a commit to do this sort of refactor, and rebase
my change on top of that.

>
> but more importantly, it lets you change the arguments without touching
> any of the code which is just passing them through.  look at struct
> vm_fault for inspiration (you probably don't want to reuse it, but maybe
> you do?)
>
> > Update the selftest to do some very basic exercising of the new flag.
> >
> > Signed-off-by: Axel Rasmussen <axelrasmussen@google.com>
> > ---
> >  fs/userfaultfd.c                         |  6 ++-
> >  include/linux/userfaultfd_k.h            |  3 +-
> >  include/uapi/linux/userfaultfd.h         |  7 ++++
> >  mm/hugetlb.c                             |  6 +--
> >  mm/shmem.c                               |  4 +-
> >  mm/userfaultfd.c                         | 49 ++++++++++++++----------
> >  tools/testing/selftests/vm/userfaultfd.c |  4 ++
> >  7 files changed, 51 insertions(+), 28 deletions(-)
> >
> > diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> > index cc694846617a..545dc033eec8 100644
> > --- a/fs/userfaultfd.c
> > +++ b/fs/userfaultfd.c
> > @@ -1910,13 +1910,15 @@ static int userfaultfd_continue(struct userfaultfd_ctx *ctx, unsigned long arg)
> >           uffdio_continue.range.start) {
> >               goto out;
> >       }
> > -     if (uffdio_continue.mode & ~UFFDIO_CONTINUE_MODE_DONTWAKE)
> > +     if (uffdio_continue.mode & ~(UFFDIO_CONTINUE_MODE_DONTWAKE |
> > +                                  UFFDIO_CONTINUE_MODE_WP))
> >               goto out;
> >
> >       if (mmget_not_zero(ctx->mm)) {
> >               ret = mcopy_continue(ctx->mm, uffdio_continue.range.start,
> >                                    uffdio_continue.range.len,
> > -                                  &ctx->mmap_changing);
> > +                                  &ctx->mmap_changing,
> > +                                  uffdio_continue.mode);
> >               mmput(ctx->mm);
> >       } else {
> >               return -ESRCH;
> > diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h
> > index 9df0b9a762cc..a53aa56e78ad 100644
> > --- a/include/linux/userfaultfd_k.h
> > +++ b/include/linux/userfaultfd_k.h
> > @@ -69,7 +69,8 @@ extern ssize_t mfill_zeropage(struct mm_struct *dst_mm,
> >                             unsigned long len,
> >                             atomic_t *mmap_changing);
> >  extern ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long dst_start,
> > -                           unsigned long len, atomic_t *mmap_changing);
> > +                           unsigned long len, atomic_t *mmap_changing,
> > +                           __u64 mode);
> >  extern int mwriteprotect_range(struct mm_struct *dst_mm,
> >                              unsigned long start, unsigned long len,
> >                              bool enable_wp, atomic_t *mmap_changing);
> > diff --git a/include/uapi/linux/userfaultfd.h b/include/uapi/linux/userfaultfd.h
> > index 005e5e306266..14059a0861bf 100644
> > --- a/include/uapi/linux/userfaultfd.h
> > +++ b/include/uapi/linux/userfaultfd.h
> > @@ -297,6 +297,13 @@ struct uffdio_writeprotect {
> >  struct uffdio_continue {
> >       struct uffdio_range range;
> >  #define UFFDIO_CONTINUE_MODE_DONTWAKE                ((__u64)1<<0)
> > +     /*
> > +      * UFFDIO_CONTINUE_MODE_WP will map the page write protected on
> > +      * the fly.  UFFDIO_CONTINUE_MODE_WP is available only if the
> > +      * write protected ioctl is implemented for the range
> > +      * according to the uffdio_register.ioctls.
> > +      */
> > +#define UFFDIO_CONTINUE_MODE_WP                      ((__u64)1<<1)
> >       __u64 mode;
> >
> >       /*
> > diff --git a/mm/hugetlb.c b/mm/hugetlb.c
> > index bdbfeb6fb393..a1bd0b0c6c43 100644
> > --- a/mm/hugetlb.c
> > +++ b/mm/hugetlb.c
> > @@ -6169,7 +6169,7 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
> >                           unsigned long src_addr,
> >                           enum mcopy_atomic_mode mode,
> >                           struct page **pagep,
> > -                         bool wp_copy)
> > +                         bool wp_mode)
> >  {
> >       bool is_continue = (mode == MCOPY_ATOMIC_CONTINUE);
> >       struct hstate *h = hstate_vma(dst_vma);
> > @@ -6306,7 +6306,7 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
> >        * For either: (1) CONTINUE on a non-shared VMA, or (2) UFFDIO_COPY
> >        * with wp flag set, don't set pte write bit.
> >        */
> > -     if (wp_copy || (is_continue && !vm_shared))
> > +     if (wp_mode || (is_continue && !vm_shared))
> >               writable = 0;
> >       else
> >               writable = dst_vma->vm_flags & VM_WRITE;
> > @@ -6321,7 +6321,7 @@ int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
> >       _dst_pte = huge_pte_mkdirty(_dst_pte);
> >       _dst_pte = pte_mkyoung(_dst_pte);
> >
> > -     if (wp_copy)
> > +     if (wp_mode)
> >               _dst_pte = huge_pte_mkuffd_wp(_dst_pte);
> >
> >       set_huge_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
> > diff --git a/mm/shmem.c b/mm/shmem.c
> > index 0005ab2c29af..725e5d5d85ab 100644
> > --- a/mm/shmem.c
> > +++ b/mm/shmem.c
> > @@ -2402,7 +2402,7 @@ int shmem_mfill_atomic_pte(struct mm_struct *dst_mm,
> >                          struct vm_area_struct *dst_vma,
> >                          unsigned long dst_addr,
> >                          unsigned long src_addr,
> > -                        bool zeropage, bool wp_copy,
> > +                        bool zeropage, bool wp_mode,
> >                          struct page **pagep)
> >  {
> >       struct inode *inode = file_inode(dst_vma->vm_file);
> > @@ -2493,7 +2493,7 @@ int shmem_mfill_atomic_pte(struct mm_struct *dst_mm,
> >               goto out_release;
> >
> >       ret = mfill_atomic_install_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> > -                                    &folio->page, true, wp_copy);
> > +                                    &folio->page, true, wp_mode);
> >       if (ret)
> >               goto out_delete_from_cache;
> >
> > diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
> > index 0499907b6f1a..2ad4dcf72968 100644
> > --- a/mm/userfaultfd.c
> > +++ b/mm/userfaultfd.c
> > @@ -58,7 +58,7 @@ struct vm_area_struct *find_dst_vma(struct mm_struct *dst_mm,
> >  int mfill_atomic_install_pte(struct mm_struct *dst_mm, pmd_t *dst_pmd,
> >                            struct vm_area_struct *dst_vma,
> >                            unsigned long dst_addr, struct page *page,
> > -                          bool newly_allocated, bool wp_copy)
> > +                          bool newly_allocated, bool wp_mode)
> >  {
> >       int ret;
> >       pte_t _dst_pte, *dst_pte;
> > @@ -79,7 +79,7 @@ int mfill_atomic_install_pte(struct mm_struct *dst_mm, pmd_t *dst_pmd,
> >        * Always mark a PTE as write-protected when needed, regardless of
> >        * VM_WRITE, which the user might change.
> >        */
> > -     if (wp_copy) {
> > +     if (wp_mode) {
> >               _dst_pte = pte_mkuffd_wp(_dst_pte);
> >               writable = false;
> >       }
> > @@ -147,7 +147,7 @@ static int mcopy_atomic_pte(struct mm_struct *dst_mm,
> >                           unsigned long dst_addr,
> >                           unsigned long src_addr,
> >                           struct page **pagep,
> > -                         bool wp_copy)
> > +                         bool wp_mode)
> >  {
> >       void *page_kaddr;
> >       int ret;
> > @@ -208,7 +208,7 @@ static int mcopy_atomic_pte(struct mm_struct *dst_mm,
> >               goto out_release;
> >
> >       ret = mfill_atomic_install_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> > -                                    page, true, wp_copy);
> > +                                    page, true, wp_mode);
> >       if (ret)
> >               goto out_release;
> >  out:
> > @@ -258,7 +258,7 @@ static int mcontinue_atomic_pte(struct mm_struct *dst_mm,
> >                               pmd_t *dst_pmd,
> >                               struct vm_area_struct *dst_vma,
> >                               unsigned long dst_addr,
> > -                             bool wp_copy)
> > +                             bool wp_mode)
> >  {
> >       struct inode *inode = file_inode(dst_vma->vm_file);
> >       pgoff_t pgoff = linear_page_index(dst_vma, dst_addr);
> > @@ -284,7 +284,7 @@ static int mcontinue_atomic_pte(struct mm_struct *dst_mm,
> >       }
> >
> >       ret = mfill_atomic_install_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> > -                                    page, false, wp_copy);
> > +                                    page, false, wp_mode);
> >       if (ret)
> >               goto out_release;
> >
> > @@ -330,7 +330,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
> >                                             unsigned long src_start,
> >                                             unsigned long len,
> >                                             enum mcopy_atomic_mode mode,
> > -                                           bool wp_copy)
> > +                                           bool wp_mode)
> >  {
> >       int vm_shared = dst_vma->vm_flags & VM_SHARED;
> >       ssize_t err;
> > @@ -427,7 +427,7 @@ static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
> >
> >               err = hugetlb_mcopy_atomic_pte(dst_mm, dst_pte, dst_vma,
> >                                              dst_addr, src_addr, mode, &page,
> > -                                            wp_copy);
> > +                                            wp_mode);
> >
> >               hugetlb_vma_unlock_read(dst_vma);
> >               mutex_unlock(&hugetlb_fault_mutex_table[hash]);
> > @@ -483,7 +483,7 @@ extern ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
> >                                     unsigned long src_start,
> >                                     unsigned long len,
> >                                     enum mcopy_atomic_mode mode,
> > -                                   bool wp_copy);
> > +                                   bool wp_mode);
> >  #endif /* CONFIG_HUGETLB_PAGE */
> >
> >  static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> > @@ -493,13 +493,13 @@ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> >                                               unsigned long src_addr,
> >                                               struct page **page,
> >                                               enum mcopy_atomic_mode mode,
> > -                                             bool wp_copy)
> > +                                             bool wp_mode)
> >  {
> >       ssize_t err;
> >
> >       if (mode == MCOPY_ATOMIC_CONTINUE) {
> >               return mcontinue_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> > -                                         wp_copy);
> > +                                         wp_mode);
> >       }
> >
> >       /*
> > @@ -516,7 +516,7 @@ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> >               if (mode == MCOPY_ATOMIC_NORMAL)
> >                       err = mcopy_atomic_pte(dst_mm, dst_pmd, dst_vma,
> >                                              dst_addr, src_addr, page,
> > -                                            wp_copy);
> > +                                            wp_mode);
> >               else
> >                       err = mfill_zeropage_pte(dst_mm, dst_pmd,
> >                                                dst_vma, dst_addr);
> > @@ -524,12 +524,21 @@ static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
> >               err = shmem_mfill_atomic_pte(dst_mm, dst_pmd, dst_vma,
> >                                            dst_addr, src_addr,
> >                                            mode != MCOPY_ATOMIC_NORMAL,
> > -                                          wp_copy, page);
> > +                                          wp_mode, page);
> >       }
> >
> >       return err;
> >  }
> >
> > +static inline bool wp_mode_enabled(enum mcopy_atomic_mode mcopy_mode, __u64 mode)
> > +{
> > +     switch (mode) {
> > +     case MCOPY_ATOMIC_NORMAL: return mode & UFFDIO_COPY_MODE_WP;
> > +     case MCOPY_ATOMIC_CONTINUE: return mode & UFFDIO_CONTINUE_MODE_WP;
> > +     default: return false;
> > +     }
> > +}
> > +
> >  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
> >                                             unsigned long dst_start,
> >                                             unsigned long src_start,
> > @@ -544,7 +553,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
> >       unsigned long src_addr, dst_addr;
> >       long copied;
> >       struct page *page;
> > -     bool wp_copy;
> > +     bool wp_mode;
> >
> >       /*
> >        * Sanitize the command parameters:
> > @@ -594,8 +603,8 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
> >        * validate 'mode' now that we know the dst_vma: don't allow
> >        * a wrprotect copy if the userfaultfd didn't register as WP.
> >        */
> > -     wp_copy = mode & UFFDIO_COPY_MODE_WP;
> > -     if (wp_copy && !(dst_vma->vm_flags & VM_UFFD_WP))
> > +     wp_mode = wp_mode_enabled(mcopy_mode, mode);
> > +     if (wp_mode && !(dst_vma->vm_flags & VM_UFFD_WP))
> >               goto out_unlock;
> >
> >       /*
> > @@ -604,7 +613,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
> >       if (is_vm_hugetlb_page(dst_vma))
> >               return  __mcopy_atomic_hugetlb(dst_mm, dst_vma, dst_start,
> >                                              src_start, len, mcopy_mode,
> > -                                            wp_copy);
> > +                                            wp_mode);
> >
> >       if (!vma_is_anonymous(dst_vma) && !vma_is_shmem(dst_vma))
> >               goto out_unlock;
> > @@ -656,7 +665,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
> >               BUG_ON(pmd_trans_huge(*dst_pmd));
> >
> >               err = mfill_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> > -                                    src_addr, &page, mcopy_mode, wp_copy);
> > +                                    src_addr, &page, mcopy_mode, wp_mode);
> >               cond_resched();
> >
> >               if (unlikely(err == -ENOENT)) {
> > @@ -718,10 +727,10 @@ ssize_t mfill_zeropage(struct mm_struct *dst_mm, unsigned long start,
> >  }
> >
> >  ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long start,
> > -                    unsigned long len, atomic_t *mmap_changing)
> > +                    unsigned long len, atomic_t *mmap_changing, __u64 mode)
> >  {
> >       return __mcopy_atomic(dst_mm, start, 0, len, MCOPY_ATOMIC_CONTINUE,
> > -                           mmap_changing, 0);
> > +                           mmap_changing, mode);
> >  }
> >
> >  void uffd_wp_range(struct mm_struct *dst_mm, struct vm_area_struct *dst_vma,
> > diff --git a/tools/testing/selftests/vm/userfaultfd.c b/tools/testing/selftests/vm/userfaultfd.c
> > index 7f22844ed704..41c1f9abc481 100644
> > --- a/tools/testing/selftests/vm/userfaultfd.c
> > +++ b/tools/testing/selftests/vm/userfaultfd.c
> > @@ -585,6 +585,8 @@ static void continue_range(int ufd, __u64 start, __u64 len)
> >       req.range.start = start;
> >       req.range.len = len;
> >       req.mode = 0;
> > +     if (test_uffdio_wp)
> > +             req.mode |= UFFDIO_CONTINUE_MODE_WP;
> >
> >       if (ioctl(ufd, UFFDIO_CONTINUE, &req))
> >               err("UFFDIO_CONTINUE failed for address 0x%" PRIx64,
> > @@ -1332,6 +1334,8 @@ static int userfaultfd_minor_test(void)
> >       uffdio_register.range.start = (unsigned long)area_dst_alias;
> >       uffdio_register.range.len = nr_pages * page_size;
> >       uffdio_register.mode = UFFDIO_REGISTER_MODE_MINOR;
> > +     if (test_uffdio_wp)
> > +             uffdio_register.mode |= UFFDIO_REGISTER_MODE_WP;
> >       if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register))
> >               err("register failure");
> >
> > --
> > 2.39.1.581.gbfd45094c4-goog
> >
> >
Peter Xu Feb. 15, 2023, 12:32 a.m. UTC | #3
On Tue, Feb 14, 2023 at 02:37:51PM -0800, Axel Rasmussen wrote:
> Agreed, it would likely be a nice cleanup. Peter, any objections? I
> wouldn't mind writing a commit to do this sort of refactor, and rebase
> my change on top of that.

No objection here.  Personally I actually prefer keeping the parameters
around if possible because it's straightforward and no thinking of any
possible indirect accesses all over the place. But maybe growing as long as
8 is still a moot point..  It's just that I don't really know whether it'll
look that good if we put everything into a struct*.

Things like src_start/dst_start/.. do not look good to be there: each layer
could loop over its own range of start/end/... so even if not in the
function parameter we'll need a variable to hold them anyway.

But I do see a few low hanging fruits:

  - I don't see why we need to pass over mmap_changing over all of the
    __mcopy_atomic() callers.  One chance is we simply pass in the ctx* to
    replace "dst_mm + mmap_changing".

  - Merge mcopy_atomic_mode and mode, having last 2 bits for the existing
    three modes, then bit 3 for WP, good enough to set it for the new case.

  - Optionally, we can avoid passing over dst_mm/src_mm all around, when
    dst_vma/src_vma is there?

How about we start from simple?
Mike Rapoport Feb. 15, 2023, 6:37 a.m. UTC | #4
On Tue, Feb 14, 2023 at 07:32:29PM -0500, Peter Xu wrote:
> On Tue, Feb 14, 2023 at 02:37:51PM -0800, Axel Rasmussen wrote:
> > Agreed, it would likely be a nice cleanup. Peter, any objections? I
> > wouldn't mind writing a commit to do this sort of refactor, and rebase
> > my change on top of that.
> 
> No objection here.  Personally I actually prefer keeping the parameters
> around if possible because it's straightforward and no thinking of any
> possible indirect accesses all over the place. But maybe growing as long as
> 8 is still a moot point..  It's just that I don't really know whether it'll
> look that good if we put everything into a struct*.
> 
> Things like src_start/dst_start/.. do not look good to be there: each layer
> could loop over its own range of start/end/... so even if not in the
> function parameter we'll need a variable to hold them anyway.
> 
> But I do see a few low hanging fruits:
> 
>   - I don't see why we need to pass over mmap_changing over all of the
>     __mcopy_atomic() callers.  One chance is we simply pass in the ctx* to
>     replace "dst_mm + mmap_changing".

Now ctx* is completely private to fs/userfaultfd.c and I think it'd be
better to keep it this way.
 
>   - Merge mcopy_atomic_mode and mode, having last 2 bits for the existing
>     three modes, then bit 3 for WP, good enough to set it for the new case.

Agree, having flags instead of an enum and bools sounds better to me.
 
>   - Optionally, we can avoid passing over dst_mm/src_mm all around, when
>     dst_vma/src_vma is there?

+1
 
> How about we start from simple?
> 
> -- 
> Peter Xu
>
diff mbox series

Patch

diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index cc694846617a..545dc033eec8 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -1910,13 +1910,15 @@  static int userfaultfd_continue(struct userfaultfd_ctx *ctx, unsigned long arg)
 	    uffdio_continue.range.start) {
 		goto out;
 	}
-	if (uffdio_continue.mode & ~UFFDIO_CONTINUE_MODE_DONTWAKE)
+	if (uffdio_continue.mode & ~(UFFDIO_CONTINUE_MODE_DONTWAKE |
+				     UFFDIO_CONTINUE_MODE_WP))
 		goto out;
 
 	if (mmget_not_zero(ctx->mm)) {
 		ret = mcopy_continue(ctx->mm, uffdio_continue.range.start,
 				     uffdio_continue.range.len,
-				     &ctx->mmap_changing);
+				     &ctx->mmap_changing,
+				     uffdio_continue.mode);
 		mmput(ctx->mm);
 	} else {
 		return -ESRCH;
diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h
index 9df0b9a762cc..a53aa56e78ad 100644
--- a/include/linux/userfaultfd_k.h
+++ b/include/linux/userfaultfd_k.h
@@ -69,7 +69,8 @@  extern ssize_t mfill_zeropage(struct mm_struct *dst_mm,
 			      unsigned long len,
 			      atomic_t *mmap_changing);
 extern ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long dst_start,
-			      unsigned long len, atomic_t *mmap_changing);
+			      unsigned long len, atomic_t *mmap_changing,
+			      __u64 mode);
 extern int mwriteprotect_range(struct mm_struct *dst_mm,
 			       unsigned long start, unsigned long len,
 			       bool enable_wp, atomic_t *mmap_changing);
diff --git a/include/uapi/linux/userfaultfd.h b/include/uapi/linux/userfaultfd.h
index 005e5e306266..14059a0861bf 100644
--- a/include/uapi/linux/userfaultfd.h
+++ b/include/uapi/linux/userfaultfd.h
@@ -297,6 +297,13 @@  struct uffdio_writeprotect {
 struct uffdio_continue {
 	struct uffdio_range range;
 #define UFFDIO_CONTINUE_MODE_DONTWAKE		((__u64)1<<0)
+	/*
+	 * UFFDIO_CONTINUE_MODE_WP will map the page write protected on
+	 * the fly.  UFFDIO_CONTINUE_MODE_WP is available only if the
+	 * write protected ioctl is implemented for the range
+	 * according to the uffdio_register.ioctls.
+	 */
+#define UFFDIO_CONTINUE_MODE_WP			((__u64)1<<1)
 	__u64 mode;
 
 	/*
diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index bdbfeb6fb393..a1bd0b0c6c43 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -6169,7 +6169,7 @@  int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
 			    unsigned long src_addr,
 			    enum mcopy_atomic_mode mode,
 			    struct page **pagep,
-			    bool wp_copy)
+			    bool wp_mode)
 {
 	bool is_continue = (mode == MCOPY_ATOMIC_CONTINUE);
 	struct hstate *h = hstate_vma(dst_vma);
@@ -6306,7 +6306,7 @@  int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
 	 * For either: (1) CONTINUE on a non-shared VMA, or (2) UFFDIO_COPY
 	 * with wp flag set, don't set pte write bit.
 	 */
-	if (wp_copy || (is_continue && !vm_shared))
+	if (wp_mode || (is_continue && !vm_shared))
 		writable = 0;
 	else
 		writable = dst_vma->vm_flags & VM_WRITE;
@@ -6321,7 +6321,7 @@  int hugetlb_mcopy_atomic_pte(struct mm_struct *dst_mm,
 	_dst_pte = huge_pte_mkdirty(_dst_pte);
 	_dst_pte = pte_mkyoung(_dst_pte);
 
-	if (wp_copy)
+	if (wp_mode)
 		_dst_pte = huge_pte_mkuffd_wp(_dst_pte);
 
 	set_huge_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
diff --git a/mm/shmem.c b/mm/shmem.c
index 0005ab2c29af..725e5d5d85ab 100644
--- a/mm/shmem.c
+++ b/mm/shmem.c
@@ -2402,7 +2402,7 @@  int shmem_mfill_atomic_pte(struct mm_struct *dst_mm,
 			   struct vm_area_struct *dst_vma,
 			   unsigned long dst_addr,
 			   unsigned long src_addr,
-			   bool zeropage, bool wp_copy,
+			   bool zeropage, bool wp_mode,
 			   struct page **pagep)
 {
 	struct inode *inode = file_inode(dst_vma->vm_file);
@@ -2493,7 +2493,7 @@  int shmem_mfill_atomic_pte(struct mm_struct *dst_mm,
 		goto out_release;
 
 	ret = mfill_atomic_install_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
-				       &folio->page, true, wp_copy);
+				       &folio->page, true, wp_mode);
 	if (ret)
 		goto out_delete_from_cache;
 
diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index 0499907b6f1a..2ad4dcf72968 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -58,7 +58,7 @@  struct vm_area_struct *find_dst_vma(struct mm_struct *dst_mm,
 int mfill_atomic_install_pte(struct mm_struct *dst_mm, pmd_t *dst_pmd,
 			     struct vm_area_struct *dst_vma,
 			     unsigned long dst_addr, struct page *page,
-			     bool newly_allocated, bool wp_copy)
+			     bool newly_allocated, bool wp_mode)
 {
 	int ret;
 	pte_t _dst_pte, *dst_pte;
@@ -79,7 +79,7 @@  int mfill_atomic_install_pte(struct mm_struct *dst_mm, pmd_t *dst_pmd,
 	 * Always mark a PTE as write-protected when needed, regardless of
 	 * VM_WRITE, which the user might change.
 	 */
-	if (wp_copy) {
+	if (wp_mode) {
 		_dst_pte = pte_mkuffd_wp(_dst_pte);
 		writable = false;
 	}
@@ -147,7 +147,7 @@  static int mcopy_atomic_pte(struct mm_struct *dst_mm,
 			    unsigned long dst_addr,
 			    unsigned long src_addr,
 			    struct page **pagep,
-			    bool wp_copy)
+			    bool wp_mode)
 {
 	void *page_kaddr;
 	int ret;
@@ -208,7 +208,7 @@  static int mcopy_atomic_pte(struct mm_struct *dst_mm,
 		goto out_release;
 
 	ret = mfill_atomic_install_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
-				       page, true, wp_copy);
+				       page, true, wp_mode);
 	if (ret)
 		goto out_release;
 out:
@@ -258,7 +258,7 @@  static int mcontinue_atomic_pte(struct mm_struct *dst_mm,
 				pmd_t *dst_pmd,
 				struct vm_area_struct *dst_vma,
 				unsigned long dst_addr,
-				bool wp_copy)
+				bool wp_mode)
 {
 	struct inode *inode = file_inode(dst_vma->vm_file);
 	pgoff_t pgoff = linear_page_index(dst_vma, dst_addr);
@@ -284,7 +284,7 @@  static int mcontinue_atomic_pte(struct mm_struct *dst_mm,
 	}
 
 	ret = mfill_atomic_install_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
-				       page, false, wp_copy);
+				       page, false, wp_mode);
 	if (ret)
 		goto out_release;
 
@@ -330,7 +330,7 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 					      unsigned long src_start,
 					      unsigned long len,
 					      enum mcopy_atomic_mode mode,
-					      bool wp_copy)
+					      bool wp_mode)
 {
 	int vm_shared = dst_vma->vm_flags & VM_SHARED;
 	ssize_t err;
@@ -427,7 +427,7 @@  static __always_inline ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 
 		err = hugetlb_mcopy_atomic_pte(dst_mm, dst_pte, dst_vma,
 					       dst_addr, src_addr, mode, &page,
-					       wp_copy);
+					       wp_mode);
 
 		hugetlb_vma_unlock_read(dst_vma);
 		mutex_unlock(&hugetlb_fault_mutex_table[hash]);
@@ -483,7 +483,7 @@  extern ssize_t __mcopy_atomic_hugetlb(struct mm_struct *dst_mm,
 				      unsigned long src_start,
 				      unsigned long len,
 				      enum mcopy_atomic_mode mode,
-				      bool wp_copy);
+				      bool wp_mode);
 #endif /* CONFIG_HUGETLB_PAGE */
 
 static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
@@ -493,13 +493,13 @@  static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
 						unsigned long src_addr,
 						struct page **page,
 						enum mcopy_atomic_mode mode,
-						bool wp_copy)
+						bool wp_mode)
 {
 	ssize_t err;
 
 	if (mode == MCOPY_ATOMIC_CONTINUE) {
 		return mcontinue_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
-					    wp_copy);
+					    wp_mode);
 	}
 
 	/*
@@ -516,7 +516,7 @@  static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
 		if (mode == MCOPY_ATOMIC_NORMAL)
 			err = mcopy_atomic_pte(dst_mm, dst_pmd, dst_vma,
 					       dst_addr, src_addr, page,
-					       wp_copy);
+					       wp_mode);
 		else
 			err = mfill_zeropage_pte(dst_mm, dst_pmd,
 						 dst_vma, dst_addr);
@@ -524,12 +524,21 @@  static __always_inline ssize_t mfill_atomic_pte(struct mm_struct *dst_mm,
 		err = shmem_mfill_atomic_pte(dst_mm, dst_pmd, dst_vma,
 					     dst_addr, src_addr,
 					     mode != MCOPY_ATOMIC_NORMAL,
-					     wp_copy, page);
+					     wp_mode, page);
 	}
 
 	return err;
 }
 
+static inline bool wp_mode_enabled(enum mcopy_atomic_mode mcopy_mode, __u64 mode)
+{
+	switch (mode) {
+	case MCOPY_ATOMIC_NORMAL: return mode & UFFDIO_COPY_MODE_WP;
+	case MCOPY_ATOMIC_CONTINUE: return mode & UFFDIO_CONTINUE_MODE_WP;
+	default: return false;
+	}
+}
+
 static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 					      unsigned long dst_start,
 					      unsigned long src_start,
@@ -544,7 +553,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 	unsigned long src_addr, dst_addr;
 	long copied;
 	struct page *page;
-	bool wp_copy;
+	bool wp_mode;
 
 	/*
 	 * Sanitize the command parameters:
@@ -594,8 +603,8 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 	 * validate 'mode' now that we know the dst_vma: don't allow
 	 * a wrprotect copy if the userfaultfd didn't register as WP.
 	 */
-	wp_copy = mode & UFFDIO_COPY_MODE_WP;
-	if (wp_copy && !(dst_vma->vm_flags & VM_UFFD_WP))
+	wp_mode = wp_mode_enabled(mcopy_mode, mode);
+	if (wp_mode && !(dst_vma->vm_flags & VM_UFFD_WP))
 		goto out_unlock;
 
 	/*
@@ -604,7 +613,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 	if (is_vm_hugetlb_page(dst_vma))
 		return  __mcopy_atomic_hugetlb(dst_mm, dst_vma, dst_start,
 					       src_start, len, mcopy_mode,
-					       wp_copy);
+					       wp_mode);
 
 	if (!vma_is_anonymous(dst_vma) && !vma_is_shmem(dst_vma))
 		goto out_unlock;
@@ -656,7 +665,7 @@  static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
 		BUG_ON(pmd_trans_huge(*dst_pmd));
 
 		err = mfill_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
-				       src_addr, &page, mcopy_mode, wp_copy);
+				       src_addr, &page, mcopy_mode, wp_mode);
 		cond_resched();
 
 		if (unlikely(err == -ENOENT)) {
@@ -718,10 +727,10 @@  ssize_t mfill_zeropage(struct mm_struct *dst_mm, unsigned long start,
 }
 
 ssize_t mcopy_continue(struct mm_struct *dst_mm, unsigned long start,
-		       unsigned long len, atomic_t *mmap_changing)
+		       unsigned long len, atomic_t *mmap_changing, __u64 mode)
 {
 	return __mcopy_atomic(dst_mm, start, 0, len, MCOPY_ATOMIC_CONTINUE,
-			      mmap_changing, 0);
+			      mmap_changing, mode);
 }
 
 void uffd_wp_range(struct mm_struct *dst_mm, struct vm_area_struct *dst_vma,
diff --git a/tools/testing/selftests/vm/userfaultfd.c b/tools/testing/selftests/vm/userfaultfd.c
index 7f22844ed704..41c1f9abc481 100644
--- a/tools/testing/selftests/vm/userfaultfd.c
+++ b/tools/testing/selftests/vm/userfaultfd.c
@@ -585,6 +585,8 @@  static void continue_range(int ufd, __u64 start, __u64 len)
 	req.range.start = start;
 	req.range.len = len;
 	req.mode = 0;
+	if (test_uffdio_wp)
+		req.mode |= UFFDIO_CONTINUE_MODE_WP;
 
 	if (ioctl(ufd, UFFDIO_CONTINUE, &req))
 		err("UFFDIO_CONTINUE failed for address 0x%" PRIx64,
@@ -1332,6 +1334,8 @@  static int userfaultfd_minor_test(void)
 	uffdio_register.range.start = (unsigned long)area_dst_alias;
 	uffdio_register.range.len = nr_pages * page_size;
 	uffdio_register.mode = UFFDIO_REGISTER_MODE_MINOR;
+	if (test_uffdio_wp)
+		uffdio_register.mode |= UFFDIO_REGISTER_MODE_WP;
 	if (ioctl(uffd, UFFDIO_REGISTER, &uffdio_register))
 		err("register failure");