diff mbox series

[v5,bpf-next,19/23] bpf: generalize is_scalar_branch_taken() logic

Message ID 20231027181346.4019398-20-andrii@kernel.org (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series BPF register bounds logic and testing improvements | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-VM_Test-30 success Logs for x86_64-llvm-16 / test (test_progs_no_alu32_parallel, true, 30) / test_progs_no_alu32_parallel on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-31 success Logs for x86_64-llvm-16 / test (test_progs_parallel, true, 30) / test_progs_parallel on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-32 success Logs for x86_64-llvm-16 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-33 success Logs for x86_64-llvm-16 / veristat
netdev/series_format fail Series longer than 15 patches (and no cover letter)
netdev/tree_selection success Clearly marked for bpf-next, async
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 1374 this patch: 1374
netdev/cc_maintainers warning 8 maintainers not CCed: john.fastabend@gmail.com kpsingh@kernel.org song@kernel.org sdf@google.com jolsa@kernel.org martin.lau@linux.dev yonghong.song@linux.dev haoluo@google.com
netdev/build_clang fail Errors and warnings before: 15 this patch: 15
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 1399 this patch: 1399
netdev/checkpatch warning WARNING: line length of 81 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-0 success Logs for Lint
bpf/vmtest-bpf-next-VM_Test-2 success Logs for Validate matrix.py
bpf/vmtest-bpf-next-VM_Test-3 success Logs for aarch64-gcc / build / build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-8 success Logs for aarch64-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-4 success Logs for aarch64-gcc / test (test_maps, false, 360) / test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-5 success Logs for aarch64-gcc / test (test_progs, false, 360) / test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-6 success Logs for aarch64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-7 success Logs for aarch64-gcc / test (test_verifier, false, 360) / test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-9 success Logs for s390x-gcc / build / build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-14 success Logs for s390x-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-15 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-16 success Logs for x86_64-gcc / build / build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-17 success Logs for x86_64-gcc / test (test_maps, false, 360) / test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-18 success Logs for x86_64-gcc / test (test_progs, false, 360) / test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-19 success Logs for x86_64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-20 success Logs for x86_64-gcc / test (test_progs_no_alu32_parallel, true, 30) / test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-21 success Logs for x86_64-gcc / test (test_progs_parallel, true, 30) / test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for x86_64-gcc / test (test_verifier, false, 360) / test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for x86_64-gcc / veristat / veristat on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-24 success Logs for x86_64-llvm-16 / build / build for x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-25 success Logs for x86_64-llvm-16 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-26 success Logs for x86_64-llvm-16 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-27 success Logs for x86_64-llvm-16 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-28 success Logs for x86_64-llvm-16 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-29 success Logs for x86_64-llvm-16 / veristat
bpf/vmtest-bpf-next-VM_Test-13 success Logs for s390x-gcc / test (test_verifier, false, 360) / test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-12 success Logs for s390x-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-PR success PR summary
bpf/vmtest-bpf-next-VM_Test-10 success Logs for s390x-gcc / test (test_maps, false, 360) / test_maps on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-11 success Logs for s390x-gcc / test (test_progs, false, 360) / test_progs on s390x with gcc

Commit Message

Andrii Nakryiko Oct. 27, 2023, 6:13 p.m. UTC
Generalize is_branch_taken logic for SCALAR_VALUE register to handle
cases when both registers are not constants. Previously supported
<range> vs <scalar> cases are a natural subset of more generic <range>
vs <range> set of cases.

Generalized logic relies on straightforward segment intersection checks.

Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
---
 kernel/bpf/verifier.c | 104 ++++++++++++++++++++++++++----------------
 1 file changed, 64 insertions(+), 40 deletions(-)

Comments

Alexei Starovoitov Oct. 31, 2023, 2:12 a.m. UTC | #1
On Fri, Oct 27, 2023 at 11:13:42AM -0700, Andrii Nakryiko wrote:
> Generalize is_branch_taken logic for SCALAR_VALUE register to handle
> cases when both registers are not constants. Previously supported
> <range> vs <scalar> cases are a natural subset of more generic <range>
> vs <range> set of cases.
> 
> Generalized logic relies on straightforward segment intersection checks.
> 
> Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> ---
>  kernel/bpf/verifier.c | 104 ++++++++++++++++++++++++++----------------
>  1 file changed, 64 insertions(+), 40 deletions(-)
> 
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 4c974296127b..f18a8247e5e2 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -14189,82 +14189,105 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
>  				  u8 opcode, bool is_jmp32)
>  {
>  	struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> +	struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
>  	u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
>  	u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
>  	s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
>  	s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> -	u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> -	s64 sval = is_jmp32 ? (s32)val : (s64)val;
> +	u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
> +	u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
> +	s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
> +	s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
>  
>  	switch (opcode) {
>  	case BPF_JEQ:
> -		if (tnum_is_const(t1))
> -			return !!tnum_equals_const(t1, val);
> -		else if (val < umin1 || val > umax1)
> +		/* const tnums */
> +		if (tnum_is_const(t1) && tnum_is_const(t2))
> +			return t1.value == t2.value;
> +		/* const ranges */
> +		if (umin1 == umax1 && umin2 == umax2)
> +			return umin1 == umin2;

I don't follow this logic.
umin1 == umax1 means that it's a single constant and
it should have been handled by earlier tnum_is_const check.

> +		if (smin1 == smax1 && smin2 == smax2)
> +			return umin1 == umin2;

here it's even more confusing. smin == smax -> singel const,
but then compare umin1 with umin2 ?!

> +		/* non-overlapping ranges */
> +		if (umin1 > umax2 || umax1 < umin2)
>  			return 0;
> -		else if (sval < smin1 || sval > smax1)
> +		if (smin1 > smax2 || smax1 < smin2)
>  			return 0;

this part makes sense.
Andrii Nakryiko Oct. 31, 2023, 6:12 a.m. UTC | #2
On Mon, Oct 30, 2023 at 7:12 PM Alexei Starovoitov
<alexei.starovoitov@gmail.com> wrote:
>
> On Fri, Oct 27, 2023 at 11:13:42AM -0700, Andrii Nakryiko wrote:
> > Generalize is_branch_taken logic for SCALAR_VALUE register to handle
> > cases when both registers are not constants. Previously supported
> > <range> vs <scalar> cases are a natural subset of more generic <range>
> > vs <range> set of cases.
> >
> > Generalized logic relies on straightforward segment intersection checks.
> >
> > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> > ---
> >  kernel/bpf/verifier.c | 104 ++++++++++++++++++++++++++----------------
> >  1 file changed, 64 insertions(+), 40 deletions(-)
> >
> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > index 4c974296127b..f18a8247e5e2 100644
> > --- a/kernel/bpf/verifier.c
> > +++ b/kernel/bpf/verifier.c
> > @@ -14189,82 +14189,105 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
> >                                 u8 opcode, bool is_jmp32)
> >  {
> >       struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> > +     struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
> >       u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
> >       u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
> >       s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
> >       s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> > -     u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> > -     s64 sval = is_jmp32 ? (s32)val : (s64)val;
> > +     u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
> > +     u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
> > +     s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
> > +     s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
> >
> >       switch (opcode) {
> >       case BPF_JEQ:
> > -             if (tnum_is_const(t1))
> > -                     return !!tnum_equals_const(t1, val);
> > -             else if (val < umin1 || val > umax1)
> > +             /* const tnums */
> > +             if (tnum_is_const(t1) && tnum_is_const(t2))
> > +                     return t1.value == t2.value;
> > +             /* const ranges */
> > +             if (umin1 == umax1 && umin2 == umax2)
> > +                     return umin1 == umin2;
>
> I don't follow this logic.
> umin1 == umax1 means that it's a single constant and
> it should have been handled by earlier tnum_is_const check.

I think you follow the logic, you just think it's redundant. Yes, it's
basically the same as

          if (tnum_is_const(t1) && tnum_is_const(t2))
                return t1.value == t2.value;

but based on ranges. I didn't feel comfortable to assume that if umin1
== umax1 then tnum_is_const(t1) will always be true. At worst we'll
perform one redundant check.

In short, I don't trust tnum to be as precise as umin/umax and other ranges.

>
> > +             if (smin1 == smax1 && smin2 == smax2)
> > +                     return umin1 == umin2;
>
> here it's even more confusing. smin == smax -> singel const,
> but then compare umin1 with umin2 ?!

Eagle eyes! Typo, sorry :( it should be `smin1 == smin2`, of course.

What saves us is reg_bounds_sync(), and if we have umin1 == umax1 then
we'll have also smin1 == smax1 == umin1 == umax1 (and corresponding
relation for second register). But I fixed these typos in both BPF_JEQ
and BPF_JNE branches.


>
> > +             /* non-overlapping ranges */
> > +             if (umin1 > umax2 || umax1 < umin2)
> >                       return 0;
> > -             else if (sval < smin1 || sval > smax1)
> > +             if (smin1 > smax2 || smax1 < smin2)
> >                       return 0;
>
> this part makes sense.
Alexei Starovoitov Oct. 31, 2023, 4:34 p.m. UTC | #3
On Mon, Oct 30, 2023 at 11:12 PM Andrii Nakryiko
<andrii.nakryiko@gmail.com> wrote:
>
> On Mon, Oct 30, 2023 at 7:12 PM Alexei Starovoitov
> <alexei.starovoitov@gmail.com> wrote:
> >
> > On Fri, Oct 27, 2023 at 11:13:42AM -0700, Andrii Nakryiko wrote:
> > > Generalize is_branch_taken logic for SCALAR_VALUE register to handle
> > > cases when both registers are not constants. Previously supported
> > > <range> vs <scalar> cases are a natural subset of more generic <range>
> > > vs <range> set of cases.
> > >
> > > Generalized logic relies on straightforward segment intersection checks.
> > >
> > > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> > > ---
> > >  kernel/bpf/verifier.c | 104 ++++++++++++++++++++++++++----------------
> > >  1 file changed, 64 insertions(+), 40 deletions(-)
> > >
> > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > index 4c974296127b..f18a8247e5e2 100644
> > > --- a/kernel/bpf/verifier.c
> > > +++ b/kernel/bpf/verifier.c
> > > @@ -14189,82 +14189,105 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
> > >                                 u8 opcode, bool is_jmp32)
> > >  {
> > >       struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> > > +     struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
> > >       u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
> > >       u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
> > >       s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
> > >       s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> > > -     u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> > > -     s64 sval = is_jmp32 ? (s32)val : (s64)val;
> > > +     u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
> > > +     u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
> > > +     s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
> > > +     s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
> > >
> > >       switch (opcode) {
> > >       case BPF_JEQ:
> > > -             if (tnum_is_const(t1))
> > > -                     return !!tnum_equals_const(t1, val);
> > > -             else if (val < umin1 || val > umax1)
> > > +             /* const tnums */
> > > +             if (tnum_is_const(t1) && tnum_is_const(t2))
> > > +                     return t1.value == t2.value;
> > > +             /* const ranges */
> > > +             if (umin1 == umax1 && umin2 == umax2)
> > > +                     return umin1 == umin2;
> >
> > I don't follow this logic.
> > umin1 == umax1 means that it's a single constant and
> > it should have been handled by earlier tnum_is_const check.
>
> I think you follow the logic, you just think it's redundant. Yes, it's
> basically the same as
>
>           if (tnum_is_const(t1) && tnum_is_const(t2))
>                 return t1.value == t2.value;
>
> but based on ranges. I didn't feel comfortable to assume that if umin1
> == umax1 then tnum_is_const(t1) will always be true. At worst we'll
> perform one redundant check.
>
> In short, I don't trust tnum to be as precise as umin/umax and other ranges.
>
> >
> > > +             if (smin1 == smax1 && smin2 == smax2)
> > > +                     return umin1 == umin2;
> >
> > here it's even more confusing. smin == smax -> singel const,
> > but then compare umin1 with umin2 ?!
>
> Eagle eyes! Typo, sorry :( it should be `smin1 == smin2`, of course.
>
> What saves us is reg_bounds_sync(), and if we have umin1 == umax1 then
> we'll have also smin1 == smax1 == umin1 == umax1 (and corresponding
> relation for second register). But I fixed these typos in both BPF_JEQ
> and BPF_JNE branches.

Not just 'saves us'. The tnum <-> bounds sync is mandatory.
I think we have a test where a function returns [-errno, 0]
and then we do if (ret < 0) check. At this point the reg has
to be tnum_is_const and zero.
So if smin1 == smax1 == umin1 == umax1 it should be tnum_is_const.
Otherwise it's a bug in sync logic.
I think instead of doing redundant and confusing check may be
add WARN either here or in sync logic to make sure it's all good ?
Andrii Nakryiko Oct. 31, 2023, 6:01 p.m. UTC | #4
On Tue, Oct 31, 2023 at 9:35 AM Alexei Starovoitov
<alexei.starovoitov@gmail.com> wrote:
>
> On Mon, Oct 30, 2023 at 11:12 PM Andrii Nakryiko
> <andrii.nakryiko@gmail.com> wrote:
> >
> > On Mon, Oct 30, 2023 at 7:12 PM Alexei Starovoitov
> > <alexei.starovoitov@gmail.com> wrote:
> > >
> > > On Fri, Oct 27, 2023 at 11:13:42AM -0700, Andrii Nakryiko wrote:
> > > > Generalize is_branch_taken logic for SCALAR_VALUE register to handle
> > > > cases when both registers are not constants. Previously supported
> > > > <range> vs <scalar> cases are a natural subset of more generic <range>
> > > > vs <range> set of cases.
> > > >
> > > > Generalized logic relies on straightforward segment intersection checks.
> > > >
> > > > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> > > > ---
> > > >  kernel/bpf/verifier.c | 104 ++++++++++++++++++++++++++----------------
> > > >  1 file changed, 64 insertions(+), 40 deletions(-)
> > > >
> > > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > > index 4c974296127b..f18a8247e5e2 100644
> > > > --- a/kernel/bpf/verifier.c
> > > > +++ b/kernel/bpf/verifier.c
> > > > @@ -14189,82 +14189,105 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
> > > >                                 u8 opcode, bool is_jmp32)
> > > >  {
> > > >       struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> > > > +     struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
> > > >       u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
> > > >       u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
> > > >       s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
> > > >       s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> > > > -     u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> > > > -     s64 sval = is_jmp32 ? (s32)val : (s64)val;
> > > > +     u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
> > > > +     u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
> > > > +     s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
> > > > +     s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
> > > >
> > > >       switch (opcode) {
> > > >       case BPF_JEQ:
> > > > -             if (tnum_is_const(t1))
> > > > -                     return !!tnum_equals_const(t1, val);
> > > > -             else if (val < umin1 || val > umax1)
> > > > +             /* const tnums */
> > > > +             if (tnum_is_const(t1) && tnum_is_const(t2))
> > > > +                     return t1.value == t2.value;
> > > > +             /* const ranges */
> > > > +             if (umin1 == umax1 && umin2 == umax2)
> > > > +                     return umin1 == umin2;
> > >
> > > I don't follow this logic.
> > > umin1 == umax1 means that it's a single constant and
> > > it should have been handled by earlier tnum_is_const check.
> >
> > I think you follow the logic, you just think it's redundant. Yes, it's
> > basically the same as
> >
> >           if (tnum_is_const(t1) && tnum_is_const(t2))
> >                 return t1.value == t2.value;
> >
> > but based on ranges. I didn't feel comfortable to assume that if umin1
> > == umax1 then tnum_is_const(t1) will always be true. At worst we'll
> > perform one redundant check.
> >
> > In short, I don't trust tnum to be as precise as umin/umax and other ranges.
> >
> > >
> > > > +             if (smin1 == smax1 && smin2 == smax2)
> > > > +                     return umin1 == umin2;
> > >
> > > here it's even more confusing. smin == smax -> singel const,
> > > but then compare umin1 with umin2 ?!
> >
> > Eagle eyes! Typo, sorry :( it should be `smin1 == smin2`, of course.
> >
> > What saves us is reg_bounds_sync(), and if we have umin1 == umax1 then
> > we'll have also smin1 == smax1 == umin1 == umax1 (and corresponding
> > relation for second register). But I fixed these typos in both BPF_JEQ
> > and BPF_JNE branches.
>
> Not just 'saves us'. The tnum <-> bounds sync is mandatory.
> I think we have a test where a function returns [-errno, 0]
> and then we do if (ret < 0) check. At this point the reg has
> to be tnum_is_const and zero.
> So if smin1 == smax1 == umin1 == umax1 it should be tnum_is_const.
> Otherwise it's a bug in sync logic.
> I think instead of doing redundant and confusing check may be
> add WARN either here or in sync logic to make sure it's all good ?

Ok, let's add it as part of register state sanity checks we discussed
on another patch. I'll drop the checks and will re-run all the test to
make sure we are not missing anything.
Andrii Nakryiko Oct. 31, 2023, 8:53 p.m. UTC | #5
On Tue, Oct 31, 2023 at 11:01 AM Andrii Nakryiko
<andrii.nakryiko@gmail.com> wrote:
>
> On Tue, Oct 31, 2023 at 9:35 AM Alexei Starovoitov
> <alexei.starovoitov@gmail.com> wrote:
> >
> > On Mon, Oct 30, 2023 at 11:12 PM Andrii Nakryiko
> > <andrii.nakryiko@gmail.com> wrote:
> > >
> > > On Mon, Oct 30, 2023 at 7:12 PM Alexei Starovoitov
> > > <alexei.starovoitov@gmail.com> wrote:
> > > >
> > > > On Fri, Oct 27, 2023 at 11:13:42AM -0700, Andrii Nakryiko wrote:
> > > > > Generalize is_branch_taken logic for SCALAR_VALUE register to handle
> > > > > cases when both registers are not constants. Previously supported
> > > > > <range> vs <scalar> cases are a natural subset of more generic <range>
> > > > > vs <range> set of cases.
> > > > >
> > > > > Generalized logic relies on straightforward segment intersection checks.
> > > > >
> > > > > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> > > > > ---
> > > > >  kernel/bpf/verifier.c | 104 ++++++++++++++++++++++++++----------------
> > > > >  1 file changed, 64 insertions(+), 40 deletions(-)
> > > > >
> > > > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > > > index 4c974296127b..f18a8247e5e2 100644
> > > > > --- a/kernel/bpf/verifier.c
> > > > > +++ b/kernel/bpf/verifier.c
> > > > > @@ -14189,82 +14189,105 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
> > > > >                                 u8 opcode, bool is_jmp32)
> > > > >  {
> > > > >       struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> > > > > +     struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
> > > > >       u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
> > > > >       u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
> > > > >       s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
> > > > >       s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> > > > > -     u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> > > > > -     s64 sval = is_jmp32 ? (s32)val : (s64)val;
> > > > > +     u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
> > > > > +     u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
> > > > > +     s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
> > > > > +     s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
> > > > >
> > > > >       switch (opcode) {
> > > > >       case BPF_JEQ:
> > > > > -             if (tnum_is_const(t1))
> > > > > -                     return !!tnum_equals_const(t1, val);
> > > > > -             else if (val < umin1 || val > umax1)
> > > > > +             /* const tnums */
> > > > > +             if (tnum_is_const(t1) && tnum_is_const(t2))
> > > > > +                     return t1.value == t2.value;
> > > > > +             /* const ranges */
> > > > > +             if (umin1 == umax1 && umin2 == umax2)
> > > > > +                     return umin1 == umin2;
> > > >
> > > > I don't follow this logic.
> > > > umin1 == umax1 means that it's a single constant and
> > > > it should have been handled by earlier tnum_is_const check.
> > >
> > > I think you follow the logic, you just think it's redundant. Yes, it's
> > > basically the same as
> > >
> > >           if (tnum_is_const(t1) && tnum_is_const(t2))
> > >                 return t1.value == t2.value;
> > >
> > > but based on ranges. I didn't feel comfortable to assume that if umin1
> > > == umax1 then tnum_is_const(t1) will always be true. At worst we'll
> > > perform one redundant check.
> > >
> > > In short, I don't trust tnum to be as precise as umin/umax and other ranges.
> > >
> > > >
> > > > > +             if (smin1 == smax1 && smin2 == smax2)
> > > > > +                     return umin1 == umin2;
> > > >
> > > > here it's even more confusing. smin == smax -> singel const,
> > > > but then compare umin1 with umin2 ?!
> > >
> > > Eagle eyes! Typo, sorry :( it should be `smin1 == smin2`, of course.
> > >
> > > What saves us is reg_bounds_sync(), and if we have umin1 == umax1 then
> > > we'll have also smin1 == smax1 == umin1 == umax1 (and corresponding
> > > relation for second register). But I fixed these typos in both BPF_JEQ
> > > and BPF_JNE branches.
> >
> > Not just 'saves us'. The tnum <-> bounds sync is mandatory.
> > I think we have a test where a function returns [-errno, 0]
> > and then we do if (ret < 0) check. At this point the reg has
> > to be tnum_is_const and zero.
> > So if smin1 == smax1 == umin1 == umax1 it should be tnum_is_const.
> > Otherwise it's a bug in sync logic.
> > I think instead of doing redundant and confusing check may be
> > add WARN either here or in sync logic to make sure it's all good ?
>
> Ok, let's add it as part of register state sanity checks we discussed
> on another patch. I'll drop the checks and will re-run all the test to
> make sure we are not missing anything.

So I have this as one more patch for the next revision (pending local
testing). If you hate any part of it, I'd appreciate early feedback :)
I'll wait for Eduard to finish going through the series (probably
tomorrow), and then will post the next version based on all the
feedback I got (and whatever might still come).

Note, in the below, I don't output the actual register state on
violation, which is unfortunate. But to make this happen I need to
refactor print_verifier_state() to allow me to print register state.
I've been wanting to move print_verifier_state() into kernel/bpf/log.c
for a while now, and fix how we print the state of spilled registers
(and maybe few more small things), so I'll do that separately, and
then add register state printing to sanity check error.


Author: Andrii Nakryiko <andrii@kernel.org>
Date:   Tue Oct 31 13:34:33 2023 -0700

    bpf: add register bounds sanity checks

    Add simple sanity checks that validate well-formed ranges (min <= max)
    across u64, s64, u32, and s32 ranges. Also for cases when the value is
    constant (either 64-bit or 32-bit), we validate that ranges and tnums
    are in agreement.

    These bounds checks are performed at the end of BPF_ALU/BPF_ALU64
    operations, on conditional jumps, and for LDX instructions (where subreg
    zero/sign extension is probably the most important to check). This
    covers most of the interesting cases.

    Also, we validate the sanity of the return register when manually
adjusting it
    for some special helpers.

    Signed-off-by: Andrii Nakryiko <andrii@kernel.org>

diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index c85d974ba21f..b29c85089bc9 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -2615,6 +2615,46 @@ static void reg_bounds_sync(struct bpf_reg_state *reg)
        __update_reg_bounds(reg);
 }

+static int reg_bounds_sanity_check(struct bpf_verifier_env *env,
struct bpf_reg_state *reg)
+{
+       const char *msg;
+
+       if (reg->umin_value > reg->umax_value ||
+           reg->smin_value > reg->smax_value ||
+           reg->u32_min_value > reg->u32_max_value ||
+           reg->s32_min_value > reg->s32_max_value) {
+                   msg = "range bounds violation";
+                   goto out;
+       }
+
+       if (tnum_is_const(reg->var_off)) {
+               u64 uval = reg->var_off.value;
+               s64 sval = (s64)uval;
+
+               if (reg->umin_value != uval || reg->umax_value != uval ||
+                   reg->smin_value != sval || reg->smax_value != sval) {
+                       msg = "const tnum out of sync with range bounds";
+                       goto out;
+               }
+       }
+
+       if (tnum_subreg_is_const(reg->var_off)) {
+               u32 uval32 = tnum_subreg(reg->var_off).value;
+               s32 sval32 = (s32)uval32;
+
+               if (reg->u32_min_value != uval32 || reg->u32_max_value
!= uval32 ||
+                   reg->s32_min_value != sval32 || reg->s32_max_value
!= sval32) {
+                       msg = "const tnum (subreg) out of sync with
range bounds";
+                       goto out;
+               }
+       }
+
+       return 0;
+out:
+       verbose(env, "%s\n", msg);
+       return -EFAULT;
+}
+
 static bool __reg32_bound_s64(s32 a)
 {
        return a >= 0 && a <= S32_MAX;
@@ -9928,14 +9968,15 @@ static int prepare_func_exit(struct
bpf_verifier_env *env, int *insn_idx)
        return 0;
 }

-static void do_refine_retval_range(struct bpf_reg_state *regs, int ret_type,
-                                  int func_id,
-                                  struct bpf_call_arg_meta *meta)
+static int do_refine_retval_range(struct bpf_verifier_env *env,
+                                 struct bpf_reg_state *regs, int ret_type,
+                                 int func_id,
+                                 struct bpf_call_arg_meta *meta)
 {
        struct bpf_reg_state *ret_reg = &regs[BPF_REG_0];

        if (ret_type != RET_INTEGER)
-               return;
+               return 0;

        switch (func_id) {
        case BPF_FUNC_get_stack:
@@ -9961,6 +10002,8 @@ static void do_refine_retval_range(struct
bpf_reg_state *regs, int ret_type,
                reg_bounds_sync(ret_reg);
                break;
        }
+
+       return reg_bounds_sanity_check(env, ret_reg);
 }

 static int
@@ -10612,7 +10655,9 @@ static int check_helper_call(struct
bpf_verifier_env *env, struct bpf_insn *insn
                regs[BPF_REG_0].ref_obj_id = id;
        }

-       do_refine_retval_range(regs, fn->ret_type, func_id, &meta);
+       err = do_refine_retval_range(env, regs, fn->ret_type, func_id, &meta);
+       if (err)
+               return err;

        err = check_map_func_compatibility(env, meta.map_ptr, func_id);
        if (err)
@@ -14079,13 +14124,12 @@ static int check_alu_op(struct
bpf_verifier_env *env, struct bpf_insn *insn)

                /* check dest operand */
                err = check_reg_arg(env, insn->dst_reg, DST_OP_NO_MARK);
+               err = err ?: adjust_reg_min_max_vals(env, insn);
                if (err)
                        return err;
-
-               return adjust_reg_min_max_vals(env, insn);
        }

-       return 0;
+       return reg_bounds_sanity_check(env, &regs[insn->dst_reg]);
 }

 static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
@@ -14600,18 +14644,21 @@ static void regs_refine_cond_op(struct
bpf_reg_state *reg1, struct bpf_reg_state
  * Technically we can do similar adjustments for pointers to the same object,
  * but we don't support that right now.
  */
-static void reg_set_min_max(struct bpf_reg_state *true_reg1,
-                           struct bpf_reg_state *true_reg2,
-                           struct bpf_reg_state *false_reg1,
-                           struct bpf_reg_state *false_reg2,
-                           u8 opcode, bool is_jmp32)
+static int reg_set_min_max(struct bpf_verifier_env *env,
+                          struct bpf_reg_state *true_reg1,
+                          struct bpf_reg_state *true_reg2,
+                          struct bpf_reg_state *false_reg1,
+                          struct bpf_reg_state *false_reg2,
+                          u8 opcode, bool is_jmp32)
 {
+       int err;
+
        /* If either register is a pointer, we can't learn anything about its
         * variable offset from the compare (unless they were a pointer into
         * the same object, but we don't bother with that).
         */
        if (false_reg1->type != SCALAR_VALUE || false_reg2->type !=
SCALAR_VALUE)
-               return;
+               return 0;

        /* fallthrough (FALSE) branch */
        regs_refine_cond_op(false_reg1, false_reg2,
rev_opcode(opcode), is_jmp32);
@@ -14622,6 +14669,12 @@ static void reg_set_min_max(struct
bpf_reg_state *true_reg1,
        regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32);
        reg_bounds_sync(true_reg1);
        reg_bounds_sync(true_reg2);
+
+       err = reg_bounds_sanity_check(env, true_reg1);
+       err = err ?: reg_bounds_sanity_check(env, true_reg2);
+       err = err ?: reg_bounds_sanity_check(env, false_reg1);
+       err = err ?: reg_bounds_sanity_check(env, false_reg2);
+       return err;
 }

 static void mark_ptr_or_null_reg(struct bpf_func_state *state,
@@ -14915,15 +14968,20 @@ static int check_cond_jmp_op(struct
bpf_verifier_env *env,
        other_branch_regs = other_branch->frame[other_branch->curframe]->regs;

        if (BPF_SRC(insn->code) == BPF_X) {
-               reg_set_min_max(&other_branch_regs[insn->dst_reg],
-                               &other_branch_regs[insn->src_reg],
-                               dst_reg, src_reg, opcode, is_jmp32);
+               err = reg_set_min_max(env,
+                                     &other_branch_regs[insn->dst_reg],
+                                     &other_branch_regs[insn->src_reg],
+                                     dst_reg, src_reg, opcode, is_jmp32);
        } else /* BPF_SRC(insn->code) == BPF_K */ {
-               reg_set_min_max(&other_branch_regs[insn->dst_reg],
-                               src_reg /* fake one */,
-                               dst_reg, src_reg /* same fake one */,
-                               opcode, is_jmp32);
+               err = reg_set_min_max(env,
+                                     &other_branch_regs[insn->dst_reg],
+                                     src_reg /* fake one */,
+                                     dst_reg, src_reg /* same fake one */,
+                                     opcode, is_jmp32);
        }
+       if (err)
+               return err;
+
        if (BPF_SRC(insn->code) == BPF_X &&
            src_reg->type == SCALAR_VALUE && src_reg->id &&
            !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
@@ -17426,10 +17484,8 @@ static int do_check(struct bpf_verifier_env *env)
                                               insn->off, BPF_SIZE(insn->code),
                                               BPF_READ, insn->dst_reg, false,
                                               BPF_MODE(insn->code) ==
BPF_MEMSX);
-                       if (err)
-                               return err;
-
-                       err = save_aux_ptr_type(env, src_reg_type, true);
+                       err = err ?: save_aux_ptr_type(env, src_reg_type, true);
+                       err = reg_bounds_sanity_check(env,
&regs[insn->dst_reg]);
                        if (err)
                                return err;
                } else if (class == BPF_STX) {
Andrii Nakryiko Oct. 31, 2023, 8:55 p.m. UTC | #6
On Tue, Oct 31, 2023 at 1:53 PM Andrii Nakryiko
<andrii.nakryiko@gmail.com> wrote:
>
> On Tue, Oct 31, 2023 at 11:01 AM Andrii Nakryiko
> <andrii.nakryiko@gmail.com> wrote:
> >
> > On Tue, Oct 31, 2023 at 9:35 AM Alexei Starovoitov
> > <alexei.starovoitov@gmail.com> wrote:
> > >
> > > On Mon, Oct 30, 2023 at 11:12 PM Andrii Nakryiko
> > > <andrii.nakryiko@gmail.com> wrote:
> > > >
> > > > On Mon, Oct 30, 2023 at 7:12 PM Alexei Starovoitov
> > > > <alexei.starovoitov@gmail.com> wrote:
> > > > >
> > > > > On Fri, Oct 27, 2023 at 11:13:42AM -0700, Andrii Nakryiko wrote:
> > > > > > Generalize is_branch_taken logic for SCALAR_VALUE register to handle
> > > > > > cases when both registers are not constants. Previously supported
> > > > > > <range> vs <scalar> cases are a natural subset of more generic <range>
> > > > > > vs <range> set of cases.
> > > > > >
> > > > > > Generalized logic relies on straightforward segment intersection checks.
> > > > > >
> > > > > > Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
> > > > > > ---
> > > > > >  kernel/bpf/verifier.c | 104 ++++++++++++++++++++++++++----------------
> > > > > >  1 file changed, 64 insertions(+), 40 deletions(-)
> > > > > >
> > > > > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > > > > index 4c974296127b..f18a8247e5e2 100644
> > > > > > --- a/kernel/bpf/verifier.c
> > > > > > +++ b/kernel/bpf/verifier.c
> > > > > > @@ -14189,82 +14189,105 @@ static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
> > > > > >                                 u8 opcode, bool is_jmp32)
> > > > > >  {
> > > > > >       struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
> > > > > > +     struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
> > > > > >       u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
> > > > > >       u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
> > > > > >       s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
> > > > > >       s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
> > > > > > -     u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
> > > > > > -     s64 sval = is_jmp32 ? (s32)val : (s64)val;
> > > > > > +     u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
> > > > > > +     u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
> > > > > > +     s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
> > > > > > +     s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
> > > > > >
> > > > > >       switch (opcode) {
> > > > > >       case BPF_JEQ:
> > > > > > -             if (tnum_is_const(t1))
> > > > > > -                     return !!tnum_equals_const(t1, val);
> > > > > > -             else if (val < umin1 || val > umax1)
> > > > > > +             /* const tnums */
> > > > > > +             if (tnum_is_const(t1) && tnum_is_const(t2))
> > > > > > +                     return t1.value == t2.value;
> > > > > > +             /* const ranges */
> > > > > > +             if (umin1 == umax1 && umin2 == umax2)
> > > > > > +                     return umin1 == umin2;
> > > > >
> > > > > I don't follow this logic.
> > > > > umin1 == umax1 means that it's a single constant and
> > > > > it should have been handled by earlier tnum_is_const check.
> > > >
> > > > I think you follow the logic, you just think it's redundant. Yes, it's
> > > > basically the same as
> > > >
> > > >           if (tnum_is_const(t1) && tnum_is_const(t2))
> > > >                 return t1.value == t2.value;
> > > >
> > > > but based on ranges. I didn't feel comfortable to assume that if umin1
> > > > == umax1 then tnum_is_const(t1) will always be true. At worst we'll
> > > > perform one redundant check.
> > > >
> > > > In short, I don't trust tnum to be as precise as umin/umax and other ranges.
> > > >
> > > > >
> > > > > > +             if (smin1 == smax1 && smin2 == smax2)
> > > > > > +                     return umin1 == umin2;
> > > > >
> > > > > here it's even more confusing. smin == smax -> singel const,
> > > > > but then compare umin1 with umin2 ?!
> > > >
> > > > Eagle eyes! Typo, sorry :( it should be `smin1 == smin2`, of course.
> > > >
> > > > What saves us is reg_bounds_sync(), and if we have umin1 == umax1 then
> > > > we'll have also smin1 == smax1 == umin1 == umax1 (and corresponding
> > > > relation for second register). But I fixed these typos in both BPF_JEQ
> > > > and BPF_JNE branches.
> > >
> > > Not just 'saves us'. The tnum <-> bounds sync is mandatory.
> > > I think we have a test where a function returns [-errno, 0]
> > > and then we do if (ret < 0) check. At this point the reg has
> > > to be tnum_is_const and zero.
> > > So if smin1 == smax1 == umin1 == umax1 it should be tnum_is_const.
> > > Otherwise it's a bug in sync logic.
> > > I think instead of doing redundant and confusing check may be
> > > add WARN either here or in sync logic to make sure it's all good ?
> >
> > Ok, let's add it as part of register state sanity checks we discussed
> > on another patch. I'll drop the checks and will re-run all the test to
> > make sure we are not missing anything.
>
> So I have this as one more patch for the next revision (pending local
> testing). If you hate any part of it, I'd appreciate early feedback :)
> I'll wait for Eduard to finish going through the series (probably
> tomorrow), and then will post the next version based on all the
> feedback I got (and whatever might still come).
>
> Note, in the below, I don't output the actual register state on
> violation, which is unfortunate. But to make this happen I need to
> refactor print_verifier_state() to allow me to print register state.
> I've been wanting to move print_verifier_state() into kernel/bpf/log.c
> for a while now, and fix how we print the state of spilled registers
> (and maybe few more small things), so I'll do that separately, and
> then add register state printing to sanity check error.
>
>
> Author: Andrii Nakryiko <andrii@kernel.org>
> Date:   Tue Oct 31 13:34:33 2023 -0700
>
>     bpf: add register bounds sanity checks
>
>     Add simple sanity checks that validate well-formed ranges (min <= max)
>     across u64, s64, u32, and s32 ranges. Also for cases when the value is
>     constant (either 64-bit or 32-bit), we validate that ranges and tnums
>     are in agreement.
>
>     These bounds checks are performed at the end of BPF_ALU/BPF_ALU64
>     operations, on conditional jumps, and for LDX instructions (where subreg
>     zero/sign extension is probably the most important to check). This
>     covers most of the interesting cases.
>
>     Also, we validate the sanity of the return register when manually
> adjusting it
>     for some special helpers.
>
>     Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
>
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index c85d974ba21f..b29c85089bc9 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -2615,6 +2615,46 @@ static void reg_bounds_sync(struct bpf_reg_state *reg)
>         __update_reg_bounds(reg);
>  }
>
> +static int reg_bounds_sanity_check(struct bpf_verifier_env *env,
> struct bpf_reg_state *reg)
> +{
> +       const char *msg;
> +
> +       if (reg->umin_value > reg->umax_value ||
> +           reg->smin_value > reg->smax_value ||
> +           reg->u32_min_value > reg->u32_max_value ||
> +           reg->s32_min_value > reg->s32_max_value) {
> +                   msg = "range bounds violation";
> +                   goto out;
> +       }
> +
> +       if (tnum_is_const(reg->var_off)) {
> +               u64 uval = reg->var_off.value;
> +               s64 sval = (s64)uval;
> +
> +               if (reg->umin_value != uval || reg->umax_value != uval ||
> +                   reg->smin_value != sval || reg->smax_value != sval) {
> +                       msg = "const tnum out of sync with range bounds";
> +                       goto out;
> +               }
> +       }
> +
> +       if (tnum_subreg_is_const(reg->var_off)) {
> +               u32 uval32 = tnum_subreg(reg->var_off).value;
> +               s32 sval32 = (s32)uval32;
> +
> +               if (reg->u32_min_value != uval32 || reg->u32_max_value
> != uval32 ||
> +                   reg->s32_min_value != sval32 || reg->s32_max_value
> != sval32) {
> +                       msg = "const tnum (subreg) out of sync with
> range bounds";
> +                       goto out;
> +               }
> +       }
> +
> +       return 0;
> +out:
> +       verbose(env, "%s\n", msg);
> +       return -EFAULT;
> +}
> +
>  static bool __reg32_bound_s64(s32 a)
>  {
>         return a >= 0 && a <= S32_MAX;
> @@ -9928,14 +9968,15 @@ static int prepare_func_exit(struct
> bpf_verifier_env *env, int *insn_idx)
>         return 0;
>  }
>
> -static void do_refine_retval_range(struct bpf_reg_state *regs, int ret_type,
> -                                  int func_id,
> -                                  struct bpf_call_arg_meta *meta)
> +static int do_refine_retval_range(struct bpf_verifier_env *env,
> +                                 struct bpf_reg_state *regs, int ret_type,
> +                                 int func_id,
> +                                 struct bpf_call_arg_meta *meta)
>  {
>         struct bpf_reg_state *ret_reg = &regs[BPF_REG_0];
>
>         if (ret_type != RET_INTEGER)
> -               return;
> +               return 0;
>
>         switch (func_id) {
>         case BPF_FUNC_get_stack:
> @@ -9961,6 +10002,8 @@ static void do_refine_retval_range(struct
> bpf_reg_state *regs, int ret_type,
>                 reg_bounds_sync(ret_reg);
>                 break;
>         }
> +
> +       return reg_bounds_sanity_check(env, ret_reg);
>  }
>
>  static int
> @@ -10612,7 +10655,9 @@ static int check_helper_call(struct
> bpf_verifier_env *env, struct bpf_insn *insn
>                 regs[BPF_REG_0].ref_obj_id = id;
>         }
>
> -       do_refine_retval_range(regs, fn->ret_type, func_id, &meta);
> +       err = do_refine_retval_range(env, regs, fn->ret_type, func_id, &meta);
> +       if (err)
> +               return err;
>
>         err = check_map_func_compatibility(env, meta.map_ptr, func_id);
>         if (err)
> @@ -14079,13 +14124,12 @@ static int check_alu_op(struct
> bpf_verifier_env *env, struct bpf_insn *insn)
>
>                 /* check dest operand */
>                 err = check_reg_arg(env, insn->dst_reg, DST_OP_NO_MARK);
> +               err = err ?: adjust_reg_min_max_vals(env, insn);
>                 if (err)
>                         return err;
> -
> -               return adjust_reg_min_max_vals(env, insn);
>         }
>
> -       return 0;
> +       return reg_bounds_sanity_check(env, &regs[insn->dst_reg]);
>  }
>
>  static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
> @@ -14600,18 +14644,21 @@ static void regs_refine_cond_op(struct
> bpf_reg_state *reg1, struct bpf_reg_state
>   * Technically we can do similar adjustments for pointers to the same object,
>   * but we don't support that right now.
>   */
> -static void reg_set_min_max(struct bpf_reg_state *true_reg1,
> -                           struct bpf_reg_state *true_reg2,
> -                           struct bpf_reg_state *false_reg1,
> -                           struct bpf_reg_state *false_reg2,
> -                           u8 opcode, bool is_jmp32)
> +static int reg_set_min_max(struct bpf_verifier_env *env,
> +                          struct bpf_reg_state *true_reg1,
> +                          struct bpf_reg_state *true_reg2,
> +                          struct bpf_reg_state *false_reg1,
> +                          struct bpf_reg_state *false_reg2,
> +                          u8 opcode, bool is_jmp32)
>  {
> +       int err;
> +
>         /* If either register is a pointer, we can't learn anything about its
>          * variable offset from the compare (unless they were a pointer into
>          * the same object, but we don't bother with that).
>          */
>         if (false_reg1->type != SCALAR_VALUE || false_reg2->type !=
> SCALAR_VALUE)
> -               return;
> +               return 0;
>
>         /* fallthrough (FALSE) branch */
>         regs_refine_cond_op(false_reg1, false_reg2,
> rev_opcode(opcode), is_jmp32);
> @@ -14622,6 +14669,12 @@ static void reg_set_min_max(struct
> bpf_reg_state *true_reg1,
>         regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32);
>         reg_bounds_sync(true_reg1);
>         reg_bounds_sync(true_reg2);
> +
> +       err = reg_bounds_sanity_check(env, true_reg1);
> +       err = err ?: reg_bounds_sanity_check(env, true_reg2);
> +       err = err ?: reg_bounds_sanity_check(env, false_reg1);
> +       err = err ?: reg_bounds_sanity_check(env, false_reg2);
> +       return err;
>  }
>
>  static void mark_ptr_or_null_reg(struct bpf_func_state *state,
> @@ -14915,15 +14968,20 @@ static int check_cond_jmp_op(struct
> bpf_verifier_env *env,
>         other_branch_regs = other_branch->frame[other_branch->curframe]->regs;
>
>         if (BPF_SRC(insn->code) == BPF_X) {
> -               reg_set_min_max(&other_branch_regs[insn->dst_reg],
> -                               &other_branch_regs[insn->src_reg],
> -                               dst_reg, src_reg, opcode, is_jmp32);
> +               err = reg_set_min_max(env,
> +                                     &other_branch_regs[insn->dst_reg],
> +                                     &other_branch_regs[insn->src_reg],
> +                                     dst_reg, src_reg, opcode, is_jmp32);
>         } else /* BPF_SRC(insn->code) == BPF_K */ {
> -               reg_set_min_max(&other_branch_regs[insn->dst_reg],
> -                               src_reg /* fake one */,
> -                               dst_reg, src_reg /* same fake one */,
> -                               opcode, is_jmp32);
> +               err = reg_set_min_max(env,
> +                                     &other_branch_regs[insn->dst_reg],
> +                                     src_reg /* fake one */,
> +                                     dst_reg, src_reg /* same fake one */,
> +                                     opcode, is_jmp32);
>         }
> +       if (err)
> +               return err;
> +
>         if (BPF_SRC(insn->code) == BPF_X &&
>             src_reg->type == SCALAR_VALUE && src_reg->id &&
>             !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
> @@ -17426,10 +17484,8 @@ static int do_check(struct bpf_verifier_env *env)
>                                                insn->off, BPF_SIZE(insn->code),
>                                                BPF_READ, insn->dst_reg, false,
>                                                BPF_MODE(insn->code) ==
> BPF_MEMSX);
> -                       if (err)
> -                               return err;
> -
> -                       err = save_aux_ptr_type(env, src_reg_type, true);
> +                       err = err ?: save_aux_ptr_type(env, src_reg_type, true);
> +                       err = reg_bounds_sanity_check(env,
> &regs[insn->dst_reg]);

this should obviously be `err = err ?: reg_bounds_sanity_check(...)`
(somehow it gets obvious in the email, not locally)

>                         if (err)
>                                 return err;
>                 } else if (class == BPF_STX) {
diff mbox series

Patch

diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 4c974296127b..f18a8247e5e2 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -14189,82 +14189,105 @@  static int is_scalar_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_sta
 				  u8 opcode, bool is_jmp32)
 {
 	struct tnum t1 = is_jmp32 ? tnum_subreg(reg1->var_off) : reg1->var_off;
+	struct tnum t2 = is_jmp32 ? tnum_subreg(reg2->var_off) : reg2->var_off;
 	u64 umin1 = is_jmp32 ? (u64)reg1->u32_min_value : reg1->umin_value;
 	u64 umax1 = is_jmp32 ? (u64)reg1->u32_max_value : reg1->umax_value;
 	s64 smin1 = is_jmp32 ? (s64)reg1->s32_min_value : reg1->smin_value;
 	s64 smax1 = is_jmp32 ? (s64)reg1->s32_max_value : reg1->smax_value;
-	u64 val = is_jmp32 ? (u32)tnum_subreg(reg2->var_off).value : reg2->var_off.value;
-	s64 sval = is_jmp32 ? (s32)val : (s64)val;
+	u64 umin2 = is_jmp32 ? (u64)reg2->u32_min_value : reg2->umin_value;
+	u64 umax2 = is_jmp32 ? (u64)reg2->u32_max_value : reg2->umax_value;
+	s64 smin2 = is_jmp32 ? (s64)reg2->s32_min_value : reg2->smin_value;
+	s64 smax2 = is_jmp32 ? (s64)reg2->s32_max_value : reg2->smax_value;
 
 	switch (opcode) {
 	case BPF_JEQ:
-		if (tnum_is_const(t1))
-			return !!tnum_equals_const(t1, val);
-		else if (val < umin1 || val > umax1)
+		/* const tnums */
+		if (tnum_is_const(t1) && tnum_is_const(t2))
+			return t1.value == t2.value;
+		/* const ranges */
+		if (umin1 == umax1 && umin2 == umax2)
+			return umin1 == umin2;
+		if (smin1 == smax1 && smin2 == smax2)
+			return umin1 == umin2;
+		/* non-overlapping ranges */
+		if (umin1 > umax2 || umax1 < umin2)
 			return 0;
-		else if (sval < smin1 || sval > smax1)
+		if (smin1 > smax2 || smax1 < smin2)
 			return 0;
 		break;
 	case BPF_JNE:
-		if (tnum_is_const(t1))
-			return !tnum_equals_const(t1, val);
-		else if (val < umin1 || val > umax1)
+		/* const tnums */
+		if (tnum_is_const(t1) && tnum_is_const(t2))
+			return t1.value != t2.value;
+		/* const ranges */
+		if (umin1 == umax1 && umin2 == umax2)
+			return umin1 != umin2;
+		if (smin1 == smax1 && smin2 == smax2)
+			return umin1 != umin2;
+		/* non-overlapping ranges */
+		if (umin1 > umax2 || umax1 < umin2)
 			return 1;
-		else if (sval < smin1 || sval > smax1)
+		if (smin1 > smax2 || smax1 < smin2)
 			return 1;
 		break;
 	case BPF_JSET:
-		if ((~t1.mask & t1.value) & val)
+		if (!is_reg_const(reg2, is_jmp32)) {
+			swap(reg1, reg2);
+			swap(t1, t2);
+		}
+		if (!is_reg_const(reg2, is_jmp32))
+			return -1;
+		if ((~t1.mask & t1.value) & t2.value)
 			return 1;
-		if (!((t1.mask | t1.value) & val))
+		if (!((t1.mask | t1.value) & t2.value))
 			return 0;
 		break;
 	case BPF_JGT:
-		if (umin1 > val )
+		if (umin1 > umax2)
 			return 1;
-		else if (umax1 <= val)
+		else if (umax1 <= umin2)
 			return 0;
 		break;
 	case BPF_JSGT:
-		if (smin1 > sval)
+		if (smin1 > smax2)
 			return 1;
-		else if (smax1 <= sval)
+		else if (smax1 <= smin2)
 			return 0;
 		break;
 	case BPF_JLT:
-		if (umax1 < val)
+		if (umax1 < umin2)
 			return 1;
-		else if (umin1 >= val)
+		else if (umin1 >= umax2)
 			return 0;
 		break;
 	case BPF_JSLT:
-		if (smax1 < sval)
+		if (smax1 < smin2)
 			return 1;
-		else if (smin1 >= sval)
+		else if (smin1 >= smax2)
 			return 0;
 		break;
 	case BPF_JGE:
-		if (umin1 >= val)
+		if (umin1 >= umax2)
 			return 1;
-		else if (umax1 < val)
+		else if (umax1 < umin2)
 			return 0;
 		break;
 	case BPF_JSGE:
-		if (smin1 >= sval)
+		if (smin1 >= smax2)
 			return 1;
-		else if (smax1 < sval)
+		else if (smax1 < smin2)
 			return 0;
 		break;
 	case BPF_JLE:
-		if (umax1 <= val)
+		if (umax1 <= umin2)
 			return 1;
-		else if (umin1 > val)
+		else if (umin1 > umax2)
 			return 0;
 		break;
 	case BPF_JSLE:
-		if (smax1 <= sval)
+		if (smax1 <= smin2)
 			return 1;
-		else if (smin1 > sval)
+		else if (smin1 > smax2)
 			return 0;
 		break;
 	}
@@ -14343,28 +14366,28 @@  static int is_pkt_ptr_branch_taken(struct bpf_reg_state *dst_reg,
 static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
 			   u8 opcode, bool is_jmp32)
 {
-	u64 val;
-
 	if (reg_is_pkt_pointer_any(reg1) && reg_is_pkt_pointer_any(reg2) && !is_jmp32)
 		return is_pkt_ptr_branch_taken(reg1, reg2, opcode);
 
-	/* try to make sure reg2 is a constant SCALAR_VALUE */
-	if (!is_reg_const(reg2, is_jmp32)) {
-		opcode = flip_opcode(opcode);
-		swap(reg1, reg2);
-	}
-	/* for now we expect reg2 to be a constant to make any useful decisions */
-	if (!is_reg_const(reg2, is_jmp32))
-		return -1;
-	val = reg_const_value(reg2, is_jmp32);
+	if (__is_pointer_value(false, reg1) || __is_pointer_value(false, reg2)) {
+		u64 val;
+
+		/* arrange that reg2 is a scalar, and reg1 is a pointer */
+		if (!is_reg_const(reg2, is_jmp32)) {
+			opcode = flip_opcode(opcode);
+			swap(reg1, reg2);
+		}
+		/* and ensure that reg2 is a constant */
+		if (!is_reg_const(reg2, is_jmp32))
+			return -1;
 
-	if (__is_pointer_value(false, reg1)) {
 		if (!reg_not_null(reg1))
 			return -1;
 
 		/* If pointer is valid tests against zero will fail so we can
 		 * use this to direct branch taken.
 		 */
+		val = reg_const_value(reg2, is_jmp32);
 		if (val != 0)
 			return -1;
 
@@ -14378,6 +14401,7 @@  static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
 		}
 	}
 
+	/* now deal with two scalars, but not necessarily constants */
 	return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
 }