diff mbox series

[bpf-next,v2,01/15] bpf: Support new sign-extension load insns

Message ID 20230713060724.389084-1-yhs@fb.com (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series bpf: Support new insns from cpu v4 | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-PR fail PR summary
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ${{ matrix.test }} on ${{ matrix.arch }} with ${{ matrix.toolchain_full }}
bpf/vmtest-bpf-next-VM_Test-2 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-3 fail Logs for build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-4 fail Logs for build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-5 fail Logs for build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-6 fail Logs for build for x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-7 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-8 success Logs for veristat
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for bpf-next, async
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 1736 this patch: 1736
netdev/cc_maintainers warning 16 maintainers not CCed: tglx@linutronix.de hpa@zytor.com dsahern@kernel.org mingo@redhat.com kpsingh@kernel.org x86@kernel.org john.fastabend@gmail.com sdf@google.com netdev@vger.kernel.org martin.lau@linux.dev song@kernel.org dave.hansen@linux.intel.com davem@davemloft.net jolsa@kernel.org haoluo@google.com bp@alien8.de
netdev/build_clang fail Errors and warnings before: 186 this patch: 186
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 1735 this patch: 1735
netdev/checkpatch fail CHECK: Macro argument 'SIZE' may be better as '(SIZE)' to avoid precedence issues CHECK: No space is necessary after a cast CHECK: multiple assignments should be avoided ERROR: Macros with multiple statements should be enclosed in a do - while loop ERROR: spaces required around that ':' (ctx:VxE) WARNING: line length of 81 exceeds 80 columns WARNING: line length of 82 exceeds 80 columns WARNING: line length of 86 exceeds 80 columns WARNING: line length of 87 exceeds 80 columns WARNING: line length of 94 exceeds 80 columns WARNING: line length of 96 exceeds 80 columns WARNING: line length of 98 exceeds 80 columns WARNING: line length of 99 exceeds 80 columns WARNING: macros should not use a trailing semicolon
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Yonghong Song July 13, 2023, 6:07 a.m. UTC
Add interpreter/jit support for new sign-extension load insns
which adds a new mode (BPF_MEMSX).
Also add verifier support to recognize these insns and to
do proper verification with new insns. In verifier, besides
to deduce proper bounds for the dst_reg, probed memory access
is handled by remembering insn mode in insn->imm field so later
on proper jit insns can be emitted.

Signed-off-by: Yonghong Song <yhs@fb.com>
---
 arch/x86/net/bpf_jit_comp.c    |  32 ++++++++-
 include/uapi/linux/bpf.h       |   1 +
 kernel/bpf/core.c              |  13 ++++
 kernel/bpf/verifier.c          | 125 +++++++++++++++++++++++++++------
 tools/include/uapi/linux/bpf.h |   1 +
 5 files changed, 151 insertions(+), 21 deletions(-)

Comments

Alexei Starovoitov July 14, 2023, 6:13 p.m. UTC | #1
On Wed, Jul 12, 2023 at 11:07:24PM -0700, Yonghong Song wrote:
>  
> @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn)
>  	LDST(DW, u64)
>  #undef LDST
>  
> +#define LDS(SIZEOP, SIZE)						\

LDSX ?

> +	LDX_MEMSX_##SIZEOP:						\
> +		DST = *(SIZE *)(unsigned long) (SRC + insn->off);	\
> +		CONT;
> +
> +	LDS(B,   s8)
> +	LDS(H,  s16)
> +	LDS(W,  s32)
> +#undef LDS

...

> @@ -17503,7 +17580,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
>  		if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) ||
>  		    insn->code == (BPF_LDX | BPF_MEM | BPF_H) ||
>  		    insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
> -		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) {
> +		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW) ||
> +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) ||
> +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) ||
> +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) {
>  			type = BPF_READ;
>  		} else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) ||
>  			   insn->code == (BPF_STX | BPF_MEM | BPF_H) ||
> @@ -17562,6 +17642,11 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
>  		 */
>  		case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED:
>  			if (type == BPF_READ) {
> +				/* it is hard to differentiate that the
> +				 * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX,
> +				 * let us use insn->imm to remember it.
> +				 */
> +				insn->imm = BPF_MODE(insn->code);

That's a fragile approach.
And the evidence is in this patch.
This part of interpreter:
        LDX_PROBE_MEM_##SIZEOP:                                         \
                bpf_probe_read_kernel(&DST, sizeof(SIZE),               \
                                      (const void *)(long) (SRC + insn->off));  \
                DST = *((SIZE *)&DST);                                  \

wasn't updated to handle sign extension.

How about
#define BPF_PROBE_MEMSX 0x40 /* same as BPF_IND */

and handle it in JITs and interpreter.
We need a selftest for BTF style access to signed fields to make sure both
interpreter and JIT handling of BPF_PROBE_MEMSX is tested.
Yonghong Song July 14, 2023, 11:22 p.m. UTC | #2
On 7/14/23 11:13 AM, Alexei Starovoitov wrote:
> On Wed, Jul 12, 2023 at 11:07:24PM -0700, Yonghong Song wrote:
>>   
>> @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn)
>>   	LDST(DW, u64)
>>   #undef LDST
>>   
>> +#define LDS(SIZEOP, SIZE)						\
> 
> LDSX ?

Ack.

> 
>> +	LDX_MEMSX_##SIZEOP:						\
>> +		DST = *(SIZE *)(unsigned long) (SRC + insn->off);	\
>> +		CONT;
>> +
>> +	LDS(B,   s8)
>> +	LDS(H,  s16)
>> +	LDS(W,  s32)
>> +#undef LDS
> 
> ...
> 
>> @@ -17503,7 +17580,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
>>   		if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) ||
>>   		    insn->code == (BPF_LDX | BPF_MEM | BPF_H) ||
>>   		    insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
>> -		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) {
>> +		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW) ||
>> +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) ||
>> +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) ||
>> +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) {
>>   			type = BPF_READ;
>>   		} else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) ||
>>   			   insn->code == (BPF_STX | BPF_MEM | BPF_H) ||
>> @@ -17562,6 +17642,11 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
>>   		 */
>>   		case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED:
>>   			if (type == BPF_READ) {
>> +				/* it is hard to differentiate that the
>> +				 * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX,
>> +				 * let us use insn->imm to remember it.
>> +				 */
>> +				insn->imm = BPF_MODE(insn->code);
> 
> That's a fragile approach.
> And the evidence is in this patch.
> This part of interpreter:
>          LDX_PROBE_MEM_##SIZEOP:                                         \
>                  bpf_probe_read_kernel(&DST, sizeof(SIZE),               \
>                                        (const void *)(long) (SRC + insn->off));  \
>                  DST = *((SIZE *)&DST);                                  \
> 
> wasn't updated to handle sign extension.

Thanks for catching this!
> 
> How about
> #define BPF_PROBE_MEMSX 0x40 /* same as BPF_IND */
> 
> and handle it in JITs and interpreter.

Good idea. Will do.

> We need a selftest for BTF style access to signed fields to make sure both
> interpreter and JIT handling of BPF_PROBE_MEMSX is tested.

Will do.
Eduard Zingerman July 17, 2023, 1:39 a.m. UTC | #3
On Wed, 2023-07-12 at 23:07 -0700, Yonghong Song wrote:
> > Add interpreter/jit support for new sign-extension load insns
> > which adds a new mode (BPF_MEMSX).
> > Also add verifier support to recognize these insns and to
> > do proper verification with new insns. In verifier, besides
> > to deduce proper bounds for the dst_reg, probed memory access
> > is handled by remembering insn mode in insn->imm field so later
> > on proper jit insns can be emitted.
> > 
> > Signed-off-by: Yonghong Song <yhs@fb.com>
> > ---
> >  arch/x86/net/bpf_jit_comp.c    |  32 ++++++++-
> >  include/uapi/linux/bpf.h       |   1 +
> >  kernel/bpf/core.c              |  13 ++++
> >  kernel/bpf/verifier.c          | 125 +++++++++++++++++++++++++++------
> >  tools/include/uapi/linux/bpf.h |   1 +
> >  5 files changed, 151 insertions(+), 21 deletions(-)
> > 
> > diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
> > index 438adb695daa..addeea95f397 100644
> > --- a/arch/x86/net/bpf_jit_comp.c
> > +++ b/arch/x86/net/bpf_jit_comp.c
> > @@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
> >  	*pprog = prog;
> >  }
> >  
> > +/* LDX: dst_reg = *(s8*)(src_reg + off) */
> > +static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
> > +{
> > +	u8 *prog = *pprog;
> > +
> > +	switch (size) {
> > +	case BPF_B:
> > +		/* Emit 'movsx rax, byte ptr [rax + off]' */
> > +		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE);
> > +		break;
> > +	case BPF_H:
> > +		/* Emit 'movsx rax, word ptr [rax + off]' */
> > +		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF);
> > +		break;
> > +	case BPF_W:
> > +		/* Emit 'movsx rax, dword ptr [rax+0x14]' */
> > +		EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63);
> > +		break;
> > +	}
> > +	emit_insn_suffix(&prog, src_reg, dst_reg, off);
> > +	*pprog = prog;
> > +}
> > +
> >  /* STX: *(u8*)(dst_reg + off) = src_reg */
> >  static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
> >  {
> > @@ -1370,6 +1393,9 @@ st:			if (is_imm8(insn->off))
> >  		case BPF_LDX | BPF_PROBE_MEM | BPF_W:
> >  		case BPF_LDX | BPF_MEM | BPF_DW:
> >  		case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
> > +		case BPF_LDX | BPF_MEMSX | BPF_B:
> > +		case BPF_LDX | BPF_MEMSX | BPF_H:
> > +		case BPF_LDX | BPF_MEMSX | BPF_W:
> >  			insn_off = insn->off;
> >  
> >  			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
> > @@ -1415,7 +1441,11 @@ st:			if (is_imm8(insn->off))
> >  				start_of_ldx = prog;
> >  				end_of_jmp[-1] = start_of_ldx - end_of_jmp;
> >  			}
> > -			emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
> > +			if ((BPF_MODE(insn->code) == BPF_PROBE_MEM && insn->imm == BPF_MEMSX) ||
> > +			    BPF_MODE(insn->code) == BPF_MEMSX)
> > +				emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
> > +			else
> > +				emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
> >  			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
> >  				struct exception_table_entry *ex;
> >  				u8 *_insn = image + proglen + (start_of_ldx - temp);
> > diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
> > index 600d0caebbd8..c7196302d1eb 100644
> > --- a/include/uapi/linux/bpf.h
> > +++ b/include/uapi/linux/bpf.h
> > @@ -19,6 +19,7 @@
> >  
> >  /* ld/ldx fields */
> >  #define BPF_DW		0x18	/* double word (64-bit) */
> > +#define BPF_MEMSX	0x80	/* load with sign extension */
> >  #define BPF_ATOMIC	0xc0	/* atomic memory ops - op type in immediate */
> >  #define BPF_XADD	0xc0	/* exclusive add - legacy name */
> >  
> > diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
> > index dc85240a0134..8a1cc658789e 100644
> > --- a/kernel/bpf/core.c
> > +++ b/kernel/bpf/core.c
> > @@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base);
> >  	INSN_3(LDX, MEM, H),			\
> >  	INSN_3(LDX, MEM, W),			\
> >  	INSN_3(LDX, MEM, DW),			\
> > +	INSN_3(LDX, MEMSX, B),			\
> > +	INSN_3(LDX, MEMSX, H),			\
> > +	INSN_3(LDX, MEMSX, W),			\
> >  	/*   Immediate based. */		\
> >  	INSN_3(LD, IMM, DW)
> >  
> > @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn)
> >  	LDST(DW, u64)
> >  #undef LDST
> >  
> > +#define LDS(SIZEOP, SIZE)						\
> > +	LDX_MEMSX_##SIZEOP:						\
> > +		DST = *(SIZE *)(unsigned long) (SRC + insn->off);	\
> > +		CONT;
> > +
> > +	LDS(B,   s8)
> > +	LDS(H,  s16)
> > +	LDS(W,  s32)
> > +#undef LDS
> > +
> >  #define ATOMIC_ALU_OP(BOP, KOP)						\
> >  		case BOP:						\
> >  			if (BPF_SIZE(insn->code) == BPF_W)		\
> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > index 81a93eeac7a0..fbe4ca72d4c1 100644
> > --- a/kernel/bpf/verifier.c
> > +++ b/kernel/bpf/verifier.c
> > @@ -5795,6 +5795,77 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size)
> >  	__reg_combine_64_into_32(reg);
> >  }
> >  
> > +static void set_sext64_default_val(struct bpf_reg_state *reg, int size)
> > +{
> > +	if (size == 1) {
> > +		reg->smin_value = reg->s32_min_value = S8_MIN;
> > +		reg->smax_value = reg->s32_max_value = S8_MAX;
> > +	} else if (size == 2) {
> > +		reg->smin_value = reg->s32_min_value = S16_MIN;
> > +		reg->smax_value = reg->s32_max_value = S16_MAX;
> > +	} else {
> > +		/* size == 4 */
> > +		reg->smin_value = reg->s32_min_value = S32_MIN;
> > +		reg->smax_value = reg->s32_max_value = S32_MAX;
> > +	}
> > +	reg->umin_value = reg->u32_min_value = 0;
> > +	reg->umax_value = U64_MAX;
> > +	reg->u32_max_value = U32_MAX;
> > +	reg->var_off = tnum_unknown;
> > +}
> > +
> > +static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size)
> > +{
> > +	u64 top_smax_value, top_smin_value;
> > +	s64 init_s64_max, init_s64_min, s64_max, s64_min;
> > +	u64 num_bits = size * 8;
> > +
> > +	top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits;
> > +	top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits;
> > +
> > +	if (top_smax_value != top_smin_value)
> > +		goto out;
> > +
> > +	/* find the s64_min and s64_min after sign extension */
> > +	if (size == 1) {
> > +		init_s64_max = (s8)reg->smax_value;
> > +		init_s64_min = (s8)reg->smin_value;
> > +	} else if (size == 2) {
> > +		init_s64_max = (s16)reg->smax_value;
> > +		init_s64_min = (s16)reg->smin_value;
> > +	} else {
> > +		/* size == 4 */
> > +		init_s64_max = (s32)reg->smax_value;
> > +		init_s64_min = (s32)reg->smin_value;
> > +	}
> > +
> > +	s64_max = max(init_s64_max, init_s64_min);
> > +	s64_min = min(init_s64_max, init_s64_min);
> > +
> > +	if (s64_max >= 0 && s64_min >= 0) {
> > +		reg->smin_value = reg->s32_min_value = s64_min;
> > +		reg->smax_value = reg->s32_max_value = s64_max;
> > +		reg->umin_value = reg->u32_min_value = s64_min;
> > +		reg->umax_value = reg->u32_max_value = s64_max;
> > +		reg->var_off = tnum_range(s64_min, s64_max);
> > +		return;
> > +	}
> > +
> > +	if (s64_min < 0 && s64_max < 0) {
> > +		reg->smin_value = reg->s32_min_value = s64_min;
> > +		reg->smax_value = reg->s32_max_value = s64_max;
> > +		reg->umin_value = (u64)s64_max;
> > +		reg->umax_value = (u64)s64_min;
> > +		reg->u32_min_value = (u32)s64_max;
> > +		reg->u32_max_value = (u32)s64_min;
> > +		reg->var_off = tnum_range((u64)s64_max, (u64)s64_min);
> > +		return;
> > +	}
> > +
> > +out:
> > +	set_sext64_default_val(reg, size);
> > +}
> > +
> >  static bool bpf_map_is_rdonly(const struct bpf_map *map)
> >  {
> >  	/* A map is considered read-only if the following condition are true:
> > @@ -5815,7 +5886,8 @@ static bool bpf_map_is_rdonly(const struct bpf_map *map)
> >  	       !bpf_map_write_active(map);
> >  }
> >  
> > -static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
> > +static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val,
> > +			       bool is_ldsx)
> >  {
> >  	void *ptr;
> >  	u64 addr;
> > @@ -5828,13 +5900,13 @@ static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
> >  
> >  	switch (size) {
> >  	case sizeof(u8):
> > -		*val = (u64)*(u8 *)ptr;
> > +		*val = is_ldsx ? (s64)*(s8 *)ptr : (u64)*(u8 *)ptr;
> >  		break;
> >  	case sizeof(u16):
> > -		*val = (u64)*(u16 *)ptr;
> > +		*val = is_ldsx ? (s64)*(s16 *)ptr : (u64)*(u16 *)ptr;
> >  		break;
> >  	case sizeof(u32):
> > -		*val = (u64)*(u32 *)ptr;
> > +		*val = is_ldsx ? (s64)*(s32 *)ptr : (u64)*(u32 *)ptr;
> >  		break;
> >  	case sizeof(u64):
> >  		*val = *(u64 *)ptr;
> > @@ -6248,7 +6320,7 @@ static int check_stack_access_within_bounds(
> >   */
> >  static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regno,
> >  			    int off, int bpf_size, enum bpf_access_type t,
> > -			    int value_regno, bool strict_alignment_once)
> > +			    int value_regno, bool strict_alignment_once, bool is_ldsx)
> >  {
> >  	struct bpf_reg_state *regs = cur_regs(env);
> >  	struct bpf_reg_state *reg = regs + regno;
> > @@ -6309,7 +6381,7 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
> >  				u64 val = 0;
> >  
> >  				err = bpf_map_direct_read(map, map_off, size,
> > -							  &val);
> > +							  &val, is_ldsx);
> >  				if (err)
> >  					return err;
> >  
> > @@ -6479,8 +6551,11 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
> >  
> >  	if (!err && size < BPF_REG_SIZE && value_regno >= 0 && t == BPF_READ &&
> >  	    regs[value_regno].type == SCALAR_VALUE) {
> > -		/* b/h/w load zero-extends, mark upper bits as known 0 */
> > -		coerce_reg_to_size(&regs[value_regno], size);
> > +		if (!is_ldsx)
> > +			/* b/h/w load zero-extends, mark upper bits as known 0 */
> > +			coerce_reg_to_size(&regs[value_regno], size);
> > +		else
> > +			coerce_reg_to_size_sx(&regs[value_regno], size);
> >  	}
> >  	return err;
> >  }
> > @@ -6572,17 +6647,17 @@ static int check_atomic(struct bpf_verifier_env *env, int insn_idx, struct bpf_i
> >  	 * case to simulate the register fill.
> >  	 */
> >  	err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
> > -			       BPF_SIZE(insn->code), BPF_READ, -1, true);
> > +			       BPF_SIZE(insn->code), BPF_READ, -1, true, false);
> >  	if (!err && load_reg >= 0)
> >  		err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
> >  				       BPF_SIZE(insn->code), BPF_READ, load_reg,
> > -				       true);
> > +				       true, false);
> >  	if (err)
> >  		return err;
> >  
> >  	/* Check whether we can write into the same memory. */
> >  	err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
> > -			       BPF_SIZE(insn->code), BPF_WRITE, -1, true);
> > +			       BPF_SIZE(insn->code), BPF_WRITE, -1, true, false);
> >  	if (err)
> >  		return err;
> >  
> > @@ -6828,7 +6903,7 @@ static int check_helper_mem_access(struct bpf_verifier_env *env, int regno,
> >  				return zero_size_allowed ? 0 : -EACCES;
> >  
> >  			return check_mem_access(env, env->insn_idx, regno, offset, BPF_B,
> > -						atype, -1, false);
> > +						atype, -1, false, false);
> >  		}
> >  
> >  		fallthrough;
> > @@ -7200,7 +7275,7 @@ static int process_dynptr_func(struct bpf_verifier_env *env, int regno, int insn
> >  		/* we write BPF_DW bits (8 bytes) at a time */
> >  		for (i = 0; i < BPF_DYNPTR_SIZE; i += 8) {
> >  			err = check_mem_access(env, insn_idx, regno,
> > -					       i, BPF_DW, BPF_WRITE, -1, false);
> > +					       i, BPF_DW, BPF_WRITE, -1, false, false);
> >  			if (err)
> >  				return err;
> >  		}
> > @@ -7293,7 +7368,7 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id
> >  
> >  		for (i = 0; i < nr_slots * 8; i += BPF_REG_SIZE) {
> >  			err = check_mem_access(env, insn_idx, regno,
> > -					       i, BPF_DW, BPF_WRITE, -1, false);
> > +					       i, BPF_DW, BPF_WRITE, -1, false, false);
> >  			if (err)
> >  				return err;
> >  		}
> > @@ -9437,7 +9512,7 @@ static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn
> >  	 */
> >  	for (i = 0; i < meta.access_size; i++) {
> >  		err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B,
> > -				       BPF_WRITE, -1, false);
> > +				       BPF_WRITE, -1, false, false);
> >  		if (err)
> >  			return err;
> >  	}
> > @@ -16315,7 +16390,8 @@ static int do_check(struct bpf_verifier_env *env)
> >  			 */
> >  			err = check_mem_access(env, env->insn_idx, insn->src_reg,
> >  					       insn->off, BPF_SIZE(insn->code),
> > -					       BPF_READ, insn->dst_reg, false);
> > +					       BPF_READ, insn->dst_reg, false,
> > +					       BPF_MODE(insn->code) == BPF_MEMSX);
> >  			if (err)
> >  				return err;
> >  
> > @@ -16352,7 +16428,7 @@ static int do_check(struct bpf_verifier_env *env)
> >  			/* check that memory (dst_reg + off) is writeable */
> >  			err = check_mem_access(env, env->insn_idx, insn->dst_reg,
> >  					       insn->off, BPF_SIZE(insn->code),
> > -					       BPF_WRITE, insn->src_reg, false);
> > +					       BPF_WRITE, insn->src_reg, false, false);
> >  			if (err)
> >  				return err;
> >  
> > @@ -16377,7 +16453,7 @@ static int do_check(struct bpf_verifier_env *env)
> >  			/* check that memory (dst_reg + off) is writeable */
> >  			err = check_mem_access(env, env->insn_idx, insn->dst_reg,
> >  					       insn->off, BPF_SIZE(insn->code),
> > -					       BPF_WRITE, -1, false);
> > +					       BPF_WRITE, -1, false, false);
> >  			if (err)
> >  				return err;
> >  
> > @@ -16805,7 +16881,8 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
> >  
> >  	for (i = 0; i < insn_cnt; i++, insn++) {
> >  		if (BPF_CLASS(insn->code) == BPF_LDX &&
> > -		    (BPF_MODE(insn->code) != BPF_MEM || insn->imm != 0)) {
> > +		    ((BPF_MODE(insn->code) != BPF_MEM && BPF_MODE(insn->code) != BPF_MEMSX) ||
> > +		    insn->imm != 0)) {
> >  			verbose(env, "BPF_LDX uses reserved fields\n");
> >  			return -EINVAL;
> >  		}
> > @@ -17503,7 +17580,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
> >  		if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) ||
> >  		    insn->code == (BPF_LDX | BPF_MEM | BPF_H) ||
> >  		    insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
> > -		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) {
> > +		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW) ||
> > +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) ||
> > +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) ||
> > +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) {

Later in this function there is a code that deals with
`is_narrower_load` condition (line 17785 in my case).
This code handles the case when e.g. 1 byte is read from a 4 byte field.
It does so by first converting such load to 4 byte load and
than adding BPF_RSH and BPF_AND instructions.
It appears to me that this code should handle sign extension as well.

> >  			type = BPF_READ;
> >  		} else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) ||
> >  			   insn->code == (BPF_STX | BPF_MEM | BPF_H) ||
> > @@ -17562,6 +17642,11 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
> >  		 */
> >  		case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED:
> >  			if (type == BPF_READ) {
> > +				/* it is hard to differentiate that the
> > +				 * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX,
> > +				 * let us use insn->imm to remember it.
> > +				 */
> > +				insn->imm = BPF_MODE(insn->code);
> >  				insn->code = BPF_LDX | BPF_PROBE_MEM |
> >  					BPF_SIZE((insn)->code);
> >  				env->prog->aux->num_exentries++;
> > diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
> > index 600d0caebbd8..c7196302d1eb 100644
> > --- a/tools/include/uapi/linux/bpf.h
> > +++ b/tools/include/uapi/linux/bpf.h
> > @@ -19,6 +19,7 @@
> >  
> >  /* ld/ldx fields */
> >  #define BPF_DW		0x18	/* double word (64-bit) */
> > +#define BPF_MEMSX	0x80	/* load with sign extension */
> >  #define BPF_ATOMIC	0xc0	/* atomic memory ops - op type in immediate */
> >  #define BPF_XADD	0xc0	/* exclusive add - legacy name */
> >
Eduard Zingerman July 19, 2023, 12:15 a.m. UTC | #4
On Wed, 2023-07-12 at 23:07 -0700, Yonghong Song wrote:
> Add interpreter/jit support for new sign-extension load insns
> which adds a new mode (BPF_MEMSX).
> Also add verifier support to recognize these insns and to
> do proper verification with new insns. In verifier, besides
> to deduce proper bounds for the dst_reg, probed memory access
> is handled by remembering insn mode in insn->imm field so later
> on proper jit insns can be emitted.
> 
> Signed-off-by: Yonghong Song <yhs@fb.com>
> ---
>  arch/x86/net/bpf_jit_comp.c    |  32 ++++++++-
>  include/uapi/linux/bpf.h       |   1 +
>  kernel/bpf/core.c              |  13 ++++
>  kernel/bpf/verifier.c          | 125 +++++++++++++++++++++++++++------
>  tools/include/uapi/linux/bpf.h |   1 +
>  5 files changed, 151 insertions(+), 21 deletions(-)
> 
> diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
> index 438adb695daa..addeea95f397 100644
> --- a/arch/x86/net/bpf_jit_comp.c
> +++ b/arch/x86/net/bpf_jit_comp.c
> @@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
>  	*pprog = prog;
>  }
>  
> +/* LDX: dst_reg = *(s8*)(src_reg + off) */
> +static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
> +{
> +	u8 *prog = *pprog;
> +
> +	switch (size) {
> +	case BPF_B:
> +		/* Emit 'movsx rax, byte ptr [rax + off]' */
> +		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE);
> +		break;
> +	case BPF_H:
> +		/* Emit 'movsx rax, word ptr [rax + off]' */
> +		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF);
> +		break;
> +	case BPF_W:
> +		/* Emit 'movsx rax, dword ptr [rax+0x14]' */
> +		EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63);
> +		break;
> +	}
> +	emit_insn_suffix(&prog, src_reg, dst_reg, off);
> +	*pprog = prog;
> +}
> +
>  /* STX: *(u8*)(dst_reg + off) = src_reg */
>  static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
>  {
> @@ -1370,6 +1393,9 @@ st:			if (is_imm8(insn->off))
>  		case BPF_LDX | BPF_PROBE_MEM | BPF_W:
>  		case BPF_LDX | BPF_MEM | BPF_DW:
>  		case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
> +		case BPF_LDX | BPF_MEMSX | BPF_B:
> +		case BPF_LDX | BPF_MEMSX | BPF_H:
> +		case BPF_LDX | BPF_MEMSX | BPF_W:
>  			insn_off = insn->off;
>  
>  			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
> @@ -1415,7 +1441,11 @@ st:			if (is_imm8(insn->off))
>  				start_of_ldx = prog;
>  				end_of_jmp[-1] = start_of_ldx - end_of_jmp;
>  			}
> -			emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
> +			if ((BPF_MODE(insn->code) == BPF_PROBE_MEM && insn->imm == BPF_MEMSX) ||
> +			    BPF_MODE(insn->code) == BPF_MEMSX)
> +				emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
> +			else
> +				emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
>  			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
>  				struct exception_table_entry *ex;
>  				u8 *_insn = image + proglen + (start_of_ldx - temp);
> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
> index 600d0caebbd8..c7196302d1eb 100644
> --- a/include/uapi/linux/bpf.h
> +++ b/include/uapi/linux/bpf.h
> @@ -19,6 +19,7 @@
>  
>  /* ld/ldx fields */
>  #define BPF_DW		0x18	/* double word (64-bit) */
> +#define BPF_MEMSX	0x80	/* load with sign extension */
>  #define BPF_ATOMIC	0xc0	/* atomic memory ops - op type in immediate */
>  #define BPF_XADD	0xc0	/* exclusive add - legacy name */
>  
> diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
> index dc85240a0134..8a1cc658789e 100644
> --- a/kernel/bpf/core.c
> +++ b/kernel/bpf/core.c
> @@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base);
>  	INSN_3(LDX, MEM, H),			\
>  	INSN_3(LDX, MEM, W),			\
>  	INSN_3(LDX, MEM, DW),			\
> +	INSN_3(LDX, MEMSX, B),			\
> +	INSN_3(LDX, MEMSX, H),			\
> +	INSN_3(LDX, MEMSX, W),			\
>  	/*   Immediate based. */		\
>  	INSN_3(LD, IMM, DW)
>  
> @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn)
>  	LDST(DW, u64)
>  #undef LDST
>  
> +#define LDS(SIZEOP, SIZE)						\
> +	LDX_MEMSX_##SIZEOP:						\
> +		DST = *(SIZE *)(unsigned long) (SRC + insn->off);	\
> +		CONT;
> +
> +	LDS(B,   s8)
> +	LDS(H,  s16)
> +	LDS(W,  s32)
> +#undef LDS
> +
>  #define ATOMIC_ALU_OP(BOP, KOP)						\
>  		case BOP:						\
>  			if (BPF_SIZE(insn->code) == BPF_W)		\
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 81a93eeac7a0..fbe4ca72d4c1 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -5795,6 +5795,77 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size)
>  	__reg_combine_64_into_32(reg);
>  }
>  
> +static void set_sext64_default_val(struct bpf_reg_state *reg, int size)
> +{
> +	if (size == 1) {
> +		reg->smin_value = reg->s32_min_value = S8_MIN;
> +		reg->smax_value = reg->s32_max_value = S8_MAX;
> +	} else if (size == 2) {
> +		reg->smin_value = reg->s32_min_value = S16_MIN;
> +		reg->smax_value = reg->s32_max_value = S16_MAX;
> +	} else {
> +		/* size == 4 */
> +		reg->smin_value = reg->s32_min_value = S32_MIN;
> +		reg->smax_value = reg->s32_max_value = S32_MAX;
> +	}
> +	reg->umin_value = reg->u32_min_value = 0;
> +	reg->umax_value = U64_MAX;
> +	reg->u32_max_value = U32_MAX;
> +	reg->var_off = tnum_unknown;
> +}
> +
> +static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size)
> +{
> +	u64 top_smax_value, top_smin_value;
> +	s64 init_s64_max, init_s64_min, s64_max, s64_min;
> +	u64 num_bits = size * 8;
> +
> +	top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits;
> +	top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits;
> +
> +	if (top_smax_value != top_smin_value)
> +		goto out;
> +
> +	/* find the s64_min and s64_min after sign extension */
> +	if (size == 1) {
> +		init_s64_max = (s8)reg->smax_value;
> +		init_s64_min = (s8)reg->smin_value;
> +	} else if (size == 2) {
> +		init_s64_max = (s16)reg->smax_value;
> +		init_s64_min = (s16)reg->smin_value;
> +	} else {
> +		/* size == 4 */
> +		init_s64_max = (s32)reg->smax_value;
> +		init_s64_min = (s32)reg->smin_value;
> +	}
> +
> +	s64_max = max(init_s64_max, init_s64_min);
> +	s64_min = min(init_s64_max, init_s64_min);
> +
> +	if (s64_max >= 0 && s64_min >= 0) {
> +		reg->smin_value = reg->s32_min_value = s64_min;
> +		reg->smax_value = reg->s32_max_value = s64_max;
> +		reg->umin_value = reg->u32_min_value = s64_min;
> +		reg->umax_value = reg->u32_max_value = s64_max;
> +		reg->var_off = tnum_range(s64_min, s64_max);
> +		return;
> +	}
> +
> +	if (s64_min < 0 && s64_max < 0) {
> +		reg->smin_value = reg->s32_min_value = s64_min;
> +		reg->smax_value = reg->s32_max_value = s64_max;
> +		reg->umin_value = (u64)s64_max;
> +		reg->umax_value = (u64)s64_min;

I think the last two assignments are not correct for the following example:

{
	"testtesttest",
	.insns = {
		BPF_EMIT_CALL(BPF_FUNC_get_prandom_u32),
		BPF_JMP_IMM(BPF_JLT, BPF_REG_0, 0xff80, 2),
		BPF_JMP_IMM(BPF_JGT, BPF_REG_0, 0xffff, 1),
		{
			.code  = BPF_ALU64 | BPF_MOV | BPF_X,
			.dst_reg = BPF_REG_0,
			.src_reg = BPF_REG_0,
			.off   = 8,
			.imm   = 0,
		},
		BPF_EXIT_INSN(),
	},
	.result = ACCEPT,
	.retval = 0,
},

Here is execution log:

0: R1=ctx(off=0,imm=0) R10=fp0
0: (85) call bpf_get_prandom_u32#7 ; R0_w=Pscalar()
1: (a5) if r0 < 0xff80 goto pc+2   ; R0_w=Pscalar(umin=65408)
2: (25) if r0 > 0xffff goto pc+1   ; R0_w=Pscalar(umin=65408,umax=65535,var_off=(0xff80; 0x7f))
3: (bf) r0 = r0                    ; R0_w=Pscalar
                                      (smin=-128,smax=-1,
                                       umin=18'446'744'073'709'551'615,
                                       umax=18'446'744'073'709'551'488,
                                       var_off=(0xffffffffffffff80; 0x7f),
                                       u32_min=-1,u32_max=-128)
4: (95) exit

Note that umax < umin, which should not happen.
In this case the assignments in question are:

    reg->umin_value = (u64)s64_max; // == -1   == 0xffffffffffffffff
    reg->umax_value = (u64)s64_min; // == -128 == 0xffffffffffffff80


> +		reg->u32_min_value = (u32)s64_max;
> +		reg->u32_max_value = (u32)s64_min;
> +		reg->var_off = tnum_range((u64)s64_max, (u64)s64_min);
> +		return;
> +	}
> +
> +out:
> +	set_sext64_default_val(reg, size);
> +}
> +
>  static bool bpf_map_is_rdonly(const struct bpf_map *map)
>  {
>  	/* A map is considered read-only if the following condition are true:
> @@ -5815,7 +5886,8 @@ static bool bpf_map_is_rdonly(const struct bpf_map *map)
>  	       !bpf_map_write_active(map);
>  }
>  
> -static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
> +static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val,
> +			       bool is_ldsx)
>  {
>  	void *ptr;
>  	u64 addr;
> @@ -5828,13 +5900,13 @@ static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
>  
>  	switch (size) {
>  	case sizeof(u8):
> -		*val = (u64)*(u8 *)ptr;
> +		*val = is_ldsx ? (s64)*(s8 *)ptr : (u64)*(u8 *)ptr;
>  		break;
>  	case sizeof(u16):
> -		*val = (u64)*(u16 *)ptr;
> +		*val = is_ldsx ? (s64)*(s16 *)ptr : (u64)*(u16 *)ptr;
>  		break;
>  	case sizeof(u32):
> -		*val = (u64)*(u32 *)ptr;
> +		*val = is_ldsx ? (s64)*(s32 *)ptr : (u64)*(u32 *)ptr;
>  		break;
>  	case sizeof(u64):
>  		*val = *(u64 *)ptr;
> @@ -6248,7 +6320,7 @@ static int check_stack_access_within_bounds(
>   */
>  static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regno,
>  			    int off, int bpf_size, enum bpf_access_type t,
> -			    int value_regno, bool strict_alignment_once)
> +			    int value_regno, bool strict_alignment_once, bool is_ldsx)
>  {
>  	struct bpf_reg_state *regs = cur_regs(env);
>  	struct bpf_reg_state *reg = regs + regno;
> @@ -6309,7 +6381,7 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
>  				u64 val = 0;
>  
>  				err = bpf_map_direct_read(map, map_off, size,
> -							  &val);
> +							  &val, is_ldsx);
>  				if (err)
>  					return err;
>  
> @@ -6479,8 +6551,11 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
>  
>  	if (!err && size < BPF_REG_SIZE && value_regno >= 0 && t == BPF_READ &&
>  	    regs[value_regno].type == SCALAR_VALUE) {
> -		/* b/h/w load zero-extends, mark upper bits as known 0 */
> -		coerce_reg_to_size(&regs[value_regno], size);
> +		if (!is_ldsx)
> +			/* b/h/w load zero-extends, mark upper bits as known 0 */
> +			coerce_reg_to_size(&regs[value_regno], size);
> +		else
> +			coerce_reg_to_size_sx(&regs[value_regno], size);
>  	}
>  	return err;
>  }
> @@ -6572,17 +6647,17 @@ static int check_atomic(struct bpf_verifier_env *env, int insn_idx, struct bpf_i
>  	 * case to simulate the register fill.
>  	 */
>  	err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
> -			       BPF_SIZE(insn->code), BPF_READ, -1, true);
> +			       BPF_SIZE(insn->code), BPF_READ, -1, true, false);
>  	if (!err && load_reg >= 0)
>  		err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
>  				       BPF_SIZE(insn->code), BPF_READ, load_reg,
> -				       true);
> +				       true, false);
>  	if (err)
>  		return err;
>  
>  	/* Check whether we can write into the same memory. */
>  	err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
> -			       BPF_SIZE(insn->code), BPF_WRITE, -1, true);
> +			       BPF_SIZE(insn->code), BPF_WRITE, -1, true, false);
>  	if (err)
>  		return err;
>  
> @@ -6828,7 +6903,7 @@ static int check_helper_mem_access(struct bpf_verifier_env *env, int regno,
>  				return zero_size_allowed ? 0 : -EACCES;
>  
>  			return check_mem_access(env, env->insn_idx, regno, offset, BPF_B,
> -						atype, -1, false);
> +						atype, -1, false, false);
>  		}
>  
>  		fallthrough;
> @@ -7200,7 +7275,7 @@ static int process_dynptr_func(struct bpf_verifier_env *env, int regno, int insn
>  		/* we write BPF_DW bits (8 bytes) at a time */
>  		for (i = 0; i < BPF_DYNPTR_SIZE; i += 8) {
>  			err = check_mem_access(env, insn_idx, regno,
> -					       i, BPF_DW, BPF_WRITE, -1, false);
> +					       i, BPF_DW, BPF_WRITE, -1, false, false);
>  			if (err)
>  				return err;
>  		}
> @@ -7293,7 +7368,7 @@ static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id
>  
>  		for (i = 0; i < nr_slots * 8; i += BPF_REG_SIZE) {
>  			err = check_mem_access(env, insn_idx, regno,
> -					       i, BPF_DW, BPF_WRITE, -1, false);
> +					       i, BPF_DW, BPF_WRITE, -1, false, false);
>  			if (err)
>  				return err;
>  		}
> @@ -9437,7 +9512,7 @@ static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn
>  	 */
>  	for (i = 0; i < meta.access_size; i++) {
>  		err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B,
> -				       BPF_WRITE, -1, false);
> +				       BPF_WRITE, -1, false, false);
>  		if (err)
>  			return err;
>  	}
> @@ -16315,7 +16390,8 @@ static int do_check(struct bpf_verifier_env *env)
>  			 */
>  			err = check_mem_access(env, env->insn_idx, insn->src_reg,
>  					       insn->off, BPF_SIZE(insn->code),
> -					       BPF_READ, insn->dst_reg, false);
> +					       BPF_READ, insn->dst_reg, false,
> +					       BPF_MODE(insn->code) == BPF_MEMSX);
>  			if (err)
>  				return err;
>  
> @@ -16352,7 +16428,7 @@ static int do_check(struct bpf_verifier_env *env)
>  			/* check that memory (dst_reg + off) is writeable */
>  			err = check_mem_access(env, env->insn_idx, insn->dst_reg,
>  					       insn->off, BPF_SIZE(insn->code),
> -					       BPF_WRITE, insn->src_reg, false);
> +					       BPF_WRITE, insn->src_reg, false, false);
>  			if (err)
>  				return err;
>  
> @@ -16377,7 +16453,7 @@ static int do_check(struct bpf_verifier_env *env)
>  			/* check that memory (dst_reg + off) is writeable */
>  			err = check_mem_access(env, env->insn_idx, insn->dst_reg,
>  					       insn->off, BPF_SIZE(insn->code),
> -					       BPF_WRITE, -1, false);
> +					       BPF_WRITE, -1, false, false);
>  			if (err)
>  				return err;
>  
> @@ -16805,7 +16881,8 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
>  
>  	for (i = 0; i < insn_cnt; i++, insn++) {
>  		if (BPF_CLASS(insn->code) == BPF_LDX &&
> -		    (BPF_MODE(insn->code) != BPF_MEM || insn->imm != 0)) {
> +		    ((BPF_MODE(insn->code) != BPF_MEM && BPF_MODE(insn->code) != BPF_MEMSX) ||
> +		    insn->imm != 0)) {
>  			verbose(env, "BPF_LDX uses reserved fields\n");
>  			return -EINVAL;
>  		}
> @@ -17503,7 +17580,10 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
>  		if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) ||
>  		    insn->code == (BPF_LDX | BPF_MEM | BPF_H) ||
>  		    insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
> -		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) {
> +		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW) ||
> +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) ||
> +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) ||
> +		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) {
>  			type = BPF_READ;
>  		} else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) ||
>  			   insn->code == (BPF_STX | BPF_MEM | BPF_H) ||
> @@ -17562,6 +17642,11 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
>  		 */
>  		case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED:
>  			if (type == BPF_READ) {
> +				/* it is hard to differentiate that the
> +				 * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX,
> +				 * let us use insn->imm to remember it.
> +				 */
> +				insn->imm = BPF_MODE(insn->code);
>  				insn->code = BPF_LDX | BPF_PROBE_MEM |
>  					BPF_SIZE((insn)->code);
>  				env->prog->aux->num_exentries++;
> diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
> index 600d0caebbd8..c7196302d1eb 100644
> --- a/tools/include/uapi/linux/bpf.h
> +++ b/tools/include/uapi/linux/bpf.h
> @@ -19,6 +19,7 @@
>  
>  /* ld/ldx fields */
>  #define BPF_DW		0x18	/* double word (64-bit) */
> +#define BPF_MEMSX	0x80	/* load with sign extension */
>  #define BPF_ATOMIC	0xc0	/* atomic memory ops - op type in immediate */
>  #define BPF_XADD	0xc0	/* exclusive add - legacy name */
>
Yonghong Song July 19, 2023, 2:28 a.m. UTC | #5
On 7/18/23 5:15 PM, Eduard Zingerman wrote:
> On Wed, 2023-07-12 at 23:07 -0700, Yonghong Song wrote:
>> Add interpreter/jit support for new sign-extension load insns
>> which adds a new mode (BPF_MEMSX).
>> Also add verifier support to recognize these insns and to
>> do proper verification with new insns. In verifier, besides
>> to deduce proper bounds for the dst_reg, probed memory access
>> is handled by remembering insn mode in insn->imm field so later
>> on proper jit insns can be emitted.
>>
>> Signed-off-by: Yonghong Song <yhs@fb.com>
>> ---
>>   arch/x86/net/bpf_jit_comp.c    |  32 ++++++++-
>>   include/uapi/linux/bpf.h       |   1 +
>>   kernel/bpf/core.c              |  13 ++++
>>   kernel/bpf/verifier.c          | 125 +++++++++++++++++++++++++++------
>>   tools/include/uapi/linux/bpf.h |   1 +
>>   5 files changed, 151 insertions(+), 21 deletions(-)
>>
>> diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
>> index 438adb695daa..addeea95f397 100644
>> --- a/arch/x86/net/bpf_jit_comp.c
>> +++ b/arch/x86/net/bpf_jit_comp.c
>> @@ -779,6 +779,29 @@ static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
>>   	*pprog = prog;
>>   }
>>   
>> +/* LDX: dst_reg = *(s8*)(src_reg + off) */
>> +static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
>> +{
>> +	u8 *prog = *pprog;
>> +
>> +	switch (size) {
>> +	case BPF_B:
>> +		/* Emit 'movsx rax, byte ptr [rax + off]' */
>> +		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE);
>> +		break;
>> +	case BPF_H:
>> +		/* Emit 'movsx rax, word ptr [rax + off]' */
>> +		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF);
>> +		break;
>> +	case BPF_W:
>> +		/* Emit 'movsx rax, dword ptr [rax+0x14]' */
>> +		EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63);
>> +		break;
>> +	}
>> +	emit_insn_suffix(&prog, src_reg, dst_reg, off);
>> +	*pprog = prog;
>> +}
>> +
>>   /* STX: *(u8*)(dst_reg + off) = src_reg */
>>   static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
>>   {
>> @@ -1370,6 +1393,9 @@ st:			if (is_imm8(insn->off))
>>   		case BPF_LDX | BPF_PROBE_MEM | BPF_W:
>>   		case BPF_LDX | BPF_MEM | BPF_DW:
>>   		case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
>> +		case BPF_LDX | BPF_MEMSX | BPF_B:
>> +		case BPF_LDX | BPF_MEMSX | BPF_H:
>> +		case BPF_LDX | BPF_MEMSX | BPF_W:
>>   			insn_off = insn->off;
>>   
>>   			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
>> @@ -1415,7 +1441,11 @@ st:			if (is_imm8(insn->off))
>>   				start_of_ldx = prog;
>>   				end_of_jmp[-1] = start_of_ldx - end_of_jmp;
>>   			}
>> -			emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
>> +			if ((BPF_MODE(insn->code) == BPF_PROBE_MEM && insn->imm == BPF_MEMSX) ||
>> +			    BPF_MODE(insn->code) == BPF_MEMSX)
>> +				emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
>> +			else
>> +				emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
>>   			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
>>   				struct exception_table_entry *ex;
>>   				u8 *_insn = image + proglen + (start_of_ldx - temp);
>> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
>> index 600d0caebbd8..c7196302d1eb 100644
>> --- a/include/uapi/linux/bpf.h
>> +++ b/include/uapi/linux/bpf.h
>> @@ -19,6 +19,7 @@
>>   
>>   /* ld/ldx fields */
>>   #define BPF_DW		0x18	/* double word (64-bit) */
>> +#define BPF_MEMSX	0x80	/* load with sign extension */
>>   #define BPF_ATOMIC	0xc0	/* atomic memory ops - op type in immediate */
>>   #define BPF_XADD	0xc0	/* exclusive add - legacy name */
>>   
>> diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
>> index dc85240a0134..8a1cc658789e 100644
>> --- a/kernel/bpf/core.c
>> +++ b/kernel/bpf/core.c
>> @@ -1610,6 +1610,9 @@ EXPORT_SYMBOL_GPL(__bpf_call_base);
>>   	INSN_3(LDX, MEM, H),			\
>>   	INSN_3(LDX, MEM, W),			\
>>   	INSN_3(LDX, MEM, DW),			\
>> +	INSN_3(LDX, MEMSX, B),			\
>> +	INSN_3(LDX, MEMSX, H),			\
>> +	INSN_3(LDX, MEMSX, W),			\
>>   	/*   Immediate based. */		\
>>   	INSN_3(LD, IMM, DW)
>>   
>> @@ -1942,6 +1945,16 @@ static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn)
>>   	LDST(DW, u64)
>>   #undef LDST
>>   
>> +#define LDS(SIZEOP, SIZE)						\
>> +	LDX_MEMSX_##SIZEOP:						\
>> +		DST = *(SIZE *)(unsigned long) (SRC + insn->off);	\
>> +		CONT;
>> +
>> +	LDS(B,   s8)
>> +	LDS(H,  s16)
>> +	LDS(W,  s32)
>> +#undef LDS
>> +
>>   #define ATOMIC_ALU_OP(BOP, KOP)						\
>>   		case BOP:						\
>>   			if (BPF_SIZE(insn->code) == BPF_W)		\
>> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
>> index 81a93eeac7a0..fbe4ca72d4c1 100644
>> --- a/kernel/bpf/verifier.c
>> +++ b/kernel/bpf/verifier.c
>> @@ -5795,6 +5795,77 @@ static void coerce_reg_to_size(struct bpf_reg_state *reg, int size)
>>   	__reg_combine_64_into_32(reg);
>>   }
>>   
>> +static void set_sext64_default_val(struct bpf_reg_state *reg, int size)
>> +{
>> +	if (size == 1) {
>> +		reg->smin_value = reg->s32_min_value = S8_MIN;
>> +		reg->smax_value = reg->s32_max_value = S8_MAX;
>> +	} else if (size == 2) {
>> +		reg->smin_value = reg->s32_min_value = S16_MIN;
>> +		reg->smax_value = reg->s32_max_value = S16_MAX;
>> +	} else {
>> +		/* size == 4 */
>> +		reg->smin_value = reg->s32_min_value = S32_MIN;
>> +		reg->smax_value = reg->s32_max_value = S32_MAX;
>> +	}
>> +	reg->umin_value = reg->u32_min_value = 0;
>> +	reg->umax_value = U64_MAX;
>> +	reg->u32_max_value = U32_MAX;
>> +	reg->var_off = tnum_unknown;
>> +}
>> +
>> +static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size)
>> +{
>> +	u64 top_smax_value, top_smin_value;
>> +	s64 init_s64_max, init_s64_min, s64_max, s64_min;
>> +	u64 num_bits = size * 8;
>> +
>> +	top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits;
>> +	top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits;
>> +
>> +	if (top_smax_value != top_smin_value)
>> +		goto out;
>> +
>> +	/* find the s64_min and s64_min after sign extension */
>> +	if (size == 1) {
>> +		init_s64_max = (s8)reg->smax_value;
>> +		init_s64_min = (s8)reg->smin_value;
>> +	} else if (size == 2) {
>> +		init_s64_max = (s16)reg->smax_value;
>> +		init_s64_min = (s16)reg->smin_value;
>> +	} else {
>> +		/* size == 4 */
>> +		init_s64_max = (s32)reg->smax_value;
>> +		init_s64_min = (s32)reg->smin_value;
>> +	}
>> +
>> +	s64_max = max(init_s64_max, init_s64_min);
>> +	s64_min = min(init_s64_max, init_s64_min);
>> +
>> +	if (s64_max >= 0 && s64_min >= 0) {
>> +		reg->smin_value = reg->s32_min_value = s64_min;
>> +		reg->smax_value = reg->s32_max_value = s64_max;
>> +		reg->umin_value = reg->u32_min_value = s64_min;
>> +		reg->umax_value = reg->u32_max_value = s64_max;
>> +		reg->var_off = tnum_range(s64_min, s64_max);
>> +		return;
>> +	}
>> +
>> +	if (s64_min < 0 && s64_max < 0) {
>> +		reg->smin_value = reg->s32_min_value = s64_min;
>> +		reg->smax_value = reg->s32_max_value = s64_max;
>> +		reg->umin_value = (u64)s64_max;
>> +		reg->umax_value = (u64)s64_min;
> 
> I think the last two assignments are not correct for the following example:
> 
> {
> 	"testtesttest",
> 	.insns = {
> 		BPF_EMIT_CALL(BPF_FUNC_get_prandom_u32),
> 		BPF_JMP_IMM(BPF_JLT, BPF_REG_0, 0xff80, 2),
> 		BPF_JMP_IMM(BPF_JGT, BPF_REG_0, 0xffff, 1),
> 		{
> 			.code  = BPF_ALU64 | BPF_MOV | BPF_X,
> 			.dst_reg = BPF_REG_0,
> 			.src_reg = BPF_REG_0,
> 			.off   = 8,
> 			.imm   = 0,
> 		},
> 		BPF_EXIT_INSN(),
> 	},
> 	.result = ACCEPT,
> 	.retval = 0,
> },
> 
> Here is execution log:
> 
> 0: R1=ctx(off=0,imm=0) R10=fp0
> 0: (85) call bpf_get_prandom_u32#7 ; R0_w=Pscalar()
> 1: (a5) if r0 < 0xff80 goto pc+2   ; R0_w=Pscalar(umin=65408)
> 2: (25) if r0 > 0xffff goto pc+1   ; R0_w=Pscalar(umin=65408,umax=65535,var_off=(0xff80; 0x7f))
> 3: (bf) r0 = r0                    ; R0_w=Pscalar
>                                        (smin=-128,smax=-1,
>                                         umin=18'446'744'073'709'551'615,
>                                         umax=18'446'744'073'709'551'488,
>                                         var_off=(0xffffffffffffff80; 0x7f),
>                                         u32_min=-1,u32_max=-128)
> 4: (95) exit
> 
> Note that umax < umin, which should not happen.
> In this case the assignments in question are:
> 
>      reg->umin_value = (u64)s64_max; // == -1   == 0xffffffffffffffff
>      reg->umax_value = (u64)s64_min; // == -128 == 0xffffffffffffff80

Thanks for pointing out. Yes, the assignment is incorrect and they are
mismatched. Will fix the issue and add a test for this.

> 
> 
>> +		reg->u32_min_value = (u32)s64_max;
>> +		reg->u32_max_value = (u32)s64_min;
>> +		reg->var_off = tnum_range((u64)s64_max, (u64)s64_min);
>> +		return;
>> +	}
>> +
>> +out:
>> +	set_sext64_default_val(reg, size);
>> +}
>> +
>[...]
diff mbox series

Patch

diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 438adb695daa..addeea95f397 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -779,6 +779,29 @@  static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
 	*pprog = prog;
 }
 
+/* LDX: dst_reg = *(s8*)(src_reg + off) */
+static void emit_ldsx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
+{
+	u8 *prog = *pprog;
+
+	switch (size) {
+	case BPF_B:
+		/* Emit 'movsx rax, byte ptr [rax + off]' */
+		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBE);
+		break;
+	case BPF_H:
+		/* Emit 'movsx rax, word ptr [rax + off]' */
+		EMIT3(add_2mod(0x48, src_reg, dst_reg), 0x0F, 0xBF);
+		break;
+	case BPF_W:
+		/* Emit 'movsx rax, dword ptr [rax+0x14]' */
+		EMIT2(add_2mod(0x48, src_reg, dst_reg), 0x63);
+		break;
+	}
+	emit_insn_suffix(&prog, src_reg, dst_reg, off);
+	*pprog = prog;
+}
+
 /* STX: *(u8*)(dst_reg + off) = src_reg */
 static void emit_stx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
 {
@@ -1370,6 +1393,9 @@  st:			if (is_imm8(insn->off))
 		case BPF_LDX | BPF_PROBE_MEM | BPF_W:
 		case BPF_LDX | BPF_MEM | BPF_DW:
 		case BPF_LDX | BPF_PROBE_MEM | BPF_DW:
+		case BPF_LDX | BPF_MEMSX | BPF_B:
+		case BPF_LDX | BPF_MEMSX | BPF_H:
+		case BPF_LDX | BPF_MEMSX | BPF_W:
 			insn_off = insn->off;
 
 			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
@@ -1415,7 +1441,11 @@  st:			if (is_imm8(insn->off))
 				start_of_ldx = prog;
 				end_of_jmp[-1] = start_of_ldx - end_of_jmp;
 			}
-			emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
+			if ((BPF_MODE(insn->code) == BPF_PROBE_MEM && insn->imm == BPF_MEMSX) ||
+			    BPF_MODE(insn->code) == BPF_MEMSX)
+				emit_ldsx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
+			else
+				emit_ldx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, insn_off);
 			if (BPF_MODE(insn->code) == BPF_PROBE_MEM) {
 				struct exception_table_entry *ex;
 				u8 *_insn = image + proglen + (start_of_ldx - temp);
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index 600d0caebbd8..c7196302d1eb 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -19,6 +19,7 @@ 
 
 /* ld/ldx fields */
 #define BPF_DW		0x18	/* double word (64-bit) */
+#define BPF_MEMSX	0x80	/* load with sign extension */
 #define BPF_ATOMIC	0xc0	/* atomic memory ops - op type in immediate */
 #define BPF_XADD	0xc0	/* exclusive add - legacy name */
 
diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
index dc85240a0134..8a1cc658789e 100644
--- a/kernel/bpf/core.c
+++ b/kernel/bpf/core.c
@@ -1610,6 +1610,9 @@  EXPORT_SYMBOL_GPL(__bpf_call_base);
 	INSN_3(LDX, MEM, H),			\
 	INSN_3(LDX, MEM, W),			\
 	INSN_3(LDX, MEM, DW),			\
+	INSN_3(LDX, MEMSX, B),			\
+	INSN_3(LDX, MEMSX, H),			\
+	INSN_3(LDX, MEMSX, W),			\
 	/*   Immediate based. */		\
 	INSN_3(LD, IMM, DW)
 
@@ -1942,6 +1945,16 @@  static u64 ___bpf_prog_run(u64 *regs, const struct bpf_insn *insn)
 	LDST(DW, u64)
 #undef LDST
 
+#define LDS(SIZEOP, SIZE)						\
+	LDX_MEMSX_##SIZEOP:						\
+		DST = *(SIZE *)(unsigned long) (SRC + insn->off);	\
+		CONT;
+
+	LDS(B,   s8)
+	LDS(H,  s16)
+	LDS(W,  s32)
+#undef LDS
+
 #define ATOMIC_ALU_OP(BOP, KOP)						\
 		case BOP:						\
 			if (BPF_SIZE(insn->code) == BPF_W)		\
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 81a93eeac7a0..fbe4ca72d4c1 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -5795,6 +5795,77 @@  static void coerce_reg_to_size(struct bpf_reg_state *reg, int size)
 	__reg_combine_64_into_32(reg);
 }
 
+static void set_sext64_default_val(struct bpf_reg_state *reg, int size)
+{
+	if (size == 1) {
+		reg->smin_value = reg->s32_min_value = S8_MIN;
+		reg->smax_value = reg->s32_max_value = S8_MAX;
+	} else if (size == 2) {
+		reg->smin_value = reg->s32_min_value = S16_MIN;
+		reg->smax_value = reg->s32_max_value = S16_MAX;
+	} else {
+		/* size == 4 */
+		reg->smin_value = reg->s32_min_value = S32_MIN;
+		reg->smax_value = reg->s32_max_value = S32_MAX;
+	}
+	reg->umin_value = reg->u32_min_value = 0;
+	reg->umax_value = U64_MAX;
+	reg->u32_max_value = U32_MAX;
+	reg->var_off = tnum_unknown;
+}
+
+static void coerce_reg_to_size_sx(struct bpf_reg_state *reg, int size)
+{
+	u64 top_smax_value, top_smin_value;
+	s64 init_s64_max, init_s64_min, s64_max, s64_min;
+	u64 num_bits = size * 8;
+
+	top_smax_value = ((u64)reg->smax_value >> num_bits) << num_bits;
+	top_smin_value = ((u64)reg->smin_value >> num_bits) << num_bits;
+
+	if (top_smax_value != top_smin_value)
+		goto out;
+
+	/* find the s64_min and s64_min after sign extension */
+	if (size == 1) {
+		init_s64_max = (s8)reg->smax_value;
+		init_s64_min = (s8)reg->smin_value;
+	} else if (size == 2) {
+		init_s64_max = (s16)reg->smax_value;
+		init_s64_min = (s16)reg->smin_value;
+	} else {
+		/* size == 4 */
+		init_s64_max = (s32)reg->smax_value;
+		init_s64_min = (s32)reg->smin_value;
+	}
+
+	s64_max = max(init_s64_max, init_s64_min);
+	s64_min = min(init_s64_max, init_s64_min);
+
+	if (s64_max >= 0 && s64_min >= 0) {
+		reg->smin_value = reg->s32_min_value = s64_min;
+		reg->smax_value = reg->s32_max_value = s64_max;
+		reg->umin_value = reg->u32_min_value = s64_min;
+		reg->umax_value = reg->u32_max_value = s64_max;
+		reg->var_off = tnum_range(s64_min, s64_max);
+		return;
+	}
+
+	if (s64_min < 0 && s64_max < 0) {
+		reg->smin_value = reg->s32_min_value = s64_min;
+		reg->smax_value = reg->s32_max_value = s64_max;
+		reg->umin_value = (u64)s64_max;
+		reg->umax_value = (u64)s64_min;
+		reg->u32_min_value = (u32)s64_max;
+		reg->u32_max_value = (u32)s64_min;
+		reg->var_off = tnum_range((u64)s64_max, (u64)s64_min);
+		return;
+	}
+
+out:
+	set_sext64_default_val(reg, size);
+}
+
 static bool bpf_map_is_rdonly(const struct bpf_map *map)
 {
 	/* A map is considered read-only if the following condition are true:
@@ -5815,7 +5886,8 @@  static bool bpf_map_is_rdonly(const struct bpf_map *map)
 	       !bpf_map_write_active(map);
 }
 
-static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
+static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val,
+			       bool is_ldsx)
 {
 	void *ptr;
 	u64 addr;
@@ -5828,13 +5900,13 @@  static int bpf_map_direct_read(struct bpf_map *map, int off, int size, u64 *val)
 
 	switch (size) {
 	case sizeof(u8):
-		*val = (u64)*(u8 *)ptr;
+		*val = is_ldsx ? (s64)*(s8 *)ptr : (u64)*(u8 *)ptr;
 		break;
 	case sizeof(u16):
-		*val = (u64)*(u16 *)ptr;
+		*val = is_ldsx ? (s64)*(s16 *)ptr : (u64)*(u16 *)ptr;
 		break;
 	case sizeof(u32):
-		*val = (u64)*(u32 *)ptr;
+		*val = is_ldsx ? (s64)*(s32 *)ptr : (u64)*(u32 *)ptr;
 		break;
 	case sizeof(u64):
 		*val = *(u64 *)ptr;
@@ -6248,7 +6320,7 @@  static int check_stack_access_within_bounds(
  */
 static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regno,
 			    int off, int bpf_size, enum bpf_access_type t,
-			    int value_regno, bool strict_alignment_once)
+			    int value_regno, bool strict_alignment_once, bool is_ldsx)
 {
 	struct bpf_reg_state *regs = cur_regs(env);
 	struct bpf_reg_state *reg = regs + regno;
@@ -6309,7 +6381,7 @@  static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
 				u64 val = 0;
 
 				err = bpf_map_direct_read(map, map_off, size,
-							  &val);
+							  &val, is_ldsx);
 				if (err)
 					return err;
 
@@ -6479,8 +6551,11 @@  static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
 
 	if (!err && size < BPF_REG_SIZE && value_regno >= 0 && t == BPF_READ &&
 	    regs[value_regno].type == SCALAR_VALUE) {
-		/* b/h/w load zero-extends, mark upper bits as known 0 */
-		coerce_reg_to_size(&regs[value_regno], size);
+		if (!is_ldsx)
+			/* b/h/w load zero-extends, mark upper bits as known 0 */
+			coerce_reg_to_size(&regs[value_regno], size);
+		else
+			coerce_reg_to_size_sx(&regs[value_regno], size);
 	}
 	return err;
 }
@@ -6572,17 +6647,17 @@  static int check_atomic(struct bpf_verifier_env *env, int insn_idx, struct bpf_i
 	 * case to simulate the register fill.
 	 */
 	err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
-			       BPF_SIZE(insn->code), BPF_READ, -1, true);
+			       BPF_SIZE(insn->code), BPF_READ, -1, true, false);
 	if (!err && load_reg >= 0)
 		err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
 				       BPF_SIZE(insn->code), BPF_READ, load_reg,
-				       true);
+				       true, false);
 	if (err)
 		return err;
 
 	/* Check whether we can write into the same memory. */
 	err = check_mem_access(env, insn_idx, insn->dst_reg, insn->off,
-			       BPF_SIZE(insn->code), BPF_WRITE, -1, true);
+			       BPF_SIZE(insn->code), BPF_WRITE, -1, true, false);
 	if (err)
 		return err;
 
@@ -6828,7 +6903,7 @@  static int check_helper_mem_access(struct bpf_verifier_env *env, int regno,
 				return zero_size_allowed ? 0 : -EACCES;
 
 			return check_mem_access(env, env->insn_idx, regno, offset, BPF_B,
-						atype, -1, false);
+						atype, -1, false, false);
 		}
 
 		fallthrough;
@@ -7200,7 +7275,7 @@  static int process_dynptr_func(struct bpf_verifier_env *env, int regno, int insn
 		/* we write BPF_DW bits (8 bytes) at a time */
 		for (i = 0; i < BPF_DYNPTR_SIZE; i += 8) {
 			err = check_mem_access(env, insn_idx, regno,
-					       i, BPF_DW, BPF_WRITE, -1, false);
+					       i, BPF_DW, BPF_WRITE, -1, false, false);
 			if (err)
 				return err;
 		}
@@ -7293,7 +7368,7 @@  static int process_iter_arg(struct bpf_verifier_env *env, int regno, int insn_id
 
 		for (i = 0; i < nr_slots * 8; i += BPF_REG_SIZE) {
 			err = check_mem_access(env, insn_idx, regno,
-					       i, BPF_DW, BPF_WRITE, -1, false);
+					       i, BPF_DW, BPF_WRITE, -1, false, false);
 			if (err)
 				return err;
 		}
@@ -9437,7 +9512,7 @@  static int check_helper_call(struct bpf_verifier_env *env, struct bpf_insn *insn
 	 */
 	for (i = 0; i < meta.access_size; i++) {
 		err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B,
-				       BPF_WRITE, -1, false);
+				       BPF_WRITE, -1, false, false);
 		if (err)
 			return err;
 	}
@@ -16315,7 +16390,8 @@  static int do_check(struct bpf_verifier_env *env)
 			 */
 			err = check_mem_access(env, env->insn_idx, insn->src_reg,
 					       insn->off, BPF_SIZE(insn->code),
-					       BPF_READ, insn->dst_reg, false);
+					       BPF_READ, insn->dst_reg, false,
+					       BPF_MODE(insn->code) == BPF_MEMSX);
 			if (err)
 				return err;
 
@@ -16352,7 +16428,7 @@  static int do_check(struct bpf_verifier_env *env)
 			/* check that memory (dst_reg + off) is writeable */
 			err = check_mem_access(env, env->insn_idx, insn->dst_reg,
 					       insn->off, BPF_SIZE(insn->code),
-					       BPF_WRITE, insn->src_reg, false);
+					       BPF_WRITE, insn->src_reg, false, false);
 			if (err)
 				return err;
 
@@ -16377,7 +16453,7 @@  static int do_check(struct bpf_verifier_env *env)
 			/* check that memory (dst_reg + off) is writeable */
 			err = check_mem_access(env, env->insn_idx, insn->dst_reg,
 					       insn->off, BPF_SIZE(insn->code),
-					       BPF_WRITE, -1, false);
+					       BPF_WRITE, -1, false, false);
 			if (err)
 				return err;
 
@@ -16805,7 +16881,8 @@  static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
 
 	for (i = 0; i < insn_cnt; i++, insn++) {
 		if (BPF_CLASS(insn->code) == BPF_LDX &&
-		    (BPF_MODE(insn->code) != BPF_MEM || insn->imm != 0)) {
+		    ((BPF_MODE(insn->code) != BPF_MEM && BPF_MODE(insn->code) != BPF_MEMSX) ||
+		    insn->imm != 0)) {
 			verbose(env, "BPF_LDX uses reserved fields\n");
 			return -EINVAL;
 		}
@@ -17503,7 +17580,10 @@  static int convert_ctx_accesses(struct bpf_verifier_env *env)
 		if (insn->code == (BPF_LDX | BPF_MEM | BPF_B) ||
 		    insn->code == (BPF_LDX | BPF_MEM | BPF_H) ||
 		    insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
-		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW)) {
+		    insn->code == (BPF_LDX | BPF_MEM | BPF_DW) ||
+		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_B) ||
+		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_H) ||
+		    insn->code == (BPF_LDX | BPF_MEMSX | BPF_W)) {
 			type = BPF_READ;
 		} else if (insn->code == (BPF_STX | BPF_MEM | BPF_B) ||
 			   insn->code == (BPF_STX | BPF_MEM | BPF_H) ||
@@ -17562,6 +17642,11 @@  static int convert_ctx_accesses(struct bpf_verifier_env *env)
 		 */
 		case PTR_TO_BTF_ID | MEM_ALLOC | PTR_UNTRUSTED:
 			if (type == BPF_READ) {
+				/* it is hard to differentiate that the
+				 * BPF_PROBE_MEM is for BPF_MEM or BPF_MEMSX,
+				 * let us use insn->imm to remember it.
+				 */
+				insn->imm = BPF_MODE(insn->code);
 				insn->code = BPF_LDX | BPF_PROBE_MEM |
 					BPF_SIZE((insn)->code);
 				env->prog->aux->num_exentries++;
diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
index 600d0caebbd8..c7196302d1eb 100644
--- a/tools/include/uapi/linux/bpf.h
+++ b/tools/include/uapi/linux/bpf.h
@@ -19,6 +19,7 @@ 
 
 /* ld/ldx fields */
 #define BPF_DW		0x18	/* double word (64-bit) */
+#define BPF_MEMSX	0x80	/* load with sign extension */
 #define BPF_ATOMIC	0xc0	/* atomic memory ops - op type in immediate */
 #define BPF_XADD	0xc0	/* exclusive add - legacy name */