diff mbox series

[bpf-next,2/4] bpf: track find_equal_scalars history on per-instruction level

Message ID 20240222005005.31784-3-eddyz87@gmail.com (mailing list archive)
State Changes Requested
Delegated to: BPF
Headers show
Series bpf: track find_equal_scalars history on per-instruction level | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-PR success PR summary
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for bpf-next
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 1113 this patch: 1113
netdev/build_tools success Errors and warnings before: 0 this patch: 0
netdev/cc_maintainers warning 9 maintainers not CCed: jolsa@kernel.org mykolal@fb.com john.fastabend@gmail.com shuah@kernel.org song@kernel.org sdf@google.com linux-kselftest@vger.kernel.org kpsingh@kernel.org haoluo@google.com
netdev/build_clang success Errors and warnings before: 1066 this patch: 1066
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success Fixes tag looks correct
netdev/build_allmodconfig_warn success Errors and warnings before: 1130 this patch: 1130
netdev/checkpatch fail ERROR: space required after that ',' (ctx:VxV) ERROR: spaces required around that ':' (ctx:OxE) ERROR: spaces required around that ':' (ctx:VxW) ERROR: spaces required around that '=' (ctx:VxO) ERROR: spaces required around that '=' (ctx:VxV) WARNING: labels should not be indented WARNING: line length of 81 exceeds 80 columns WARNING: line length of 82 exceeds 80 columns WARNING: line length of 86 exceeds 80 columns WARNING: line length of 89 exceeds 80 columns WARNING: line length of 92 exceeds 80 columns WARNING: line length of 93 exceeds 80 columns WARNING: line length of 94 exceeds 80 columns WARNING: line length of 95 exceeds 80 columns WARNING: line length of 96 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline fail Was 0 now: 1
bpf/vmtest-bpf-next-VM_Test-0 success Logs for Lint
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-2 success Logs for Unittests
bpf/vmtest-bpf-next-VM_Test-3 success Logs for Validate matrix.py
bpf/vmtest-bpf-next-VM_Test-5 success Logs for aarch64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-9 success Logs for aarch64-gcc / test (test_verifier, false, 360) / test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-4 success Logs for aarch64-gcc / build / build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-10 success Logs for aarch64-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-11 success Logs for s390x-gcc / build / build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-16 success Logs for s390x-gcc / test (test_verifier, false, 360) / test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-12 success Logs for s390x-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-17 success Logs for s390x-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-18 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-19 success Logs for x86_64-gcc / build / build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-20 success Logs for x86_64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-28 success Logs for x86_64-llvm-17 / build / build for x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-33 success Logs for x86_64-llvm-17 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-35 success Logs for x86_64-llvm-18 / build / build for x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-34 success Logs for x86_64-llvm-17 / veristat
bpf/vmtest-bpf-next-VM_Test-41 success Logs for x86_64-llvm-18 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-42 success Logs for x86_64-llvm-18 / veristat
bpf/vmtest-bpf-next-VM_Test-6 success Logs for aarch64-gcc / test (test_maps, false, 360) / test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-7 success Logs for aarch64-gcc / test (test_progs, false, 360) / test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-30 success Logs for x86_64-llvm-17 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-31 success Logs for x86_64-llvm-17 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-32 success Logs for x86_64-llvm-17 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-37 success Logs for x86_64-llvm-18 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-38 success Logs for x86_64-llvm-18 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-39 success Logs for x86_64-llvm-18 / test (test_progs_cpuv4, false, 360) / test_progs_cpuv4 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-40 success Logs for x86_64-llvm-18 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-8 success Logs for aarch64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-13 success Logs for s390x-gcc / test (test_maps, false, 360) / test_maps on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-21 success Logs for x86_64-gcc / test (test_maps, false, 360) / test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-26 success Logs for x86_64-gcc / test (test_verifier, false, 360) / test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-27 success Logs for x86_64-gcc / veristat / veristat on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for x86_64-gcc / test (test_progs, false, 360) / test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for x86_64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-24 success Logs for x86_64-gcc / test (test_progs_no_alu32_parallel, true, 30) / test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-25 success Logs for x86_64-gcc / test (test_progs_parallel, true, 30) / test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-15 success Logs for s390x-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-36 success Logs for x86_64-llvm-18 / build-release / build for x86_64 with llvm-18 and -O2 optimization
bpf/vmtest-bpf-next-VM_Test-14 success Logs for s390x-gcc / test (test_progs, false, 360) / test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-29 success Logs for x86_64-llvm-17 / build-release / build for x86_64 with llvm-17 and -O2 optimization

Commit Message

Eduard Zingerman Feb. 22, 2024, 12:50 a.m. UTC
Use bpf_verifier_state->jmp_history to track which registers were
updated by find_equal_scalars() when conditional jump was verified.
Use recorded information in backtrack_insn() to propagate precision.

E.g. for the following program:

            while verifying instructions
  r1 = r0              |
  if r1 < 8  goto ...  | push r0,r1 as equal_scalars in jmp_history
  if r0 > 16 goto ...  | push r0,r1 as equal_scalars in jmp_history
  r2 = r10             |
  r2 += r0             v mark_chain_precision(r0)

            while doing mark_chain_precision(r0)
  r1 = r0              ^
  if r1 < 8  goto ...  | mark r0,r1 as precise
  if r0 > 16 goto ...  | mark r0,r1 as precise
  r2 = r10             |
  r2 += r0             | mark r0 precise

Technically achieve this in following steps:
- Use 10 bits to identify each register that gains range because of
  find_equal_scalars():
  - 3 bits for frame number;
  - 6 bits for register or stack slot number;
  - 1 bit to indicate if register is spilled.
- Use u64 as a vector of 6 such records + 4 bits for vector length.
- Augment struct bpf_jmp_history_entry with field 'equal_scalars'
  representing such vector.
- When doing check_cond_jmp_op() for remember up to 6 registers that
  gain range because of find_equal_scalars() in such a vector.
- Don't propagate range information and reset IDs for registers that
  don't fit in 6-value vector.
- Push collected vector to bpf_verifier_state->jmp_history for
  instruction index of conditional jump.
- When doing backtrack_insn() for conditional jumps
  check if any of recorded equal scalars is currently marked precise,
  if so mark all equal recorded scalars as precise.

Fixes: 904e6ddf4133 ("bpf: Use scalar ids in mark_chain_precision()")
Reported-by: Hao Sun <sunhao.th@gmail.com>
Closes: https://lore.kernel.org/bpf/CAEf4BzZ0xidVCqB47XnkXcNhkPWF6_nTV7yt+_Lf0kcFEut2Mg@mail.gmail.com/
Suggested-by: Andrii Nakryiko <andrii@kernel.org>
Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
---
 include/linux/bpf_verifier.h                  |   1 +
 kernel/bpf/verifier.c                         | 207 ++++++++++++++++--
 .../bpf/progs/verifier_subprog_precision.c    |   2 +-
 .../testing/selftests/bpf/verifier/precise.c  |   2 +-
 4 files changed, 195 insertions(+), 17 deletions(-)

Comments

Andrii Nakryiko Feb. 28, 2024, 7:58 p.m. UTC | #1
On Wed, Feb 21, 2024 at 4:50 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> Use bpf_verifier_state->jmp_history to track which registers were
> updated by find_equal_scalars() when conditional jump was verified.
> Use recorded information in backtrack_insn() to propagate precision.
>
> E.g. for the following program:
>
>             while verifying instructions
>   r1 = r0              |
>   if r1 < 8  goto ...  | push r0,r1 as equal_scalars in jmp_history
>   if r0 > 16 goto ...  | push r0,r1 as equal_scalars in jmp_history
>   r2 = r10             |
>   r2 += r0             v mark_chain_precision(r0)
>
>             while doing mark_chain_precision(r0)
>   r1 = r0              ^
>   if r1 < 8  goto ...  | mark r0,r1 as precise
>   if r0 > 16 goto ...  | mark r0,r1 as precise
>   r2 = r10             |
>   r2 += r0             | mark r0 precise
>
> Technically achieve this in following steps:
> - Use 10 bits to identify each register that gains range because of
>   find_equal_scalars():
>   - 3 bits for frame number;
>   - 6 bits for register or stack slot number;
>   - 1 bit to indicate if register is spilled.
> - Use u64 as a vector of 6 such records + 4 bits for vector length.
> - Augment struct bpf_jmp_history_entry with field 'equal_scalars'
>   representing such vector.
> - When doing check_cond_jmp_op() for remember up to 6 registers that
>   gain range because of find_equal_scalars() in such a vector.
> - Don't propagate range information and reset IDs for registers that
>   don't fit in 6-value vector.
> - Push collected vector to bpf_verifier_state->jmp_history for
>   instruction index of conditional jump.
> - When doing backtrack_insn() for conditional jumps
>   check if any of recorded equal scalars is currently marked precise,
>   if so mark all equal recorded scalars as precise.
>
> Fixes: 904e6ddf4133 ("bpf: Use scalar ids in mark_chain_precision()")
> Reported-by: Hao Sun <sunhao.th@gmail.com>
> Closes: https://lore.kernel.org/bpf/CAEf4BzZ0xidVCqB47XnkXcNhkPWF6_nTV7yt+_Lf0kcFEut2Mg@mail.gmail.com/
> Suggested-by: Andrii Nakryiko <andrii@kernel.org>
> Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
> ---
>  include/linux/bpf_verifier.h                  |   1 +
>  kernel/bpf/verifier.c                         | 207 ++++++++++++++++--
>  .../bpf/progs/verifier_subprog_precision.c    |   2 +-
>  .../testing/selftests/bpf/verifier/precise.c  |   2 +-
>  4 files changed, 195 insertions(+), 17 deletions(-)
>
> diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> index cbfb235984c8..26e32555711c 100644
> --- a/include/linux/bpf_verifier.h
> +++ b/include/linux/bpf_verifier.h
> @@ -361,6 +361,7 @@ struct bpf_jmp_history_entry {
>         u32 prev_idx : 22;
>         /* special flags, e.g., whether insn is doing register stack spill/load */
>         u32 flags : 10;
> +       u64 equal_scalars;

nit: should we call this concept as a bit more generic "linked
registers" instead of "equal scalars"?

>  };
>
>  /* Maximum number of register states that can exist at once */
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index 759ef089b33c..b95b6842703c 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -3304,6 +3304,76 @@ static bool is_jmp_point(struct bpf_verifier_env *env, int insn_idx)
>         return env->insn_aux_data[insn_idx].jmp_point;
>  }
>
> +#define ES_FRAMENO_BITS        3
> +#define ES_SPI_BITS    6
> +#define ES_ENTRY_BITS  (ES_SPI_BITS + ES_FRAMENO_BITS + 1)
> +#define ES_SIZE_BITS   4
> +#define ES_FRAMENO_MASK        ((1ul << ES_FRAMENO_BITS) - 1)
> +#define ES_SPI_MASK    ((1ul << ES_SPI_BITS)     - 1)
> +#define ES_SIZE_MASK   ((1ul << ES_SIZE_BITS)    - 1)
> +#define ES_SPI_OFF     ES_FRAMENO_BITS
> +#define ES_IS_REG_OFF  (ES_SPI_BITS + ES_FRAMENO_BITS)
> +
> +/* Pack one history entry for equal scalars as 10 bits in the following format:
> + * - 3-bits frameno
> + * - 6-bits spi_or_reg
> + * - 1-bit  is_reg
> + */
> +static u64 equal_scalars_pack(u32 frameno, u32 spi_or_reg, bool is_reg)
> +{
> +       u64 val = 0;
> +
> +       val |= frameno & ES_FRAMENO_MASK;
> +       val |= (spi_or_reg & ES_SPI_MASK) << ES_SPI_OFF;
> +       val |= (is_reg ? 1 : 0) << ES_IS_REG_OFF;
> +       return val;
> +}
> +
> +static void equal_scalars_unpack(u64 val, u32 *frameno, u32 *spi_or_reg, bool *is_reg)
> +{
> +       *frameno    =  val & ES_FRAMENO_MASK;
> +       *spi_or_reg = (val >> ES_SPI_OFF) & ES_SPI_MASK;
> +       *is_reg     = (val >> ES_IS_REG_OFF) & 0x1;
> +}
> +
> +static u32 equal_scalars_size(u64 equal_scalars)
> +{
> +       return equal_scalars & ES_SIZE_MASK;
> +}
> +
> +/* Use u64 as a stack of 6 10-bit values, use first 4-bits to track
> + * number of elements currently in stack.
> + */
> +static bool equal_scalars_push(u64 *equal_scalars, u32 frameno, u32 spi_or_reg, bool is_reg)
> +{
> +       u32 num;
> +
> +       num = equal_scalars_size(*equal_scalars);
> +       if (num == 6)
> +               return false;
> +       *equal_scalars >>= ES_SIZE_BITS;
> +       *equal_scalars <<= ES_ENTRY_BITS;
> +       *equal_scalars |= equal_scalars_pack(frameno, spi_or_reg, is_reg);
> +       *equal_scalars <<= ES_SIZE_BITS;
> +       *equal_scalars |= num + 1;
> +       return true;
> +}
> +
> +static bool equal_scalars_pop(u64 *equal_scalars, u32 *frameno, u32 *spi_or_reg, bool *is_reg)
> +{
> +       u32 num;
> +
> +       num = equal_scalars_size(*equal_scalars);
> +       if (num == 0)
> +               return false;
> +       *equal_scalars >>= ES_SIZE_BITS;
> +       equal_scalars_unpack(*equal_scalars, frameno, spi_or_reg, is_reg);
> +       *equal_scalars >>= ES_ENTRY_BITS;
> +       *equal_scalars <<= ES_SIZE_BITS;
> +       *equal_scalars |= num - 1;
> +       return true;
> +}
> +

I'm wondering if this pop/push set of primitives is the best approach?
What if we had pack/unpack operations, where for various checking
logic we'd be working with "unpacked" representation, e.g., something
like this:

struct linked_reg_set {
    int cnt;
    struct {
        int frameno;
        union {
            int spi;
            int regno;
        };
        bool is_set;
        bool is_reg;
    } reg_set[6];
};

bt_set_equal_scalars() could accept `struct linked_reg_set*` instead
of bitmask itself. Same for find_equal_scalars().

I think even implementation of packing/unpacking would be more
straightforward and we won't even need all those ES_xxx consts (or at
least fewer of them).

WDYT?

>  static struct bpf_jmp_history_entry *get_jmp_hist_entry(struct bpf_verifier_state *st,
>                                                         u32 hist_end, int insn_idx)
>  {

[...]
Eduard Zingerman Feb. 28, 2024, 9:16 p.m. UTC | #2
On Wed, 2024-02-28 at 11:58 -0800, Andrii Nakryiko wrote:
[...]

> > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> > index cbfb235984c8..26e32555711c 100644
> > --- a/include/linux/bpf_verifier.h
> > +++ b/include/linux/bpf_verifier.h
> > @@ -361,6 +361,7 @@ struct bpf_jmp_history_entry {
> >         u32 prev_idx : 22;
> >         /* special flags, e.g., whether insn is doing register stack spill/load */
> >         u32 flags : 10;
> > +       u64 equal_scalars;
> 
> nit: should we call this concept as a bit more generic "linked
> registers" instead of "equal scalars"?

It's a historical name for the feature and it is present in a few commit and tests.
Agree that "linked_registers" is better in current context.
A bit reluctant but can change it here.

[...]

> I'm wondering if this pop/push set of primitives is the best approach?

I kinda like it :)

> What if we had pack/unpack operations, where for various checking
> logic we'd be working with "unpacked" representation, e.g., something
> like this:
> 
> struct linked_reg_set {
>     int cnt;
>     struct {

Will need a name here, otherwise iteration would be somewhat inconvenient.
Suppose 'struct reg_or_spill'.

>         int frameno;
>         union {
>             int spi;
>             int regno;
>         };
>         bool is_set;
>         bool is_reg;
>     } reg_set[6];
> };
> 
> bt_set_equal_scalars() could accept `struct linked_reg_set*` instead
> of bitmask itself. Same for find_equal_scalars().

For clients it would be

        while (equal_scalars_pop(&equal_scalars, &fr, &spi, &is_reg)) {
                if ((is_reg && bt_is_frame_reg_set(bt, fr, spi)) ||
                    (!is_reg && bt_is_frame_slot_set(bt, fr, spi)))
                    ...
        }

    --- vs ---
 
        for (i = 0; i < equal_scalars->cnt; ++i) {
                struct reg_or_spill *r = equal_scalars->reg_set[i];

                if ((r->is_reg && bt_is_frame_reg_set(bt, r->frameno, r->regno)) ||
                    (!r->is_reg && bt_is_frame_slot_set(bt, r->frameno, r->spi)))
                    ...
        }

I'd say, no significant difference.

> I think even implementation of packing/unpacking would be more
> straightforward and we won't even need all those ES_xxx consts (or at
> least fewer of them).
> 
> WDYT?

I wouldn't say it simplifies packing/unpacking much.
Below is the code using new data structure and it's like
59 lines old version vs 56 lines new version.

--- 8< ----------------------------------------------------------------

struct reg_or_spill {
	int frameno;
	union {
		int spi;
		int regno;
	};
	bool is_reg;
};

struct linked_reg_set {
	int cnt;
	struct reg_or_spill reg_set[6];
};

/* Pack one history entry for equal scalars as 10 bits in the following format:
 * - 3-bits frameno
 * - 6-bits spi_or_reg
 * - 1-bit  is_reg
 */
static u64 linked_reg_set_pack(struct linked_reg_set *s)
{
	u64 val = 0;
	int i;

	for (i = 0; i < s->cnt; ++i) {
		struct reg_or_spill *r = &s->reg_set[i];
		u64 tmp = 0;

		tmp |= r->frameno & ES_FRAMENO_MASK;
		tmp |= (r->spi & ES_SPI_MASK) << ES_SPI_OFF;
		tmp |= (r->is_reg ? 1 : 0) << ES_IS_REG_OFF;

		val <<= ES_ENTRY_BITS;
		val |= tmp;
	}
	val <<= ES_SIZE_BITS;
	val |= s->cnt;
	return val;
}

static void linked_reg_set_unpack(u64 val, struct linked_reg_set *s)
{
	int i;

	s->cnt = val & ES_SIZE_MASK;
	val >>= ES_SIZE_BITS;

	for (i = 0; i < s->cnt; ++i) {
		struct reg_or_spill *r = &s->reg_set[i];

		r->frameno =  val & ES_FRAMENO_MASK;
		r->spi     = (val >> ES_SPI_OFF) & ES_SPI_MASK;
		r->is_reg  = (val >> ES_IS_REG_OFF) & 0x1;
		val >>= ES_ENTRY_BITS;
	}
}

---------------------------------------------------------------- >8 ---
Andrii Nakryiko Feb. 28, 2024, 9:36 p.m. UTC | #3
On Wed, Feb 28, 2024 at 1:16 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> On Wed, 2024-02-28 at 11:58 -0800, Andrii Nakryiko wrote:
> [...]
>
> > > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> > > index cbfb235984c8..26e32555711c 100644
> > > --- a/include/linux/bpf_verifier.h
> > > +++ b/include/linux/bpf_verifier.h
> > > @@ -361,6 +361,7 @@ struct bpf_jmp_history_entry {
> > >         u32 prev_idx : 22;
> > >         /* special flags, e.g., whether insn is doing register stack spill/load */
> > >         u32 flags : 10;
> > > +       u64 equal_scalars;
> >
> > nit: should we call this concept as a bit more generic "linked
> > registers" instead of "equal scalars"?
>
> It's a historical name for the feature and it is present in a few commit and tests.
> Agree that "linked_registers" is better in current context.
> A bit reluctant but can change it here.
>
> [...]
>
> > I'm wondering if this pop/push set of primitives is the best approach?
>
> I kinda like it :)
>
> > What if we had pack/unpack operations, where for various checking
> > logic we'd be working with "unpacked" representation, e.g., something
> > like this:
> >
> > struct linked_reg_set {
> >     int cnt;
> >     struct {
>
> Will need a name here, otherwise iteration would be somewhat inconvenient.
> Suppose 'struct reg_or_spill'.
>
> >         int frameno;
> >         union {
> >             int spi;
> >             int regno;
> >         };
> >         bool is_set;
> >         bool is_reg;
> >     } reg_set[6];
> > };
> >
> > bt_set_equal_scalars() could accept `struct linked_reg_set*` instead
> > of bitmask itself. Same for find_equal_scalars().
>
> For clients it would be
>
>         while (equal_scalars_pop(&equal_scalars, &fr, &spi, &is_reg)) {
>                 if ((is_reg && bt_is_frame_reg_set(bt, fr, spi)) ||
>                     (!is_reg && bt_is_frame_slot_set(bt, fr, spi)))
>                     ...
>         }
>
>     --- vs ---
>
>         for (i = 0; i < equal_scalars->cnt; ++i) {
>                 struct reg_or_spill *r = equal_scalars->reg_set[i];
>
>                 if ((r->is_reg && bt_is_frame_reg_set(bt, r->frameno, r->regno)) ||
>                     (!r->is_reg && bt_is_frame_slot_set(bt, r->frameno, r->spi)))
>                     ...
>         }
>
> I'd say, no significant difference.

Can I disagree? I find the second to be much better. There is no
in-place modification of a mask, no out parameters, we have a clean
record r with a few fields. We also know the count upfront, though we
maintain a simple rule (mask == 0 => cnt == 0), so not really a big
deal either way.

>
> > I think even implementation of packing/unpacking would be more
> > straightforward and we won't even need all those ES_xxx consts (or at
> > least fewer of them).
> >
> > WDYT?
>
> I wouldn't say it simplifies packing/unpacking much.
> Below is the code using new data structure and it's like
> 59 lines old version vs 56 lines new version.

I'd say it's not about a number of lines, it's about ease of
understanding, reasoning, and using these helpers.

I do prefer the code you wrote below, but I'm not going to die on this
hill if you insist. I'll go think about the rest of the logic.

>
> --- 8< ----------------------------------------------------------------
>
> struct reg_or_spill {
>         int frameno;
>         union {
>                 int spi;
>                 int regno;
>         };
>         bool is_reg;
> };
>
> struct linked_reg_set {
>         int cnt;
>         struct reg_or_spill reg_set[6];
> };
>
> /* Pack one history entry for equal scalars as 10 bits in the following format:
>  * - 3-bits frameno
>  * - 6-bits spi_or_reg
>  * - 1-bit  is_reg
>  */
> static u64 linked_reg_set_pack(struct linked_reg_set *s)
> {
>         u64 val = 0;
>         int i;
>
>         for (i = 0; i < s->cnt; ++i) {
>                 struct reg_or_spill *r = &s->reg_set[i];
>                 u64 tmp = 0;
>
>                 tmp |= r->frameno & ES_FRAMENO_MASK;
>                 tmp |= (r->spi & ES_SPI_MASK) << ES_SPI_OFF;

nit: we shouldn't mask anything here, it just makes an impression that
r->frameno can be bigger than we have bits for it in a bitmask

>                 tmp |= (r->is_reg ? 1 : 0) << ES_IS_REG_OFF;
>
>                 val <<= ES_ENTRY_BITS;
>                 val |= tmp;

val <<= ES_ENTRY_BITS;
val |= r->frameno | (r->spi << ES_SPI_OFF) | ((r->is_reg ? 1 : 0) <<
ES_IS_REG_OFF);

or you can do it as three assignment, but there is no need for tmp

>         }
>         val <<= ES_SIZE_BITS;
>         val |= s->cnt;
>         return val;
> }
>
> static void linked_reg_set_unpack(u64 val, struct linked_reg_set *s)
> {
>         int i;
>
>         s->cnt = val & ES_SIZE_MASK;
>         val >>= ES_SIZE_BITS;
>
>         for (i = 0; i < s->cnt; ++i) {
>                 struct reg_or_spill *r = &s->reg_set[i];
>
>                 r->frameno =  val & ES_FRAMENO_MASK;
>                 r->spi     = (val >> ES_SPI_OFF) & ES_SPI_MASK;
>                 r->is_reg  = (val >> ES_IS_REG_OFF) & 0x1;
>                 val >>= ES_ENTRY_BITS;
>         }
> }
>

I do think that the above is much easier to read and follow.

> ---------------------------------------------------------------- >8 ---
Andrii Nakryiko Feb. 28, 2024, 9:40 p.m. UTC | #4
On Wed, Feb 28, 2024 at 1:16 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> On Wed, 2024-02-28 at 11:58 -0800, Andrii Nakryiko wrote:
> [...]
>
> > > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> > > index cbfb235984c8..26e32555711c 100644
> > > --- a/include/linux/bpf_verifier.h
> > > +++ b/include/linux/bpf_verifier.h
> > > @@ -361,6 +361,7 @@ struct bpf_jmp_history_entry {
> > >         u32 prev_idx : 22;
> > >         /* special flags, e.g., whether insn is doing register stack spill/load */
> > >         u32 flags : 10;
> > > +       u64 equal_scalars;
> >
> > nit: should we call this concept as a bit more generic "linked
> > registers" instead of "equal scalars"?
>
> It's a historical name for the feature and it is present in a few commit and tests.
> Agree that "linked_registers" is better in current context.
> A bit reluctant but can change it here.

I'd start with calling this specific field either "linked_regs" or
"linked_set". It's a superset of "equal scalars", so we don't strictly
need to rename all the existing mentions of "equal_scalars" in
existing code.

>
> [...]
>

[...]
Eduard Zingerman Feb. 28, 2024, 10:39 p.m. UTC | #5
On Wed, 2024-02-28 at 13:36 -0800, Andrii Nakryiko wrote:
[...]

> I'd say it's not about a number of lines, it's about ease of
> understanding, reasoning, and using these helpers.
> 
> I do prefer the code you wrote below, but I'm not going to die on this
> hill if you insist. I'll go think about the rest of the logic.

Ok, code is meant to be read, so I'll switch to below in v2.

[...]

> > static u64 linked_reg_set_pack(struct linked_reg_set *s)
> > {
> >         u64 val = 0;
> >         int i;
> > 
> >         for (i = 0; i < s->cnt; ++i) {
> >                 struct reg_or_spill *r = &s->reg_set[i];
> >                 u64 tmp = 0;
> > 
> >                 tmp |= r->frameno & ES_FRAMENO_MASK;
> >                 tmp |= (r->spi & ES_SPI_MASK) << ES_SPI_OFF;
> 
> nit: we shouldn't mask anything here, it just makes an impression that
> r->frameno can be bigger than we have bits for it in a bitmask

Ok, I'll add bitmasks to field definitions and remove masks here.

[...]
Andrii Nakryiko Feb. 28, 2024, 11:01 p.m. UTC | #6
On Wed, Feb 21, 2024 at 4:50 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> Use bpf_verifier_state->jmp_history to track which registers were
> updated by find_equal_scalars() when conditional jump was verified.
> Use recorded information in backtrack_insn() to propagate precision.
>
> E.g. for the following program:
>
>             while verifying instructions
>   r1 = r0              |
>   if r1 < 8  goto ...  | push r0,r1 as equal_scalars in jmp_history
>   if r0 > 16 goto ...  | push r0,r1 as equal_scalars in jmp_history
>   r2 = r10             |
>   r2 += r0             v mark_chain_precision(r0)
>
>             while doing mark_chain_precision(r0)
>   r1 = r0              ^
>   if r1 < 8  goto ...  | mark r0,r1 as precise
>   if r0 > 16 goto ...  | mark r0,r1 as precise
>   r2 = r10             |
>   r2 += r0             | mark r0 precise
>
> Technically achieve this in following steps:
> - Use 10 bits to identify each register that gains range because of
>   find_equal_scalars():
>   - 3 bits for frame number;
>   - 6 bits for register or stack slot number;
>   - 1 bit to indicate if register is spilled.
> - Use u64 as a vector of 6 such records + 4 bits for vector length.
> - Augment struct bpf_jmp_history_entry with field 'equal_scalars'
>   representing such vector.
> - When doing check_cond_jmp_op() for remember up to 6 registers that
>   gain range because of find_equal_scalars() in such a vector.
> - Don't propagate range information and reset IDs for registers that
>   don't fit in 6-value vector.
> - Push collected vector to bpf_verifier_state->jmp_history for
>   instruction index of conditional jump.
> - When doing backtrack_insn() for conditional jumps
>   check if any of recorded equal scalars is currently marked precise,
>   if so mark all equal recorded scalars as precise.
>
> Fixes: 904e6ddf4133 ("bpf: Use scalar ids in mark_chain_precision()")
> Reported-by: Hao Sun <sunhao.th@gmail.com>
> Closes: https://lore.kernel.org/bpf/CAEf4BzZ0xidVCqB47XnkXcNhkPWF6_nTV7yt+_Lf0kcFEut2Mg@mail.gmail.com/
> Suggested-by: Andrii Nakryiko <andrii@kernel.org>
> Signed-off-by: Eduard Zingerman <eddyz87@gmail.com>
> ---
>  include/linux/bpf_verifier.h                  |   1 +
>  kernel/bpf/verifier.c                         | 207 ++++++++++++++++--
>  .../bpf/progs/verifier_subprog_precision.c    |   2 +-
>  .../testing/selftests/bpf/verifier/precise.c  |   2 +-
>  4 files changed, 195 insertions(+), 17 deletions(-)
>
> diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> index cbfb235984c8..26e32555711c 100644
> --- a/include/linux/bpf_verifier.h
> +++ b/include/linux/bpf_verifier.h
> @@ -361,6 +361,7 @@ struct bpf_jmp_history_entry {
>         u32 prev_idx : 22;
>         /* special flags, e.g., whether insn is doing register stack spill/load */
>         u32 flags : 10;
> +       u64 equal_scalars;
>  };
>

[...]

> @@ -3314,7 +3384,7 @@ static struct bpf_jmp_history_entry *get_jmp_hist_entry(struct bpf_verifier_stat
>
>  /* for any branch, call, exit record the history of jmps in the given state */
>  static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_state *cur,
> -                           int insn_flags)
> +                           int insn_flags, u64 equal_scalars)
>  {
>         struct bpf_jmp_history_entry *p, *cur_hist_ent;
>         u32 cnt = cur->jmp_history_cnt;
> @@ -3332,6 +3402,12 @@ static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st
>                           "verifier insn history bug: insn_idx %d cur flags %x new flags %x\n",
>                           env->insn_idx, cur_hist_ent->flags, insn_flags);
>                 cur_hist_ent->flags |= insn_flags;
> +               if (cur_hist_ent->equal_scalars != 0) {
> +                       verbose(env, "verifier bug: insn_idx %d equal_scalars != 0: %#llx\n",
> +                               env->insn_idx, cur_hist_ent->equal_scalars);
> +                       return -EFAULT;
> +               }

let's do WARN_ONCE() just like we do for flags? why deviating?

> +               cur_hist_ent->equal_scalars = equal_scalars;
>                 return 0;
>         }
>
> @@ -3346,6 +3422,7 @@ static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st
>         p->idx = env->insn_idx;
>         p->prev_idx = env->prev_insn_idx;
>         p->flags = insn_flags;
> +       p->equal_scalars = equal_scalars;
>         cur->jmp_history_cnt = cnt;
>
>         return 0;

[...]

>  static bool calls_callback(struct bpf_verifier_env *env, int insn_idx);
>
>  /* For given verifier state backtrack_insn() is called from the last insn to
> @@ -3802,6 +3917,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
>                          */
>                         return 0;
>                 } else if (BPF_SRC(insn->code) == BPF_X) {
> +                       bt_set_equal_scalars(bt, hist);
>                         if (!bt_is_reg_set(bt, dreg) && !bt_is_reg_set(bt, sreg))
>                                 return 0;
>                         /* dreg <cond> sreg
> @@ -3812,6 +3928,9 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
>                          */
>                         bt_set_reg(bt, dreg);
>                         bt_set_reg(bt, sreg);
> +                       bt_set_equal_scalars(bt, hist);
> +               } else if (BPF_SRC(insn->code) == BPF_K) {
> +                       bt_set_equal_scalars(bt, hist);

Can you please elaborate why we are doing bt_set_equal_scalars() in
these three places and not everywhere else? I'm trying to understand
whether we should do it more generically for any instruction either
before or after all the bt_set_xxx() calls...

>                          /* else dreg <cond> K
>                           * Only dreg still needs precision before
>                           * this insn, so for the K-based conditional
> @@ -4579,7 +4698,7 @@ static int check_stack_write_fixed_off(struct bpf_verifier_env *env,
>         }
>
>         if (insn_flags)
> -               return push_jmp_history(env, env->cur_state, insn_flags);
> +               return push_jmp_history(env, env->cur_state, insn_flags, 0);
>         return 0;
>  }
>

[...]
Eduard Zingerman Feb. 28, 2024, 11:29 p.m. UTC | #7
On Wed, 2024-02-28 at 15:01 -0800, Andrii Nakryiko wrote:
[...]

> > @@ -3332,6 +3402,12 @@ static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st
> >                           "verifier insn history bug: insn_idx %d cur flags %x new flags %x\n",
> >                           env->insn_idx, cur_hist_ent->flags, insn_flags);
> >                 cur_hist_ent->flags |= insn_flags;
> > +               if (cur_hist_ent->equal_scalars != 0) {
> > +                       verbose(env, "verifier bug: insn_idx %d equal_scalars != 0: %#llx\n",
> > +                               env->insn_idx, cur_hist_ent->equal_scalars);
> > +                       return -EFAULT;
> > +               }
> 
> let's do WARN_ONCE() just like we do for flags? why deviating?

Ok

[...]

> >  /* For given verifier state backtrack_insn() is called from the last insn to
> > @@ -3802,6 +3917,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> >                          */
> >                         return 0;
> >                 } else if (BPF_SRC(insn->code) == BPF_X) {
> > +                       bt_set_equal_scalars(bt, hist);
> >                         if (!bt_is_reg_set(bt, dreg) && !bt_is_reg_set(bt, sreg))
> >                                 return 0;
> >                         /* dreg <cond> sreg
> > @@ -3812,6 +3928,9 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> >                          */
> >                         bt_set_reg(bt, dreg);
> >                         bt_set_reg(bt, sreg);
> > +                       bt_set_equal_scalars(bt, hist);
> > +               } else if (BPF_SRC(insn->code) == BPF_K) {
> > +                       bt_set_equal_scalars(bt, hist);
> 
> Can you please elaborate why we are doing bt_set_equal_scalars() in
> these three places and not everywhere else? I'm trying to understand
> whether we should do it more generically for any instruction either
> before or after all the bt_set_xxx() calls...

The before call for BPF_X is for situation when dreg/sreg are not yet
tracked precise but one of the registers that gained range because of
this 'if' is already tracked.

The after call for BPF_X is for situation when say dreg is already
tracked precise but sreg is not and there are some registers had same
id as sreg, that gained range when this 'if' was processed.
The equal_scalars_bpf_x_dst() test case covers this situation.
Here it is for your convenience:

    /* Registers r{0,1,2} share same ID when 'if r1 > r3' insn is processed,
     * check that verifier marks r{0,1,2} as precise while backtracking
     * 'if r1 > r3' with r3 already marked.
     */
    SEC("socket")
    __success __log_level(2)
    __flag(BPF_F_TEST_STATE_FREQ)
    __msg("frame0: regs=r3 stack= before 5: (2d) if r1 > r3 goto pc+0")
    __msg("frame0: parent state regs=r0,r1,r2,r3 stack=:")
    __msg("frame0: regs=r0,r1,r2,r3 stack= before 4: (b7) r3 = 7")
    __naked void equal_scalars_bpf_x_dst(void)
    {
    	asm volatile (
    	/* r0 = random number up to 0xff */
    	"call %[bpf_ktime_get_ns];"
    	"r0 &= 0xff;"
    	/* tie r0.id == r1.id == r2.id */
    	"r1 = r0;"
    	"r2 = r0;"
    	"r3 = 7;"
    	"if r1 > r3 goto +0;"
    	/* force r0 to be precise, this eventually marks r1 and r2 as
    	 * precise as well because of shared IDs
    	 */
    	"r4 = r10;"
    	"r4 += r3;"
    	"r0 = 0;"
    	"exit;"
    	:
    	: __imm(bpf_ktime_get_ns)
    	: __clobber_all);
    }

The before call for BPF_K is the same as before call for BPF_X: for
situation when dreg is not yet tracked precise, but one of the
registers that gained range because of this 'if' is already tracked.

The calls are placed at point where conditional jumps are processed
because 'equal_scalars' are only recorded for conditional jumps.

> 
> >                          /* else dreg <cond> K
> >                           * Only dreg still needs precision before
> >                           * this insn, so for the K-based conditional

[...]
Andrii Nakryiko March 1, 2024, 5:34 p.m. UTC | #8
On Wed, Feb 28, 2024 at 3:29 PM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> On Wed, 2024-02-28 at 15:01 -0800, Andrii Nakryiko wrote:
> [...]
>
> > > @@ -3332,6 +3402,12 @@ static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st
> > >                           "verifier insn history bug: insn_idx %d cur flags %x new flags %x\n",
> > >                           env->insn_idx, cur_hist_ent->flags, insn_flags);
> > >                 cur_hist_ent->flags |= insn_flags;
> > > +               if (cur_hist_ent->equal_scalars != 0) {
> > > +                       verbose(env, "verifier bug: insn_idx %d equal_scalars != 0: %#llx\n",
> > > +                               env->insn_idx, cur_hist_ent->equal_scalars);
> > > +                       return -EFAULT;
> > > +               }
> >
> > let's do WARN_ONCE() just like we do for flags? why deviating?
>
> Ok
>
> [...]
>
> > >  /* For given verifier state backtrack_insn() is called from the last insn to
> > > @@ -3802,6 +3917,7 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> > >                          */
> > >                         return 0;
> > >                 } else if (BPF_SRC(insn->code) == BPF_X) {
> > > +                       bt_set_equal_scalars(bt, hist);
> > >                         if (!bt_is_reg_set(bt, dreg) && !bt_is_reg_set(bt, sreg))
> > >                                 return 0;
> > >                         /* dreg <cond> sreg
> > > @@ -3812,6 +3928,9 @@ static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
> > >                          */
> > >                         bt_set_reg(bt, dreg);
> > >                         bt_set_reg(bt, sreg);
> > > +                       bt_set_equal_scalars(bt, hist);
> > > +               } else if (BPF_SRC(insn->code) == BPF_K) {
> > > +                       bt_set_equal_scalars(bt, hist);
> >
> > Can you please elaborate why we are doing bt_set_equal_scalars() in
> > these three places and not everywhere else? I'm trying to understand
> > whether we should do it more generically for any instruction either
> > before or after all the bt_set_xxx() calls...
>
> The before call for BPF_X is for situation when dreg/sreg are not yet
> tracked precise but one of the registers that gained range because of
> this 'if' is already tracked.
>
> The after call for BPF_X is for situation when say dreg is already
> tracked precise but sreg is not and there are some registers had same
> id as sreg, that gained range when this 'if' was processed.
> The equal_scalars_bpf_x_dst() test case covers this situation.
> Here it is for your convenience:
>
>     /* Registers r{0,1,2} share same ID when 'if r1 > r3' insn is processed,
>      * check that verifier marks r{0,1,2} as precise while backtracking
>      * 'if r1 > r3' with r3 already marked.
>      */
>     SEC("socket")
>     __success __log_level(2)
>     __flag(BPF_F_TEST_STATE_FREQ)
>     __msg("frame0: regs=r3 stack= before 5: (2d) if r1 > r3 goto pc+0")
>     __msg("frame0: parent state regs=r0,r1,r2,r3 stack=:")
>     __msg("frame0: regs=r0,r1,r2,r3 stack= before 4: (b7) r3 = 7")
>     __naked void equal_scalars_bpf_x_dst(void)
>     {
>         asm volatile (
>         /* r0 = random number up to 0xff */
>         "call %[bpf_ktime_get_ns];"
>         "r0 &= 0xff;"
>         /* tie r0.id == r1.id == r2.id */
>         "r1 = r0;"
>         "r2 = r0;"
>         "r3 = 7;"
>         "if r1 > r3 goto +0;"
>         /* force r0 to be precise, this eventually marks r1 and r2 as
>          * precise as well because of shared IDs
>          */
>         "r4 = r10;"
>         "r4 += r3;"
>         "r0 = 0;"
>         "exit;"
>         :
>         : __imm(bpf_ktime_get_ns)
>         : __clobber_all);
>     }
>
> The before call for BPF_K is the same as before call for BPF_X: for
> situation when dreg is not yet tracked precise, but one of the
> registers that gained range because of this 'if' is already tracked.
>
> The calls are placed at point where conditional jumps are processed
> because 'equal_scalars' are only recorded for conditional jumps.

As I mentioned in offline conversation, I wonder if it's better and
less error-prone to do linked register processing in backtrack_insn()
not just for conditional jumps, for all instructions? Whenever we
currently do bpf_set_reg(), we can first check if there are linked
registers and they contain a register we are about to set precise. If
that's the case, set all of them precise.

That would make it unnecessary to have this "before BPF_X|BPF_K"
checks, I think.

It might be sufficient to process just conditional jumps given today's
use of linked registers, but is there any downside to doing it across
all instructions? Are you worried about regression in number of states
due to precision? Or performance?

>
> >
> > >                          /* else dreg <cond> K
> > >                           * Only dreg still needs precision before
> > >                           * this insn, so for the K-based conditional
>
> [...]
Eduard Zingerman March 1, 2024, 5:44 p.m. UTC | #9
On Fri, 2024-03-01 at 09:34 -0800, Andrii Nakryiko wrote:
[...]

> As I mentioned in offline conversation, I wonder if it's better and
> less error-prone to do linked register processing in backtrack_insn()
> not just for conditional jumps, for all instructions? Whenever we
> currently do bpf_set_reg(), we can first check if there are linked
> registers and they contain a register we are about to set precise. If
> that's the case, set all of them precise.
> 
> That would make it unnecessary to have this "before BPF_X|BPF_K"
> checks, I think.

It should not be a problem to do bt_set_equal_scalars() at the
beginning of backtrack_insn().
Same way, I can put the after call at the end of backtrack_insn().
Is this what you have in mind?

> It might be sufficient to process just conditional jumps given today's
> use of linked registers, but is there any downside to doing it across
> all instructions? Are you worried about regression in number of states
> due to precision? Or performance?

Changing position for bt_set_equal_scalars() calls should not affect
anything semantically at the moment. Changes to backtracking state
would be done only if some linked registers are present in 'hist' and
that would be true only for conditional jumps.
Maybe some more CPU cycles but I don't think that would be noticeable.
Andrii Nakryiko March 4, 2024, 11:37 p.m. UTC | #10
On Fri, Mar 1, 2024 at 9:44 AM Eduard Zingerman <eddyz87@gmail.com> wrote:
>
> On Fri, 2024-03-01 at 09:34 -0800, Andrii Nakryiko wrote:
> [...]
>
> > As I mentioned in offline conversation, I wonder if it's better and
> > less error-prone to do linked register processing in backtrack_insn()
> > not just for conditional jumps, for all instructions? Whenever we
> > currently do bpf_set_reg(), we can first check if there are linked
> > registers and they contain a register we are about to set precise. If
> > that's the case, set all of them precise.
> >
> > That would make it unnecessary to have this "before BPF_X|BPF_K"
> > checks, I think.
>
> It should not be a problem to do bt_set_equal_scalars() at the
> beginning of backtrack_insn().
> Same way, I can put the after call at the end of backtrack_insn().
> Is this what you have in mind?

Not exactly. It was more a proposal to change the current use of
bt_set_reg() with bt_set_linked_regs(), which would take into account
linked registers. And do it throughout the entire backtrack_insn(),
regardless of specific instruction being backtracked. I think that
would eliminate the need to have bt_set_equal_scalars() before
instruction as well.

>
> > It might be sufficient to process just conditional jumps given today's
> > use of linked registers, but is there any downside to doing it across
> > all instructions? Are you worried about regression in number of states
> > due to precision? Or performance?
>
> Changing position for bt_set_equal_scalars() calls should not affect
> anything semantically at the moment. Changes to backtracking state
> would be done only if some linked registers are present in 'hist' and
> that would be true only for conditional jumps.
> Maybe some more CPU cycles but I don't think that would be noticeable.
diff mbox series

Patch

diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index cbfb235984c8..26e32555711c 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -361,6 +361,7 @@  struct bpf_jmp_history_entry {
 	u32 prev_idx : 22;
 	/* special flags, e.g., whether insn is doing register stack spill/load */
 	u32 flags : 10;
+	u64 equal_scalars;
 };
 
 /* Maximum number of register states that can exist at once */
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 759ef089b33c..b95b6842703c 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -3304,6 +3304,76 @@  static bool is_jmp_point(struct bpf_verifier_env *env, int insn_idx)
 	return env->insn_aux_data[insn_idx].jmp_point;
 }
 
+#define ES_FRAMENO_BITS	3
+#define ES_SPI_BITS	6
+#define ES_ENTRY_BITS	(ES_SPI_BITS + ES_FRAMENO_BITS + 1)
+#define ES_SIZE_BITS	4
+#define ES_FRAMENO_MASK	((1ul << ES_FRAMENO_BITS) - 1)
+#define ES_SPI_MASK	((1ul << ES_SPI_BITS)     - 1)
+#define ES_SIZE_MASK	((1ul << ES_SIZE_BITS)    - 1)
+#define ES_SPI_OFF	ES_FRAMENO_BITS
+#define ES_IS_REG_OFF	(ES_SPI_BITS + ES_FRAMENO_BITS)
+
+/* Pack one history entry for equal scalars as 10 bits in the following format:
+ * - 3-bits frameno
+ * - 6-bits spi_or_reg
+ * - 1-bit  is_reg
+ */
+static u64 equal_scalars_pack(u32 frameno, u32 spi_or_reg, bool is_reg)
+{
+	u64 val = 0;
+
+	val |= frameno & ES_FRAMENO_MASK;
+	val |= (spi_or_reg & ES_SPI_MASK) << ES_SPI_OFF;
+	val |= (is_reg ? 1 : 0) << ES_IS_REG_OFF;
+	return val;
+}
+
+static void equal_scalars_unpack(u64 val, u32 *frameno, u32 *spi_or_reg, bool *is_reg)
+{
+	*frameno    =  val & ES_FRAMENO_MASK;
+	*spi_or_reg = (val >> ES_SPI_OFF) & ES_SPI_MASK;
+	*is_reg     = (val >> ES_IS_REG_OFF) & 0x1;
+}
+
+static u32 equal_scalars_size(u64 equal_scalars)
+{
+	return equal_scalars & ES_SIZE_MASK;
+}
+
+/* Use u64 as a stack of 6 10-bit values, use first 4-bits to track
+ * number of elements currently in stack.
+ */
+static bool equal_scalars_push(u64 *equal_scalars, u32 frameno, u32 spi_or_reg, bool is_reg)
+{
+	u32 num;
+
+	num = equal_scalars_size(*equal_scalars);
+	if (num == 6)
+		return false;
+	*equal_scalars >>= ES_SIZE_BITS;
+	*equal_scalars <<= ES_ENTRY_BITS;
+	*equal_scalars |= equal_scalars_pack(frameno, spi_or_reg, is_reg);
+	*equal_scalars <<= ES_SIZE_BITS;
+	*equal_scalars |= num + 1;
+	return true;
+}
+
+static bool equal_scalars_pop(u64 *equal_scalars, u32 *frameno, u32 *spi_or_reg, bool *is_reg)
+{
+	u32 num;
+
+	num = equal_scalars_size(*equal_scalars);
+	if (num == 0)
+		return false;
+	*equal_scalars >>= ES_SIZE_BITS;
+	equal_scalars_unpack(*equal_scalars, frameno, spi_or_reg, is_reg);
+	*equal_scalars >>= ES_ENTRY_BITS;
+	*equal_scalars <<= ES_SIZE_BITS;
+	*equal_scalars |= num - 1;
+	return true;
+}
+
 static struct bpf_jmp_history_entry *get_jmp_hist_entry(struct bpf_verifier_state *st,
 							u32 hist_end, int insn_idx)
 {
@@ -3314,7 +3384,7 @@  static struct bpf_jmp_history_entry *get_jmp_hist_entry(struct bpf_verifier_stat
 
 /* for any branch, call, exit record the history of jmps in the given state */
 static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_state *cur,
-			    int insn_flags)
+			    int insn_flags, u64 equal_scalars)
 {
 	struct bpf_jmp_history_entry *p, *cur_hist_ent;
 	u32 cnt = cur->jmp_history_cnt;
@@ -3332,6 +3402,12 @@  static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st
 			  "verifier insn history bug: insn_idx %d cur flags %x new flags %x\n",
 			  env->insn_idx, cur_hist_ent->flags, insn_flags);
 		cur_hist_ent->flags |= insn_flags;
+		if (cur_hist_ent->equal_scalars != 0) {
+			verbose(env, "verifier bug: insn_idx %d equal_scalars != 0: %#llx\n",
+				env->insn_idx, cur_hist_ent->equal_scalars);
+			return -EFAULT;
+		}
+		cur_hist_ent->equal_scalars = equal_scalars;
 		return 0;
 	}
 
@@ -3346,6 +3422,7 @@  static int push_jmp_history(struct bpf_verifier_env *env, struct bpf_verifier_st
 	p->idx = env->insn_idx;
 	p->prev_idx = env->prev_insn_idx;
 	p->flags = insn_flags;
+	p->equal_scalars = equal_scalars;
 	cur->jmp_history_cnt = cnt;
 
 	return 0;
@@ -3502,6 +3579,11 @@  static inline bool bt_is_reg_set(struct backtrack_state *bt, u32 reg)
 	return bt->reg_masks[bt->frame] & (1 << reg);
 }
 
+static inline bool bt_is_frame_reg_set(struct backtrack_state *bt, u32 frame, u32 reg)
+{
+	return bt->reg_masks[frame] & (1 << reg);
+}
+
 static inline bool bt_is_frame_slot_set(struct backtrack_state *bt, u32 frame, u32 slot)
 {
 	return bt->stack_masks[frame] & (1ull << slot);
@@ -3546,6 +3628,39 @@  static void fmt_stack_mask(char *buf, ssize_t buf_sz, u64 stack_mask)
 	}
 }
 
+/* If any register R in hist->equal_scalars is marked as precise in bt,
+ * do bt_set_frame_{reg,slot}(bt, R) for all registers in hist->equal_scalars.
+ */
+static void bt_set_equal_scalars(struct backtrack_state *bt, struct bpf_jmp_history_entry *hist)
+{
+	bool is_reg, some_precise = false;
+	u64 equal_scalars;
+	u32 fr, spi;
+
+	if (!hist || hist->equal_scalars == 0)
+		return;
+
+	equal_scalars = hist->equal_scalars;
+	while (equal_scalars_pop(&equal_scalars, &fr, &spi, &is_reg)) {
+		if ((is_reg && bt_is_frame_reg_set(bt, fr, spi)) ||
+		    (!is_reg && bt_is_frame_slot_set(bt, fr, spi))) {
+			some_precise = true;
+			break;
+		}
+	}
+
+	if (!some_precise)
+		return;
+
+	equal_scalars = hist->equal_scalars;
+	while (equal_scalars_pop(&equal_scalars, &fr, &spi, &is_reg)) {
+		if (is_reg)
+			bt_set_frame_reg(bt, fr, spi);
+		else
+			bt_set_frame_slot(bt, fr, spi);
+	}
+}
+
 static bool calls_callback(struct bpf_verifier_env *env, int insn_idx);
 
 /* For given verifier state backtrack_insn() is called from the last insn to
@@ -3802,6 +3917,7 @@  static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
 			 */
 			return 0;
 		} else if (BPF_SRC(insn->code) == BPF_X) {
+			bt_set_equal_scalars(bt, hist);
 			if (!bt_is_reg_set(bt, dreg) && !bt_is_reg_set(bt, sreg))
 				return 0;
 			/* dreg <cond> sreg
@@ -3812,6 +3928,9 @@  static int backtrack_insn(struct bpf_verifier_env *env, int idx, int subseq_idx,
 			 */
 			bt_set_reg(bt, dreg);
 			bt_set_reg(bt, sreg);
+			bt_set_equal_scalars(bt, hist);
+		} else if (BPF_SRC(insn->code) == BPF_K) {
+			bt_set_equal_scalars(bt, hist);
 			 /* else dreg <cond> K
 			  * Only dreg still needs precision before
 			  * this insn, so for the K-based conditional
@@ -4579,7 +4698,7 @@  static int check_stack_write_fixed_off(struct bpf_verifier_env *env,
 	}
 
 	if (insn_flags)
-		return push_jmp_history(env, env->cur_state, insn_flags);
+		return push_jmp_history(env, env->cur_state, insn_flags, 0);
 	return 0;
 }
 
@@ -4884,7 +5003,7 @@  static int check_stack_read_fixed_off(struct bpf_verifier_env *env,
 		insn_flags = 0; /* we are not restoring spilled register */
 	}
 	if (insn_flags)
-		return push_jmp_history(env, env->cur_state, insn_flags);
+		return push_jmp_history(env, env->cur_state, insn_flags, 0);
 	return 0;
 }
 
@@ -14835,16 +14954,58 @@  static bool try_match_pkt_pointers(const struct bpf_insn *insn,
 	return true;
 }
 
-static void find_equal_scalars(struct bpf_verifier_state *vstate,
-			       struct bpf_reg_state *known_reg)
+static void __find_equal_scalars(u64 *equal_scalars,
+				 struct bpf_reg_state *reg,
+				 u32 id, u32 frameno, u32 spi_or_reg, bool is_reg)
 {
-	struct bpf_func_state *state;
+	if (reg->type != SCALAR_VALUE || reg->id != id)
+		return;
+
+	if (!equal_scalars_push(equal_scalars, frameno, spi_or_reg, is_reg))
+		reg->id = 0;
+}
+
+/* For all R being scalar registers or spilled scalar registers
+ * in verifier state, save R in equal_scalars if R->id == id.
+ * If there are too many Rs sharing same id, reset id for leftover Rs.
+ */
+static void find_equal_scalars(struct bpf_verifier_state *vstate, u32 id, u64 *equal_scalars)
+{
+	struct bpf_func_state *func;
 	struct bpf_reg_state *reg;
+	int i, j;
 
-	bpf_for_each_reg_in_vstate(vstate, state, reg, ({
-		if (reg->type == SCALAR_VALUE && reg->id == known_reg->id)
+	for (i = vstate->curframe; i >= 0; i--) {
+		func = vstate->frame[i];
+		for (j = 0; j < BPF_REG_FP; j++) {
+			reg = &func->regs[j];
+			__find_equal_scalars(equal_scalars, reg, id, i, j, true);
+		}
+		for (j = 0; j < func->allocated_stack / BPF_REG_SIZE; j++) {
+			if (!is_spilled_reg(&func->stack[j]))
+				continue;
+			reg = &func->stack[j].spilled_ptr;
+			__find_equal_scalars(equal_scalars, reg, id, i, j, false);
+		}
+	}
+}
+
+/* For all R in equal_scalars, copy known_reg range into R
+ * if R->id == known_reg->id.
+ */
+static void copy_known_reg(struct bpf_verifier_state *vstate,
+			   struct bpf_reg_state *known_reg, u64 equal_scalars)
+{
+	struct bpf_reg_state *reg;
+	u32 fr, spi;
+	bool is_reg;
+
+	while (equal_scalars_pop(&equal_scalars, &fr, &spi, &is_reg)) {
+		reg = is_reg ? &vstate->frame[fr]->regs[spi]
+			     : &vstate->frame[fr]->stack[spi].spilled_ptr;
+		if (reg->id == known_reg->id)
 			copy_register_state(reg, known_reg);
-	}));
+	}
 }
 
 static int check_cond_jmp_op(struct bpf_verifier_env *env,
@@ -14857,6 +15018,7 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 	struct bpf_reg_state *eq_branch_regs;
 	struct bpf_reg_state fake_reg = {};
 	u8 opcode = BPF_OP(insn->code);
+	u64 equal_scalars = 0;
 	bool is_jmp32;
 	int pred = -1;
 	int err;
@@ -14944,6 +15106,21 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 		return 0;
 	}
 
+	/* Push scalar registers sharing same ID to jump history,
+	 * do this before creating 'other_branch', so that both
+	 * 'this_branch' and 'other_branch' share this history
+	 * if parent state is created.
+	 */
+	if (BPF_SRC(insn->code) == BPF_X && src_reg->type == SCALAR_VALUE && src_reg->id)
+		find_equal_scalars(this_branch, src_reg->id, &equal_scalars);
+	if (dst_reg->type == SCALAR_VALUE && dst_reg->id)
+		find_equal_scalars(this_branch, dst_reg->id, &equal_scalars);
+	if (equal_scalars_size(equal_scalars) > 1) {
+		err = push_jmp_history(env, this_branch, 0, equal_scalars);
+		if (err)
+			return err;
+	}
+
 	other_branch = push_stack(env, *insn_idx + insn->off + 1, *insn_idx,
 				  false);
 	if (!other_branch)
@@ -14968,13 +15145,13 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 	if (BPF_SRC(insn->code) == BPF_X &&
 	    src_reg->type == SCALAR_VALUE && src_reg->id &&
 	    !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {
-		find_equal_scalars(this_branch, src_reg);
-		find_equal_scalars(other_branch, &other_branch_regs[insn->src_reg]);
+		copy_known_reg(this_branch, src_reg, equal_scalars);
+		copy_known_reg(other_branch, &other_branch_regs[insn->src_reg], equal_scalars);
 	}
 	if (dst_reg->type == SCALAR_VALUE && dst_reg->id &&
 	    !WARN_ON_ONCE(dst_reg->id != other_branch_regs[insn->dst_reg].id)) {
-		find_equal_scalars(this_branch, dst_reg);
-		find_equal_scalars(other_branch, &other_branch_regs[insn->dst_reg]);
+		copy_known_reg(this_branch, dst_reg, equal_scalars);
+		copy_known_reg(other_branch, &other_branch_regs[insn->dst_reg], equal_scalars);
 	}
 
 	/* if one pointer register is compared to another pointer
@@ -17213,7 +17390,7 @@  static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
 			 * the current state.
 			 */
 			if (is_jmp_point(env, env->insn_idx))
-				err = err ? : push_jmp_history(env, cur, 0);
+				err = err ? : push_jmp_history(env, cur, 0, 0);
 			err = err ? : propagate_precision(env, &sl->state);
 			if (err)
 				return err;
@@ -17477,7 +17654,7 @@  static int do_check(struct bpf_verifier_env *env)
 		}
 
 		if (is_jmp_point(env, env->insn_idx)) {
-			err = push_jmp_history(env, state, 0);
+			err = push_jmp_history(env, state, 0, 0);
 			if (err)
 				return err;
 		}
diff --git a/tools/testing/selftests/bpf/progs/verifier_subprog_precision.c b/tools/testing/selftests/bpf/progs/verifier_subprog_precision.c
index 6f5d19665cf6..2c7261834149 100644
--- a/tools/testing/selftests/bpf/progs/verifier_subprog_precision.c
+++ b/tools/testing/selftests/bpf/progs/verifier_subprog_precision.c
@@ -191,7 +191,7 @@  __msg("mark_precise: frame0: last_idx 14 first_idx 9")
 __msg("mark_precise: frame0: regs=r6 stack= before 13: (bf) r1 = r7")
 __msg("mark_precise: frame0: regs=r6 stack= before 12: (27) r6 *= 4")
 __msg("mark_precise: frame0: regs=r6 stack= before 11: (25) if r6 > 0x3 goto pc+4")
-__msg("mark_precise: frame0: regs=r6 stack= before 10: (bf) r6 = r0")
+__msg("mark_precise: frame0: regs=r0,r6 stack= before 10: (bf) r6 = r0")
 __msg("mark_precise: frame0: regs=r0 stack= before 9: (85) call bpf_loop")
 /* State entering callback body popped from states stack */
 __msg("from 9 to 17: frame1:")
diff --git a/tools/testing/selftests/bpf/verifier/precise.c b/tools/testing/selftests/bpf/verifier/precise.c
index 0a9293a57211..64d722199e8f 100644
--- a/tools/testing/selftests/bpf/verifier/precise.c
+++ b/tools/testing/selftests/bpf/verifier/precise.c
@@ -44,7 +44,7 @@ 
 	mark_precise: frame0: regs=r2 stack= before 23\
 	mark_precise: frame0: regs=r2 stack= before 22\
 	mark_precise: frame0: regs=r2 stack= before 20\
-	mark_precise: frame0: parent state regs=r2 stack=:\
+	mark_precise: frame0: parent state regs=r2,r9 stack=:\
 	mark_precise: frame0: last_idx 19 first_idx 10\
 	mark_precise: frame0: regs=r2,r9 stack= before 19\
 	mark_precise: frame0: regs=r9 stack= before 18\