diff mbox series

[bpf-next,2/8] arm32, bpf: add support for sign-extension load instruction

Message ID 20230905210621.1711859-3-puranjay12@gmail.com (mailing list archive)
State New
Headers show
Series arm32, bpf: add support for cpuv4 insns | expand

Commit Message

Puranjay Mohan Sept. 5, 2023, 9:06 p.m. UTC
The cpuv4 added the support of an instruction that is similar to load
but also sign-extends the result after the load.

BPF_MEMSX | <size> | BPF_LDX means dst = *(signed size *) (src + offset)
here <size> can be one of BPF_B, BPF_H, BPF_W.

ARM32 has instructions to load a byte or a half word with sign
extension into a 32bit register. As the JIT uses two 32 bit registers
to simulate a 64-bit BPF register, an extra instruction is emitted to
sign-extent the result up to the second register.

Signed-off-by: Puranjay Mohan <puranjay12@gmail.com>
---
 arch/arm/net/bpf_jit_32.c | 69 ++++++++++++++++++++++++++++++++++++++-
 arch/arm/net/bpf_jit_32.h |  2 ++
 2 files changed, 70 insertions(+), 1 deletion(-)

Comments

Russell King (Oracle) Sept. 5, 2023, 9:24 p.m. UTC | #1
On Tue, Sep 05, 2023 at 09:06:15PM +0000, Puranjay Mohan wrote:
> The cpuv4 added the support of an instruction that is similar to load
> but also sign-extends the result after the load.
> 
> BPF_MEMSX | <size> | BPF_LDX means dst = *(signed size *) (src + offset)
> here <size> can be one of BPF_B, BPF_H, BPF_W.
> 
> ARM32 has instructions to load a byte or a half word with sign
> extension into a 32bit register. As the JIT uses two 32 bit registers
> to simulate a 64-bit BPF register, an extra instruction is emitted to
> sign-extent the result up to the second register.
> 
> Signed-off-by: Puranjay Mohan <puranjay12@gmail.com>
> ---
>  arch/arm/net/bpf_jit_32.c | 69 ++++++++++++++++++++++++++++++++++++++-
>  arch/arm/net/bpf_jit_32.h |  2 ++
>  2 files changed, 70 insertions(+), 1 deletion(-)
> 
> diff --git a/arch/arm/net/bpf_jit_32.c b/arch/arm/net/bpf_jit_32.c
> index b26579da770e..f7c162479cf2 100644
> --- a/arch/arm/net/bpf_jit_32.c
> +++ b/arch/arm/net/bpf_jit_32.c
> @@ -333,6 +333,9 @@ static u32 arm_bpf_ldst_imm8(u32 op, u8 rt, u8 rn, s16 imm8)
>  #define ARM_LDRD_I(rt, rn, off)	arm_bpf_ldst_imm8(ARM_INST_LDRD_I, rt, rn, off)
>  #define ARM_LDRH_I(rt, rn, off)	arm_bpf_ldst_imm8(ARM_INST_LDRH_I, rt, rn, off)
>  
> +#define ARM_LDRSH_I(rt, rn, off) arm_bpf_ldst_imm8(ARM_INST_LDRSH_I, rt, rn, off)
> +#define ARM_LDRSB_I(rt, rn, off) arm_bpf_ldst_imm8(ARM_INST_LDRSB_I, rt, rn, off)
> +
>  #define ARM_STR_I(rt, rn, off)	arm_bpf_ldst_imm12(ARM_INST_STR_I, rt, rn, off)
>  #define ARM_STRB_I(rt, rn, off)	arm_bpf_ldst_imm12(ARM_INST_STRB_I, rt, rn, off)
>  #define ARM_STRD_I(rt, rn, off)	arm_bpf_ldst_imm8(ARM_INST_STRD_I, rt, rn, off)
> @@ -1026,6 +1029,24 @@ static bool is_ldst_imm(s16 off, const u8 size)
>  	return -off_max <= off && off <= off_max;
>  }
>  
> +static bool is_ldst_imm8(s16 off, const u8 size)
> +{
> +	s16 off_max = 0;
> +
> +	switch (size) {
> +	case BPF_B:
> +		off_max = 0xff;
> +		break;
> +	case BPF_W:
> +		off_max = 0xfff;
> +		break;
> +	case BPF_H:
> +		off_max = 0xff;
> +		break;
> +	}
> +	return -off_max <= off && off <= off_max;
> +}
> +
>  /* *(size *)(dst + off) = src */
>  static inline void emit_str_r(const s8 dst, const s8 src[],
>  			      s16 off, struct jit_ctx *ctx, const u8 sz){
> @@ -1105,6 +1126,45 @@ static inline void emit_ldx_r(const s8 dst[], const s8 src,
>  	arm_bpf_put_reg64(dst, rd, ctx);
>  }
>  
> +/* dst = *(signed size*)(src + off) */
> +static inline void emit_ldsx_r(const s8 dst[], const s8 src,
> +			       s16 off, struct jit_ctx *ctx, const u8 sz){
> +	const s8 *tmp = bpf2a32[TMP_REG_1];
> +	const s8 *rd = is_stacked(dst_lo) ? tmp : dst;
> +	s8 rm = src;
> +
> +	if (!is_ldst_imm8(off, sz)) {
> +		emit_a32_mov_i(tmp[0], off, ctx);
> +		emit(ARM_ADD_R(tmp[0], tmp[0], src), ctx);

Hmm. This looks inefficient when "off" is able to fit in an immediate.
Please try:

	int add_off;

	if (!is_ldst_imm8(off, sz)) {
		add_off = imm8m(off);
		if (add_off > 0) {
			emit(ARM_ADD_I(tmp[0], src, add_off), ctx);
			rm = tmp[0];
		} else {
			emit_a32_mov_i(tmp[0], off, ctx);
			emit(ARM_ADD_R(tmp[0], tmp[0], src), ctx);
			rm = tmp[0];
		}
		off = 0;
> +	} else if (rd[1] == rm) {
> +		emit(ARM_MOV_R(tmp[0], rm), ctx);
> +		rm = tmp[0];

Why do you need this? rd and rm can be the same for LDRS[BH].

> +	}
> +	switch (sz) {
> +	case BPF_B:
> +		/* Load a Byte with sign extension*/
> +		emit(ARM_LDRSB_I(rd[1], rm, off), ctx);
> +		/* Carry the sign extension to upper 32 bits */
> +		emit(ARM_ASR_I(rd[0], rd[1], 31), ctx);
> +		break;
> +	case BPF_H:
> +		/* Load a HalfWord with sign extension*/
> +		emit(ARM_LDRSH_I(rd[1], rm, off), ctx);
> +		/* Carry the sign extension to upper 32 bits */
> +		emit(ARM_ASR_I(rd[0], rd[1], 31), ctx);
> +		break;
> +	case BPF_W:
> +		/* Load a Word*/
> +		emit(ARM_LDR_I(rd[1], rm, off), ctx);
> +		/* Carry the sign extension to upper 32 bits */
> +		emit(ARM_ASR_I(rd[0], rd[1], 31), ctx);

The last instruction extending to the upper 32 bits is the same in each
of these cases, so is there any reason not to do it outside the switch
statement?
Puranjay Mohan Sept. 6, 2023, 11:47 a.m. UTC | #2
On Tue, Sep 05 2023, Russell King (Oracle) wrote:

[...]
>> +/* dst = *(signed size*)(src + off) */
>> +static inline void emit_ldsx_r(const s8 dst[], const s8 src,
>> +			       s16 off, struct jit_ctx *ctx, const u8 sz){
>> +	const s8 *tmp = bpf2a32[TMP_REG_1];
>> +	const s8 *rd = is_stacked(dst_lo) ? tmp : dst;
>> +	s8 rm = src;
>> +
>> +	if (!is_ldst_imm8(off, sz)) {
>> +		emit_a32_mov_i(tmp[0], off, ctx);
>> +		emit(ARM_ADD_R(tmp[0], tmp[0], src), ctx);
>
> Hmm. This looks inefficient when "off" is able to fit in an immediate.
> Please try:
>
> 	int add_off;
>
> 	if (!is_ldst_imm8(off, sz)) {
> 		add_off = imm8m(off);
> 		if (add_off > 0) {
> 			emit(ARM_ADD_I(tmp[0], src, add_off), ctx);
> 			rm = tmp[0];
> 		} else {
> 			emit_a32_mov_i(tmp[0], off, ctx);
> 			emit(ARM_ADD_R(tmp[0], tmp[0], src), ctx);
> 			rm = tmp[0];
> 		}
> 		off = 0;
>> +	} else if (rd[1] == rm) {
>> +		emit(ARM_MOV_R(tmp[0], rm), ctx);
>> +		rm = tmp[0];
>
> Why do you need this? rd and rm can be the same for LDRS[BH].

I agree that this is not required, will remove in the next version.
Will also use the suggested optimization for immediate.

>> +	}
>> +	switch (sz) {
>> +	case BPF_B:
>> +		/* Load a Byte with sign extension*/
>> +		emit(ARM_LDRSB_I(rd[1], rm, off), ctx);
>> +		/* Carry the sign extension to upper 32 bits */
>> +		emit(ARM_ASR_I(rd[0], rd[1], 31), ctx);
>> +		break;
>> +	case BPF_H:
>> +		/* Load a HalfWord with sign extension*/
>> +		emit(ARM_LDRSH_I(rd[1], rm, off), ctx);
>> +		/* Carry the sign extension to upper 32 bits */
>> +		emit(ARM_ASR_I(rd[0], rd[1], 31), ctx);
>> +		break;
>> +	case BPF_W:
>> +		/* Load a Word*/
>> +		emit(ARM_LDR_I(rd[1], rm, off), ctx);
>> +		/* Carry the sign extension to upper 32 bits */
>> +		emit(ARM_ASR_I(rd[0], rd[1], 31), ctx);
>
> The last instruction extending to the upper 32 bits is the same in each
> of these cases, so is there any reason not to do it outside the switch
> statement?

Will move it outside in the next version.


Thanks,
Puranjay
diff mbox series

Patch

diff --git a/arch/arm/net/bpf_jit_32.c b/arch/arm/net/bpf_jit_32.c
index b26579da770e..f7c162479cf2 100644
--- a/arch/arm/net/bpf_jit_32.c
+++ b/arch/arm/net/bpf_jit_32.c
@@ -333,6 +333,9 @@  static u32 arm_bpf_ldst_imm8(u32 op, u8 rt, u8 rn, s16 imm8)
 #define ARM_LDRD_I(rt, rn, off)	arm_bpf_ldst_imm8(ARM_INST_LDRD_I, rt, rn, off)
 #define ARM_LDRH_I(rt, rn, off)	arm_bpf_ldst_imm8(ARM_INST_LDRH_I, rt, rn, off)
 
+#define ARM_LDRSH_I(rt, rn, off) arm_bpf_ldst_imm8(ARM_INST_LDRSH_I, rt, rn, off)
+#define ARM_LDRSB_I(rt, rn, off) arm_bpf_ldst_imm8(ARM_INST_LDRSB_I, rt, rn, off)
+
 #define ARM_STR_I(rt, rn, off)	arm_bpf_ldst_imm12(ARM_INST_STR_I, rt, rn, off)
 #define ARM_STRB_I(rt, rn, off)	arm_bpf_ldst_imm12(ARM_INST_STRB_I, rt, rn, off)
 #define ARM_STRD_I(rt, rn, off)	arm_bpf_ldst_imm8(ARM_INST_STRD_I, rt, rn, off)
@@ -1026,6 +1029,24 @@  static bool is_ldst_imm(s16 off, const u8 size)
 	return -off_max <= off && off <= off_max;
 }
 
+static bool is_ldst_imm8(s16 off, const u8 size)
+{
+	s16 off_max = 0;
+
+	switch (size) {
+	case BPF_B:
+		off_max = 0xff;
+		break;
+	case BPF_W:
+		off_max = 0xfff;
+		break;
+	case BPF_H:
+		off_max = 0xff;
+		break;
+	}
+	return -off_max <= off && off <= off_max;
+}
+
 /* *(size *)(dst + off) = src */
 static inline void emit_str_r(const s8 dst, const s8 src[],
 			      s16 off, struct jit_ctx *ctx, const u8 sz){
@@ -1105,6 +1126,45 @@  static inline void emit_ldx_r(const s8 dst[], const s8 src,
 	arm_bpf_put_reg64(dst, rd, ctx);
 }
 
+/* dst = *(signed size*)(src + off) */
+static inline void emit_ldsx_r(const s8 dst[], const s8 src,
+			       s16 off, struct jit_ctx *ctx, const u8 sz){
+	const s8 *tmp = bpf2a32[TMP_REG_1];
+	const s8 *rd = is_stacked(dst_lo) ? tmp : dst;
+	s8 rm = src;
+
+	if (!is_ldst_imm8(off, sz)) {
+		emit_a32_mov_i(tmp[0], off, ctx);
+		emit(ARM_ADD_R(tmp[0], tmp[0], src), ctx);
+		rm = tmp[0];
+		off = 0;
+	} else if (rd[1] == rm) {
+		emit(ARM_MOV_R(tmp[0], rm), ctx);
+		rm = tmp[0];
+	}
+	switch (sz) {
+	case BPF_B:
+		/* Load a Byte with sign extension*/
+		emit(ARM_LDRSB_I(rd[1], rm, off), ctx);
+		/* Carry the sign extension to upper 32 bits */
+		emit(ARM_ASR_I(rd[0], rd[1], 31), ctx);
+		break;
+	case BPF_H:
+		/* Load a HalfWord with sign extension*/
+		emit(ARM_LDRSH_I(rd[1], rm, off), ctx);
+		/* Carry the sign extension to upper 32 bits */
+		emit(ARM_ASR_I(rd[0], rd[1], 31), ctx);
+		break;
+	case BPF_W:
+		/* Load a Word*/
+		emit(ARM_LDR_I(rd[1], rm, off), ctx);
+		/* Carry the sign extension to upper 32 bits */
+		emit(ARM_ASR_I(rd[0], rd[1], 31), ctx);
+		break;
+	}
+	arm_bpf_put_reg64(dst, rd, ctx);
+}
+
 /* Arithmatic Operation */
 static inline void emit_ar_r(const u8 rd, const u8 rt, const u8 rm,
 			     const u8 rn, struct jit_ctx *ctx, u8 op,
@@ -1603,8 +1663,15 @@  static int build_insn(const struct bpf_insn *insn, struct jit_ctx *ctx)
 	case BPF_LDX | BPF_MEM | BPF_H:
 	case BPF_LDX | BPF_MEM | BPF_B:
 	case BPF_LDX | BPF_MEM | BPF_DW:
+	/* LDSX: dst = *(signed size *)(src + off) */
+	case BPF_LDX | BPF_MEMSX | BPF_B:
+	case BPF_LDX | BPF_MEMSX | BPF_H:
+	case BPF_LDX | BPF_MEMSX | BPF_W:
 		rn = arm_bpf_get_reg32(src_lo, tmp2[1], ctx);
-		emit_ldx_r(dst, rn, off, ctx, BPF_SIZE(code));
+		if (BPF_MODE(insn->code) == BPF_MEMSX)
+			emit_ldsx_r(dst, rn, off, ctx, BPF_SIZE(code));
+		else
+			emit_ldx_r(dst, rn, off, ctx, BPF_SIZE(code));
 		break;
 	/* speculation barrier */
 	case BPF_ST | BPF_NOSPEC:
diff --git a/arch/arm/net/bpf_jit_32.h b/arch/arm/net/bpf_jit_32.h
index e0b593a1498d..79c7373fadce 100644
--- a/arch/arm/net/bpf_jit_32.h
+++ b/arch/arm/net/bpf_jit_32.h
@@ -79,9 +79,11 @@ 
 #define ARM_INST_LDST__IMM12	0x00000fff
 #define ARM_INST_LDRB_I		0x05500000
 #define ARM_INST_LDRB_R		0x07d00000
+#define ARM_INST_LDRSB_I	0x015000d0
 #define ARM_INST_LDRD_I		0x014000d0
 #define ARM_INST_LDRH_I		0x015000b0
 #define ARM_INST_LDRH_R		0x019000b0
+#define ARM_INST_LDRSH_I	0x015000f0
 #define ARM_INST_LDR_I		0x05100000
 #define ARM_INST_LDR_R		0x07900000