diff mbox series

[for-next,6/7] io_uring: iopoll protect complete_post

Message ID cc6d854065c57c838ca8e8806f707a226b70fd2d.1669203009.git.asml.silence@gmail.com (mailing list archive)
State New
Headers show
Series iopoll cqe posting fixes | expand

Commit Message

Pavel Begunkov Nov. 23, 2022, 11:33 a.m. UTC
io_req_complete_post() may be used by iopoll enabled rings, grab locks
in this case. That requires to pass issue_flags to propagate the locking
state.

Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
---
 io_uring/io_uring.c  | 21 +++++++++++++++------
 io_uring/io_uring.h  | 10 ++++++++--
 io_uring/kbuf.c      |  4 ++--
 io_uring/poll.c      |  2 +-
 io_uring/uring_cmd.c |  2 +-
 5 files changed, 27 insertions(+), 12 deletions(-)
diff mbox series

Patch

diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
index bd9b286eb031..54a9966bbb47 100644
--- a/io_uring/io_uring.c
+++ b/io_uring/io_uring.c
@@ -813,7 +813,7 @@  bool io_post_aux_cqe(struct io_ring_ctx *ctx,
 	return filled;
 }
 
-void io_req_complete_post(struct io_kiocb *req)
+static void __io_req_complete_post(struct io_kiocb *req)
 {
 	struct io_ring_ctx *ctx = req->ctx;
 
@@ -849,9 +849,18 @@  void io_req_complete_post(struct io_kiocb *req)
 	io_cq_unlock_post(ctx);
 }
 
-inline void __io_req_complete(struct io_kiocb *req, unsigned issue_flags)
+void io_req_complete_post(struct io_kiocb *req, unsigned issue_flags)
 {
-	io_req_complete_post(req);
+	if (!(issue_flags & IO_URING_F_UNLOCKED) ||
+	    !(req->ctx->flags & IORING_SETUP_IOPOLL)) {
+		__io_req_complete_post(req);
+	} else {
+		struct io_ring_ctx *ctx = req->ctx;
+
+		mutex_lock(&ctx->uring_lock);
+		__io_req_complete_post(req);
+		mutex_unlock(&ctx->uring_lock);
+	}
 }
 
 void io_req_complete_failed(struct io_kiocb *req, s32 res)
@@ -865,7 +874,7 @@  void io_req_complete_failed(struct io_kiocb *req, s32 res)
 	io_req_set_res(req, res, io_put_kbuf(req, IO_URING_F_UNLOCKED));
 	if (def->fail)
 		def->fail(req);
-	io_req_complete_post(req);
+	io_req_complete_post(req, 0);
 }
 
 /*
@@ -1449,7 +1458,7 @@  void io_req_task_complete(struct io_kiocb *req, bool *locked)
 	if (*locked)
 		io_req_complete_defer(req);
 	else
-		io_req_complete_post(req);
+		io_req_complete_post_tw(req, locked);
 }
 
 /*
@@ -1717,7 +1726,7 @@  static int io_issue_sqe(struct io_kiocb *req, unsigned int issue_flags)
 		if (issue_flags & IO_URING_F_COMPLETE_DEFER)
 			io_req_complete_defer(req);
 		else
-			io_req_complete_post(req);
+			io_req_complete_post(req, issue_flags);
 	} else if (ret != IOU_ISSUE_SKIP_COMPLETE)
 		return ret;
 
diff --git a/io_uring/io_uring.h b/io_uring/io_uring.h
index 002b6cc842a5..e908966f081a 100644
--- a/io_uring/io_uring.h
+++ b/io_uring/io_uring.h
@@ -30,12 +30,18 @@  int io_run_task_work_sig(struct io_ring_ctx *ctx);
 int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked);
 int io_run_local_work(struct io_ring_ctx *ctx);
 void io_req_complete_failed(struct io_kiocb *req, s32 res);
-void __io_req_complete(struct io_kiocb *req, unsigned issue_flags);
-void io_req_complete_post(struct io_kiocb *req);
+void io_req_complete_post(struct io_kiocb *req, unsigned issue_flags);
 bool io_post_aux_cqe(struct io_ring_ctx *ctx, u64 user_data, s32 res, u32 cflags);
 bool io_fill_cqe_aux(struct io_ring_ctx *ctx, u64 user_data, s32 res, u32 cflags);
 void __io_commit_cqring_flush(struct io_ring_ctx *ctx);
 
+static inline void io_req_complete_post_tw(struct io_kiocb *req, bool *locked)
+{
+	unsigned flags = *locked ? 0 : IO_URING_F_UNLOCKED;
+
+	io_req_complete_post(req, flags);
+}
+
 struct page **io_pin_pages(unsigned long ubuf, unsigned long len, int *npages);
 
 struct file *io_file_get_normal(struct io_kiocb *req, int fd);
diff --git a/io_uring/kbuf.c b/io_uring/kbuf.c
index 25cd724ade18..c33b53b7f435 100644
--- a/io_uring/kbuf.c
+++ b/io_uring/kbuf.c
@@ -311,7 +311,7 @@  int io_remove_buffers(struct io_kiocb *req, unsigned int issue_flags)
 
 	/* complete before unlock, IOPOLL may need the lock */
 	io_req_set_res(req, ret, 0);
-	__io_req_complete(req, issue_flags);
+	io_req_complete_post(req, 0);
 	io_ring_submit_unlock(ctx, issue_flags);
 	return IOU_ISSUE_SKIP_COMPLETE;
 }
@@ -460,7 +460,7 @@  int io_provide_buffers(struct io_kiocb *req, unsigned int issue_flags)
 		req_set_fail(req);
 	/* complete before unlock, IOPOLL may need the lock */
 	io_req_set_res(req, ret, 0);
-	__io_req_complete(req, issue_flags);
+	io_req_complete_post(req, 0);
 	io_ring_submit_unlock(ctx, issue_flags);
 	return IOU_ISSUE_SKIP_COMPLETE;
 }
diff --git a/io_uring/poll.c b/io_uring/poll.c
index 2830b7daf952..583fc0d745a6 100644
--- a/io_uring/poll.c
+++ b/io_uring/poll.c
@@ -298,7 +298,7 @@  static void io_apoll_task_func(struct io_kiocb *req, bool *locked)
 	io_poll_tw_hash_eject(req, locked);
 
 	if (ret == IOU_POLL_REMOVE_POLL_USE_RES)
-		io_req_complete_post(req);
+		io_req_complete_post_tw(req, locked);
 	else if (ret == IOU_POLL_DONE)
 		io_req_task_submit(req, locked);
 	else
diff --git a/io_uring/uring_cmd.c b/io_uring/uring_cmd.c
index e50de0b6b9f8..446a189b78b0 100644
--- a/io_uring/uring_cmd.c
+++ b/io_uring/uring_cmd.c
@@ -56,7 +56,7 @@  void io_uring_cmd_done(struct io_uring_cmd *ioucmd, ssize_t ret, ssize_t res2)
 		/* order with io_iopoll_req_issued() checking ->iopoll_complete */
 		smp_store_release(&req->iopoll_completed, 1);
 	else
-		__io_req_complete(req, 0);
+		io_req_complete_post(req, 0);
 }
 EXPORT_SYMBOL_GPL(io_uring_cmd_done);