Message ID | 20240222085232.62483-2-hffilwlqm@gmail.com (mailing list archive) |
---|---|
State | Changes Requested |
Delegated to: | BPF |
Headers | show |
Series | bpf, x64: Fix tailcall hierarchy | expand |
On 2024/2/22 16:52, Leon Hwang wrote: > From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall > handling in JIT"), the tailcall on x64 works better than before. > > From commit e411901c0b775a3a ("bpf: allow for tailcalls in BPF subprograms > for x64 JIT"), tailcall is able to run in BPF subprograms on x64. > > How about: > > 1. More than 1 subprograms are called in a bpf program. > 2. The tailcalls in the subprograms call the bpf program. > > Because of missing tail_call_cnt back-propagation, a tailcall hierarchy > comes up. And MAX_TAIL_CALL_CNT limit does not work for this case. > [SNIP] > > Fixes: ebf7d1f508a7 ("bpf, x64: rework pro/epilogue and tailcall handling in JIT") > Fixes: e411901c0b77 ("bpf: allow for tailcalls in BPF subprograms for x64 JIT") > Signed-off-by: Leon Hwang <hffilwlqm@gmail.com> > --- > arch/x86/net/bpf_jit_comp.c | 128 ++++++++++++++++++++---------------- > 1 file changed, 71 insertions(+), 57 deletions(-) > > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c > index e1390d1e331b5..3d1498a13b04c 100644 > --- a/arch/x86/net/bpf_jit_comp.c > +++ b/arch/x86/net/bpf_jit_comp.c > @@ -18,6 +18,7 @@ > #include <asm/text-patching.h> > #include <asm/unwind.h> > #include <asm/cfi.h> > +#include <asm/percpu.h> > [SNIP] > + > /* > * Generate the following code: nit: the "tail_call_cnt++" of the comment should be updated too. > * > @@ -594,7 +641,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > u32 stack_depth, u8 *ip, > struct jit_context *ctx) > { > - int tcc_off = -4 - round_up(stack_depth, 8); > u8 *prog = *pprog, *start = *pprog; > int offset; > > @@ -615,17 +661,14 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > offset = ctx->tail_call_indirect_label - (prog + 2 - start); > EMIT2(X86_JBE, offset); /* jbe out */ > > - /* > - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) > + /* if (bpf_tail_call_cnt_fetch_and_inc() >= MAX_TAIL_CALL_CNT) > * goto out; > */ > - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ > + emit_call(&prog, bpf_tail_call_cnt_fetch_and_inc, ip + (prog - start)); > EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ [SNIP] Thanks, Leon
On 2024/2/22 16:52, Leon Hwang wrote: >>From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall > handling in JIT"), the tailcall on x64 works better than before. > >>From commit e411901c0b775a3a ("bpf: allow for tailcalls in BPF subprograms > for x64 JIT"), tailcall is able to run in BPF subprograms on x64. > > How about: > > 1. More than 1 subprograms are called in a bpf program. > 2. The tailcalls in the subprograms call the bpf program. > > Because of missing tail_call_cnt back-propagation, a tailcall hierarchy > comes up. And MAX_TAIL_CALL_CNT limit does not work for this case. > > Let's take a look into an example: > > \#include <linux/bpf.h> > \#include <bpf/bpf_helpers.h> > \#include "bpf_legacy.h" > > struct { > __uint(type, BPF_MAP_TYPE_PROG_ARRAY); > __uint(max_entries, 1); > __uint(key_size, sizeof(__u32)); > __uint(value_size, sizeof(__u32)); > } jmp_table SEC(".maps"); > > int count = 0; > > static __noinline > int subprog_tail(struct __sk_buff *skb) > { > bpf_tail_call_static(skb, &jmp_table, 0); > return 0; > } > > SEC("tc") > int entry(struct __sk_buff *skb) > { > volatile int ret = 1; > > count++; > subprog_tail(skb); /* subprog call1 */ > subprog_tail(skb); /* subprog call2 */ > > return ret; > } > > char __license[] SEC("license") = "GPL"; > > And the entry bpf prog is populated to the 0th slot of jmp_table. Then, > what happens when entry bpf prog runs? The CPU will be stalled because > of too many tailcalls, e.g. the test_progs failed to run on aarch64 and > s390x because of "rcu: INFO: rcu_sched self-detected stall on CPU". > > So, if CPU does not stall because of too many tailcalls, how many > tailcalls will be there for this case? And why MAX_TAIL_CALL_CNT limit > does not work for this case? > > Let's step into some running steps. > > At the very first time when subprog_tail() is called, subprog_tail() does > tailcall the entry bpf prog. Then, subprog_taill() is called at second time > at the position subprog call1, and it tailcalls the entry bpf prog again. > > Then, again and again. At the very first time when MAX_TAIL_CALL_CNT limit > works, subprog_tail() has been called for 34 times at the position subprog > call1. And at this time, the tail_call_cnt is 33 in subprog_tail(). > > Next, the 34th subprog_tail() returns to entry() because of > MAX_TAIL_CALL_CNT limit. > > In entry(), the 34th entry(), at the time after the 34th subprog_tail() at > the position subprog call1 finishes and before the 1st subprog_tail() at > the position subprog call2 calls in entry(), what's the value of > tail_call_cnt in entry()? It's 33. > > As we know, tail_all_cnt is pushed on the stack of entry(), and propagates > to subprog_tail() by %rax from stack. > > Then, at the time when subprog_tail() at the position subprog call2 is > called for its first time, tail_call_cnt 33 propagates to subprog_tail() > by %rax. And the tailcall in subprog_tail() is aborted because of > tail_call_cnt >= MAX_TAIL_CALL_CNT too. > > Then, subprog_tail() at the position subprog call2 ends, and the 34th > entry() ends. And it returns to the 33rd subprog_tail() called from the > position subprog call1. But wait, at this time, what's the value of > tail_call_cnt under the stack of subprog_tail()? It's 33. > > Then, in the 33rd entry(), at the time after the 33th subprog_tail() at > the position subprog call1 finishes and before the 2nd subprog_tail() at > the position subprog call2 calls, what's the value of tail_call_cnt > in current entry()? It's *32*. Why not 33? > > Before stepping into subprog_tail() at the position subprog call2 in 33rd > entry(), like stopping the time machine, let's have a look at the stack > memory: > > | STACK | > +---------+ RBP <-- current rbp > | ret | STACK of 33rd entry() > | tcc | its value is 32 > +---------+ RSP <-- current rsp > | rip | STACK of 34rd entry() > | rbp | reuse the STACK of 33rd subprog_tail() at the position > | ret | subprog call1 > | tcc | its value is 33 > +---------+ rsp > | rip | STACK of 1st subprog_tail() at the position subprog call2 > | rbp | > | tcc | its value is 33 > +---------+ rsp > > Why not 33? It's because tail_call_cnt does not back-propagate from > subprog_tail() to entry(). > > Then, while stepping into subprog_tail() at the position subprog call2 in > 33rd entry(): > > | STACK | > +---------+ > | ret | STACK of 33rd entry() > | tcc | its value is 32 > | rip | > | rbp | > +---------+ RBP <-- current rbp > | tcc | its value is 32; STACK of subprog_tail() at the position > +---------+ RSP <-- current rsp subprog call2 > > Then, while pausing after tailcalling in 2nd subprog_tail() at the position > subprog call2: > > | STACK | > +---------+ > | ret | STACK of 33rd entry() > | tcc | its value is 32 > | rip | > | rbp | > +---------+ RBP <-- current rbp > | tcc | its value is 33; STACK of subprog_tail() at the position > +---------+ RSP <-- current rsp subprog call2 > > Note: what happens to tail_call_cnt: > /* > * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) > * goto out; > */ > It's to check >= MAX_TAIL_CALL_CNT first and then increment tail_call_cnt. > > So, current tailcall is allowed to run. > > Then, entry() is tailcalled. And the stack memory status is: > > | STACK | > +---------+ > | ret | STACK of 33rd entry() > | tcc | its value is 32 > | rip | > | rbp | > +---------+ RBP <-- current rbp > | ret | STACK of 35th entry(); reuse STACK of subprog_tail() at the > | tcc | its value is 33 the position subprog call2 > +---------+ RSP <-- current rsp > > So, the tailcalls in the 35th entry() will be aborted. > > And, ..., again and again. :( > > And, I hope you have understood the reason why MAX_TAIL_CALL_CNT limit > does not work for this case. > > And, how many tailcalls are there for this case if CPU does not stall? > >>From top-down view, does it look like hierarchy layer and layer? > > I think it is a hierarchy layer model with 2+4+8+...+2**33 tailcalls. As a > result, if CPU does not stall, there will be 2**34 - 2 = 17,179,869,182 > tailcalls. That's the guy making CPU stalled. > > What about there are N subprog_tail() in entry()? If CPU does not stall > because of too many tailcalls, there will be almost N**34 tailcalls. > > As we learn about the issue, how does this patch resolve it? > > In this patch, it uses PERCPU tail_call_cnt to store the temporary > tail_call_cnt. > > First, at the prologue of bpf prog, it initialise the PERCPU > tail_call_cnt by setting current CPU's tail_call_cnt to 0. > > Then, when a tailcall happens, it fetches and increments current CPU's > tail_call_cnt, and compares to MAX_TAIL_CALL_CNT. > > Additionally, in order to avoid touching other registers excluding %rax, > it uses asm to handle PERCPU tail_call_cnt by %rax only. > > As a result, the previous tailcall way can be removed totally, including > > 1. "push rax" at prologue. > 2. load tail_call_cnt to rax before calling function. > 3. "pop rax" before jumping to tailcallee when tailcall. > 4. "push rax" and load tail_call_cnt to rax at trampoline. > > Fixes: ebf7d1f508a7 ("bpf, x64: rework pro/epilogue and tailcall handling in JIT") > Fixes: e411901c0b77 ("bpf: allow for tailcalls in BPF subprograms for x64 JIT") > Signed-off-by: Leon Hwang <hffilwlqm@gmail.com> > --- > arch/x86/net/bpf_jit_comp.c | 128 ++++++++++++++++++++---------------- > 1 file changed, 71 insertions(+), 57 deletions(-) > > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c > index e1390d1e331b5..3d1498a13b04c 100644 > --- a/arch/x86/net/bpf_jit_comp.c > +++ b/arch/x86/net/bpf_jit_comp.c > @@ -18,6 +18,7 @@ > #include <asm/text-patching.h> > #include <asm/unwind.h> > #include <asm/cfi.h> > +#include <asm/percpu.h> > > static bool all_callee_regs_used[4] = {true, true, true, true}; > > @@ -259,7 +260,7 @@ struct jit_context { > /* Number of bytes emit_patch() needs to generate instructions */ > #define X86_PATCH_SIZE 5 > /* Number of bytes that will be skipped on tailcall */ > -#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) > +#define X86_TAIL_CALL_OFFSET (14 + ENDBR_INSN_SIZE) > > static void push_r12(u8 **pprog) > { > @@ -389,6 +390,9 @@ static void emit_cfi(u8 **pprog, u32 hash) > *pprog = prog; > } > > +static int emit_call(u8 **pprog, void *func, void *ip); > +static __used void bpf_tail_call_cnt_prepare(void); > + > /* > * Emit x86-64 prologue code for BPF program. > * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes > @@ -396,9 +400,9 @@ static void emit_cfi(u8 **pprog, u32 hash) > */ > static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > bool tail_call_reachable, bool is_subprog, > - bool is_exception_cb) > + bool is_exception_cb, u8 *ip) > { > - u8 *prog = *pprog; > + u8 *prog = *pprog, *start = *pprog; > > emit_cfi(&prog, is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash); > /* BPF trampoline can be made to work without these nops, > @@ -407,13 +411,10 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > emit_nops(&prog, X86_PATCH_SIZE); > if (!ebpf_from_cbpf) { > if (tail_call_reachable && !is_subprog) > - /* When it's the entry of the whole tailcall context, > - * zeroing rax means initialising tail_call_cnt. > - */ > - EMIT2(0x31, 0xC0); /* xor eax, eax */ > + emit_call(&prog, bpf_tail_call_cnt_prepare, > + ip + (prog - start)); > else > - /* Keep the same instruction layout. */ > - EMIT2(0x66, 0x90); /* nop2 */ > + emit_nops(&prog, X86_PATCH_SIZE); > } > /* Exception callback receives FP as third parameter */ > if (is_exception_cb) { > @@ -438,8 +439,6 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, > /* sub rsp, rounded_stack_depth */ > if (stack_depth) > EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); > - if (tail_call_reachable) > - EMIT1(0x50); /* push rax */ > *pprog = prog; > } > > @@ -575,6 +574,54 @@ static void emit_return(u8 **pprog, u8 *ip) > *pprog = prog; > } > > +DEFINE_PER_CPU(u32, bpf_tail_call_cnt); Hi Leon, the solution is really simplifies complexity. If I understand correctly, this TAIL_CALL_CNT becomes the system global wise, not the prog global wise, but before it was limiting the TCC of entry prog. > + > +static __used void bpf_tail_call_cnt_prepare(void) > +{ > + /* The following asm equals to > + * > + * u32 *tcc_ptr = this_cpu_ptr(&bpf_tail_call_cnt); > + * > + * *tcc_ptr = 0; > + * > + * This asm must uses %rax only. > + */ > + > + asm volatile ( > + "addq " __percpu_arg(0) ", %1\n\t" > + "movl $0, (%%rax)\n\t" > + : > + : "m" (this_cpu_off), "r" (&bpf_tail_call_cnt) > + ); > +} > + > +static __used u32 bpf_tail_call_cnt_fetch_and_inc(void) > +{ > + u32 tail_call_cnt; > + > + /* The following asm equals to > + * > + * u32 *tcc_ptr = this_cpu_ptr(&bpf_tail_call_cnt); > + * > + * (*tcc_ptr)++; > + * tail_call_cnt = *tcc_ptr; > + * tail_call_cnt--; > + * > + * This asm must uses %rax only. > + */ > + > + asm volatile ( > + "addq " __percpu_arg(1) ", %2\n\t" > + "incl (%%rax)\n\t" > + "movl (%%rax), %0\n\t" > + "decl %0\n\t" > + : "=r" (tail_call_cnt) > + : "m" (this_cpu_off), "r" (&bpf_tail_call_cnt) > + ); > + > + return tail_call_cnt; > +} > + > /* > * Generate the following code: > * > @@ -594,7 +641,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > u32 stack_depth, u8 *ip, > struct jit_context *ctx) > { > - int tcc_off = -4 - round_up(stack_depth, 8); > u8 *prog = *pprog, *start = *pprog; > int offset; > > @@ -615,17 +661,14 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > offset = ctx->tail_call_indirect_label - (prog + 2 - start); > EMIT2(X86_JBE, offset); /* jbe out */ > > - /* > - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) > + /* if (bpf_tail_call_cnt_fetch_and_inc() >= MAX_TAIL_CALL_CNT) > * goto out; > */ > - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ > + emit_call(&prog, bpf_tail_call_cnt_fetch_and_inc, ip + (prog - start)); > EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ > > offset = ctx->tail_call_indirect_label - (prog + 2 - start); > EMIT2(X86_JAE, offset); /* jae out */ > - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ > - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ > > /* prog = array->ptrs[index]; */ > EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ > @@ -647,7 +690,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, > pop_callee_regs(&prog, callee_regs_used); > } > > - EMIT1(0x58); /* pop rax */ > if (stack_depth) > EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ > round_up(stack_depth, 8)); > @@ -675,21 +717,17 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, > bool *callee_regs_used, u32 stack_depth, > struct jit_context *ctx) > { > - int tcc_off = -4 - round_up(stack_depth, 8); > u8 *prog = *pprog, *start = *pprog; > int offset; > > - /* > - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) > + /* if (bpf_tail_call_cnt_fetch_and_inc() >= MAX_TAIL_CALL_CNT) > * goto out; > */ > - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ > + emit_call(&prog, bpf_tail_call_cnt_fetch_and_inc, ip); > EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ > > offset = ctx->tail_call_direct_label - (prog + 2 - start); > EMIT2(X86_JAE, offset); /* jae out */ > - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ > - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ > > poke->tailcall_bypass = ip + (prog - start); > poke->adj_off = X86_TAIL_CALL_OFFSET; > @@ -706,7 +744,6 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, > pop_callee_regs(&prog, callee_regs_used); > } > > - EMIT1(0x58); /* pop rax */ > if (stack_depth) > EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); > > @@ -1133,10 +1170,6 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op) > > #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp))) > > -/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ > -#define RESTORE_TAIL_CALL_CNT(stack) \ > - EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8) > - > static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image, > int oldproglen, struct jit_context *ctx, bool jmp_padding) > { > @@ -1160,7 +1193,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image > > emit_prologue(&prog, bpf_prog->aux->stack_depth, > bpf_prog_was_classic(bpf_prog), tail_call_reachable, > - bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb); > + bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb, > + image); > /* Exception callback will clobber callee regs for its own use, and > * restore the original callee regs from main prog's stack frame. > */ > @@ -1752,17 +1786,12 @@ st: if (is_imm8(insn->off)) > case BPF_JMP | BPF_CALL: { > int offs; > > + if (!imm32) > + return -EINVAL; > + > func = (u8 *) __bpf_call_base + imm32; > - if (tail_call_reachable) { > - RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth); > - if (!imm32) > - return -EINVAL; > - offs = 7 + x86_call_depth_emit_accounting(&prog, func); > - } else { > - if (!imm32) > - return -EINVAL; > - offs = x86_call_depth_emit_accounting(&prog, func); > - } > + offs = x86_call_depth_emit_accounting(&prog, func); > + > if (emit_call(&prog, func, image + addrs[i - 1] + offs)) > return -EINVAL; > break; > @@ -2550,7 +2579,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im > * [ ... ] > * [ stack_arg2 ] > * RBP - arg_stack_off [ stack_arg1 ] > - * RSP [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX > */ > > /* room for return value of orig_call or fentry prog */ > @@ -2622,8 +2650,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im > /* sub rsp, stack_size */ > EMIT4(0x48, 0x83, 0xEC, stack_size); > } > - if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) > - EMIT1(0x50); /* push rax */ > /* mov QWORD PTR [rbp - rbx_off], rbx */ > emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off); > > @@ -2678,16 +2704,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im > restore_regs(m, &prog, regs_off); > save_args(m, &prog, arg_stack_off, true); > > - if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { > - /* Before calling the original function, restore the > - * tail_call_cnt from stack to rax. > - */ > - RESTORE_TAIL_CALL_CNT(stack_size); > - } > - > if (flags & BPF_TRAMP_F_ORIG_STACK) { > - emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8); > - EMIT2(0xff, 0xd3); /* call *rbx */ > + emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8); > + EMIT2(0xff, 0xd0); /* call *rax */ > } else { > /* call original function */ > if (emit_rsb_call(&prog, orig_call, image + (prog - (u8 *)rw_image))) { > @@ -2740,11 +2759,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im > ret = -EINVAL; > goto cleanup; > } > - } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { > - /* Before running the original function, restore the > - * tail_call_cnt from stack to rax. > - */ > - RESTORE_TAIL_CALL_CNT(stack_size); > } > > /* restore return value of orig_call or fentry prog back into RAX */
On 2024/2/23 12:06, Pu Lehui wrote: > > > On 2024/2/22 16:52, Leon Hwang wrote: [SNIP] >> } >> @@ -575,6 +574,54 @@ static void emit_return(u8 **pprog, u8 *ip) >> *pprog = prog; >> } >> +DEFINE_PER_CPU(u32, bpf_tail_call_cnt); > > Hi Leon, the solution is really simplifies complexity. If I understand > correctly, this TAIL_CALL_CNT becomes the system global wise, not the > prog global wise, but before it was limiting the TCC of entry prog. > Correct. It becomes a PERCPU global variable. But, I think this solution is not robust enough. For example, time prog1 prog1 ==================================> line prog2 this is a time-line on a CPU. If prog1 and prog2 have tailcalls to run, prog2 will reset the tail_call_cnt on current CPU, which is used by prog1. As a result, when the CPU schedules from prog2 to prog1, tail_call_cnt on current CPU has been reset to 0, no matter whether prog1 incremented it. The tail_call_cnt reset issue happens too, even if PERCPU tail_call_cnt moves to 'struct bpf_prog_aux', i.e. one kprobe bpf prog can be triggered on many functions e.g. cilium/pwru. However, this moving is better than this solution. I think, my previous POC of 'struct bpf_prog_run_ctx' would be better. I'll resend it later, with some improvements. Thanks, Leon
On Fri, Feb 23, 2024 at 7:30 AM Leon Hwang <hffilwlqm@gmail.com> wrote: > > > > On 2024/2/23 12:06, Pu Lehui wrote: > > > > > > On 2024/2/22 16:52, Leon Hwang wrote: > > [SNIP] > > >> } > >> @@ -575,6 +574,54 @@ static void emit_return(u8 **pprog, u8 *ip) > >> *pprog = prog; > >> } > >> +DEFINE_PER_CPU(u32, bpf_tail_call_cnt); > > > > Hi Leon, the solution is really simplifies complexity. If I understand > > correctly, this TAIL_CALL_CNT becomes the system global wise, not the > > prog global wise, but before it was limiting the TCC of entry prog. > > > > Correct. It becomes a PERCPU global variable. > > But, I think this solution is not robust enough. > > For example, > > time prog1 prog1 > ==================================> > line prog2 > > this is a time-line on a CPU. If prog1 and prog2 have tailcalls to run, > prog2 will reset the tail_call_cnt on current CPU, which is used by > prog1. As a result, when the CPU schedules from prog2 to prog1, > tail_call_cnt on current CPU has been reset to 0, no matter whether > prog1 incremented it. > > The tail_call_cnt reset issue happens too, even if PERCPU tail_call_cnt > moves to 'struct bpf_prog_aux', i.e. one kprobe bpf prog can be > triggered on many functions e.g. cilium/pwru. However, this moving is > better than this solution. kprobe progs are not preemptable. There is bpf_prog_active that disallows any recursion. Moving this percpu count to prog->aux should solve it. > I think, my previous POC of 'struct bpf_prog_run_ctx' would be better. > I'll resend it later, with some improvements. percpu approach is still prefered, since it removes rax mess.
On Thu, Feb 22, 2024 at 12:53 AM Leon Hwang <hffilwlqm@gmail.com> wrote: > > +DEFINE_PER_CPU(u32, bpf_tail_call_cnt); > + > +static __used void bpf_tail_call_cnt_prepare(void) > +{ > + /* The following asm equals to > + * > + * u32 *tcc_ptr = this_cpu_ptr(&bpf_tail_call_cnt); > + * > + * *tcc_ptr = 0; > + * > + * This asm must uses %rax only. > + */ > + > + asm volatile ( > + "addq " __percpu_arg(0) ", %1\n\t" > + "movl $0, (%%rax)\n\t" This looks wrong. Should probably be "movl $0, (%1)" ? > + : > + : "m" (this_cpu_off), "r" (&bpf_tail_call_cnt) > + ); > +} > + > +static __used u32 bpf_tail_call_cnt_fetch_and_inc(void) > +{ > + u32 tail_call_cnt; > + > + /* The following asm equals to > + * > + * u32 *tcc_ptr = this_cpu_ptr(&bpf_tail_call_cnt); > + * > + * (*tcc_ptr)++; > + * tail_call_cnt = *tcc_ptr; > + * tail_call_cnt--; > + * > + * This asm must uses %rax only. > + */ > + > + asm volatile ( > + "addq " __percpu_arg(1) ", %2\n\t" > + "incl (%%rax)\n\t" > + "movl (%%rax), %0\n\t" and %2 here instead of rax ? > + "decl %0\n\t" > + : "=r" (tail_call_cnt) > + : "m" (this_cpu_off), "r" (&bpf_tail_call_cnt) > + ); > +
On 2024/2/24 00:35, Alexei Starovoitov wrote: > On Fri, Feb 23, 2024 at 7:30 AM Leon Hwang <hffilwlqm@gmail.com> wrote: >> >> >> >> On 2024/2/23 12:06, Pu Lehui wrote: >>> >>> >>> On 2024/2/22 16:52, Leon Hwang wrote: >> >> [SNIP] >> >>>> } >>>> @@ -575,6 +574,54 @@ static void emit_return(u8 **pprog, u8 *ip) >>>> *pprog = prog; >>>> } >>>> +DEFINE_PER_CPU(u32, bpf_tail_call_cnt); >>> >>> Hi Leon, the solution is really simplifies complexity. If I understand >>> correctly, this TAIL_CALL_CNT becomes the system global wise, not the >>> prog global wise, but before it was limiting the TCC of entry prog. >>> >> >> Correct. It becomes a PERCPU global variable. >> >> But, I think this solution is not robust enough. >> >> For example, >> >> time prog1 prog1 >> ==================================> >> line prog2 >> >> this is a time-line on a CPU. If prog1 and prog2 have tailcalls to run, >> prog2 will reset the tail_call_cnt on current CPU, which is used by >> prog1. As a result, when the CPU schedules from prog2 to prog1, >> tail_call_cnt on current CPU has been reset to 0, no matter whether >> prog1 incremented it. >> >> The tail_call_cnt reset issue happens too, even if PERCPU tail_call_cnt >> moves to 'struct bpf_prog_aux', i.e. one kprobe bpf prog can be >> triggered on many functions e.g. cilium/pwru. However, this moving is >> better than this solution. > > kprobe progs are not preemptable. > There is bpf_prog_active that disallows any recursion. > Moving this percpu count to prog->aux should solve it. > >> I think, my previous POC of 'struct bpf_prog_run_ctx' would be better. >> I'll resend it later, with some improvements. > > percpu approach is still prefered, since it removes rax mess. It seems that we cannot remove rax. Let's take a look at tailcall3.c selftest: struct { __uint(type, BPF_MAP_TYPE_PROG_ARRAY); __uint(max_entries, 1); __uint(key_size, sizeof(__u32)); __uint(value_size, sizeof(__u32)); } jmp_table SEC(".maps"); int count = 0; SEC("tc") int classifier_0(struct __sk_buff *skb) { count++; bpf_tail_call_static(skb, &jmp_table, 0); return 1; } SEC("tc") int entry(struct __sk_buff *skb) { bpf_tail_call_static(skb, &jmp_table, 0); return 0; } Here, classifier_0 is populated to jmp_table. Then, at classifier_0's prologue, when we 'move rax, classifier_0->tail_call_cnt' in order to use the PERCPU tail_call_cnt in 'struct bpf_prog' for current run-time, it fails to run selftests. It's because the tail_call_cnt is not from the entry bpf prog. The tail_call_cnt from the entry bpf prog is the expected one, even though classifier_0 bpf prog runs. (It seems that it's unnecessary to provide the diff of the exclusive approach with PERCPU tail_call_cnt.) Next, I tried a POC with PERCPU tail_call_cnt in 'struct bpf_prog' and rax: 1. At prologue, initialise tail_call_cnt from bpf_prog's tail_call_cnt. 2. Propagate tail_call_cnt pointer by the previous rax way. Here's the diff for the POC: diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index e1390d1e3..54f5770d9 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -18,18 +18,22 @@ #include <asm/text-patching.h> #include <asm/unwind.h> #include <asm/cfi.h> +#include <asm/percpu.h> static bool all_callee_regs_used[4] = {true, true, true, true}; -static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) +static u8 *emit_code(u8 *ptr, u64 bytes, unsigned int len) { if (len == 1) *ptr = bytes; else if (len == 2) *(u16 *)ptr = bytes; - else { + else if (len == 4) { *(u32 *)ptr = bytes; barrier(); + } else { + *(u64 *)ptr = bytes; + barrier(); } return ptr + len; } @@ -51,6 +55,9 @@ static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) #define EMIT4_off32(b1, b2, b3, b4, off) \ do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0) +#define EMIT2_off64(b1, b2, off) \ + do { EMIT2(b1, b2); EMIT(off, 8); } while(0) + #ifdef CONFIG_X86_KERNEL_IBT #define EMIT_ENDBR() EMIT(gen_endbr(), 4) #define EMIT_ENDBR_POISON() EMIT(gen_endbr_poison(), 4) @@ -259,7 +266,7 @@ struct jit_context { /* Number of bytes emit_patch() needs to generate instructions */ #define X86_PATCH_SIZE 5 /* Number of bytes that will be skipped on tailcall */ -#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) +#define X86_TAIL_CALL_OFFSET (24 + ENDBR_INSN_SIZE) static void push_r12(u8 **pprog) { @@ -389,16 +396,40 @@ static void emit_cfi(u8 **pprog, u32 hash) *pprog = prog; } + +static __used void bpf_tail_call_cnt_prepare(void) +{ + /* The following asm equals to + * + * u32 *tcc_ptr = this_cpu_ptr(prog->aux->entry->tail_call_cnt); + * + * *tcc_ptr = 0; + * + * This asm must uses %rax only. + */ + + /* %rax has been set as prog->aux->entry->tail_call_cnt. */ + asm volatile ( + "addq " __percpu_arg(0) ", %%rax\n\t" + "movl $0, (%%rax)\n\t" + : + : "m" (this_cpu_off) + ); +} + +static int emit_call(u8 **pprog, void *func, void *ip); + /* * Emit x86-64 prologue code for BPF program. * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes * while jumping to another program */ -static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, - bool tail_call_reachable, bool is_subprog, - bool is_exception_cb) +static void emit_prologue(struct bpf_prog *bpf_prog, u8 **pprog, u32 stack_depth, + bool ebpf_from_cbpf, bool tail_call_reachable, + bool is_subprog, bool is_exception_cb, u8 *ip) { - u8 *prog = *pprog; + struct bpf_prog *entry = bpf_prog->aux->entry; + u8 *prog = *pprog, *start = *pprog; emit_cfi(&prog, is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash); /* BPF trampoline can be made to work without these nops, @@ -406,14 +437,16 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, */ emit_nops(&prog, X86_PATCH_SIZE); if (!ebpf_from_cbpf) { - if (tail_call_reachable && !is_subprog) - /* When it's the entry of the whole tailcall context, - * zeroing rax means initialising tail_call_cnt. - */ - EMIT2(0x31, 0xC0); /* xor eax, eax */ - else + if (tail_call_reachable && !is_subprog) { + /* mov rax, entry->tail_call_cnt */ + EMIT2_off64(0x48, 0xB8, (u64) entry->tail_call_cnt); + /* call bpf_tail_call_cnt_prepare */ + emit_call(&prog, bpf_tail_call_cnt_prepare, + ip + (prog - start)); + } else { /* Keep the same instruction layout. */ - EMIT2(0x66, 0x90); /* nop2 */ + emit_nops(&prog, 10 + X86_PATCH_SIZE); + } } /* Exception callback receives FP as third parameter */ if (is_exception_cb) { @@ -581,7 +614,7 @@ static void emit_return(u8 **pprog, u8 *ip) * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ... * if (index >= array->map.max_entries) * goto out; - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) * goto out; * prog = array->ptrs[index]; * if (prog == NULL) @@ -594,7 +627,7 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, u32 stack_depth, u8 *ip, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); + int tcc_ptr_off = -8 - round_up(stack_depth, 8); u8 *prog = *pprog, *start = *pprog; int offset; @@ -616,16 +649,15 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, EMIT2(X86_JBE, offset); /* jbe out */ /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ + EMIT3_off32(0x48, 0x8B, 0x85, tcc_ptr_off); /* mov rax, qword ptr [rbp - tcc_ptr_off] */ + EMIT3(0x83, 0x38, MAX_TAIL_CALL_CNT); /* cmp dword ptr [rax], MAX_TAIL_CALL_CNT */ offset = ctx->tail_call_indirect_label - (prog + 2 - start); EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ + EMIT3(0x83, 0x00, 0x01); /* add dword ptr [rax], 1 */ /* prog = array->ptrs[index]; */ EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ @@ -647,6 +679,7 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, pop_callee_regs(&prog, callee_regs_used); } + /* pop tail_call_cnt_ptr */ EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ @@ -675,21 +708,20 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, bool *callee_regs_used, u32 stack_depth, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); + int tcc_ptr_off = -8 - round_up(stack_depth, 8); u8 *prog = *pprog, *start = *pprog; int offset; /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if ((*tcc_ptr)++ >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ + EMIT3_off32(0x48, 0x8B, 0x85, tcc_ptr_off); /* mov rax, qword ptr [rbp - tcc_ptr_off] */ + EMIT3(0x83, 0x38, MAX_TAIL_CALL_CNT); /* cmp dword ptr [rax], MAX_TAIL_CALL_CNT */ offset = ctx->tail_call_direct_label - (prog + 2 - start); EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ + EMIT3(0x83, 0x00, 0x01); /* add dword ptr [rax], 1 */ poke->tailcall_bypass = ip + (prog - start); poke->adj_off = X86_TAIL_CALL_OFFSET; @@ -706,6 +738,7 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, pop_callee_regs(&prog, callee_regs_used); } + /* pop tail_call_cnt_ptr */ EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); @@ -1134,7 +1167,7 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op) #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp))) /* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ -#define RESTORE_TAIL_CALL_CNT(stack) \ +#define LOAD_TAIL_CALL_CNT_PTR(stack) \ EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8) static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image, @@ -1158,9 +1191,10 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image /* tail call's presence in current prog implies it is reachable */ tail_call_reachable |= tail_call_seen; - emit_prologue(&prog, bpf_prog->aux->stack_depth, + emit_prologue(bpf_prog, &prog, bpf_prog->aux->stack_depth, bpf_prog_was_classic(bpf_prog), tail_call_reachable, - bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb); + bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb, + image); /* Exception callback will clobber callee regs for its own use, and * restore the original callee regs from main prog's stack frame. */ @@ -1754,7 +1788,7 @@ st: if (is_imm8(insn->off)) func = (u8 *) __bpf_call_base + imm32; if (tail_call_reachable) { - RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth); + LOAD_TAIL_CALL_CNT_PTR(bpf_prog->aux->stack_depth); if (!imm32) return -EINVAL; offs = 7 + x86_call_depth_emit_accounting(&prog, func); @@ -2679,10 +2713,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im save_args(m, &prog, arg_stack_off, true); if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before calling the original function, restore the - * tail_call_cnt from stack to rax. + /* Before calling the original function, load the + * tail_call_cnt_ptr to rax. */ - RESTORE_TAIL_CALL_CNT(stack_size); + LOAD_TAIL_CALL_CNT_PTR(stack_size); } if (flags & BPF_TRAMP_F_ORIG_STACK) { @@ -2741,10 +2775,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im goto cleanup; } } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before running the original function, restore the - * tail_call_cnt from stack to rax. + /* Before running the original function, load the + * tail_call_cnt_ptr to rax. */ - RESTORE_TAIL_CALL_CNT(stack_size); + LOAD_TAIL_CALL_CNT_PTR(stack_size); } /* restore return value of orig_call or fentry prog back into RAX */ diff --git a/include/linux/bpf.h b/include/linux/bpf.h index 814dc913a..5e8abcb11 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -1459,6 +1459,7 @@ struct bpf_prog_aux { /* function name for valid attach_btf_id */ const char *attach_func_name; struct bpf_prog **func; + struct bpf_prog *entry; void *jit_data; /* JIT specific data. arch dependent */ struct bpf_jit_poke_descriptor *poke_tab; struct bpf_kfunc_desc_tab *kfunc_tab; @@ -1542,6 +1543,7 @@ struct bpf_prog { u8 tag[BPF_TAG_SIZE]; struct bpf_prog_stats __percpu *stats; int __percpu *active; + u32 __percpu *tail_call_cnt; unsigned int (*bpf_func)(const void *ctx, const struct bpf_insn *insn); struct bpf_prog_aux *aux; /* Auxiliary fields */ diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c index 71c459a51..7884f66fc 100644 --- a/kernel/bpf/core.c +++ b/kernel/bpf/core.c @@ -110,6 +110,13 @@ struct bpf_prog *bpf_prog_alloc_no_stats(unsigned int size, gfp_t gfp_extra_flag kfree(aux); return NULL; } + fp->tail_call_cnt = alloc_percpu_gfp(u32, bpf_memcg_flags(GFP_KERNEL | gfp_extra_flags)); + if (!fp->tail_call_cnt) { + free_percpu(fp->active); + vfree(fp); + kfree(aux); + return NULL; + } fp->pages = size / PAGE_SIZE; fp->aux = aux; @@ -142,6 +149,7 @@ struct bpf_prog *bpf_prog_alloc(unsigned int size, gfp_t gfp_extra_flags) prog->stats = alloc_percpu_gfp(struct bpf_prog_stats, gfp_flags); if (!prog->stats) { + free_percpu(prog->tail_call_cnt); free_percpu(prog->active); kfree(prog->aux); vfree(prog); @@ -261,6 +269,7 @@ struct bpf_prog *bpf_prog_realloc(struct bpf_prog *fp_old, unsigned int size, fp_old->aux = NULL; fp_old->stats = NULL; fp_old->active = NULL; + fp_old->tail_call_cnt = NULL; __bpf_prog_free(fp_old); } @@ -277,6 +286,7 @@ void __bpf_prog_free(struct bpf_prog *fp) } free_percpu(fp->stats); free_percpu(fp->active); + free_percpu(fp->tail_call_cnt); vfree(fp); } @@ -1428,6 +1438,7 @@ static void bpf_prog_clone_free(struct bpf_prog *fp) fp->aux = NULL; fp->stats = NULL; fp->active = NULL; + fp->tail_call_cnt = NULL; __bpf_prog_free(fp); } @@ -2379,6 +2390,7 @@ struct bpf_prog *bpf_prog_select_runtime(struct bpf_prog *fp, int *err) if (*err) return fp; + fp->aux->entry = fp; fp = bpf_int_jit_compile(fp); bpf_prog_jit_attempt_done(fp); if (!fp->jited && jit_needed) { diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 011d54a1d..442d0a4b2 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -19025,6 +19025,7 @@ static int jit_subprogs(struct bpf_verifier_env *env) } func[i]->aux->num_exentries = num_exentries; func[i]->aux->tail_call_reachable = env->subprog_info[i].tail_call_reachable; + func[i]->aux->entry = prog; func[i]->aux->exception_cb = env->subprog_info[i].is_exception_cb; if (!i) func[i]->aux->exception_boundary = env->seen_exception; Thanks, Leon
On 2024/2/26 23:32, Leon Hwang wrote: > > > On 2024/2/24 00:35, Alexei Starovoitov wrote: >> On Fri, Feb 23, 2024 at 7:30 AM Leon Hwang <hffilwlqm@gmail.com> wrote: >>> >>> >>> >>> On 2024/2/23 12:06, Pu Lehui wrote: >>>> >>>> >>>> On 2024/2/22 16:52, Leon Hwang wrote: >>> >>> [SNIP] >>> >>>>> } >>>>> @@ -575,6 +574,54 @@ static void emit_return(u8 **pprog, u8 *ip) >>>>> *pprog = prog; >>>>> } >>>>> +DEFINE_PER_CPU(u32, bpf_tail_call_cnt); >>>> >>>> Hi Leon, the solution is really simplifies complexity. If I understand >>>> correctly, this TAIL_CALL_CNT becomes the system global wise, not the >>>> prog global wise, but before it was limiting the TCC of entry prog. >>>> >>> >>> Correct. It becomes a PERCPU global variable. >>> >>> But, I think this solution is not robust enough. >>> >>> For example, >>> >>> time prog1 prog1 >>> ==================================> >>> line prog2 >>> >>> this is a time-line on a CPU. If prog1 and prog2 have tailcalls to run, >>> prog2 will reset the tail_call_cnt on current CPU, which is used by >>> prog1. As a result, when the CPU schedules from prog2 to prog1, >>> tail_call_cnt on current CPU has been reset to 0, no matter whether >>> prog1 incremented it. >>> >>> The tail_call_cnt reset issue happens too, even if PERCPU tail_call_cnt >>> moves to 'struct bpf_prog_aux', i.e. one kprobe bpf prog can be >>> triggered on many functions e.g. cilium/pwru. However, this moving is >>> better than this solution. >> >> kprobe progs are not preemptable. >> There is bpf_prog_active that disallows any recursion. >> Moving this percpu count to prog->aux should solve it. >> >>> I think, my previous POC of 'struct bpf_prog_run_ctx' would be better. >>> I'll resend it later, with some improvements. >> >> percpu approach is still prefered, since it removes rax mess. > > It seems that we cannot remove rax. > > Let's take a look at tailcall3.c selftest: > > struct { > __uint(type, BPF_MAP_TYPE_PROG_ARRAY); > __uint(max_entries, 1); > __uint(key_size, sizeof(__u32)); > __uint(value_size, sizeof(__u32)); > } jmp_table SEC(".maps"); > > int count = 0; > > SEC("tc") > int classifier_0(struct __sk_buff *skb) > { > count++; > bpf_tail_call_static(skb, &jmp_table, 0); > return 1; > } > > SEC("tc") > int entry(struct __sk_buff *skb) > { > bpf_tail_call_static(skb, &jmp_table, 0); > return 0; > } > > Here, classifier_0 is populated to jmp_table. > > Then, at classifier_0's prologue, when we 'move rax, > classifier_0->tail_call_cnt' in order to use the PERCPU tail_call_cnt in > 'struct bpf_prog' for current run-time, it fails to run selftests. It's > because the tail_call_cnt is not from the entry bpf prog. The > tail_call_cnt from the entry bpf prog is the expected one, even though > classifier_0 bpf prog runs. (It seems that it's unnecessary to provide > the diff of the exclusive approach with PERCPU tail_call_cnt.) Sorry for the unclear message. It should be emit_bpf_tail_call_xxx() instead of emit_prologue(). I think it's better to provide the diff of the exclusive approach with PERCPU tail_call_cnt, in order to compare these two approachs. P.S. This POC failed to pass all selftests, like "Summary: 0/13 PASSED, 0 SKIPPED, 1 FAILED". diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index e1390d1e3..695c99c0f 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -18,18 +18,22 @@ #include <asm/text-patching.h> #include <asm/unwind.h> #include <asm/cfi.h> +#include <asm/percpu.h> static bool all_callee_regs_used[4] = {true, true, true, true}; -static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) +static u8 *emit_code(u8 *ptr, u64 bytes, unsigned int len) { if (len == 1) *ptr = bytes; else if (len == 2) *(u16 *)ptr = bytes; - else { + else if (len == 4) { *(u32 *)ptr = bytes; barrier(); + } else { + *(u64 *)ptr = bytes; + barrier(); } return ptr + len; } @@ -51,6 +55,9 @@ static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len) #define EMIT4_off32(b1, b2, b3, b4, off) \ do { EMIT4(b1, b2, b3, b4); EMIT(off, 4); } while (0) +#define EMIT2_off64(b1, b2, off) \ + do { EMIT2(b1, b2); EMIT(off, 8); } while(0) + #ifdef CONFIG_X86_KERNEL_IBT #define EMIT_ENDBR() EMIT(gen_endbr(), 4) #define EMIT_ENDBR_POISON() EMIT(gen_endbr_poison(), 4) @@ -259,7 +266,7 @@ struct jit_context { /* Number of bytes emit_patch() needs to generate instructions */ #define X86_PATCH_SIZE 5 /* Number of bytes that will be skipped on tailcall */ -#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) +#define X86_TAIL_CALL_OFFSET (24 + ENDBR_INSN_SIZE) static void push_r12(u8 **pprog) { @@ -389,16 +396,20 @@ static void emit_cfi(u8 **pprog, u32 hash) *pprog = prog; } +static int emit_call(u8 **pprog, void *func, void *ip); +static __used void bpf_tail_call_cnt_prepare(void); + /* * Emit x86-64 prologue code for BPF program. * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes * while jumping to another program */ -static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, - bool tail_call_reachable, bool is_subprog, - bool is_exception_cb) +static void emit_prologue(struct bpf_prog *bpf_prog, u8 **pprog, u32 stack_depth, + bool ebpf_from_cbpf, bool tail_call_reachable, + bool is_subprog, bool is_exception_cb, u8 *ip) { - u8 *prog = *pprog; + struct bpf_prog *entry = bpf_prog->aux->entry; + u8 *prog = *pprog, *start = *pprog; emit_cfi(&prog, is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash); /* BPF trampoline can be made to work without these nops, @@ -406,14 +417,16 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, */ emit_nops(&prog, X86_PATCH_SIZE); if (!ebpf_from_cbpf) { - if (tail_call_reachable && !is_subprog) - /* When it's the entry of the whole tailcall context, - * zeroing rax means initialising tail_call_cnt. - */ - EMIT2(0x31, 0xC0); /* xor eax, eax */ - else + if (tail_call_reachable && !is_subprog) { + /* mov rax, entry->tail_call_cnt */ + EMIT2_off64(0x48, 0xB8, (u64) entry->tail_call_cnt); + /* call bpf_tail_call_cnt_prepare */ + emit_call(&prog, bpf_tail_call_cnt_prepare, + ip + (prog - start)); + } else { /* Keep the same instruction layout. */ - EMIT2(0x66, 0x90); /* nop2 */ + emit_nops(&prog, 10 + X86_PATCH_SIZE); + } } /* Exception callback receives FP as third parameter */ if (is_exception_cb) { @@ -438,8 +451,6 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, /* sub rsp, rounded_stack_depth */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); - if (tail_call_reachable) - EMIT1(0x50); /* push rax */ *pprog = prog; } @@ -575,13 +586,61 @@ static void emit_return(u8 **pprog, u8 *ip) *pprog = prog; } +static __used void bpf_tail_call_cnt_prepare(void) +{ + /* The following asm equals to + * + * u32 *tcc_ptr = this_cpu_ptr(prog->aux->entry->tail_call_cnt); + * + * *tcc_ptr = 0; + * + * This asm must uses %rax only. + */ + + /* %rax has been set as prog->aux->entry->tail_call_cnt. */ + asm volatile ( + "addq " __percpu_arg(0) ", %%rax\n\t" + "movl $0, (%%rax)\n\t" + : + : "m" (this_cpu_off) + ); +} + +static __used u32 bpf_tail_call_cnt_fetch_and_inc(void) +{ + u32 tail_call_cnt; + + /* The following asm equals to + * + * u32 *tcc_ptr = this_cpu_ptr(prog->aux->entry->tail_call_cnt); + * + * (*tcc_ptr)++; + * tail_call_cnt = *tcc_ptr; + * tail_call_cnt--; + * + * This asm must uses %rax only. + */ + + /* %rax has been set as prog->aux->entry->tail_call_cnt. */ + asm volatile ( + "addq " __percpu_arg(1) ", %%rax\n\t" + "incl (%%rax)\n\t" + "movl (%%rax), %0\n\t" + "decl %0\n\t" + : "=r" (tail_call_cnt) + : "m" (this_cpu_off) + ); + + return tail_call_cnt; +} + /* * Generate the following code: * * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ... * if (index >= array->map.max_entries) * goto out; - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if (bpf_tail_call_cnt_fetch_and_inc() >= MAX_TAIL_CALL_CNT) * goto out; * prog = array->ptrs[index]; * if (prog == NULL) @@ -594,7 +653,7 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, u32 stack_depth, u8 *ip, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); + struct bpf_prog *entry = bpf_prog->aux->entry; u8 *prog = *pprog, *start = *pprog; int offset; @@ -615,17 +674,16 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, offset = ctx->tail_call_indirect_label - (prog + 2 - start); EMIT2(X86_JBE, offset); /* jbe out */ - /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + /* if (bpf_tail_call_cnt_fetch_and_inc() >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ + /* mov rax, entry->tail_call_cnt */ + EMIT2_off64(0x48, 0xB8, (u64) entry->tail_call_cnt); + emit_call(&prog, bpf_tail_call_cnt_fetch_and_inc, ip + (prog - start)); EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ offset = ctx->tail_call_indirect_label - (prog + 2 - start); EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ /* prog = array->ptrs[index]; */ EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ @@ -647,7 +705,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, pop_callee_regs(&prog, callee_regs_used); } - EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ round_up(stack_depth, 8)); @@ -675,21 +732,20 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, bool *callee_regs_used, u32 stack_depth, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); + struct bpf_prog *entry = bpf_prog->aux->entry; u8 *prog = *pprog, *start = *pprog; int offset; - /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + /* if (bpf_tail_call_cnt_fetch_and_inc() >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ + /* mov rax, entry->tail_call_cnt */ + EMIT2_off64(0x48, 0xB8, (u64) entry->tail_call_cnt); + emit_call(&prog, bpf_tail_call_cnt_fetch_and_inc, ip + (prog - start)); EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ offset = ctx->tail_call_direct_label - (prog + 2 - start); EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ poke->tailcall_bypass = ip + (prog - start); poke->adj_off = X86_TAIL_CALL_OFFSET; @@ -706,7 +762,6 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, pop_callee_regs(&prog, callee_regs_used); } - EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); @@ -1133,10 +1188,6 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op) #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp))) -/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ -#define RESTORE_TAIL_CALL_CNT(stack) \ - EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8) - static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image, int oldproglen, struct jit_context *ctx, bool jmp_padding) { @@ -1158,9 +1209,10 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image /* tail call's presence in current prog implies it is reachable */ tail_call_reachable |= tail_call_seen; - emit_prologue(&prog, bpf_prog->aux->stack_depth, + emit_prologue(bpf_prog, &prog, bpf_prog->aux->stack_depth, bpf_prog_was_classic(bpf_prog), tail_call_reachable, - bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb); + bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb, + image); /* Exception callback will clobber callee regs for its own use, and * restore the original callee regs from main prog's stack frame. */ @@ -1752,17 +1804,12 @@ st: if (is_imm8(insn->off)) case BPF_JMP | BPF_CALL: { int offs; + if (!imm32) + return -EINVAL; + func = (u8 *) __bpf_call_base + imm32; - if (tail_call_reachable) { - RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth); - if (!imm32) - return -EINVAL; - offs = 7 + x86_call_depth_emit_accounting(&prog, func); - } else { - if (!imm32) - return -EINVAL; - offs = x86_call_depth_emit_accounting(&prog, func); - } + offs = x86_call_depth_emit_accounting(&prog, func); + if (emit_call(&prog, func, image + addrs[i - 1] + offs)) return -EINVAL; break; @@ -2550,7 +2597,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im * [ ... ] * [ stack_arg2 ] * RBP - arg_stack_off [ stack_arg1 ] - * RSP [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX */ /* room for return value of orig_call or fentry prog */ @@ -2622,8 +2668,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im /* sub rsp, stack_size */ EMIT4(0x48, 0x83, 0xEC, stack_size); } - if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) - EMIT1(0x50); /* push rax */ /* mov QWORD PTR [rbp - rbx_off], rbx */ emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off); @@ -2678,16 +2722,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im restore_regs(m, &prog, regs_off); save_args(m, &prog, arg_stack_off, true); - if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before calling the original function, restore the - * tail_call_cnt from stack to rax. - */ - RESTORE_TAIL_CALL_CNT(stack_size); - } - if (flags & BPF_TRAMP_F_ORIG_STACK) { - emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8); - EMIT2(0xff, 0xd3); /* call *rbx */ + emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8); + EMIT2(0xff, 0xd0); /* call *rax */ } else { /* call original function */ if (emit_rsb_call(&prog, orig_call, image + (prog - (u8 *)rw_image))) { @@ -2740,11 +2777,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im ret = -EINVAL; goto cleanup; } - } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before running the original function, restore the - * tail_call_cnt from stack to rax. - */ - RESTORE_TAIL_CALL_CNT(stack_size); } /* restore return value of orig_call or fentry prog back into RAX */ diff --git a/include/linux/bpf.h b/include/linux/bpf.h index 814dc913a..5e8abcb11 100644 --- a/include/linux/bpf.h +++ b/include/linux/bpf.h @@ -1459,6 +1459,7 @@ struct bpf_prog_aux { /* function name for valid attach_btf_id */ const char *attach_func_name; struct bpf_prog **func; + struct bpf_prog *entry; void *jit_data; /* JIT specific data. arch dependent */ struct bpf_jit_poke_descriptor *poke_tab; struct bpf_kfunc_desc_tab *kfunc_tab; @@ -1542,6 +1543,7 @@ struct bpf_prog { u8 tag[BPF_TAG_SIZE]; struct bpf_prog_stats __percpu *stats; int __percpu *active; + u32 __percpu *tail_call_cnt; unsigned int (*bpf_func)(const void *ctx, const struct bpf_insn *insn); struct bpf_prog_aux *aux; /* Auxiliary fields */ diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c index 71c459a51..1b5baa922 100644 --- a/kernel/bpf/core.c +++ b/kernel/bpf/core.c @@ -110,6 +110,13 @@ struct bpf_prog *bpf_prog_alloc_no_stats(unsigned int size, gfp_t gfp_extra_flag kfree(aux); return NULL; } + fp->tail_call_cnt = alloc_percpu_gfp(u32, bpf_memcg_flags(GFP_KERNEL | gfp_extra_flags)); + if (!fp->tail_call_cnt) { + free_percpu(fp->active); + vfree(fp); + kfree(aux); + return NULL; + } fp->pages = size / PAGE_SIZE; fp->aux = aux; @@ -142,6 +149,7 @@ struct bpf_prog *bpf_prog_alloc(unsigned int size, gfp_t gfp_extra_flags) prog->stats = alloc_percpu_gfp(struct bpf_prog_stats, gfp_flags); if (!prog->stats) { + free_percpu(prog->tail_call_cnt); free_percpu(prog->active); kfree(prog->aux); vfree(prog); @@ -261,6 +269,7 @@ struct bpf_prog *bpf_prog_realloc(struct bpf_prog *fp_old, unsigned int size, fp_old->aux = NULL; fp_old->stats = NULL; fp_old->active = NULL; + fp_old->tail_call_cnt = NULL; __bpf_prog_free(fp_old); } @@ -277,6 +286,7 @@ void __bpf_prog_free(struct bpf_prog *fp) } free_percpu(fp->stats); free_percpu(fp->active); + free_percpu(fp->tail_call_cnt); vfree(fp); } @@ -1428,6 +1438,7 @@ static void bpf_prog_clone_free(struct bpf_prog *fp) fp->aux = NULL; fp->stats = NULL; fp->active = NULL; + fp->tail_call_cnt = NULL; __bpf_prog_free(fp); } diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c index 011d54a1d..616e1d7a5 100644 --- a/kernel/bpf/verifier.c +++ b/kernel/bpf/verifier.c @@ -18917,7 +18917,7 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env) static int jit_subprogs(struct bpf_verifier_env *env) { - struct bpf_prog *prog = env->prog, **func, *tmp; + struct bpf_prog *prog = env->prog, **func, *tmp, *entry = prog; int i, j, subprog_start, subprog_end = 0, len, subprog; struct bpf_map *map_ptr; struct bpf_insn *insn; @@ -19025,6 +19025,7 @@ static int jit_subprogs(struct bpf_verifier_env *env) } func[i]->aux->num_exentries = num_exentries; func[i]->aux->tail_call_reachable = env->subprog_info[i].tail_call_reachable; + func[i]->aux->entry = entry; func[i]->aux->exception_cb = env->subprog_info[i].is_exception_cb; if (!i) func[i]->aux->exception_boundary = env->seen_exception; > > Next, I tried a POC with PERCPU tail_call_cnt in 'struct bpf_prog' and rax: > > 1. At prologue, initialise tail_call_cnt from bpf_prog's tail_call_cnt. > 2. Propagate tail_call_cnt pointer by the previous rax way. > > Here's the diff for the POC: > [SNIP] Thanks, Leon
On Mon, Feb 26, 2024 at 7:32 AM Leon Hwang <hffilwlqm@gmail.com> wrote: > > > > On 2024/2/24 00:35, Alexei Starovoitov wrote: > > On Fri, Feb 23, 2024 at 7:30 AM Leon Hwang <hffilwlqm@gmail.com> wrote: > >> > >> > >> > >> On 2024/2/23 12:06, Pu Lehui wrote: > >>> > >>> > >>> On 2024/2/22 16:52, Leon Hwang wrote: > >> > >> [SNIP] > >> > >>>> } > >>>> @@ -575,6 +574,54 @@ static void emit_return(u8 **pprog, u8 *ip) > >>>> *pprog = prog; > >>>> } > >>>> +DEFINE_PER_CPU(u32, bpf_tail_call_cnt); > >>> > >>> Hi Leon, the solution is really simplifies complexity. If I understand > >>> correctly, this TAIL_CALL_CNT becomes the system global wise, not the > >>> prog global wise, but before it was limiting the TCC of entry prog. > >>> > >> > >> Correct. It becomes a PERCPU global variable. > >> > >> But, I think this solution is not robust enough. > >> > >> For example, > >> > >> time prog1 prog1 > >> ==================================> > >> line prog2 > >> > >> this is a time-line on a CPU. If prog1 and prog2 have tailcalls to run, > >> prog2 will reset the tail_call_cnt on current CPU, which is used by > >> prog1. As a result, when the CPU schedules from prog2 to prog1, > >> tail_call_cnt on current CPU has been reset to 0, no matter whether > >> prog1 incremented it. > >> > >> The tail_call_cnt reset issue happens too, even if PERCPU tail_call_cnt > >> moves to 'struct bpf_prog_aux', i.e. one kprobe bpf prog can be > >> triggered on many functions e.g. cilium/pwru. However, this moving is > >> better than this solution. > > > > kprobe progs are not preemptable. > > There is bpf_prog_active that disallows any recursion. > > Moving this percpu count to prog->aux should solve it. > > > >> I think, my previous POC of 'struct bpf_prog_run_ctx' would be better. > >> I'll resend it later, with some improvements. > > > > percpu approach is still prefered, since it removes rax mess. > > It seems that we cannot remove rax. > > Let's take a look at tailcall3.c selftest: > > struct { > __uint(type, BPF_MAP_TYPE_PROG_ARRAY); > __uint(max_entries, 1); > __uint(key_size, sizeof(__u32)); > __uint(value_size, sizeof(__u32)); > } jmp_table SEC(".maps"); > > int count = 0; > > SEC("tc") > int classifier_0(struct __sk_buff *skb) > { > count++; > bpf_tail_call_static(skb, &jmp_table, 0); > return 1; > } > > SEC("tc") > int entry(struct __sk_buff *skb) > { > bpf_tail_call_static(skb, &jmp_table, 0); > return 0; > } > > Here, classifier_0 is populated to jmp_table. > > Then, at classifier_0's prologue, when we 'move rax, > classifier_0->tail_call_cnt' in order to use the PERCPU tail_call_cnt in > 'struct bpf_prog' for current run-time, it fails to run selftests. It's > because the tail_call_cnt is not from the entry bpf prog. The > tail_call_cnt from the entry bpf prog is the expected one, even though > classifier_0 bpf prog runs. (It seems that it's unnecessary to provide > the diff of the exclusive approach with PERCPU tail_call_cnt.) Not following. With percpu tail_call_cnt, does classifier_0 loop forever ? I doubt it. You mean expected 'count' value is different? The test expected 33 and instead it's ... what? > + if (tail_call_reachable && !is_subprog) { > + /* mov rax, entry->tail_call_cnt */ > + EMIT2_off64(0x48, 0xB8, (u64) entry->tail_call_cnt); > + /* call bpf_tail_call_cnt_prepare */ > + emit_call(&prog, bpf_tail_call_cnt_prepare, > + ip + (prog - start)); > + } else { > /* Keep the same instruction layout. */ > - EMIT2(0x66, 0x90); /* nop2 */ > + emit_nops(&prog, 10 + X86_PATCH_SIZE); As mentioned before... such "fix" is not acceptable. We will not be penalizing all progs this way. How about we make percpu tail_call_cnt per prog_array map, then remove rax as this patch does, but instead of zeroing tcc on entry, zero it on exit. While processing bpf_exit add: if (tail_call_reachable) emit_call(&prog, bpf_tail_call_cnt_prepare,...) if prog that tailcalls get preempted on this cpu and another prog starts that also tailcalls it won't zero the count. This way we can remove nop5 from prologue too. The preempted prog will eventually zero ttc on exit, and earlier prog that uses the same prog_array can tail call more than 32 times, but it cannot be abused reliably, since preemption is non deterministic.
On 2024/2/27 06:12, Alexei Starovoitov wrote: > On Mon, Feb 26, 2024 at 7:32 AM Leon Hwang <hffilwlqm@gmail.com> wrote: >> >> >> >> On 2024/2/24 00:35, Alexei Starovoitov wrote: >>> On Fri, Feb 23, 2024 at 7:30 AM Leon Hwang <hffilwlqm@gmail.com> wrote: [SNIP] >> >> Let's take a look at tailcall3.c selftest: >> >> struct { >> __uint(type, BPF_MAP_TYPE_PROG_ARRAY); >> __uint(max_entries, 1); >> __uint(key_size, sizeof(__u32)); >> __uint(value_size, sizeof(__u32)); >> } jmp_table SEC(".maps"); >> >> int count = 0; >> >> SEC("tc") >> int classifier_0(struct __sk_buff *skb) >> { >> count++; >> bpf_tail_call_static(skb, &jmp_table, 0); >> return 1; >> } >> >> SEC("tc") >> int entry(struct __sk_buff *skb) >> { >> bpf_tail_call_static(skb, &jmp_table, 0); >> return 0; >> } >> >> Here, classifier_0 is populated to jmp_table. >> >> Then, at classifier_0's prologue, when we 'move rax, >> classifier_0->tail_call_cnt' in order to use the PERCPU tail_call_cnt in >> 'struct bpf_prog' for current run-time, it fails to run selftests. It's >> because the tail_call_cnt is not from the entry bpf prog. The >> tail_call_cnt from the entry bpf prog is the expected one, even though >> classifier_0 bpf prog runs. (It seems that it's unnecessary to provide >> the diff of the exclusive approach with PERCPU tail_call_cnt.) > > Not following. > With percpu tail_call_cnt, does classifier_0 loop forever ? I doubt it. > You mean expected 'count' value is different? > The test expected 33 and instead it's ... what? Yeah, the test result is 34 instead of expected 33. test_tailcall_count:PASS:tailcall 0 nsec test_tailcall_count:PASS:tailcall retval 0 nsec test_tailcall_count:PASS:tailcall count 0 nsec test_tailcall_count:FAIL:tailcall count unexpected tailcall count: actual 34 != expected 33 test_tailcall_count:PASS:tailcall 0 nsec test_tailcall_count:PASS:tailcall retval 0 nsec #311/3 tailcalls/tailcall_3:FAIL > >> + if (tail_call_reachable && !is_subprog) { >> + /* mov rax, entry->tail_call_cnt */ >> + EMIT2_off64(0x48, 0xB8, (u64) entry->tail_call_cnt); >> + /* call bpf_tail_call_cnt_prepare */ >> + emit_call(&prog, bpf_tail_call_cnt_prepare, >> + ip + (prog - start)); >> + } else { >> /* Keep the same instruction layout. */ >> - EMIT2(0x66, 0x90); /* nop2 */ >> + emit_nops(&prog, 10 + X86_PATCH_SIZE); > > As mentioned before... such "fix" is not acceptable. > We will not be penalizing all progs this way. > > How about we make percpu tail_call_cnt per prog_array map, No, we can not store percpu tail_call_cnt on either bpf prog or prog_array map. Considering this case: 1. prog1 tailcall prog2 with prog_array map1. 2. prog2 tailcall prog3 with prog_array map2. 3. prog3 tailcall prog4 with prog_array map3. 4. ... We can not store percpu tail_call_cnt on either prog1 or prog_array map1. In conclusion, tail_call_cnt is a run-time variable that should be stored on stack ideally. Can we store tail_call_cnt on stack and then propagate tcc_ptr by some way instead of rax? > then remove rax as this patch does, > but instead of zeroing tcc on entry, zero it on exit. > While processing bpf_exit add: > if (tail_call_reachable) > emit_call(&prog, bpf_tail_call_cnt_prepare,...) > > if prog that tailcalls get preempted on this cpu and > another prog starts that also tailcalls it won't zero the count. > This way we can remove nop5 from prologue too. > > The preempted prog will eventually zero ttc on exit, > and earlier prog that uses the same prog_array can tail call more > than 32 times, but it cannot be abused reliably, > since preemption is non deterministic. We can not zeroing tcc on exit. If zero it on exit, the selftests of this patchset will run forever, e.g. struct { __uint(type, BPF_MAP_TYPE_PROG_ARRAY); __uint(max_entries, 1); __uint(key_size, sizeof(__u32)); __uint(value_size, sizeof(__u32)); } jmp_table SEC(".maps"); int count = 0; static __noinline int subprog_tail(struct __sk_buff *skb) { bpf_tail_call_static(skb, &jmp_table, 0); return 0; } SEC("tc") int entry(struct __sk_buff *skb) { int ret = 1; count++; subprog_tail(skb); /* tailcall-pos1 */ subprog_tail(skb); /* tailcall-pos2 */ return ret; /* zeroing tcc */ } The jmp_table populates with the entry bpf prog. The entry bpf prog zeros tcc always when returns. So, after the entry bpf prog returns from subprog_tail() at tailcall-pos1, tcc has been reset to 0, and the entry bpf prog tailcalled from subprog_tail() at tailcall-pos2 can run forever. Here's another alternative approach. Like this PATCH v2, it's ok to initialise tcc as MAX_TAIL_CALL_CNT, and then decrement it when tailcall happens. This approach does same with this PATCH v2, and passes all selftests. Here's the diff: diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index e1390d1e3..72773899a 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -11,6 +11,7 @@ #include <linux/bpf.h> #include <linux/memory.h> #include <linux/sort.h> +#include <linux/stringify.h> #include <asm/extable.h> #include <asm/ftrace.h> #include <asm/set_memory.h> @@ -18,6 +19,7 @@ #include <asm/text-patching.h> #include <asm/unwind.h> #include <asm/cfi.h> +#include <asm/percpu.h> static bool all_callee_regs_used[4] = {true, true, true, true}; @@ -259,7 +261,7 @@ struct jit_context { /* Number of bytes emit_patch() needs to generate instructions */ #define X86_PATCH_SIZE 5 /* Number of bytes that will be skipped on tailcall */ -#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) +#define X86_TAIL_CALL_OFFSET (14 + ENDBR_INSN_SIZE) static void push_r12(u8 **pprog) { @@ -389,6 +391,9 @@ static void emit_cfi(u8 **pprog, u32 hash) *pprog = prog; } +static int emit_call(u8 **pprog, void *func, void *ip); +static __used void bpf_tail_call_cnt_prepare(void); + /* * Emit x86-64 prologue code for BPF program. * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes @@ -396,9 +401,9 @@ static void emit_cfi(u8 **pprog, u32 hash) */ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, bool tail_call_reachable, bool is_subprog, - bool is_exception_cb) + bool is_exception_cb, u8 *ip) { - u8 *prog = *pprog; + u8 *prog = *pprog, *start = *pprog; emit_cfi(&prog, is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash); /* BPF trampoline can be made to work without these nops, @@ -407,13 +412,11 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, emit_nops(&prog, X86_PATCH_SIZE); if (!ebpf_from_cbpf) { if (tail_call_reachable && !is_subprog) - /* When it's the entry of the whole tailcall context, - * zeroing rax means initialising tail_call_cnt. - */ - EMIT2(0x31, 0xC0); /* xor eax, eax */ + emit_call(&prog, bpf_tail_call_cnt_prepare, + ip + (prog - start)); else /* Keep the same instruction layout. */ - EMIT2(0x66, 0x90); /* nop2 */ + emit_nops(&prog, X86_PATCH_SIZE); } /* Exception callback receives FP as third parameter */ if (is_exception_cb) { @@ -438,8 +441,6 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, /* sub rsp, rounded_stack_depth */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); - if (tail_call_reachable) - EMIT1(0x50); /* push rax */ *pprog = prog; } @@ -575,13 +576,53 @@ static void emit_return(u8 **pprog, u8 *ip) *pprog = prog; } +DEFINE_PER_CPU(u32, bpf_tail_call_cnt); + +static __used void bpf_tail_call_cnt_prepare(void) +{ + /* The following asm equals to + * + * u32 *tcc_ptr = this_cpu_ptr(&bpf_tail_call_cnt); + * + * *tcc_ptr = MAX_TAIL_CALL_CNT; + */ + + asm volatile ( + "addq " __percpu_arg(0) ", %1\n\t" + "movl $" __stringify(MAX_TAIL_CALL_CNT) ", (%1)\n\t" + : + : "m" (this_cpu_off), "r" (&bpf_tail_call_cnt) + ); +} + +static __used u32 *bpf_tail_call_cnt_ptr(void) +{ + u32 *tcc_ptr; + + /* The following asm equals to + * + * u32 *tcc_ptr = this_cpu_ptr(&bpf_tail_call_cnt); + * + * return tcc_ptr; + */ + + asm volatile ( + "addq " __percpu_arg(1) ", %2\n\t" + "movq %2, %0\n\t" + : "=r" (tcc_ptr) + : "m" (this_cpu_off), "r" (&bpf_tail_call_cnt) + ); + + return tcc_ptr; +} + /* * Generate the following code: * * ... bpf_tail_call(void *ctx, struct bpf_array *array, u64 index) ... * if (index >= array->map.max_entries) * goto out; - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if ((*tcc_ptr)-- == 0) * goto out; * prog = array->ptrs[index]; * if (prog == NULL) @@ -594,7 +635,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, u32 stack_depth, u8 *ip, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); u8 *prog = *pprog, *start = *pprog; int offset; @@ -616,16 +656,16 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, EMIT2(X86_JBE, offset); /* jbe out */ /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if ((*tcc_ptr)-- == 0) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ + /* call bpf_tail_call_cnt_ptr */ + emit_call(&prog, bpf_tail_call_cnt_ptr, ip + (prog - start)); + EMIT3(0x83, 0x38, 0); /* cmp dword ptr [rax], 0 */ offset = ctx->tail_call_indirect_label - (prog + 2 - start); - EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ + EMIT2(X86_JE, offset); /* je out */ + EMIT2(0xFF, 0x08); /* dec dword ptr [rax] */ /* prog = array->ptrs[index]; */ EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ @@ -647,7 +687,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, pop_callee_regs(&prog, callee_regs_used); } - EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ round_up(stack_depth, 8)); @@ -675,21 +714,20 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, bool *callee_regs_used, u32 stack_depth, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); u8 *prog = *pprog, *start = *pprog; int offset; /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + * if ((*tcc_ptr)-- == 0) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ - EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ + /* call bpf_tail_call_cnt_ptr */ + emit_call(&prog, bpf_tail_call_cnt_ptr, ip); + EMIT3(0x83, 0x38, 0); /* cmp dword ptr [rax], 0 */ offset = ctx->tail_call_direct_label - (prog + 2 - start); - EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ + EMIT2(X86_JE, offset); /* je out */ + EMIT2(0xFF, 0x08); /* dec dword ptr [rax] */ poke->tailcall_bypass = ip + (prog - start); poke->adj_off = X86_TAIL_CALL_OFFSET; @@ -706,7 +744,6 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, pop_callee_regs(&prog, callee_regs_used); } - EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); @@ -1133,10 +1170,6 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op) #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp))) -/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ -#define RESTORE_TAIL_CALL_CNT(stack) \ - EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8) - static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image, int oldproglen, struct jit_context *ctx, bool jmp_padding) { @@ -1160,7 +1193,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image emit_prologue(&prog, bpf_prog->aux->stack_depth, bpf_prog_was_classic(bpf_prog), tail_call_reachable, - bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb); + bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb, + image); /* Exception callback will clobber callee regs for its own use, and * restore the original callee regs from main prog's stack frame. */ @@ -1752,17 +1786,11 @@ st: if (is_imm8(insn->off)) case BPF_JMP | BPF_CALL: { int offs; + if (!imm32) + return -EINVAL; + func = (u8 *) __bpf_call_base + imm32; - if (tail_call_reachable) { - RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth); - if (!imm32) - return -EINVAL; - offs = 7 + x86_call_depth_emit_accounting(&prog, func); - } else { - if (!imm32) - return -EINVAL; - offs = x86_call_depth_emit_accounting(&prog, func); - } + offs = x86_call_depth_emit_accounting(&prog, func); if (emit_call(&prog, func, image + addrs[i - 1] + offs)) return -EINVAL; break; @@ -2550,7 +2578,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im * [ ... ] * [ stack_arg2 ] * RBP - arg_stack_off [ stack_arg1 ] - * RSP [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX */ /* room for return value of orig_call or fentry prog */ @@ -2622,8 +2649,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im /* sub rsp, stack_size */ EMIT4(0x48, 0x83, 0xEC, stack_size); } - if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) - EMIT1(0x50); /* push rax */ /* mov QWORD PTR [rbp - rbx_off], rbx */ emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off); @@ -2678,16 +2703,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im restore_regs(m, &prog, regs_off); save_args(m, &prog, arg_stack_off, true); - if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before calling the original function, restore the - * tail_call_cnt from stack to rax. - */ - RESTORE_TAIL_CALL_CNT(stack_size); - } - if (flags & BPF_TRAMP_F_ORIG_STACK) { - emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8); - EMIT2(0xff, 0xd3); /* call *rbx */ + emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8); + EMIT2(0xff, 0xd0); /* call *rax */ } else { /* call original function */ if (emit_rsb_call(&prog, orig_call, image + (prog - (u8 *)rw_image))) { @@ -2740,11 +2758,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im ret = -EINVAL; goto cleanup; } - } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before running the original function, restore the - * tail_call_cnt from stack to rax. - */ - RESTORE_TAIL_CALL_CNT(stack_size); } /* restore return value of orig_call or fentry prog back into RAX */ Thanks, Leon
On Wed, Feb 28, 2024 at 6:31 AM Leon Hwang <hffilwlqm@gmail.com> wrote: > > > > Here's another alternative approach. Like this PATCH v2, it's ok to > initialise tcc as MAX_TAIL_CALL_CNT, and then decrement it when tailcall > happens. This approach does same with this PATCH v2, and passes all > selftests. > > Here's the diff: before this thread falls through the cracks... Please never post a relative diff from some patch that was in some thread. Always post a diff against the latest bpf or bpf-next _and_ submit it to the list as a full patch with your SOB, etc so that CI can test it and maintainers can review it. When the diff is not in patchwork it will be lost.
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index e1390d1e331b5..3d1498a13b04c 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -18,6 +18,7 @@ #include <asm/text-patching.h> #include <asm/unwind.h> #include <asm/cfi.h> +#include <asm/percpu.h> static bool all_callee_regs_used[4] = {true, true, true, true}; @@ -259,7 +260,7 @@ struct jit_context { /* Number of bytes emit_patch() needs to generate instructions */ #define X86_PATCH_SIZE 5 /* Number of bytes that will be skipped on tailcall */ -#define X86_TAIL_CALL_OFFSET (11 + ENDBR_INSN_SIZE) +#define X86_TAIL_CALL_OFFSET (14 + ENDBR_INSN_SIZE) static void push_r12(u8 **pprog) { @@ -389,6 +390,9 @@ static void emit_cfi(u8 **pprog, u32 hash) *pprog = prog; } +static int emit_call(u8 **pprog, void *func, void *ip); +static __used void bpf_tail_call_cnt_prepare(void); + /* * Emit x86-64 prologue code for BPF program. * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes @@ -396,9 +400,9 @@ static void emit_cfi(u8 **pprog, u32 hash) */ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, bool tail_call_reachable, bool is_subprog, - bool is_exception_cb) + bool is_exception_cb, u8 *ip) { - u8 *prog = *pprog; + u8 *prog = *pprog, *start = *pprog; emit_cfi(&prog, is_subprog ? cfi_bpf_subprog_hash : cfi_bpf_hash); /* BPF trampoline can be made to work without these nops, @@ -407,13 +411,10 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, emit_nops(&prog, X86_PATCH_SIZE); if (!ebpf_from_cbpf) { if (tail_call_reachable && !is_subprog) - /* When it's the entry of the whole tailcall context, - * zeroing rax means initialising tail_call_cnt. - */ - EMIT2(0x31, 0xC0); /* xor eax, eax */ + emit_call(&prog, bpf_tail_call_cnt_prepare, + ip + (prog - start)); else - /* Keep the same instruction layout. */ - EMIT2(0x66, 0x90); /* nop2 */ + emit_nops(&prog, X86_PATCH_SIZE); } /* Exception callback receives FP as third parameter */ if (is_exception_cb) { @@ -438,8 +439,6 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, /* sub rsp, rounded_stack_depth */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); - if (tail_call_reachable) - EMIT1(0x50); /* push rax */ *pprog = prog; } @@ -575,6 +574,54 @@ static void emit_return(u8 **pprog, u8 *ip) *pprog = prog; } +DEFINE_PER_CPU(u32, bpf_tail_call_cnt); + +static __used void bpf_tail_call_cnt_prepare(void) +{ + /* The following asm equals to + * + * u32 *tcc_ptr = this_cpu_ptr(&bpf_tail_call_cnt); + * + * *tcc_ptr = 0; + * + * This asm must uses %rax only. + */ + + asm volatile ( + "addq " __percpu_arg(0) ", %1\n\t" + "movl $0, (%%rax)\n\t" + : + : "m" (this_cpu_off), "r" (&bpf_tail_call_cnt) + ); +} + +static __used u32 bpf_tail_call_cnt_fetch_and_inc(void) +{ + u32 tail_call_cnt; + + /* The following asm equals to + * + * u32 *tcc_ptr = this_cpu_ptr(&bpf_tail_call_cnt); + * + * (*tcc_ptr)++; + * tail_call_cnt = *tcc_ptr; + * tail_call_cnt--; + * + * This asm must uses %rax only. + */ + + asm volatile ( + "addq " __percpu_arg(1) ", %2\n\t" + "incl (%%rax)\n\t" + "movl (%%rax), %0\n\t" + "decl %0\n\t" + : "=r" (tail_call_cnt) + : "m" (this_cpu_off), "r" (&bpf_tail_call_cnt) + ); + + return tail_call_cnt; +} + /* * Generate the following code: * @@ -594,7 +641,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, u32 stack_depth, u8 *ip, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); u8 *prog = *pprog, *start = *pprog; int offset; @@ -615,17 +661,14 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, offset = ctx->tail_call_indirect_label - (prog + 2 - start); EMIT2(X86_JBE, offset); /* jbe out */ - /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + /* if (bpf_tail_call_cnt_fetch_and_inc() >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ + emit_call(&prog, bpf_tail_call_cnt_fetch_and_inc, ip + (prog - start)); EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ offset = ctx->tail_call_indirect_label - (prog + 2 - start); EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ /* prog = array->ptrs[index]; */ EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ @@ -647,7 +690,6 @@ static void emit_bpf_tail_call_indirect(struct bpf_prog *bpf_prog, pop_callee_regs(&prog, callee_regs_used); } - EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ round_up(stack_depth, 8)); @@ -675,21 +717,17 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, bool *callee_regs_used, u32 stack_depth, struct jit_context *ctx) { - int tcc_off = -4 - round_up(stack_depth, 8); u8 *prog = *pprog, *start = *pprog; int offset; - /* - * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) + /* if (bpf_tail_call_cnt_fetch_and_inc() >= MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ + emit_call(&prog, bpf_tail_call_cnt_fetch_and_inc, ip); EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ offset = ctx->tail_call_direct_label - (prog + 2 - start); EMIT2(X86_JAE, offset); /* jae out */ - EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ poke->tailcall_bypass = ip + (prog - start); poke->adj_off = X86_TAIL_CALL_OFFSET; @@ -706,7 +744,6 @@ static void emit_bpf_tail_call_direct(struct bpf_prog *bpf_prog, pop_callee_regs(&prog, callee_regs_used); } - EMIT1(0x58); /* pop rax */ if (stack_depth) EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); @@ -1133,10 +1170,6 @@ static void emit_shiftx(u8 **pprog, u32 dst_reg, u8 src_reg, bool is64, u8 op) #define INSN_SZ_DIFF (((addrs[i] - addrs[i - 1]) - (prog - temp))) -/* mov rax, qword ptr [rbp - rounded_stack_depth - 8] */ -#define RESTORE_TAIL_CALL_CNT(stack) \ - EMIT3_off32(0x48, 0x8B, 0x85, -round_up(stack, 8) - 8) - static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image, int oldproglen, struct jit_context *ctx, bool jmp_padding) { @@ -1160,7 +1193,8 @@ static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, u8 *rw_image emit_prologue(&prog, bpf_prog->aux->stack_depth, bpf_prog_was_classic(bpf_prog), tail_call_reachable, - bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb); + bpf_is_subprog(bpf_prog), bpf_prog->aux->exception_cb, + image); /* Exception callback will clobber callee regs for its own use, and * restore the original callee regs from main prog's stack frame. */ @@ -1752,17 +1786,12 @@ st: if (is_imm8(insn->off)) case BPF_JMP | BPF_CALL: { int offs; + if (!imm32) + return -EINVAL; + func = (u8 *) __bpf_call_base + imm32; - if (tail_call_reachable) { - RESTORE_TAIL_CALL_CNT(bpf_prog->aux->stack_depth); - if (!imm32) - return -EINVAL; - offs = 7 + x86_call_depth_emit_accounting(&prog, func); - } else { - if (!imm32) - return -EINVAL; - offs = x86_call_depth_emit_accounting(&prog, func); - } + offs = x86_call_depth_emit_accounting(&prog, func); + if (emit_call(&prog, func, image + addrs[i - 1] + offs)) return -EINVAL; break; @@ -2550,7 +2579,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im * [ ... ] * [ stack_arg2 ] * RBP - arg_stack_off [ stack_arg1 ] - * RSP [ tail_call_cnt ] BPF_TRAMP_F_TAIL_CALL_CTX */ /* room for return value of orig_call or fentry prog */ @@ -2622,8 +2650,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im /* sub rsp, stack_size */ EMIT4(0x48, 0x83, 0xEC, stack_size); } - if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) - EMIT1(0x50); /* push rax */ /* mov QWORD PTR [rbp - rbx_off], rbx */ emit_stx(&prog, BPF_DW, BPF_REG_FP, BPF_REG_6, -rbx_off); @@ -2678,16 +2704,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im restore_regs(m, &prog, regs_off); save_args(m, &prog, arg_stack_off, true); - if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before calling the original function, restore the - * tail_call_cnt from stack to rax. - */ - RESTORE_TAIL_CALL_CNT(stack_size); - } - if (flags & BPF_TRAMP_F_ORIG_STACK) { - emit_ldx(&prog, BPF_DW, BPF_REG_6, BPF_REG_FP, 8); - EMIT2(0xff, 0xd3); /* call *rbx */ + emit_ldx(&prog, BPF_DW, BPF_REG_0, BPF_REG_FP, 8); + EMIT2(0xff, 0xd0); /* call *rax */ } else { /* call original function */ if (emit_rsb_call(&prog, orig_call, image + (prog - (u8 *)rw_image))) { @@ -2740,11 +2759,6 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *rw_im ret = -EINVAL; goto cleanup; } - } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) { - /* Before running the original function, restore the - * tail_call_cnt from stack to rax. - */ - RESTORE_TAIL_CALL_CNT(stack_size); } /* restore return value of orig_call or fentry prog back into RAX */
From commit ebf7d1f508a73871 ("bpf, x64: rework pro/epilogue and tailcall handling in JIT"), the tailcall on x64 works better than before. From commit e411901c0b775a3a ("bpf: allow for tailcalls in BPF subprograms for x64 JIT"), tailcall is able to run in BPF subprograms on x64. How about: 1. More than 1 subprograms are called in a bpf program. 2. The tailcalls in the subprograms call the bpf program. Because of missing tail_call_cnt back-propagation, a tailcall hierarchy comes up. And MAX_TAIL_CALL_CNT limit does not work for this case. Let's take a look into an example: \#include <linux/bpf.h> \#include <bpf/bpf_helpers.h> \#include "bpf_legacy.h" struct { __uint(type, BPF_MAP_TYPE_PROG_ARRAY); __uint(max_entries, 1); __uint(key_size, sizeof(__u32)); __uint(value_size, sizeof(__u32)); } jmp_table SEC(".maps"); int count = 0; static __noinline int subprog_tail(struct __sk_buff *skb) { bpf_tail_call_static(skb, &jmp_table, 0); return 0; } SEC("tc") int entry(struct __sk_buff *skb) { volatile int ret = 1; count++; subprog_tail(skb); /* subprog call1 */ subprog_tail(skb); /* subprog call2 */ return ret; } char __license[] SEC("license") = "GPL"; And the entry bpf prog is populated to the 0th slot of jmp_table. Then, what happens when entry bpf prog runs? The CPU will be stalled because of too many tailcalls, e.g. the test_progs failed to run on aarch64 and s390x because of "rcu: INFO: rcu_sched self-detected stall on CPU". So, if CPU does not stall because of too many tailcalls, how many tailcalls will be there for this case? And why MAX_TAIL_CALL_CNT limit does not work for this case? Let's step into some running steps. At the very first time when subprog_tail() is called, subprog_tail() does tailcall the entry bpf prog. Then, subprog_taill() is called at second time at the position subprog call1, and it tailcalls the entry bpf prog again. Then, again and again. At the very first time when MAX_TAIL_CALL_CNT limit works, subprog_tail() has been called for 34 times at the position subprog call1. And at this time, the tail_call_cnt is 33 in subprog_tail(). Next, the 34th subprog_tail() returns to entry() because of MAX_TAIL_CALL_CNT limit. In entry(), the 34th entry(), at the time after the 34th subprog_tail() at the position subprog call1 finishes and before the 1st subprog_tail() at the position subprog call2 calls in entry(), what's the value of tail_call_cnt in entry()? It's 33. As we know, tail_all_cnt is pushed on the stack of entry(), and propagates to subprog_tail() by %rax from stack. Then, at the time when subprog_tail() at the position subprog call2 is called for its first time, tail_call_cnt 33 propagates to subprog_tail() by %rax. And the tailcall in subprog_tail() is aborted because of tail_call_cnt >= MAX_TAIL_CALL_CNT too. Then, subprog_tail() at the position subprog call2 ends, and the 34th entry() ends. And it returns to the 33rd subprog_tail() called from the position subprog call1. But wait, at this time, what's the value of tail_call_cnt under the stack of subprog_tail()? It's 33. Then, in the 33rd entry(), at the time after the 33th subprog_tail() at the position subprog call1 finishes and before the 2nd subprog_tail() at the position subprog call2 calls, what's the value of tail_call_cnt in current entry()? It's *32*. Why not 33? Before stepping into subprog_tail() at the position subprog call2 in 33rd entry(), like stopping the time machine, let's have a look at the stack memory: | STACK | +---------+ RBP <-- current rbp | ret | STACK of 33rd entry() | tcc | its value is 32 +---------+ RSP <-- current rsp | rip | STACK of 34rd entry() | rbp | reuse the STACK of 33rd subprog_tail() at the position | ret | subprog call1 | tcc | its value is 33 +---------+ rsp | rip | STACK of 1st subprog_tail() at the position subprog call2 | rbp | | tcc | its value is 33 +---------+ rsp Why not 33? It's because tail_call_cnt does not back-propagate from subprog_tail() to entry(). Then, while stepping into subprog_tail() at the position subprog call2 in 33rd entry(): | STACK | +---------+ | ret | STACK of 33rd entry() | tcc | its value is 32 | rip | | rbp | +---------+ RBP <-- current rbp | tcc | its value is 32; STACK of subprog_tail() at the position +---------+ RSP <-- current rsp subprog call2 Then, while pausing after tailcalling in 2nd subprog_tail() at the position subprog call2: | STACK | +---------+ | ret | STACK of 33rd entry() | tcc | its value is 32 | rip | | rbp | +---------+ RBP <-- current rbp | tcc | its value is 33; STACK of subprog_tail() at the position +---------+ RSP <-- current rsp subprog call2 Note: what happens to tail_call_cnt: /* * if (tail_call_cnt++ >= MAX_TAIL_CALL_CNT) * goto out; */ It's to check >= MAX_TAIL_CALL_CNT first and then increment tail_call_cnt. So, current tailcall is allowed to run. Then, entry() is tailcalled. And the stack memory status is: | STACK | +---------+ | ret | STACK of 33rd entry() | tcc | its value is 32 | rip | | rbp | +---------+ RBP <-- current rbp | ret | STACK of 35th entry(); reuse STACK of subprog_tail() at the | tcc | its value is 33 the position subprog call2 +---------+ RSP <-- current rsp So, the tailcalls in the 35th entry() will be aborted. And, ..., again and again. :( And, I hope you have understood the reason why MAX_TAIL_CALL_CNT limit does not work for this case. And, how many tailcalls are there for this case if CPU does not stall? From top-down view, does it look like hierarchy layer and layer? I think it is a hierarchy layer model with 2+4+8+...+2**33 tailcalls. As a result, if CPU does not stall, there will be 2**34 - 2 = 17,179,869,182 tailcalls. That's the guy making CPU stalled. What about there are N subprog_tail() in entry()? If CPU does not stall because of too many tailcalls, there will be almost N**34 tailcalls. As we learn about the issue, how does this patch resolve it? In this patch, it uses PERCPU tail_call_cnt to store the temporary tail_call_cnt. First, at the prologue of bpf prog, it initialise the PERCPU tail_call_cnt by setting current CPU's tail_call_cnt to 0. Then, when a tailcall happens, it fetches and increments current CPU's tail_call_cnt, and compares to MAX_TAIL_CALL_CNT. Additionally, in order to avoid touching other registers excluding %rax, it uses asm to handle PERCPU tail_call_cnt by %rax only. As a result, the previous tailcall way can be removed totally, including 1. "push rax" at prologue. 2. load tail_call_cnt to rax before calling function. 3. "pop rax" before jumping to tailcallee when tailcall. 4. "push rax" and load tail_call_cnt to rax at trampoline. Fixes: ebf7d1f508a7 ("bpf, x64: rework pro/epilogue and tailcall handling in JIT") Fixes: e411901c0b77 ("bpf: allow for tailcalls in BPF subprograms for x64 JIT") Signed-off-by: Leon Hwang <hffilwlqm@gmail.com> --- arch/x86/net/bpf_jit_comp.c | 128 ++++++++++++++++++++---------------- 1 file changed, 71 insertions(+), 57 deletions(-)