diff mbox series

[bpf-next,2/4] bpf, arm64: Fix tailcall infinite loop caused by freplace

Message ID 20240825130943.7738-3-leon.hwang@linux.dev (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series bpf: Fix tailcall infinite loop caused by freplace | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-PR success PR summary
bpf/vmtest-bpf-next-VM_Test-0 success Logs for Lint
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-2 success Logs for Unittests
bpf/vmtest-bpf-next-VM_Test-3 success Logs for Validate matrix.py
bpf/vmtest-bpf-next-VM_Test-5 success Logs for aarch64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-4 success Logs for aarch64-gcc / build / build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-9 success Logs for aarch64-gcc / test (test_verifier, false, 360) / test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-27 success Logs for x86_64-llvm-17 / build / build for x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-25 success Logs for x86_64-gcc / test (test_verifier, false, 360) / test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-28 success Logs for x86_64-llvm-17 / build-release / build for x86_64 with llvm-17-O2
bpf/vmtest-bpf-next-VM_Test-29 success Logs for x86_64-llvm-17 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-33 success Logs for x86_64-llvm-17 / veristat
bpf/vmtest-bpf-next-VM_Test-32 success Logs for x86_64-llvm-17 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-34 success Logs for x86_64-llvm-18 / build / build for x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-35 success Logs for x86_64-llvm-18 / build-release / build for x86_64 with llvm-18-O2
bpf/vmtest-bpf-next-VM_Test-40 success Logs for x86_64-llvm-18 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-41 success Logs for x86_64-llvm-18 / veristat
bpf/vmtest-bpf-next-VM_Test-17 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-20 success Logs for x86_64-gcc / test (test_maps, false, 360) / test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-10 success Logs for aarch64-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-18 success Logs for x86_64-gcc / build / build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-15 success Logs for s390x-gcc / test (test_verifier, false, 360) / test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-12 success Logs for s390x-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-16 success Logs for s390x-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-19 success Logs for x86_64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-11 success Logs for s390x-gcc / build / build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-7 success Logs for aarch64-gcc / test (test_progs, false, 360) / test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-6 success Logs for aarch64-gcc / test (test_maps, false, 360) / test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-8 success Logs for aarch64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-24 success Logs for x86_64-gcc / test (test_progs_parallel, true, 30) / test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-26 success Logs for x86_64-gcc / veristat / veristat on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-30 success Logs for x86_64-llvm-17 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-31 success Logs for x86_64-llvm-17 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-36 success Logs for x86_64-llvm-18 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-38 success Logs for x86_64-llvm-18 / test (test_progs_cpuv4, false, 360) / test_progs_cpuv4 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-37 success Logs for x86_64-llvm-18 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-39 success Logs for x86_64-llvm-18 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-13 success Logs for s390x-gcc / test (test_progs, false, 360) / test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-21 success Logs for x86_64-gcc / test (test_progs, false, 360) / test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-14 success Logs for s390x-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for x86_64-gcc / test (test_progs_no_alu32_parallel, true, 30) / test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for x86_64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with gcc
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for bpf-next, async
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 7 this patch: 7
netdev/build_tools success No tools touched, skip
netdev/cc_maintainers fail 1 blamed authors not CCed: martin.lau@linux.dev; 10 maintainers not CCed: sdf@fomichev.me haoluo@google.com linux-arm-kernel@lists.infradead.org jolsa@kernel.org catalin.marinas@arm.com song@kernel.org will@kernel.org kpsingh@kernel.org martin.lau@linux.dev john.fastabend@gmail.com
netdev/build_clang success Errors and warnings before: 7 this patch: 7
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success Fixes tag looks correct
netdev/build_allmodconfig_warn success Errors and warnings before: 7 this patch: 7
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 115 lines checked
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Leon Hwang Aug. 25, 2024, 1:09 p.m. UTC
Like "bpf, x64: Fix tailcall infinite loop caused by freplace", the same
issue happens on arm64, too.

For example:

tc_bpf2bpf.c:

// SPDX-License-Identifier: GPL-2.0
\#include <linux/bpf.h>
\#include <bpf/bpf_helpers.h>

__noinline
int subprog_tc(struct __sk_buff *skb)
{
	return skb->len * 2;
}

SEC("tc")
int entry_tc(struct __sk_buff *skb)
{
	return subprog(skb);
}

char __license[] SEC("license") = "GPL";

tailcall_bpf2bpf_hierarchy_freplace.c:

// SPDX-License-Identifier: GPL-2.0
\#include <linux/bpf.h>
\#include <bpf/bpf_helpers.h>

struct {
	__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
	__uint(max_entries, 1);
	__uint(key_size, sizeof(__u32));
	__uint(value_size, sizeof(__u32));
} jmp_table SEC(".maps");

int count = 0;

static __noinline
int subprog_tail(struct __sk_buff *skb)
{
	bpf_tail_call_static(skb, &jmp_table, 0);
	return 0;
}

SEC("freplace")
int entry_freplace(struct __sk_buff *skb)
{
	count++;
	subprog_tail(skb);
	subprog_tail(skb);
	return count;
}

char __license[] SEC("license") = "GPL";

The attach target of entry_freplace is subprog_tc, and the tail callee
in subprog_tail is entry_tc.

Then, the infinite loop will be entry_tc -> entry_tc -> entry_freplace ->
subprog_tail --tailcall-> entry_tc, because tail_call_cnt in
entry_freplace will count from zero for every time of entry_freplace
execution.

This patch fixes the issue by avoiding touching tail_call_cnt at
prologue when it's subprog or freplace prog.

Then, when freplace prog attaches to entry_tc, it has to initialize
tail_call_cnt and tail_call_cnt_ptr, because its target is main prog and
its target's prologue hasn't initialize them before the attach hook.

So, this patch uses x7 register to tell freplace prog that its target
prog is main prog or not.

Meanwhile, while tail calling to a freplace prog, it is required to
reset x7 register to prevent re-initializing tail_call_cnt at freplace
prog's prologue.

Fixes: 1c123c567fb1 ("bpf: Resolve fext program type when checking map compatibility")
Signed-off-by: Leon Hwang <leon.hwang@linux.dev>
---
 arch/arm64/net/bpf_jit_comp.c | 39 +++++++++++++++++++++++++++++++----
 1 file changed, 35 insertions(+), 4 deletions(-)

Comments

Xu Kuohai Aug. 26, 2024, 2:32 p.m. UTC | #1
On 8/25/2024 9:09 PM, Leon Hwang wrote:
> Like "bpf, x64: Fix tailcall infinite loop caused by freplace", the same
> issue happens on arm64, too.
> 
> For example:
> 
> tc_bpf2bpf.c:
> 
> // SPDX-License-Identifier: GPL-2.0
> \#include <linux/bpf.h>
> \#include <bpf/bpf_helpers.h>
> 
> __noinline
> int subprog_tc(struct __sk_buff *skb)
> {
> 	return skb->len * 2;
> }
> 
> SEC("tc")
> int entry_tc(struct __sk_buff *skb)
> {
> 	return subprog(skb);
> }
> 
> char __license[] SEC("license") = "GPL";
> 
> tailcall_bpf2bpf_hierarchy_freplace.c:
> 
> // SPDX-License-Identifier: GPL-2.0
> \#include <linux/bpf.h>
> \#include <bpf/bpf_helpers.h>
> 
> struct {
> 	__uint(type, BPF_MAP_TYPE_PROG_ARRAY);
> 	__uint(max_entries, 1);
> 	__uint(key_size, sizeof(__u32));
> 	__uint(value_size, sizeof(__u32));
> } jmp_table SEC(".maps");
> 
> int count = 0;
> 
> static __noinline
> int subprog_tail(struct __sk_buff *skb)
> {
> 	bpf_tail_call_static(skb, &jmp_table, 0);
> 	return 0;
> }
> 
> SEC("freplace")
> int entry_freplace(struct __sk_buff *skb)
> {
> 	count++;
> 	subprog_tail(skb);
> 	subprog_tail(skb);
> 	return count;
> }
> 
> char __license[] SEC("license") = "GPL";
> 
> The attach target of entry_freplace is subprog_tc, and the tail callee
> in subprog_tail is entry_tc.
> 
> Then, the infinite loop will be entry_tc -> entry_tc -> entry_freplace ->
> subprog_tail --tailcall-> entry_tc, because tail_call_cnt in
> entry_freplace will count from zero for every time of entry_freplace
> execution.
> 
> This patch fixes the issue by avoiding touching tail_call_cnt at
> prologue when it's subprog or freplace prog.
> 
> Then, when freplace prog attaches to entry_tc, it has to initialize
> tail_call_cnt and tail_call_cnt_ptr, because its target is main prog and
> its target's prologue hasn't initialize them before the attach hook.
> 
> So, this patch uses x7 register to tell freplace prog that its target
> prog is main prog or not.
> 
> Meanwhile, while tail calling to a freplace prog, it is required to
> reset x7 register to prevent re-initializing tail_call_cnt at freplace
> prog's prologue.
> 
> Fixes: 1c123c567fb1 ("bpf: Resolve fext program type when checking map compatibility")
> Signed-off-by: Leon Hwang <leon.hwang@linux.dev>
> ---
>   arch/arm64/net/bpf_jit_comp.c | 39 +++++++++++++++++++++++++++++++----
>   1 file changed, 35 insertions(+), 4 deletions(-)
> 
> diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
> index 59e05a7aea56a..4f8189824973f 100644
> --- a/arch/arm64/net/bpf_jit_comp.c
> +++ b/arch/arm64/net/bpf_jit_comp.c
> @@ -276,6 +276,7 @@ static bool is_lsi_offset(int offset, int scale)
>   /* generated prologue:
>    *      bti c // if CONFIG_ARM64_BTI_KERNEL
>    *      mov x9, lr
> + *      mov x7, 1 // if not-freplace main prog
>    *      nop  // POKE_OFFSET
>    *      paciasp // if CONFIG_ARM64_PTR_AUTH_KERNEL
>    *      stp x29, lr, [sp, #-16]!
> @@ -293,13 +294,14 @@ static bool is_lsi_offset(int offset, int scale)
>   static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
>   {
>   	const struct bpf_prog *prog = ctx->prog;
> +	const bool is_ext = prog->type == BPF_PROG_TYPE_EXT;
>   	const bool is_main_prog = !bpf_is_subprog(prog);
>   	const u8 ptr = bpf2a64[TCCNT_PTR];
>   	const u8 fp = bpf2a64[BPF_REG_FP];
>   	const u8 tcc = ptr;
>   
>   	emit(A64_PUSH(ptr, fp, A64_SP), ctx);
> -	if (is_main_prog) {
> +	if (is_main_prog && !is_ext) {
>   		/* Initialize tail_call_cnt. */
>   		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
>   		emit(A64_PUSH(tcc, fp, A64_SP), ctx);
> @@ -315,22 +317,26 @@ static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
>   #define PAC_INSNS (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL) ? 1 : 0)
>   
>   /* Offset of nop instruction in bpf prog entry to be poked */
> -#define POKE_OFFSET (BTI_INSNS + 1)
> +#define POKE_OFFSET (BTI_INSNS + 2)
>   
>   /* Tail call offset to jump into */
> -#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 10)
> +#define PROLOGUE_OFFSET (BTI_INSNS + 3 + PAC_INSNS + 10)
>   
>   static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
>   			  bool is_exception_cb, u64 arena_vm_start)
>   {
>   	const struct bpf_prog *prog = ctx->prog;
> +	const bool is_ext = prog->type == BPF_PROG_TYPE_EXT;
>   	const bool is_main_prog = !bpf_is_subprog(prog);
> +	const u8 r0 = bpf2a64[BPF_REG_0];
>   	const u8 r6 = bpf2a64[BPF_REG_6];
>   	const u8 r7 = bpf2a64[BPF_REG_7];
>   	const u8 r8 = bpf2a64[BPF_REG_8];
>   	const u8 r9 = bpf2a64[BPF_REG_9];
>   	const u8 fp = bpf2a64[BPF_REG_FP];
>   	const u8 fpb = bpf2a64[FP_BOTTOM];
> +	const u8 ptr = bpf2a64[TCCNT_PTR];
> +	const u8 tmp = bpf2a64[TMP_REG_1];
>   	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
>   	const int idx0 = ctx->idx;
>   	int cur_offset;
> @@ -367,6 +373,10 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
>   	emit_bti(A64_BTI_JC, ctx);
>   
>   	emit(A64_MOV(1, A64_R(9), A64_LR), ctx);
> +	if (!is_ext)
> +		emit(A64_MOVZ(1, r0, is_main_prog, 0), ctx);
> +	else
> +		emit(A64_NOP, ctx);
>   	emit(A64_NOP, ctx);
>   
>   	if (!is_exception_cb) {
> @@ -413,6 +423,19 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
>   		emit_bti(A64_BTI_J, ctx);
>   	}
>   
> +	/* If freplace's target prog is main prog, it has to make x26 as
> +	 * tail_call_cnt_ptr, and then initialize tail_call_cnt via the
> +	 * tail_call_cnt_ptr.
> +	 */
> +	if (is_main_prog && is_ext) {
> +		emit(A64_MOVZ(1, tmp, 1, 0), ctx);
> +		emit(A64_CMP(1, r0, tmp), ctx);
> +		emit(A64_B_(A64_COND_NE, 4), ctx);
> +		emit(A64_ADD_I(1, ptr, A64_SP, 16), ctx);
> +		emit(A64_MOVZ(1, r0, 0, 0), ctx);
> +		emit(A64_STR64I(r0, ptr, 0), ctx);
> +	}
> +
>   	/*
>   	 * Program acting as exception boundary should save all ARM64
>   	 * Callee-saved registers as the exception callback needs to recover
> @@ -444,6 +467,7 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
>   static int out_offset = -1; /* initialized on the first pass of build_body() */
>   static int emit_bpf_tail_call(struct jit_ctx *ctx)
>   {
> +	const u8 r0 = bpf2a64[BPF_REG_0];
>   	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
>   	const u8 r2 = bpf2a64[BPF_REG_2];
>   	const u8 r3 = bpf2a64[BPF_REG_3];
> @@ -491,6 +515,11 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>   
>   	/* Update tail_call_cnt if the slot is populated. */
>   	emit(A64_STR64I(tcc, ptr, 0), ctx);
> +	/* When freplace prog tail calls freplace prog, setting r0 as 0 is to
> +	 * prevent re-initializing tail_call_cnt at the prologue of target
> +	 * freplace prog.
> +	 */
> +	emit(A64_MOVZ(1, r0, 0, 0), ctx);
>   
>   	/* goto *(prog->bpf_func + prologue_offset); */
>   	off = offsetof(struct bpf_prog, bpf_func);
> @@ -2199,9 +2228,10 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
>   		emit(A64_RET(A64_R(10)), ctx);
>   		/* store return value */
>   		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
> -		/* reserve a nop for bpf_tramp_image_put */
> +		/* reserve two nops for bpf_tramp_image_put */
>   		im->ip_after_call = ctx->ro_image + ctx->idx;
>   		emit(A64_NOP, ctx);
> +		emit(A64_NOP, ctx);
>   	}
>   
>   	/* update the branches saved in invoke_bpf_mod_ret with cbnz */
> @@ -2484,6 +2514,7 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
>   		/* skip to the nop instruction in bpf prog entry:
>   		 * bti c // if BTI enabled
>   		 * mov x9, x30
> +		 * mov x7, 1 // if not-freplace main prog
>   		 * nop
>   		 */
>   		ip = image + POKE_OFFSET * AARCH64_INSN_SIZE;

This patch makes arm64 jited prologue even more complex. I've posted a series [1]
to simplify the arm64 jited prologue/epilogue. I think we can fix this issue based
on [1]. I'll give it a try.

[1] https://lore.kernel.org/bpf/20240826071624.350108-1-xukuohai@huaweicloud.com/
Leon Hwang Aug. 27, 2024, 2:23 a.m. UTC | #2
On 26/8/24 22:32, Xu Kuohai wrote:
> On 8/25/2024 9:09 PM, Leon Hwang wrote:
>> Like "bpf, x64: Fix tailcall infinite loop caused by freplace", the same
>> issue happens on arm64, too.
>>

[...]

> 
> This patch makes arm64 jited prologue even more complex. I've posted a
> series [1]
> to simplify the arm64 jited prologue/epilogue. I think we can fix this
> issue based
> on [1]. I'll give it a try.
> 
> [1]
> https://lore.kernel.org/bpf/20240826071624.350108-1-xukuohai@huaweicloud.com/
> 

Your patch series seems great. We can fix it based on it.

Please notify me if you have a successful try.

Thanks,
Leon
Xu Kuohai Aug. 30, 2024, 7:37 a.m. UTC | #3
On 8/27/2024 10:23 AM, Leon Hwang wrote:
> 
> 
> On 26/8/24 22:32, Xu Kuohai wrote:
>> On 8/25/2024 9:09 PM, Leon Hwang wrote:
>>> Like "bpf, x64: Fix tailcall infinite loop caused by freplace", the same
>>> issue happens on arm64, too.
>>>
> 
> [...]
> 
>>
>> This patch makes arm64 jited prologue even more complex. I've posted a
>> series [1]
>> to simplify the arm64 jited prologue/epilogue. I think we can fix this
>> issue based
>> on [1]. I'll give it a try.
>>
>> [1]
>> https://lore.kernel.org/bpf/20240826071624.350108-1-xukuohai@huaweicloud.com/
>>
> 
> Your patch series seems great. We can fix it based on it.
> 
> Please notify me if you have a successful try.
> 

I think the complexity arises from having to decide whether
to initialize or keep the tail counter value in the prologue.

To get rid of this complexity, a straightforward idea is to
move the tail call counter initialization to the entry of
bpf world, and in the bpf world, we only increase and check
the tail call counter, never save/restore or set it. The
"entry of the bpf world" here refers to mechanisms like
bpf_prog_run, bpf dispatcher, or bpf trampoline that
allows bpf prog to be invoked from C function.

Below is a rough POC diff for arm64 that could pass all
of your tests. The tail call counter is held in callee-saved
register x26, and is set to 0 by arch_run_bpf.

diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
index 8aa32cb140b9..2c0f7daf1655 100644
--- a/arch/arm64/net/bpf_jit_comp.c
+++ b/arch/arm64/net/bpf_jit_comp.c
@@ -26,7 +26,7 @@

  #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
  #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
-#define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
+#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
  #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
  #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)

@@ -63,7 +63,7 @@ static const int bpf2a64[] = {
  	[TMP_REG_2] = A64_R(11),
  	[TMP_REG_3] = A64_R(12),
  	/* tail_call_cnt_ptr */
-	[TCCNT_PTR] = A64_R(26),
+	[TCALL_CNT] = A64_R(26), // x26 is used to hold tail call counter
  	/* temporary register for blinding constants */
  	[BPF_REG_AX] = A64_R(9),
  	/* callee saved register for kern_vm_start address */
@@ -286,19 +286,6 @@ static bool is_lsi_offset(int offset, int scale)
   *      // PROLOGUE_OFFSET
   *	// save callee-saved registers
   */
-static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
-{
-	const bool is_main_prog = !bpf_is_subprog(ctx->prog);
-	const u8 ptr = bpf2a64[TCCNT_PTR];
-
-	if (is_main_prog) {
-		/* Initialize tail_call_cnt. */
-		emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx);
-		emit(A64_MOV(1, ptr, A64_SP), ctx);
-	} else
-		emit(A64_PUSH(ptr, ptr, A64_SP), ctx);
-}
-
  static void find_used_callee_regs(struct jit_ctx *ctx)
  {
  	int i;
@@ -419,7 +406,7 @@ static void pop_callee_regs(struct jit_ctx *ctx)
  #define POKE_OFFSET (BTI_INSNS + 1)

  /* Tail call offset to jump into */
-#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4)
+#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 2)

  static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
  {
@@ -473,8 +460,6 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
  		emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
  		emit(A64_MOV(1, A64_FP, A64_SP), ctx);

-		prepare_bpf_tail_call_cnt(ctx);
-
  		if (!ebpf_from_cbpf && is_main_prog) {
  			cur_offset = ctx->idx - idx0;
  			if (cur_offset != PROLOGUE_OFFSET) {
@@ -499,7 +484,7 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
  		 *
  		 * 12 registers are on the stack
  		 */
-		emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
+		emit(A64_SUB_I(1, A64_SP, A64_FP, 80), ctx);
  	}

  	if (ctx->fp_used)
@@ -527,8 +512,7 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)

  	const u8 tmp = bpf2a64[TMP_REG_1];
  	const u8 prg = bpf2a64[TMP_REG_2];
-	const u8 tcc = bpf2a64[TMP_REG_3];
-	const u8 ptr = bpf2a64[TCCNT_PTR];
+	const u8 tcc = bpf2a64[TCALL_CNT];
  	size_t off;
  	__le32 *branch1 = NULL;
  	__le32 *branch2 = NULL;
@@ -546,16 +530,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
  	emit(A64_NOP, ctx);

  	/*
-	 * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
+	 * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
  	 *     goto out;
  	 */
  	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
-	emit(A64_LDR64I(tcc, ptr, 0), ctx);
  	emit(A64_CMP(1, tcc, tmp), ctx);
  	branch2 = ctx->image + ctx->idx;
  	emit(A64_NOP, ctx);

-	/* (*tail_call_cnt_ptr)++; */
+	/* tail_call_cnt++; */
  	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);

  	/* prog = array->ptrs[index];
@@ -570,9 +553,6 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
  	branch3 = ctx->image + ctx->idx;
  	emit(A64_NOP, ctx);

-	/* Update tail_call_cnt if the slot is populated. */
-	emit(A64_STR64I(tcc, ptr, 0), ctx);
-
  	/* restore SP */
  	if (ctx->stack_size)
  		emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
@@ -793,6 +773,27 @@ asm (
  "	.popsection\n"
  );

+unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn *insnsi, bpf_func_t bpf_func);
+asm (
+"	.pushsection .text, \"ax\", @progbits\n"
+"	.global arch_run_bpf\n"
+"	.type arch_run_bpf, %function\n"
+"arch_run_bpf:\n"
+#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
+"	bti j\n"
+#endif
+"	stp x29, x30, [sp, #-16]!\n"
+"	stp xzr, x26, [sp, #-16]!\n"
+"	mov x26, #0\n"
+"	blr x2\n"
+"	ldp xzr, x26, [sp], #16\n"
+"	ldp x29, x30, [sp], #16\n"
+"	ret x30\n"
+"	.size arch_run_bpf, . - arch_run_bpf\n"
+"	.popsection\n"
+);
+EXPORT_SYMBOL_GPL(arch_run_bpf);
+
  /* build a plt initialized like this:
   *
   * plt:
@@ -826,7 +827,6 @@ static void build_plt(struct jit_ctx *ctx)
  static void build_epilogue(struct jit_ctx *ctx)
  {
  	const u8 r0 = bpf2a64[BPF_REG_0];
-	const u8 ptr = bpf2a64[TCCNT_PTR];

  	/* We're done with BPF stack */
  	if (ctx->stack_size)
@@ -834,8 +834,6 @@ static void build_epilogue(struct jit_ctx *ctx)

  	pop_callee_regs(ctx);

-	emit(A64_POP(A64_ZR, ptr, A64_SP), ctx);
-
  	/* Restore FP/LR registers */
  	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);

@@ -2066,6 +2064,8 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
  	bool save_ret;
  	__le32 **branches = NULL;

+	bool target_is_bpf = is_bpf_text_address((unsigned long)func_addr);
+
  	/* trampoline stack layout:
  	 *                  [ parent ip         ]
  	 *                  [ FP                ]
@@ -2133,6 +2133,11 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
  	 */
  	emit_bti(A64_BTI_JC, ctx);

+	if (!target_is_bpf) {
+		emit(A64_PUSH(A64_ZR, A64_R(26), A64_SP), ctx);
+		emit(A64_MOVZ(1, A64_R(26), 0, 0), ctx);
+	}
+
  	/* frame for parent function */
  	emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
  	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
@@ -2226,6 +2231,8 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
  	/* pop frames  */
  	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
  	emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
+	if (!target_is_bpf)
+		emit(A64_POP(A64_ZR, A64_R(26), A64_SP), ctx);

  	if (flags & BPF_TRAMP_F_SKIP_FRAME) {
  		/* skip patched function, return to parent */
diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index dc63083f76b7..8660d15dd50c 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -1244,12 +1244,14 @@ struct bpf_dispatcher {
  #define __bpfcall __nocfi
  #endif

+unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn *insnsi, bpf_func_t bpf_func);
+
  static __always_inline __bpfcall unsigned int bpf_dispatcher_nop_func(
  	const void *ctx,
  	const struct bpf_insn *insnsi,
  	bpf_func_t bpf_func)
  {
-	return bpf_func(ctx, insnsi);
+	return arch_run_bpf(ctx, insnsi, bpf_func);
  }

  /* the implementation of the opaque uapi struct bpf_dynptr */
@@ -1317,7 +1319,7 @@ int arch_prepare_bpf_dispatcher(void *image, void *buf, s64 *funcs, int num_func
  #else
  #define __BPF_DISPATCHER_SC_INIT(name)
  #define __BPF_DISPATCHER_SC(name)
-#define __BPF_DISPATCHER_CALL(name)		bpf_func(ctx, insnsi)
+#define __BPF_DISPATCHER_CALL(name)		arch_run_bpf(ctx, insnsi, bpf_func);
  #define __BPF_DISPATCHER_UPDATE(_d, _new)
  #endif

> Thanks,
> Leon
Leon Hwang Aug. 30, 2024, 9:08 a.m. UTC | #4
On 30/8/24 15:37, Xu Kuohai wrote:
> On 8/27/2024 10:23 AM, Leon Hwang wrote:
>>

[...]

> 
> I think the complexity arises from having to decide whether
> to initialize or keep the tail counter value in the prologue.
> 
> To get rid of this complexity, a straightforward idea is to
> move the tail call counter initialization to the entry of
> bpf world, and in the bpf world, we only increase and check
> the tail call counter, never save/restore or set it. The
> "entry of the bpf world" here refers to mechanisms like
> bpf_prog_run, bpf dispatcher, or bpf trampoline that
> allows bpf prog to be invoked from C function.
> 
> Below is a rough POC diff for arm64 that could pass all
> of your tests. The tail call counter is held in callee-saved
> register x26, and is set to 0 by arch_run_bpf.
> 
> diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
> index 8aa32cb140b9..2c0f7daf1655 100644
> --- a/arch/arm64/net/bpf_jit_comp.c
> +++ b/arch/arm64/net/bpf_jit_comp.c
> @@ -26,7 +26,7 @@
> 
>  #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
>  #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
> -#define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
> +#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
>  #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
>  #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)
> 
> @@ -63,7 +63,7 @@ static const int bpf2a64[] = {
>      [TMP_REG_2] = A64_R(11),
>      [TMP_REG_3] = A64_R(12),
>      /* tail_call_cnt_ptr */
> -    [TCCNT_PTR] = A64_R(26),
> +    [TCALL_CNT] = A64_R(26), // x26 is used to hold tail call counter
>      /* temporary register for blinding constants */
>      [BPF_REG_AX] = A64_R(9),
>      /* callee saved register for kern_vm_start address */
> @@ -286,19 +286,6 @@ static bool is_lsi_offset(int offset, int scale)
>   *      // PROLOGUE_OFFSET
>   *    // save callee-saved registers
>   */
> -static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
> -{
> -    const bool is_main_prog = !bpf_is_subprog(ctx->prog);
> -    const u8 ptr = bpf2a64[TCCNT_PTR];
> -
> -    if (is_main_prog) {
> -        /* Initialize tail_call_cnt. */
> -        emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx);
> -        emit(A64_MOV(1, ptr, A64_SP), ctx);
> -    } else
> -        emit(A64_PUSH(ptr, ptr, A64_SP), ctx);
> -}
> -
>  static void find_used_callee_regs(struct jit_ctx *ctx)
>  {
>      int i;
> @@ -419,7 +406,7 @@ static void pop_callee_regs(struct jit_ctx *ctx)
>  #define POKE_OFFSET (BTI_INSNS + 1)
> 
>  /* Tail call offset to jump into */
> -#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4)
> +#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 2)
> 
>  static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
>  {
> @@ -473,8 +460,6 @@ static int build_prologue(struct jit_ctx *ctx, bool
> ebpf_from_cbpf)
>          emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
>          emit(A64_MOV(1, A64_FP, A64_SP), ctx);
> 
> -        prepare_bpf_tail_call_cnt(ctx);
> -
>          if (!ebpf_from_cbpf && is_main_prog) {
>              cur_offset = ctx->idx - idx0;
>              if (cur_offset != PROLOGUE_OFFSET) {
> @@ -499,7 +484,7 @@ static int build_prologue(struct jit_ctx *ctx, bool
> ebpf_from_cbpf)
>           *
>           * 12 registers are on the stack
>           */
> -        emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
> +        emit(A64_SUB_I(1, A64_SP, A64_FP, 80), ctx);
>      }
> 
>      if (ctx->fp_used)
> @@ -527,8 +512,7 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
> 
>      const u8 tmp = bpf2a64[TMP_REG_1];
>      const u8 prg = bpf2a64[TMP_REG_2];
> -    const u8 tcc = bpf2a64[TMP_REG_3];
> -    const u8 ptr = bpf2a64[TCCNT_PTR];
> +    const u8 tcc = bpf2a64[TCALL_CNT];
>      size_t off;
>      __le32 *branch1 = NULL;
>      __le32 *branch2 = NULL;
> @@ -546,16 +530,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>      emit(A64_NOP, ctx);
> 
>      /*
> -     * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
> +     * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
>       *     goto out;
>       */
>      emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
> -    emit(A64_LDR64I(tcc, ptr, 0), ctx);
>      emit(A64_CMP(1, tcc, tmp), ctx);
>      branch2 = ctx->image + ctx->idx;
>      emit(A64_NOP, ctx);
> 
> -    /* (*tail_call_cnt_ptr)++; */
> +    /* tail_call_cnt++; */
>      emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
> 
>      /* prog = array->ptrs[index];
> @@ -570,9 +553,6 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>      branch3 = ctx->image + ctx->idx;
>      emit(A64_NOP, ctx);
> 
> -    /* Update tail_call_cnt if the slot is populated. */
> -    emit(A64_STR64I(tcc, ptr, 0), ctx);
> -
>      /* restore SP */
>      if (ctx->stack_size)
>          emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
> @@ -793,6 +773,27 @@ asm (
>  "    .popsection\n"
>  );
> 
> +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn
> *insnsi, bpf_func_t bpf_func);
> +asm (
> +"    .pushsection .text, \"ax\", @progbits\n"
> +"    .global arch_run_bpf\n"
> +"    .type arch_run_bpf, %function\n"
> +"arch_run_bpf:\n"
> +#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
> +"    bti j\n"
> +#endif
> +"    stp x29, x30, [sp, #-16]!\n"
> +"    stp xzr, x26, [sp, #-16]!\n"
> +"    mov x26, #0\n"
> +"    blr x2\n"
> +"    ldp xzr, x26, [sp], #16\n"
> +"    ldp x29, x30, [sp], #16\n"
> +"    ret x30\n"
> +"    .size arch_run_bpf, . - arch_run_bpf\n"
> +"    .popsection\n"
> +);
> +EXPORT_SYMBOL_GPL(arch_run_bpf);
> +
>  /* build a plt initialized like this:
>   *
>   * plt:
> @@ -826,7 +827,6 @@ static void build_plt(struct jit_ctx *ctx)
>  static void build_epilogue(struct jit_ctx *ctx)
>  {
>      const u8 r0 = bpf2a64[BPF_REG_0];
> -    const u8 ptr = bpf2a64[TCCNT_PTR];
> 
>      /* We're done with BPF stack */
>      if (ctx->stack_size)
> @@ -834,8 +834,6 @@ static void build_epilogue(struct jit_ctx *ctx)
> 
>      pop_callee_regs(ctx);
> 
> -    emit(A64_POP(A64_ZR, ptr, A64_SP), ctx);
> -
>      /* Restore FP/LR registers */
>      emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
> 
> @@ -2066,6 +2064,8 @@ static int prepare_trampoline(struct jit_ctx *ctx,
> struct bpf_tramp_image *im,
>      bool save_ret;
>      __le32 **branches = NULL;
> 
> +    bool target_is_bpf = is_bpf_text_address((unsigned long)func_addr);
> +
>      /* trampoline stack layout:
>       *                  [ parent ip         ]
>       *                  [ FP                ]
> @@ -2133,6 +2133,11 @@ static int prepare_trampoline(struct jit_ctx
> *ctx, struct bpf_tramp_image *im,
>       */
>      emit_bti(A64_BTI_JC, ctx);
> 
> +    if (!target_is_bpf) {
> +        emit(A64_PUSH(A64_ZR, A64_R(26), A64_SP), ctx);
> +        emit(A64_MOVZ(1, A64_R(26), 0, 0), ctx);
> +    }
> +
>      /* frame for parent function */
>      emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
>      emit(A64_MOV(1, A64_FP, A64_SP), ctx);
> @@ -2226,6 +2231,8 @@ static int prepare_trampoline(struct jit_ctx *ctx,
> struct bpf_tramp_image *im,
>      /* pop frames  */
>      emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
>      emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
> +    if (!target_is_bpf)
> +        emit(A64_POP(A64_ZR, A64_R(26), A64_SP), ctx);
> 
>      if (flags & BPF_TRAMP_F_SKIP_FRAME) {
>          /* skip patched function, return to parent */
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index dc63083f76b7..8660d15dd50c 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -1244,12 +1244,14 @@ struct bpf_dispatcher {
>  #define __bpfcall __nocfi
>  #endif
> 
> +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn
> *insnsi, bpf_func_t bpf_func);
> +
>  static __always_inline __bpfcall unsigned int bpf_dispatcher_nop_func(
>      const void *ctx,
>      const struct bpf_insn *insnsi,
>      bpf_func_t bpf_func)
>  {
> -    return bpf_func(ctx, insnsi);
> +    return arch_run_bpf(ctx, insnsi, bpf_func);
>  }
> 
>  /* the implementation of the opaque uapi struct bpf_dynptr */
> @@ -1317,7 +1319,7 @@ int arch_prepare_bpf_dispatcher(void *image, void
> *buf, s64 *funcs, int num_func
>  #else
>  #define __BPF_DISPATCHER_SC_INIT(name)
>  #define __BPF_DISPATCHER_SC(name)
> -#define __BPF_DISPATCHER_CALL(name)        bpf_func(ctx, insnsi)
> +#define __BPF_DISPATCHER_CALL(name)        arch_run_bpf(ctx, insnsi,
> bpf_func);
>  #define __BPF_DISPATCHER_UPDATE(_d, _new)
>  #endif
> 

This approach is really cool!

I want an alike approach on x86. But I failed. Because, on x86, it's an
indirect call to "call *rdx", aka "bpf_func(ctx, insnsi)".

Let us imagine the arch_run_bpf() on x86:

unsigned int __naked arch_run_bpf(const void *ctx, const struct bpf_insn
*insnsi, bpf_func_t bpf_func)
{
	asm (
		"pushq %rbp\n\t"
		"movq %rsp, %rbp\n\t"
		"xor %rax, %rax\n\t"
		"pushq %rax\n\t"
		"movq %rsp, %rax\n\t"
		"callq *%rdx\n\t"
		"leave\n\t"
		"ret\n\t"
	);
}

If we can change "callq *%rdx" to a direct call, it'll be really
wonderful to resolve this tailcall issue on x86.

How to introduce arch_bpf_run() for all JIT backends?

Thanks,
Leon
Xu Kuohai Aug. 30, 2024, 10 a.m. UTC | #5
On 8/30/2024 5:08 PM, Leon Hwang wrote:
> 
> 
> On 30/8/24 15:37, Xu Kuohai wrote:
>> On 8/27/2024 10:23 AM, Leon Hwang wrote:
>>>
> 
> [...]
> 
>>
>> I think the complexity arises from having to decide whether
>> to initialize or keep the tail counter value in the prologue.
>>
>> To get rid of this complexity, a straightforward idea is to
>> move the tail call counter initialization to the entry of
>> bpf world, and in the bpf world, we only increase and check
>> the tail call counter, never save/restore or set it. The
>> "entry of the bpf world" here refers to mechanisms like
>> bpf_prog_run, bpf dispatcher, or bpf trampoline that
>> allows bpf prog to be invoked from C function.
>>
>> Below is a rough POC diff for arm64 that could pass all
>> of your tests. The tail call counter is held in callee-saved
>> register x26, and is set to 0 by arch_run_bpf.
>>
>> diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
>> index 8aa32cb140b9..2c0f7daf1655 100644
>> --- a/arch/arm64/net/bpf_jit_comp.c
>> +++ b/arch/arm64/net/bpf_jit_comp.c
>> @@ -26,7 +26,7 @@
>>
>>   #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
>>   #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
>> -#define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
>> +#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
>>   #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
>>   #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)
>>
>> @@ -63,7 +63,7 @@ static const int bpf2a64[] = {
>>       [TMP_REG_2] = A64_R(11),
>>       [TMP_REG_3] = A64_R(12),
>>       /* tail_call_cnt_ptr */
>> -    [TCCNT_PTR] = A64_R(26),
>> +    [TCALL_CNT] = A64_R(26), // x26 is used to hold tail call counter
>>       /* temporary register for blinding constants */
>>       [BPF_REG_AX] = A64_R(9),
>>       /* callee saved register for kern_vm_start address */
>> @@ -286,19 +286,6 @@ static bool is_lsi_offset(int offset, int scale)
>>    *      // PROLOGUE_OFFSET
>>    *    // save callee-saved registers
>>    */
>> -static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
>> -{
>> -    const bool is_main_prog = !bpf_is_subprog(ctx->prog);
>> -    const u8 ptr = bpf2a64[TCCNT_PTR];
>> -
>> -    if (is_main_prog) {
>> -        /* Initialize tail_call_cnt. */
>> -        emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx);
>> -        emit(A64_MOV(1, ptr, A64_SP), ctx);
>> -    } else
>> -        emit(A64_PUSH(ptr, ptr, A64_SP), ctx);
>> -}
>> -
>>   static void find_used_callee_regs(struct jit_ctx *ctx)
>>   {
>>       int i;
>> @@ -419,7 +406,7 @@ static void pop_callee_regs(struct jit_ctx *ctx)
>>   #define POKE_OFFSET (BTI_INSNS + 1)
>>
>>   /* Tail call offset to jump into */
>> -#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4)
>> +#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 2)
>>
>>   static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
>>   {
>> @@ -473,8 +460,6 @@ static int build_prologue(struct jit_ctx *ctx, bool
>> ebpf_from_cbpf)
>>           emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
>>           emit(A64_MOV(1, A64_FP, A64_SP), ctx);
>>
>> -        prepare_bpf_tail_call_cnt(ctx);
>> -
>>           if (!ebpf_from_cbpf && is_main_prog) {
>>               cur_offset = ctx->idx - idx0;
>>               if (cur_offset != PROLOGUE_OFFSET) {
>> @@ -499,7 +484,7 @@ static int build_prologue(struct jit_ctx *ctx, bool
>> ebpf_from_cbpf)
>>            *
>>            * 12 registers are on the stack
>>            */
>> -        emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
>> +        emit(A64_SUB_I(1, A64_SP, A64_FP, 80), ctx);
>>       }
>>
>>       if (ctx->fp_used)
>> @@ -527,8 +512,7 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>>
>>       const u8 tmp = bpf2a64[TMP_REG_1];
>>       const u8 prg = bpf2a64[TMP_REG_2];
>> -    const u8 tcc = bpf2a64[TMP_REG_3];
>> -    const u8 ptr = bpf2a64[TCCNT_PTR];
>> +    const u8 tcc = bpf2a64[TCALL_CNT];
>>       size_t off;
>>       __le32 *branch1 = NULL;
>>       __le32 *branch2 = NULL;
>> @@ -546,16 +530,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>>       emit(A64_NOP, ctx);
>>
>>       /*
>> -     * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
>> +     * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
>>        *     goto out;
>>        */
>>       emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
>> -    emit(A64_LDR64I(tcc, ptr, 0), ctx);
>>       emit(A64_CMP(1, tcc, tmp), ctx);
>>       branch2 = ctx->image + ctx->idx;
>>       emit(A64_NOP, ctx);
>>
>> -    /* (*tail_call_cnt_ptr)++; */
>> +    /* tail_call_cnt++; */
>>       emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
>>
>>       /* prog = array->ptrs[index];
>> @@ -570,9 +553,6 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>>       branch3 = ctx->image + ctx->idx;
>>       emit(A64_NOP, ctx);
>>
>> -    /* Update tail_call_cnt if the slot is populated. */
>> -    emit(A64_STR64I(tcc, ptr, 0), ctx);
>> -
>>       /* restore SP */
>>       if (ctx->stack_size)
>>           emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
>> @@ -793,6 +773,27 @@ asm (
>>   "    .popsection\n"
>>   );
>>
>> +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn
>> *insnsi, bpf_func_t bpf_func);
>> +asm (
>> +"    .pushsection .text, \"ax\", @progbits\n"
>> +"    .global arch_run_bpf\n"
>> +"    .type arch_run_bpf, %function\n"
>> +"arch_run_bpf:\n"
>> +#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
>> +"    bti j\n"
>> +#endif
>> +"    stp x29, x30, [sp, #-16]!\n"
>> +"    stp xzr, x26, [sp, #-16]!\n"
>> +"    mov x26, #0\n"
>> +"    blr x2\n"
>> +"    ldp xzr, x26, [sp], #16\n"
>> +"    ldp x29, x30, [sp], #16\n"
>> +"    ret x30\n"
>> +"    .size arch_run_bpf, . - arch_run_bpf\n"
>> +"    .popsection\n"
>> +);
>> +EXPORT_SYMBOL_GPL(arch_run_bpf);
>> +
>>   /* build a plt initialized like this:
>>    *
>>    * plt:
>> @@ -826,7 +827,6 @@ static void build_plt(struct jit_ctx *ctx)
>>   static void build_epilogue(struct jit_ctx *ctx)
>>   {
>>       const u8 r0 = bpf2a64[BPF_REG_0];
>> -    const u8 ptr = bpf2a64[TCCNT_PTR];
>>
>>       /* We're done with BPF stack */
>>       if (ctx->stack_size)
>> @@ -834,8 +834,6 @@ static void build_epilogue(struct jit_ctx *ctx)
>>
>>       pop_callee_regs(ctx);
>>
>> -    emit(A64_POP(A64_ZR, ptr, A64_SP), ctx);
>> -
>>       /* Restore FP/LR registers */
>>       emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
>>
>> @@ -2066,6 +2064,8 @@ static int prepare_trampoline(struct jit_ctx *ctx,
>> struct bpf_tramp_image *im,
>>       bool save_ret;
>>       __le32 **branches = NULL;
>>
>> +    bool target_is_bpf = is_bpf_text_address((unsigned long)func_addr);
>> +
>>       /* trampoline stack layout:
>>        *                  [ parent ip         ]
>>        *                  [ FP                ]
>> @@ -2133,6 +2133,11 @@ static int prepare_trampoline(struct jit_ctx
>> *ctx, struct bpf_tramp_image *im,
>>        */
>>       emit_bti(A64_BTI_JC, ctx);
>>
>> +    if (!target_is_bpf) {
>> +        emit(A64_PUSH(A64_ZR, A64_R(26), A64_SP), ctx);
>> +        emit(A64_MOVZ(1, A64_R(26), 0, 0), ctx);
>> +    }
>> +
>>       /* frame for parent function */
>>       emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
>>       emit(A64_MOV(1, A64_FP, A64_SP), ctx);
>> @@ -2226,6 +2231,8 @@ static int prepare_trampoline(struct jit_ctx *ctx,
>> struct bpf_tramp_image *im,
>>       /* pop frames  */
>>       emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
>>       emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
>> +    if (!target_is_bpf)
>> +        emit(A64_POP(A64_ZR, A64_R(26), A64_SP), ctx);
>>
>>       if (flags & BPF_TRAMP_F_SKIP_FRAME) {
>>           /* skip patched function, return to parent */
>> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
>> index dc63083f76b7..8660d15dd50c 100644
>> --- a/include/linux/bpf.h
>> +++ b/include/linux/bpf.h
>> @@ -1244,12 +1244,14 @@ struct bpf_dispatcher {
>>   #define __bpfcall __nocfi
>>   #endif
>>
>> +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn
>> *insnsi, bpf_func_t bpf_func);
>> +
>>   static __always_inline __bpfcall unsigned int bpf_dispatcher_nop_func(
>>       const void *ctx,
>>       const struct bpf_insn *insnsi,
>>       bpf_func_t bpf_func)
>>   {
>> -    return bpf_func(ctx, insnsi);
>> +    return arch_run_bpf(ctx, insnsi, bpf_func);
>>   }
>>
>>   /* the implementation of the opaque uapi struct bpf_dynptr */
>> @@ -1317,7 +1319,7 @@ int arch_prepare_bpf_dispatcher(void *image, void
>> *buf, s64 *funcs, int num_func
>>   #else
>>   #define __BPF_DISPATCHER_SC_INIT(name)
>>   #define __BPF_DISPATCHER_SC(name)
>> -#define __BPF_DISPATCHER_CALL(name)        bpf_func(ctx, insnsi)
>> +#define __BPF_DISPATCHER_CALL(name)        arch_run_bpf(ctx, insnsi,
>> bpf_func);
>>   #define __BPF_DISPATCHER_UPDATE(_d, _new)
>>   #endif
>>
> 
> This approach is really cool!
> 
> I want an alike approach on x86. But I failed. Because, on x86, it's an
> indirect call to "call *rdx", aka "bpf_func(ctx, insnsi)".
> 
> Let us imagine the arch_run_bpf() on x86:
> 
> unsigned int __naked arch_run_bpf(const void *ctx, const struct bpf_insn
> *insnsi, bpf_func_t bpf_func)
> {
> 	asm (
> 		"pushq %rbp\n\t"
> 		"movq %rsp, %rbp\n\t"
> 		"xor %rax, %rax\n\t"
> 		"pushq %rax\n\t"
> 		"movq %rsp, %rax\n\t"
> 		"callq *%rdx\n\t"
> 		"leave\n\t"
> 		"ret\n\t"
> 	);
> }
> 
> If we can change "callq *%rdx" to a direct call, it'll be really
> wonderful to resolve this tailcall issue on x86.
>

Right, so we need static call here, perhaps we can create a custom
static call trampoline to setup tail call counter.

> How to introduce arch_bpf_run() for all JIT backends?
>

Seems we can not avoid arch specific code. One approach could be
to define a default __weak function to call bpf_func directly,
and let each arch to provide its own overridden implementation.

> Thanks,
> Leon
Leon Hwang Aug. 30, 2024, 12:11 p.m. UTC | #6
On 2024/8/30 18:00, Xu Kuohai wrote:
> On 8/30/2024 5:08 PM, Leon Hwang wrote:
>>
>>
>> On 30/8/24 15:37, Xu Kuohai wrote:
>>> On 8/27/2024 10:23 AM, Leon Hwang wrote:
>>>>
>>

[...]

>>
>> This approach is really cool!
>>
>> I want an alike approach on x86. But I failed. Because, on x86, it's an
>> indirect call to "call *rdx", aka "bpf_func(ctx, insnsi)".
>>
>> Let us imagine the arch_run_bpf() on x86:
>>
>> unsigned int __naked arch_run_bpf(const void *ctx, const struct bpf_insn
>> *insnsi, bpf_func_t bpf_func)
>> {
>>     asm (
>>         "pushq %rbp\n\t"
>>         "movq %rsp, %rbp\n\t"
>>         "xor %rax, %rax\n\t"
>>         "pushq %rax\n\t"
>>         "movq %rsp, %rax\n\t"
>>         "callq *%rdx\n\t"
>>         "leave\n\t"
>>         "ret\n\t"
>>     );
>> }
>>
>> If we can change "callq *%rdx" to a direct call, it'll be really
>> wonderful to resolve this tailcall issue on x86.
>>
> 
> Right, so we need static call here, perhaps we can create a custom
> static call trampoline to setup tail call counter.
> 
>> How to introduce arch_bpf_run() for all JIT backends?
>>
> 
> Seems we can not avoid arch specific code. One approach could be
> to define a default __weak function to call bpf_func directly,
> and let each arch to provide its own overridden implementation.
> 

Hi Xu Kuohai,

Can you send a separate patch to fix this issue on arm64?

After you fixing it, I'll send the patch to fix it on x64.

Thanks,
Leon
Alexei Starovoitov Aug. 30, 2024, 4:03 p.m. UTC | #7
On Fri, Aug 30, 2024 at 5:11 AM Leon Hwang <leon.hwang@linux.dev> wrote:
>
>
>
> On 2024/8/30 18:00, Xu Kuohai wrote:
> > On 8/30/2024 5:08 PM, Leon Hwang wrote:
> >>
> >>
> >> On 30/8/24 15:37, Xu Kuohai wrote:
> >>> On 8/27/2024 10:23 AM, Leon Hwang wrote:
> >>>>
> >>
>
> [...]
>
> >>
> >> This approach is really cool!
> >>
> >> I want an alike approach on x86. But I failed. Because, on x86, it's an
> >> indirect call to "call *rdx", aka "bpf_func(ctx, insnsi)".
> >>
> >> Let us imagine the arch_run_bpf() on x86:
> >>
> >> unsigned int __naked arch_run_bpf(const void *ctx, const struct bpf_insn
> >> *insnsi, bpf_func_t bpf_func)
> >> {
> >>     asm (
> >>         "pushq %rbp\n\t"
> >>         "movq %rsp, %rbp\n\t"
> >>         "xor %rax, %rax\n\t"
> >>         "pushq %rax\n\t"
> >>         "movq %rsp, %rax\n\t"
> >>         "callq *%rdx\n\t"
> >>         "leave\n\t"
> >>         "ret\n\t"
> >>     );
> >> }
> >>
> >> If we can change "callq *%rdx" to a direct call, it'll be really
> >> wonderful to resolve this tailcall issue on x86.
> >>
> >
> > Right, so we need static call here, perhaps we can create a custom
> > static call trampoline to setup tail call counter.
> >
> >> How to introduce arch_bpf_run() for all JIT backends?
> >>
> >
> > Seems we can not avoid arch specific code. One approach could be
> > to define a default __weak function to call bpf_func directly,
> > and let each arch to provide its own overridden implementation.
> >
>
> Hi Xu Kuohai,
>
> Can you send a separate patch to fix this issue on arm64?
>
> After you fixing it, I'll send the patch to fix it on x64.

Hold on.
We're disabling freplace+tail_call in the verifier.
No need to change any JITs.
Puranjay Mohan Sept. 5, 2024, 9:13 a.m. UTC | #8
Xu Kuohai <xukuohai@huaweicloud.com> writes:

> On 8/27/2024 10:23 AM, Leon Hwang wrote:
>> 
>> 
>> On 26/8/24 22:32, Xu Kuohai wrote:
>>> On 8/25/2024 9:09 PM, Leon Hwang wrote:
>>>> Like "bpf, x64: Fix tailcall infinite loop caused by freplace", the same
>>>> issue happens on arm64, too.
>>>>
>> 
>> [...]
>> 
>>>
>>> This patch makes arm64 jited prologue even more complex. I've posted a
>>> series [1]
>>> to simplify the arm64 jited prologue/epilogue. I think we can fix this
>>> issue based
>>> on [1]. I'll give it a try.
>>>
>>> [1]
>>> https://lore.kernel.org/bpf/20240826071624.350108-1-xukuohai@huaweicloud.com/
>>>
>> 
>> Your patch series seems great. We can fix it based on it.
>> 
>> Please notify me if you have a successful try.
>> 
>
> I think the complexity arises from having to decide whether
> to initialize or keep the tail counter value in the prologue.
>
> To get rid of this complexity, a straightforward idea is to
> move the tail call counter initialization to the entry of
> bpf world, and in the bpf world, we only increase and check
> the tail call counter, never save/restore or set it. The
> "entry of the bpf world" here refers to mechanisms like
> bpf_prog_run, bpf dispatcher, or bpf trampoline that
> allows bpf prog to be invoked from C function.
>
> Below is a rough POC diff for arm64 that could pass all
> of your tests. The tail call counter is held in callee-saved
> register x26, and is set to 0 by arch_run_bpf.

I like this approach as it removes all the complexity of handling tcc in
different cases. Can we go ahead with this for arm64 and make
arch_run_bpf a weak function and let other architectures override this
if they want to use a similar approach to this and if other archs want to
do something else they can skip implementing arch_run_bpf.

Thanks,
Puranjay

>
> diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
> index 8aa32cb140b9..2c0f7daf1655 100644
> --- a/arch/arm64/net/bpf_jit_comp.c
> +++ b/arch/arm64/net/bpf_jit_comp.c
> @@ -26,7 +26,7 @@
>
>   #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
>   #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
> -#define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
> +#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
>   #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
>   #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)
>
> @@ -63,7 +63,7 @@ static const int bpf2a64[] = {
>   	[TMP_REG_2] = A64_R(11),
>   	[TMP_REG_3] = A64_R(12),
>   	/* tail_call_cnt_ptr */
> -	[TCCNT_PTR] = A64_R(26),
> +	[TCALL_CNT] = A64_R(26), // x26 is used to hold tail call counter
>   	/* temporary register for blinding constants */
>   	[BPF_REG_AX] = A64_R(9),
>   	/* callee saved register for kern_vm_start address */
> @@ -286,19 +286,6 @@ static bool is_lsi_offset(int offset, int scale)
>    *      // PROLOGUE_OFFSET
>    *	// save callee-saved registers
>    */
> -static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
> -{
> -	const bool is_main_prog = !bpf_is_subprog(ctx->prog);
> -	const u8 ptr = bpf2a64[TCCNT_PTR];
> -
> -	if (is_main_prog) {
> -		/* Initialize tail_call_cnt. */
> -		emit(A64_PUSH(A64_ZR, ptr, A64_SP), ctx);
> -		emit(A64_MOV(1, ptr, A64_SP), ctx);
> -	} else
> -		emit(A64_PUSH(ptr, ptr, A64_SP), ctx);
> -}
> -
>   static void find_used_callee_regs(struct jit_ctx *ctx)
>   {
>   	int i;
> @@ -419,7 +406,7 @@ static void pop_callee_regs(struct jit_ctx *ctx)
>   #define POKE_OFFSET (BTI_INSNS + 1)
>
>   /* Tail call offset to jump into */
> -#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 4)
> +#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 2)
>
>   static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
>   {
> @@ -473,8 +460,6 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
>   		emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
>   		emit(A64_MOV(1, A64_FP, A64_SP), ctx);
>
> -		prepare_bpf_tail_call_cnt(ctx);
> -
>   		if (!ebpf_from_cbpf && is_main_prog) {
>   			cur_offset = ctx->idx - idx0;
>   			if (cur_offset != PROLOGUE_OFFSET) {
> @@ -499,7 +484,7 @@ static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf)
>   		 *
>   		 * 12 registers are on the stack
>   		 */
> -		emit(A64_SUB_I(1, A64_SP, A64_FP, 96), ctx);
> +		emit(A64_SUB_I(1, A64_SP, A64_FP, 80), ctx);
>   	}
>
>   	if (ctx->fp_used)
> @@ -527,8 +512,7 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>
>   	const u8 tmp = bpf2a64[TMP_REG_1];
>   	const u8 prg = bpf2a64[TMP_REG_2];
> -	const u8 tcc = bpf2a64[TMP_REG_3];
> -	const u8 ptr = bpf2a64[TCCNT_PTR];
> +	const u8 tcc = bpf2a64[TCALL_CNT];
>   	size_t off;
>   	__le32 *branch1 = NULL;
>   	__le32 *branch2 = NULL;
> @@ -546,16 +530,15 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>   	emit(A64_NOP, ctx);
>
>   	/*
> -	 * if ((*tail_call_cnt_ptr) >= MAX_TAIL_CALL_CNT)
> +	 * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
>   	 *     goto out;
>   	 */
>   	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
> -	emit(A64_LDR64I(tcc, ptr, 0), ctx);
>   	emit(A64_CMP(1, tcc, tmp), ctx);
>   	branch2 = ctx->image + ctx->idx;
>   	emit(A64_NOP, ctx);
>
> -	/* (*tail_call_cnt_ptr)++; */
> +	/* tail_call_cnt++; */
>   	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
>
>   	/* prog = array->ptrs[index];
> @@ -570,9 +553,6 @@ static int emit_bpf_tail_call(struct jit_ctx *ctx)
>   	branch3 = ctx->image + ctx->idx;
>   	emit(A64_NOP, ctx);
>
> -	/* Update tail_call_cnt if the slot is populated. */
> -	emit(A64_STR64I(tcc, ptr, 0), ctx);
> -
>   	/* restore SP */
>   	if (ctx->stack_size)
>   		emit(A64_ADD_I(1, A64_SP, A64_SP, ctx->stack_size), ctx);
> @@ -793,6 +773,27 @@ asm (
>   "	.popsection\n"
>   );
>
> +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn *insnsi, bpf_func_t bpf_func);
> +asm (
> +"	.pushsection .text, \"ax\", @progbits\n"
> +"	.global arch_run_bpf\n"
> +"	.type arch_run_bpf, %function\n"
> +"arch_run_bpf:\n"
> +#if IS_ENABLED(CONFIG_ARM64_BTI_KERNEL)
> +"	bti j\n"
> +#endif
> +"	stp x29, x30, [sp, #-16]!\n"
> +"	stp xzr, x26, [sp, #-16]!\n"
> +"	mov x26, #0\n"
> +"	blr x2\n"
> +"	ldp xzr, x26, [sp], #16\n"
> +"	ldp x29, x30, [sp], #16\n"
> +"	ret x30\n"
> +"	.size arch_run_bpf, . - arch_run_bpf\n"
> +"	.popsection\n"
> +);
> +EXPORT_SYMBOL_GPL(arch_run_bpf);
> +
>   /* build a plt initialized like this:
>    *
>    * plt:
> @@ -826,7 +827,6 @@ static void build_plt(struct jit_ctx *ctx)
>   static void build_epilogue(struct jit_ctx *ctx)
>   {
>   	const u8 r0 = bpf2a64[BPF_REG_0];
> -	const u8 ptr = bpf2a64[TCCNT_PTR];
>
>   	/* We're done with BPF stack */
>   	if (ctx->stack_size)
> @@ -834,8 +834,6 @@ static void build_epilogue(struct jit_ctx *ctx)
>
>   	pop_callee_regs(ctx);
>
> -	emit(A64_POP(A64_ZR, ptr, A64_SP), ctx);
> -
>   	/* Restore FP/LR registers */
>   	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
>
> @@ -2066,6 +2064,8 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
>   	bool save_ret;
>   	__le32 **branches = NULL;
>
> +	bool target_is_bpf = is_bpf_text_address((unsigned long)func_addr);
> +
>   	/* trampoline stack layout:
>   	 *                  [ parent ip         ]
>   	 *                  [ FP                ]
> @@ -2133,6 +2133,11 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
>   	 */
>   	emit_bti(A64_BTI_JC, ctx);
>
> +	if (!target_is_bpf) {
> +		emit(A64_PUSH(A64_ZR, A64_R(26), A64_SP), ctx);
> +		emit(A64_MOVZ(1, A64_R(26), 0, 0), ctx);
> +	}
> +
>   	/* frame for parent function */
>   	emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
>   	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
> @@ -2226,6 +2231,8 @@ static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
>   	/* pop frames  */
>   	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
>   	emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
> +	if (!target_is_bpf)
> +		emit(A64_POP(A64_ZR, A64_R(26), A64_SP), ctx);
>
>   	if (flags & BPF_TRAMP_F_SKIP_FRAME) {
>   		/* skip patched function, return to parent */
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index dc63083f76b7..8660d15dd50c 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -1244,12 +1244,14 @@ struct bpf_dispatcher {
>   #define __bpfcall __nocfi
>   #endif
>
> +unsigned int arch_run_bpf(const void *ctx, const struct bpf_insn *insnsi, bpf_func_t bpf_func);
> +
>   static __always_inline __bpfcall unsigned int bpf_dispatcher_nop_func(
>   	const void *ctx,
>   	const struct bpf_insn *insnsi,
>   	bpf_func_t bpf_func)
>   {
> -	return bpf_func(ctx, insnsi);
> +	return arch_run_bpf(ctx, insnsi, bpf_func);
>   }
>
>   /* the implementation of the opaque uapi struct bpf_dynptr */
> @@ -1317,7 +1319,7 @@ int arch_prepare_bpf_dispatcher(void *image, void *buf, s64 *funcs, int num_func
>   #else
>   #define __BPF_DISPATCHER_SC_INIT(name)
>   #define __BPF_DISPATCHER_SC(name)
> -#define __BPF_DISPATCHER_CALL(name)		bpf_func(ctx, insnsi)
> +#define __BPF_DISPATCHER_CALL(name)		arch_run_bpf(ctx, insnsi, bpf_func);
>   #define __BPF_DISPATCHER_UPDATE(_d, _new)
>   #endif
>
>> Thanks,
>> Leon
Leon Hwang Sept. 6, 2024, 2:32 p.m. UTC | #9
On 2024/9/5 17:13, Puranjay Mohan wrote:
> Xu Kuohai <xukuohai@huaweicloud.com> writes:
> 
>> On 8/27/2024 10:23 AM, Leon Hwang wrote:
>>>
>>>
>>> On 26/8/24 22:32, Xu Kuohai wrote:
>>>> On 8/25/2024 9:09 PM, Leon Hwang wrote:
>>>>> Like "bpf, x64: Fix tailcall infinite loop caused by freplace", the same
>>>>> issue happens on arm64, too.
>>>>>
>>>
>>> [...]
>>>
>>>>
>>>> This patch makes arm64 jited prologue even more complex. I've posted a
>>>> series [1]
>>>> to simplify the arm64 jited prologue/epilogue. I think we can fix this
>>>> issue based
>>>> on [1]. I'll give it a try.
>>>>
>>>> [1]
>>>> https://lore.kernel.org/bpf/20240826071624.350108-1-xukuohai@huaweicloud.com/
>>>>
>>>
>>> Your patch series seems great. We can fix it based on it.
>>>
>>> Please notify me if you have a successful try.
>>>
>>
>> I think the complexity arises from having to decide whether
>> to initialize or keep the tail counter value in the prologue.
>>
>> To get rid of this complexity, a straightforward idea is to
>> move the tail call counter initialization to the entry of
>> bpf world, and in the bpf world, we only increase and check
>> the tail call counter, never save/restore or set it. The
>> "entry of the bpf world" here refers to mechanisms like
>> bpf_prog_run, bpf dispatcher, or bpf trampoline that
>> allows bpf prog to be invoked from C function.
>>
>> Below is a rough POC diff for arm64 that could pass all
>> of your tests. The tail call counter is held in callee-saved
>> register x26, and is set to 0 by arch_run_bpf.
> 
> I like this approach as it removes all the complexity of handling tcc in

I like this approach, too.

> different cases. Can we go ahead with this for arm64 and make
> arch_run_bpf a weak function and let other architectures override this
> if they want to use a similar approach to this and if other archs want to
> do something else they can skip implementing arch_run_bpf.
> 

Hi Alexei,

What do you think about this idea?

Thanks,
Leon
Alexei Starovoitov Sept. 6, 2024, 3:24 p.m. UTC | #10
On Fri, Sep 6, 2024 at 7:32 AM Leon Hwang <leon.hwang@linux.dev> wrote:
>
>
>
> On 2024/9/5 17:13, Puranjay Mohan wrote:
> > Xu Kuohai <xukuohai@huaweicloud.com> writes:
> >
> >> On 8/27/2024 10:23 AM, Leon Hwang wrote:
> >>>
> >>>
> >>> On 26/8/24 22:32, Xu Kuohai wrote:
> >>>> On 8/25/2024 9:09 PM, Leon Hwang wrote:
> >>>>> Like "bpf, x64: Fix tailcall infinite loop caused by freplace", the same
> >>>>> issue happens on arm64, too.
> >>>>>
> >>>
> >>> [...]
> >>>
> >>>>
> >>>> This patch makes arm64 jited prologue even more complex. I've posted a
> >>>> series [1]
> >>>> to simplify the arm64 jited prologue/epilogue. I think we can fix this
> >>>> issue based
> >>>> on [1]. I'll give it a try.
> >>>>
> >>>> [1]
> >>>> https://lore.kernel.org/bpf/20240826071624.350108-1-xukuohai@huaweicloud.com/
> >>>>
> >>>
> >>> Your patch series seems great. We can fix it based on it.
> >>>
> >>> Please notify me if you have a successful try.
> >>>
> >>
> >> I think the complexity arises from having to decide whether
> >> to initialize or keep the tail counter value in the prologue.
> >>
> >> To get rid of this complexity, a straightforward idea is to
> >> move the tail call counter initialization to the entry of
> >> bpf world, and in the bpf world, we only increase and check
> >> the tail call counter, never save/restore or set it. The
> >> "entry of the bpf world" here refers to mechanisms like
> >> bpf_prog_run, bpf dispatcher, or bpf trampoline that
> >> allows bpf prog to be invoked from C function.
> >>
> >> Below is a rough POC diff for arm64 that could pass all
> >> of your tests. The tail call counter is held in callee-saved
> >> register x26, and is set to 0 by arch_run_bpf.
> >
> > I like this approach as it removes all the complexity of handling tcc in
>
> I like this approach, too.
>
> > different cases. Can we go ahead with this for arm64 and make
> > arch_run_bpf a weak function and let other architectures override this
> > if they want to use a similar approach to this and if other archs want to
> > do something else they can skip implementing arch_run_bpf.
> >
>
> Hi Alexei,
>
> What do you think about this idea?

This was discussed before and no, we're not going to add an extra tcc init
to bpf_prog_run and penalize everybody for this niche case.
Xu Kuohai Sept. 7, 2024, 7:03 a.m. UTC | #11
On 9/6/2024 11:24 PM, Alexei Starovoitov wrote:
> On Fri, Sep 6, 2024 at 7:32 AM Leon Hwang <leon.hwang@linux.dev> wrote:
>>
>>
>>
>> On 2024/9/5 17:13, Puranjay Mohan wrote:
>>> Xu Kuohai <xukuohai@huaweicloud.com> writes:
>>>
>>>> On 8/27/2024 10:23 AM, Leon Hwang wrote:
>>>>>
>>>>>
>>>>> On 26/8/24 22:32, Xu Kuohai wrote:
>>>>>> On 8/25/2024 9:09 PM, Leon Hwang wrote:
>>>>>>> Like "bpf, x64: Fix tailcall infinite loop caused by freplace", the same
>>>>>>> issue happens on arm64, too.
>>>>>>>
>>>>>
>>>>> [...]
>>>>>
>>>>>>
>>>>>> This patch makes arm64 jited prologue even more complex. I've posted a
>>>>>> series [1]
>>>>>> to simplify the arm64 jited prologue/epilogue. I think we can fix this
>>>>>> issue based
>>>>>> on [1]. I'll give it a try.
>>>>>>
>>>>>> [1]
>>>>>> https://lore.kernel.org/bpf/20240826071624.350108-1-xukuohai@huaweicloud.com/
>>>>>>
>>>>>
>>>>> Your patch series seems great. We can fix it based on it.
>>>>>
>>>>> Please notify me if you have a successful try.
>>>>>
>>>>
>>>> I think the complexity arises from having to decide whether
>>>> to initialize or keep the tail counter value in the prologue.
>>>>
>>>> To get rid of this complexity, a straightforward idea is to
>>>> move the tail call counter initialization to the entry of
>>>> bpf world, and in the bpf world, we only increase and check
>>>> the tail call counter, never save/restore or set it. The
>>>> "entry of the bpf world" here refers to mechanisms like
>>>> bpf_prog_run, bpf dispatcher, or bpf trampoline that
>>>> allows bpf prog to be invoked from C function.
>>>>
>>>> Below is a rough POC diff for arm64 that could pass all
>>>> of your tests. The tail call counter is held in callee-saved
>>>> register x26, and is set to 0 by arch_run_bpf.
>>>
>>> I like this approach as it removes all the complexity of handling tcc in
>>
>> I like this approach, too.
>>
>>> different cases. Can we go ahead with this for arm64 and make
>>> arch_run_bpf a weak function and let other architectures override this
>>> if they want to use a similar approach to this and if other archs want to
>>> do something else they can skip implementing arch_run_bpf.
>>>
>>
>> Hi Alexei,
>>
>> What do you think about this idea?
> 
> This was discussed before and no, we're not going to add an extra tcc init
> to bpf_prog_run and penalize everybody for this niche case.

+1, we should avoid hacking jit and adding complexity just for a niche case.
diff mbox series

Patch

diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
index 59e05a7aea56a..4f8189824973f 100644
--- a/arch/arm64/net/bpf_jit_comp.c
+++ b/arch/arm64/net/bpf_jit_comp.c
@@ -276,6 +276,7 @@  static bool is_lsi_offset(int offset, int scale)
 /* generated prologue:
  *      bti c // if CONFIG_ARM64_BTI_KERNEL
  *      mov x9, lr
+ *      mov x7, 1 // if not-freplace main prog
  *      nop  // POKE_OFFSET
  *      paciasp // if CONFIG_ARM64_PTR_AUTH_KERNEL
  *      stp x29, lr, [sp, #-16]!
@@ -293,13 +294,14 @@  static bool is_lsi_offset(int offset, int scale)
 static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
 {
 	const struct bpf_prog *prog = ctx->prog;
+	const bool is_ext = prog->type == BPF_PROG_TYPE_EXT;
 	const bool is_main_prog = !bpf_is_subprog(prog);
 	const u8 ptr = bpf2a64[TCCNT_PTR];
 	const u8 fp = bpf2a64[BPF_REG_FP];
 	const u8 tcc = ptr;
 
 	emit(A64_PUSH(ptr, fp, A64_SP), ctx);
-	if (is_main_prog) {
+	if (is_main_prog && !is_ext) {
 		/* Initialize tail_call_cnt. */
 		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
 		emit(A64_PUSH(tcc, fp, A64_SP), ctx);
@@ -315,22 +317,26 @@  static void prepare_bpf_tail_call_cnt(struct jit_ctx *ctx)
 #define PAC_INSNS (IS_ENABLED(CONFIG_ARM64_PTR_AUTH_KERNEL) ? 1 : 0)
 
 /* Offset of nop instruction in bpf prog entry to be poked */
-#define POKE_OFFSET (BTI_INSNS + 1)
+#define POKE_OFFSET (BTI_INSNS + 2)
 
 /* Tail call offset to jump into */
-#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 10)
+#define PROLOGUE_OFFSET (BTI_INSNS + 3 + PAC_INSNS + 10)
 
 static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
 			  bool is_exception_cb, u64 arena_vm_start)
 {
 	const struct bpf_prog *prog = ctx->prog;
+	const bool is_ext = prog->type == BPF_PROG_TYPE_EXT;
 	const bool is_main_prog = !bpf_is_subprog(prog);
+	const u8 r0 = bpf2a64[BPF_REG_0];
 	const u8 r6 = bpf2a64[BPF_REG_6];
 	const u8 r7 = bpf2a64[BPF_REG_7];
 	const u8 r8 = bpf2a64[BPF_REG_8];
 	const u8 r9 = bpf2a64[BPF_REG_9];
 	const u8 fp = bpf2a64[BPF_REG_FP];
 	const u8 fpb = bpf2a64[FP_BOTTOM];
+	const u8 ptr = bpf2a64[TCCNT_PTR];
+	const u8 tmp = bpf2a64[TMP_REG_1];
 	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
 	const int idx0 = ctx->idx;
 	int cur_offset;
@@ -367,6 +373,10 @@  static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
 	emit_bti(A64_BTI_JC, ctx);
 
 	emit(A64_MOV(1, A64_R(9), A64_LR), ctx);
+	if (!is_ext)
+		emit(A64_MOVZ(1, r0, is_main_prog, 0), ctx);
+	else
+		emit(A64_NOP, ctx);
 	emit(A64_NOP, ctx);
 
 	if (!is_exception_cb) {
@@ -413,6 +423,19 @@  static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
 		emit_bti(A64_BTI_J, ctx);
 	}
 
+	/* If freplace's target prog is main prog, it has to make x26 as
+	 * tail_call_cnt_ptr, and then initialize tail_call_cnt via the
+	 * tail_call_cnt_ptr.
+	 */
+	if (is_main_prog && is_ext) {
+		emit(A64_MOVZ(1, tmp, 1, 0), ctx);
+		emit(A64_CMP(1, r0, tmp), ctx);
+		emit(A64_B_(A64_COND_NE, 4), ctx);
+		emit(A64_ADD_I(1, ptr, A64_SP, 16), ctx);
+		emit(A64_MOVZ(1, r0, 0, 0), ctx);
+		emit(A64_STR64I(r0, ptr, 0), ctx);
+	}
+
 	/*
 	 * Program acting as exception boundary should save all ARM64
 	 * Callee-saved registers as the exception callback needs to recover
@@ -444,6 +467,7 @@  static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
 static int out_offset = -1; /* initialized on the first pass of build_body() */
 static int emit_bpf_tail_call(struct jit_ctx *ctx)
 {
+	const u8 r0 = bpf2a64[BPF_REG_0];
 	/* bpf_tail_call(void *prog_ctx, struct bpf_array *array, u64 index) */
 	const u8 r2 = bpf2a64[BPF_REG_2];
 	const u8 r3 = bpf2a64[BPF_REG_3];
@@ -491,6 +515,11 @@  static int emit_bpf_tail_call(struct jit_ctx *ctx)
 
 	/* Update tail_call_cnt if the slot is populated. */
 	emit(A64_STR64I(tcc, ptr, 0), ctx);
+	/* When freplace prog tail calls freplace prog, setting r0 as 0 is to
+	 * prevent re-initializing tail_call_cnt at the prologue of target
+	 * freplace prog.
+	 */
+	emit(A64_MOVZ(1, r0, 0, 0), ctx);
 
 	/* goto *(prog->bpf_func + prologue_offset); */
 	off = offsetof(struct bpf_prog, bpf_func);
@@ -2199,9 +2228,10 @@  static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
 		emit(A64_RET(A64_R(10)), ctx);
 		/* store return value */
 		emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
-		/* reserve a nop for bpf_tramp_image_put */
+		/* reserve two nops for bpf_tramp_image_put */
 		im->ip_after_call = ctx->ro_image + ctx->idx;
 		emit(A64_NOP, ctx);
+		emit(A64_NOP, ctx);
 	}
 
 	/* update the branches saved in invoke_bpf_mod_ret with cbnz */
@@ -2484,6 +2514,7 @@  int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type poke_type,
 		/* skip to the nop instruction in bpf prog entry:
 		 * bti c // if BTI enabled
 		 * mov x9, x30
+		 * mov x7, 1 // if not-freplace main prog
 		 * nop
 		 */
 		ip = image + POKE_OFFSET * AARCH64_INSN_SIZE;