diff mbox series

[RFC,08/13] fs/userfaultfd: complete reads asynchronously

Message ID 20201129004548.1619714-9-namit@vmware.com (mailing list archive)
State New, archived
Headers show
Series fs/userfaultfd: support iouring and polling | expand

Commit Message

Nadav Amit Nov. 29, 2020, 12:45 a.m. UTC
From: Nadav Amit <namit@vmware.com>

Complete reads asynchronously to allow io_uring to complete reads
asynchronously.

Reads, which report page-faults and events, can only be performed
asynchronously if the read is performed into a kernel buffer, and
therefore guarantee that no page-fault would occur during the completion
of the read. Otherwise, we would have needed to handle nested
page-faults or do expensive pinning/unpinning of the pages into which
the read is performed.

Userfaultfd holds in its context the kiocb and iov_iter that would be
used for the next asynchronous read (can be extended later into a list
to hold more than a single enqueued read).  If such a buffer is
available and a fault occurs, the fault is reported to the user and the
fault is added to the fault workqueue instead of the pending-fault
workqueue.

There is a need to prevent a race between synchronous and asynchronous
reads, so reads will first use buffers that were previous enqueued and
only later pending-faults and events. For this matter a new
"notification" lock is introduced that is held while enqueuing new
events and pending faults and during event reads. It may be possible to
use the fd_wqh.lock instead, but having a separate lock for the matter
seems cleaner.

Cc: Jens Axboe <axboe@kernel.dk>
Cc: Andrea Arcangeli <aarcange@redhat.com>
Cc: Peter Xu <peterx@redhat.com>
Cc: Alexander Viro <viro@zeniv.linux.org.uk>
Cc: io-uring@vger.kernel.org
Cc: linux-fsdevel@vger.kernel.org
Cc: linux-kernel@vger.kernel.org
Cc: linux-mm@kvack.org
Signed-off-by: Nadav Amit <namit@vmware.com>
---
 fs/userfaultfd.c | 265 +++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 235 insertions(+), 30 deletions(-)
diff mbox series

Patch

diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index 6333b4632742..db1a963f6ae2 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -44,9 +44,10 @@  enum userfaultfd_state {
  *
  * Locking order:
  *	fd_wqh.lock
- *		fault_pending_wqh.lock
- *			fault_wqh.lock
- *		event_wqh.lock
+ *		notification_lock
+ *			fault_pending_wqh.lock
+ *				fault_wqh.lock
+ *			event_wqh.lock
  *
  * To avoid deadlocks, IRQs must be disabled when taking any of the above locks,
  * since fd_wqh.lock is taken by aio_poll() while it's holding a lock that's
@@ -79,6 +80,16 @@  struct userfaultfd_ctx {
 	struct mm_struct *mm;
 	/* controlling process files as they might be different than current */
 	struct files_struct *files;
+	/*
+	 * lock for sync and async userfaultfd reads, which must be held when
+	 * enqueueing into fault_pending_wqh or event_wqh, upon userfaultfd
+	 * reads and on accesses of iocb_callback and to.
+	 */
+	spinlock_t notification_lock;
+	/* kiocb struct that is used for the next asynchronous read */
+	struct kiocb *iocb_callback;
+	/* the iterator that is used for the next asynchronous read */
+	struct iov_iter to;
 };
 
 struct userfaultfd_fork_ctx {
@@ -356,6 +367,53 @@  static inline long userfaultfd_get_blocking_state(unsigned int flags)
 	return TASK_UNINTERRUPTIBLE;
 }
 
+static bool userfaultfd_get_async_complete_locked(struct userfaultfd_ctx *ctx,
+				struct kiocb **iocb, struct iov_iter *iter)
+{
+	if (!ctx->released)
+		lockdep_assert_held(&ctx->notification_lock);
+
+	if (ctx->iocb_callback == NULL)
+		return false;
+
+	*iocb = ctx->iocb_callback;
+	*iter = ctx->to;
+
+	ctx->iocb_callback = NULL;
+	ctx->to.kvec = NULL;
+	return true;
+}
+
+static bool userfaultfd_get_async_complete(struct userfaultfd_ctx *ctx,
+				struct kiocb **iocb, struct iov_iter *iter)
+{
+	bool r;
+
+	spin_lock_irq(&ctx->notification_lock);
+	r = userfaultfd_get_async_complete_locked(ctx, iocb, iter);
+	spin_unlock_irq(&ctx->notification_lock);
+	return r;
+}
+
+static void userfaultfd_copy_async_msg(struct kiocb *iocb,
+				       struct iov_iter *iter,
+				       struct uffd_msg *msg,
+				       int ret)
+{
+
+	const struct kvec *kvec = iter->kvec;
+
+	if (ret == 0)
+		ret = copy_to_iter(msg, sizeof(*msg), iter);
+
+	/* Should never fail as we guarantee that we use a kernel buffer */
+	WARN_ON_ONCE(ret != sizeof(*msg));
+	iocb->ki_complete(iocb, ret, 0);
+
+	kfree(kvec);
+	iter->kvec = NULL;
+}
+
 /*
  * The locking rules involved in returning VM_FAULT_RETRY depending on
  * FAULT_FLAG_ALLOW_RETRY, FAULT_FLAG_RETRY_NOWAIT and
@@ -380,6 +438,10 @@  vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 	bool must_wait;
 	long blocking_state;
 	bool poll;
+	bool async = false;
+	struct kiocb *iocb;
+	struct iov_iter iter;
+	wait_queue_head_t *wqh;
 
 	/*
 	 * We don't do userfault handling for the final child pid update.
@@ -489,12 +551,29 @@  vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 
 	blocking_state = userfaultfd_get_blocking_state(vmf->flags);
 
-	spin_lock_irq(&ctx->fault_pending_wqh.lock);
+	/*
+	 * Abuse fd_wqh.lock to protect against concurrent reads to avoid a
+	 * scenario in which a fault/event is queued, and read returns
+	 * -EIOCBQUEUED.
+	 */
+	spin_lock_irq(&ctx->notification_lock);
+	async = userfaultfd_get_async_complete_locked(ctx, &iocb, &iter);
+	wqh = &ctx->fault_pending_wqh;
+
+	if (async)
+		wqh = &ctx->fault_wqh;
+
 	/*
 	 * After the __add_wait_queue the uwq is visible to userland
 	 * through poll/read().
 	 */
-	__add_wait_queue(&ctx->fault_pending_wqh, &uwq.wq);
+	spin_lock(&wqh->lock);
+
+	__add_wait_queue(wqh, &uwq.wq);
+
+	/* Ensure it is queued before userspace is informed. */
+	smp_wmb();
+
 	/*
 	 * The smp_mb() after __set_current_state prevents the reads
 	 * following the spin_unlock to happen before the list_add in
@@ -504,7 +583,15 @@  vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 	if (!poll)
 		set_current_state(blocking_state);
 
-	spin_unlock_irq(&ctx->fault_pending_wqh.lock);
+	spin_unlock(&wqh->lock);
+	spin_unlock_irq(&ctx->notification_lock);
+
+	/*
+	 * Do the copy after the lock is relinquished to avoid circular lock
+	 * dependencies.
+	 */
+	if (async)
+		userfaultfd_copy_async_msg(iocb, &iter, &uwq.msg, 0);
 
 	if (!is_vm_hugetlb_page(vmf->vma))
 		must_wait = userfaultfd_must_wait(ctx, vmf->address, vmf->flags,
@@ -516,7 +603,9 @@  vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 	mmap_read_unlock(mm);
 
 	if (likely(must_wait && !READ_ONCE(ctx->released))) {
-		wake_up_poll(&ctx->fd_wqh, EPOLLIN);
+		if (!async)
+			wake_up_poll(&ctx->fd_wqh, EPOLLIN);
+
 		if (poll) {
 			while (!READ_ONCE(uwq.waken) && !READ_ONCE(ctx->released) &&
 			       !signal_pending(current)) {
@@ -544,13 +633,21 @@  vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 	 * kernel stack can be released after the list_del_init.
 	 */
 	if (!list_empty_careful(&uwq.wq.entry)) {
-		spin_lock_irq(&ctx->fault_pending_wqh.lock);
+		local_irq_disable();
+		if (!async)
+			spin_lock(&ctx->fault_pending_wqh.lock);
+		spin_lock(&ctx->fault_wqh.lock);
+
 		/*
 		 * No need of list_del_init(), the uwq on the stack
 		 * will be freed shortly anyway.
 		 */
 		list_del(&uwq.wq.entry);
-		spin_unlock_irq(&ctx->fault_pending_wqh.lock);
+
+		spin_unlock(&ctx->fault_wqh.lock);
+		if (!async)
+			spin_unlock(&ctx->fault_pending_wqh.lock);
+		local_irq_enable();
 	}
 
 	/*
@@ -563,10 +660,17 @@  vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
 	return ret;
 }
 
+
+static int resolve_userfault_fork(struct userfaultfd_ctx *ctx,
+				  struct userfaultfd_ctx *new,
+				  struct uffd_msg *msg);
+
 static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
 					      struct userfaultfd_wait_queue *ewq)
 {
 	struct userfaultfd_ctx *release_new_ctx;
+	struct iov_iter iter;
+	struct kiocb *iocb;
 
 	if (WARN_ON_ONCE(current->flags & PF_EXITING))
 		goto out;
@@ -575,12 +679,42 @@  static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
 	init_waitqueue_entry(&ewq->wq, current);
 	release_new_ctx = NULL;
 
-	spin_lock_irq(&ctx->event_wqh.lock);
+retry:
+	spin_lock_irq(&ctx->notification_lock);
+
 	/*
-	 * After the __add_wait_queue the uwq is visible to userland
-	 * through poll/read().
+	 * Submit asynchronously when needed, and release the notification lock
+	 * as soon as the event is either queued on the work queue or an entry
+	 * is taken.
+	 */
+	if (userfaultfd_get_async_complete_locked(ctx, &iocb, &iter)) {
+		int r = 0;
+
+		spin_unlock_irq(&ctx->notification_lock);
+		if (ewq->msg.event == UFFD_EVENT_FORK) {
+			struct userfaultfd_ctx *new =
+				(struct userfaultfd_ctx *)(unsigned long)
+					ewq->msg.arg.reserved.reserved1;
+
+			r = resolve_userfault_fork(ctx, new, &ewq->msg);
+		}
+		userfaultfd_copy_async_msg(iocb, &iter, &ewq->msg, r);
+
+		if (r != 0)
+			goto retry;
+
+		goto out;
+	}
+
+	spin_lock(&ctx->event_wqh.lock);
+	/*
+	 * After the __add_wait_queue or the call to ki_complete the uwq is
+	 * visible to userland through poll/read().
 	 */
 	__add_wait_queue(&ctx->event_wqh, &ewq->wq);
+
+	spin_unlock(&ctx->notification_lock);
+
 	for (;;) {
 		set_current_state(TASK_KILLABLE);
 		if (ewq->msg.event == 0)
@@ -683,6 +817,7 @@  int dup_userfaultfd(struct vm_area_struct *vma, struct list_head *fcs)
 		ctx->features = octx->features;
 		ctx->released = false;
 		ctx->mmap_changing = false;
+		ctx->iocb_callback = NULL;
 		ctx->mm = vma->vm_mm;
 		mmgrab(ctx->mm);
 
@@ -854,6 +989,15 @@  void userfaultfd_unmap_complete(struct mm_struct *mm, struct list_head *uf)
 	}
 }
 
+static void userfaultfd_cancel_async_reads(struct userfaultfd_ctx *ctx)
+{
+	struct iov_iter iter;
+	struct kiocb *iocb;
+
+	while (userfaultfd_get_async_complete(ctx, &iocb, &iter))
+		userfaultfd_copy_async_msg(iocb, &iter, NULL, -EBADF);
+}
+
 static int userfaultfd_release(struct inode *inode, struct file *file)
 {
 	struct userfaultfd_ctx *ctx = file->private_data;
@@ -912,6 +1056,8 @@  static int userfaultfd_release(struct inode *inode, struct file *file)
 	__wake_up(&ctx->fault_wqh, TASK_NORMAL, 1, &range);
 	spin_unlock_irq(&ctx->fault_pending_wqh.lock);
 
+	userfaultfd_cancel_async_reads(ctx);
+
 	/* Flush pending events that may still wait on event_wqh */
 	wake_up_all(&ctx->event_wqh);
 
@@ -1032,8 +1178,39 @@  static int resolve_userfault_fork(struct userfaultfd_ctx *ctx,
 	return 0;
 }
 
-static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
-				    struct uffd_msg *msg)
+static ssize_t userfaultfd_enqueue(struct kiocb *iocb,
+				   struct userfaultfd_ctx *ctx,
+				   struct iov_iter *to)
+{
+	lockdep_assert_irqs_disabled();
+
+	if (!to)
+		return -EAGAIN;
+
+	if (is_sync_kiocb(iocb) ||
+	    (!iov_iter_is_bvec(to) && !iov_iter_is_kvec(to)))
+		return -EAGAIN;
+
+	/* Check again if there are pending events */
+	if (waitqueue_active(&ctx->fault_pending_wqh) ||
+	    waitqueue_active(&ctx->event_wqh))
+		return -EAGAIN;
+
+	/*
+	 * Check that there is no other callback already registered, as
+	 * we only support one at the moment.
+	 */
+	if (ctx->iocb_callback)
+		return -EAGAIN;
+
+	ctx->iocb_callback = iocb;
+	ctx->to = *to;
+	return -EIOCBQUEUED;
+}
+
+static ssize_t userfaultfd_ctx_read(struct kiocb *iocb,
+				    struct userfaultfd_ctx *ctx, int no_wait,
+				    struct uffd_msg *msg, struct iov_iter *to)
 {
 	ssize_t ret;
 	DECLARE_WAITQUEUE(wait, current);
@@ -1051,6 +1228,7 @@  static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
 	/* always take the fd_wqh lock before the fault_pending_wqh lock */
 	spin_lock_irq(&ctx->fd_wqh.lock);
 	__add_wait_queue(&ctx->fd_wqh, &wait);
+	spin_lock(&ctx->notification_lock);
 	for (;;) {
 		set_current_state(TASK_INTERRUPTIBLE);
 		spin_lock(&ctx->fault_pending_wqh.lock);
@@ -1122,21 +1300,23 @@  static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
 			ret = 0;
 		}
 		spin_unlock(&ctx->event_wqh.lock);
-		if (!ret)
-			break;
 
-		if (signal_pending(current)) {
+		if (ret == -EAGAIN && signal_pending(current))
 			ret = -ERESTARTSYS;
+
+		if (ret == -EAGAIN && no_wait)
+			ret = userfaultfd_enqueue(iocb, ctx, to);
+
+		if (no_wait || ret != -EAGAIN)
 			break;
-		}
-		if (no_wait) {
-			ret = -EAGAIN;
-			break;
-		}
+
+		spin_unlock(&ctx->notification_lock);
 		spin_unlock_irq(&ctx->fd_wqh.lock);
 		schedule();
 		spin_lock_irq(&ctx->fd_wqh.lock);
+		spin_lock(&ctx->notification_lock);
 	}
+	spin_unlock(&ctx->notification_lock);
 	__remove_wait_queue(&ctx->fd_wqh, &wait);
 	__set_current_state(TASK_RUNNING);
 	spin_unlock_irq(&ctx->fd_wqh.lock);
@@ -1202,20 +1382,38 @@  static ssize_t userfaultfd_read_iter(struct kiocb *iocb, struct iov_iter *to)
 	ssize_t _ret, ret = 0;
 	struct uffd_msg msg;
 	int no_wait = file->f_flags & O_NONBLOCK;
+	struct iov_iter _to, *async_to = NULL;
 
-	if (ctx->state == UFFD_STATE_WAIT_API)
+	if (ctx->state == UFFD_STATE_WAIT_API || READ_ONCE(ctx->released))
 		return -EINVAL;
 
+	/* Duplicate before taking the lock */
+	if (no_wait && !is_sync_kiocb(iocb) &&
+	    (iov_iter_is_bvec(to) || iov_iter_is_kvec(to))) {
+		async_to = &_to;
+		dup_iter(async_to, to, GFP_KERNEL);
+	}
+
 	for (;;) {
-		if (iov_iter_count(to) < sizeof(msg))
-			return ret ? ret : -EINVAL;
-		_ret = userfaultfd_ctx_read(ctx, no_wait, &msg);
-		if (_ret < 0)
-			return ret ? ret : _ret;
+		if (iov_iter_count(to) < sizeof(msg)) {
+			if (!ret)
+				ret = -EINVAL;
+			break;
+		}
+		_ret = userfaultfd_ctx_read(iocb, ctx, no_wait, &msg, async_to);
+		if (_ret < 0) {
+			if (ret == 0)
+				ret = _ret;
+			break;
+		}
+		async_to = NULL;
 
 		_ret = copy_to_iter(&msg, sizeof(msg), to);
-		if (_ret != sizeof(msg))
-			return ret ? ret : -EINVAL;
+		if (_ret != sizeof(msg)) {
+			if (ret == 0)
+				ret = -EINVAL;
+			break;
+		}
 
 		ret += sizeof(msg);
 
@@ -1225,6 +1423,11 @@  static ssize_t userfaultfd_read_iter(struct kiocb *iocb, struct iov_iter *to)
 		 */
 		no_wait = O_NONBLOCK;
 	}
+
+	if (ret != -EIOCBQUEUED && async_to != NULL)
+		kfree(async_to->kvec);
+
+	return ret;
 }
 
 static void __wake_userfault(struct userfaultfd_ctx *ctx,
@@ -1997,6 +2200,7 @@  static void init_once_userfaultfd_ctx(void *mem)
 	init_waitqueue_head(&ctx->event_wqh);
 	init_waitqueue_head(&ctx->fd_wqh);
 	seqcount_spinlock_init(&ctx->refile_seq, &ctx->fault_pending_wqh.lock);
+	spin_lock_init(&ctx->notification_lock);
 }
 
 SYSCALL_DEFINE1(userfaultfd, int, flags)
@@ -2027,6 +2231,7 @@  SYSCALL_DEFINE1(userfaultfd, int, flags)
 	ctx->released = false;
 	ctx->mmap_changing = false;
 	ctx->mm = current->mm;
+	ctx->iocb_callback = NULL;
 	/* prevent the mm struct to be freed */
 	mmgrab(ctx->mm);