diff mbox series

[bpf-next,v3,1/4] bpf: use scalar ids in mark_chain_precision()

Message ID 20230606222411.1820404-2-eddyz87@gmail.com (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series verify scalar ids mapping in regsafe() | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-PR success PR summary
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for bpf-next, async
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 81 this patch: 81
netdev/cc_maintainers warning 9 maintainers not CCed: kpsingh@kernel.org shuah@kernel.org sdf@google.com john.fastabend@gmail.com song@kernel.org mykolal@fb.com linux-kselftest@vger.kernel.org jolsa@kernel.org haoluo@google.com
netdev/build_clang success Errors and warnings before: 20 this patch: 20
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 81 this patch: 81
netdev/checkpatch fail ERROR: space required after that ',' (ctx:VxV) ERROR: spaces required around that ':' (ctx:OxE) ERROR: spaces required around that ':' (ctx:VxW) ERROR: spaces required around that '=' (ctx:VxO) ERROR: spaces required around that '=' (ctx:VxV) ERROR: spaces required around that '=' (ctx:VxW) WARNING: labels should not be indented WARNING: line length of 81 exceeds 80 columns WARNING: line length of 90 exceeds 80 columns WARNING: line length of 96 exceeds 80 columns
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline fail Was 0 now: 3
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-2 success Logs for build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-4 success Logs for build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-5 success Logs for build for x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-6 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-3 success Logs for build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-7 success Logs for test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-9 success Logs for test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-10 success Logs for test_maps on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-11 success Logs for test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-15 success Logs for test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-17 success Logs for test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-19 success Logs for test_progs_no_alu32_parallel on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-20 success Logs for test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for test_progs_parallel on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-25 success Logs for test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-27 success Logs for test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-28 success Logs for test_verifier on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-29 success Logs for veristat
bpf/vmtest-bpf-next-VM_Test-13 fail Logs for test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-14 success Logs for test_progs on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-18 success Logs for test_progs_no_alu32 on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-21 success Logs for test_progs_no_alu32_parallel on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-24 success Logs for test_progs_parallel on x86_64 with llvm-16
bpf/vmtest-bpf-next-VM_Test-26 success Logs for test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-12 fail Logs for test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-16 fail Logs for test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-8 success Logs for test_maps on s390x with gcc

Commit Message

Eduard Zingerman June 6, 2023, 10:24 p.m. UTC
Change mark_chain_precision() to track precision in situations
like below:

    r2 = unknown value
    ...
  --- state #0 ---
    ...
    r1 = r2                 // r1 and r2 now share the same ID
    ...
  --- state #1 {r1.id = A, r2.id = A} ---
    ...
    if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
    ...
  --- state #2 {r1.id = A, r2.id = A} ---
    r3 = r10
    r3 += r1                // need to mark both r1 and r2

At the beginning of the processing of each state, ensure that if a
register with a scalar ID is marked as precise, all registers sharing
this ID are also marked as precise.

This property would be used by a follow-up change in regsafe().

Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
---
 include/linux/bpf_verifier.h                  |  10 +-
 kernel/bpf/verifier.c                         | 114 ++++++++++++++++++
 .../testing/selftests/bpf/verifier/precise.c  |   8 +-
 3 files changed, 127 insertions(+), 5 deletions(-)

Comments

Andrii Nakryiko June 7, 2023, 9:40 p.m. UTC | #1
On Tue, Jun 6, 2023 at 3:24 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> Change mark_chain_precision() to track precision in situations
> like below:
>
>     r2 = unknown value
>     ...
>   --- state #0 ---
>     ...
>     r1 = r2                 // r1 and r2 now share the same ID
>     ...
>   --- state #1 {r1.id = A, r2.id = A} ---
>     ...
>     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
>     ...
>   --- state #2 {r1.id = A, r2.id = A} ---
>     r3 = r10
>     r3 += r1                // need to mark both r1 and r2
>
> At the beginning of the processing of each state, ensure that if a
> register with a scalar ID is marked as precise, all registers sharing
> this ID are also marked as precise.
>
> This property would be used by a follow-up change in regsafe().
>
> Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
> ---
>  include/linux/bpf_verifier.h                  |  10 +-
>  kernel/bpf/verifier.c                         | 114 ++++++++++++++++++
>  .../testing/selftests/bpf/verifier/precise.c  |   8 +-
>  3 files changed, 127 insertions(+), 5 deletions(-)
>
> diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> index ee4cc7471ed9..3f9856baa542 100644
> --- a/include/linux/bpf_verifier.h
> +++ b/include/linux/bpf_verifier.h
> @@ -559,6 +559,11 @@ struct backtrack_state {
>         u64 stack_masks[MAX_CALL_FRAMES];
>  };
>
> +struct reg_id_scratch {
> +       u32 count;
> +       u32 ids[BPF_ID_MAP_SIZE];
> +};
> +
>  /* single container for all structs
>   * one verifier_env per bpf_check() call
>   */
> @@ -590,7 +595,10 @@ struct bpf_verifier_env {
>         const struct bpf_line_info *prev_linfo;
>         struct bpf_verifier_log log;
>         struct bpf_subprog_info subprog_info[BPF_MAX_SUBPROGS + 1];
> -       struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> +       union {
> +               struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> +               struct reg_id_scratch precise_ids_scratch;

naming nit: "ids_scratch" or "idset_scratch" to stay in line with
"idmap_scratch"?

> +       };
>         struct {
>                 int *insn_state;
>                 int *insn_stack;
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index d117deb03806..2aa60b73f1b5 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -3779,6 +3779,96 @@ static void mark_all_scalars_imprecise(struct bpf_verifier_env *env, struct bpf_
>         }
>  }
>
> +static inline bool reg_id_scratch_contains(struct reg_id_scratch *s, u32 id)
> +{
> +       u32 i;
> +
> +       for (i = 0; i < s->count; ++i)
> +               if (s->ids[i] == id)
> +                       return true;
> +
> +       return false;
> +}
> +
> +static inline int reg_id_scratch_push(struct reg_id_scratch *s, u32 id)
> +{
> +       if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
> +               return -1;
> +       s->ids[s->count++] = id;

this will allow duplicated IDs to be added? Was it done in the name of speed?

> +       WARN_ONCE(s->count > 64,
> +                 "reg_id_scratch.count is unreasonably large (%d)", s->count);

do we need this one? Especially that it's not _ONCE variant? Maybe the
first WARN_ON_ONCE is enough?

> +       return 0;
> +}
> +
> +static inline void reg_id_scratch_reset(struct reg_id_scratch *s)

we probably don't need "inline" for all these helpers?

> +{
> +       s->count = 0;
> +}
> +
> +/* Collect a set of IDs for all registers currently marked as precise in env->bt.
> + * Mark all registers with these IDs as precise.
> + */
> +static void mark_precise_scalar_ids(struct bpf_verifier_env *env, struct bpf_verifier_state *st)
> +{
> +       struct reg_id_scratch *precise_ids = &env->precise_ids_scratch;
> +       struct backtrack_state *bt = &env->bt;
> +       struct bpf_func_state *func;
> +       struct bpf_reg_state *reg;
> +       DECLARE_BITMAP(mask, 64);
> +       int i, fr;
> +
> +       reg_id_scratch_reset(precise_ids);
> +
> +       for (fr = bt->frame; fr >= 0; fr--) {
> +               func = st->frame[fr];
> +
> +               bitmap_from_u64(mask, bt_frame_reg_mask(bt, fr));
> +               for_each_set_bit(i, mask, 32) {
> +                       reg = &func->regs[i];
> +                       if (!reg->id || reg->type != SCALAR_VALUE)
> +                               continue;
> +                       if (reg_id_scratch_push(precise_ids, reg->id))
> +                               return;
> +               }
> +
> +               bitmap_from_u64(mask, bt_frame_stack_mask(bt, fr));
> +               for_each_set_bit(i, mask, 64) {
> +                       if (i >= func->allocated_stack / BPF_REG_SIZE)
> +                               break;
> +                       if (!is_spilled_scalar_reg(&func->stack[i]))
> +                               continue;
> +                       reg = &func->stack[i].spilled_ptr;
> +                       if (!reg->id || reg->type != SCALAR_VALUE)

is_spilled_scalar_reg() already ensures reg->type is SCALAR_VALUE

> +                               continue;
> +                       if (reg_id_scratch_push(precise_ids, reg->id))
> +                               return;

if push fails (due to overflow of id set), shouldn't we propagate
error back and fallback to mark_all_precise?


> +               }
> +       }
> +
> +       for (fr = 0; fr <= st->curframe; ++fr) {
> +               func = st->frame[fr];
> +
> +               for (i = BPF_REG_0; i < BPF_REG_10; ++i) {
> +                       reg = &func->regs[i];
> +                       if (!reg->id)
> +                               continue;
> +                       if (!reg_id_scratch_contains(precise_ids, reg->id))
> +                               continue;
> +                       bt_set_frame_reg(bt, fr, i);
> +               }
> +               for (i = 0; i < func->allocated_stack / BPF_REG_SIZE; ++i) {
> +                       if (!is_spilled_scalar_reg(&func->stack[i]))
> +                               continue;
> +                       reg = &func->stack[i].spilled_ptr;
> +                       if (!reg->id)
> +                               continue;
> +                       if (!reg_id_scratch_contains(precise_ids, reg->id))
> +                               continue;
> +                       bt_set_frame_slot(bt, fr, i);
> +               }
> +       }
> +}
> +
>  /*
>   * __mark_chain_precision() backtracks BPF program instruction sequence and
>   * chain of verifier states making sure that register *regno* (if regno >= 0)
> @@ -3910,6 +4000,30 @@ static int __mark_chain_precision(struct bpf_verifier_env *env, int regno)
>                                 bt->frame, last_idx, first_idx, subseq_idx);
>                 }
>
> +               /* If some register with scalar ID is marked as precise,
> +                * make sure that all registers sharing this ID are also precise.
> +                * This is needed to estimate effect of find_equal_scalars().
> +                * Do this at the last instruction of each state,
> +                * bpf_reg_state::id fields are valid for these instructions.
> +                *
> +                * Allows to track precision in situation like below:
> +                *
> +                *     r2 = unknown value
> +                *     ...
> +                *   --- state #0 ---
> +                *     ...
> +                *     r1 = r2                 // r1 and r2 now share the same ID
> +                *     ...
> +                *   --- state #1 {r1.id = A, r2.id = A} ---
> +                *     ...
> +                *     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
> +                *     ...
> +                *   --- state #2 {r1.id = A, r2.id = A} ---
> +                *     r3 = r10
> +                *     r3 += r1                // need to mark both r1 and r2
> +                */
> +               mark_precise_scalar_ids(env, st);
> +
>                 if (last_idx < 0) {
>                         /* we are at the entry into subprog, which
>                          * is expected for global funcs, but only if
> diff --git a/tools/testing/selftests/bpf/verifier/precise.c b/tools/testing/selftests/bpf/verifier/precise.c
> index b8c0aae8e7ec..99272bb890da 100644
> --- a/tools/testing/selftests/bpf/verifier/precise.c
> +++ b/tools/testing/selftests/bpf/verifier/precise.c
> @@ -46,7 +46,7 @@
>         mark_precise: frame0: regs=r2 stack= before 20\
>         mark_precise: frame0: parent state regs=r2 stack=:\
>         mark_precise: frame0: last_idx 19 first_idx 10\
> -       mark_precise: frame0: regs=r2 stack= before 19\
> +       mark_precise: frame0: regs=r2,r9 stack= before 19\
>         mark_precise: frame0: regs=r9 stack= before 18\
>         mark_precise: frame0: regs=r8,r9 stack= before 17\
>         mark_precise: frame0: regs=r0,r9 stack= before 15\
> @@ -106,10 +106,10 @@
>         mark_precise: frame0: regs=r2 stack= before 22\
>         mark_precise: frame0: parent state regs=r2 stack=:\
>         mark_precise: frame0: last_idx 20 first_idx 20\
> -       mark_precise: frame0: regs=r2 stack= before 20\
> -       mark_precise: frame0: parent state regs=r2 stack=:\
> +       mark_precise: frame0: regs=r2,r9 stack= before 20\
> +       mark_precise: frame0: parent state regs=r2,r9 stack=:\
>         mark_precise: frame0: last_idx 19 first_idx 17\
> -       mark_precise: frame0: regs=r2 stack= before 19\
> +       mark_precise: frame0: regs=r2,r9 stack= before 19\
>         mark_precise: frame0: regs=r9 stack= before 18\
>         mark_precise: frame0: regs=r8,r9 stack= before 17\
>         mark_precise: frame0: parent state regs= stack=:",
> --
> 2.40.1
>
Eduard Zingerman June 8, 2023, 12:35 p.m. UTC | #2
On Wed, 2023-06-07 at 14:40 -0700, Andrii Nakryiko wrote:
> On Tue, Jun 6, 2023 at 3:24 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
> > 
> > Change mark_chain_precision() to track precision in situations
> > like below:
> > 
> >     r2 = unknown value
> >     ...
> >   --- state #0 ---
> >     ...
> >     r1 = r2                 // r1 and r2 now share the same ID
> >     ...
> >   --- state #1 {r1.id = A, r2.id = A} ---
> >     ...
> >     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
> >     ...
> >   --- state #2 {r1.id = A, r2.id = A} ---
> >     r3 = r10
> >     r3 += r1                // need to mark both r1 and r2
> > 
> > At the beginning of the processing of each state, ensure that if a
> > register with a scalar ID is marked as precise, all registers sharing
> > this ID are also marked as precise.
> > 
> > This property would be used by a follow-up change in regsafe().
> > 
> > Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
> > ---
> >  include/linux/bpf_verifier.h                  |  10 +-
> >  kernel/bpf/verifier.c                         | 114 ++++++++++++++++++
> >  .../testing/selftests/bpf/verifier/precise.c  |   8 +-
> >  3 files changed, 127 insertions(+), 5 deletions(-)
> > 
> > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> > index ee4cc7471ed9..3f9856baa542 100644
> > --- a/include/linux/bpf_verifier.h
> > +++ b/include/linux/bpf_verifier.h
> > @@ -559,6 +559,11 @@ struct backtrack_state {
> >         u64 stack_masks[MAX_CALL_FRAMES];
> >  };
> > 
> > +struct reg_id_scratch {
> > +       u32 count;
> > +       u32 ids[BPF_ID_MAP_SIZE];
> > +};
> > +
> >  /* single container for all structs
> >   * one verifier_env per bpf_check() call
> >   */
> > @@ -590,7 +595,10 @@ struct bpf_verifier_env {
> >         const struct bpf_line_info *prev_linfo;
> >         struct bpf_verifier_log log;
> >         struct bpf_subprog_info subprog_info[BPF_MAX_SUBPROGS + 1];
> > -       struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> > +       union {
> > +               struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> > +               struct reg_id_scratch precise_ids_scratch;
> 
> naming nit: "ids_scratch" or "idset_scratch" to stay in line with
> "idmap_scratch"?

Makes sense, will change to "idset_scratch".

> 
> > +       };
> >         struct {
> >                 int *insn_state;
> >                 int *insn_stack;
> > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > index d117deb03806..2aa60b73f1b5 100644
> > --- a/kernel/bpf/verifier.c
> > +++ b/kernel/bpf/verifier.c
> > @@ -3779,6 +3779,96 @@ static void mark_all_scalars_imprecise(struct bpf_verifier_env *env, struct bpf_
> >         }
> >  }
> > 
> > +static inline bool reg_id_scratch_contains(struct reg_id_scratch *s, u32 id)
> > +{
> > +       u32 i;
> > +
> > +       for (i = 0; i < s->count; ++i)
> > +               if (s->ids[i] == id)
> > +                       return true;
> > +
> > +       return false;
> > +}
> > +
> > +static inline int reg_id_scratch_push(struct reg_id_scratch *s, u32 id)
> > +{
> > +       if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
> > +               return -1;
> > +       s->ids[s->count++] = id;
> 
> this will allow duplicated IDs to be added? Was it done in the name of speed?

tbh, it's an artifact from bsearch/sort migration of a series.
While doing test veristat runs I found that maximal value of s->count is 5,
so looks like it would be fine the way it is now and it would be fine
if linear scan is added to avoid duplicate ids. Don't think I have a preference.

> 
> > +       WARN_ONCE(s->count > 64,
> > +                 "reg_id_scratch.count is unreasonably large (%d)", s->count);
> 
> do we need this one? Especially that it's not _ONCE variant? Maybe the
> first WARN_ON_ONCE is enough?

We make an assumption that linear scans of this array are ok, and it
would be scanned often. I'd like to have some indication if this
assumption is broken. The s->ids array is large (10 regs + 64 spills) * 8 frames.
If you think that this logging is not necessary I'll remove it.

> 
> > +       return 0;
> > +}
> > +
> > +static inline void reg_id_scratch_reset(struct reg_id_scratch *s)
> 
> we probably don't need "inline" for all these helpers?

Ok, will remove "inline".

> 
> > +{
> > +       s->count = 0;
> > +}
> > +
> > +/* Collect a set of IDs for all registers currently marked as precise in env->bt.
> > + * Mark all registers with these IDs as precise.
> > + */
> > +static void mark_precise_scalar_ids(struct bpf_verifier_env *env, struct bpf_verifier_state *st)
> > +{
> > +       struct reg_id_scratch *precise_ids = &env->precise_ids_scratch;
> > +       struct backtrack_state *bt = &env->bt;
> > +       struct bpf_func_state *func;
> > +       struct bpf_reg_state *reg;
> > +       DECLARE_BITMAP(mask, 64);
> > +       int i, fr;
> > +
> > +       reg_id_scratch_reset(precise_ids);
> > +
> > +       for (fr = bt->frame; fr >= 0; fr--) {
> > +               func = st->frame[fr];
> > +
> > +               bitmap_from_u64(mask, bt_frame_reg_mask(bt, fr));
> > +               for_each_set_bit(i, mask, 32) {
> > +                       reg = &func->regs[i];
> > +                       if (!reg->id || reg->type != SCALAR_VALUE)
> > +                               continue;
> > +                       if (reg_id_scratch_push(precise_ids, reg->id))
> > +                               return;
> > +               }
> > +
> > +               bitmap_from_u64(mask, bt_frame_stack_mask(bt, fr));
> > +               for_each_set_bit(i, mask, 64) {
> > +                       if (i >= func->allocated_stack / BPF_REG_SIZE)
> > +                               break;
> > +                       if (!is_spilled_scalar_reg(&func->stack[i]))
> > +                               continue;
> > +                       reg = &func->stack[i].spilled_ptr;
> > +                       if (!reg->id || reg->type != SCALAR_VALUE)
> 
> is_spilled_scalar_reg() already ensures reg->type is SCALAR_VALUE

Yes, my bad.

> 
> > +                               continue;
> > +                       if (reg_id_scratch_push(precise_ids, reg->id))
> > +                               return;
> 
> if push fails (due to overflow of id set), shouldn't we propagate
> error back and fallback to mark_all_precise?

In theory this push should never fail, as we pre-allocate enough slots
in the scratch. I'll propagate error to __mark_chain_precision() and
exit from that one with -EFAULT.

> 
> 
> > +               }
> > +       }
> > +
> > +       for (fr = 0; fr <= st->curframe; ++fr) {
> > +               func = st->frame[fr];
> > +
> > +               for (i = BPF_REG_0; i < BPF_REG_10; ++i) {
> > +                       reg = &func->regs[i];
> > +                       if (!reg->id)
> > +                               continue;
> > +                       if (!reg_id_scratch_contains(precise_ids, reg->id))
> > +                               continue;
> > +                       bt_set_frame_reg(bt, fr, i);
> > +               }
> > +               for (i = 0; i < func->allocated_stack / BPF_REG_SIZE; ++i) {
> > +                       if (!is_spilled_scalar_reg(&func->stack[i]))
> > +                               continue;
> > +                       reg = &func->stack[i].spilled_ptr;
> > +                       if (!reg->id)
> > +                               continue;
> > +                       if (!reg_id_scratch_contains(precise_ids, reg->id))
> > +                               continue;
> > +                       bt_set_frame_slot(bt, fr, i);
> > +               }
> > +       }
> > +}
> > +
> >  /*
> >   * __mark_chain_precision() backtracks BPF program instruction sequence and
> >   * chain of verifier states making sure that register *regno* (if regno >= 0)
> > @@ -3910,6 +4000,30 @@ static int __mark_chain_precision(struct bpf_verifier_env *env, int regno)
> >                                 bt->frame, last_idx, first_idx, subseq_idx);
> >                 }
> > 
> > +               /* If some register with scalar ID is marked as precise,
> > +                * make sure that all registers sharing this ID are also precise.
> > +                * This is needed to estimate effect of find_equal_scalars().
> > +                * Do this at the last instruction of each state,
> > +                * bpf_reg_state::id fields are valid for these instructions.
> > +                *
> > +                * Allows to track precision in situation like below:
> > +                *
> > +                *     r2 = unknown value
> > +                *     ...
> > +                *   --- state #0 ---
> > +                *     ...
> > +                *     r1 = r2                 // r1 and r2 now share the same ID
> > +                *     ...
> > +                *   --- state #1 {r1.id = A, r2.id = A} ---
> > +                *     ...
> > +                *     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
> > +                *     ...
> > +                *   --- state #2 {r1.id = A, r2.id = A} ---
> > +                *     r3 = r10
> > +                *     r3 += r1                // need to mark both r1 and r2
> > +                */
> > +               mark_precise_scalar_ids(env, st);
> > +
> >                 if (last_idx < 0) {
> >                         /* we are at the entry into subprog, which
> >                          * is expected for global funcs, but only if
> > diff --git a/tools/testing/selftests/bpf/verifier/precise.c b/tools/testing/selftests/bpf/verifier/precise.c
> > index b8c0aae8e7ec..99272bb890da 100644
> > --- a/tools/testing/selftests/bpf/verifier/precise.c
> > +++ b/tools/testing/selftests/bpf/verifier/precise.c
> > @@ -46,7 +46,7 @@
> >         mark_precise: frame0: regs=r2 stack= before 20\
> >         mark_precise: frame0: parent state regs=r2 stack=:\
> >         mark_precise: frame0: last_idx 19 first_idx 10\
> > -       mark_precise: frame0: regs=r2 stack= before 19\
> > +       mark_precise: frame0: regs=r2,r9 stack= before 19\
> >         mark_precise: frame0: regs=r9 stack= before 18\
> >         mark_precise: frame0: regs=r8,r9 stack= before 17\
> >         mark_precise: frame0: regs=r0,r9 stack= before 15\
> > @@ -106,10 +106,10 @@
> >         mark_precise: frame0: regs=r2 stack= before 22\
> >         mark_precise: frame0: parent state regs=r2 stack=:\
> >         mark_precise: frame0: last_idx 20 first_idx 20\
> > -       mark_precise: frame0: regs=r2 stack= before 20\
> > -       mark_precise: frame0: parent state regs=r2 stack=:\
> > +       mark_precise: frame0: regs=r2,r9 stack= before 20\
> > +       mark_precise: frame0: parent state regs=r2,r9 stack=:\
> >         mark_precise: frame0: last_idx 19 first_idx 17\
> > -       mark_precise: frame0: regs=r2 stack= before 19\
> > +       mark_precise: frame0: regs=r2,r9 stack= before 19\
> >         mark_precise: frame0: regs=r9 stack= before 18\
> >         mark_precise: frame0: regs=r8,r9 stack= before 17\
> >         mark_precise: frame0: parent state regs= stack=:",
> > --
> > 2.40.1
> >
Alexei Starovoitov June 8, 2023, 3:43 p.m. UTC | #3
On Thu, Jun 8, 2023 at 5:35 AM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> On Wed, 2023-06-07 at 14:40 -0700, Andrii Nakryiko wrote:
> > On Tue, Jun 6, 2023 at 3:24 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
> > >
> > > Change mark_chain_precision() to track precision in situations
> > > like below:
> > >
> > >     r2 = unknown value
> > >     ...
> > >   --- state #0 ---
> > >     ...
> > >     r1 = r2                 // r1 and r2 now share the same ID
> > >     ...
> > >   --- state #1 {r1.id = A, r2.id = A} ---
> > >     ...
> > >     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
> > >     ...
> > >   --- state #2 {r1.id = A, r2.id = A} ---
> > >     r3 = r10
> > >     r3 += r1                // need to mark both r1 and r2
> > >
> > > At the beginning of the processing of each state, ensure that if a
> > > register with a scalar ID is marked as precise, all registers sharing
> > > this ID are also marked as precise.
> > >
> > > This property would be used by a follow-up change in regsafe().
> > >
> > > Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
> > > ---
> > >  include/linux/bpf_verifier.h                  |  10 +-
> > >  kernel/bpf/verifier.c                         | 114 ++++++++++++++++++
> > >  .../testing/selftests/bpf/verifier/precise.c  |   8 +-
> > >  3 files changed, 127 insertions(+), 5 deletions(-)
> > >
> > > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> > > index ee4cc7471ed9..3f9856baa542 100644
> > > --- a/include/linux/bpf_verifier.h
> > > +++ b/include/linux/bpf_verifier.h
> > > @@ -559,6 +559,11 @@ struct backtrack_state {
> > >         u64 stack_masks[MAX_CALL_FRAMES];
> > >  };
> > >
> > > +struct reg_id_scratch {
> > > +       u32 count;
> > > +       u32 ids[BPF_ID_MAP_SIZE];
> > > +};
> > > +
> > >  /* single container for all structs
> > >   * one verifier_env per bpf_check() call
> > >   */
> > > @@ -590,7 +595,10 @@ struct bpf_verifier_env {
> > >         const struct bpf_line_info *prev_linfo;
> > >         struct bpf_verifier_log log;
> > >         struct bpf_subprog_info subprog_info[BPF_MAX_SUBPROGS + 1];
> > > -       struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> > > +       union {
> > > +               struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> > > +               struct reg_id_scratch precise_ids_scratch;
> >
> > naming nit: "ids_scratch" or "idset_scratch" to stay in line with
> > "idmap_scratch"?
>
> Makes sense, will change to "idset_scratch".
>
> >
> > > +       };
> > >         struct {
> > >                 int *insn_state;
> > >                 int *insn_stack;
> > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > index d117deb03806..2aa60b73f1b5 100644
> > > --- a/kernel/bpf/verifier.c
> > > +++ b/kernel/bpf/verifier.c
> > > @@ -3779,6 +3779,96 @@ static void mark_all_scalars_imprecise(struct bpf_verifier_env *env, struct bpf_
> > >         }
> > >  }
> > >
> > > +static inline bool reg_id_scratch_contains(struct reg_id_scratch *s, u32 id)
> > > +{
> > > +       u32 i;
> > > +
> > > +       for (i = 0; i < s->count; ++i)
> > > +               if (s->ids[i] == id)
> > > +                       return true;
> > > +
> > > +       return false;
> > > +}
> > > +
> > > +static inline int reg_id_scratch_push(struct reg_id_scratch *s, u32 id)
> > > +{
> > > +       if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
> > > +               return -1;
> > > +       s->ids[s->count++] = id;
> >
> > this will allow duplicated IDs to be added? Was it done in the name of speed?
>
> tbh, it's an artifact from bsearch/sort migration of a series.
> While doing test veristat runs I found that maximal value of s->count is 5,
> so looks like it would be fine the way it is now and it would be fine
> if linear scan is added to avoid duplicate ids. Don't think I have a preference.

Maybe return -EFAULT for count > 64 and print a verifier error message ?
If/when syzbot/human manages to craft such a program we'll know that
this is something to address.
Eduard Zingerman June 8, 2023, 5:30 p.m. UTC | #4
On Thu, 2023-06-08 at 08:43 -0700, Alexei Starovoitov wrote:
> On Thu, Jun 8, 2023 at 5:35 AM Eduard Zingerman <eddyz87@gmail.com> wrote:
> > 
> > On Wed, 2023-06-07 at 14:40 -0700, Andrii Nakryiko wrote:
> > > On Tue, Jun 6, 2023 at 3:24 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
> > > > 
> > > > Change mark_chain_precision() to track precision in situations
> > > > like below:
> > > > 
> > > >     r2 = unknown value
> > > >     ...
> > > >   --- state #0 ---
> > > >     ...
> > > >     r1 = r2                 // r1 and r2 now share the same ID
> > > >     ...
> > > >   --- state #1 {r1.id = A, r2.id = A} ---
> > > >     ...
> > > >     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
> > > >     ...
> > > >   --- state #2 {r1.id = A, r2.id = A} ---
> > > >     r3 = r10
> > > >     r3 += r1                // need to mark both r1 and r2
> > > > 
> > > > At the beginning of the processing of each state, ensure that if a
> > > > register with a scalar ID is marked as precise, all registers sharing
> > > > this ID are also marked as precise.
> > > > 
> > > > This property would be used by a follow-up change in regsafe().
> > > > 
> > > > Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
> > > > ---
> > > >  include/linux/bpf_verifier.h                  |  10 +-
> > > >  kernel/bpf/verifier.c                         | 114 ++++++++++++++++++
> > > >  .../testing/selftests/bpf/verifier/precise.c  |   8 +-
> > > >  3 files changed, 127 insertions(+), 5 deletions(-)
> > > > 
> > > > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> > > > index ee4cc7471ed9..3f9856baa542 100644
> > > > --- a/include/linux/bpf_verifier.h
> > > > +++ b/include/linux/bpf_verifier.h
> > > > @@ -559,6 +559,11 @@ struct backtrack_state {
> > > >         u64 stack_masks[MAX_CALL_FRAMES];
> > > >  };
> > > > 
> > > > +struct reg_id_scratch {
> > > > +       u32 count;
> > > > +       u32 ids[BPF_ID_MAP_SIZE];
> > > > +};
> > > > +
> > > >  /* single container for all structs
> > > >   * one verifier_env per bpf_check() call
> > > >   */
> > > > @@ -590,7 +595,10 @@ struct bpf_verifier_env {
> > > >         const struct bpf_line_info *prev_linfo;
> > > >         struct bpf_verifier_log log;
> > > >         struct bpf_subprog_info subprog_info[BPF_MAX_SUBPROGS + 1];
> > > > -       struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> > > > +       union {
> > > > +               struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
> > > > +               struct reg_id_scratch precise_ids_scratch;
> > > 
> > > naming nit: "ids_scratch" or "idset_scratch" to stay in line with
> > > "idmap_scratch"?
> > 
> > Makes sense, will change to "idset_scratch".
> > 
> > > 
> > > > +       };
> > > >         struct {
> > > >                 int *insn_state;
> > > >                 int *insn_stack;
> > > > diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> > > > index d117deb03806..2aa60b73f1b5 100644
> > > > --- a/kernel/bpf/verifier.c
> > > > +++ b/kernel/bpf/verifier.c
> > > > @@ -3779,6 +3779,96 @@ static void mark_all_scalars_imprecise(struct bpf_verifier_env *env, struct bpf_
> > > >         }
> > > >  }
> > > > 
> > > > +static inline bool reg_id_scratch_contains(struct reg_id_scratch *s, u32 id)
> > > > +{
> > > > +       u32 i;
> > > > +
> > > > +       for (i = 0; i < s->count; ++i)
> > > > +               if (s->ids[i] == id)
> > > > +                       return true;
> > > > +
> > > > +       return false;
> > > > +}
> > > > +
> > > > +static inline int reg_id_scratch_push(struct reg_id_scratch *s, u32 id)
> > > > +{
> > > > +       if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
> > > > +               return -1;
> > > > +       s->ids[s->count++] = id;
> > > 
> > > this will allow duplicated IDs to be added? Was it done in the name of speed?
> > 
> > tbh, it's an artifact from bsearch/sort migration of a series.
> > While doing test veristat runs I found that maximal value of s->count is 5,
> > so looks like it would be fine the way it is now and it would be fine
> > if linear scan is added to avoid duplicate ids. Don't think I have a preference.
> 
> Maybe return -EFAULT for count > 64 and print a verifier error message ?
> If/when syzbot/human manages to craft such a program we'll know that
> this is something to address.

Sounds a bit heavy-handed.
Should the same logic apply to idmap?

I did some silly testing, 1'000'000 searches over u32 array of size (10+64)*8:
- linear search is done in 0.7s
- qsort/bsearch is done in 23s

It looks like my concerns are completely overblown. I'm inclined to
remove the size warning and just check for array overflow.
diff mbox series

Patch

diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index ee4cc7471ed9..3f9856baa542 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -559,6 +559,11 @@  struct backtrack_state {
 	u64 stack_masks[MAX_CALL_FRAMES];
 };
 
+struct reg_id_scratch {
+	u32 count;
+	u32 ids[BPF_ID_MAP_SIZE];
+};
+
 /* single container for all structs
  * one verifier_env per bpf_check() call
  */
@@ -590,7 +595,10 @@  struct bpf_verifier_env {
 	const struct bpf_line_info *prev_linfo;
 	struct bpf_verifier_log log;
 	struct bpf_subprog_info subprog_info[BPF_MAX_SUBPROGS + 1];
-	struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
+	union {
+		struct bpf_id_pair idmap_scratch[BPF_ID_MAP_SIZE];
+		struct reg_id_scratch precise_ids_scratch;
+	};
 	struct {
 		int *insn_state;
 		int *insn_stack;
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index d117deb03806..2aa60b73f1b5 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -3779,6 +3779,96 @@  static void mark_all_scalars_imprecise(struct bpf_verifier_env *env, struct bpf_
 	}
 }
 
+static inline bool reg_id_scratch_contains(struct reg_id_scratch *s, u32 id)
+{
+	u32 i;
+
+	for (i = 0; i < s->count; ++i)
+		if (s->ids[i] == id)
+			return true;
+
+	return false;
+}
+
+static inline int reg_id_scratch_push(struct reg_id_scratch *s, u32 id)
+{
+	if (WARN_ON_ONCE(s->count >= ARRAY_SIZE(s->ids)))
+		return -1;
+	s->ids[s->count++] = id;
+	WARN_ONCE(s->count > 64,
+		  "reg_id_scratch.count is unreasonably large (%d)", s->count);
+	return 0;
+}
+
+static inline void reg_id_scratch_reset(struct reg_id_scratch *s)
+{
+	s->count = 0;
+}
+
+/* Collect a set of IDs for all registers currently marked as precise in env->bt.
+ * Mark all registers with these IDs as precise.
+ */
+static void mark_precise_scalar_ids(struct bpf_verifier_env *env, struct bpf_verifier_state *st)
+{
+	struct reg_id_scratch *precise_ids = &env->precise_ids_scratch;
+	struct backtrack_state *bt = &env->bt;
+	struct bpf_func_state *func;
+	struct bpf_reg_state *reg;
+	DECLARE_BITMAP(mask, 64);
+	int i, fr;
+
+	reg_id_scratch_reset(precise_ids);
+
+	for (fr = bt->frame; fr >= 0; fr--) {
+		func = st->frame[fr];
+
+		bitmap_from_u64(mask, bt_frame_reg_mask(bt, fr));
+		for_each_set_bit(i, mask, 32) {
+			reg = &func->regs[i];
+			if (!reg->id || reg->type != SCALAR_VALUE)
+				continue;
+			if (reg_id_scratch_push(precise_ids, reg->id))
+				return;
+		}
+
+		bitmap_from_u64(mask, bt_frame_stack_mask(bt, fr));
+		for_each_set_bit(i, mask, 64) {
+			if (i >= func->allocated_stack / BPF_REG_SIZE)
+				break;
+			if (!is_spilled_scalar_reg(&func->stack[i]))
+				continue;
+			reg = &func->stack[i].spilled_ptr;
+			if (!reg->id || reg->type != SCALAR_VALUE)
+				continue;
+			if (reg_id_scratch_push(precise_ids, reg->id))
+				return;
+		}
+	}
+
+	for (fr = 0; fr <= st->curframe; ++fr) {
+		func = st->frame[fr];
+
+		for (i = BPF_REG_0; i < BPF_REG_10; ++i) {
+			reg = &func->regs[i];
+			if (!reg->id)
+				continue;
+			if (!reg_id_scratch_contains(precise_ids, reg->id))
+				continue;
+			bt_set_frame_reg(bt, fr, i);
+		}
+		for (i = 0; i < func->allocated_stack / BPF_REG_SIZE; ++i) {
+			if (!is_spilled_scalar_reg(&func->stack[i]))
+				continue;
+			reg = &func->stack[i].spilled_ptr;
+			if (!reg->id)
+				continue;
+			if (!reg_id_scratch_contains(precise_ids, reg->id))
+				continue;
+			bt_set_frame_slot(bt, fr, i);
+		}
+	}
+}
+
 /*
  * __mark_chain_precision() backtracks BPF program instruction sequence and
  * chain of verifier states making sure that register *regno* (if regno >= 0)
@@ -3910,6 +4000,30 @@  static int __mark_chain_precision(struct bpf_verifier_env *env, int regno)
 				bt->frame, last_idx, first_idx, subseq_idx);
 		}
 
+		/* If some register with scalar ID is marked as precise,
+		 * make sure that all registers sharing this ID are also precise.
+		 * This is needed to estimate effect of find_equal_scalars().
+		 * Do this at the last instruction of each state,
+		 * bpf_reg_state::id fields are valid for these instructions.
+		 *
+		 * Allows to track precision in situation like below:
+		 *
+		 *     r2 = unknown value
+		 *     ...
+		 *   --- state #0 ---
+		 *     ...
+		 *     r1 = r2                 // r1 and r2 now share the same ID
+		 *     ...
+		 *   --- state #1 {r1.id = A, r2.id = A} ---
+		 *     ...
+		 *     if (r2 > 10) goto exit; // find_equal_scalars() assigns range to r1
+		 *     ...
+		 *   --- state #2 {r1.id = A, r2.id = A} ---
+		 *     r3 = r10
+		 *     r3 += r1                // need to mark both r1 and r2
+		 */
+		mark_precise_scalar_ids(env, st);
+
 		if (last_idx < 0) {
 			/* we are at the entry into subprog, which
 			 * is expected for global funcs, but only if
diff --git a/tools/testing/selftests/bpf/verifier/precise.c b/tools/testing/selftests/bpf/verifier/precise.c
index b8c0aae8e7ec..99272bb890da 100644
--- a/tools/testing/selftests/bpf/verifier/precise.c
+++ b/tools/testing/selftests/bpf/verifier/precise.c
@@ -46,7 +46,7 @@ 
 	mark_precise: frame0: regs=r2 stack= before 20\
 	mark_precise: frame0: parent state regs=r2 stack=:\
 	mark_precise: frame0: last_idx 19 first_idx 10\
-	mark_precise: frame0: regs=r2 stack= before 19\
+	mark_precise: frame0: regs=r2,r9 stack= before 19\
 	mark_precise: frame0: regs=r9 stack= before 18\
 	mark_precise: frame0: regs=r8,r9 stack= before 17\
 	mark_precise: frame0: regs=r0,r9 stack= before 15\
@@ -106,10 +106,10 @@ 
 	mark_precise: frame0: regs=r2 stack= before 22\
 	mark_precise: frame0: parent state regs=r2 stack=:\
 	mark_precise: frame0: last_idx 20 first_idx 20\
-	mark_precise: frame0: regs=r2 stack= before 20\
-	mark_precise: frame0: parent state regs=r2 stack=:\
+	mark_precise: frame0: regs=r2,r9 stack= before 20\
+	mark_precise: frame0: parent state regs=r2,r9 stack=:\
 	mark_precise: frame0: last_idx 19 first_idx 17\
-	mark_precise: frame0: regs=r2 stack= before 19\
+	mark_precise: frame0: regs=r2,r9 stack= before 19\
 	mark_precise: frame0: regs=r9 stack= before 18\
 	mark_precise: frame0: regs=r8,r9 stack= before 17\
 	mark_precise: frame0: parent state regs= stack=:",