diff mbox series

[bpf-next,v7,4/4] bpf, arm64: bpf trampoline for arm64

Message ID 20220708093032.1832755-5-xukuohai@huawei.com (mailing list archive)
State New, archived
Headers show
Series bpf trampoline for arm64 | expand

Commit Message

Xu Kuohai July 8, 2022, 9:30 a.m. UTC
This is arm64 version of commit fec56f5890d9 ("bpf: Introduce BPF
trampoline"). A bpf trampoline converts native calling convention to bpf
calling convention and is used to implement various bpf features, such
as fentry, fexit, fmod_ret and struct_ops.

This patch does essentially the same thing that bpf trampoline does on x86.

Tested on raspberry pi 4b and qemu:

 #18 /1     bpf_tcp_ca/dctcp:OK
 #18 /2     bpf_tcp_ca/cubic:OK
 #18 /3     bpf_tcp_ca/invalid_license:OK
 #18 /4     bpf_tcp_ca/dctcp_fallback:OK
 #18 /5     bpf_tcp_ca/rel_setsockopt:OK
 #18        bpf_tcp_ca:OK
 #51 /1     dummy_st_ops/dummy_st_ops_attach:OK
 #51 /2     dummy_st_ops/dummy_init_ret_value:OK
 #51 /3     dummy_st_ops/dummy_init_ptr_arg:OK
 #51 /4     dummy_st_ops/dummy_multiple_args:OK
 #51        dummy_st_ops:OK
 #57 /1     fexit_bpf2bpf/target_no_callees:OK
 #57 /2     fexit_bpf2bpf/target_yes_callees:OK
 #57 /3     fexit_bpf2bpf/func_replace:OK
 #57 /4     fexit_bpf2bpf/func_replace_verify:OK
 #57 /5     fexit_bpf2bpf/func_sockmap_update:OK
 #57 /6     fexit_bpf2bpf/func_replace_return_code:OK
 #57 /7     fexit_bpf2bpf/func_map_prog_compatibility:OK
 #57 /8     fexit_bpf2bpf/func_replace_multi:OK
 #57 /9     fexit_bpf2bpf/fmod_ret_freplace:OK
 #57        fexit_bpf2bpf:OK
 #237       xdp_bpf2bpf:OK

Signed-off-by: Xu Kuohai <xukuohai@huawei.com>
Acked-by: Song Liu <songliubraving@fb.com>
Acked-by: KP Singh <kpsingh@kernel.org>
---
 arch/arm64/net/bpf_jit_comp.c | 394 +++++++++++++++++++++++++++++++++-
 1 file changed, 391 insertions(+), 3 deletions(-)

Comments

Jean-Philippe Brucker July 11, 2022, 11:57 a.m. UTC | #1
On Fri, Jul 08, 2022 at 05:30:32AM -0400, Xu Kuohai wrote:
> +static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
> +			    int args_off, int retval_off, int run_ctx_off,
> +			    bool save_ret)
> +{
> +	u32 *branch;
> +	u64 enter_prog;
> +	u64 exit_prog;
> +	u8 r0 = bpf2a64[BPF_REG_0];
> +	struct bpf_prog *p = l->link.prog;
> +	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
> +
> +	if (p->aux->sleepable) {
> +		enter_prog = (u64)__bpf_prog_enter_sleepable;
> +		exit_prog = (u64)__bpf_prog_exit_sleepable;
> +	} else {
> +		enter_prog = (u64)__bpf_prog_enter;
> +		exit_prog = (u64)__bpf_prog_exit;
> +	}
> +
> +	if (l->cookie == 0) {
> +		/* if cookie is zero, one instruction is enough to store it */
> +		emit(A64_STR64I(A64_ZR, A64_SP, run_ctx_off + cookie_off), ctx);
> +	} else {
> +		emit_a64_mov_i64(A64_R(10), l->cookie, ctx);
> +		emit(A64_STR64I(A64_R(10), A64_SP, run_ctx_off + cookie_off),
> +		     ctx);
> +	}
> +
> +	/* save p to callee saved register x19 to avoid loading p with mov_i64
> +	 * each time.
> +	 */
> +	emit_addr_mov_i64(A64_R(19), (const u64)p, ctx);
> +
> +	/* arg1: prog */
> +	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
> +	/* arg2: &run_ctx */
> +	emit(A64_ADD_I(1, A64_R(1), A64_SP, run_ctx_off), ctx);
> +
> +	emit_call(enter_prog, ctx);
> +
> +	/* if (__bpf_prog_enter(prog) == 0)
> +	 *         goto skip_exec_of_prog;
> +	 */
> +	branch = ctx->image + ctx->idx;
> +	emit(A64_NOP, ctx);
> +
> +	/* save return value to callee saved register x20 */
> +	emit(A64_MOV(1, A64_R(20), A64_R(0)), ctx);
> +
> +	emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx);
> +	if (!p->jited)
> +		emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx);
> +
> +	emit_call((const u64)p->bpf_func, ctx);
> +
> +	/* store return value, which is held in r0 for JIT and in x0
> +	 * for interpreter.
> +	 */
> +	if (save_ret)
> +		emit(A64_STR64I(p->jited ? r0 : A64_R(0), A64_SP, retval_off),
> +		     ctx);

This should be only A64_R(0), not r0. r0 happens to equal A64_R(0) when
jitted due to the way build_epilogue() builds the function at the moment,
but we shouldn't rely on that.

Apart from that, for the series

Reviewed-by: Jean-Philippe Brucker <jean-philippe@linaro.org>
Xu Kuohai July 11, 2022, 2:16 p.m. UTC | #2
On 7/11/2022 7:57 PM, Jean-Philippe Brucker wrote:
> On Fri, Jul 08, 2022 at 05:30:32AM -0400, Xu Kuohai wrote:
>> +static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
>> +			    int args_off, int retval_off, int run_ctx_off,
>> +			    bool save_ret)
>> +{
>> +	u32 *branch;
>> +	u64 enter_prog;
>> +	u64 exit_prog;
>> +	u8 r0 = bpf2a64[BPF_REG_0];
>> +	struct bpf_prog *p = l->link.prog;
>> +	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
>> +
>> +	if (p->aux->sleepable) {
>> +		enter_prog = (u64)__bpf_prog_enter_sleepable;
>> +		exit_prog = (u64)__bpf_prog_exit_sleepable;
>> +	} else {
>> +		enter_prog = (u64)__bpf_prog_enter;
>> +		exit_prog = (u64)__bpf_prog_exit;
>> +	}
>> +
>> +	if (l->cookie == 0) {
>> +		/* if cookie is zero, one instruction is enough to store it */
>> +		emit(A64_STR64I(A64_ZR, A64_SP, run_ctx_off + cookie_off), ctx);
>> +	} else {
>> +		emit_a64_mov_i64(A64_R(10), l->cookie, ctx);
>> +		emit(A64_STR64I(A64_R(10), A64_SP, run_ctx_off + cookie_off),
>> +		     ctx);
>> +	}
>> +
>> +	/* save p to callee saved register x19 to avoid loading p with mov_i64
>> +	 * each time.
>> +	 */
>> +	emit_addr_mov_i64(A64_R(19), (const u64)p, ctx);
>> +
>> +	/* arg1: prog */
>> +	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
>> +	/* arg2: &run_ctx */
>> +	emit(A64_ADD_I(1, A64_R(1), A64_SP, run_ctx_off), ctx);
>> +
>> +	emit_call(enter_prog, ctx);
>> +
>> +	/* if (__bpf_prog_enter(prog) == 0)
>> +	 *         goto skip_exec_of_prog;
>> +	 */
>> +	branch = ctx->image + ctx->idx;
>> +	emit(A64_NOP, ctx);
>> +
>> +	/* save return value to callee saved register x20 */
>> +	emit(A64_MOV(1, A64_R(20), A64_R(0)), ctx);
>> +
>> +	emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx);
>> +	if (!p->jited)
>> +		emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx);
>> +
>> +	emit_call((const u64)p->bpf_func, ctx);
>> +
>> +	/* store return value, which is held in r0 for JIT and in x0
>> +	 * for interpreter.
>> +	 */
>> +	if (save_ret)
>> +		emit(A64_STR64I(p->jited ? r0 : A64_R(0), A64_SP, retval_off),
>> +		     ctx);
> 
> This should be only A64_R(0), not r0. r0 happens to equal A64_R(0) when
> jitted due to the way build_epilogue() builds the function at the moment,
> but we shouldn't rely on that.
> 

looks like I misunderstood something, will change it to:

/* store return value, which is held in x0 for interpreter and in
 * bpf register r0 for JIT, but r0 happens to equal x0 due to the
 * way build_epilogue() builds the JIT image.
 */
if (save_ret)
        emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);

> Apart from that, for the series
> 
> Reviewed-by: Jean-Philippe Brucker <jean-philippe@linaro.org>
> 
> .
Jean-Philippe Brucker July 11, 2022, 2:37 p.m. UTC | #3
On Mon, Jul 11, 2022 at 10:16:00PM +0800, Xu Kuohai wrote:
> >> +	if (save_ret)
> >> +		emit(A64_STR64I(p->jited ? r0 : A64_R(0), A64_SP, retval_off),
> >> +		     ctx);
> > 
> > This should be only A64_R(0), not r0. r0 happens to equal A64_R(0) when
> > jitted due to the way build_epilogue() builds the function at the moment,
> > but we shouldn't rely on that.
> > 
> 
> looks like I misunderstood something, will change it to:
> 
> /* store return value, which is held in x0 for interpreter and in
>  * bpf register r0 for JIT,

It's simpler than that: in both cases the return value is in x0 because
the function follows the procedure call standard. You could drop the
comment to avoid confusion and only do the change to A64_R(0)

Thanks,
Jean

>
>
>  but r0 happens to equal x0 due to the
>  * way build_epilogue() builds the JIT image.
>  */
> if (save_ret)
>         emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
> 
> > Apart from that, for the series
> > 
> > Reviewed-by: Jean-Philippe Brucker <jean-philippe@linaro.org>
> > 
> > .
Xu Kuohai July 11, 2022, 2:40 p.m. UTC | #4
On 7/11/2022 10:37 PM, Jean-Philippe Brucker wrote:
> On Mon, Jul 11, 2022 at 10:16:00PM +0800, Xu Kuohai wrote:
>>>> +	if (save_ret)
>>>> +		emit(A64_STR64I(p->jited ? r0 : A64_R(0), A64_SP, retval_off),
>>>> +		     ctx);
>>>
>>> This should be only A64_R(0), not r0. r0 happens to equal A64_R(0) when
>>> jitted due to the way build_epilogue() builds the function at the moment,
>>> but we shouldn't rely on that.
>>>
>>
>> looks like I misunderstood something, will change it to:
>>
>> /* store return value, which is held in x0 for interpreter and in
>>  * bpf register r0 for JIT,
> 
> It's simpler than that: in both cases the return value is in x0 because
> the function follows the procedure call standard. You could drop the
> comment to avoid confusion and only do the change to A64_R(0)
> 

OK, will send v9 since v8 was just sent

> Thanks,
> Jean
> 
>>
>>
>>  but r0 happens to equal x0 due to the
>>  * way build_epilogue() builds the JIT image.
>>  */
>> if (save_ret)
>>         emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
>>
>>> Apart from that, for the series
>>>
>>> Reviewed-by: Jean-Philippe Brucker <jean-philippe@linaro.org>
>>>
>>> .
> .
Jean-Philippe Brucker July 11, 2022, 2:48 p.m. UTC | #5
On Mon, Jul 11, 2022 at 10:40:42PM +0800, Xu Kuohai wrote:
> On 7/11/2022 10:37 PM, Jean-Philippe Brucker wrote:
> > On Mon, Jul 11, 2022 at 10:16:00PM +0800, Xu Kuohai wrote:
> >>>> +	if (save_ret)
> >>>> +		emit(A64_STR64I(p->jited ? r0 : A64_R(0), A64_SP, retval_off),
> >>>> +		     ctx);
> >>>
> >>> This should be only A64_R(0), not r0. r0 happens to equal A64_R(0) when
> >>> jitted due to the way build_epilogue() builds the function at the moment,
> >>> but we shouldn't rely on that.
> >>>
> >>
> >> looks like I misunderstood something, will change it to:
> >>
> >> /* store return value, which is held in x0 for interpreter and in
> >>  * bpf register r0 for JIT,
> > 
> > It's simpler than that: in both cases the return value is in x0 because
> > the function follows the procedure call standard. You could drop the
> > comment to avoid confusion and only do the change to A64_R(0)
> > 
> 
> OK, will send v9 since v8 was just sent

Right sorry about this, I could have been clearer

Thanks,
Jean
diff mbox series

Patch

diff --git a/arch/arm64/net/bpf_jit_comp.c b/arch/arm64/net/bpf_jit_comp.c
index 0ef35ec30d4e..073dad95a6a1 100644
--- a/arch/arm64/net/bpf_jit_comp.c
+++ b/arch/arm64/net/bpf_jit_comp.c
@@ -176,6 +176,14 @@  static inline void emit_addr_mov_i64(const int reg, const u64 val,
 	}
 }
 
+static inline void emit_call(u64 target, struct jit_ctx *ctx)
+{
+	u8 tmp = bpf2a64[TMP_REG_1];
+
+	emit_addr_mov_i64(tmp, target, ctx);
+	emit(A64_BLR(tmp), ctx);
+}
+
 static inline int bpf2a64_offset(int bpf_insn, int off,
 				 const struct jit_ctx *ctx)
 {
@@ -1072,8 +1080,7 @@  static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx,
 					    &func_addr, &func_addr_fixed);
 		if (ret < 0)
 			return ret;
-		emit_addr_mov_i64(tmp, func_addr, ctx);
-		emit(A64_BLR(tmp), ctx);
+		emit_call(func_addr, ctx);
 		emit(A64_MOV(1, r0, A64_R(0)), ctx);
 		break;
 	}
@@ -1417,6 +1424,13 @@  static int validate_code(struct jit_ctx *ctx)
 		if (a64_insn == AARCH64_BREAK_FAULT)
 			return -1;
 	}
+	return 0;
+}
+
+static int validate_ctx(struct jit_ctx *ctx)
+{
+	if (validate_code(ctx))
+		return -1;
 
 	if (WARN_ON_ONCE(ctx->exentry_idx != ctx->prog->aux->num_exentries))
 		return -1;
@@ -1546,7 +1560,7 @@  struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 	build_plt(&ctx);
 
 	/* 3. Extra pass to validate JITed code. */
-	if (validate_code(&ctx)) {
+	if (validate_ctx(&ctx)) {
 		bpf_jit_binary_free(header);
 		prog = orig_prog;
 		goto out_off;
@@ -1624,6 +1638,380 @@  bool bpf_jit_supports_subprog_tailcalls(void)
 	return true;
 }
 
+static void invoke_bpf_prog(struct jit_ctx *ctx, struct bpf_tramp_link *l,
+			    int args_off, int retval_off, int run_ctx_off,
+			    bool save_ret)
+{
+	u32 *branch;
+	u64 enter_prog;
+	u64 exit_prog;
+	u8 r0 = bpf2a64[BPF_REG_0];
+	struct bpf_prog *p = l->link.prog;
+	int cookie_off = offsetof(struct bpf_tramp_run_ctx, bpf_cookie);
+
+	if (p->aux->sleepable) {
+		enter_prog = (u64)__bpf_prog_enter_sleepable;
+		exit_prog = (u64)__bpf_prog_exit_sleepable;
+	} else {
+		enter_prog = (u64)__bpf_prog_enter;
+		exit_prog = (u64)__bpf_prog_exit;
+	}
+
+	if (l->cookie == 0) {
+		/* if cookie is zero, one instruction is enough to store it */
+		emit(A64_STR64I(A64_ZR, A64_SP, run_ctx_off + cookie_off), ctx);
+	} else {
+		emit_a64_mov_i64(A64_R(10), l->cookie, ctx);
+		emit(A64_STR64I(A64_R(10), A64_SP, run_ctx_off + cookie_off),
+		     ctx);
+	}
+
+	/* save p to callee saved register x19 to avoid loading p with mov_i64
+	 * each time.
+	 */
+	emit_addr_mov_i64(A64_R(19), (const u64)p, ctx);
+
+	/* arg1: prog */
+	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
+	/* arg2: &run_ctx */
+	emit(A64_ADD_I(1, A64_R(1), A64_SP, run_ctx_off), ctx);
+
+	emit_call(enter_prog, ctx);
+
+	/* if (__bpf_prog_enter(prog) == 0)
+	 *         goto skip_exec_of_prog;
+	 */
+	branch = ctx->image + ctx->idx;
+	emit(A64_NOP, ctx);
+
+	/* save return value to callee saved register x20 */
+	emit(A64_MOV(1, A64_R(20), A64_R(0)), ctx);
+
+	emit(A64_ADD_I(1, A64_R(0), A64_SP, args_off), ctx);
+	if (!p->jited)
+		emit_addr_mov_i64(A64_R(1), (const u64)p->insnsi, ctx);
+
+	emit_call((const u64)p->bpf_func, ctx);
+
+	/* store return value, which is held in r0 for JIT and in x0
+	 * for interpreter.
+	 */
+	if (save_ret)
+		emit(A64_STR64I(p->jited ? r0 : A64_R(0), A64_SP, retval_off),
+		     ctx);
+
+	if (ctx->image) {
+		int offset = &ctx->image[ctx->idx] - branch;
+		*branch = A64_CBZ(1, A64_R(0), offset);
+	}
+
+	/* arg1: prog */
+	emit(A64_MOV(1, A64_R(0), A64_R(19)), ctx);
+	/* arg2: start time */
+	emit(A64_MOV(1, A64_R(1), A64_R(20)), ctx);
+	/* arg3: &run_ctx */
+	emit(A64_ADD_I(1, A64_R(2), A64_SP, run_ctx_off), ctx);
+
+	emit_call(exit_prog, ctx);
+}
+
+static void invoke_bpf_mod_ret(struct jit_ctx *ctx, struct bpf_tramp_links *tl,
+			       int args_off, int retval_off, int run_ctx_off,
+			       u32 **branches)
+{
+	int i;
+
+	/* The first fmod_ret program will receive a garbage return value.
+	 * Set this to 0 to avoid confusing the program.
+	 */
+	emit(A64_STR64I(A64_ZR, A64_SP, retval_off), ctx);
+	for (i = 0; i < tl->nr_links; i++) {
+		invoke_bpf_prog(ctx, tl->links[i], args_off, retval_off,
+				run_ctx_off, true);
+		/* if (*(u64 *)(sp + retval_off) !=  0)
+		 *	goto do_fexit;
+		 */
+		emit(A64_LDR64I(A64_R(10), A64_SP, retval_off), ctx);
+		/* Save the location of branch, and generate a nop.
+		 * This nop will be replaced with a cbnz later.
+		 */
+		branches[i] = ctx->image + ctx->idx;
+		emit(A64_NOP, ctx);
+	}
+}
+
+static void save_args(struct jit_ctx *ctx, int args_off, int nargs)
+{
+	int i;
+
+	for (i = 0; i < nargs; i++) {
+		emit(A64_STR64I(i, A64_SP, args_off), ctx);
+		args_off += 8;
+	}
+}
+
+static void restore_args(struct jit_ctx *ctx, int args_off, int nargs)
+{
+	int i;
+
+	for (i = 0; i < nargs; i++) {
+		emit(A64_LDR64I(i, A64_SP, args_off), ctx);
+		args_off += 8;
+	}
+}
+
+/* Based on the x86's implementation of arch_prepare_bpf_trampoline().
+ *
+ * bpf prog and function entry before bpf trampoline hooked:
+ *   mov x9, lr
+ *   nop
+ *
+ * bpf prog and function entry after bpf trampoline hooked:
+ *   mov x9, lr
+ *   bl  <bpf_trampoline or plt>
+ *
+ */
+static int prepare_trampoline(struct jit_ctx *ctx, struct bpf_tramp_image *im,
+			      struct bpf_tramp_links *tlinks, void *orig_call,
+			      int nargs, u32 flags)
+{
+	int i;
+	int stack_size;
+	int retaddr_off;
+	int regs_off;
+	int retval_off;
+	int args_off;
+	int nargs_off;
+	int ip_off;
+	int run_ctx_off;
+	struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
+	struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
+	struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
+	bool save_ret;
+	u32 **branches = NULL;
+
+	/* trampoline stack layout:
+	 *                  [ parent ip         ]
+	 *                  [ FP                ]
+	 * SP + retaddr_off [ self ip           ]
+	 *                  [ FP                ]
+	 *
+	 *                  [ padding           ] align SP to multiples of 16
+	 *
+	 *                  [ x20               ] callee saved reg x20
+	 * SP + regs_off    [ x19               ] callee saved reg x19
+	 *
+	 * SP + retval_off  [ return value      ] BPF_TRAMP_F_CALL_ORIG or
+	 *                                        BPF_TRAMP_F_RET_FENTRY_RET
+	 *
+	 *                  [ argN              ]
+	 *                  [ ...               ]
+	 * SP + args_off    [ arg1              ]
+	 *
+	 * SP + nargs_off   [ args count        ]
+	 *
+	 * SP + ip_off      [ traced function   ] BPF_TRAMP_F_IP_ARG flag
+	 *
+	 * SP + run_ctx_off [ bpf_tramp_run_ctx ]
+	 */
+
+	stack_size = 0;
+	run_ctx_off = stack_size;
+	/* room for bpf_tramp_run_ctx */
+	stack_size += round_up(sizeof(struct bpf_tramp_run_ctx), 8);
+
+	ip_off = stack_size;
+	/* room for IP address argument */
+	if (flags & BPF_TRAMP_F_IP_ARG)
+		stack_size += 8;
+
+	nargs_off = stack_size;
+	/* room for args count */
+	stack_size += 8;
+
+	args_off = stack_size;
+	/* room for args */
+	stack_size += nargs * 8;
+
+	/* room for return value */
+	retval_off = stack_size;
+	save_ret = flags & (BPF_TRAMP_F_CALL_ORIG | BPF_TRAMP_F_RET_FENTRY_RET);
+	if (save_ret)
+		stack_size += 8;
+
+	/* room for callee saved registers, currently x19 and x20 are used */
+	regs_off = stack_size;
+	stack_size += 16;
+
+	/* round up to multiples of 16 to avoid SPAlignmentFault */
+	stack_size = round_up(stack_size, 16);
+
+	/* return address locates above FP */
+	retaddr_off = stack_size + 8;
+
+	/* bpf trampoline may be invoked by 3 instruction types:
+	 * 1. bl, attached to bpf prog or kernel function via short jump
+	 * 2. br, attached to bpf prog or kernel function via long jump
+	 * 3. blr, working as a function pointer, used by struct_ops.
+	 * So BTI_JC should used here to support both br and blr.
+	 */
+	emit_bti(A64_BTI_JC, ctx);
+
+	/* frame for parent function */
+	emit(A64_PUSH(A64_FP, A64_R(9), A64_SP), ctx);
+	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
+
+	/* frame for patched function */
+	emit(A64_PUSH(A64_FP, A64_LR, A64_SP), ctx);
+	emit(A64_MOV(1, A64_FP, A64_SP), ctx);
+
+	/* allocate stack space */
+	emit(A64_SUB_I(1, A64_SP, A64_SP, stack_size), ctx);
+
+	if (flags & BPF_TRAMP_F_IP_ARG) {
+		/* save ip address of the traced function */
+		emit_addr_mov_i64(A64_R(10), (const u64)orig_call, ctx);
+		emit(A64_STR64I(A64_R(10), A64_SP, ip_off), ctx);
+	}
+
+	/* save args count*/
+	emit(A64_MOVZ(1, A64_R(10), nargs, 0), ctx);
+	emit(A64_STR64I(A64_R(10), A64_SP, nargs_off), ctx);
+
+	/* save args */
+	save_args(ctx, args_off, nargs);
+
+	/* save callee saved registers */
+	emit(A64_STR64I(A64_R(19), A64_SP, regs_off), ctx);
+	emit(A64_STR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
+
+	if (flags & BPF_TRAMP_F_CALL_ORIG) {
+		emit_addr_mov_i64(A64_R(0), (const u64)im, ctx);
+		emit_call((const u64)__bpf_tramp_enter, ctx);
+	}
+
+	for (i = 0; i < fentry->nr_links; i++)
+		invoke_bpf_prog(ctx, fentry->links[i], args_off,
+				retval_off, run_ctx_off,
+				flags & BPF_TRAMP_F_RET_FENTRY_RET);
+
+	if (fmod_ret->nr_links) {
+		branches = kcalloc(fmod_ret->nr_links, sizeof(u32 *),
+				   GFP_KERNEL);
+		if (!branches)
+			return -ENOMEM;
+
+		invoke_bpf_mod_ret(ctx, fmod_ret, args_off, retval_off,
+				   run_ctx_off, branches);
+	}
+
+	if (flags & BPF_TRAMP_F_CALL_ORIG) {
+		restore_args(ctx, args_off, nargs);
+		/* call original func */
+		emit(A64_LDR64I(A64_R(10), A64_SP, retaddr_off), ctx);
+		emit(A64_BLR(A64_R(10)), ctx);
+		/* store return value */
+		if (is_bpf_text_address((unsigned long)orig_call))
+			emit(A64_STR64I(bpf2a64[BPF_REG_0], A64_SP, retval_off),
+			     ctx);
+		else
+			emit(A64_STR64I(A64_R(0), A64_SP, retval_off), ctx);
+		/* reserve a nop for bpf_tramp_image_put */
+		im->ip_after_call = ctx->image + ctx->idx;
+		emit(A64_NOP, ctx);
+	}
+
+	/* update the branches saved in invoke_bpf_mod_ret with cbnz */
+	for (i = 0; i < fmod_ret->nr_links && ctx->image != NULL; i++) {
+		int offset = &ctx->image[ctx->idx] - branches[i];
+		*branches[i] = A64_CBNZ(1, A64_R(10), offset);
+	}
+
+	for (i = 0; i < fexit->nr_links; i++)
+		invoke_bpf_prog(ctx, fexit->links[i], args_off, retval_off,
+				run_ctx_off, false);
+
+	if (flags & BPF_TRAMP_F_CALL_ORIG) {
+		im->ip_epilogue = ctx->image + ctx->idx;
+		emit_addr_mov_i64(A64_R(0), (const u64)im, ctx);
+		emit_call((const u64)__bpf_tramp_exit, ctx);
+	}
+
+	if (flags & BPF_TRAMP_F_RESTORE_REGS)
+		restore_args(ctx, args_off, nargs);
+
+	/* restore callee saved register x19 and x20 */
+	emit(A64_LDR64I(A64_R(19), A64_SP, regs_off), ctx);
+	emit(A64_LDR64I(A64_R(20), A64_SP, regs_off + 8), ctx);
+
+	if (save_ret)
+		emit(A64_LDR64I(A64_R(0), A64_SP, retval_off), ctx);
+
+	/* reset SP  */
+	emit(A64_MOV(1, A64_SP, A64_FP), ctx);
+
+	/* pop frames  */
+	emit(A64_POP(A64_FP, A64_LR, A64_SP), ctx);
+	emit(A64_POP(A64_FP, A64_R(9), A64_SP), ctx);
+
+	if (flags & BPF_TRAMP_F_SKIP_FRAME) {
+		/* skip patched function, return to parent */
+		emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
+		emit(A64_RET(A64_R(9)), ctx);
+	} else {
+		/* return to patched function */
+		emit(A64_MOV(1, A64_R(10), A64_LR), ctx);
+		emit(A64_MOV(1, A64_LR, A64_R(9)), ctx);
+		emit(A64_RET(A64_R(10)), ctx);
+	}
+
+	if (ctx->image)
+		bpf_flush_icache(ctx->image, ctx->image + ctx->idx);
+
+	kfree(branches);
+
+	return ctx->idx;
+}
+
+int arch_prepare_bpf_trampoline(struct bpf_tramp_image *im, void *image,
+				void *image_end, const struct btf_func_model *m,
+				u32 flags, struct bpf_tramp_links *tlinks,
+				void *orig_call)
+{
+	int ret;
+	int nargs = m->nr_args;
+	int max_insns = ((long)image_end - (long)image) / AARCH64_INSN_SIZE;
+	struct jit_ctx ctx = {
+		.image = NULL,
+		.idx = 0,
+	};
+
+	/* the first 8 arguments are passed by registers */
+	if (nargs > 8)
+		return -ENOTSUPP;
+
+	ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nargs, flags);
+	if (ret < 0)
+		return ret;
+
+	if (ret > max_insns)
+		return -EFBIG;
+
+	ctx.image = image;
+	ctx.idx = 0;
+
+	jit_fill_hole(image, (unsigned int)(image_end - image));
+	ret = prepare_trampoline(&ctx, im, tlinks, orig_call, nargs, flags);
+
+	if (ret > 0 && validate_code(&ctx) < 0)
+		ret = -EINVAL;
+
+	if (ret > 0)
+		ret *= AARCH64_INSN_SIZE;
+
+	return ret;
+}
+
 static bool is_long_jump(void *ip, void *target)
 {
 	long offset;