@@ -126,6 +126,8 @@ struct socket {
const struct proto_ops *ops; /* Might change with IPV6_ADDRFORM or MPTCP. */
struct socket_wq wq;
+
+ unsigned zc_rx_idx;
};
/*
@@ -579,6 +579,7 @@ enum {
/* register a network interface queue for zerocopy */
IORING_REGISTER_ZC_RX_IFQ = 29,
+ IORING_REGISTER_ZC_RX_SOCK = 30,
/* this goes last */
IORING_REGISTER_LAST,
@@ -824,6 +825,12 @@ struct io_uring_zc_rx_ifq_reg {
struct io_rbuf_rqring_offsets rq_off;
};
+struct io_uring_zc_rx_sock_reg {
+ __u32 sockfd;
+ __u32 zc_rx_ifq_idx;
+ __u32 __resv[2];
+};
+
#ifdef __cplusplus
}
#endif
@@ -16,6 +16,7 @@
#include "net.h"
#include "notif.h"
#include "rsrc.h"
+#include "zc_rx.h"
#if defined(CONFIG_NET)
struct io_shutdown {
@@ -1033,6 +1034,25 @@ int io_recv(struct io_kiocb *req, unsigned int issue_flags)
return ret;
}
+static __maybe_unused
+struct io_zc_rx_ifq *io_zc_verify_sock(struct io_kiocb *req,
+ struct socket *sock)
+{
+ unsigned token = READ_ONCE(sock->zc_rx_idx);
+ unsigned ifq_idx = token >> IO_ZC_IFQ_IDX_OFFSET;
+ unsigned sock_idx = token & IO_ZC_IFQ_IDX_MASK;
+ struct io_zc_rx_ifq *ifq;
+
+ if (ifq_idx)
+ return NULL;
+ ifq = req->ctx->ifq;
+ if (!ifq || sock_idx >= ifq->nr_sockets)
+ return NULL;
+ if (ifq->sockets[sock_idx] != req->file)
+ return NULL;
+ return ifq;
+}
+
void io_send_zc_cleanup(struct io_kiocb *req)
{
struct io_sr_msg *zc = io_kiocb_to_cmd(req, struct io_sr_msg);
@@ -570,6 +570,12 @@ static int __io_uring_register(struct io_ring_ctx *ctx, unsigned opcode,
break;
ret = io_register_zc_rx_ifq(ctx, arg);
break;
+ case IORING_REGISTER_ZC_RX_SOCK:
+ ret = -EINVAL;
+ if (!arg || nr_args != 1)
+ break;
+ ret = io_register_zc_rx_sock(ctx, arg);
+ break;
default:
ret = -EINVAL;
break;
@@ -5,12 +5,15 @@
#include <linux/mm.h>
#include <linux/io_uring.h>
#include <linux/netdevice.h>
+#include <net/tcp.h>
+#include <net/af_unix.h>
#include <uapi/linux/io_uring.h>
#include "io_uring.h"
#include "kbuf.h"
#include "zc_rx.h"
+#include "rsrc.h"
typedef int (*bpf_op_t)(struct net_device *dev, struct netdev_bpf *bpf);
@@ -76,10 +79,31 @@ static struct io_zc_rx_ifq *io_zc_rx_ifq_alloc(struct io_ring_ctx *ctx)
return ifq;
}
-static void io_zc_rx_ifq_free(struct io_zc_rx_ifq *ifq)
+static void io_shutdown_ifq(struct io_zc_rx_ifq *ifq)
{
- if (ifq->if_rxq_id != -1)
+ int i;
+
+ if (!ifq)
+ return;
+
+ for (i = 0; i < ifq->nr_sockets; i++) {
+ if (ifq->sockets[i]) {
+ fput(ifq->sockets[i]);
+ ifq->sockets[i] = NULL;
+ }
+ }
+ ifq->nr_sockets = 0;
+
+ if (ifq->if_rxq_id != -1) {
io_close_zc_rxq(ifq);
+ ifq->if_rxq_id = -1;
+ }
+}
+
+static void io_zc_rx_ifq_free(struct io_zc_rx_ifq *ifq)
+{
+ io_shutdown_ifq(ifq);
+
if (ifq->dev)
dev_put(ifq->dev);
io_free_rbuf_ring(ifq);
@@ -132,7 +156,6 @@ int io_register_zc_rx_ifq(struct io_ring_ctx *ctx,
reg.rq_off.tail = offsetof(struct io_uring, tail);
if (copy_to_user(arg, ®, sizeof(reg))) {
- io_close_zc_rxq(ifq);
ret = -EFAULT;
goto err;
}
@@ -153,6 +176,8 @@ void io_unregister_zc_rx_ifqs(struct io_ring_ctx *ctx)
if (!ifq)
return;
+ WARN_ON_ONCE(ifq->nr_sockets);
+
ctx->ifq = NULL;
io_zc_rx_ifq_free(ifq);
}
@@ -160,6 +185,66 @@ void io_unregister_zc_rx_ifqs(struct io_ring_ctx *ctx)
void io_shutdown_zc_rx_ifqs(struct io_ring_ctx *ctx)
{
lockdep_assert_held(&ctx->uring_lock);
+
+ io_shutdown_ifq(ctx->ifq);
+}
+
+int io_register_zc_rx_sock(struct io_ring_ctx *ctx,
+ struct io_uring_zc_rx_sock_reg __user *arg)
+{
+ struct io_uring_zc_rx_sock_reg sr;
+ struct io_zc_rx_ifq *ifq;
+ struct socket *sock;
+ struct file *file;
+ int ret = -EEXIST;
+ int idx;
+
+ if (copy_from_user(&sr, arg, sizeof(sr)))
+ return -EFAULT;
+ if (sr.__resv[0] || sr.__resv[1])
+ return -EINVAL;
+ if (sr.zc_rx_ifq_idx != 0 || !ctx->ifq)
+ return -EINVAL;
+
+ ifq = ctx->ifq;
+ if (ifq->nr_sockets >= ARRAY_SIZE(ifq->sockets))
+ return -EINVAL;
+
+ BUILD_BUG_ON(ARRAY_SIZE(ifq->sockets) > IO_ZC_IFQ_IDX_MASK);
+
+ file = fget(sr.sockfd);
+ if (!file)
+ return -EBADF;
+
+ if (!!unix_get_socket(file)) {
+ fput(file);
+ return -EBADF;
+ }
+
+ sock = sock_from_file(file);
+ if (unlikely(!sock || !sock->sk)) {
+ fput(file);
+ return -ENOTSOCK;
+ }
+
+ idx = ifq->nr_sockets;
+ lock_sock(sock->sk);
+ if (!sock->zc_rx_idx) {
+ unsigned token;
+
+ token = idx + (sr.zc_rx_ifq_idx << IO_ZC_IFQ_IDX_OFFSET);
+ WRITE_ONCE(sock->zc_rx_idx, token);
+ ret = 0;
+ }
+ release_sock(sock->sk);
+
+ if (ret) {
+ fput(file);
+ return ret;
+ }
+ ifq->sockets[idx] = file;
+ ifq->nr_sockets++;
+ return 0;
}
#endif
@@ -2,6 +2,13 @@
#ifndef IOU_ZC_RX_H
#define IOU_ZC_RX_H
+#include <linux/io_uring_types.h>
+#include <linux/skbuff.h>
+
+#define IO_ZC_MAX_IFQ_SOCKETS 16
+#define IO_ZC_IFQ_IDX_OFFSET 16
+#define IO_ZC_IFQ_IDX_MASK ((1U << IO_ZC_IFQ_IDX_OFFSET) - 1)
+
struct io_zc_rx_ifq {
struct io_ring_ctx *ctx;
struct net_device *dev;
@@ -11,6 +18,9 @@ struct io_zc_rx_ifq {
/* hw rx descriptor ring id */
u32 if_rxq_id;
+
+ unsigned nr_sockets;
+ struct file *sockets[IO_ZC_MAX_IFQ_SOCKETS];
};
#if defined(CONFIG_PAGE_POOL)
@@ -18,6 +28,8 @@ int io_register_zc_rx_ifq(struct io_ring_ctx *ctx,
struct io_uring_zc_rx_ifq_reg __user *arg);
void io_unregister_zc_rx_ifqs(struct io_ring_ctx *ctx);
void io_shutdown_zc_rx_ifqs(struct io_ring_ctx *ctx);
+int io_register_zc_rx_sock(struct io_ring_ctx *ctx,
+ struct io_uring_zc_rx_sock_reg __user *arg);
#else
static inline int io_register_zc_rx_ifq(struct io_ring_ctx *ctx,
struct io_uring_zc_rx_ifq_reg __user *arg)
@@ -30,6 +42,11 @@ static inline void io_unregister_zc_rx_ifqs(struct io_ring_ctx *ctx)
static inline void io_shutdown_zc_rx_ifqs(struct io_ring_ctx *ctx)
{
}
+static inline int io_register_zc_rx_sock(struct io_ring_ctx *ctx,
+ struct io_uring_zc_rx_sock_reg __user *arg)
+{
+ return -EOPNOTSUPP;
+}
#endif
#endif
@@ -637,6 +637,7 @@ struct socket *sock_alloc(void)
sock = SOCKET_I(inode);
+ sock->zc_rx_idx = 0;
inode->i_ino = get_next_ino();
inode->i_mode = S_IFSOCK | S_IRWXUGO;
inode->i_uid = current_fsuid();