diff mbox series

[bpf-next,v2,7/9] riscv, bpf: optimize calls

Message ID 20191216091343.23260-8-bjorn.topel@gmail.com (mailing list archive)
State New, archived
Headers show
Series riscv: BPF JIT fix, optimizations and far jumps support | expand

Commit Message

Björn Töpel Dec. 16, 2019, 9:13 a.m. UTC
Instead of using emit_imm() and emit_jalr() which can expand to six
instructions, start using jal or auipc+jalr.

Signed-off-by: Björn Töpel <bjorn.topel@gmail.com>
---
 arch/riscv/net/bpf_jit_comp.c | 101 +++++++++++++++++++++-------------
 1 file changed, 64 insertions(+), 37 deletions(-)

Comments

Palmer Dabbelt Dec. 23, 2019, 6:58 p.m. UTC | #1
On Mon, 16 Dec 2019 01:13:41 PST (-0800), Bjorn Topel wrote:
> Instead of using emit_imm() and emit_jalr() which can expand to six
> instructions, start using jal or auipc+jalr.
>
> Signed-off-by: Björn Töpel <bjorn.topel@gmail.com>
> ---
>  arch/riscv/net/bpf_jit_comp.c | 101 +++++++++++++++++++++-------------
>  1 file changed, 64 insertions(+), 37 deletions(-)
>
> diff --git a/arch/riscv/net/bpf_jit_comp.c b/arch/riscv/net/bpf_jit_comp.c
> index 46cff093f526..8d7e3343a08c 100644
> --- a/arch/riscv/net/bpf_jit_comp.c
> +++ b/arch/riscv/net/bpf_jit_comp.c
> @@ -811,11 +811,12 @@ static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
>  	*rd = RV_REG_T2;
>  }
>
> -static void emit_jump_and_link(u8 rd, int rvoff, struct rv_jit_context *ctx)
> +static void emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
> +			       struct rv_jit_context *ctx)
>  {
>  	s64 upper, lower;
>
> -	if (is_21b_int(rvoff)) {
> +	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
>  		emit(rv_jal(rd, rvoff >> 1), ctx);
>  		return;
>  	}
> @@ -832,6 +833,28 @@ static bool is_signed_bpf_cond(u8 cond)
>  		cond == BPF_JSGE || cond == BPF_JSLE;
>  }
>
> +static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
> +{
> +	s64 off = 0;
> +	u64 ip;
> +	u8 rd;
> +
> +	if (addr && ctx->insns) {
> +		ip = (u64)(long)(ctx->insns + ctx->ninsns);
> +		off = addr - ip;
> +		if (!is_32b_int(off)) {
> +			pr_err("bpf-jit: target call addr %pK is out of range\n",
> +			       (void *)addr);
> +			return -ERANGE;
> +		}
> +	}
> +
> +	emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
> +	rd = bpf_to_rv_reg(BPF_REG_0, ctx);
> +	emit(rv_addi(rd, RV_REG_A0, 0), ctx);

Why are they out of order?  It seems like it'd be better to just have the BPF
calling convention match the RISC-V calling convention, as that'd avoid
juggling the registers around.

> +	return 0;
> +}
> +
>  static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
>  		     bool extra_pass)
>  {
> @@ -1107,7 +1130,7 @@ static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
>  	/* JUMP off */
>  	case BPF_JMP | BPF_JA:
>  		rvoff = rv_offset(i, off, ctx);
> -		emit_jump_and_link(RV_REG_ZERO, rvoff, ctx);
> +		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
>  		break;
>
>  	/* IF (dst COND src) JUMP off */
> @@ -1209,7 +1232,7 @@ static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
>  	case BPF_JMP | BPF_CALL:
>  	{
>  		bool fixed;
> -		int i, ret;
> +		int ret;
>  		u64 addr;
>
>  		mark_call(ctx);
> @@ -1217,20 +1240,9 @@ static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
>  					    &fixed);
>  		if (ret < 0)
>  			return ret;
> -		if (fixed) {
> -			emit_imm(RV_REG_T1, addr, ctx);
> -		} else {
> -			i = ctx->ninsns;
> -			emit_imm(RV_REG_T1, addr, ctx);
> -			for (i = ctx->ninsns - i; i < 8; i++) {
> -				/* nop */
> -				emit(rv_addi(RV_REG_ZERO, RV_REG_ZERO, 0),
> -				     ctx);
> -			}
> -		}
> -		emit(rv_jalr(RV_REG_RA, RV_REG_T1, 0), ctx);
> -		rd = bpf_to_rv_reg(BPF_REG_0, ctx);
> -		emit(rv_addi(rd, RV_REG_A0, 0), ctx);
> +		ret = emit_call(fixed, addr, ctx);
> +		if (ret)
> +			return ret;
>  		break;
>  	}
>  	/* tail call */
> @@ -1245,7 +1257,7 @@ static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
>  			break;
>
>  		rvoff = epilogue_offset(ctx);
> -		emit_jump_and_link(RV_REG_ZERO, rvoff, ctx);
> +		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
>  		break;
>
>  	/* dst = imm64 */
> @@ -1508,7 +1520,7 @@ static void build_epilogue(struct rv_jit_context *ctx)
>  	__build_epilogue(false, ctx);
>  }
>
> -static int build_body(struct rv_jit_context *ctx, bool extra_pass)
> +static int build_body(struct rv_jit_context *ctx, bool extra_pass, int *offset)
>  {
>  	const struct bpf_prog *prog = ctx->prog;
>  	int i;
> @@ -1520,12 +1532,12 @@ static int build_body(struct rv_jit_context *ctx, bool extra_pass)
>  		ret = emit_insn(insn, ctx, extra_pass);
>  		if (ret > 0) {
>  			i++;
> -			if (ctx->insns == NULL)
> -				ctx->offset[i] = ctx->ninsns;
> +			if (offset)
> +				offset[i] = ctx->ninsns;
>  			continue;
>  		}
> -		if (ctx->insns == NULL)
> -			ctx->offset[i] = ctx->ninsns;
> +		if (offset)
> +			offset[i] = ctx->ninsns;
>  		if (ret)
>  			return ret;
>  	}
> @@ -1553,8 +1565,8 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
>  	struct bpf_prog *tmp, *orig_prog = prog;
>  	int pass = 0, prev_ninsns = 0, i;
>  	struct rv_jit_data *jit_data;
> +	unsigned int image_size = 0;
>  	struct rv_jit_context *ctx;
> -	unsigned int image_size;
>
>  	if (!prog->jit_requested)
>  		return orig_prog;
> @@ -1599,36 +1611,51 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
>  	for (i = 0; i < 16; i++) {
>  		pass++;
>  		ctx->ninsns = 0;
> -		if (build_body(ctx, extra_pass)) {
> +		if (build_body(ctx, extra_pass, ctx->offset)) {
>  			prog = orig_prog;
>  			goto out_offset;
>  		}
>  		build_prologue(ctx);
>  		ctx->epilogue_offset = ctx->ninsns;
>  		build_epilogue(ctx);
> -		if (ctx->ninsns == prev_ninsns)
> -			break;
> +
> +		if (ctx->ninsns == prev_ninsns) {
> +			if (jit_data->header)
> +				break;
> +
> +			image_size = sizeof(u32) * ctx->ninsns;
> +			jit_data->header =
> +				bpf_jit_binary_alloc(image_size,
> +						     &jit_data->image,
> +						     sizeof(u32),
> +						     bpf_fill_ill_insns);
> +			if (!jit_data->header) {
> +				prog = orig_prog;
> +				goto out_offset;
> +			}
> +
> +			ctx->insns = (u32 *)jit_data->image;
> +			/* Now, when the image is allocated, the image
> +			 * can potentially shrink more (auipc/jalr ->
> +			 * jal).
> +			 */
> +		}

It seems like these fragments should go along with patch #2 that introduces the
code, as I don't see anything above that makes this necessary here.

>  		prev_ninsns = ctx->ninsns;
>  	}
>
> -	/* Allocate image, now that we know the size. */
> -	image_size = sizeof(u32) * ctx->ninsns;
> -	jit_data->header = bpf_jit_binary_alloc(image_size, &jit_data->image,
> -						sizeof(u32),
> -						bpf_fill_ill_insns);
> -	if (!jit_data->header) {
> +	if (i == 16) {
> +		pr_err("bpf-jit: image did not converge in <%d passes!\n", i);
> +		bpf_jit_binary_free(jit_data->header);
>  		prog = orig_prog;
>  		goto out_offset;
>  	}
>
> -	/* Second, real pass, that acutally emits the image. */
> -	ctx->insns = (u32 *)jit_data->image;
>  skip_init_ctx:
>  	pass++;
>  	ctx->ninsns = 0;
>
>  	build_prologue(ctx);
> -	if (build_body(ctx, extra_pass)) {
> +	if (build_body(ctx, extra_pass, NULL)) {
>  		bpf_jit_binary_free(jit_data->header);
>  		prog = orig_prog;
>  		goto out_offset;
Björn Töpel Jan. 7, 2020, 10:14 a.m. UTC | #2
On Mon, 23 Dec 2019 at 19:58, Palmer Dabbelt <palmerdabbelt@google.com> wrote:
>
> On Mon, 16 Dec 2019 01:13:41 PST (-0800), Bjorn Topel wrote:
> > Instead of using emit_imm() and emit_jalr() which can expand to six
> > instructions, start using jal or auipc+jalr.
> >
> > Signed-off-by: Björn Töpel <bjorn.topel@gmail.com>
> > ---
> >  arch/riscv/net/bpf_jit_comp.c | 101 +++++++++++++++++++++-------------
> >  1 file changed, 64 insertions(+), 37 deletions(-)
> >
> > diff --git a/arch/riscv/net/bpf_jit_comp.c b/arch/riscv/net/bpf_jit_comp.c
> > index 46cff093f526..8d7e3343a08c 100644
> > --- a/arch/riscv/net/bpf_jit_comp.c
> > +++ b/arch/riscv/net/bpf_jit_comp.c
> > @@ -811,11 +811,12 @@ static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
> >       *rd = RV_REG_T2;
> >  }
> >
> > -static void emit_jump_and_link(u8 rd, int rvoff, struct rv_jit_context *ctx)
> > +static void emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
> > +                            struct rv_jit_context *ctx)
> >  {
> >       s64 upper, lower;
> >
> > -     if (is_21b_int(rvoff)) {
> > +     if (rvoff && is_21b_int(rvoff) && !force_jalr) {
> >               emit(rv_jal(rd, rvoff >> 1), ctx);
> >               return;
> >       }
> > @@ -832,6 +833,28 @@ static bool is_signed_bpf_cond(u8 cond)
> >               cond == BPF_JSGE || cond == BPF_JSLE;
> >  }
> >
> > +static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
> > +{
> > +     s64 off = 0;
> > +     u64 ip;
> > +     u8 rd;
> > +
> > +     if (addr && ctx->insns) {
> > +             ip = (u64)(long)(ctx->insns + ctx->ninsns);
> > +             off = addr - ip;
> > +             if (!is_32b_int(off)) {
> > +                     pr_err("bpf-jit: target call addr %pK is out of range\n",
> > +                            (void *)addr);
> > +                     return -ERANGE;
> > +             }
> > +     }
> > +
> > +     emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
> > +     rd = bpf_to_rv_reg(BPF_REG_0, ctx);
> > +     emit(rv_addi(rd, RV_REG_A0, 0), ctx);
>
> Why are they out of order?  It seems like it'd be better to just have the BPF
> calling convention match the RISC-V calling convention, as that'd avoid
> juggling the registers around.
>

BPF passes arguments in R1, R2, ..., and return value in R0. Given
that a0 plays the role of R1 and R0, how can we avoid the register
juggling (without complicating the JIT too much)? It would be nice
though... and ARM64 has the same concern AFAIK.

[...]
> > @@ -1599,36 +1611,51 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
> >       for (i = 0; i < 16; i++) {
> >               pass++;
> >               ctx->ninsns = 0;
> > -             if (build_body(ctx, extra_pass)) {
> > +             if (build_body(ctx, extra_pass, ctx->offset)) {
> >                       prog = orig_prog;
> >                       goto out_offset;
> >               }
> >               build_prologue(ctx);
> >               ctx->epilogue_offset = ctx->ninsns;
> >               build_epilogue(ctx);
> > -             if (ctx->ninsns == prev_ninsns)
> > -                     break;
> > +
> > +             if (ctx->ninsns == prev_ninsns) {
> > +                     if (jit_data->header)
> > +                             break;
> > +
> > +                     image_size = sizeof(u32) * ctx->ninsns;
> > +                     jit_data->header =
> > +                             bpf_jit_binary_alloc(image_size,
> > +                                                  &jit_data->image,
> > +                                                  sizeof(u32),
> > +                                                  bpf_fill_ill_insns);
> > +                     if (!jit_data->header) {
> > +                             prog = orig_prog;
> > +                             goto out_offset;
> > +                     }
> > +
> > +                     ctx->insns = (u32 *)jit_data->image;
> > +                     /* Now, when the image is allocated, the image
> > +                      * can potentially shrink more (auipc/jalr ->
> > +                      * jal).
> > +                      */
> > +             }
>
> It seems like these fragments should go along with patch #2 that introduces the
> code, as I don't see anything above that makes this necessary here.
>

No, you're right.


Björn
Palmer Dabbelt Jan. 28, 2020, 2:15 a.m. UTC | #3
On Tue, 07 Jan 2020 02:14:17 PST (-0800), Bjorn Topel wrote:
> On Mon, 23 Dec 2019 at 19:58, Palmer Dabbelt <palmerdabbelt@google.com> wrote:
>>
>> On Mon, 16 Dec 2019 01:13:41 PST (-0800), Bjorn Topel wrote:
>> > Instead of using emit_imm() and emit_jalr() which can expand to six
>> > instructions, start using jal or auipc+jalr.
>> >
>> > Signed-off-by: Björn Töpel <bjorn.topel@gmail.com>
>> > ---
>> >  arch/riscv/net/bpf_jit_comp.c | 101 +++++++++++++++++++++-------------
>> >  1 file changed, 64 insertions(+), 37 deletions(-)
>> >
>> > diff --git a/arch/riscv/net/bpf_jit_comp.c b/arch/riscv/net/bpf_jit_comp.c
>> > index 46cff093f526..8d7e3343a08c 100644
>> > --- a/arch/riscv/net/bpf_jit_comp.c
>> > +++ b/arch/riscv/net/bpf_jit_comp.c
>> > @@ -811,11 +811,12 @@ static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
>> >       *rd = RV_REG_T2;
>> >  }
>> >
>> > -static void emit_jump_and_link(u8 rd, int rvoff, struct rv_jit_context *ctx)
>> > +static void emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
>> > +                            struct rv_jit_context *ctx)
>> >  {
>> >       s64 upper, lower;
>> >
>> > -     if (is_21b_int(rvoff)) {
>> > +     if (rvoff && is_21b_int(rvoff) && !force_jalr) {
>> >               emit(rv_jal(rd, rvoff >> 1), ctx);
>> >               return;
>> >       }
>> > @@ -832,6 +833,28 @@ static bool is_signed_bpf_cond(u8 cond)
>> >               cond == BPF_JSGE || cond == BPF_JSLE;
>> >  }
>> >
>> > +static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
>> > +{
>> > +     s64 off = 0;
>> > +     u64 ip;
>> > +     u8 rd;
>> > +
>> > +     if (addr && ctx->insns) {
>> > +             ip = (u64)(long)(ctx->insns + ctx->ninsns);
>> > +             off = addr - ip;
>> > +             if (!is_32b_int(off)) {
>> > +                     pr_err("bpf-jit: target call addr %pK is out of range\n",
>> > +                            (void *)addr);
>> > +                     return -ERANGE;
>> > +             }
>> > +     }
>> > +
>> > +     emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
>> > +     rd = bpf_to_rv_reg(BPF_REG_0, ctx);
>> > +     emit(rv_addi(rd, RV_REG_A0, 0), ctx);
>>
>> Why are they out of order?  It seems like it'd be better to just have the BPF
>> calling convention match the RISC-V calling convention, as that'd avoid
>> juggling the registers around.
>>
>
> BPF passes arguments in R1, R2, ..., and return value in R0. Given
> that a0 plays the role of R1 and R0, how can we avoid the register
> juggling (without complicating the JIT too much)? It would be nice
> though... and ARM64 has the same concern AFAIK.

Oh, why did you say that?  This kind of stuff is why I'm twenty days behind on
email...

https://lore.kernel.org/bpf/20200128021145.36774-1-palmerdabbelt@google.com/T/#t

:)

> [...]
>> > @@ -1599,36 +1611,51 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
>> >       for (i = 0; i < 16; i++) {
>> >               pass++;
>> >               ctx->ninsns = 0;
>> > -             if (build_body(ctx, extra_pass)) {
>> > +             if (build_body(ctx, extra_pass, ctx->offset)) {
>> >                       prog = orig_prog;
>> >                       goto out_offset;
>> >               }
>> >               build_prologue(ctx);
>> >               ctx->epilogue_offset = ctx->ninsns;
>> >               build_epilogue(ctx);
>> > -             if (ctx->ninsns == prev_ninsns)
>> > -                     break;
>> > +
>> > +             if (ctx->ninsns == prev_ninsns) {
>> > +                     if (jit_data->header)
>> > +                             break;
>> > +
>> > +                     image_size = sizeof(u32) * ctx->ninsns;
>> > +                     jit_data->header =
>> > +                             bpf_jit_binary_alloc(image_size,
>> > +                                                  &jit_data->image,
>> > +                                                  sizeof(u32),
>> > +                                                  bpf_fill_ill_insns);
>> > +                     if (!jit_data->header) {
>> > +                             prog = orig_prog;
>> > +                             goto out_offset;
>> > +                     }
>> > +
>> > +                     ctx->insns = (u32 *)jit_data->image;
>> > +                     /* Now, when the image is allocated, the image
>> > +                      * can potentially shrink more (auipc/jalr ->
>> > +                      * jal).
>> > +                      */
>> > +             }
>>
>> It seems like these fragments should go along with patch #2 that introduces the
>> code, as I don't see anything above that makes this necessary here.
>>
>
> No, you're right.
>
>
> Björn
Björn Töpel Feb. 3, 2020, 12:11 p.m. UTC | #4
On Tue, 28 Jan 2020 at 03:15, Palmer Dabbelt <palmerdabbelt@google.com> wrote:
>
[...]
> >
> > BPF passes arguments in R1, R2, ..., and return value in R0. Given
> > that a0 plays the role of R1 and R0, how can we avoid the register
> > juggling (without complicating the JIT too much)? It would be nice
> > though... and ARM64 has the same concern AFAIK.
>
> Oh, why did you say that?  This kind of stuff is why I'm twenty days behind on
> email...
>
> https://lore.kernel.org/bpf/20200128021145.36774-1-palmerdabbelt@google.com/T/#t
>
> :)
>

(back from vacation)

:-D Very nice, I'll take a look!


Björn
diff mbox series

Patch

diff --git a/arch/riscv/net/bpf_jit_comp.c b/arch/riscv/net/bpf_jit_comp.c
index 46cff093f526..8d7e3343a08c 100644
--- a/arch/riscv/net/bpf_jit_comp.c
+++ b/arch/riscv/net/bpf_jit_comp.c
@@ -811,11 +811,12 @@  static void emit_sext_32_rd(u8 *rd, struct rv_jit_context *ctx)
 	*rd = RV_REG_T2;
 }
 
-static void emit_jump_and_link(u8 rd, int rvoff, struct rv_jit_context *ctx)
+static void emit_jump_and_link(u8 rd, s64 rvoff, bool force_jalr,
+			       struct rv_jit_context *ctx)
 {
 	s64 upper, lower;
 
-	if (is_21b_int(rvoff)) {
+	if (rvoff && is_21b_int(rvoff) && !force_jalr) {
 		emit(rv_jal(rd, rvoff >> 1), ctx);
 		return;
 	}
@@ -832,6 +833,28 @@  static bool is_signed_bpf_cond(u8 cond)
 		cond == BPF_JSGE || cond == BPF_JSLE;
 }
 
+static int emit_call(bool fixed, u64 addr, struct rv_jit_context *ctx)
+{
+	s64 off = 0;
+	u64 ip;
+	u8 rd;
+
+	if (addr && ctx->insns) {
+		ip = (u64)(long)(ctx->insns + ctx->ninsns);
+		off = addr - ip;
+		if (!is_32b_int(off)) {
+			pr_err("bpf-jit: target call addr %pK is out of range\n",
+			       (void *)addr);
+			return -ERANGE;
+		}
+	}
+
+	emit_jump_and_link(RV_REG_RA, off, !fixed, ctx);
+	rd = bpf_to_rv_reg(BPF_REG_0, ctx);
+	emit(rv_addi(rd, RV_REG_A0, 0), ctx);
+	return 0;
+}
+
 static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
 		     bool extra_pass)
 {
@@ -1107,7 +1130,7 @@  static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
 	/* JUMP off */
 	case BPF_JMP | BPF_JA:
 		rvoff = rv_offset(i, off, ctx);
-		emit_jump_and_link(RV_REG_ZERO, rvoff, ctx);
+		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
 		break;
 
 	/* IF (dst COND src) JUMP off */
@@ -1209,7 +1232,7 @@  static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
 	case BPF_JMP | BPF_CALL:
 	{
 		bool fixed;
-		int i, ret;
+		int ret;
 		u64 addr;
 
 		mark_call(ctx);
@@ -1217,20 +1240,9 @@  static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
 					    &fixed);
 		if (ret < 0)
 			return ret;
-		if (fixed) {
-			emit_imm(RV_REG_T1, addr, ctx);
-		} else {
-			i = ctx->ninsns;
-			emit_imm(RV_REG_T1, addr, ctx);
-			for (i = ctx->ninsns - i; i < 8; i++) {
-				/* nop */
-				emit(rv_addi(RV_REG_ZERO, RV_REG_ZERO, 0),
-				     ctx);
-			}
-		}
-		emit(rv_jalr(RV_REG_RA, RV_REG_T1, 0), ctx);
-		rd = bpf_to_rv_reg(BPF_REG_0, ctx);
-		emit(rv_addi(rd, RV_REG_A0, 0), ctx);
+		ret = emit_call(fixed, addr, ctx);
+		if (ret)
+			return ret;
 		break;
 	}
 	/* tail call */
@@ -1245,7 +1257,7 @@  static int emit_insn(const struct bpf_insn *insn, struct rv_jit_context *ctx,
 			break;
 
 		rvoff = epilogue_offset(ctx);
-		emit_jump_and_link(RV_REG_ZERO, rvoff, ctx);
+		emit_jump_and_link(RV_REG_ZERO, rvoff, false, ctx);
 		break;
 
 	/* dst = imm64 */
@@ -1508,7 +1520,7 @@  static void build_epilogue(struct rv_jit_context *ctx)
 	__build_epilogue(false, ctx);
 }
 
-static int build_body(struct rv_jit_context *ctx, bool extra_pass)
+static int build_body(struct rv_jit_context *ctx, bool extra_pass, int *offset)
 {
 	const struct bpf_prog *prog = ctx->prog;
 	int i;
@@ -1520,12 +1532,12 @@  static int build_body(struct rv_jit_context *ctx, bool extra_pass)
 		ret = emit_insn(insn, ctx, extra_pass);
 		if (ret > 0) {
 			i++;
-			if (ctx->insns == NULL)
-				ctx->offset[i] = ctx->ninsns;
+			if (offset)
+				offset[i] = ctx->ninsns;
 			continue;
 		}
-		if (ctx->insns == NULL)
-			ctx->offset[i] = ctx->ninsns;
+		if (offset)
+			offset[i] = ctx->ninsns;
 		if (ret)
 			return ret;
 	}
@@ -1553,8 +1565,8 @@  struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 	struct bpf_prog *tmp, *orig_prog = prog;
 	int pass = 0, prev_ninsns = 0, i;
 	struct rv_jit_data *jit_data;
+	unsigned int image_size = 0;
 	struct rv_jit_context *ctx;
-	unsigned int image_size;
 
 	if (!prog->jit_requested)
 		return orig_prog;
@@ -1599,36 +1611,51 @@  struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 	for (i = 0; i < 16; i++) {
 		pass++;
 		ctx->ninsns = 0;
-		if (build_body(ctx, extra_pass)) {
+		if (build_body(ctx, extra_pass, ctx->offset)) {
 			prog = orig_prog;
 			goto out_offset;
 		}
 		build_prologue(ctx);
 		ctx->epilogue_offset = ctx->ninsns;
 		build_epilogue(ctx);
-		if (ctx->ninsns == prev_ninsns)
-			break;
+
+		if (ctx->ninsns == prev_ninsns) {
+			if (jit_data->header)
+				break;
+
+			image_size = sizeof(u32) * ctx->ninsns;
+			jit_data->header =
+				bpf_jit_binary_alloc(image_size,
+						     &jit_data->image,
+						     sizeof(u32),
+						     bpf_fill_ill_insns);
+			if (!jit_data->header) {
+				prog = orig_prog;
+				goto out_offset;
+			}
+
+			ctx->insns = (u32 *)jit_data->image;
+			/* Now, when the image is allocated, the image
+			 * can potentially shrink more (auipc/jalr ->
+			 * jal).
+			 */
+		}
 		prev_ninsns = ctx->ninsns;
 	}
 
-	/* Allocate image, now that we know the size. */
-	image_size = sizeof(u32) * ctx->ninsns;
-	jit_data->header = bpf_jit_binary_alloc(image_size, &jit_data->image,
-						sizeof(u32),
-						bpf_fill_ill_insns);
-	if (!jit_data->header) {
+	if (i == 16) {
+		pr_err("bpf-jit: image did not converge in <%d passes!\n", i);
+		bpf_jit_binary_free(jit_data->header);
 		prog = orig_prog;
 		goto out_offset;
 	}
 
-	/* Second, real pass, that acutally emits the image. */
-	ctx->insns = (u32 *)jit_data->image;
 skip_init_ctx:
 	pass++;
 	ctx->ninsns = 0;
 
 	build_prologue(ctx);
-	if (build_body(ctx, extra_pass)) {
+	if (build_body(ctx, extra_pass, NULL)) {
 		bpf_jit_binary_free(jit_data->header);
 		prog = orig_prog;
 		goto out_offset;