diff mbox series

[RFC,3/3] io_uring/bpf: add kfuncs for BPF programs

Message ID e4c5bc9551109bef91c53be43a4296f3d317f19a.1731285516.git.asml.silence@gmail.com (mailing list archive)
State New
Headers show
Series Add BPF for io_uring | expand

Commit Message

Pavel Begunkov Nov. 11, 2024, 1:50 a.m. UTC
Add a way for io_uring BPF programs to look at CQEs and submit new
requests.

Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
---
 io_uring/bpf.c      | 118 ++++++++++++++++++++++++++++++++++++++++++++
 io_uring/bpf.h      |   2 +
 io_uring/io_uring.c |   1 +
 3 files changed, 121 insertions(+)
diff mbox series

Patch

diff --git a/io_uring/bpf.c b/io_uring/bpf.c
index 8b7c74761c63..d413c3712612 100644
--- a/io_uring/bpf.c
+++ b/io_uring/bpf.c
@@ -4,6 +4,123 @@ 
 #include <linux/filter.h>
 
 #include "bpf.h"
+#include "io_uring.h"
+
+static inline struct io_bpf_ctx *io_user_to_bpf_ctx(struct io_uring_bpf_ctx *ctx)
+{
+	struct io_bpf_ctx_kern *bc = (struct io_bpf_ctx_kern *)ctx;
+
+	return container_of(bc, struct io_bpf_ctx, kern);
+}
+
+__bpf_kfunc_start_defs();
+
+__bpf_kfunc int bpf_io_uring_queue_sqe(struct io_uring_bpf_ctx *user_ctx,
+					void *bpf_sqe, int mem__sz)
+{
+	struct io_bpf_ctx *bc = io_user_to_bpf_ctx(user_ctx);
+	struct io_ring_ctx *ctx = bc->ctx;
+	unsigned tail = ctx->rings->sq.tail;
+	struct io_uring_sqe *sqe;
+
+	if (mem__sz != sizeof(*sqe))
+		return -EINVAL;
+
+	ctx->rings->sq.tail++;
+	tail &= (ctx->sq_entries - 1);
+	/* double index for 128-byte SQEs, twice as long */
+	if (ctx->flags & IORING_SETUP_SQE128)
+		tail <<= 1;
+	sqe = &ctx->sq_sqes[tail];
+	memcpy(sqe, bpf_sqe, sizeof(*sqe));
+	return 0;
+}
+
+__bpf_kfunc int bpf_io_uring_submit_sqes(struct io_uring_bpf_ctx *user_ctx,
+					 unsigned nr)
+{
+	struct io_bpf_ctx *bc = io_user_to_bpf_ctx(user_ctx);
+	struct io_ring_ctx *ctx = bc->ctx;
+
+	return io_submit_sqes(ctx, nr);
+}
+
+__bpf_kfunc int bpf_io_uring_get_cqe(struct io_uring_bpf_ctx *user_ctx,
+				     struct io_uring_cqe *res__uninit)
+{
+	struct io_bpf_ctx *bc = io_user_to_bpf_ctx(user_ctx);
+	struct io_ring_ctx *ctx = bc->ctx;
+	struct io_rings *rings = ctx->rings;
+	unsigned int mask = ctx->cq_entries - 1;
+	unsigned head = rings->cq.head;
+	struct io_uring_cqe *cqe;
+
+	/* TODO CQE32 */
+	if (head == rings->cq.tail)
+		goto fail;
+
+	cqe = &rings->cqes[head & mask];
+	memcpy(res__uninit, cqe, sizeof(*cqe));
+	rings->cq.head++;
+	return 0;
+fail:
+	memset(res__uninit, 0, sizeof(*res__uninit));
+	return -EINVAL;
+}
+
+__bpf_kfunc
+struct io_uring_cqe *bpf_io_uring_get_cqe2(struct io_uring_bpf_ctx *user_ctx)
+{
+	struct io_bpf_ctx *bc = io_user_to_bpf_ctx(user_ctx);
+	struct io_ring_ctx *ctx = bc->ctx;
+	struct io_rings *rings = ctx->rings;
+	unsigned int mask = ctx->cq_entries - 1;
+	unsigned head = rings->cq.head;
+	struct io_uring_cqe *cqe;
+
+	/* TODO CQE32 */
+	if (head == rings->cq.tail)
+		return NULL;
+
+	cqe = &rings->cqes[head & mask];
+	rings->cq.head++;
+	return cqe;
+}
+
+__bpf_kfunc
+void bpf_io_uring_set_wait_params(struct io_uring_bpf_ctx *user_ctx,
+				  unsigned wait_nr)
+{
+	struct io_bpf_ctx *bc = io_user_to_bpf_ctx(user_ctx);
+	struct io_ring_ctx *ctx = bc->ctx;
+	struct io_wait_queue *wq = bc->waitq;
+
+	wait_nr = min_t(unsigned, wait_nr, ctx->cq_entries);
+	wq->cq_tail = READ_ONCE(ctx->rings->cq.head) + wait_nr;
+}
+
+__bpf_kfunc_end_defs();
+
+BTF_KFUNCS_START(io_uring_kfunc_set)
+BTF_ID_FLAGS(func, bpf_io_uring_queue_sqe, KF_SLEEPABLE);
+BTF_ID_FLAGS(func, bpf_io_uring_submit_sqes, KF_SLEEPABLE);
+BTF_ID_FLAGS(func, bpf_io_uring_get_cqe, 0);
+BTF_ID_FLAGS(func, bpf_io_uring_get_cqe2, KF_RET_NULL);
+BTF_ID_FLAGS(func, bpf_io_uring_set_wait_params, 0);
+BTF_KFUNCS_END(io_uring_kfunc_set)
+
+static const struct btf_kfunc_id_set bpf_io_uring_kfunc_set = {
+	.owner = THIS_MODULE,
+	.set = &io_uring_kfunc_set,
+};
+
+static int init_io_uring_bpf(void)
+{
+	return register_btf_kfunc_id_set(BPF_PROG_TYPE_IOURING,
+					 &bpf_io_uring_kfunc_set);
+}
+late_initcall(init_io_uring_bpf);
+
 
 static const struct bpf_func_proto *
 io_bpf_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
@@ -82,6 +199,7 @@  int io_register_bpf(struct io_ring_ctx *ctx, void __user *arg,
 	}
 
 	bc->prog = prog;
+	bc->ctx = ctx;
 	ctx->bpf_ctx = bc;
 	return 0;
 }
diff --git a/io_uring/bpf.h b/io_uring/bpf.h
index 2b4e555ff07a..9f578a48ce2e 100644
--- a/io_uring/bpf.h
+++ b/io_uring/bpf.h
@@ -9,6 +9,8 @@  struct bpf_prog;
 
 struct io_bpf_ctx {
 	struct io_bpf_ctx_kern kern;
+	struct io_ring_ctx *ctx;
+	struct io_wait_queue *waitq;
 	struct bpf_prog *prog;
 };
 
diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
index 82599e2a888a..98206e68ce70 100644
--- a/io_uring/io_uring.c
+++ b/io_uring/io_uring.c
@@ -2836,6 +2836,7 @@  static int io_cqring_wait(struct io_ring_ctx *ctx, int min_events, u32 flags,
 	io_napi_busy_loop(ctx, &iowq);
 
 	if (io_bpf_enabled(ctx)) {
+		ctx->bpf_ctx->waitq = &iowq;
 		ret = io_run_bpf(ctx);
 		if (ret == IOU_BPF_RET_STOP)
 			return 0;