@@ -81,6 +81,7 @@ struct rv_jit_context {
int nexentries;
unsigned long flags;
int stack_size;
+ int tcc_offset;
};
/* Convert from ninsns to bytes. */
@@ -13,13 +13,11 @@
#include <asm/patch.h>
#include "bpf_jit.h"
+#define RV_REG_TCC RV_REG_A6
#define RV_FENTRY_NINSNS 2
/* fentry and TCC init insns will be skipped on tailcall */
#define RV_TAILCALL_OFFSET ((RV_FENTRY_NINSNS + 1) * 4)
-#define RV_REG_TCC RV_REG_A6
-#define RV_REG_TCC_SAVED RV_REG_S6 /* Store A6 in S6 if program do calls */
-
static const int regmap[] = {
[BPF_REG_0] = RV_REG_A5,
[BPF_REG_1] = RV_REG_A0,
@@ -51,14 +49,12 @@ static const int pt_regmap[] = {
};
enum {
- RV_CTX_F_SEEN_TAIL_CALL = 0,
RV_CTX_F_SEEN_CALL = RV_REG_RA,
RV_CTX_F_SEEN_S1 = RV_REG_S1,
RV_CTX_F_SEEN_S2 = RV_REG_S2,
RV_CTX_F_SEEN_S3 = RV_REG_S3,
RV_CTX_F_SEEN_S4 = RV_REG_S4,
RV_CTX_F_SEEN_S5 = RV_REG_S5,
- RV_CTX_F_SEEN_S6 = RV_REG_S6,
};
static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
@@ -71,7 +67,6 @@ static u8 bpf_to_rv_reg(int bpf_reg, struct rv_jit_context *ctx)
case RV_CTX_F_SEEN_S3:
case RV_CTX_F_SEEN_S4:
case RV_CTX_F_SEEN_S5:
- case RV_CTX_F_SEEN_S6:
__set_bit(reg, &ctx->flags);
}
return reg;
@@ -86,7 +81,6 @@ static bool seen_reg(int reg, struct rv_jit_context *ctx)
case RV_CTX_F_SEEN_S3:
case RV_CTX_F_SEEN_S4:
case RV_CTX_F_SEEN_S5:
- case RV_CTX_F_SEEN_S6:
return test_bit(reg, &ctx->flags);
}
return false;
@@ -102,32 +96,6 @@ static void mark_call(struct rv_jit_context *ctx)
__set_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
}
-static bool seen_call(struct rv_jit_context *ctx)
-{
- return test_bit(RV_CTX_F_SEEN_CALL, &ctx->flags);
-}
-
-static void mark_tail_call(struct rv_jit_context *ctx)
-{
- __set_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
-}
-
-static bool seen_tail_call(struct rv_jit_context *ctx)
-{
- return test_bit(RV_CTX_F_SEEN_TAIL_CALL, &ctx->flags);
-}
-
-static u8 rv_tail_call_reg(struct rv_jit_context *ctx)
-{
- mark_tail_call(ctx);
-
- if (seen_call(ctx)) {
- __set_bit(RV_CTX_F_SEEN_S6, &ctx->flags);
- return RV_REG_S6;
- }
- return RV_REG_A6;
-}
-
static bool is_32b_int(s64 val)
{
return -(1L << 31) <= val && val < (1L << 31);
@@ -252,10 +220,6 @@ static void __build_epilogue(bool is_tail_call, struct rv_jit_context *ctx)
emit_ld(RV_REG_S5, store_offset, RV_REG_SP, ctx);
store_offset -= 8;
}
- if (seen_reg(RV_REG_S6, ctx)) {
- emit_ld(RV_REG_S6, store_offset, RV_REG_SP, ctx);
- store_offset -= 8;
- }
emit_addi(RV_REG_SP, RV_REG_SP, stack_adjust, ctx);
/* Set return value. */
@@ -343,7 +307,6 @@ static void emit_branch(u8 cond, u8 rd, u8 rs, int rvoff,
static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
{
int tc_ninsn, off, start_insn = ctx->ninsns;
- u8 tcc = rv_tail_call_reg(ctx);
/* a0: &ctx
* a1: &array
@@ -366,9 +329,11 @@ static int emit_bpf_tail_call(int insn, struct rv_jit_context *ctx)
/* if (--TCC < 0)
* goto out;
*/
- emit_addi(RV_REG_TCC, tcc, -1, ctx);
+ emit_ld(RV_REG_TCC, ctx->tcc_offset, RV_REG_SP, ctx);
+ emit_addi(RV_REG_TCC, RV_REG_TCC, -1, ctx);
off = ninsns_rvoff(tc_ninsn - (ctx->ninsns - start_insn));
emit_branch(BPF_JSLT, RV_REG_TCC, RV_REG_ZERO, off, ctx);
+ emit_sd(RV_REG_SP, ctx->tcc_offset, RV_REG_TCC, ctx);
/* prog = array->ptrs[index];
* if (!prog)
@@ -767,7 +732,7 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
int i, ret, offset;
int *branches_off = NULL;
int stack_size = 0, nregs = m->nr_args;
- int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off;
+ int retval_off, args_off, nregs_off, ip_off, run_ctx_off, sreg_off, tcc_off;
struct bpf_tramp_links *fentry = &tlinks[BPF_TRAMP_FENTRY];
struct bpf_tramp_links *fexit = &tlinks[BPF_TRAMP_FEXIT];
struct bpf_tramp_links *fmod_ret = &tlinks[BPF_TRAMP_MODIFY_RETURN];
@@ -812,6 +777,8 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
*
* FP - sreg_off [ callee saved reg ]
*
+ * FP - tcc_off [ tail call count ] BPF_TRAMP_F_TAIL_CALL_CTX
+ *
* [ pads ] pads for 16 bytes alignment
*/
@@ -853,6 +820,11 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
stack_size += 8;
sreg_off = stack_size;
+ if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
+ stack_size += 8;
+ tcc_off = stack_size;
+ }
+
stack_size = round_up(stack_size, 16);
if (!is_struct_ops) {
@@ -879,6 +851,10 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
emit_addi(RV_REG_FP, RV_REG_SP, stack_size, ctx);
}
+ /* store tail call count */
+ if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+ emit_sd(RV_REG_FP, -tcc_off, RV_REG_TCC, ctx);
+
/* callee saved register S1 to pass start time */
emit_sd(RV_REG_FP, -sreg_off, RV_REG_S1, ctx);
@@ -932,6 +908,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
if (flags & BPF_TRAMP_F_CALL_ORIG) {
restore_args(nregs, args_off, ctx);
+ /* restore TCC to RV_REG_TCC before calling the original function */
+ if (flags & BPF_TRAMP_F_TAIL_CALL_CTX)
+ emit_ld(RV_REG_TCC, -tcc_off, RV_REG_FP, ctx);
ret = emit_call((const u64)orig_call, true, ctx);
if (ret)
goto out;
@@ -963,6 +942,9 @@ static int __arch_prepare_bpf_trampoline(struct bpf_tramp_image *im,
ret = emit_call((const u64)__bpf_tramp_exit, true, ctx);
if (ret)
goto out;
+ } else if (flags & BPF_TRAMP_F_TAIL_CALL_CTX) {
+ /* restore TCC to RV_REG_TCC before calling the original function */
+ emit_ld(RV_REG_TCC, -tcc_off, RV_REG_FP, ctx);
}
if (flags & BPF_TRAMP_F_RESTORE_REGS)
@@ -1455,6 +1437,9 @@ int bpf_jit_emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
if (ret < 0)
return ret;
+ /* restore TCC from stack to RV_REG_TCC */
+ emit_ld(RV_REG_TCC, ctx->tcc_offset, RV_REG_SP, ctx);
+
ret = emit_call(addr, fixed_addr, ctx);
if (ret)
return ret;
@@ -1733,8 +1718,7 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
stack_adjust += 8;
if (seen_reg(RV_REG_S5, ctx))
stack_adjust += 8;
- if (seen_reg(RV_REG_S6, ctx))
- stack_adjust += 8;
+ stack_adjust += 8; /* RV_REG_TCC */
stack_adjust = round_up(stack_adjust, 16);
stack_adjust += bpf_stack_adjust;
@@ -1749,7 +1733,8 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
* (TCC) register. This instruction is skipped for tail calls.
* Force using a 4-byte (non-compressed) instruction.
*/
- emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
+ if (!bpf_is_subprog(ctx->prog))
+ emit(rv_addi(RV_REG_TCC, RV_REG_ZERO, MAX_TAIL_CALL_CNT), ctx);
emit_addi(RV_REG_SP, RV_REG_SP, -stack_adjust, ctx);
@@ -1779,22 +1764,14 @@ void bpf_jit_build_prologue(struct rv_jit_context *ctx)
emit_sd(RV_REG_SP, store_offset, RV_REG_S5, ctx);
store_offset -= 8;
}
- if (seen_reg(RV_REG_S6, ctx)) {
- emit_sd(RV_REG_SP, store_offset, RV_REG_S6, ctx);
- store_offset -= 8;
- }
+ emit_sd(RV_REG_SP, store_offset, RV_REG_TCC, ctx);
+ ctx->tcc_offset = store_offset;
emit_addi(RV_REG_FP, RV_REG_SP, stack_adjust, ctx);
if (bpf_stack_adjust)
emit_addi(RV_REG_S5, RV_REG_SP, bpf_stack_adjust, ctx);
- /* Program contains calls and tail calls, so RV_REG_TCC need
- * to be saved across calls.
- */
- if (seen_tail_call(ctx) && seen_call(ctx))
- emit_mv(RV_REG_TCC_SAVED, RV_REG_TCC, ctx);
-
ctx->stack_size = stack_adjust;
}
@@ -1807,3 +1784,8 @@ bool bpf_jit_supports_kfunc_call(void)
{
return true;
}
+
+bool bpf_jit_supports_subprog_tailcalls(void)
+{
+ return true;
+}