diff mbox series

[v3,for-next,3/3] io_uring: support multishot in recvmsg

Message ID 20220714110258.1336200-4-dylany@fb.com (mailing list archive)
State New
Headers show
Series io_uring: multishot recvmsg | expand

Commit Message

Dylan Yudaken July 14, 2022, 11:02 a.m. UTC
Similar to multishot recv, this will require provided buffers to be
used. However recvmsg is much more complex than recv as it has multiple
outputs. Specifically flags, name, and control messages.

Support this by introducing a new struct io_uring_recvmsg_out with 4
fields. namelen, controllen and flags match the similar out fields in
msghdr from standard recvmsg(2), payloadlen is the length of the payload
following the header.
This struct is placed at the start of the returned buffer. Based on what
the user specifies in struct msghdr, the next bytes of the buffer will be
name (the next msg_namelen bytes), and then control (the next
msg_controllen bytes). The payload will come at the end. The return value
in the CQE is the total used size of the provided buffer.

Signed-off-by: Dylan Yudaken <dylany@fb.com>
---
 include/uapi/linux/io_uring.h |   7 ++
 io_uring/net.c                | 182 ++++++++++++++++++++++++++++++----
 io_uring/net.h                |   6 ++
 3 files changed, 176 insertions(+), 19 deletions(-)

Comments

Jens Axboe July 14, 2022, 1:36 p.m. UTC | #1
On 7/14/22 5:02 AM, Dylan Yudaken wrote:
> Similar to multishot recv, this will require provided buffers to be
> used. However recvmsg is much more complex than recv as it has multiple
> outputs. Specifically flags, name, and control messages.
> 
> Support this by introducing a new struct io_uring_recvmsg_out with 4
> fields. namelen, controllen and flags match the similar out fields in
> msghdr from standard recvmsg(2), payloadlen is the length of the payload
> following the header.
> This struct is placed at the start of the returned buffer. Based on what
> the user specifies in struct msghdr, the next bytes of the buffer will be
> name (the next msg_namelen bytes), and then control (the next
> msg_controllen bytes). The payload will come at the end. The return value
> in the CQE is the total used size of the provided buffer.

Just a few minor nits (some repeat ones too), otherwise looks fine to
me. I can either fold these in while applying, or you can spin a v4. Let
me know!

> +static int io_recvmsg_multishot_overflow(struct io_async_msghdr *iomsg)
> +{
> +	unsigned long hdr;
> +
> +	if (check_add_overflow(sizeof(struct io_uring_recvmsg_out),
> +			       (unsigned long)iomsg->namelen, &hdr))
> +		return -EOVERFLOW;
> +	if (check_add_overflow(hdr, iomsg->controllen, &hdr))
> +		return -EOVERFLOW;
> +	if (hdr > INT_MAX)
> +		return -EOVERFLOW;
> +
> +	return 0;
> +}

Nobody checks the specific value of this helper, so we should either
actually do that, or just make this one return a true/false instead. The
latter makes the most sense to me.

> +static int io_recvmsg_prep_multishot(struct io_async_msghdr *kmsg, struct io_sr_msg *sr,
> +				     void __user **buf, size_t *len)
> +{

The line breaks here are odd, should be at 80 unless there's a good
reason for it to exceed it.

Function reads better now though with the cast.

> +static int io_recvmsg_multishot(
> +	struct socket *sock,
> +	struct io_sr_msg *io,
> +	struct io_async_msghdr *kmsg,
> +	unsigned int flags,
> +	bool *finished)
> +{

This is still formatted badly.

> +		if (req->flags & REQ_F_APOLL_MULTISHOT) {
> +			ret = io_recvmsg_prep_multishot(kmsg, sr,
> +							&buf, &len);

			ret = io_recvmsg_prep_multishot(kmsg, sr, &buf, &len);

this still would fit nicely on an unbroken line.
diff mbox series

Patch

diff --git a/include/uapi/linux/io_uring.h b/include/uapi/linux/io_uring.h
index 499679134961..4c9b11e2e991 100644
--- a/include/uapi/linux/io_uring.h
+++ b/include/uapi/linux/io_uring.h
@@ -613,4 +613,11 @@  struct io_uring_file_index_range {
 	__u64	resv;
 };
 
+struct io_uring_recvmsg_out {
+	__u32 namelen;
+	__u32 controllen;
+	__u32 payloadlen;
+	__u32 flags;
+};
+
 #endif
diff --git a/io_uring/net.c b/io_uring/net.c
index 5bc3440a8290..7a4707cbf863 100644
--- a/io_uring/net.c
+++ b/io_uring/net.c
@@ -325,6 +325,21 @@  int io_send(struct io_kiocb *req, unsigned int issue_flags)
 	return IOU_OK;
 }
 
+static int io_recvmsg_multishot_overflow(struct io_async_msghdr *iomsg)
+{
+	unsigned long hdr;
+
+	if (check_add_overflow(sizeof(struct io_uring_recvmsg_out),
+			       (unsigned long)iomsg->namelen, &hdr))
+		return -EOVERFLOW;
+	if (check_add_overflow(hdr, iomsg->controllen, &hdr))
+		return -EOVERFLOW;
+	if (hdr > INT_MAX)
+		return -EOVERFLOW;
+
+	return 0;
+}
+
 static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
 				 struct io_async_msghdr *iomsg)
 {
@@ -352,6 +367,13 @@  static int __io_recvmsg_copy_hdr(struct io_kiocb *req,
 			sr->len = iomsg->fast_iov[0].iov_len;
 			iomsg->free_iov = NULL;
 		}
+
+		if (req->flags & REQ_F_APOLL_MULTISHOT) {
+			iomsg->namelen = msg.msg_namelen;
+			iomsg->controllen = msg.msg_controllen;
+			if (io_recvmsg_multishot_overflow(iomsg))
+				return -EOVERFLOW;
+		}
 	} else {
 		iomsg->free_iov = iomsg->fast_iov;
 		ret = __import_iovec(READ, msg.msg_iov, msg.msg_iovlen, UIO_FASTIOV,
@@ -399,6 +421,13 @@  static int __io_compat_recvmsg_copy_hdr(struct io_kiocb *req,
 			sr->len = clen;
 			iomsg->free_iov = NULL;
 		}
+
+		if (req->flags & REQ_F_APOLL_MULTISHOT) {
+			iomsg->namelen = msg.msg_namelen;
+			iomsg->controllen = msg.msg_controllen;
+			if (io_recvmsg_multishot_overflow(iomsg))
+				return -EOVERFLOW;
+		}
 	} else {
 		iomsg->free_iov = iomsg->fast_iov;
 		ret = __import_iovec(READ, (struct iovec __user *)uiov, msg.msg_iovlen,
@@ -455,8 +484,6 @@  int io_recvmsg_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe)
 	if (sr->msg_flags & MSG_ERRQUEUE)
 		req->flags |= REQ_F_CLEAR_POLLIN;
 	if (sr->flags & IORING_RECV_MULTISHOT) {
-		if (req->opcode == IORING_OP_RECVMSG)
-			return -EINVAL;
 		if (!(req->flags & REQ_F_BUFFER_SELECT))
 			return -EINVAL;
 		if (sr->msg_flags & MSG_WAITALL)
@@ -483,12 +510,13 @@  static inline void io_recv_prep_retry(struct io_kiocb *req)
 }
 
 /*
- * Finishes io_recv
+ * Finishes io_recv and io_recvmsg.
  *
  * Returns true if it is actually finished, or false if it should run
  * again (for multishot).
  */
-static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int cflags)
+static inline bool io_recv_finish(struct io_kiocb *req, int *ret,
+				  unsigned int cflags, bool mshot_finished)
 {
 	if (!(req->flags & REQ_F_APOLL_MULTISHOT)) {
 		io_req_set_res(req, *ret, cflags);
@@ -496,7 +524,7 @@  static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int c
 		return true;
 	}
 
-	if (*ret > 0) {
+	if (!mshot_finished) {
 		if (io_post_aux_cqe(req->ctx, req->cqe.user_data, *ret,
 				    cflags | IORING_CQE_F_MORE, false)) {
 			io_recv_prep_retry(req);
@@ -518,6 +546,91 @@  static inline bool io_recv_finish(struct io_kiocb *req, int *ret, unsigned int c
 	return true;
 }
 
+static int io_recvmsg_prep_multishot(struct io_async_msghdr *kmsg, struct io_sr_msg *sr,
+				     void __user **buf, size_t *len)
+{
+	unsigned long ubuf = (unsigned long)*buf;
+	unsigned long hdr;
+
+	hdr = sizeof(struct io_uring_recvmsg_out) + kmsg->namelen + kmsg->controllen;
+	if (*len < hdr)
+		return -EFAULT;
+
+	if (kmsg->controllen) {
+		unsigned long control = ubuf + hdr - kmsg->controllen;
+
+		kmsg->msg.msg_control_user = (void *)control;
+		kmsg->msg.msg_controllen = kmsg->controllen;
+	}
+
+	sr->buf = *buf; /* stash for later copy */
+	*buf = (void *)(ubuf + hdr);
+	kmsg->payloadlen = *len = *len - hdr;
+	return 0;
+}
+
+struct io_recvmsg_multishot_hdr {
+	struct io_uring_recvmsg_out msg;
+	struct sockaddr_storage addr;
+};
+
+static int io_recvmsg_multishot(
+	struct socket *sock,
+	struct io_sr_msg *io,
+	struct io_async_msghdr *kmsg,
+	unsigned int flags,
+	bool *finished)
+{
+	int err;
+	int copy_len;
+	struct io_recvmsg_multishot_hdr hdr;
+
+	if (kmsg->namelen)
+		kmsg->msg.msg_name = &hdr.addr;
+	kmsg->msg.msg_flags = flags & (MSG_CMSG_CLOEXEC|MSG_CMSG_COMPAT);
+	kmsg->msg.msg_namelen = 0;
+
+	if (sock->file->f_flags & O_NONBLOCK)
+		flags |= MSG_DONTWAIT;
+
+	err = sock_recvmsg(sock, &kmsg->msg, flags);
+	*finished = err <= 0;
+	if (err < 0)
+		return err;
+
+	hdr.msg = (struct io_uring_recvmsg_out) {
+		.controllen = kmsg->controllen - kmsg->msg.msg_controllen,
+		.flags = kmsg->msg.msg_flags & ~MSG_CMSG_COMPAT
+	};
+
+	hdr.msg.payloadlen = err;
+	if (err > kmsg->payloadlen)
+		err = kmsg->payloadlen;
+
+	copy_len = sizeof(struct io_uring_recvmsg_out);
+	if (kmsg->msg.msg_namelen > kmsg->namelen)
+		copy_len += kmsg->namelen;
+	else
+		copy_len += kmsg->msg.msg_namelen;
+
+	/*
+	 *      "fromlen shall refer to the value before truncation.."
+	 *                      1003.1g
+	 */
+	hdr.msg.namelen = kmsg->msg.msg_namelen;
+
+	/* ensure that there is no gap between hdr and sockaddr_storage */
+	BUILD_BUG_ON(offsetof(struct io_recvmsg_multishot_hdr, addr) !=
+		     sizeof(struct io_uring_recvmsg_out));
+	if (copy_to_user(io->buf, &hdr, copy_len)) {
+		*finished = true;
+		return -EFAULT;
+	}
+
+	return sizeof(struct io_uring_recvmsg_out) + kmsg->namelen +
+			kmsg->controllen + err;
+}
+
 int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 {
 	struct io_sr_msg *sr = io_kiocb_to_cmd(req);
@@ -527,6 +640,7 @@  int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 	unsigned flags;
 	int ret, min_ret = 0;
 	bool force_nonblock = issue_flags & IO_URING_F_NONBLOCK;
+	bool mshot_finished = true;
 
 	sock = sock_from_file(req->file);
 	if (unlikely(!sock))
@@ -545,16 +659,29 @@  int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 	    (sr->flags & IORING_RECVSEND_POLL_FIRST))
 		return io_setup_async_msg(req, kmsg, issue_flags);
 
+retry_multishot:
 	if (io_do_buffer_select(req)) {
 		void __user *buf;
+		size_t len = sr->len;
 
-		buf = io_buffer_select(req, &sr->len, issue_flags);
+		buf = io_buffer_select(req, &len, issue_flags);
 		if (!buf)
 			return -ENOBUFS;
+
+		if (req->flags & REQ_F_APOLL_MULTISHOT) {
+			ret = io_recvmsg_prep_multishot(kmsg, sr,
+							&buf, &len);
+
+			if (ret) {
+				io_kbuf_recycle(req, issue_flags);
+				return ret;
+			}
+		}
+
 		kmsg->fast_iov[0].iov_base = buf;
-		kmsg->fast_iov[0].iov_len = sr->len;
+		kmsg->fast_iov[0].iov_len = len;
 		iov_iter_init(&kmsg->msg.msg_iter, READ, kmsg->fast_iov, 1,
-				sr->len);
+				len);
 	}
 
 	flags = sr->msg_flags;
@@ -564,10 +691,22 @@  int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 		min_ret = iov_iter_count(&kmsg->msg.msg_iter);
 
 	kmsg->msg.msg_get_inq = 1;
-	ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg, kmsg->uaddr, flags);
+	if (req->flags & REQ_F_APOLL_MULTISHOT)
+		ret = io_recvmsg_multishot(sock, sr, kmsg, flags,
+					   &mshot_finished);
+	else
+		ret = __sys_recvmsg_sock(sock, &kmsg->msg, sr->umsg, kmsg->uaddr, flags);
+
 	if (ret < min_ret) {
-		if (ret == -EAGAIN && force_nonblock)
-			return io_setup_async_msg(req, kmsg, issue_flags);
+		if (ret == -EAGAIN && force_nonblock) {
+			ret = io_setup_async_msg(req, kmsg, issue_flags);
+			if (ret == -EAGAIN && (req->flags & IO_APOLL_MULTI_POLLED) ==
+					       IO_APOLL_MULTI_POLLED) {
+				io_kbuf_recycle(req, issue_flags);
+				return IOU_ISSUE_SKIP_COMPLETE;
+			}
+			return ret;
+		}
 		if (ret == -ERESTARTSYS)
 			ret = -EINTR;
 		if (ret > 0 && io_net_retry(sock, flags)) {
@@ -580,11 +719,6 @@  int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 		req_set_fail(req);
 	}
 
-	/* fast path, check for non-NULL to avoid function call */
-	if (kmsg->free_iov)
-		kfree(kmsg->free_iov);
-	io_netmsg_recycle(req, issue_flags);
-	req->flags &= ~REQ_F_NEED_CLEANUP;
 	if (ret > 0)
 		ret += sr->done_io;
 	else if (sr->done_io)
@@ -596,8 +730,18 @@  int io_recvmsg(struct io_kiocb *req, unsigned int issue_flags)
 	if (kmsg->msg.msg_inq)
 		cflags |= IORING_CQE_F_SOCK_NONEMPTY;
 
-	io_req_set_res(req, ret, cflags);
-	return IOU_OK;
+	if (!io_recv_finish(req, &ret, cflags, mshot_finished))
+		goto retry_multishot;
+
+	if (mshot_finished) {
+		io_netmsg_recycle(req, issue_flags);
+		/* fast path, check for non-NULL to avoid function call */
+		if (kmsg->free_iov)
+			kfree(kmsg->free_iov);
+		req->flags &= ~REQ_F_NEED_CLEANUP;
+	}
+
+	return ret;
 }
 
 int io_recv(struct io_kiocb *req, unsigned int issue_flags)
@@ -684,7 +828,7 @@  int io_recv(struct io_kiocb *req, unsigned int issue_flags)
 	if (msg.msg_inq)
 		cflags |= IORING_CQE_F_SOCK_NONEMPTY;
 
-	if (!io_recv_finish(req, &ret, cflags))
+	if (!io_recv_finish(req, &ret, cflags, ret <= 0))
 		goto retry_multishot;
 
 	return ret;
diff --git a/io_uring/net.h b/io_uring/net.h
index 178a6d8b76e0..db20ce9d6546 100644
--- a/io_uring/net.h
+++ b/io_uring/net.h
@@ -9,6 +9,12 @@ 
 struct io_async_msghdr {
 	union {
 		struct iovec		fast_iov[UIO_FASTIOV];
+		struct {
+			struct iovec	fast_iov_one;
+			__kernel_size_t	controllen;
+			int		namelen;
+			__kernel_size_t	payloadlen;
+		};
 		struct io_cache_entry	cache;
 	};
 	/* points to an allocated iov, if NULL we use fast_iov instead */