diff mbox series

[bpf-next,01/13] bpf: generalize reg_set_min_max() to handle non-const register comparisons

Message ID 20231103000822.2509815-2-andrii@kernel.org (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series BPF register bounds range vs range support | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for bpf-next, async
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 1413 this patch: 1413
netdev/cc_maintainers warning 8 maintainers not CCed: jolsa@kernel.org sdf@google.com john.fastabend@gmail.com kpsingh@kernel.org song@kernel.org yonghong.song@linux.dev haoluo@google.com martin.lau@linux.dev
netdev/build_clang success Errors and warnings before: 1379 this patch: 1379
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 1441 this patch: 1441
netdev/checkpatch warning WARNING: Prefer 'fallthrough;' over fallthrough comment WARNING: line length of 81 exceeds 80 columns WARNING: line length of 82 exceeds 80 columns WARNING: line length of 83 exceeds 80 columns WARNING: line length of 85 exceeds 80 columns WARNING: line length of 87 exceeds 80 columns WARNING: line length of 88 exceeds 80 columns WARNING: line length of 89 exceeds 80 columns WARNING: line length of 90 exceeds 80 columns WARNING: line length of 91 exceeds 80 columns WARNING: line length of 92 exceeds 80 columns WARNING: line length of 96 exceeds 80 columns WARNING: line length of 99 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0
bpf/vmtest-bpf-next-VM_Test-0 success Logs for Lint
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-2 success Logs for Validate matrix.py
bpf/vmtest-bpf-next-VM_Test-3 success Logs for aarch64-gcc / build / build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-8 success Logs for aarch64-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-4 success Logs for aarch64-gcc / test (test_maps, false, 360) / test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-5 success Logs for aarch64-gcc / test (test_progs, false, 360) / test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-6 success Logs for aarch64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-7 success Logs for aarch64-gcc / test (test_verifier, false, 360) / test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-9 success Logs for s390x-gcc / build / build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-14 success Logs for s390x-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-15 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-16 success Logs for x86_64-gcc / build / build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-17 success Logs for x86_64-gcc / test (test_maps, false, 360) / test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-18 success Logs for x86_64-gcc / test (test_progs, false, 360) / test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-19 success Logs for x86_64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-20 success Logs for x86_64-gcc / test (test_progs_no_alu32_parallel, true, 30) / test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-21 success Logs for x86_64-gcc / test (test_progs_parallel, true, 30) / test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for x86_64-gcc / test (test_verifier, false, 360) / test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for x86_64-gcc / veristat / veristat on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-24 success Logs for x86_64-llvm-16 / build / build for x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-25 success Logs for x86_64-llvm-16 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-26 success Logs for x86_64-llvm-16 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-27 success Logs for x86_64-llvm-16 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-28 success Logs for x86_64-llvm-16 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-29 success Logs for x86_64-llvm-16 / veristat
bpf/vmtest-bpf-next-VM_Test-11 success Logs for s390x-gcc / test (test_progs, false, 360) / test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-12 success Logs for s390x-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-13 success Logs for s390x-gcc / test (test_verifier, false, 360) / test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-10 success Logs for s390x-gcc / test (test_maps, false, 360) / test_maps on s390x with gcc
bpf/vmtest-bpf-next-PR success PR summary

Commit Message

Andrii Nakryiko Nov. 3, 2023, 12:08 a.m. UTC
Generalize bounds adjustment logic of reg_set_min_max() to handle not
just register vs constant case, but in general any register vs any
register cases. For most of the operations it's trivial extension based
on range vs range comparison logic, we just need to properly pick
min/max of a range to compare against min/max of the other range.

For BPF_JSET we keep the original capabilities, just make sure JSET is
integrated in the common framework. This is manifested in the
internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more
uniform rev_opcode() handling. See the code for details. This allows to
reuse the same code exactly both for TRUE and FALSE branches without
explicitly handling both conditions with custom code.

Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE
case none of the registers are constants. This is now just a normal
generic case handled by reg_set_min_max().

To make tnum handling cleaner, tnum_with_subreg() helper is added, as
that's a common operator when dealing with 32-bit subregister bounds.
This keeps the overall logic much less noisy when it comes to tnums.

Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
---
 include/linux/tnum.h  |   4 +
 kernel/bpf/tnum.c     |   7 +-
 kernel/bpf/verifier.c | 327 ++++++++++++++++++++----------------------
 3 files changed, 165 insertions(+), 173 deletions(-)

Comments

Shung-Hsi Yu Nov. 3, 2023, 7:52 a.m. UTC | #1
On Thu, Nov 02, 2023 at 05:08:10PM -0700, Andrii Nakryiko wrote:
> Generalize bounds adjustment logic of reg_set_min_max() to handle not
> just register vs constant case, but in general any register vs any
> register cases. For most of the operations it's trivial extension based
> on range vs range comparison logic, we just need to properly pick
> min/max of a range to compare against min/max of the other range.
> 
> For BPF_JSET we keep the original capabilities, just make sure JSET is
> integrated in the common framework. This is manifested in the
> internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more
                    ^ typo?

Two more comments below

> uniform rev_opcode() handling. See the code for details. This allows to
> reuse the same code exactly both for TRUE and FALSE branches without
> explicitly handling both conditions with custom code.
> 
> Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE
> case none of the registers are constants. This is now just a normal
> generic case handled by reg_set_min_max().
> 
> To make tnum handling cleaner, tnum_with_subreg() helper is added, as
> that's a common operator when dealing with 32-bit subregister bounds.
> This keeps the overall logic much less noisy when it comes to tnums.
> 
> Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> ---
>  include/linux/tnum.h  |   4 +
>  kernel/bpf/tnum.c     |   7 +-
>  kernel/bpf/verifier.c | 327 ++++++++++++++++++++----------------------
>  3 files changed, 165 insertions(+), 173 deletions(-)
> 
> diff --git a/include/linux/tnum.h b/include/linux/tnum.h
> index 1c3948a1d6ad..3c13240077b8 100644
> --- a/include/linux/tnum.h
> +++ b/include/linux/tnum.h
> @@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a);
>  struct tnum tnum_subreg(struct tnum a);
>  /* Returns the tnum with the lower 32-bit subreg cleared */
>  struct tnum tnum_clear_subreg(struct tnum a);
> +/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower
> + * 32-bit subreg in *subreg*
> + */
> +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg);
>  /* Returns the tnum with the lower 32-bit subreg set to value */
>  struct tnum tnum_const_subreg(struct tnum a, u32 value);
>  /* Returns true if 32-bit subreg @a is a known constant*/
> diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
> index 3d7127f439a1..f4c91c9b27d7 100644
> --- a/kernel/bpf/tnum.c
> +++ b/kernel/bpf/tnum.c
> @@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a)
>  	return tnum_lshift(tnum_rshift(a, 32), 32);
>  }
>  
> +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
> +{
> +	return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
> +}
> +
>  struct tnum tnum_const_subreg(struct tnum a, u32 value)
>  {
> -	return tnum_or(tnum_clear_subreg(a), tnum_const(value));
> +	return tnum_with_subreg(a, tnum_const(value));
>  }
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 2197385d91dc..52934080042c 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -14379,218 +14379,211 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
>  	return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
>  }
>  
> -/* Adjusts the register min/max values in the case that the dst_reg and
> - * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
> - * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
> - * Technically we can do similar adjustments for pointers to the same object,
> - * but we don't support that right now.
> +/* Opcode that corresponds to a *false* branch condition.
> + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2
>   */
> -static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> -			    struct bpf_reg_state *true_reg2,
> -			    struct bpf_reg_state *false_reg1,
> -			    struct bpf_reg_state *false_reg2,
> -			    u8 opcode, bool is_jmp32)
> +static u8 rev_opcode(u8 opcode)

Nit: rev_opcode and flip_opcode seems like a possible source of confusing
down the line. Flip and reverse are often interchangable, i.e. "flip the
order" and "reverse the order" is the same thing.

Maybe "neg_opcode" or "neg_cond_opcode"?

Or do it the otherway around, keep rev_opcode but rename flip_opcode.

One more comment about BPF_JSET below

>  {
> -	struct tnum false_32off, false_64off;
> -	struct tnum true_32off, true_64off;
> -	u64 uval;
> -	u32 uval32;
> -	s64 sval;
> -	s32 sval32;
> -
> -	/* If either register is a pointer, we can't learn anything about its
> -	 * variable offset from the compare (unless they were a pointer into
> -	 * the same object, but we don't bother with that).
> +	switch (opcode) {
> +	case BPF_JEQ:		return BPF_JNE;
> +	case BPF_JNE:		return BPF_JEQ;
> +	/* JSET doesn't have it's reverse opcode in BPF, so add
> +	 * BPF_X flag to denote the reverse of that operation
>  	 */
> -	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
> -		return;
> -
> -	/* we expect right-hand registers (src ones) to be constants, for now */
> -	if (!is_reg_const(false_reg2, is_jmp32)) {
> -		opcode = flip_opcode(opcode);
> -		swap(true_reg1, true_reg2);
> -		swap(false_reg1, false_reg2);
> +	case BPF_JSET:		return BPF_JSET | BPF_X;
> +	case BPF_JSET | BPF_X:	return BPF_JSET;
> +	case BPF_JGE:		return BPF_JLT;
> +	case BPF_JGT:		return BPF_JLE;
> +	case BPF_JLE:		return BPF_JGT;
> +	case BPF_JLT:		return BPF_JGE;
> +	case BPF_JSGE:		return BPF_JSLT;
> +	case BPF_JSGT:		return BPF_JSLE;
> +	case BPF_JSLE:		return BPF_JSGT;
> +	case BPF_JSLT:		return BPF_JSGE;
> +	default:		return 0;
>  	}
> -	if (!is_reg_const(false_reg2, is_jmp32))
> -		return;
> +}
>  
> -	false_32off = tnum_subreg(false_reg1->var_off);
> -	false_64off = false_reg1->var_off;
> -	true_32off = tnum_subreg(true_reg1->var_off);
> -	true_64off = true_reg1->var_off;
> -	uval = false_reg2->var_off.value;
> -	uval32 = (u32)tnum_subreg(false_reg2->var_off).value;
> -	sval = (s64)uval;
> -	sval32 = (s32)uval32;
> +/* Refine range knowledge for <reg1> <op> <reg>2 conditional operation. */
> +static void regs_refine_cond_op(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
> +				u8 opcode, bool is_jmp32)
> +{
> +	struct tnum t;
>  
>  	switch (opcode) {
> -	/* JEQ/JNE comparison doesn't change the register equivalence.
> -	 *
> -	 * r1 = r2;
> -	 * if (r1 == 42) goto label;
> -	 * ...
> -	 * label: // here both r1 and r2 are known to be 42.
> -	 *
> -	 * Hence when marking register as known preserve it's ID.
> -	 */
>  	case BPF_JEQ:
>  		if (is_jmp32) {
> -			__mark_reg32_known(true_reg1, uval32);
> -			true_32off = tnum_subreg(true_reg1->var_off);
> +			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
> +			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
> +			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
> +			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
> +			reg2->u32_min_value = reg1->u32_min_value;
> +			reg2->u32_max_value = reg1->u32_max_value;
> +			reg2->s32_min_value = reg1->s32_min_value;
> +			reg2->s32_max_value = reg1->s32_max_value;
> +
> +			t = tnum_intersect(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
> +			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> +			reg2->var_off = tnum_with_subreg(reg2->var_off, t);
>  		} else {
> -			___mark_reg_known(true_reg1, uval);
> -			true_64off = true_reg1->var_off;
> +			reg1->umin_value = max(reg1->umin_value, reg2->umin_value);
> +			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
> +			reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
> +			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
> +			reg2->umin_value = reg1->umin_value;
> +			reg2->umax_value = reg1->umax_value;
> +			reg2->smin_value = reg1->smin_value;
> +			reg2->smax_value = reg1->smax_value;
> +
> +			reg1->var_off = tnum_intersect(reg1->var_off, reg2->var_off);
> +			reg2->var_off = reg1->var_off;
>  		}
>  		break;
>  	case BPF_JNE:
> +		/* we don't derive any new information for inequality yet */
> +		break;
> +	case BPF_JSET:
> +	case BPF_JSET | BPF_X: { /* BPF_JSET and its reverse, see rev_opcode() */
> +		u64 val;
> +
> +		if (!is_reg_const(reg2, is_jmp32))
> +			swap(reg1, reg2);
> +		if (!is_reg_const(reg2, is_jmp32))
> +			break;
> +
> +		val = reg_const_value(reg2, is_jmp32);
> +		/* BPF_JSET (i.e., TRUE branch, *not* BPF_JSET | BPF_X)
> +		 * requires single bit to learn something useful. E.g., if we
> +		 * know that `r1 & 0x3` is true, then which bits (0, 1, or both)
> +		 * are actually set? We can learn something definite only if
> +		 * it's a single-bit value to begin with.
> +		 *
> +		 * BPF_JSET | BPF_X (i.e., negation of BPF_JSET) doesn't have
> +		 * this restriction. I.e., !(r1 & 0x3) means neither bit 0 nor
> +		 * bit 1 is set, which we can readily use in adjustments.
> +		 */
> +		if (!(opcode & BPF_X) && !is_power_of_2(val))
> +			break;
> +
>  		if (is_jmp32) {
> -			__mark_reg32_known(false_reg1, uval32);
> -			false_32off = tnum_subreg(false_reg1->var_off);
> +			if (opcode & BPF_X)
> +				t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val));
> +			else
> +				t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val));
> +			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
>  		} else {
> -			___mark_reg_known(false_reg1, uval);
> -			false_64off = false_reg1->var_off;
> +			if (opcode & BPF_X)
> +				reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val));
> +			else
> +				reg1->var_off = tnum_or(reg1->var_off, tnum_const(val));
>  		}
>  		break;

Since you're already adding a tnum helper, I think we can add one more
for BPF_JSET here

	struct tnum tnum_neg(struct tnum a)
	{
		return TNUM(~a.value, a.mask);
	}

So instead of getting a value out of tnum then putting the value back
into tnum again

    u64 val;
    val = reg_const_value(reg2, is_jmp32);
    tnum_ops(..., tnum_const(val or ~val);

Keep the value in tnum and process it as-is if possible

    tnum_ops(..., reg2->var_off or tnum_neg(reg2->var_off));

And with that hopefully make this fragment short enough that we don't
mind duplicate a bit of code to seperate the BPF_JSET case from the
BPF_JSET | BPF_X case. IMO a conditional is_power_of_2 check followed by
two level of branching is a bit too much to follow, it is better to have
them seperated just like how you're doing it for the others already.

I.e. something like the follow

	case BPF_JSET: {
		if (!is_reg_const(reg2, is_jmp32))
			swap(reg1, reg2);
		if (!is_reg_const(reg2, is_jmp32))
			break;
		/* comment */
		if (!is_power_of_2(reg_const_value(reg2, is_jmp32))
			break;

		if (is_jmp32) {
			t = tnum_or(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
		} else {
			reg1->var_off = tnum_or(reg1->var_off, reg2->var_off);
		}
		break;
	}
	case BPF_JSET | BPF_X: {
		if (!is_reg_const(reg2, is_jmp32))
			swap(reg1, reg2);
		if (!is_reg_const(reg2, is_jmp32))
			break;

		if (is_jmp32) {
			/* a slightly long line ... */
			t = tnum_and(tnum_subreg(reg1->var_off), tnum_neg(tnum_subreg(reg2->var_off)));
			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
		} else {
			reg1->var_off = tnum_and(reg1->var_off, tnum_neg(reg2->var_off));
		}
		break;
	}

> ...
Shung-Hsi Yu Nov. 3, 2023, 8:33 a.m. UTC | #2
On Fri, Nov 03, 2023 at 03:52:36PM +0800, Shung-Hsi Yu wrote:
> On Thu, Nov 02, 2023 at 05:08:10PM -0700, Andrii Nakryiko wrote:
> > Generalize bounds adjustment logic of reg_set_min_max() to handle not
> > just register vs constant case, but in general any register vs any
> > register cases. For most of the operations it's trivial extension based
> > on range vs range comparison logic, we just need to properly pick
> > min/max of a range to compare against min/max of the other range.
> > 
> > For BPF_JSET we keep the original capabilities, just make sure JSET is
> > integrated in the common framework. This is manifested in the
> > internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more
>                     ^ typo?
> 
> Two more comments below
> 
> > uniform rev_opcode() handling. See the code for details. This allows to
> > reuse the same code exactly both for TRUE and FALSE branches without
> > explicitly handling both conditions with custom code.
> > 
> > Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE
> > case none of the registers are constants. This is now just a normal
> > generic case handled by reg_set_min_max().
> > 
> > To make tnum handling cleaner, tnum_with_subreg() helper is added, as
> > that's a common operator when dealing with 32-bit subregister bounds.
> > This keeps the overall logic much less noisy when it comes to tnums.
> > 
> > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> > ---
> >  include/linux/tnum.h  |   4 +
> >  kernel/bpf/tnum.c     |   7 +-
> >  kernel/bpf/verifier.c | 327 ++++++++++++++++++++----------------------
> >  3 files changed, 165 insertions(+), 173 deletions(-)
> > 
> > diff --git a/include/linux/tnum.h b/include/linux/tnum.h
> > index 1c3948a1d6ad..3c13240077b8 100644
> > --- a/include/linux/tnum.h
> > +++ b/include/linux/tnum.h
> > @@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a);
> >  struct tnum tnum_subreg(struct tnum a);
> >  /* Returns the tnum with the lower 32-bit subreg cleared */
> >  struct tnum tnum_clear_subreg(struct tnum a);
> > +/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower
> > + * 32-bit subreg in *subreg*
> > + */
> > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg);
> >  /* Returns the tnum with the lower 32-bit subreg set to value */
> >  struct tnum tnum_const_subreg(struct tnum a, u32 value);
> >  /* Returns true if 32-bit subreg @a is a known constant*/
> > diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
> > index 3d7127f439a1..f4c91c9b27d7 100644
> > --- a/kernel/bpf/tnum.c
> > +++ b/kernel/bpf/tnum.c
> > @@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a)
> >  	return tnum_lshift(tnum_rshift(a, 32), 32);
> >  }
> >  
> > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
> > +{
> > +	return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
> > +}
> > +
> >  struct tnum tnum_const_subreg(struct tnum a, u32 value)
> >  {
> > -	return tnum_or(tnum_clear_subreg(a), tnum_const(value));
> > +	return tnum_with_subreg(a, tnum_const(value));
> >  }
> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > index 2197385d91dc..52934080042c 100644
> > --- a/kernel/bpf/verifier.c
> > +++ b/kernel/bpf/verifier.c
> > @@ -14379,218 +14379,211 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
> >  	return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
> >  }
> >  
> > -/* Adjusts the register min/max values in the case that the dst_reg and
> > - * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
> > - * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
> > - * Technically we can do similar adjustments for pointers to the same object,
> > - * but we don't support that right now.
> > +/* Opcode that corresponds to a *false* branch condition.
> > + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2
> >   */
> > -static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> > -			    struct bpf_reg_state *true_reg2,
> > -			    struct bpf_reg_state *false_reg1,
> > -			    struct bpf_reg_state *false_reg2,
> > -			    u8 opcode, bool is_jmp32)
> > +static u8 rev_opcode(u8 opcode)
> 
> Nit: rev_opcode and flip_opcode seems like a possible source of confusing
> down the line. Flip and reverse are often interchangable, i.e. "flip the
> order" and "reverse the order" is the same thing.
> 
> Maybe "neg_opcode" or "neg_cond_opcode"?
> 
> Or do it the otherway around, keep rev_opcode but rename flip_opcode.
> 
> One more comment about BPF_JSET below
> 
> >  {
> > -	struct tnum false_32off, false_64off;
> > -	struct tnum true_32off, true_64off;
> > -	u64 uval;
> > -	u32 uval32;
> > -	s64 sval;
> > -	s32 sval32;
> > -
> > -	/* If either register is a pointer, we can't learn anything about its
> > -	 * variable offset from the compare (unless they were a pointer into
> > -	 * the same object, but we don't bother with that).
> > +	switch (opcode) {
> > +	case BPF_JEQ:		return BPF_JNE;
> > +	case BPF_JNE:		return BPF_JEQ;
> > +	/* JSET doesn't have it's reverse opcode in BPF, so add
> > +	 * BPF_X flag to denote the reverse of that operation
> >  	 */
> > -	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
> > -		return;
> > -
> > -	/* we expect right-hand registers (src ones) to be constants, for now */
> > -	if (!is_reg_const(false_reg2, is_jmp32)) {
> > -		opcode = flip_opcode(opcode);
> > -		swap(true_reg1, true_reg2);
> > -		swap(false_reg1, false_reg2);
> > +	case BPF_JSET:		return BPF_JSET | BPF_X;
> > +	case BPF_JSET | BPF_X:	return BPF_JSET;
> > +	case BPF_JGE:		return BPF_JLT;
> > +	case BPF_JGT:		return BPF_JLE;
> > +	case BPF_JLE:		return BPF_JGT;
> > +	case BPF_JLT:		return BPF_JGE;
> > +	case BPF_JSGE:		return BPF_JSLT;
> > +	case BPF_JSGT:		return BPF_JSLE;
> > +	case BPF_JSLE:		return BPF_JSGT;
> > +	case BPF_JSLT:		return BPF_JSGE;
> > +	default:		return 0;
> >  	}
> > -	if (!is_reg_const(false_reg2, is_jmp32))
> > -		return;
> > +}
> >  
> > -	false_32off = tnum_subreg(false_reg1->var_off);
> > -	false_64off = false_reg1->var_off;
> > -	true_32off = tnum_subreg(true_reg1->var_off);
> > -	true_64off = true_reg1->var_off;
> > -	uval = false_reg2->var_off.value;
> > -	uval32 = (u32)tnum_subreg(false_reg2->var_off).value;
> > -	sval = (s64)uval;
> > -	sval32 = (s32)uval32;
> > +/* Refine range knowledge for <reg1> <op> <reg>2 conditional operation. */
> > +static void regs_refine_cond_op(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
> > +				u8 opcode, bool is_jmp32)
> > +{
> > +	struct tnum t;
> >  
> >  	switch (opcode) {
> > -	/* JEQ/JNE comparison doesn't change the register equivalence.
> > -	 *
> > -	 * r1 = r2;
> > -	 * if (r1 == 42) goto label;
> > -	 * ...
> > -	 * label: // here both r1 and r2 are known to be 42.
> > -	 *
> > -	 * Hence when marking register as known preserve it's ID.
> > -	 */
> >  	case BPF_JEQ:
> >  		if (is_jmp32) {
> > -			__mark_reg32_known(true_reg1, uval32);
> > -			true_32off = tnum_subreg(true_reg1->var_off);
> > +			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
> > +			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
> > +			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
> > +			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
> > +			reg2->u32_min_value = reg1->u32_min_value;
> > +			reg2->u32_max_value = reg1->u32_max_value;
> > +			reg2->s32_min_value = reg1->s32_min_value;
> > +			reg2->s32_max_value = reg1->s32_max_value;
> > +
> > +			t = tnum_intersect(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
> > +			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> > +			reg2->var_off = tnum_with_subreg(reg2->var_off, t);
> >  		} else {
> > -			___mark_reg_known(true_reg1, uval);
> > -			true_64off = true_reg1->var_off;
> > +			reg1->umin_value = max(reg1->umin_value, reg2->umin_value);
> > +			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
> > +			reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
> > +			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
> > +			reg2->umin_value = reg1->umin_value;
> > +			reg2->umax_value = reg1->umax_value;
> > +			reg2->smin_value = reg1->smin_value;
> > +			reg2->smax_value = reg1->smax_value;
> > +
> > +			reg1->var_off = tnum_intersect(reg1->var_off, reg2->var_off);
> > +			reg2->var_off = reg1->var_off;
> >  		}
> >  		break;
> >  	case BPF_JNE:
> > +		/* we don't derive any new information for inequality yet */
> > +		break;
> > +	case BPF_JSET:
> > +	case BPF_JSET | BPF_X: { /* BPF_JSET and its reverse, see rev_opcode() */
> > +		u64 val;
> > +
> > +		if (!is_reg_const(reg2, is_jmp32))
> > +			swap(reg1, reg2);
> > +		if (!is_reg_const(reg2, is_jmp32))
> > +			break;
> > +
> > +		val = reg_const_value(reg2, is_jmp32);
> > +		/* BPF_JSET (i.e., TRUE branch, *not* BPF_JSET | BPF_X)
> > +		 * requires single bit to learn something useful. E.g., if we
> > +		 * know that `r1 & 0x3` is true, then which bits (0, 1, or both)
> > +		 * are actually set? We can learn something definite only if
> > +		 * it's a single-bit value to begin with.
> > +		 *
> > +		 * BPF_JSET | BPF_X (i.e., negation of BPF_JSET) doesn't have
> > +		 * this restriction. I.e., !(r1 & 0x3) means neither bit 0 nor
> > +		 * bit 1 is set, which we can readily use in adjustments.
> > +		 */
> > +		if (!(opcode & BPF_X) && !is_power_of_2(val))
> > +			break;
> > +
> >  		if (is_jmp32) {
> > -			__mark_reg32_known(false_reg1, uval32);
> > -			false_32off = tnum_subreg(false_reg1->var_off);
> > +			if (opcode & BPF_X)
> > +				t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val));
> > +			else
> > +				t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val));
> > +			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> >  		} else {
> > -			___mark_reg_known(false_reg1, uval);
> > -			false_64off = false_reg1->var_off;
> > +			if (opcode & BPF_X)
> > +				reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val));
> > +			else
> > +				reg1->var_off = tnum_or(reg1->var_off, tnum_const(val));
> >  		}
> >  		break;
> 
> Since you're already adding a tnum helper, I think we can add one more
> for BPF_JSET here
> 
> 	struct tnum tnum_neg(struct tnum a)
> 	{
> 		return TNUM(~a.value, a.mask);
> 	}

Didn't think it through well enough, with the above we might end up with a
invalid tnum because the unknown bits gets negated as well, need to mask the
unknown bits out.

 	struct tnum tnum_neg(struct tnum a)
 	{
 		return TNUM(~a.value & ~a.mask, a.mask);
 	}

> So instead of getting a value out of tnum then putting the value back
> into tnum again
> 
>     u64 val;
>     val = reg_const_value(reg2, is_jmp32);
>     tnum_ops(..., tnum_const(val or ~val);
> 
> Keep the value in tnum and process it as-is if possible
> 
>     tnum_ops(..., reg2->var_off or tnum_neg(reg2->var_off));
> 
> And with that hopefully make this fragment short enough that we don't
> mind duplicate a bit of code to seperate the BPF_JSET case from the
> BPF_JSET | BPF_X case. IMO a conditional is_power_of_2 check followed by
> two level of branching is a bit too much to follow, it is better to have
> them seperated just like how you're doing it for the others already.
> 
> I.e. something like the follow
> 
> 	case BPF_JSET: {
> 		if (!is_reg_const(reg2, is_jmp32))
> 			swap(reg1, reg2);
> 		if (!is_reg_const(reg2, is_jmp32))
> 			break;
> 		/* comment */
> 		if (!is_power_of_2(reg_const_value(reg2, is_jmp32))
> 			break;
> 
> 		if (is_jmp32) {
> 			t = tnum_or(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
> 			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> 		} else {
> 			reg1->var_off = tnum_or(reg1->var_off, reg2->var_off);
> 		}
> 		break;
> 	}
> 	case BPF_JSET | BPF_X: {
> 		if (!is_reg_const(reg2, is_jmp32))
> 			swap(reg1, reg2);
> 		if (!is_reg_const(reg2, is_jmp32))
> 			break;
> 
> 		if (is_jmp32) {
> 			/* a slightly long line ... */
> 			t = tnum_and(tnum_subreg(reg1->var_off), tnum_neg(tnum_subreg(reg2->var_off)));
> 			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> 		} else {
> 			reg1->var_off = tnum_and(reg1->var_off, tnum_neg(reg2->var_off));
> 		}
> 		break;
> 	}
> 
> > ...
Eduard Zingerman Nov. 3, 2023, 4:20 p.m. UTC | #3
On Thu, 2023-11-02 at 17:08 -0700, Andrii Nakryiko wrote:
> Generalize bounds adjustment logic of reg_set_min_max() to handle not
> just register vs constant case, but in general any register vs any
> register cases. For most of the operations it's trivial extension based
> on range vs range comparison logic, we just need to properly pick
> min/max of a range to compare against min/max of the other range.
> 
> For BPF_JSET we keep the original capabilities, just make sure JSET is
> integrated in the common framework. This is manifested in the
> internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more
> uniform rev_opcode() handling. See the code for details. This allows to
> reuse the same code exactly both for TRUE and FALSE branches without
> explicitly handling both conditions with custom code.
> 
> Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE
> case none of the registers are constants. This is now just a normal
> generic case handled by reg_set_min_max().
> 
> To make tnum handling cleaner, tnum_with_subreg() helper is added, as
> that's a common operator when dealing with 32-bit subregister bounds.
> This keeps the overall logic much less noisy when it comes to tnums.
> 
> Signed-off-by: Andrii Nakryiko <andrii@kernel.org>

Acked-by: Eduard Zingerman <eddyz87@gmail.com>

(With one bit of a bikeshedding below).

> ---
>  include/linux/tnum.h  |   4 +
>  kernel/bpf/tnum.c     |   7 +-
>  kernel/bpf/verifier.c | 327 ++++++++++++++++++++----------------------
>  3 files changed, 165 insertions(+), 173 deletions(-)
> 
> diff --git a/include/linux/tnum.h b/include/linux/tnum.h
> index 1c3948a1d6ad..3c13240077b8 100644
> --- a/include/linux/tnum.h
> +++ b/include/linux/tnum.h
> @@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a);
>  struct tnum tnum_subreg(struct tnum a);
>  /* Returns the tnum with the lower 32-bit subreg cleared */
>  struct tnum tnum_clear_subreg(struct tnum a);
> +/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower
> + * 32-bit subreg in *subreg*
> + */
> +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg);
>  /* Returns the tnum with the lower 32-bit subreg set to value */
>  struct tnum tnum_const_subreg(struct tnum a, u32 value);
>  /* Returns true if 32-bit subreg @a is a known constant*/
> diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
> index 3d7127f439a1..f4c91c9b27d7 100644
> --- a/kernel/bpf/tnum.c
> +++ b/kernel/bpf/tnum.c
> @@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a)
>  	return tnum_lshift(tnum_rshift(a, 32), 32);
>  }
>  
> +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
> +{
> +	return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
> +}
> +
>  struct tnum tnum_const_subreg(struct tnum a, u32 value)
>  {
> -	return tnum_or(tnum_clear_subreg(a), tnum_const(value));
> +	return tnum_with_subreg(a, tnum_const(value));
>  }
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 2197385d91dc..52934080042c 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -14379,218 +14379,211 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
>  	return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
>  }
>  
> -/* Adjusts the register min/max values in the case that the dst_reg and
> - * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
> - * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
> - * Technically we can do similar adjustments for pointers to the same object,
> - * but we don't support that right now.
> +/* Opcode that corresponds to a *false* branch condition.
> + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2
>   */
> -static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> -			    struct bpf_reg_state *true_reg2,
> -			    struct bpf_reg_state *false_reg1,
> -			    struct bpf_reg_state *false_reg2,
> -			    u8 opcode, bool is_jmp32)
> +static u8 rev_opcode(u8 opcode)
>  {
> -	struct tnum false_32off, false_64off;
> -	struct tnum true_32off, true_64off;
> -	u64 uval;
> -	u32 uval32;
> -	s64 sval;
> -	s32 sval32;
> -
> -	/* If either register is a pointer, we can't learn anything about its
> -	 * variable offset from the compare (unless they were a pointer into
> -	 * the same object, but we don't bother with that).
> +	switch (opcode) {
> +	case BPF_JEQ:		return BPF_JNE;
> +	case BPF_JNE:		return BPF_JEQ;
> +	/* JSET doesn't have it's reverse opcode in BPF, so add
> +	 * BPF_X flag to denote the reverse of that operation
>  	 */
> -	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
> -		return;
> -
> -	/* we expect right-hand registers (src ones) to be constants, for now */
> -	if (!is_reg_const(false_reg2, is_jmp32)) {
> -		opcode = flip_opcode(opcode);
> -		swap(true_reg1, true_reg2);
> -		swap(false_reg1, false_reg2);
> +	case BPF_JSET:		return BPF_JSET | BPF_X;
> +	case BPF_JSET | BPF_X:	return BPF_JSET;
> +	case BPF_JGE:		return BPF_JLT;
> +	case BPF_JGT:		return BPF_JLE;
> +	case BPF_JLE:		return BPF_JGT;
> +	case BPF_JLT:		return BPF_JGE;
> +	case BPF_JSGE:		return BPF_JSLT;
> +	case BPF_JSGT:		return BPF_JSLE;
> +	case BPF_JSLE:		return BPF_JSGT;
> +	case BPF_JSLT:		return BPF_JSGE;
> +	default:		return 0;
>  	}
> -	if (!is_reg_const(false_reg2, is_jmp32))
> -		return;
> +}
>  
> -	false_32off = tnum_subreg(false_reg1->var_off);
> -	false_64off = false_reg1->var_off;
> -	true_32off = tnum_subreg(true_reg1->var_off);
> -	true_64off = true_reg1->var_off;
> -	uval = false_reg2->var_off.value;
> -	uval32 = (u32)tnum_subreg(false_reg2->var_off).value;
> -	sval = (s64)uval;
> -	sval32 = (s32)uval32;
> +/* Refine range knowledge for <reg1> <op> <reg>2 conditional operation. */
> +static void regs_refine_cond_op(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
> +				u8 opcode, bool is_jmp32)
> +{
> +	struct tnum t;
>  
>  	switch (opcode) {
> -	/* JEQ/JNE comparison doesn't change the register equivalence.
> -	 *
> -	 * r1 = r2;
> -	 * if (r1 == 42) goto label;
> -	 * ...
> -	 * label: // here both r1 and r2 are known to be 42.
> -	 *
> -	 * Hence when marking register as known preserve it's ID.
> -	 */
>  	case BPF_JEQ:
>  		if (is_jmp32) {
> -			__mark_reg32_known(true_reg1, uval32);
> -			true_32off = tnum_subreg(true_reg1->var_off);
> +			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
> +			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
> +			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
> +			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
> +			reg2->u32_min_value = reg1->u32_min_value;
> +			reg2->u32_max_value = reg1->u32_max_value;
> +			reg2->s32_min_value = reg1->s32_min_value;
> +			reg2->s32_max_value = reg1->s32_max_value;
> +
> +			t = tnum_intersect(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
> +			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> +			reg2->var_off = tnum_with_subreg(reg2->var_off, t);
>  		} else {
> -			___mark_reg_known(true_reg1, uval);
> -			true_64off = true_reg1->var_off;
> +			reg1->umin_value = max(reg1->umin_value, reg2->umin_value);
> +			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
> +			reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
> +			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
> +			reg2->umin_value = reg1->umin_value;
> +			reg2->umax_value = reg1->umax_value;
> +			reg2->smin_value = reg1->smin_value;
> +			reg2->smax_value = reg1->smax_value;
> +
> +			reg1->var_off = tnum_intersect(reg1->var_off, reg2->var_off);
> +			reg2->var_off = reg1->var_off;
>  		}
>  		break;
>  	case BPF_JNE:
> +		/* we don't derive any new information for inequality yet */
> +		break;
> +	case BPF_JSET:
> +	case BPF_JSET | BPF_X: { /* BPF_JSET and its reverse, see rev_opcode() */
> +		u64 val;
> +
> +		if (!is_reg_const(reg2, is_jmp32))
> +			swap(reg1, reg2);
> +		if (!is_reg_const(reg2, is_jmp32))
> +			break;
> +
> +		val = reg_const_value(reg2, is_jmp32);
> +		/* BPF_JSET (i.e., TRUE branch, *not* BPF_JSET | BPF_X)
> +		 * requires single bit to learn something useful. E.g., if we
> +		 * know that `r1 & 0x3` is true, then which bits (0, 1, or both)
> +		 * are actually set? We can learn something definite only if
> +		 * it's a single-bit value to begin with.
> +		 *
> +		 * BPF_JSET | BPF_X (i.e., negation of BPF_JSET) doesn't have
> +		 * this restriction. I.e., !(r1 & 0x3) means neither bit 0 nor
> +		 * bit 1 is set, which we can readily use in adjustments.
> +		 */
> +		if (!(opcode & BPF_X) && !is_power_of_2(val))
> +			break;
> +
>  		if (is_jmp32) {
> -			__mark_reg32_known(false_reg1, uval32);
> -			false_32off = tnum_subreg(false_reg1->var_off);
> +			if (opcode & BPF_X)
> +				t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val));
> +			else
> +				t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val));
> +			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
>  		} else {
> -			___mark_reg_known(false_reg1, uval);
> -			false_64off = false_reg1->var_off;
> +			if (opcode & BPF_X)
> +				reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val));
> +			else
> +				reg1->var_off = tnum_or(reg1->var_off, tnum_const(val));
>  		}
>  		break;
> -	case BPF_JSET:
> +	}
> +	case BPF_JGE:
>  		if (is_jmp32) {
> -			false_32off = tnum_and(false_32off, tnum_const(~uval32));
> -			if (is_power_of_2(uval32))
> -				true_32off = tnum_or(true_32off,
> -						     tnum_const(uval32));
> +			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
> +			reg2->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
>  		} else {
> -			false_64off = tnum_and(false_64off, tnum_const(~uval));
> -			if (is_power_of_2(uval))
> -				true_64off = tnum_or(true_64off,
> -						     tnum_const(uval));
> +			reg1->umin_value = max(reg1->umin_value, reg2->umin_value);
> +			reg2->umax_value = min(reg1->umax_value, reg2->umax_value);
>  		}
>  		break;
> -	case BPF_JGE:
>  	case BPF_JGT:
> -	{
>  		if (is_jmp32) {
> -			u32 false_umax = opcode == BPF_JGT ? uval32  : uval32 - 1;
> -			u32 true_umin = opcode == BPF_JGT ? uval32 + 1 : uval32;
> -
> -			false_reg1->u32_max_value = min(false_reg1->u32_max_value,
> -						       false_umax);
> -			true_reg1->u32_min_value = max(true_reg1->u32_min_value,
> -						      true_umin);
> +			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value + 1);
> +			reg2->u32_max_value = min(reg1->u32_max_value - 1, reg2->u32_max_value);
>  		} else {
> -			u64 false_umax = opcode == BPF_JGT ? uval    : uval - 1;
> -			u64 true_umin = opcode == BPF_JGT ? uval + 1 : uval;
> -
> -			false_reg1->umax_value = min(false_reg1->umax_value, false_umax);
> -			true_reg1->umin_value = max(true_reg1->umin_value, true_umin);
> +			reg1->umin_value = max(reg1->umin_value, reg2->umin_value + 1);
> +			reg2->umax_value = min(reg1->umax_value - 1, reg2->umax_value);
>  		}
>  		break;
> -	}
>  	case BPF_JSGE:
> +		if (is_jmp32) {
> +			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
> +			reg2->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
> +		} else {
> +			reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
> +			reg2->smax_value = min(reg1->smax_value, reg2->smax_value);
> +		}
> +		break;
>  	case BPF_JSGT:

It is possible to spare some code by swapping arguments here:

	case BPF_JLE:
	case BPF_JLT:
	case BPF_JSLE:
	case BPF_JSLT:
		return regs_refine_cond_op(reg2, reg1, flip_opcode(opcode), is_jmp32);


> -	{
>  		if (is_jmp32) {
> -			s32 false_smax = opcode == BPF_JSGT ? sval32    : sval32 - 1;
> -			s32 true_smin = opcode == BPF_JSGT ? sval32 + 1 : sval32;
> -
> -			false_reg1->s32_max_value = min(false_reg1->s32_max_value, false_smax);
> -			true_reg1->s32_min_value = max(true_reg1->s32_min_value, true_smin);
> +			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value + 1);
> +			reg2->s32_max_value = min(reg1->s32_max_value - 1, reg2->s32_max_value);
>  		} else {
> -			s64 false_smax = opcode == BPF_JSGT ? sval    : sval - 1;
> -			s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval;
> -
> -			false_reg1->smax_value = min(false_reg1->smax_value, false_smax);
> -			true_reg1->smin_value = max(true_reg1->smin_value, true_smin);
> +			reg1->smin_value = max(reg1->smin_value, reg2->smin_value + 1);
> +			reg2->smax_value = min(reg1->smax_value - 1, reg2->smax_value);
>  		}
>  		break;
> -	}
>  	case BPF_JLE:
> +		if (is_jmp32) {
> +			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
> +			reg2->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
> +		} else {
> +			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
> +			reg2->umin_value = max(reg1->umin_value, reg2->umin_value);
> +		}
> +		break;
>  	case BPF_JLT:
> -	{
>  		if (is_jmp32) {
> -			u32 false_umin = opcode == BPF_JLT ? uval32  : uval32 + 1;
> -			u32 true_umax = opcode == BPF_JLT ? uval32 - 1 : uval32;
> -
> -			false_reg1->u32_min_value = max(false_reg1->u32_min_value,
> -						       false_umin);
> -			true_reg1->u32_max_value = min(true_reg1->u32_max_value,
> -						      true_umax);
> +			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value - 1);
> +			reg2->u32_min_value = max(reg1->u32_min_value + 1, reg2->u32_min_value);
>  		} else {
> -			u64 false_umin = opcode == BPF_JLT ? uval    : uval + 1;
> -			u64 true_umax = opcode == BPF_JLT ? uval - 1 : uval;
> -
> -			false_reg1->umin_value = max(false_reg1->umin_value, false_umin);
> -			true_reg1->umax_value = min(true_reg1->umax_value, true_umax);
> +			reg1->umax_value = min(reg1->umax_value, reg2->umax_value - 1);
> +			reg2->umin_value = max(reg1->umin_value + 1, reg2->umin_value);
>  		}
>  		break;
> -	}
>  	case BPF_JSLE:
> +		if (is_jmp32) {
> +			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
> +			reg2->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
> +		} else {
> +			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
> +			reg2->smin_value = max(reg1->smin_value, reg2->smin_value);
> +		}
> +		break;
>  	case BPF_JSLT:
> -	{
>  		if (is_jmp32) {
> -			s32 false_smin = opcode == BPF_JSLT ? sval32    : sval32 + 1;
> -			s32 true_smax = opcode == BPF_JSLT ? sval32 - 1 : sval32;
> -
> -			false_reg1->s32_min_value = max(false_reg1->s32_min_value, false_smin);
> -			true_reg1->s32_max_value = min(true_reg1->s32_max_value, true_smax);
> +			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value - 1);
> +			reg2->s32_min_value = max(reg1->s32_min_value + 1, reg2->s32_min_value);
>  		} else {
> -			s64 false_smin = opcode == BPF_JSLT ? sval    : sval + 1;
> -			s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval;
> -
> -			false_reg1->smin_value = max(false_reg1->smin_value, false_smin);
> -			true_reg1->smax_value = min(true_reg1->smax_value, true_smax);
> +			reg1->smax_value = min(reg1->smax_value, reg2->smax_value - 1);
> +			reg2->smin_value = max(reg1->smin_value + 1, reg2->smin_value);
>  		}
>  		break;
> -	}
>  	default:
>  		return;
>  	}
> -
> -	if (is_jmp32) {
> -		false_reg1->var_off = tnum_or(tnum_clear_subreg(false_64off),
> -					     tnum_subreg(false_32off));
> -		true_reg1->var_off = tnum_or(tnum_clear_subreg(true_64off),
> -					    tnum_subreg(true_32off));
> -		reg_bounds_sync(false_reg1);
> -		reg_bounds_sync(true_reg1);
> -	} else {
> -		false_reg1->var_off = false_64off;
> -		true_reg1->var_off = true_64off;
> -		reg_bounds_sync(false_reg1);
> -		reg_bounds_sync(true_reg1);
> -	}
> -}
> -
> -/* Regs are known to be equal, so intersect their min/max/var_off */
> -static void __reg_combine_min_max(struct bpf_reg_state *src_reg,
> -				  struct bpf_reg_state *dst_reg)
> -{
> -	src_reg->umin_value = dst_reg->umin_value = max(src_reg->umin_value,
> -							dst_reg->umin_value);
> -	src_reg->umax_value = dst_reg->umax_value = min(src_reg->umax_value,
> -							dst_reg->umax_value);
> -	src_reg->smin_value = dst_reg->smin_value = max(src_reg->smin_value,
> -							dst_reg->smin_value);
> -	src_reg->smax_value = dst_reg->smax_value = min(src_reg->smax_value,
> -							dst_reg->smax_value);
> -	src_reg->var_off = dst_reg->var_off = tnum_intersect(src_reg->var_off,
> -							     dst_reg->var_off);
> -	reg_bounds_sync(src_reg);
> -	reg_bounds_sync(dst_reg);
>  }
>  
> -static void reg_combine_min_max(struct bpf_reg_state *true_src,
> -				struct bpf_reg_state *true_dst,
> -				struct bpf_reg_state *false_src,
> -				struct bpf_reg_state *false_dst,
> -				u8 opcode)
> +/* Adjusts the register min/max values in the case that the dst_reg and
> + * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
> + * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
> + * Technically we can do similar adjustments for pointers to the same object,
> + * but we don't support that right now.
> + */
> +static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> +			    struct bpf_reg_state *true_reg2,
> +			    struct bpf_reg_state *false_reg1,
> +			    struct bpf_reg_state *false_reg2,
> +			    u8 opcode, bool is_jmp32)
>  {
> -	switch (opcode) {
> -	case BPF_JEQ:
> -		__reg_combine_min_max(true_src, true_dst);
> -		break;
> -	case BPF_JNE:
> -		__reg_combine_min_max(false_src, false_dst);
> -		break;
> -	}
> +	/* If either register is a pointer, we can't learn anything about its
> +	 * variable offset from the compare (unless they were a pointer into
> +	 * the same object, but we don't bother with that).
> +	 */
> +	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
> +		return;
> +
> +	/* fallthrough (FALSE) branch */
> +	regs_refine_cond_op(false_reg1, false_reg2, rev_opcode(opcode), is_jmp32);
> +	reg_bounds_sync(false_reg1);
> +	reg_bounds_sync(false_reg2);
> +
> +	/* jump (TRUE) branch */
> +	regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32);
> +	reg_bounds_sync(true_reg1);
> +	reg_bounds_sync(true_reg2);
>  }
>  
>  static void mark_ptr_or_null_reg(struct bpf_func_state *state,
> @@ -14887,22 +14880,12 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
>  		reg_set_min_max(&other_branch_regs[insn->dst_reg],
>  				&other_branch_regs[insn->src_reg],
>  				dst_reg, src_reg, opcode, is_jmp32);
> -
> -		if (dst_reg->type == SCALAR_VALUE &&
> -		    src_reg->type == SCALAR_VALUE &&
> -		    !is_jmp32 && (opcode == BPF_JEQ || opcode == BPF_JNE)) {
> -			/* Comparing for equality, we can combine knowledge */
> -			reg_combine_min_max(&other_branch_regs[insn->src_reg],
> -					    &other_branch_regs[insn->dst_reg],
> -					    src_reg, dst_reg, opcode);
> -		}
>  	} else /* BPF_SRC(insn->code) == BPF_K */ {
>  		reg_set_min_max(&other_branch_regs[insn->dst_reg],
>  				src_reg /* fake one */,
>  				dst_reg, src_reg /* same fake one */,
>  				opcode, is_jmp32);
>  	}
> -
>  	if (BPF_SRC(insn->code) == BPF_X &&
>  	    src_reg->type == SCALAR_VALUE && src_reg->id &&
>  	    !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
Andrii Nakryiko Nov. 3, 2023, 8:39 p.m. UTC | #4
On Fri, Nov 3, 2023 at 9:20 AM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> On Thu, 2023-11-02 at 17:08 -0700, Andrii Nakryiko wrote:
> > Generalize bounds adjustment logic of reg_set_min_max() to handle not
> > just register vs constant case, but in general any register vs any
> > register cases. For most of the operations it's trivial extension based
> > on range vs range comparison logic, we just need to properly pick
> > min/max of a range to compare against min/max of the other range.
> >
> > For BPF_JSET we keep the original capabilities, just make sure JSET is
> > integrated in the common framework. This is manifested in the
> > internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more
> > uniform rev_opcode() handling. See the code for details. This allows to
> > reuse the same code exactly both for TRUE and FALSE branches without
> > explicitly handling both conditions with custom code.
> >
> > Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE
> > case none of the registers are constants. This is now just a normal
> > generic case handled by reg_set_min_max().
> >
> > To make tnum handling cleaner, tnum_with_subreg() helper is added, as
> > that's a common operator when dealing with 32-bit subregister bounds.
> > This keeps the overall logic much less noisy when it comes to tnums.
> >
> > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
>
> Acked-by: Eduard Zingerman <eddyz87@gmail.com>
>
> (With one bit of a bikeshedding below).
>
> > ---
> >  include/linux/tnum.h  |   4 +
> >  kernel/bpf/tnum.c     |   7 +-
> >  kernel/bpf/verifier.c | 327 ++++++++++++++++++++----------------------
> >  3 files changed, 165 insertions(+), 173 deletions(-)
> >

please trim irrelevant parts

[...]

> >       case BPF_JSGE:
> > +             if (is_jmp32) {
> > +                     reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
> > +                     reg2->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
> > +             } else {
> > +                     reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
> > +                     reg2->smax_value = min(reg1->smax_value, reg2->smax_value);
> > +             }
> > +             break;
> >       case BPF_JSGT:
>
> It is possible to spare some code by swapping arguments here:
>
>         case BPF_JLE:
>         case BPF_JLT:
>         case BPF_JSLE:
>         case BPF_JSLT:
>                 return regs_refine_cond_op(reg2, reg1, flip_opcode(opcode), is_jmp32);

yep, math is nice like that :) I'm a bit hesitant to add
recursive-looking calls (even though it's not recursion), so maybe
I'll just do:

case BPF_JLE:
case BPF_JLT:
case BPF_JSLE:
case BPF_JSLT:
    opcode = flip_opcode(opcode);
    swap(reg1, reg2);
    goto again;


and goto again will just jump to the beginning of this function?

Oh, and I more naturally think about LT/LE as "base conditions", so
I'll do the above for GE/GT operations.


>
>
> > -     {
> >               if (is_jmp32) {
> > -                     s32 false_smax = opcode == BPF_JSGT ? sval32    : sval32 - 1;
> > -                     s32 true_smin = opcode == BPF_JSGT ? sval32 + 1 : sval32;
> > -
> > -                     false_reg1->s32_max_value = min(false_reg1->s32_max_value, false_smax);
> > -                     true_reg1->s32_min_value = max(true_reg1->s32_min_value, true_smin);
> > +                     reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value + 1);
> > +                     reg2->s32_max_value = min(reg1->s32_max_value - 1, reg2->s32_max_value);

[...]
Andrii Nakryiko Nov. 3, 2023, 8:39 p.m. UTC | #5
On Fri, Nov 3, 2023 at 12:52 AM Shung-Hsi Yu <shung-hsi.yu@suse.com> wrote:
>
> On Thu, Nov 02, 2023 at 05:08:10PM -0700, Andrii Nakryiko wrote:
> > Generalize bounds adjustment logic of reg_set_min_max() to handle not
> > just register vs constant case, but in general any register vs any
> > register cases. For most of the operations it's trivial extension based
> > on range vs range comparison logic, we just need to properly pick
> > min/max of a range to compare against min/max of the other range.
> >
> > For BPF_JSET we keep the original capabilities, just make sure JSET is
> > integrated in the common framework. This is manifested in the
> > internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more
>                     ^ typo?
>
> Two more comments below
>
> > uniform rev_opcode() handling. See the code for details. This allows to
> > reuse the same code exactly both for TRUE and FALSE branches without
> > explicitly handling both conditions with custom code.
> >
> > Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE
> > case none of the registers are constants. This is now just a normal
> > generic case handled by reg_set_min_max().
> >
> > To make tnum handling cleaner, tnum_with_subreg() helper is added, as
> > that's a common operator when dealing with 32-bit subregister bounds.
> > This keeps the overall logic much less noisy when it comes to tnums.
> >
> > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> > ---
> >  include/linux/tnum.h  |   4 +
> >  kernel/bpf/tnum.c     |   7 +-
> >  kernel/bpf/verifier.c | 327 ++++++++++++++++++++----------------------
> >  3 files changed, 165 insertions(+), 173 deletions(-)
> >
> > diff --git a/include/linux/tnum.h b/include/linux/tnum.h
> > index 1c3948a1d6ad..3c13240077b8 100644
> > --- a/include/linux/tnum.h
> > +++ b/include/linux/tnum.h
> > @@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a);
> >  struct tnum tnum_subreg(struct tnum a);
> >  /* Returns the tnum with the lower 32-bit subreg cleared */
> >  struct tnum tnum_clear_subreg(struct tnum a);
> > +/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower
> > + * 32-bit subreg in *subreg*
> > + */
> > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg);
> >  /* Returns the tnum with the lower 32-bit subreg set to value */
> >  struct tnum tnum_const_subreg(struct tnum a, u32 value);
> >  /* Returns true if 32-bit subreg @a is a known constant*/
> > diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
> > index 3d7127f439a1..f4c91c9b27d7 100644
> > --- a/kernel/bpf/tnum.c
> > +++ b/kernel/bpf/tnum.c
> > @@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a)
> >       return tnum_lshift(tnum_rshift(a, 32), 32);
> >  }
> >
> > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
> > +{
> > +     return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
> > +}
> > +
> >  struct tnum tnum_const_subreg(struct tnum a, u32 value)
> >  {
> > -     return tnum_or(tnum_clear_subreg(a), tnum_const(value));
> > +     return tnum_with_subreg(a, tnum_const(value));
> >  }
> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > index 2197385d91dc..52934080042c 100644
> > --- a/kernel/bpf/verifier.c
> > +++ b/kernel/bpf/verifier.c
> > @@ -14379,218 +14379,211 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
> >       return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
> >  }
> >
> > -/* Adjusts the register min/max values in the case that the dst_reg and
> > - * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
> > - * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
> > - * Technically we can do similar adjustments for pointers to the same object,
> > - * but we don't support that right now.
> > +/* Opcode that corresponds to a *false* branch condition.
> > + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2
> >   */
> > -static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> > -                         struct bpf_reg_state *true_reg2,
> > -                         struct bpf_reg_state *false_reg1,
> > -                         struct bpf_reg_state *false_reg2,
> > -                         u8 opcode, bool is_jmp32)
> > +static u8 rev_opcode(u8 opcode)
>
> Nit: rev_opcode and flip_opcode seems like a possible source of confusing
> down the line. Flip and reverse are often interchangable, i.e. "flip the
> order" and "reverse the order" is the same thing.
>
> Maybe "neg_opcode" or "neg_cond_opcode"?

neg has too strong connotation with BPF_NEG, so not really happy with
this one. In selftest I used "complement_op", but it's also quite
arbitrary.

>
> Or do it the otherway around, keep rev_opcode but rename flip_opcode.

how about flip_opcode -> swap_opcode? and then keep reg_opcode as is?

>
> One more comment about BPF_JSET below
>

please trim big chunks of code you are not commenting on to keep
emails a bit shorter

[...]


> >               if (is_jmp32) {
> > -                     __mark_reg32_known(false_reg1, uval32);
> > -                     false_32off = tnum_subreg(false_reg1->var_off);
> > +                     if (opcode & BPF_X)
> > +                             t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val));
> > +                     else
> > +                             t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val));
> > +                     reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> >               } else {
> > -                     ___mark_reg_known(false_reg1, uval);
> > -                     false_64off = false_reg1->var_off;
> > +                     if (opcode & BPF_X)
> > +                             reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val));
> > +                     else
> > +                             reg1->var_off = tnum_or(reg1->var_off, tnum_const(val));
> >               }
> >               break;
>
> Since you're already adding a tnum helper, I think we can add one more
> for BPF_JSET here
>
>         struct tnum tnum_neg(struct tnum a)
>         {
>                 return TNUM(~a.value, a.mask);
>         }
>

I'm not sure what tnum_neg() does (even if the correct
implementation), but either way I'd like to minimize touching tnum
stuff, it's too tricky :) we can address that as a separate patch if
you'd like


> So instead of getting a value out of tnum then putting the value back
> into tnum again
>
>     u64 val;
>     val = reg_const_value(reg2, is_jmp32);
>     tnum_ops(..., tnum_const(val or ~val);
>
> Keep the value in tnum and process it as-is if possible
>
>     tnum_ops(..., reg2->var_off or tnum_neg(reg2->var_off));

>
> And with that hopefully make this fragment short enough that we don't
> mind duplicate a bit of code to seperate the BPF_JSET case from the
> BPF_JSET | BPF_X case. IMO a conditional is_power_of_2 check followed by
> two level of branching is a bit too much to follow, it is better to have
> them seperated just like how you're doing it for the others already.

I can split those two cases without any new tnum helpers, the
duplicated part is just const checking, basically, no big deal

>
> I.e. something like the follow
>
>         case BPF_JSET: {
>                 if (!is_reg_const(reg2, is_jmp32))
>                         swap(reg1, reg2);
>                 if (!is_reg_const(reg2, is_jmp32))
>                         break;
>                 /* comment */
>                 if (!is_power_of_2(reg_const_value(reg2, is_jmp32))
>                         break;
>
>                 if (is_jmp32) {
>                         t = tnum_or(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
>                         reg1->var_off = tnum_with_subreg(reg1->var_off, t);
>                 } else {
>                         reg1->var_off = tnum_or(reg1->var_off, reg2->var_off);
>                 }
>                 break;
>         }
>         case BPF_JSET | BPF_X: {
>                 if (!is_reg_const(reg2, is_jmp32))
>                         swap(reg1, reg2);
>                 if (!is_reg_const(reg2, is_jmp32))
>                         break;
>
>                 if (is_jmp32) {
>                         /* a slightly long line ... */
>                         t = tnum_and(tnum_subreg(reg1->var_off), tnum_neg(tnum_subreg(reg2->var_off)));
>                         reg1->var_off = tnum_with_subreg(reg1->var_off, t);
>                 } else {
>                         reg1->var_off = tnum_and(reg1->var_off, tnum_neg(reg2->var_off));
>                 }
>                 break;
>         }
>
> > ...
Andrii Nakryiko Nov. 3, 2023, 8:48 p.m. UTC | #6
On Fri, Nov 3, 2023 at 1:39 PM Andrii Nakryiko
<andrii.nakryiko@gmail.com> wrote:
>
> On Fri, Nov 3, 2023 at 12:52 AM Shung-Hsi Yu <shung-hsi.yu@suse.com> wrote:
> >
> > On Thu, Nov 02, 2023 at 05:08:10PM -0700, Andrii Nakryiko wrote:
> > > Generalize bounds adjustment logic of reg_set_min_max() to handle not
> > > just register vs constant case, but in general any register vs any
> > > register cases. For most of the operations it's trivial extension based
> > > on range vs range comparison logic, we just need to properly pick
> > > min/max of a range to compare against min/max of the other range.
> > >
> > > For BPF_JSET we keep the original capabilities, just make sure JSET is
> > > integrated in the common framework. This is manifested in the
> > > internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more
> >                     ^ typo?
> >
> > Two more comments below
> >
> > > uniform rev_opcode() handling. See the code for details. This allows to
> > > reuse the same code exactly both for TRUE and FALSE branches without
> > > explicitly handling both conditions with custom code.
> > >
> > > Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE
> > > case none of the registers are constants. This is now just a normal
> > > generic case handled by reg_set_min_max().
> > >
> > > To make tnum handling cleaner, tnum_with_subreg() helper is added, as
> > > that's a common operator when dealing with 32-bit subregister bounds.
> > > This keeps the overall logic much less noisy when it comes to tnums.
> > >
> > > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> > > ---
> > >  include/linux/tnum.h  |   4 +
> > >  kernel/bpf/tnum.c     |   7 +-
> > >  kernel/bpf/verifier.c | 327 ++++++++++++++++++++----------------------
> > >  3 files changed, 165 insertions(+), 173 deletions(-)
> > >
> > > diff --git a/include/linux/tnum.h b/include/linux/tnum.h
> > > index 1c3948a1d6ad..3c13240077b8 100644
> > > --- a/include/linux/tnum.h
> > > +++ b/include/linux/tnum.h
> > > @@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a);
> > >  struct tnum tnum_subreg(struct tnum a);
> > >  /* Returns the tnum with the lower 32-bit subreg cleared */
> > >  struct tnum tnum_clear_subreg(struct tnum a);
> > > +/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower
> > > + * 32-bit subreg in *subreg*
> > > + */
> > > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg);
> > >  /* Returns the tnum with the lower 32-bit subreg set to value */
> > >  struct tnum tnum_const_subreg(struct tnum a, u32 value);
> > >  /* Returns true if 32-bit subreg @a is a known constant*/
> > > diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
> > > index 3d7127f439a1..f4c91c9b27d7 100644
> > > --- a/kernel/bpf/tnum.c
> > > +++ b/kernel/bpf/tnum.c
> > > @@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a)
> > >       return tnum_lshift(tnum_rshift(a, 32), 32);
> > >  }
> > >
> > > +struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
> > > +{
> > > +     return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
> > > +}
> > > +
> > >  struct tnum tnum_const_subreg(struct tnum a, u32 value)
> > >  {
> > > -     return tnum_or(tnum_clear_subreg(a), tnum_const(value));
> > > +     return tnum_with_subreg(a, tnum_const(value));
> > >  }
> > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > index 2197385d91dc..52934080042c 100644
> > > --- a/kernel/bpf/verifier.c
> > > +++ b/kernel/bpf/verifier.c
> > > @@ -14379,218 +14379,211 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
> > >       return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
> > >  }
> > >
> > > -/* Adjusts the register min/max values in the case that the dst_reg and
> > > - * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
> > > - * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
> > > - * Technically we can do similar adjustments for pointers to the same object,
> > > - * but we don't support that right now.
> > > +/* Opcode that corresponds to a *false* branch condition.
> > > + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2
> > >   */
> > > -static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> > > -                         struct bpf_reg_state *true_reg2,
> > > -                         struct bpf_reg_state *false_reg1,
> > > -                         struct bpf_reg_state *false_reg2,
> > > -                         u8 opcode, bool is_jmp32)
> > > +static u8 rev_opcode(u8 opcode)
> >
> > Nit: rev_opcode and flip_opcode seems like a possible source of confusing
> > down the line. Flip and reverse are often interchangable, i.e. "flip the
> > order" and "reverse the order" is the same thing.
> >
> > Maybe "neg_opcode" or "neg_cond_opcode"?
>
> neg has too strong connotation with BPF_NEG, so not really happy with
> this one. In selftest I used "complement_op", but it's also quite
> arbitrary.
>
> >
> > Or do it the otherway around, keep rev_opcode but rename flip_opcode.
>
> how about flip_opcode -> swap_opcode? and then keep reg_opcode as is?

nah, swap_opcode sounds wrong as well. I guess I'll just leave it as is for now.

>
> >
> > One more comment about BPF_JSET below
> >
>
> please trim big chunks of code you are not commenting on to keep
> emails a bit shorter
>
> [...]
>
>
> > >               if (is_jmp32) {
> > > -                     __mark_reg32_known(false_reg1, uval32);
> > > -                     false_32off = tnum_subreg(false_reg1->var_off);
> > > +                     if (opcode & BPF_X)
> > > +                             t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val));
> > > +                     else
> > > +                             t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val));
> > > +                     reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> > >               } else {
> > > -                     ___mark_reg_known(false_reg1, uval);
> > > -                     false_64off = false_reg1->var_off;
> > > +                     if (opcode & BPF_X)
> > > +                             reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val));
> > > +                     else
> > > +                             reg1->var_off = tnum_or(reg1->var_off, tnum_const(val));
> > >               }
> > >               break;
> >
> > Since you're already adding a tnum helper, I think we can add one more
> > for BPF_JSET here
> >
> >         struct tnum tnum_neg(struct tnum a)
> >         {
> >                 return TNUM(~a.value, a.mask);
> >         }
> >
>
> I'm not sure what tnum_neg() does (even if the correct
> implementation), but either way I'd like to minimize touching tnum
> stuff, it's too tricky :) we can address that as a separate patch if
> you'd like
>
>
> > So instead of getting a value out of tnum then putting the value back
> > into tnum again
> >
> >     u64 val;
> >     val = reg_const_value(reg2, is_jmp32);
> >     tnum_ops(..., tnum_const(val or ~val);
> >
> > Keep the value in tnum and process it as-is if possible
> >
> >     tnum_ops(..., reg2->var_off or tnum_neg(reg2->var_off));
>
> >
> > And with that hopefully make this fragment short enough that we don't
> > mind duplicate a bit of code to seperate the BPF_JSET case from the
> > BPF_JSET | BPF_X case. IMO a conditional is_power_of_2 check followed by
> > two level of branching is a bit too much to follow, it is better to have
> > them seperated just like how you're doing it for the others already.
>
> I can split those two cases without any new tnum helpers, the
> duplicated part is just const checking, basically, no big deal
>
> >
> > I.e. something like the follow
> >
> >         case BPF_JSET: {
> >                 if (!is_reg_const(reg2, is_jmp32))
> >                         swap(reg1, reg2);
> >                 if (!is_reg_const(reg2, is_jmp32))
> >                         break;
> >                 /* comment */
> >                 if (!is_power_of_2(reg_const_value(reg2, is_jmp32))
> >                         break;
> >
> >                 if (is_jmp32) {
> >                         t = tnum_or(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
> >                         reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> >                 } else {
> >                         reg1->var_off = tnum_or(reg1->var_off, reg2->var_off);
> >                 }
> >                 break;
> >         }
> >         case BPF_JSET | BPF_X: {
> >                 if (!is_reg_const(reg2, is_jmp32))
> >                         swap(reg1, reg2);
> >                 if (!is_reg_const(reg2, is_jmp32))
> >                         break;
> >
> >                 if (is_jmp32) {
> >                         /* a slightly long line ... */
> >                         t = tnum_and(tnum_subreg(reg1->var_off), tnum_neg(tnum_subreg(reg2->var_off)));
> >                         reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> >                 } else {
> >                         reg1->var_off = tnum_and(reg1->var_off, tnum_neg(reg2->var_off));
> >                 }
> >                 break;
> >         }
> >
> > > ...
Shung-Hsi Yu Nov. 6, 2023, 2:22 a.m. UTC | #7
On Fri, Nov 03, 2023 at 01:48:32PM -0700, Andrii Nakryiko wrote:
> On Fri, Nov 3, 2023 at 1:39 PM Andrii Nakryiko
> <andrii.nakryiko@gmail.com> wrote:
> > On Fri, Nov 3, 2023 at 12:52 AM Shung-Hsi Yu <shung-hsi.yu@suse.com> wrote:
> > > On Thu, Nov 02, 2023 at 05:08:10PM -0700, Andrii Nakryiko wrote:
> > > > Generalize bounds adjustment logic of reg_set_min_max() to handle not
> > > > just register vs constant case, but in general any register vs any
> > > > register cases. For most of the operations it's trivial extension based
> > > > on range vs range comparison logic, we just need to properly pick
> > > > min/max of a range to compare against min/max of the other range.
> > > >
> > > > For BPF_JSET we keep the original capabilities, just make sure JSET is
> > > > integrated in the common framework. This is manifested in the
> > > > internal-only BPF_KSET + BPF_X "opcode" to allow for simpler and more
> > >                     ^ typo?
> > >
> > > Two more comments below
> > >
> > > > uniform rev_opcode() handling. See the code for details. This allows to
> > > > reuse the same code exactly both for TRUE and FALSE branches without
> > > > explicitly handling both conditions with custom code.
> > > >
> > > > Note also that now we don't need a special handling of BPF_JEQ/BPF_JNE
> > > > case none of the registers are constants. This is now just a normal
> > > > generic case handled by reg_set_min_max().
> > > >
> > > > To make tnum handling cleaner, tnum_with_subreg() helper is added, as
> > > > that's a common operator when dealing with 32-bit subregister bounds.
> > > > This keeps the overall logic much less noisy when it comes to tnums.
> > > >
> > > > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> > > > ---
> > > >  include/linux/tnum.h  |   4 +
> > > >  kernel/bpf/tnum.c     |   7 +-
> > > >  kernel/bpf/verifier.c | 327 ++++++++++++++++++++----------------------
> > > >  3 files changed, 165 insertions(+), 173 deletions(-)

...]

> > > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > > index 2197385d91dc..52934080042c 100644
> > > > --- a/kernel/bpf/verifier.c
> > > > +++ b/kernel/bpf/verifier.c
> > > > @@ -14379,218 +14379,211 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
> > > >       return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
> > > >  }
> > > >
> > > > -/* Adjusts the register min/max values in the case that the dst_reg and
> > > > - * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
> > > > - * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
> > > > - * Technically we can do similar adjustments for pointers to the same object,
> > > > - * but we don't support that right now.
> > > > +/* Opcode that corresponds to a *false* branch condition.
> > > > + * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2
> > > >   */
> > > > -static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> > > > -                         struct bpf_reg_state *true_reg2,
> > > > -                         struct bpf_reg_state *false_reg1,
> > > > -                         struct bpf_reg_state *false_reg2,
> > > > -                         u8 opcode, bool is_jmp32)
> > > > +static u8 rev_opcode(u8 opcode)
> > >
> > > Nit: rev_opcode and flip_opcode seems like a possible source of confusing
> > > down the line. Flip and reverse are often interchangable, i.e. "flip the
> > > order" and "reverse the order" is the same thing.
> > >
> > > Maybe "neg_opcode" or "neg_cond_opcode"?
> >
> > neg has too strong connotation with BPF_NEG, so not really happy with
> > this one.

That's true.

> > In selftest I used "complement_op", but it's also quite arbitrary.
> >
> > > Or do it the otherway around, keep rev_opcode but rename flip_opcode.
> >
> > how about flip_opcode -> swap_opcode? and then keep reg_opcode as is?
> 
> nah, swap_opcode sounds wrong as well. I guess I'll just leave it as is for now.

I don't have any better suggestion in mind, so no objection here.

> > >
> > > One more comment about BPF_JSET below
> >
> > please trim big chunks of code you are not commenting on to keep
> > emails a bit shorter

Noted, will do so next time.

> > [...]
> >
> > > >               if (is_jmp32) {
> > > > -                     __mark_reg32_known(false_reg1, uval32);
> > > > -                     false_32off = tnum_subreg(false_reg1->var_off);
> > > > +                     if (opcode & BPF_X)
> > > > +                             t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val));
> > > > +                     else
> > > > +                             t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val));
> > > > +                     reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> > > >               } else {
> > > > -                     ___mark_reg_known(false_reg1, uval);
> > > > -                     false_64off = false_reg1->var_off;
> > > > +                     if (opcode & BPF_X)
> > > > +                             reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val));
> > > > +                     else
> > > > +                             reg1->var_off = tnum_or(reg1->var_off, tnum_const(val));
> > > >               }
> > > >               break;
> > >
> > > Since you're already adding a tnum helper, I think we can add one more
> > > for BPF_JSET here
> > >
> > >         struct tnum tnum_neg(struct tnum a)
> > >         {
> > >                 return TNUM(~a.value, a.mask);
> > >         }
> > >
> >
> > I'm not sure what tnum_neg() does (even if the correct
> > implementation), but either way I'd like to minimize touching tnum
> > stuff, it's too tricky :) we can address that as a separate patch if
> > you'd like

Tricky, but not as tricky as this patchset :)

Seizing this change chance for some shameless self-promotion of slides I
had on tnum

  https://docs.google.com/presentation/d/1Nz2AIvYwAi3rgMNiLV_bn5JjulHJynu9JHulNrTJuZU/edit#slide=id.g16cabc3ff80_0_87

I've send out the tnum change as RFC for now[0]; will resend it along
with the changes proposed here once this patchset or its successor is
merged as suggested.

0: https://lore.kernel.org/bpf/20231106021119.10455-1-shung-hsi.yu@suse.com/

> > > So instead of getting a value out of tnum then putting the value back
> > > into tnum again
> > >
> > >     u64 val;
> > >     val = reg_const_value(reg2, is_jmp32);
> > >     tnum_ops(..., tnum_const(val or ~val);
> > >
> > > Keep the value in tnum and process it as-is if possible
> > >
> > >     tnum_ops(..., reg2->var_off or tnum_neg(reg2->var_off));
> >
> > >
> > > And with that hopefully make this fragment short enough that we don't
> > > mind duplicate a bit of code to seperate the BPF_JSET case from the
> > > BPF_JSET | BPF_X case. IMO a conditional is_power_of_2 check followed by
> > > two level of branching is a bit too much to follow, it is better to have
> > > them seperated just like how you're doing it for the others already.
> >
> > I can split those two cases without any new tnum helpers, the
> > duplicated part is just const checking, basically, no big deal
> >
> > >
> > > I.e. something like the follow
> > >
> > >         case BPF_JSET: {
> > >                 if (!is_reg_const(reg2, is_jmp32))
> > >                         swap(reg1, reg2);
> > >                 if (!is_reg_const(reg2, is_jmp32))
> > >                         break;
> > >                 /* comment */
> > >                 if (!is_power_of_2(reg_const_value(reg2, is_jmp32))
> > >                         break;
> > >
> > >                 if (is_jmp32) {
> > >                         t = tnum_or(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
> > >                         reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> > >                 } else {
> > >                         reg1->var_off = tnum_or(reg1->var_off, reg2->var_off);
> > >                 }
> > >                 break;
> > >         }
> > >         case BPF_JSET | BPF_X: {
> > >                 if (!is_reg_const(reg2, is_jmp32))
> > >                         swap(reg1, reg2);
> > >                 if (!is_reg_const(reg2, is_jmp32))
> > >                         break;
> > >
> > >                 if (is_jmp32) {
> > >                         /* a slightly long line ... */
> > >                         t = tnum_and(tnum_subreg(reg1->var_off), tnum_neg(tnum_subreg(reg2->var_off)));
> > >                         reg1->var_off = tnum_with_subreg(reg1->var_off, t);
> > >                 } else {
> > >                         reg1->var_off = tnum_and(reg1->var_off, tnum_neg(reg2->var_off));
> > >                 }
> > >                 break;
> > >         }
> > >
> > > > ...
diff mbox series

Patch

diff --git a/include/linux/tnum.h b/include/linux/tnum.h
index 1c3948a1d6ad..3c13240077b8 100644
--- a/include/linux/tnum.h
+++ b/include/linux/tnum.h
@@ -106,6 +106,10 @@  int tnum_sbin(char *str, size_t size, struct tnum a);
 struct tnum tnum_subreg(struct tnum a);
 /* Returns the tnum with the lower 32-bit subreg cleared */
 struct tnum tnum_clear_subreg(struct tnum a);
+/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower
+ * 32-bit subreg in *subreg*
+ */
+struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg);
 /* Returns the tnum with the lower 32-bit subreg set to value */
 struct tnum tnum_const_subreg(struct tnum a, u32 value);
 /* Returns true if 32-bit subreg @a is a known constant*/
diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
index 3d7127f439a1..f4c91c9b27d7 100644
--- a/kernel/bpf/tnum.c
+++ b/kernel/bpf/tnum.c
@@ -208,7 +208,12 @@  struct tnum tnum_clear_subreg(struct tnum a)
 	return tnum_lshift(tnum_rshift(a, 32), 32);
 }
 
+struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
+{
+	return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
+}
+
 struct tnum tnum_const_subreg(struct tnum a, u32 value)
 {
-	return tnum_or(tnum_clear_subreg(a), tnum_const(value));
+	return tnum_with_subreg(a, tnum_const(value));
 }
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 2197385d91dc..52934080042c 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -14379,218 +14379,211 @@  static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
 	return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
 }
 
-/* Adjusts the register min/max values in the case that the dst_reg and
- * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
- * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
- * Technically we can do similar adjustments for pointers to the same object,
- * but we don't support that right now.
+/* Opcode that corresponds to a *false* branch condition.
+ * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2
  */
-static void reg_set_min_max(struct bpf_reg_state *true_reg1,
-			    struct bpf_reg_state *true_reg2,
-			    struct bpf_reg_state *false_reg1,
-			    struct bpf_reg_state *false_reg2,
-			    u8 opcode, bool is_jmp32)
+static u8 rev_opcode(u8 opcode)
 {
-	struct tnum false_32off, false_64off;
-	struct tnum true_32off, true_64off;
-	u64 uval;
-	u32 uval32;
-	s64 sval;
-	s32 sval32;
-
-	/* If either register is a pointer, we can't learn anything about its
-	 * variable offset from the compare (unless they were a pointer into
-	 * the same object, but we don't bother with that).
+	switch (opcode) {
+	case BPF_JEQ:		return BPF_JNE;
+	case BPF_JNE:		return BPF_JEQ;
+	/* JSET doesn't have it's reverse opcode in BPF, so add
+	 * BPF_X flag to denote the reverse of that operation
 	 */
-	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
-		return;
-
-	/* we expect right-hand registers (src ones) to be constants, for now */
-	if (!is_reg_const(false_reg2, is_jmp32)) {
-		opcode = flip_opcode(opcode);
-		swap(true_reg1, true_reg2);
-		swap(false_reg1, false_reg2);
+	case BPF_JSET:		return BPF_JSET | BPF_X;
+	case BPF_JSET | BPF_X:	return BPF_JSET;
+	case BPF_JGE:		return BPF_JLT;
+	case BPF_JGT:		return BPF_JLE;
+	case BPF_JLE:		return BPF_JGT;
+	case BPF_JLT:		return BPF_JGE;
+	case BPF_JSGE:		return BPF_JSLT;
+	case BPF_JSGT:		return BPF_JSLE;
+	case BPF_JSLE:		return BPF_JSGT;
+	case BPF_JSLT:		return BPF_JSGE;
+	default:		return 0;
 	}
-	if (!is_reg_const(false_reg2, is_jmp32))
-		return;
+}
 
-	false_32off = tnum_subreg(false_reg1->var_off);
-	false_64off = false_reg1->var_off;
-	true_32off = tnum_subreg(true_reg1->var_off);
-	true_64off = true_reg1->var_off;
-	uval = false_reg2->var_off.value;
-	uval32 = (u32)tnum_subreg(false_reg2->var_off).value;
-	sval = (s64)uval;
-	sval32 = (s32)uval32;
+/* Refine range knowledge for <reg1> <op> <reg>2 conditional operation. */
+static void regs_refine_cond_op(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
+				u8 opcode, bool is_jmp32)
+{
+	struct tnum t;
 
 	switch (opcode) {
-	/* JEQ/JNE comparison doesn't change the register equivalence.
-	 *
-	 * r1 = r2;
-	 * if (r1 == 42) goto label;
-	 * ...
-	 * label: // here both r1 and r2 are known to be 42.
-	 *
-	 * Hence when marking register as known preserve it's ID.
-	 */
 	case BPF_JEQ:
 		if (is_jmp32) {
-			__mark_reg32_known(true_reg1, uval32);
-			true_32off = tnum_subreg(true_reg1->var_off);
+			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
+			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
+			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
+			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
+			reg2->u32_min_value = reg1->u32_min_value;
+			reg2->u32_max_value = reg1->u32_max_value;
+			reg2->s32_min_value = reg1->s32_min_value;
+			reg2->s32_max_value = reg1->s32_max_value;
+
+			t = tnum_intersect(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
+			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
+			reg2->var_off = tnum_with_subreg(reg2->var_off, t);
 		} else {
-			___mark_reg_known(true_reg1, uval);
-			true_64off = true_reg1->var_off;
+			reg1->umin_value = max(reg1->umin_value, reg2->umin_value);
+			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
+			reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
+			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
+			reg2->umin_value = reg1->umin_value;
+			reg2->umax_value = reg1->umax_value;
+			reg2->smin_value = reg1->smin_value;
+			reg2->smax_value = reg1->smax_value;
+
+			reg1->var_off = tnum_intersect(reg1->var_off, reg2->var_off);
+			reg2->var_off = reg1->var_off;
 		}
 		break;
 	case BPF_JNE:
+		/* we don't derive any new information for inequality yet */
+		break;
+	case BPF_JSET:
+	case BPF_JSET | BPF_X: { /* BPF_JSET and its reverse, see rev_opcode() */
+		u64 val;
+
+		if (!is_reg_const(reg2, is_jmp32))
+			swap(reg1, reg2);
+		if (!is_reg_const(reg2, is_jmp32))
+			break;
+
+		val = reg_const_value(reg2, is_jmp32);
+		/* BPF_JSET (i.e., TRUE branch, *not* BPF_JSET | BPF_X)
+		 * requires single bit to learn something useful. E.g., if we
+		 * know that `r1 & 0x3` is true, then which bits (0, 1, or both)
+		 * are actually set? We can learn something definite only if
+		 * it's a single-bit value to begin with.
+		 *
+		 * BPF_JSET | BPF_X (i.e., negation of BPF_JSET) doesn't have
+		 * this restriction. I.e., !(r1 & 0x3) means neither bit 0 nor
+		 * bit 1 is set, which we can readily use in adjustments.
+		 */
+		if (!(opcode & BPF_X) && !is_power_of_2(val))
+			break;
+
 		if (is_jmp32) {
-			__mark_reg32_known(false_reg1, uval32);
-			false_32off = tnum_subreg(false_reg1->var_off);
+			if (opcode & BPF_X)
+				t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val));
+			else
+				t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val));
+			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
 		} else {
-			___mark_reg_known(false_reg1, uval);
-			false_64off = false_reg1->var_off;
+			if (opcode & BPF_X)
+				reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val));
+			else
+				reg1->var_off = tnum_or(reg1->var_off, tnum_const(val));
 		}
 		break;
-	case BPF_JSET:
+	}
+	case BPF_JGE:
 		if (is_jmp32) {
-			false_32off = tnum_and(false_32off, tnum_const(~uval32));
-			if (is_power_of_2(uval32))
-				true_32off = tnum_or(true_32off,
-						     tnum_const(uval32));
+			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
+			reg2->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
 		} else {
-			false_64off = tnum_and(false_64off, tnum_const(~uval));
-			if (is_power_of_2(uval))
-				true_64off = tnum_or(true_64off,
-						     tnum_const(uval));
+			reg1->umin_value = max(reg1->umin_value, reg2->umin_value);
+			reg2->umax_value = min(reg1->umax_value, reg2->umax_value);
 		}
 		break;
-	case BPF_JGE:
 	case BPF_JGT:
-	{
 		if (is_jmp32) {
-			u32 false_umax = opcode == BPF_JGT ? uval32  : uval32 - 1;
-			u32 true_umin = opcode == BPF_JGT ? uval32 + 1 : uval32;
-
-			false_reg1->u32_max_value = min(false_reg1->u32_max_value,
-						       false_umax);
-			true_reg1->u32_min_value = max(true_reg1->u32_min_value,
-						      true_umin);
+			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value + 1);
+			reg2->u32_max_value = min(reg1->u32_max_value - 1, reg2->u32_max_value);
 		} else {
-			u64 false_umax = opcode == BPF_JGT ? uval    : uval - 1;
-			u64 true_umin = opcode == BPF_JGT ? uval + 1 : uval;
-
-			false_reg1->umax_value = min(false_reg1->umax_value, false_umax);
-			true_reg1->umin_value = max(true_reg1->umin_value, true_umin);
+			reg1->umin_value = max(reg1->umin_value, reg2->umin_value + 1);
+			reg2->umax_value = min(reg1->umax_value - 1, reg2->umax_value);
 		}
 		break;
-	}
 	case BPF_JSGE:
+		if (is_jmp32) {
+			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
+			reg2->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
+		} else {
+			reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
+			reg2->smax_value = min(reg1->smax_value, reg2->smax_value);
+		}
+		break;
 	case BPF_JSGT:
-	{
 		if (is_jmp32) {
-			s32 false_smax = opcode == BPF_JSGT ? sval32    : sval32 - 1;
-			s32 true_smin = opcode == BPF_JSGT ? sval32 + 1 : sval32;
-
-			false_reg1->s32_max_value = min(false_reg1->s32_max_value, false_smax);
-			true_reg1->s32_min_value = max(true_reg1->s32_min_value, true_smin);
+			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value + 1);
+			reg2->s32_max_value = min(reg1->s32_max_value - 1, reg2->s32_max_value);
 		} else {
-			s64 false_smax = opcode == BPF_JSGT ? sval    : sval - 1;
-			s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval;
-
-			false_reg1->smax_value = min(false_reg1->smax_value, false_smax);
-			true_reg1->smin_value = max(true_reg1->smin_value, true_smin);
+			reg1->smin_value = max(reg1->smin_value, reg2->smin_value + 1);
+			reg2->smax_value = min(reg1->smax_value - 1, reg2->smax_value);
 		}
 		break;
-	}
 	case BPF_JLE:
+		if (is_jmp32) {
+			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
+			reg2->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
+		} else {
+			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
+			reg2->umin_value = max(reg1->umin_value, reg2->umin_value);
+		}
+		break;
 	case BPF_JLT:
-	{
 		if (is_jmp32) {
-			u32 false_umin = opcode == BPF_JLT ? uval32  : uval32 + 1;
-			u32 true_umax = opcode == BPF_JLT ? uval32 - 1 : uval32;
-
-			false_reg1->u32_min_value = max(false_reg1->u32_min_value,
-						       false_umin);
-			true_reg1->u32_max_value = min(true_reg1->u32_max_value,
-						      true_umax);
+			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value - 1);
+			reg2->u32_min_value = max(reg1->u32_min_value + 1, reg2->u32_min_value);
 		} else {
-			u64 false_umin = opcode == BPF_JLT ? uval    : uval + 1;
-			u64 true_umax = opcode == BPF_JLT ? uval - 1 : uval;
-
-			false_reg1->umin_value = max(false_reg1->umin_value, false_umin);
-			true_reg1->umax_value = min(true_reg1->umax_value, true_umax);
+			reg1->umax_value = min(reg1->umax_value, reg2->umax_value - 1);
+			reg2->umin_value = max(reg1->umin_value + 1, reg2->umin_value);
 		}
 		break;
-	}
 	case BPF_JSLE:
+		if (is_jmp32) {
+			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
+			reg2->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
+		} else {
+			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
+			reg2->smin_value = max(reg1->smin_value, reg2->smin_value);
+		}
+		break;
 	case BPF_JSLT:
-	{
 		if (is_jmp32) {
-			s32 false_smin = opcode == BPF_JSLT ? sval32    : sval32 + 1;
-			s32 true_smax = opcode == BPF_JSLT ? sval32 - 1 : sval32;
-
-			false_reg1->s32_min_value = max(false_reg1->s32_min_value, false_smin);
-			true_reg1->s32_max_value = min(true_reg1->s32_max_value, true_smax);
+			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value - 1);
+			reg2->s32_min_value = max(reg1->s32_min_value + 1, reg2->s32_min_value);
 		} else {
-			s64 false_smin = opcode == BPF_JSLT ? sval    : sval + 1;
-			s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval;
-
-			false_reg1->smin_value = max(false_reg1->smin_value, false_smin);
-			true_reg1->smax_value = min(true_reg1->smax_value, true_smax);
+			reg1->smax_value = min(reg1->smax_value, reg2->smax_value - 1);
+			reg2->smin_value = max(reg1->smin_value + 1, reg2->smin_value);
 		}
 		break;
-	}
 	default:
 		return;
 	}
-
-	if (is_jmp32) {
-		false_reg1->var_off = tnum_or(tnum_clear_subreg(false_64off),
-					     tnum_subreg(false_32off));
-		true_reg1->var_off = tnum_or(tnum_clear_subreg(true_64off),
-					    tnum_subreg(true_32off));
-		reg_bounds_sync(false_reg1);
-		reg_bounds_sync(true_reg1);
-	} else {
-		false_reg1->var_off = false_64off;
-		true_reg1->var_off = true_64off;
-		reg_bounds_sync(false_reg1);
-		reg_bounds_sync(true_reg1);
-	}
-}
-
-/* Regs are known to be equal, so intersect their min/max/var_off */
-static void __reg_combine_min_max(struct bpf_reg_state *src_reg,
-				  struct bpf_reg_state *dst_reg)
-{
-	src_reg->umin_value = dst_reg->umin_value = max(src_reg->umin_value,
-							dst_reg->umin_value);
-	src_reg->umax_value = dst_reg->umax_value = min(src_reg->umax_value,
-							dst_reg->umax_value);
-	src_reg->smin_value = dst_reg->smin_value = max(src_reg->smin_value,
-							dst_reg->smin_value);
-	src_reg->smax_value = dst_reg->smax_value = min(src_reg->smax_value,
-							dst_reg->smax_value);
-	src_reg->var_off = dst_reg->var_off = tnum_intersect(src_reg->var_off,
-							     dst_reg->var_off);
-	reg_bounds_sync(src_reg);
-	reg_bounds_sync(dst_reg);
 }
 
-static void reg_combine_min_max(struct bpf_reg_state *true_src,
-				struct bpf_reg_state *true_dst,
-				struct bpf_reg_state *false_src,
-				struct bpf_reg_state *false_dst,
-				u8 opcode)
+/* Adjusts the register min/max values in the case that the dst_reg and
+ * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
+ * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
+ * Technically we can do similar adjustments for pointers to the same object,
+ * but we don't support that right now.
+ */
+static void reg_set_min_max(struct bpf_reg_state *true_reg1,
+			    struct bpf_reg_state *true_reg2,
+			    struct bpf_reg_state *false_reg1,
+			    struct bpf_reg_state *false_reg2,
+			    u8 opcode, bool is_jmp32)
 {
-	switch (opcode) {
-	case BPF_JEQ:
-		__reg_combine_min_max(true_src, true_dst);
-		break;
-	case BPF_JNE:
-		__reg_combine_min_max(false_src, false_dst);
-		break;
-	}
+	/* If either register is a pointer, we can't learn anything about its
+	 * variable offset from the compare (unless they were a pointer into
+	 * the same object, but we don't bother with that).
+	 */
+	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
+		return;
+
+	/* fallthrough (FALSE) branch */
+	regs_refine_cond_op(false_reg1, false_reg2, rev_opcode(opcode), is_jmp32);
+	reg_bounds_sync(false_reg1);
+	reg_bounds_sync(false_reg2);
+
+	/* jump (TRUE) branch */
+	regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32);
+	reg_bounds_sync(true_reg1);
+	reg_bounds_sync(true_reg2);
 }
 
 static void mark_ptr_or_null_reg(struct bpf_func_state *state,
@@ -14887,22 +14880,12 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 		reg_set_min_max(&other_branch_regs[insn->dst_reg],
 				&other_branch_regs[insn->src_reg],
 				dst_reg, src_reg, opcode, is_jmp32);
-
-		if (dst_reg->type == SCALAR_VALUE &&
-		    src_reg->type == SCALAR_VALUE &&
-		    !is_jmp32 && (opcode == BPF_JEQ || opcode == BPF_JNE)) {
-			/* Comparing for equality, we can combine knowledge */
-			reg_combine_min_max(&other_branch_regs[insn->src_reg],
-					    &other_branch_regs[insn->dst_reg],
-					    src_reg, dst_reg, opcode);
-		}
 	} else /* BPF_SRC(insn->code) == BPF_K */ {
 		reg_set_min_max(&other_branch_regs[insn->dst_reg],
 				src_reg /* fake one */,
 				dst_reg, src_reg /* same fake one */,
 				opcode, is_jmp32);
 	}
-
 	if (BPF_SRC(insn->code) == BPF_X &&
 	    src_reg->type == SCALAR_VALUE && src_reg->id &&
 	    !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {