diff mbox series

[v5,bpf-next,1/4] bpf: Introduce may_goto instruction

Message ID 20240305045219.66142-2-alexei.starovoitov@gmail.com (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series bpf: Introduce may_goto and cond_break | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-VM_Test-0 success Logs for Lint
bpf/vmtest-bpf-next-VM_Test-3 success Logs for Validate matrix.py
bpf/vmtest-bpf-next-VM_Test-2 success Logs for Unittests
bpf/vmtest-bpf-next-VM_Test-1 success Logs for ShellCheck
bpf/vmtest-bpf-next-VM_Test-5 success Logs for aarch64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-24 success Logs for x86_64-gcc / test (test_progs_no_alu32_parallel, true, 30) / test_progs_no_alu32_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-25 success Logs for x86_64-gcc / test (test_progs_parallel, true, 30) / test_progs_parallel on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-26 success Logs for x86_64-gcc / test (test_verifier, false, 360) / test_verifier on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-27 success Logs for x86_64-gcc / veristat / veristat on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-4 success Logs for aarch64-gcc / build / build for aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-10 success Logs for aarch64-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-11 success Logs for s390x-gcc / build / build for s390x with gcc
bpf/vmtest-bpf-next-VM_Test-9 success Logs for aarch64-gcc / test (test_verifier, false, 360) / test_verifier on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-12 success Logs for s390x-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-19 success Logs for x86_64-gcc / build / build for x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-17 success Logs for s390x-gcc / veristat
bpf/vmtest-bpf-next-VM_Test-18 success Logs for set-matrix
bpf/vmtest-bpf-next-VM_Test-16 success Logs for s390x-gcc / test (test_verifier, false, 360) / test_verifier on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-20 success Logs for x86_64-gcc / build-release
bpf/vmtest-bpf-next-VM_Test-21 success Logs for x86_64-gcc / test (test_maps, false, 360) / test_maps on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-28 success Logs for x86_64-llvm-17 / build / build for x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-30 success Logs for x86_64-llvm-17 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-33 success Logs for x86_64-llvm-17 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-34 success Logs for x86_64-llvm-17 / veristat
bpf/vmtest-bpf-next-VM_Test-35 success Logs for x86_64-llvm-18 / build / build for x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-37 success Logs for x86_64-llvm-18 / test (test_maps, false, 360) / test_maps on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-41 success Logs for x86_64-llvm-18 / test (test_verifier, false, 360) / test_verifier on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-42 success Logs for x86_64-llvm-18 / veristat
bpf/vmtest-bpf-next-VM_Test-6 success Logs for aarch64-gcc / test (test_maps, false, 360) / test_maps on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-7 success Logs for aarch64-gcc / test (test_progs, false, 360) / test_progs on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-8 success Logs for aarch64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on aarch64 with gcc
bpf/vmtest-bpf-next-VM_Test-13 success Logs for s390x-gcc / test (test_maps, false, 360) / test_maps on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-22 success Logs for x86_64-gcc / test (test_progs, false, 360) / test_progs on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-23 success Logs for x86_64-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with gcc
bpf/vmtest-bpf-next-VM_Test-31 success Logs for x86_64-llvm-17 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-32 success Logs for x86_64-llvm-17 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-17
bpf/vmtest-bpf-next-VM_Test-36 success Logs for x86_64-llvm-18 / build-release / build for x86_64 with llvm-18 and -O2 optimization
bpf/vmtest-bpf-next-VM_Test-38 success Logs for x86_64-llvm-18 / test (test_progs, false, 360) / test_progs on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-39 success Logs for x86_64-llvm-18 / test (test_progs_cpuv4, false, 360) / test_progs_cpuv4 on x86_64 with llvm-18
bpf/vmtest-bpf-next-VM_Test-40 success Logs for x86_64-llvm-18 / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on x86_64 with llvm-18
bpf/vmtest-bpf-next-PR success PR summary
bpf/vmtest-bpf-next-VM_Test-15 success Logs for s390x-gcc / test (test_progs_no_alu32, false, 360) / test_progs_no_alu32 on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-14 success Logs for s390x-gcc / test (test_progs, false, 360) / test_progs on s390x with gcc
bpf/vmtest-bpf-next-VM_Test-29 success Logs for x86_64-llvm-17 / build-release / build for x86_64 with llvm-17 and -O2 optimization
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for bpf-next, async
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 7577 this patch: 7577
netdev/build_tools success Errors and warnings before: 1 this patch: 1
netdev/cc_maintainers warning 13 maintainers not CCed: jolsa@kernel.org kpsingh@kernel.org ndesaulniers@google.com song@kernel.org nathan@kernel.org yonghong.song@linux.dev haoluo@google.com justinstitt@google.com martin.lau@linux.dev morbo@google.com sdf@google.com quentin@isovalent.com llvm@lists.linux.dev
netdev/build_clang success Errors and warnings before: 2250 this patch: 2250
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 8070 this patch: 8070
netdev/checkpatch warning CHECK: multiple assignments should be avoided WARNING: Prefer 'fallthrough;' over fallthrough comment WARNING: line length of 81 exceeds 80 columns WARNING: line length of 84 exceeds 80 columns WARNING: line length of 86 exceeds 80 columns WARNING: line length of 89 exceeds 80 columns WARNING: line length of 92 exceeds 80 columns WARNING: line length of 93 exceeds 80 columns WARNING: line length of 98 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Alexei Starovoitov March 5, 2024, 4:52 a.m. UTC
From: Alexei Starovoitov <ast@kernel.org>

Introduce may_goto instruction that acts on a hidden bpf_iter_num, so that
bpf_iter_num_new(), bpf_iter_num_destroy() don't need to be called explicitly.
It can be used in any normal "for" or "while" loop, like

  for (i = zero; i < cnt; cond_break, i++) {

The verifier recognizes that may_goto is used in the program,
reserves additional 8 bytes of stack, initializes them in subprog
prologue, and replaces may_goto instruction with:
aux_reg = *(u64 *)(fp - 40)
if aux_reg == 0 goto pc+off
aux_reg += 1
*(u64 *)(fp - 40) = aux_reg

may_goto instruction can be used by LLVM to implement __builtin_memcpy,
__builtin_strcmp.

may_goto is not a full substitute for bpf_for() macro.
bpf_for() doesn't have induction variable that verifiers sees,
so 'i' in bpf_for(i, 0, 100) is seen as imprecise and bounded.

But when the code is written as:
for (i = 0; i < 100; cond_break, i++)
the verifier see 'i' as precise constant zero,
hence cond_break (aka may_goto) doesn't help to converge the loop.
A static or global variable can be used as a workaround:
static int zero = 0;
for (i = zero; i < 100; cond_break, i++) // works!

may_goto works well with arena pointers that don't need to be bounds-checked
on every iteration. Load/store from arena returns imprecise unbounded scalars.

Reserve new opcode BPF_JMP | BPF_JMA for may_goto insn.
JMA stands for "jump maybe", and "jump multipurpose", and "jump multi always".
Since goto_or_nop insn was proposed, it may use the same opcode.
may_goto vs goto_or_nop can be distinguished by src_reg:
code = BPF_JMP | BPF_JMA:
src_reg = 0 - may_goto
src_reg = 1 - goto_or_nop
We could have reused BPF_JMP | BPF_JA like:
src_reg = 0 - normal goto
src_reg = 1 - may_goto
src_reg = 2 - goto_or_nop
but JA is a real insn and it's unconditional, while may_goto and goto_or_nop
are pseudo instructions, and both are conditional. Hence it's better to
have a different opcode for them. Hence BPF_JMA.

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
---
 include/linux/bpf_verifier.h   |   2 +
 include/uapi/linux/bpf.h       |   1 +
 kernel/bpf/core.c              |   1 +
 kernel/bpf/disasm.c            |   3 +
 kernel/bpf/verifier.c          | 156 ++++++++++++++++++++++++++-------
 tools/include/uapi/linux/bpf.h |   1 +
 6 files changed, 134 insertions(+), 30 deletions(-)

Comments

Andrii Nakryiko March 5, 2024, 6:37 p.m. UTC | #1
On Mon, Mar 4, 2024 at 8:52 PM Alexei Starovoitov
<alexei.starovoitov@gmail.com> wrote:
>
> From: Alexei Starovoitov <ast@kernel.org>
>
> Introduce may_goto instruction that acts on a hidden bpf_iter_num, so that
> bpf_iter_num_new(), bpf_iter_num_destroy() don't need to be called explicitly.

bpf_iter_num was probably an inspiration, but I think by now the
analogy is pretty weak. bpf_iter_num_next() returns NULL or pointer to
int (i.e., it returns some usable value), while may_goto jumps or not.
So it's not just implicit new/destroy. The above doesn't confuse me,
but I wonder if someone less familiar with iterators would be confused
by the above?

> It can be used in any normal "for" or "while" loop, like
>
>   for (i = zero; i < cnt; cond_break, i++) {
>
> The verifier recognizes that may_goto is used in the program,
> reserves additional 8 bytes of stack, initializes them in subprog
> prologue, and replaces may_goto instruction with:
> aux_reg = *(u64 *)(fp - 40)
> if aux_reg == 0 goto pc+off
> aux_reg += 1

`aux_reg -= 1`?

> *(u64 *)(fp - 40) = aux_reg
>
> may_goto instruction can be used by LLVM to implement __builtin_memcpy,
> __builtin_strcmp.
>
> may_goto is not a full substitute for bpf_for() macro.
> bpf_for() doesn't have induction variable that verifiers sees,
> so 'i' in bpf_for(i, 0, 100) is seen as imprecise and bounded.
>
> But when the code is written as:
> for (i = 0; i < 100; cond_break, i++)
> the verifier see 'i' as precise constant zero,
> hence cond_break (aka may_goto) doesn't help to converge the loop.
> A static or global variable can be used as a workaround:
> static int zero = 0;
> for (i = zero; i < 100; cond_break, i++) // works!
>
> may_goto works well with arena pointers that don't need to be bounds-checked
> on every iteration. Load/store from arena returns imprecise unbounded scalars.
>
> Reserve new opcode BPF_JMP | BPF_JMA for may_goto insn.
> JMA stands for "jump maybe", and "jump multipurpose", and "jump multi always".
> Since goto_or_nop insn was proposed, it may use the same opcode.
> may_goto vs goto_or_nop can be distinguished by src_reg:
> code = BPF_JMP | BPF_JMA:
> src_reg = 0 - may_goto
> src_reg = 1 - goto_or_nop
> We could have reused BPF_JMP | BPF_JA like:
> src_reg = 0 - normal goto
> src_reg = 1 - may_goto
> src_reg = 2 - goto_or_nop
> but JA is a real insn and it's unconditional, while may_goto and goto_or_nop
> are pseudo instructions, and both are conditional. Hence it's better to
> have a different opcode for them. Hence BPF_JMA.
>
> Signed-off-by: Alexei Starovoitov <ast@kernel.org>
> ---
>  include/linux/bpf_verifier.h   |   2 +
>  include/uapi/linux/bpf.h       |   1 +
>  kernel/bpf/core.c              |   1 +
>  kernel/bpf/disasm.c            |   3 +
>  kernel/bpf/verifier.c          | 156 ++++++++++++++++++++++++++-------
>  tools/include/uapi/linux/bpf.h |   1 +
>  6 files changed, 134 insertions(+), 30 deletions(-)
>

Not a huge fan of BPF_JMA, but there is no clear naming winner.
BPF_JAUX, BPF_JPSEUDO, BPF_JMAYBE, would be a bit more
greppable/recognizable, but it's not a big deal.

Left few nits below, but overall LGTM

Acked-by: Andrii Nakryiko <andrii@kernel.org>

> diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> index 84365e6dd85d..917ca603059b 100644
> --- a/include/linux/bpf_verifier.h
> +++ b/include/linux/bpf_verifier.h
> @@ -449,6 +449,7 @@ struct bpf_verifier_state {
>         u32 jmp_history_cnt;
>         u32 dfs_depth;
>         u32 callback_unroll_depth;
> +       u32 may_goto_cnt;

naming nit: seems like we consistently use "depth" terminology for
bpf_loop and open-coded iters, any reason to deviate with "cnt"
terminology here?

>  };
>
>  #define bpf_get_spilled_reg(slot, frame, mask)                         \

[...]

>
> +static bool is_may_goto_insn(struct bpf_verifier_env *env, int insn_idx)
> +{
> +       return env->prog->insnsi[insn_idx].code == (BPF_JMP | BPF_JMA);
> +}
> +
>  /* process_iter_next_call() is called when verifier gets to iterator's next
>   * "method" (e.g., bpf_iter_num_next() for numbers iterator) call. We'll refer
>   * to it as just "iter_next()" in comments below.
> @@ -14871,11 +14877,35 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
>         int err;
>
>         /* Only conditional jumps are expected to reach here. */
> -       if (opcode == BPF_JA || opcode > BPF_JSLE) {
> +       if (opcode == BPF_JA || opcode > BPF_JMA) {
>                 verbose(env, "invalid BPF_JMP/JMP32 opcode %x\n", opcode);
>                 return -EINVAL;
>         }
>
> +       if (opcode == BPF_JMA) {
> +               struct bpf_verifier_state *cur_st = env->cur_state, *queued_st, *prev_st;
> +               int idx = *insn_idx;
> +
> +               if (insn->code != (BPF_JMP | BPF_JMA) ||
> +                   insn->src_reg || insn->dst_reg || insn->imm || insn->off == 0) {
> +                       verbose(env, "invalid may_goto off %d imm %d\n",
> +                               insn->off, insn->imm);
> +                       return -EINVAL;
> +               }
> +               prev_st = find_prev_entry(env, cur_st->parent, idx);
> +
> +               /* branch out 'fallthrough' insn as a new state to explore */
> +               queued_st = push_stack(env, idx + 1, idx, false);
> +               if (!queued_st)
> +                       return -ENOMEM;
> +
> +               queued_st->may_goto_cnt++;
> +               if (prev_st)
> +                       widen_imprecise_scalars(env, prev_st, queued_st);
> +               *insn_idx += insn->off;
> +               return 0;
> +       }
> +
>         /* check src2 operand */
>         err = check_reg_arg(env, insn->dst_reg, SRC_OP);
>         if (err)
> @@ -15659,6 +15689,8 @@ static int visit_insn(int t, struct bpf_verifier_env *env)
>         default:
>                 /* conditional jump with two edges */
>                 mark_prune_point(env, t);
> +               if (insn->code == (BPF_JMP | BPF_JMA))

maybe use is_may_goto_insn() here for consistency?

> +                       mark_force_checkpoint(env, t);
>
>                 ret = push_insn(t, t + 1, FALLTHROUGH, env);
>                 if (ret)

[...]

>  patch_call_imm:
>                 fn = env->ops->get_func_proto(insn->imm, env->prog);
> @@ -19952,6 +20015,39 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
>                         return -EFAULT;
>                 }
>                 insn->imm = fn->func - __bpf_call_base;
> +next_insn:
> +               if (subprogs[cur_subprog + 1].start == i + delta + 1) {
> +                       subprogs[cur_subprog].stack_depth += stack_depth_extra;
> +                       subprogs[cur_subprog].stack_extra = stack_depth_extra;
> +                       cur_subprog++;
> +                       stack_depth = subprogs[cur_subprog].stack_depth;
> +                       stack_depth_extra = 0;
> +               }
> +               i++; insn++;

Is there a code path where we don't do i++, insn++? From cursory look
at this loop, I think we always do this, so not sure why `i++, insn++`
had to be moved from for() clause?

But if I missed it and we have to do these increments here, these are
two separate statements, so let's put them on separate lines?

> +       }
> +
> +       env->prog->aux->stack_depth = subprogs[0].stack_depth;
> +       for (i = 0; i < env->subprog_cnt; i++) {
> +               int subprog_start = subprogs[i].start, j;
> +               int stack_slots = subprogs[i].stack_extra / 8;
> +
> +               if (stack_slots >= ARRAY_SIZE(insn_buf)) {
> +                       verbose(env, "verifier bug: stack_extra is too large\n");
> +                       return -EFAULT;
> +               }
> +
> +               /* Add insns to subprog prologue to init extra stack */
> +               for (j = 0; j < stack_slots; j++)
> +                       insn_buf[j] = BPF_ST_MEM(BPF_DW, BPF_REG_FP,
> +                                                -subprogs[i].stack_depth + j * 8, BPF_MAX_LOOPS);
> +               if (j) {
> +                       insn_buf[j] = env->prog->insnsi[subprog_start];
> +
> +                       new_prog = bpf_patch_insn_data(env, subprog_start, insn_buf, j + 1);
> +                       if (!new_prog)
> +                               return -ENOMEM;
> +                       env->prog = prog = new_prog;
> +               }

this code is sort of generic (you don't assume just 0 or 1 extra
slots), but then it initializes each extra slot with BPF_MAX_LOOPS,
which doesn't look generic at all. So it's neither as simple as it
could be nor generic, really...

Maybe let's add WARN_ON if stack_extra>1 (so we catch it if we ever
extend this), but otherwise just have a simple and easier to follow

if (stack_slots) {
    insn_buf[0] = BPF_ST_MEM(..., BPF_MAX_LOOPS);
    /* bpf_patch_insn_data() replaces instruction,
     * so we need to copy first actual insn to preserve it (it's not
that obvious)
     */
    insn_buf[1] = env->prog->insnsi[subprog_start];
    ... patch ...
}

It's pretty minor, overall, but definitely caused some pause for me.

>         }
>
>         /* Since poke tab is now finalized, publish aux to tracker. */
> diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
> index a241f407c234..932ffef0dc88 100644
> --- a/tools/include/uapi/linux/bpf.h
> +++ b/tools/include/uapi/linux/bpf.h
> @@ -42,6 +42,7 @@
>  #define BPF_JSGE       0x70    /* SGE is signed '>=', GE in x86 */
>  #define BPF_JSLT       0xc0    /* SLT is signed, '<' */
>  #define BPF_JSLE       0xd0    /* SLE is signed, '<=' */
> +#define BPF_JMA                0xe0    /* may_goto */
>  #define BPF_CALL       0x80    /* function call */
>  #define BPF_EXIT       0x90    /* function return */
>
> --
> 2.43.0
>
Alexei Starovoitov March 5, 2024, 8:52 p.m. UTC | #2
On Tue, Mar 5, 2024 at 10:38 AM Andrii Nakryiko
<andrii.nakryiko@gmail.com> wrote:
>
> On Mon, Mar 4, 2024 at 8:52 PM Alexei Starovoitov
> <alexei.starovoitov@gmail.com> wrote:
> >
> > From: Alexei Starovoitov <ast@kernel.org>
> >
> > Introduce may_goto instruction that acts on a hidden bpf_iter_num, so that
> > bpf_iter_num_new(), bpf_iter_num_destroy() don't need to be called explicitly.
>
> bpf_iter_num was probably an inspiration, but I think by now the
> analogy is pretty weak. bpf_iter_num_next() returns NULL or pointer to
> int (i.e., it returns some usable value), while may_goto jumps or not.
> So it's not just implicit new/destroy. The above doesn't confuse me,
> but I wonder if someone less familiar with iterators would be confused
> by the above?

Agree. Will reword.

>
> > It can be used in any normal "for" or "while" loop, like
> >
> >   for (i = zero; i < cnt; cond_break, i++) {
> >
> > The verifier recognizes that may_goto is used in the program,
> > reserves additional 8 bytes of stack, initializes them in subprog
> > prologue, and replaces may_goto instruction with:
> > aux_reg = *(u64 *)(fp - 40)
> > if aux_reg == 0 goto pc+off
> > aux_reg += 1
>
> `aux_reg -= 1`?

+1 for -1

>
> > *(u64 *)(fp - 40) = aux_reg
> >
> > may_goto instruction can be used by LLVM to implement __builtin_memcpy,
> > __builtin_strcmp.
> >
> > may_goto is not a full substitute for bpf_for() macro.
> > bpf_for() doesn't have induction variable that verifiers sees,
> > so 'i' in bpf_for(i, 0, 100) is seen as imprecise and bounded.
> >
> > But when the code is written as:
> > for (i = 0; i < 100; cond_break, i++)
> > the verifier see 'i' as precise constant zero,
> > hence cond_break (aka may_goto) doesn't help to converge the loop.
> > A static or global variable can be used as a workaround:
> > static int zero = 0;
> > for (i = zero; i < 100; cond_break, i++) // works!
> >
> > may_goto works well with arena pointers that don't need to be bounds-checked
> > on every iteration. Load/store from arena returns imprecise unbounded scalars.
> >
> > Reserve new opcode BPF_JMP | BPF_JMA for may_goto insn.
> > JMA stands for "jump maybe", and "jump multipurpose", and "jump multi always".
> > Since goto_or_nop insn was proposed, it may use the same opcode.
> > may_goto vs goto_or_nop can be distinguished by src_reg:
> > code = BPF_JMP | BPF_JMA:
> > src_reg = 0 - may_goto
> > src_reg = 1 - goto_or_nop
> > We could have reused BPF_JMP | BPF_JA like:
> > src_reg = 0 - normal goto
> > src_reg = 1 - may_goto
> > src_reg = 2 - goto_or_nop
> > but JA is a real insn and it's unconditional, while may_goto and goto_or_nop
> > are pseudo instructions, and both are conditional. Hence it's better to
> > have a different opcode for them. Hence BPF_JMA.
> >
> > Signed-off-by: Alexei Starovoitov <ast@kernel.org>
> > ---
> >  include/linux/bpf_verifier.h   |   2 +
> >  include/uapi/linux/bpf.h       |   1 +
> >  kernel/bpf/core.c              |   1 +
> >  kernel/bpf/disasm.c            |   3 +
> >  kernel/bpf/verifier.c          | 156 ++++++++++++++++++++++++++-------
> >  tools/include/uapi/linux/bpf.h |   1 +
> >  6 files changed, 134 insertions(+), 30 deletions(-)
> >
>
> Not a huge fan of BPF_JMA, but there is no clear naming winner.
> BPF_JAUX, BPF_JPSEUDO, BPF_JMAYBE, would be a bit more
> greppable/recognizable, but it's not a big deal.

In the next version I'm planning to go with BPF_JCOND
to describe a new class of conditional pseudo jumps.
A comment hopefully will be good enough to avoid confusion
with existing conditional jumps (all of them except BPF_JA)
I will also add
enum {
 BPF_MAY_GOTO = 0,
};
and check that src_reg is equal to that.
Then in the future it can be extended with:
 BPF_NOP_OR_GOTO = 1,

> Left few nits below, but overall LGTM
>
> Acked-by: Andrii Nakryiko <andrii@kernel.org>
>
> > diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> > index 84365e6dd85d..917ca603059b 100644
> > --- a/include/linux/bpf_verifier.h
> > +++ b/include/linux/bpf_verifier.h
> > @@ -449,6 +449,7 @@ struct bpf_verifier_state {
> >         u32 jmp_history_cnt;
> >         u32 dfs_depth;
> >         u32 callback_unroll_depth;
> > +       u32 may_goto_cnt;
>
> naming nit: seems like we consistently use "depth" terminology for
> bpf_loop and open-coded iters, any reason to deviate with "cnt"
> terminology here?

sure.

> >  };
> >
> >  #define bpf_get_spilled_reg(slot, frame, mask)                         \
>
> [...]
>
> >
> > +static bool is_may_goto_insn(struct bpf_verifier_env *env, int insn_idx)
> > +{
> > +       return env->prog->insnsi[insn_idx].code == (BPF_JMP | BPF_JMA);
> > +}
> > +
> >  /* process_iter_next_call() is called when verifier gets to iterator's next
> >   * "method" (e.g., bpf_iter_num_next() for numbers iterator) call. We'll refer
> >   * to it as just "iter_next()" in comments below.
> > @@ -14871,11 +14877,35 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
> >         int err;
> >
> >         /* Only conditional jumps are expected to reach here. */
> > -       if (opcode == BPF_JA || opcode > BPF_JSLE) {
> > +       if (opcode == BPF_JA || opcode > BPF_JMA) {
> >                 verbose(env, "invalid BPF_JMP/JMP32 opcode %x\n", opcode);
> >                 return -EINVAL;
> >         }
> >
> > +       if (opcode == BPF_JMA) {
> > +               struct bpf_verifier_state *cur_st = env->cur_state, *queued_st, *prev_st;
> > +               int idx = *insn_idx;
> > +
> > +               if (insn->code != (BPF_JMP | BPF_JMA) ||
> > +                   insn->src_reg || insn->dst_reg || insn->imm || insn->off == 0) {
> > +                       verbose(env, "invalid may_goto off %d imm %d\n",
> > +                               insn->off, insn->imm);
> > +                       return -EINVAL;
> > +               }
> > +               prev_st = find_prev_entry(env, cur_st->parent, idx);
> > +
> > +               /* branch out 'fallthrough' insn as a new state to explore */
> > +               queued_st = push_stack(env, idx + 1, idx, false);
> > +               if (!queued_st)
> > +                       return -ENOMEM;
> > +
> > +               queued_st->may_goto_cnt++;
> > +               if (prev_st)
> > +                       widen_imprecise_scalars(env, prev_st, queued_st);
> > +               *insn_idx += insn->off;
> > +               return 0;
> > +       }
> > +
> >         /* check src2 operand */
> >         err = check_reg_arg(env, insn->dst_reg, SRC_OP);
> >         if (err)
> > @@ -15659,6 +15689,8 @@ static int visit_insn(int t, struct bpf_verifier_env *env)
> >         default:
> >                 /* conditional jump with two edges */
> >                 mark_prune_point(env, t);
> > +               if (insn->code == (BPF_JMP | BPF_JMA))
>
> maybe use is_may_goto_insn() here for consistency?

+1

> > +                       mark_force_checkpoint(env, t);
> >
> >                 ret = push_insn(t, t + 1, FALLTHROUGH, env);
> >                 if (ret)
>
> [...]
>
> >  patch_call_imm:
> >                 fn = env->ops->get_func_proto(insn->imm, env->prog);
> > @@ -19952,6 +20015,39 @@ static int do_misc_fixups(struct bpf_verifier_env *env)
> >                         return -EFAULT;
> >                 }
> >                 insn->imm = fn->func - __bpf_call_base;
> > +next_insn:
> > +               if (subprogs[cur_subprog + 1].start == i + delta + 1) {
> > +                       subprogs[cur_subprog].stack_depth += stack_depth_extra;
> > +                       subprogs[cur_subprog].stack_extra = stack_depth_extra;
> > +                       cur_subprog++;
> > +                       stack_depth = subprogs[cur_subprog].stack_depth;
> > +                       stack_depth_extra = 0;
> > +               }
> > +               i++; insn++;
>
> Is there a code path where we don't do i++, insn++? From cursory look
> at this loop, I think we always do this, so not sure why `i++, insn++`
> had to be moved from for() clause?

I was worried about missing replacing 'continue' with 'goto next_insn'.
Since 'continue' will work for all tests except may_goto tests,
and in arena patch set I have a hunk that is added to this loop too
and it is written with 'continue'.
Technically we can keep 'i++; insn++;' in the for()...
if we're going to be code reviewing any future additions carefully.
In other words 'stack_extra logic plus i and insn increments'
are part of the same logical block. It's an action that should
be done after each insn, hence better to keep them together.
Either inside for() as I did in v1/v2 or here toward the last '}'.
The latter is more canonical C.

>
> But if I missed it and we have to do these increments here, these are
> two separate statements, so let's put them on separate lines?

sure.

>
> > +       }
> > +
> > +       env->prog->aux->stack_depth = subprogs[0].stack_depth;
> > +       for (i = 0; i < env->subprog_cnt; i++) {
> > +               int subprog_start = subprogs[i].start, j;
> > +               int stack_slots = subprogs[i].stack_extra / 8;
> > +
> > +               if (stack_slots >= ARRAY_SIZE(insn_buf)) {
> > +                       verbose(env, "verifier bug: stack_extra is too large\n");
> > +                       return -EFAULT;
> > +               }
> > +
> > +               /* Add insns to subprog prologue to init extra stack */
> > +               for (j = 0; j < stack_slots; j++)
> > +                       insn_buf[j] = BPF_ST_MEM(BPF_DW, BPF_REG_FP,
> > +                                                -subprogs[i].stack_depth + j * 8, BPF_MAX_LOOPS);
> > +               if (j) {
> > +                       insn_buf[j] = env->prog->insnsi[subprog_start];
> > +
> > +                       new_prog = bpf_patch_insn_data(env, subprog_start, insn_buf, j + 1);
> > +                       if (!new_prog)
> > +                               return -ENOMEM;
> > +                       env->prog = prog = new_prog;
> > +               }
>
> this code is sort of generic (you don't assume just 0 or 1 extra
> slots), but then it initializes each extra slot with BPF_MAX_LOOPS,
> which doesn't look generic at all. So it's neither as simple as it
> could be nor generic, really...
>
> Maybe let's add WARN_ON if stack_extra>1 (so we catch it if we ever
> extend this), but otherwise just have a simple and easier to follow
>
> if (stack_slots) {
>     insn_buf[0] = BPF_ST_MEM(..., BPF_MAX_LOOPS);
>     /* bpf_patch_insn_data() replaces instruction,
>      * so we need to copy first actual insn to preserve it (it's not
> that obvious)
>      */
>     insn_buf[1] = env->prog->insnsi[subprog_start];
>     ... patch ...
> }
>
> It's pretty minor, overall, but definitely caused some pause for me.

I had it like that in v2, but then thought that I might
try to use this to fix tail_call from subprogs issue discussed
in the separate thread, so I changed it to generic zero init in v3,
but then went with BPF_MAX_LOOPS init in v4.
Looking at this now it seems hard coding to stack_slots==1 is better indeed.
diff mbox series

Patch

diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index 84365e6dd85d..917ca603059b 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -449,6 +449,7 @@  struct bpf_verifier_state {
 	u32 jmp_history_cnt;
 	u32 dfs_depth;
 	u32 callback_unroll_depth;
+	u32 may_goto_cnt;
 };
 
 #define bpf_get_spilled_reg(slot, frame, mask)				\
@@ -619,6 +620,7 @@  struct bpf_subprog_info {
 	u32 start; /* insn idx of function entry point */
 	u32 linfo_idx; /* The idx to the main_prog->aux->linfo */
 	u16 stack_depth; /* max. stack depth used by this function */
+	u16 stack_extra;
 	bool has_tail_call: 1;
 	bool tail_call_reachable: 1;
 	bool has_ld_abs: 1;
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index a241f407c234..932ffef0dc88 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -42,6 +42,7 @@ 
 #define BPF_JSGE	0x70	/* SGE is signed '>=', GE in x86 */
 #define BPF_JSLT	0xc0	/* SLT is signed, '<' */
 #define BPF_JSLE	0xd0	/* SLE is signed, '<=' */
+#define BPF_JMA		0xe0	/* may_goto */
 #define BPF_CALL	0x80	/* function call */
 #define BPF_EXIT	0x90	/* function return */
 
diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
index 71c459a51d9e..ba6101447b49 100644
--- a/kernel/bpf/core.c
+++ b/kernel/bpf/core.c
@@ -1675,6 +1675,7 @@  bool bpf_opcode_in_insntable(u8 code)
 		[BPF_LD | BPF_IND | BPF_B] = true,
 		[BPF_LD | BPF_IND | BPF_H] = true,
 		[BPF_LD | BPF_IND | BPF_W] = true,
+		[BPF_JMP | BPF_JMA] = true,
 	};
 #undef BPF_INSN_3_TBL
 #undef BPF_INSN_2_TBL
diff --git a/kernel/bpf/disasm.c b/kernel/bpf/disasm.c
index 49940c26a227..598cd38af84c 100644
--- a/kernel/bpf/disasm.c
+++ b/kernel/bpf/disasm.c
@@ -322,6 +322,9 @@  void print_bpf_insn(const struct bpf_insn_cbs *cbs,
 		} else if (insn->code == (BPF_JMP | BPF_JA)) {
 			verbose(cbs->private_data, "(%02x) goto pc%+d\n",
 				insn->code, insn->off);
+		} else if (insn->code == (BPF_JMP | BPF_JMA)) {
+			verbose(cbs->private_data, "(%02x) may_goto pc%+d\n",
+				insn->code, insn->off);
 		} else if (insn->code == (BPF_JMP32 | BPF_JA)) {
 			verbose(cbs->private_data, "(%02x) gotol pc%+d\n",
 				insn->code, insn->imm);
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 4dd84e13bbfe..226bb65f9c2c 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -1429,6 +1429,7 @@  static int copy_verifier_state(struct bpf_verifier_state *dst_state,
 	dst_state->dfs_depth = src->dfs_depth;
 	dst_state->callback_unroll_depth = src->callback_unroll_depth;
 	dst_state->used_as_loop_entry = src->used_as_loop_entry;
+	dst_state->may_goto_cnt = src->may_goto_cnt;
 	for (i = 0; i <= src->curframe; i++) {
 		dst = dst_state->frame[i];
 		if (!dst) {
@@ -7880,6 +7881,11 @@  static int widen_imprecise_scalars(struct bpf_verifier_env *env,
 	return 0;
 }
 
+static bool is_may_goto_insn(struct bpf_verifier_env *env, int insn_idx)
+{
+	return env->prog->insnsi[insn_idx].code == (BPF_JMP | BPF_JMA);
+}
+
 /* process_iter_next_call() is called when verifier gets to iterator's next
  * "method" (e.g., bpf_iter_num_next() for numbers iterator) call. We'll refer
  * to it as just "iter_next()" in comments below.
@@ -14871,11 +14877,35 @@  static int check_cond_jmp_op(struct bpf_verifier_env *env,
 	int err;
 
 	/* Only conditional jumps are expected to reach here. */
-	if (opcode == BPF_JA || opcode > BPF_JSLE) {
+	if (opcode == BPF_JA || opcode > BPF_JMA) {
 		verbose(env, "invalid BPF_JMP/JMP32 opcode %x\n", opcode);
 		return -EINVAL;
 	}
 
+	if (opcode == BPF_JMA) {
+		struct bpf_verifier_state *cur_st = env->cur_state, *queued_st, *prev_st;
+		int idx = *insn_idx;
+
+		if (insn->code != (BPF_JMP | BPF_JMA) ||
+		    insn->src_reg || insn->dst_reg || insn->imm || insn->off == 0) {
+			verbose(env, "invalid may_goto off %d imm %d\n",
+				insn->off, insn->imm);
+			return -EINVAL;
+		}
+		prev_st = find_prev_entry(env, cur_st->parent, idx);
+
+		/* branch out 'fallthrough' insn as a new state to explore */
+		queued_st = push_stack(env, idx + 1, idx, false);
+		if (!queued_st)
+			return -ENOMEM;
+
+		queued_st->may_goto_cnt++;
+		if (prev_st)
+			widen_imprecise_scalars(env, prev_st, queued_st);
+		*insn_idx += insn->off;
+		return 0;
+	}
+
 	/* check src2 operand */
 	err = check_reg_arg(env, insn->dst_reg, SRC_OP);
 	if (err)
@@ -15659,6 +15689,8 @@  static int visit_insn(int t, struct bpf_verifier_env *env)
 	default:
 		/* conditional jump with two edges */
 		mark_prune_point(env, t);
+		if (insn->code == (BPF_JMP | BPF_JMA))
+			mark_force_checkpoint(env, t);
 
 		ret = push_insn(t, t + 1, FALLTHROUGH, env);
 		if (ret)
@@ -17135,6 +17167,13 @@  static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
 				}
 				goto skip_inf_loop_check;
 			}
+			if (is_may_goto_insn(env, insn_idx)) {
+				if (states_equal(env, &sl->state, cur, true)) {
+					update_loop_entry(cur, &sl->state);
+					goto hit;
+				}
+				goto skip_inf_loop_check;
+			}
 			if (calls_callback(env, insn_idx)) {
 				if (states_equal(env, &sl->state, cur, true))
 					goto hit;
@@ -17144,6 +17183,7 @@  static int is_state_visited(struct bpf_verifier_env *env, int insn_idx)
 			if (states_maybe_looping(&sl->state, cur) &&
 			    states_equal(env, &sl->state, cur, true) &&
 			    !iter_active_depths_differ(&sl->state, cur) &&
+			    sl->state.may_goto_cnt == cur->may_goto_cnt &&
 			    sl->state.callback_unroll_depth == cur->callback_unroll_depth) {
 				verbose_linfo(env, insn_idx, "; ");
 				verbose(env, "infinite loop detected at insn %d\n", insn_idx);
@@ -19408,7 +19448,10 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 	struct bpf_insn insn_buf[16];
 	struct bpf_prog *new_prog;
 	struct bpf_map *map_ptr;
-	int i, ret, cnt, delta = 0;
+	int i, ret, cnt, delta = 0, cur_subprog = 0;
+	struct bpf_subprog_info *subprogs = env->subprog_info;
+	u16 stack_depth = subprogs[cur_subprog].stack_depth;
+	u16 stack_depth_extra = 0;
 
 	if (env->seen_exception && !env->exception_callback_subprog) {
 		struct bpf_insn patch[] = {
@@ -19428,7 +19471,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 		mark_subprog_exc_cb(env, env->exception_callback_subprog);
 	}
 
-	for (i = 0; i < insn_cnt; i++, insn++) {
+	for (i = 0; i < insn_cnt;) {
 		/* Make divide-by-zero exceptions impossible. */
 		if (insn->code == (BPF_ALU64 | BPF_MOD | BPF_X) ||
 		    insn->code == (BPF_ALU64 | BPF_DIV | BPF_X) ||
@@ -19467,7 +19510,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			delta    += cnt - 1;
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 
 		/* Implement LD_ABS and LD_IND with a rewrite, if supported by the program type. */
@@ -19487,7 +19530,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			delta    += cnt - 1;
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 
 		/* Rewrite pointer arithmetic to mitigate speculation attacks. */
@@ -19502,7 +19545,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			aux = &env->insn_aux_data[i + delta];
 			if (!aux->alu_state ||
 			    aux->alu_state == BPF_ALU_NON_POINTER)
-				continue;
+				goto next_insn;
 
 			isneg = aux->alu_state & BPF_ALU_NEG_VALUE;
 			issrc = (aux->alu_state & BPF_ALU_SANITIZE) ==
@@ -19540,19 +19583,39 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			delta    += cnt - 1;
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
+		}
+
+		if (insn->code == (BPF_JMP | BPF_JMA)) {
+			int stack_off = -stack_depth - 8;
+
+			stack_depth_extra = 8;
+			insn_buf[0] = BPF_LDX_MEM(BPF_DW, BPF_REG_AX, BPF_REG_10, stack_off);
+			insn_buf[1] = BPF_JMP_IMM(BPF_JEQ, BPF_REG_AX, 0, insn->off + 2);
+			insn_buf[2] = BPF_ALU64_IMM(BPF_SUB, BPF_REG_AX, 1);
+			insn_buf[3] = BPF_STX_MEM(BPF_DW, BPF_REG_10, BPF_REG_AX, stack_off);
+			cnt = 4;
+
+			new_prog = bpf_patch_insn_data(env, i + delta, insn_buf, cnt);
+			if (!new_prog)
+				return -ENOMEM;
+
+			delta += cnt - 1;
+			env->prog = prog = new_prog;
+			insn = new_prog->insnsi + i + delta;
+			goto next_insn;
 		}
 
 		if (insn->code != (BPF_JMP | BPF_CALL))
-			continue;
+			goto next_insn;
 		if (insn->src_reg == BPF_PSEUDO_CALL)
-			continue;
+			goto next_insn;
 		if (insn->src_reg == BPF_PSEUDO_KFUNC_CALL) {
 			ret = fixup_kfunc_call(env, insn, insn_buf, i + delta, &cnt);
 			if (ret)
 				return ret;
 			if (cnt == 0)
-				continue;
+				goto next_insn;
 
 			new_prog = bpf_patch_insn_data(env, i + delta, insn_buf, cnt);
 			if (!new_prog)
@@ -19561,7 +19624,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			delta	 += cnt - 1;
 			env->prog = prog = new_prog;
 			insn	  = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 
 		if (insn->imm == BPF_FUNC_get_route_realm)
@@ -19609,11 +19672,11 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 				}
 
 				insn->imm = ret + 1;
-				continue;
+				goto next_insn;
 			}
 
 			if (!bpf_map_ptr_unpriv(aux))
-				continue;
+				goto next_insn;
 
 			/* instead of changing every JIT dealing with tail_call
 			 * emit two extra insns:
@@ -19642,7 +19705,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			delta    += cnt - 1;
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 
 		if (insn->imm == BPF_FUNC_timer_set_callback) {
@@ -19754,7 +19817,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 				delta    += cnt - 1;
 				env->prog = prog = new_prog;
 				insn      = new_prog->insnsi + i + delta;
-				continue;
+				goto next_insn;
 			}
 
 			BUILD_BUG_ON(!__same_type(ops->map_lookup_elem,
@@ -19785,31 +19848,31 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			switch (insn->imm) {
 			case BPF_FUNC_map_lookup_elem:
 				insn->imm = BPF_CALL_IMM(ops->map_lookup_elem);
-				continue;
+				goto next_insn;
 			case BPF_FUNC_map_update_elem:
 				insn->imm = BPF_CALL_IMM(ops->map_update_elem);
-				continue;
+				goto next_insn;
 			case BPF_FUNC_map_delete_elem:
 				insn->imm = BPF_CALL_IMM(ops->map_delete_elem);
-				continue;
+				goto next_insn;
 			case BPF_FUNC_map_push_elem:
 				insn->imm = BPF_CALL_IMM(ops->map_push_elem);
-				continue;
+				goto next_insn;
 			case BPF_FUNC_map_pop_elem:
 				insn->imm = BPF_CALL_IMM(ops->map_pop_elem);
-				continue;
+				goto next_insn;
 			case BPF_FUNC_map_peek_elem:
 				insn->imm = BPF_CALL_IMM(ops->map_peek_elem);
-				continue;
+				goto next_insn;
 			case BPF_FUNC_redirect_map:
 				insn->imm = BPF_CALL_IMM(ops->map_redirect);
-				continue;
+				goto next_insn;
 			case BPF_FUNC_for_each_map_elem:
 				insn->imm = BPF_CALL_IMM(ops->map_for_each_callback);
-				continue;
+				goto next_insn;
 			case BPF_FUNC_map_lookup_percpu_elem:
 				insn->imm = BPF_CALL_IMM(ops->map_lookup_percpu_elem);
-				continue;
+				goto next_insn;
 			}
 
 			goto patch_call_imm;
@@ -19837,7 +19900,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			delta    += cnt - 1;
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 
 		/* Implement bpf_get_func_arg inline. */
@@ -19862,7 +19925,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			delta    += cnt - 1;
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 
 		/* Implement bpf_get_func_ret inline. */
@@ -19890,7 +19953,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			delta    += cnt - 1;
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 
 		/* Implement get_func_arg_cnt inline. */
@@ -19905,7 +19968,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 
 		/* Implement bpf_get_func_ip inline. */
@@ -19920,7 +19983,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 
 		/* Implement bpf_kptr_xchg inline */
@@ -19938,7 +20001,7 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			delta    += cnt - 1;
 			env->prog = prog = new_prog;
 			insn      = new_prog->insnsi + i + delta;
-			continue;
+			goto next_insn;
 		}
 patch_call_imm:
 		fn = env->ops->get_func_proto(insn->imm, env->prog);
@@ -19952,6 +20015,39 @@  static int do_misc_fixups(struct bpf_verifier_env *env)
 			return -EFAULT;
 		}
 		insn->imm = fn->func - __bpf_call_base;
+next_insn:
+		if (subprogs[cur_subprog + 1].start == i + delta + 1) {
+			subprogs[cur_subprog].stack_depth += stack_depth_extra;
+			subprogs[cur_subprog].stack_extra = stack_depth_extra;
+			cur_subprog++;
+			stack_depth = subprogs[cur_subprog].stack_depth;
+			stack_depth_extra = 0;
+		}
+		i++; insn++;
+	}
+
+	env->prog->aux->stack_depth = subprogs[0].stack_depth;
+	for (i = 0; i < env->subprog_cnt; i++) {
+		int subprog_start = subprogs[i].start, j;
+		int stack_slots = subprogs[i].stack_extra / 8;
+
+		if (stack_slots >= ARRAY_SIZE(insn_buf)) {
+			verbose(env, "verifier bug: stack_extra is too large\n");
+			return -EFAULT;
+		}
+
+		/* Add insns to subprog prologue to init extra stack */
+		for (j = 0; j < stack_slots; j++)
+			insn_buf[j] = BPF_ST_MEM(BPF_DW, BPF_REG_FP,
+						 -subprogs[i].stack_depth + j * 8, BPF_MAX_LOOPS);
+		if (j) {
+			insn_buf[j] = env->prog->insnsi[subprog_start];
+
+			new_prog = bpf_patch_insn_data(env, subprog_start, insn_buf, j + 1);
+			if (!new_prog)
+				return -ENOMEM;
+			env->prog = prog = new_prog;
+		}
 	}
 
 	/* Since poke tab is now finalized, publish aux to tracker. */
diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
index a241f407c234..932ffef0dc88 100644
--- a/tools/include/uapi/linux/bpf.h
+++ b/tools/include/uapi/linux/bpf.h
@@ -42,6 +42,7 @@ 
 #define BPF_JSGE	0x70	/* SGE is signed '>=', GE in x86 */
 #define BPF_JSLT	0xc0	/* SLT is signed, '<' */
 #define BPF_JSLE	0xd0	/* SLE is signed, '<=' */
+#define BPF_JMA		0xe0	/* may_goto */
 #define BPF_CALL	0x80	/* function call */
 #define BPF_EXIT	0x90	/* function return */