@@ -18,8 +18,10 @@
#ifdef CONFIG_BPF_JIT
static struct bpf_struct_ops bpf_mptcp_sched_ops;
-static const struct btf_type *mptcp_sock_type, *mptcp_subflow_type __read_mostly;
-static u32 mptcp_sock_id, mptcp_subflow_id;
+static u32 mptcp_sock_id,
+ mptcp_subflow_id;
+
+/* MPTCP BPF packet scheduler */
static const struct bpf_func_proto *
bpf_mptcp_sched_get_func_proto(enum bpf_func_id func_id,
@@ -43,12 +45,10 @@ static int bpf_mptcp_sched_btf_struct_access(struct bpf_verifier_log *log,
const struct bpf_reg_state *reg,
int off, int size)
{
- const struct btf_type *t;
+ u32 id = reg->btf_id;
size_t end;
- t = btf_type_by_id(reg->btf, reg->btf_id);
-
- if (t == mptcp_sock_type) {
+ if (id == mptcp_sock_id) {
switch (off) {
case offsetof(struct mptcp_sock, snd_burst):
end = offsetofend(struct mptcp_sock, snd_burst);
@@ -58,11 +58,14 @@ static int bpf_mptcp_sched_btf_struct_access(struct bpf_verifier_log *log,
off);
return -EACCES;
}
- } else if (t == mptcp_subflow_type) {
+ } else if (id == mptcp_subflow_id) {
switch (off) {
case offsetof(struct mptcp_subflow_context, avg_pacing_rate):
end = offsetofend(struct mptcp_subflow_context, avg_pacing_rate);
break;
+ case offsetof(struct mptcp_subflow_context, scheduled):
+ end = offsetofend(struct mptcp_subflow_context, scheduled);
+ break;
default:
bpf_log(log, "no write support to mptcp_subflow_context at off %d\n",
off);
@@ -75,7 +78,7 @@ static int bpf_mptcp_sched_btf_struct_access(struct bpf_verifier_log *log,
if (off + size > end) {
bpf_log(log, "access beyond %s at off %u size %u ended at %zu",
- t == mptcp_sock_type ? "mptcp_sock" : "mptcp_subflow_context",
+ id == mptcp_sock_id ? "mptcp_sock" : "mptcp_subflow_context",
off, size, end);
return -EACCES;
}
@@ -113,7 +116,6 @@ static int bpf_mptcp_sched_init_member(const struct btf_type *t,
const struct mptcp_sched_ops *usched;
struct mptcp_sched_ops *sched;
u32 moff;
- int ret;
usched = (const struct mptcp_sched_ops *)udata;
sched = (struct mptcp_sched_ops *)kdata;
@@ -124,12 +126,7 @@ static int bpf_mptcp_sched_init_member(const struct btf_type *t,
if (bpf_obj_name_cpy(sched->name, usched->name,
sizeof(sched->name)) <= 0)
return -EINVAL;
-
- rcu_read_lock();
- ret = mptcp_sched_find(usched->name) ? -EEXIST : 1;
- rcu_read_unlock();
-
- return ret;
+ return 1;
}
return 0;
@@ -144,18 +141,21 @@ static int bpf_mptcp_sched_init(struct btf *btf)
if (type_id < 0)
return -EINVAL;
mptcp_sock_id = type_id;
- mptcp_sock_type = btf_type_by_id(btf, mptcp_sock_id);
type_id = btf_find_by_name_kind(btf, "mptcp_subflow_context",
BTF_KIND_STRUCT);
if (type_id < 0)
return -EINVAL;
mptcp_subflow_id = type_id;
- mptcp_subflow_type = btf_type_by_id(btf, mptcp_subflow_id);
return 0;
}
+static int bpf_mptcp_sched_validate(void *kdata)
+{
+ return mptcp_validate_scheduler(kdata);
+}
+
static int __bpf_mptcp_sched_get_subflow(struct mptcp_sock *msk,
struct mptcp_sched_data *data)
{
@@ -183,6 +183,7 @@ static struct bpf_struct_ops bpf_mptcp_sched_ops = {
.check_member = bpf_mptcp_sched_check_member,
.init_member = bpf_mptcp_sched_init_member,
.init = bpf_mptcp_sched_init,
+ .validate = bpf_mptcp_sched_validate,
.name = "mptcp_sched_ops",
.cfi_stubs = &__bpf_mptcp_sched_ops,
};