diff mbox series

[bpf-next,v4,4/5] bpf, arm64: Fix tailcall hierarchy

Message ID 20240509150541.81799-5-hffilwlqm@gmail.com (mailing list archive)
State New, archived
Delegated to: BPF
Headers show
Series bpf: Fix tailcall hierarchy | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-VM_Test-0 success Logs for Lint
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-2 success Logs for Unittests
bpf/vmtest-bpf-next-VM_Test-3 success Logs for Validate matrix.py
bpf/vmtest-bpf-next-VM_Test-5 success Logs for aarch64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-4 success Logs for aarch64-gcc / build / build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-10 success Logs for aarch64-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-12 success Logs for s390x-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-9 success Logs for aarch64-gcc / test (test_verifier, false, 360) / test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-11 success Logs for s390x-gcc / build / build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-13 success Logs for s390x-gcc / test (test_maps, false, 360) / test_maps on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-16 success Logs for s390x-gcc / test (test_verifier, false, 360) / test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-17 success Logs for s390x-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-18 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-19 success Logs for x86_64-gcc / build / build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-20 success Logs for x86_64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-28 success Logs for x86_64-llvm-17 / build / build for x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-29 success Logs for x86_64-llvm-17 / build-release / build for x86_64 with llvm-17 and -O2 optimization
bpf/vmtest-bpf-next-VM_Test-34 success Logs for x86_64-llvm-17 / veristat
bpf/vmtest-bpf-next-VM_Test-35 success Logs for x86_64-llvm-18 / build / build for x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-36 success Logs for x86_64-llvm-18 / build-release / build for x86_64 with llvm-18 and -O2 optimization
bpf/vmtest-bpf-next-VM_Test-42 success Logs for x86_64-llvm-18 / veristat
bpf/vmtest-bpf-next-VM_Test-6 success Logs for aarch64-gcc / test (test_maps, false, 360) / test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-7 success Logs for aarch64-gcc / test (test_progs, false, 360) / test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-8 success Logs for aarch64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-14 success Logs for s390x-gcc / test (test_progs, false, 360) / test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-15 success Logs for s390x-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-21 success Logs for x86_64-gcc / test (test_maps, false, 360) / test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for x86_64-gcc / test (test_progs, false, 360) / test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for x86_64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-24 success Logs for x86_64-gcc / test (test_progs_no_alu32_parallel, true, 30) / test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-25 success Logs for x86_64-gcc / test (test_progs_parallel, true, 30) / test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-26 success Logs for x86_64-gcc / test (test_verifier, false, 360) / test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-27 success Logs for x86_64-gcc / veristat / veristat on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-30 success Logs for x86_64-llvm-17 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-31 success Logs for x86_64-llvm-17 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-32 success Logs for x86_64-llvm-17 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-33 success Logs for x86_64-llvm-17 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-37 success Logs for x86_64-llvm-18 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-38 success Logs for x86_64-llvm-18 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-39 success Logs for x86_64-llvm-18 / test (test_progs_cpuv4, false, 360) / test_progs_cpuv4 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-40 success Logs for x86_64-llvm-18 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-41 success Logs for x86_64-llvm-18 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-18
bpf/vmtest-bpf-next-PR success PR summary

Commit Message

Leon Hwang May 9, 2024, 3:05 p.m. UTC
Like the way of "bpf, x64: Fix tailcall hierarchy", this patch fixes
this issue on arm64.

At prologue, it loads tail_call_cnt_ptr from bpf_tail_call_run_ctx to
TCCNT_PTR register, and restores the original ctx from
bpf_tail_call_run_ctx to X0 register meanwhile.

Then, when a tailcall runs:
1. load tail_call_cnt from tail_call_cnt_ptr
2. compare tail_call_cnt with MAX_TAIL_CALL_CNT
3. increment tail_call_cnt
4. store tail_call_cnt by tail_call_cnt_ptr

Furthermore, when trampoline is the caller of bpf prog, it is required
to prepare tail_call_cnt and tail call run ctx on the stack of the
trampoline.

Finally, enable bpf_jit_supports_tail_call_cnt_ptr() to use
bpf_tail_call_run_ctx in __bpf_prog_run().

Fixes: d4609a5d8c70 ("bpf, arm64: Keep tail call count across bpf2bpf calls")
Signed-off-by: Leon Hwang <hffilwlqm@gmail.com>
---
 arch/arm64/net/bpf_jit_comp.c | 63 +++++++++++++++++++++++++++--------
 1 file changed, 50 insertions(+), 13 deletions(-)
diff mbox series

Patch

diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
index 53347d4217f4b..1160b3619f821 100644
--- a/arch/arm64/net/bpf_jit_comp.c
+++ b/arch/arm64/net/bpf_jit_comp.c
@@ -26,7 +26,7 @@ 
 
 #define TMP_REG_1 (MAX_BPF_JIT_REG + 0)
 #define TMP_REG_2 (MAX_BPF_JIT_REG + 1)
-#define TCALL_CNT (MAX_BPF_JIT_REG + 2)
+#define TCCNT_PTR (MAX_BPF_JIT_REG + 2)
 #define TMP_REG_3 (MAX_BPF_JIT_REG + 3)
 #define FP_BOTTOM (MAX_BPF_JIT_REG + 4)
 #define ARENA_VM_START (MAX_BPF_JIT_REG + 5)
@@ -63,8 +63,8 @@  static const int bpf2a64[] = {
 	[TMP_REG_1] = A64_R(10),
 	[TMP_REG_2] = A64_R(11),
 	[TMP_REG_3] = A64_R(12),
-	/* tail_call_cnt */
-	[TCALL_CNT] = A64_R(26),
+	/* tail_call_cnt_ptr */
+	[TCCNT_PTR] = A64_R(26),
 	/* temporary register for blinding constants */
 	[BPF_REG_AX] = A64_R(9),
 	[FP_BOTTOM] = A64_R(27),
@@ -296,19 +296,20 @@  static bool is_lsi_offset(int offset, int scale)
 #define POKE_OFFSET (BTI_INSNS + 1)
 
 /* Tail call offset to jump into */
-#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 8)
+#define PROLOGUE_OFFSET (BTI_INSNS + 2 + PAC_INSNS + 9)
 
 static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
 			  bool is_exception_cb, u64 arena_vm_start)
 {
 	const struct bpf_prog *prog = ctx->prog;
 	const bool is_main_prog = !bpf_is_subprog(prog);
+	const u8 r1 = bpf2a64[BPF_REG_1];
 	const u8 r6 = bpf2a64[BPF_REG_6];
 	const u8 r7 = bpf2a64[BPF_REG_7];
 	const u8 r8 = bpf2a64[BPF_REG_8];
 	const u8 r9 = bpf2a64[BPF_REG_9];
 	const u8 fp = bpf2a64[BPF_REG_FP];
-	const u8 tcc = bpf2a64[TCALL_CNT];
+	const u8 ptr = bpf2a64[TCCNT_PTR];
 	const u8 fpb = bpf2a64[FP_BOTTOM];
 	const u8 arena_vm_base = bpf2a64[ARENA_VM_START];
 	const int idx0 = ctx->idx;
@@ -359,7 +360,7 @@  static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
 		/* Save callee-saved registers */
 		emit(A64_PUSH(r6, r7, A64_SP), ctx);
 		emit(A64_PUSH(r8, r9, A64_SP), ctx);
-		emit(A64_PUSH(fp, tcc, A64_SP), ctx);
+		emit(A64_PUSH(fp, ptr, A64_SP), ctx);
 		emit(A64_PUSH(fpb, A64_R(28), A64_SP), ctx);
 	} else {
 		/*
@@ -381,8 +382,15 @@  static int build_prologue(struct jit_ctx *ctx, bool ebpf_from_cbpf,
 	emit(A64_MOV(1, fp, A64_SP), ctx);
 
 	if (!ebpf_from_cbpf && is_main_prog) {
-		/* Initialize tail_call_cnt */
-		emit(A64_MOVZ(1, tcc, 0, 0), ctx);
+		if (prog->aux->tail_call_reachable) {
+			/* Cache tcc_ptr. */
+			emit(A64_LDR64I(ptr, r1, 8), ctx);
+			/* Restore the original ctx. */
+			emit(A64_LDR64I(r1, r1, 0), ctx);
+		} else {
+			emit(A64_NOP, ctx);
+			emit(A64_NOP, ctx);
+		}
 
 		cur_offset = ctx->idx - idx0;
 		if (cur_offset != PROLOGUE_OFFSET) {
@@ -432,7 +440,8 @@  static int emit_bpf_tail_call(struct jit_ctx *ctx)
 
 	const u8 tmp = bpf2a64[TMP_REG_1];
 	const u8 prg = bpf2a64[TMP_REG_2];
-	const u8 tcc = bpf2a64[TCALL_CNT];
+	const u8 ptr = bpf2a64[TCCNT_PTR];
+	const u8 tcc = bpf2a64[TMP_REG_3];
 	const int idx0 = ctx->idx;
 #define cur_offset (ctx->idx - idx0)
 #define jmp_offset (out_offset - (cur_offset))
@@ -449,14 +458,16 @@  static int emit_bpf_tail_call(struct jit_ctx *ctx)
 	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
 
 	/*
-	 * if (tail_call_cnt >= MAX_TAIL_CALL_CNT)
+	 * if ((*tcc_ptr) >= MAX_TAIL_CALL_CNT)
 	 *     goto out;
-	 * tail_call_cnt++;
+	 * (*tcc_ptr)++;
 	 */
 	emit_a64_mov_i64(tmp, MAX_TAIL_CALL_CNT, ctx);
+	emit(A64_LDR32I(tcc, ptr, 0), ctx);
 	emit(A64_CMP(1, tcc, tmp), ctx);
 	emit(A64_B_(A64_COND_CS, jmp_offset), ctx);
 	emit(A64_ADD_I(1, tcc, tcc, 1), ctx);
+	emit(A64_STR32I(tcc, ptr, 0), ctx);
 
 	/* prog = array->ptrs[index];
 	 * if (prog == NULL)
@@ -1890,15 +1901,28 @@  bool bpf_jit_supports_subprog_tailcalls(void)
 	return true;
 }
 
+/* Indicate the JIT backend supports tail call count pointer in tailcall context. */
+bool bpf_jit_supports_tail_call_cnt_ptr(void)
+{
+	return true;
+}
+
 static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
 			    int args_off, int retval_off, int run_ctx_off,
 			    bool save_ret)
 {
+	int tail_call_run_ctx_off = offsetof(struct bpf_tramp_run_ctx, tail_call_run_ctx);
+	int tcc_ptr_off = tail_call_run_ctx_off + offsetof(struct bpf_tail_call_run_ctx,
+							   tail_call_cnt_ptr);
+	int tail_call_cnt_off = offsetof(struct bpf_tramp_run_ctx, tail_call_cnt);
+	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
+	struct bpf_prog *p = l->link.prog;
+	const u8 tmp = bpf2a64[TMP_REG_1];
+	const u8 r1 = bpf2a64[BPF_REG_1];
+	const u8 sp = A64_SP;
 	__le32 *branch;
 	u64 enter_prog;
 	u64 exit_prog;
-	struct bpf_prog *p = l->link.prog;
-	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
 
 	enter_prog = (u64)bpf_trampoline_enter(p);
 	exit_prog = (u64)bpf_trampoline_exit(p);
@@ -1936,6 +1960,19 @@  static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
 	emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx);
 	if (!p->jited)
 		emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx);
+	if (p->aux->use_tail_call_run_ctx) {
+		/* Cache the original ctx. */
+		emit(A64_STR64I(r1, sp, run_ctx_off + tail_call_run_ctx_off), ctx);
+		/* Update r1 as tcc_ptr. */
+		emit(A64_ADD_I(1, r1, sp, run_ctx_off + tail_call_cnt_off), ctx);
+		/* Clear tail_call_cnt. */
+		emit_a64_mov_i(0, tmp, 0, ctx);
+		emit(A64_STR32I(tmp, r1, 0), ctx);
+		/* Cache tcc_ptr. */
+		emit(A64_STR64I(r1, sp, run_ctx_off + tcc_ptr_off), ctx);
+		/* Update r1 as tail call run ctx. */
+		emit(A64_ADD_I(1, r1, sp, run_ctx_off + tail_call_run_ctx_off), ctx);
+	}
 
 	emit_call((const u64)p->bpf_func, ctx);