diff mbox series

[bpf-next,2/8] bpf: add bpf_for_each_map_elem() helper

Message ID 20210204234829.1629159-1-yhs@fb.com (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series bpf: add bpf_for_each_map_elem() helper | expand

Checks

Context Check Description
netdev/cover_letter success Link
netdev/fixes_present success Link
netdev/patch_count success Link
netdev/tree_selection success Clearly marked for bpf-next
netdev/subject_prefix success Link
netdev/cc_maintainers warning 9 maintainers not CCed: quentin@isovalent.com netdev@vger.kernel.org songliubraving@fb.com andrii@kernel.org rostedt@goodmis.org mingo@redhat.com kpsingh@kernel.org john.fastabend@gmail.com kafai@fb.com
netdev/source_inline success Was 0 now: 0
netdev/verify_signedoff success Link
netdev/module_param success Was 0 now: 0
netdev/build_32bit success Errors and warnings before: 12238 this patch: 12238
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/verify_fixes success Link
netdev/checkpatch warning CHECK: Please don't use multiple blank lines WARNING: ENOTSUPP is not a SUSV4 error code, prefer EOPNOTSUPP WARNING: line length of 100 exceeds 80 columns WARNING: line length of 81 exceeds 80 columns WARNING: line length of 82 exceeds 80 columns WARNING: line length of 83 exceeds 80 columns WARNING: line length of 85 exceeds 80 columns WARNING: line length of 86 exceeds 80 columns WARNING: line length of 87 exceeds 80 columns WARNING: line length of 89 exceeds 80 columns WARNING: line length of 91 exceeds 80 columns WARNING: line length of 95 exceeds 80 columns
netdev/build_allmodconfig_warn success Errors and warnings before: 12886 this patch: 12886
netdev/header_inline success Link
netdev/stable success Stable not CCed

Commit Message

Yonghong Song Feb. 4, 2021, 11:48 p.m. UTC
The bpf_for_each_map_elem() helper is introduced which
iterates all map elements with a callback function. The
helper signature looks like
  long bpf_for_each_map_elem(map, callback_fn, callback_ctx, flags)
and for each map element, the callback_fn will be called. For example,
like hashmap, the callback signature may look like
  long callback_fn(map, key, val, callback_ctx)

There are two known use cases for this. One is from upstream ([1]) where
a for_each_map_elem helper may help implement a timeout mechanism
in a more generic way. Another is from our internal discussion
for a firewall use case where a map contains all the rules. The packet
data can be compared to all these rules to decide allow or deny
the packet.

For array maps, users can already use a bounded loop to traverse
elements. Using this helper can avoid using bounded loop. For other
type of maps (e.g., hash maps) where bounded loop is hard or
impossible to use, this helper provides a convenient way to
operate on all elements.

For callback_fn, besides map and map element, a callback_ctx,
allocated on caller stack, is also passed to the callback
function. This callback_ctx argument can provide additional
input and allow to write to caller stack for output.

If the callback_fn returns 0, the helper will iterate through next
element if available. If the callback_fn returns 1, the helper
will stop iterating and returns to the bpf program. Other return
values are not used for now.

Currently, this helper is only available with jit. It is possible
to make it work with interpreter with so effort but I leave it
as the future work.

[1]: https://lore.kernel.org/bpf/20210122205415.113822-1-xiyou.wangcong@gmail.com/

Signed-off-by: Yonghong Song <yhs@fb.com>
---
 include/linux/bpf.h            |  14 ++
 include/linux/bpf_verifier.h   |   3 +
 include/uapi/linux/bpf.h       |  28 ++++
 kernel/bpf/bpf_iter.c          |  16 +++
 kernel/bpf/helpers.c           |   2 +
 kernel/bpf/verifier.c          | 251 ++++++++++++++++++++++++++++++---
 kernel/trace/bpf_trace.c       |   2 +
 tools/include/uapi/linux/bpf.h |  28 ++++
 8 files changed, 328 insertions(+), 16 deletions(-)

Comments

Alexei Starovoitov Feb. 5, 2021, 5:49 a.m. UTC | #1
On Thu, Feb 04, 2021 at 03:48:29PM -0800, Yonghong Song wrote:
> The bpf_for_each_map_elem() helper is introduced which
> iterates all map elements with a callback function. The
> helper signature looks like
>   long bpf_for_each_map_elem(map, callback_fn, callback_ctx, flags)
> and for each map element, the callback_fn will be called. For example,
> like hashmap, the callback signature may look like
>   long callback_fn(map, key, val, callback_ctx)
> 
> There are two known use cases for this. One is from upstream ([1]) where
> a for_each_map_elem helper may help implement a timeout mechanism
> in a more generic way. Another is from our internal discussion
> for a firewall use case where a map contains all the rules. The packet
> data can be compared to all these rules to decide allow or deny
> the packet.
> 
> For array maps, users can already use a bounded loop to traverse
> elements. Using this helper can avoid using bounded loop. For other
> type of maps (e.g., hash maps) where bounded loop is hard or
> impossible to use, this helper provides a convenient way to
> operate on all elements.
> 
> For callback_fn, besides map and map element, a callback_ctx,
> allocated on caller stack, is also passed to the callback
> function. This callback_ctx argument can provide additional
> input and allow to write to caller stack for output.

The approach and implementation look great!
Few ideas below:

> +static int check_map_elem_callback(struct bpf_verifier_env *env, int *insn_idx)
> +{
> +	struct bpf_verifier_state *state = env->cur_state;
> +	struct bpf_prog_aux *aux = env->prog->aux;
> +	struct bpf_func_state *caller, *callee;
> +	struct bpf_map *map;
> +	int err, subprog;
> +
> +	if (state->curframe + 1 >= MAX_CALL_FRAMES) {
> +		verbose(env, "the call stack of %d frames is too deep\n",
> +			state->curframe + 2);
> +		return -E2BIG;
> +	}
> +
> +	caller = state->frame[state->curframe];
> +	if (state->frame[state->curframe + 1]) {
> +		verbose(env, "verifier bug. Frame %d already allocated\n",
> +			state->curframe + 1);
> +		return -EFAULT;
> +	}
> +
> +	caller->with_callback_fn = true;
> +
> +	callee = kzalloc(sizeof(*callee), GFP_KERNEL);
> +	if (!callee)
> +		return -ENOMEM;
> +	state->frame[state->curframe + 1] = callee;
> +
> +	/* callee cannot access r0, r6 - r9 for reading and has to write
> +	 * into its own stack before reading from it.
> +	 * callee can read/write into caller's stack
> +	 */
> +	init_func_state(env, callee,
> +			/* remember the callsite, it will be used by bpf_exit */
> +			*insn_idx /* callsite */,
> +			state->curframe + 1 /* frameno within this callchain */,
> +			subprog /* subprog number within this prog */);
> +
> +	/* Transfer references to the callee */
> +	err = transfer_reference_state(callee, caller);
> +	if (err)
> +		return err;
> +
> +	subprog = caller->regs[BPF_REG_2].subprog;
> +	if (aux->func_info && aux->func_info_aux[subprog].linkage != BTF_FUNC_STATIC) {
> +		verbose(env, "callback function R2 not static\n");
> +		return -EINVAL;
> +	}
> +
> +	map = caller->regs[BPF_REG_1].map_ptr;

Take a look at for (i = 0; i < 5; i++)  err = check_func_arg loop and record_func_map.
It stores the map pointer into map_ptr_state and makes sure it's unique,
so that program doesn't try to pass two different maps into the same 'call insn'.
It can make this function a bit more generic.
There would be no need to hard code regs[BPF_REG_1].
The code would take it from map_ptr_state.
Also it will help later with optimizing
  return map->ops->map_for_each_callback(map, callback_fn, callback_ctx, flags);
since the map pointer will be the same the optimization (that is applied to other map
operations) can be applied for this callback as well.

The regs[BPF_REG_2] can be generalized a bit as well.
It think linkage != BTF_FUNC_STATIC can be moved to early check_ld_imm phase.
While here the check_func_arg() loop can look for PTR_TO_FUNC type,
remeber the subprog into meta (just like map_ptr_state) and ... continues below

> +	if (!map->ops->map_set_for_each_callback_args ||
> +	    !map->ops->map_for_each_callback) {
> +		verbose(env, "callback function not allowed for map R1\n");
> +		return -ENOTSUPP;
> +	}
> +
> +	/* the following is only for hashmap, different maps
> +	 * can have different callback signatures.
> +	 */
> +	err = map->ops->map_set_for_each_callback_args(env, caller, callee);
> +	if (err)
> +		return err;
> +
> +	clear_caller_saved_regs(env, caller->regs);
> +
> +	/* only increment it after check_reg_arg() finished */
> +	state->curframe++;
> +
> +	/* and go analyze first insn of the callee */
> +	*insn_idx = env->subprog_info[subprog].start - 1;
> +
> +	if (env->log.level & BPF_LOG_LEVEL) {
> +		verbose(env, "caller:\n");
> +		print_verifier_state(env, caller);
> +		verbose(env, "callee:\n");
> +		print_verifier_state(env, callee);
> +	}
> +	return 0;
> +}
> +
>  static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
>  {
>  	struct bpf_verifier_state *state = env->cur_state;
>  	struct bpf_func_state *caller, *callee;
>  	struct bpf_reg_state *r0;
> -	int err;
> +	int i, err;
>  
>  	callee = state->frame[state->curframe];
>  	r0 = &callee->regs[BPF_REG_0];
> @@ -4955,7 +5090,17 @@ static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
>  	state->curframe--;
>  	caller = state->frame[state->curframe];
>  	/* return to the caller whatever r0 had in the callee */
> -	caller->regs[BPF_REG_0] = *r0;
> +	if (caller->with_callback_fn) {
> +		/* reset caller saved regs, the helper calling callback_fn
> +		 * has RET_INTEGER return types.
> +		 */
> +		for (i = 0; i < CALLER_SAVED_REGS; i++)
> +			mark_reg_not_init(env, caller->regs, caller_saved[i]);
> +		caller->regs[BPF_REG_0].subreg_def = DEF_NOT_SUBREG;
> +		mark_reg_unknown(env, caller->regs, BPF_REG_0);

this part can stay in check_helper_call().

> +	} else {
> +		caller->regs[BPF_REG_0] = *r0;
> +	}
>  
>  	/* Transfer references to the caller */
>  	err = transfer_reference_state(caller, callee);
> @@ -5091,7 +5236,8 @@ static int check_reference_leak(struct bpf_verifier_env *env)
>  	return state->acquired_refs ? -EINVAL : 0;
>  }
>  
> -static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn_idx)
> +static int check_helper_call(struct bpf_verifier_env *env, int func_id, int *insn_idx,
> +			     bool map_elem_callback)
>  {
>  	const struct bpf_func_proto *fn = NULL;
>  	struct bpf_reg_state *regs;
> @@ -5151,11 +5297,11 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
>  			return err;
>  	}
>  
> -	err = record_func_map(env, &meta, func_id, insn_idx);
> +	err = record_func_map(env, &meta, func_id, *insn_idx);
>  	if (err)
>  		return err;
>  
> -	err = record_func_key(env, &meta, func_id, insn_idx);
> +	err = record_func_key(env, &meta, func_id, *insn_idx);
>  	if (err)
>  		return err;
>  
> @@ -5163,7 +5309,7 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
>  	 * is inferred from register state.
>  	 */
>  	for (i = 0; i < meta.access_size; i++) {
> -		err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B,
> +		err = check_mem_access(env, *insn_idx, meta.regno, i, BPF_B,
>  				       BPF_WRITE, -1, false);
>  		if (err)
>  			return err;
> @@ -5195,6 +5341,11 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
>  		return -EINVAL;
>  	}
>  
> +	if (map_elem_callback) {
> +		env->prog->aux->with_callback_fn = true;
> +		return check_map_elem_callback(env, insn_idx);

Instead of returning early here.
The check_func_arg() loop can look for PTR_TO_FUNC type.
The allocate new callee state,
do map_set_for_each_callback_args() here.
and then proceed further.

> +	}
> +
>  	/* reset caller saved regs */
>  	for (i = 0; i < CALLER_SAVED_REGS; i++) {
>  		mark_reg_not_init(env, regs, caller_saved[i]);

Instead of doing this loop in prepare_func_exit().
This code can just proceed here and clear caller regs. This loop can stay as-is.
The transfer of caller->callee would happen already.

Then there are few lines here that diff didn't show.
They do regs[BPF_REG_0].subreg_def = DEF_NOT_SUBREG and mark_reg_unknown.
No need to do them in prepare_func_exit().
This function can proceed further reusing this caller regs clearing loop and r0 marking.

Then before returning from check_helper_call()
it will do what you have in check_map_elem_callback() and it will adjust *insn_idx.

At this point caller would have regs cleared and r0=undef.
And callee would have regs setup the way map_set_for_each_callback_args callback meant to do it.
The only thing prepare_func_exit would need to do is to make sure that assignment:
caller->regs[BPF_REG_0] = *r0
doesn't happen. caller's r0 was already set to undef.
To achieve that I think would be a bit cleaner to mark callee state instead of caller state.
So instead of caller->with_callback_fn=true maybe callee->in_callback_fn=true ?

> @@ -5306,7 +5457,7 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
>  		/* For release_reference() */
>  		regs[BPF_REG_0].ref_obj_id = meta.ref_obj_id;
>  	} else if (is_acquire_function(func_id, meta.map_ptr)) {
> -		int id = acquire_reference_state(env, insn_idx);
> +		int id = acquire_reference_state(env, *insn_idx);
>  
>  		if (id < 0)
>  			return id;
> @@ -5448,6 +5599,14 @@ static int retrieve_ptr_limit(const struct bpf_reg_state *ptr_reg,
>  		else
>  			*ptr_limit = -off;
>  		return 0;
> +	case PTR_TO_MAP_KEY:
> +		if (mask_to_left) {
> +			*ptr_limit = ptr_reg->umax_value + ptr_reg->off;
> +		} else {
> +			off = ptr_reg->smin_value + ptr_reg->off;
> +			*ptr_limit = ptr_reg->map_ptr->key_size - off;
> +		}
> +		return 0;
>  	case PTR_TO_MAP_VALUE:
>  		if (mask_to_left) {
>  			*ptr_limit = ptr_reg->umax_value + ptr_reg->off;
> @@ -5614,6 +5773,7 @@ static int adjust_ptr_min_max_vals(struct bpf_verifier_env *env,
>  		verbose(env, "R%d pointer arithmetic on %s prohibited\n",
>  			dst, reg_type_str[ptr_reg->type]);
>  		return -EACCES;
> +	case PTR_TO_MAP_KEY:
>  	case PTR_TO_MAP_VALUE:
>  		if (!env->allow_ptr_leaks && !known && (smin_val < 0) != (smax_val < 0)) {
>  			verbose(env, "R%d has unknown scalar with mixed signed bounds, pointer arithmetic with it prohibited for !root\n",
> @@ -7818,6 +7978,12 @@ static int check_ld_imm(struct bpf_verifier_env *env, struct bpf_insn *insn)
>  		return 0;
>  	}
>  
> +	if (insn->src_reg == BPF_PSEUDO_FUNC) {
> +		dst_reg->type = PTR_TO_FUNC;
> +		dst_reg->subprog = insn[1].imm;

Like here check for linkage==static can happen ?

> +		return 0;
> +	}
> +
>  	map = env->used_maps[aux->map_index];
>  	mark_reg_known_zero(env, regs, insn->dst_reg);
>  	dst_reg->map_ptr = map;
> @@ -8195,9 +8361,23 @@ static int visit_insn(int t, int insn_cnt, struct bpf_verifier_env *env)
>  
>  	/* All non-branch instructions have a single fall-through edge. */
>  	if (BPF_CLASS(insns[t].code) != BPF_JMP &&
> -	    BPF_CLASS(insns[t].code) != BPF_JMP32)
> +	    BPF_CLASS(insns[t].code) != BPF_JMP32 &&
> +	    !bpf_pseudo_func(insns + t))
>  		return push_insn(t, t + 1, FALLTHROUGH, env, false);
>  
> +	if (bpf_pseudo_func(insns + t)) {
> +		ret = push_insn(t, t + 1, FALLTHROUGH, env, false);
> +		if (ret)
> +			return ret;
> +
> +		if (t + 1 < insn_cnt)
> +			init_explored_state(env, t + 1);
> +		init_explored_state(env, t);
> +		ret = push_insn(t, t + insns[t].imm + 1, BRANCH,
> +				env, false);
> +		return ret;
> +	}
> +
>  	switch (BPF_OP(insns[t].code)) {
>  	case BPF_EXIT:
>  		return DONE_EXPLORING;
> @@ -8819,6 +8999,7 @@ static bool regsafe(struct bpf_reg_state *rold, struct bpf_reg_state *rcur,
>  			 */
>  			return false;
>  		}
> +	case PTR_TO_MAP_KEY:
>  	case PTR_TO_MAP_VALUE:
>  		/* If the new min/max/var_off satisfy the old ones and
>  		 * everything else matches, we are OK.
> @@ -9646,6 +9827,8 @@ static int do_check(struct bpf_verifier_env *env)
>  
>  			env->jmps_processed++;
>  			if (opcode == BPF_CALL) {
> +				bool map_elem_callback;
> +
>  				if (BPF_SRC(insn->code) != BPF_K ||
>  				    insn->off != 0 ||
>  				    (insn->src_reg != BPF_REG_0 &&
> @@ -9662,13 +9845,15 @@ static int do_check(struct bpf_verifier_env *env)
>  					verbose(env, "function calls are not allowed while holding a lock\n");
>  					return -EINVAL;
>  				}
> +				map_elem_callback = insn->src_reg != BPF_PSEUDO_CALL &&
> +						   insn->imm == BPF_FUNC_for_each_map_elem;
>  				if (insn->src_reg == BPF_PSEUDO_CALL)
>  					err = check_func_call(env, insn, &env->insn_idx);
>  				else
> -					err = check_helper_call(env, insn->imm, env->insn_idx);
> +					err = check_helper_call(env, insn->imm, &env->insn_idx,
> +								map_elem_callback);

then hopefully this extra 'map_elem_callback' boolean won't be needed.
Only env->insn_idx into &env->insn_idx.
In that sense check_helper_call will become a superset of check_func_call.
Maybe some code between them can be shared too.
Beyond bpf_for_each_map_elem() helper other helpers might use PTR_TO_FUNC.
I hope with this approach all of them will be handled a bit more generically.

>  				if (err)
>  					return err;
> -
>  			} else if (opcode == BPF_JA) {
>  				if (BPF_SRC(insn->code) != BPF_K ||
>  				    insn->imm != 0 ||
> @@ -10090,6 +10275,12 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
>  				goto next_insn;
>  			}
>  
> +			if (insn[0].src_reg == BPF_PSEUDO_FUNC) {
> +				aux = &env->insn_aux_data[i];
> +				aux->ptr_type = PTR_TO_FUNC;
> +				goto next_insn;
> +			}
> +
>  			/* In final convert_pseudo_ld_imm64() step, this is
>  			 * converted into regular 64-bit imm load insn.
>  			 */
> @@ -10222,9 +10413,13 @@ static void convert_pseudo_ld_imm64(struct bpf_verifier_env *env)
>  	int insn_cnt = env->prog->len;
>  	int i;
>  
> -	for (i = 0; i < insn_cnt; i++, insn++)
> -		if (insn->code == (BPF_LD | BPF_IMM | BPF_DW))
> -			insn->src_reg = 0;
> +	for (i = 0; i < insn_cnt; i++, insn++) {
> +		if (insn->code != (BPF_LD | BPF_IMM | BPF_DW))
> +			continue;
> +		if (insn->src_reg == BPF_PSEUDO_FUNC)
> +			continue;
> +		insn->src_reg = 0;
> +	}
>  }
>  
>  /* single env->prog->insni[off] instruction was replaced with the range
> @@ -10846,6 +11041,12 @@ static int jit_subprogs(struct bpf_verifier_env *env)
>  		return 0;
>  
>  	for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) {
> +		if (bpf_pseudo_func(insn)) {
> +			env->insn_aux_data[i].call_imm = insn->imm;
> +			/* subprog is encoded in insn[1].imm */
> +			continue;
> +		}
> +
>  		if (!bpf_pseudo_call(insn))
>  			continue;
>  		/* Upon error here we cannot fall back to interpreter but
> @@ -10975,6 +11176,12 @@ static int jit_subprogs(struct bpf_verifier_env *env)
>  	for (i = 0; i < env->subprog_cnt; i++) {
>  		insn = func[i]->insnsi;
>  		for (j = 0; j < func[i]->len; j++, insn++) {
> +			if (bpf_pseudo_func(insn)) {
> +				subprog = insn[1].imm;
> +				insn[0].imm = (u32)(long)func[subprog]->bpf_func;
> +				insn[1].imm = ((u64)(long)func[subprog]->bpf_func) >> 32;
> +				continue;
> +			}
>  			if (!bpf_pseudo_call(insn))
>  				continue;
>  			subprog = insn->off;
> @@ -11020,6 +11227,11 @@ static int jit_subprogs(struct bpf_verifier_env *env)
>  	 * later look the same as if they were interpreted only.
>  	 */
>  	for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) {
> +		if (bpf_pseudo_func(insn)) {
> +			insn[0].imm = env->insn_aux_data[i].call_imm;
> +			insn[1].imm = find_subprog(env, i + insn[0].imm + 1);
> +			continue;
> +		}
>  		if (!bpf_pseudo_call(insn))
>  			continue;
>  		insn->off = env->insn_aux_data[i].call_imm;
> @@ -11083,6 +11295,13 @@ static int fixup_call_args(struct bpf_verifier_env *env)
>  		verbose(env, "tail_calls are not allowed in non-JITed programs with bpf-to-bpf calls\n");
>  		return -EINVAL;
>  	}
> +	if (env->subprog_cnt > 1 && env->prog->aux->with_callback_fn) {

Does this bool really need to be be part of 'aux'?
There is a loop below that does if (!bpf_pseudo_call
to fixup insns for the interpreter.
May be add if (bpf_pseudo_func()) { return callbacks are not allowed in non-JITed }
to the loop below as well?
It's a trade off between memory and few extra insn.

> +		/* When JIT fails the progs with callback calls
> +		 * have to be rejected, since interpreter doesn't support them yet.
> +		 */
> +		verbose(env, "callbacks are not allowed in non-JITed programs\n");
> +		return -EINVAL;
> +	}
>  	for (i = 0; i < prog->len; i++, insn++) {
>  		if (!bpf_pseudo_call(insn))
>  			continue;

to this loop.

Thanks!
Yonghong Song Feb. 5, 2021, 5:39 p.m. UTC | #2
On 2/4/21 9:49 PM, Alexei Starovoitov wrote:
> On Thu, Feb 04, 2021 at 03:48:29PM -0800, Yonghong Song wrote:
>> The bpf_for_each_map_elem() helper is introduced which
>> iterates all map elements with a callback function. The
>> helper signature looks like
>>    long bpf_for_each_map_elem(map, callback_fn, callback_ctx, flags)
>> and for each map element, the callback_fn will be called. For example,
>> like hashmap, the callback signature may look like
>>    long callback_fn(map, key, val, callback_ctx)
>>
>> There are two known use cases for this. One is from upstream ([1]) where
>> a for_each_map_elem helper may help implement a timeout mechanism
>> in a more generic way. Another is from our internal discussion
>> for a firewall use case where a map contains all the rules. The packet
>> data can be compared to all these rules to decide allow or deny
>> the packet.
>>
>> For array maps, users can already use a bounded loop to traverse
>> elements. Using this helper can avoid using bounded loop. For other
>> type of maps (e.g., hash maps) where bounded loop is hard or
>> impossible to use, this helper provides a convenient way to
>> operate on all elements.
>>
>> For callback_fn, besides map and map element, a callback_ctx,
>> allocated on caller stack, is also passed to the callback
>> function. This callback_ctx argument can provide additional
>> input and allow to write to caller stack for output.
> 
> The approach and implementation look great!
> Few ideas below:
> 
>> +static int check_map_elem_callback(struct bpf_verifier_env *env, int *insn_idx)
>> +{
>> +	struct bpf_verifier_state *state = env->cur_state;
>> +	struct bpf_prog_aux *aux = env->prog->aux;
>> +	struct bpf_func_state *caller, *callee;
>> +	struct bpf_map *map;
>> +	int err, subprog;
>> +
>> +	if (state->curframe + 1 >= MAX_CALL_FRAMES) {
>> +		verbose(env, "the call stack of %d frames is too deep\n",
>> +			state->curframe + 2);
>> +		return -E2BIG;
>> +	}
>> +
>> +	caller = state->frame[state->curframe];
>> +	if (state->frame[state->curframe + 1]) {
>> +		verbose(env, "verifier bug. Frame %d already allocated\n",
>> +			state->curframe + 1);
>> +		return -EFAULT;
>> +	}
>> +
>> +	caller->with_callback_fn = true;
>> +
>> +	callee = kzalloc(sizeof(*callee), GFP_KERNEL);
>> +	if (!callee)
>> +		return -ENOMEM;
>> +	state->frame[state->curframe + 1] = callee;
>> +
>> +	/* callee cannot access r0, r6 - r9 for reading and has to write
>> +	 * into its own stack before reading from it.
>> +	 * callee can read/write into caller's stack
>> +	 */
>> +	init_func_state(env, callee,
>> +			/* remember the callsite, it will be used by bpf_exit */
>> +			*insn_idx /* callsite */,
>> +			state->curframe + 1 /* frameno within this callchain */,
>> +			subprog /* subprog number within this prog */);
>> +
>> +	/* Transfer references to the callee */
>> +	err = transfer_reference_state(callee, caller);
>> +	if (err)
>> +		return err;
>> +
>> +	subprog = caller->regs[BPF_REG_2].subprog;
>> +	if (aux->func_info && aux->func_info_aux[subprog].linkage != BTF_FUNC_STATIC) {
>> +		verbose(env, "callback function R2 not static\n");
>> +		return -EINVAL;
>> +	}
>> +
>> +	map = caller->regs[BPF_REG_1].map_ptr;
> 
> Take a look at for (i = 0; i < 5; i++)  err = check_func_arg loop and record_func_map.
> It stores the map pointer into map_ptr_state and makes sure it's unique,
> so that program doesn't try to pass two different maps into the same 'call insn'.
> It can make this function a bit more generic.
> There would be no need to hard code regs[BPF_REG_1].
> The code would take it from map_ptr_state.
> Also it will help later with optimizing
>    return map->ops->map_for_each_callback(map, callback_fn, callback_ctx, flags);
> since the map pointer will be the same the optimization (that is applied to other map
> operations) can be applied for this callback as well.

sounds good. will try this approach in the next revision.

> 
> The regs[BPF_REG_2] can be generalized a bit as well.
> It think linkage != BTF_FUNC_STATIC can be moved to early check_ld_imm phase.
> While here the check_func_arg() loop can look for PTR_TO_FUNC type,
> remeber the subprog into meta (just like map_ptr_state) and ... continues below

Yes, PTR_TO_FUNC might be used for future helpers like bpf_mod_timer() 
or bpf_for_each_task() etc. Make it more general do make sense.

> 
>> +	if (!map->ops->map_set_for_each_callback_args ||
>> +	    !map->ops->map_for_each_callback) {
>> +		verbose(env, "callback function not allowed for map R1\n");
>> +		return -ENOTSUPP;
>> +	}
>> +
>> +	/* the following is only for hashmap, different maps
>> +	 * can have different callback signatures.
>> +	 */
>> +	err = map->ops->map_set_for_each_callback_args(env, caller, callee);
>> +	if (err)
>> +		return err;
>> +
>> +	clear_caller_saved_regs(env, caller->regs);
>> +
>> +	/* only increment it after check_reg_arg() finished */
>> +	state->curframe++;
>> +
>> +	/* and go analyze first insn of the callee */
>> +	*insn_idx = env->subprog_info[subprog].start - 1;
>> +
>> +	if (env->log.level & BPF_LOG_LEVEL) {
>> +		verbose(env, "caller:\n");
>> +		print_verifier_state(env, caller);
>> +		verbose(env, "callee:\n");
>> +		print_verifier_state(env, callee);
>> +	}
>> +	return 0;
>> +}
>> +
>>   static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
>>   {
>>   	struct bpf_verifier_state *state = env->cur_state;
>>   	struct bpf_func_state *caller, *callee;
>>   	struct bpf_reg_state *r0;
>> -	int err;
>> +	int i, err;
>>   
>>   	callee = state->frame[state->curframe];
>>   	r0 = &callee->regs[BPF_REG_0];
>> @@ -4955,7 +5090,17 @@ static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
>>   	state->curframe--;
>>   	caller = state->frame[state->curframe];
>>   	/* return to the caller whatever r0 had in the callee */
>> -	caller->regs[BPF_REG_0] = *r0;
>> +	if (caller->with_callback_fn) {
>> +		/* reset caller saved regs, the helper calling callback_fn
>> +		 * has RET_INTEGER return types.
>> +		 */
>> +		for (i = 0; i < CALLER_SAVED_REGS; i++)
>> +			mark_reg_not_init(env, caller->regs, caller_saved[i]);
>> +		caller->regs[BPF_REG_0].subreg_def = DEF_NOT_SUBREG;
>> +		mark_reg_unknown(env, caller->regs, BPF_REG_0);
> 
> this part can stay in check_helper_call().

Yes, to verify the callback function from the helper is the most
complex part in my patch set. Your above suggestions should make it
more streamlined as a better fit with existing infrastruture.
Will give a try.

> 
>> +	} else {
>> +		caller->regs[BPF_REG_0] = *r0;
>> +	}
>>   
>>   	/* Transfer references to the caller */
>>   	err = transfer_reference_state(caller, callee);
>> @@ -5091,7 +5236,8 @@ static int check_reference_leak(struct bpf_verifier_env *env)
>>   	return state->acquired_refs ? -EINVAL : 0;
>>   }
>>   
>> -static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn_idx)
>> +static int check_helper_call(struct bpf_verifier_env *env, int func_id, int *insn_idx,
>> +			     bool map_elem_callback)
>>   {
>>   	const struct bpf_func_proto *fn = NULL;
>>   	struct bpf_reg_state *regs;
>> @@ -5151,11 +5297,11 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
>>   			return err;
>>   	}
>>   
>> -	err = record_func_map(env, &meta, func_id, insn_idx);
>> +	err = record_func_map(env, &meta, func_id, *insn_idx);
>>   	if (err)
>>   		return err;
>>   
>> -	err = record_func_key(env, &meta, func_id, insn_idx);
>> +	err = record_func_key(env, &meta, func_id, *insn_idx);
>>   	if (err)
>>   		return err;
>>   
>> @@ -5163,7 +5309,7 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
>>   	 * is inferred from register state.
>>   	 */
>>   	for (i = 0; i < meta.access_size; i++) {
>> -		err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B,
>> +		err = check_mem_access(env, *insn_idx, meta.regno, i, BPF_B,
>>   				       BPF_WRITE, -1, false);
>>   		if (err)
>>   			return err;
>> @@ -5195,6 +5341,11 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
>>   		return -EINVAL;
>>   	}
>>   
>> +	if (map_elem_callback) {
>> +		env->prog->aux->with_callback_fn = true;
>> +		return check_map_elem_callback(env, insn_idx);
> 
> Instead of returning early here.
> The check_func_arg() loop can look for PTR_TO_FUNC type.
> The allocate new callee state,
> do map_set_for_each_callback_args() here.
> and then proceed further.

ditto. Will re-organize to avoid early return here.

> 
>> +	}
>> +
>>   	/* reset caller saved regs */
>>   	for (i = 0; i < CALLER_SAVED_REGS; i++) {
>>   		mark_reg_not_init(env, regs, caller_saved[i]);
> 
> Instead of doing this loop in prepare_func_exit().
> This code can just proceed here and clear caller regs. This loop can stay as-is.
> The transfer of caller->callee would happen already.
> 
> Then there are few lines here that diff didn't show.
> They do regs[BPF_REG_0].subreg_def = DEF_NOT_SUBREG and mark_reg_unknown.
> No need to do them in prepare_func_exit().
> This function can proceed further reusing this caller regs clearing loop and r0 marking.
> 
> Then before returning from check_helper_call()
> it will do what you have in check_map_elem_callback() and it will adjust *insn_idx.
> 
> At this point caller would have regs cleared and r0=undef.
> And callee would have regs setup the way map_set_for_each_callback_args callback meant to do it.
> The only thing prepare_func_exit would need to do is to make sure that assignment:
> caller->regs[BPF_REG_0] = *r0
> doesn't happen. caller's r0 was already set to undef.
> To achieve that I think would be a bit cleaner to mark callee state instead of caller state.
> So instead of caller->with_callback_fn=true maybe callee->in_callback_fn=true ?

Agree. Will try to do normal helper return verification here
instead of a cut-down version in prepare_func_exit().

> 
>> @@ -5306,7 +5457,7 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
>>   		/* For release_reference() */
>>   		regs[BPF_REG_0].ref_obj_id = meta.ref_obj_id;
>>   	} else if (is_acquire_function(func_id, meta.map_ptr)) {
>> -		int id = acquire_reference_state(env, insn_idx);
>> +		int id = acquire_reference_state(env, *insn_idx);
>>   
>>   		if (id < 0)
>>   			return id;
>> @@ -5448,6 +5599,14 @@ static int retrieve_ptr_limit(const struct bpf_reg_state *ptr_reg,
>>   		else
>>   			*ptr_limit = -off;
>>   		return 0;
>> +	case PTR_TO_MAP_KEY:
>> +		if (mask_to_left) {
>> +			*ptr_limit = ptr_reg->umax_value + ptr_reg->off;
>> +		} else {
>> +			off = ptr_reg->smin_value + ptr_reg->off;
>> +			*ptr_limit = ptr_reg->map_ptr->key_size - off;
>> +		}
>> +		return 0;
>>   	case PTR_TO_MAP_VALUE:
>>   		if (mask_to_left) {
>>   			*ptr_limit = ptr_reg->umax_value + ptr_reg->off;
>> @@ -5614,6 +5773,7 @@ static int adjust_ptr_min_max_vals(struct bpf_verifier_env *env,
>>   		verbose(env, "R%d pointer arithmetic on %s prohibited\n",
>>   			dst, reg_type_str[ptr_reg->type]);
>>   		return -EACCES;
>> +	case PTR_TO_MAP_KEY:
>>   	case PTR_TO_MAP_VALUE:
>>   		if (!env->allow_ptr_leaks && !known && (smin_val < 0) != (smax_val < 0)) {
>>   			verbose(env, "R%d has unknown scalar with mixed signed bounds, pointer arithmetic with it prohibited for !root\n",
>> @@ -7818,6 +7978,12 @@ static int check_ld_imm(struct bpf_verifier_env *env, struct bpf_insn *insn)
>>   		return 0;
>>   	}
>>   
>> +	if (insn->src_reg == BPF_PSEUDO_FUNC) {
>> +		dst_reg->type = PTR_TO_FUNC;
>> +		dst_reg->subprog = insn[1].imm;
> 
> Like here check for linkage==static can happen ?

will do.

> 
>> +		return 0;
>> +	}
>> +
>>   	map = env->used_maps[aux->map_index];
>>   	mark_reg_known_zero(env, regs, insn->dst_reg);
>>   	dst_reg->map_ptr = map;
>> @@ -8195,9 +8361,23 @@ static int visit_insn(int t, int insn_cnt, struct bpf_verifier_env *env)
>>   
>>   	/* All non-branch instructions have a single fall-through edge. */
>>   	if (BPF_CLASS(insns[t].code) != BPF_JMP &&
>> -	    BPF_CLASS(insns[t].code) != BPF_JMP32)
>> +	    BPF_CLASS(insns[t].code) != BPF_JMP32 &&
>> +	    !bpf_pseudo_func(insns + t))
>>   		return push_insn(t, t + 1, FALLTHROUGH, env, false);
>>   
>> +	if (bpf_pseudo_func(insns + t)) {
>> +		ret = push_insn(t, t + 1, FALLTHROUGH, env, false);
>> +		if (ret)
>> +			return ret;
>> +
>> +		if (t + 1 < insn_cnt)
>> +			init_explored_state(env, t + 1);
>> +		init_explored_state(env, t);
>> +		ret = push_insn(t, t + insns[t].imm + 1, BRANCH,
>> +				env, false);
>> +		return ret;
>> +	}
>> +
>>   	switch (BPF_OP(insns[t].code)) {
>>   	case BPF_EXIT:
>>   		return DONE_EXPLORING;
>> @@ -8819,6 +8999,7 @@ static bool regsafe(struct bpf_reg_state *rold, struct bpf_reg_state *rcur,
>>   			 */
>>   			return false;
>>   		}
>> +	case PTR_TO_MAP_KEY:
>>   	case PTR_TO_MAP_VALUE:
>>   		/* If the new min/max/var_off satisfy the old ones and
>>   		 * everything else matches, we are OK.
>> @@ -9646,6 +9827,8 @@ static int do_check(struct bpf_verifier_env *env)
>>   
>>   			env->jmps_processed++;
>>   			if (opcode == BPF_CALL) {
>> +				bool map_elem_callback;
>> +
>>   				if (BPF_SRC(insn->code) != BPF_K ||
>>   				    insn->off != 0 ||
>>   				    (insn->src_reg != BPF_REG_0 &&
>> @@ -9662,13 +9845,15 @@ static int do_check(struct bpf_verifier_env *env)
>>   					verbose(env, "function calls are not allowed while holding a lock\n");
>>   					return -EINVAL;
>>   				}
>> +				map_elem_callback = insn->src_reg != BPF_PSEUDO_CALL &&
>> +						   insn->imm == BPF_FUNC_for_each_map_elem;
>>   				if (insn->src_reg == BPF_PSEUDO_CALL)
>>   					err = check_func_call(env, insn, &env->insn_idx);
>>   				else
>> -					err = check_helper_call(env, insn->imm, env->insn_idx);
>> +					err = check_helper_call(env, insn->imm, &env->insn_idx,
>> +								map_elem_callback);
> 
> then hopefully this extra 'map_elem_callback' boolean won't be needed.
> Only env->insn_idx into &env->insn_idx.
> In that sense check_helper_call will become a superset of check_func_call.
> Maybe some code between them can be shared too.
> Beyond bpf_for_each_map_elem() helper other helpers might use PTR_TO_FUNC.
> I hope with this approach all of them will be handled a bit more generically.

We shouldn't use this since we have insn->imm as the helper id. Will 
remove it.

> 
>>   				if (err)
>>   					return err;
>> -
>>   			} else if (opcode == BPF_JA) {
>>   				if (BPF_SRC(insn->code) != BPF_K ||
>>   				    insn->imm != 0 ||
>> @@ -10090,6 +10275,12 @@ static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
>>   				goto next_insn;
>>   			}
>>   
>> +			if (insn[0].src_reg == BPF_PSEUDO_FUNC) {
>> +				aux = &env->insn_aux_data[i];
>> +				aux->ptr_type = PTR_TO_FUNC;
>> +				goto next_insn;
>> +			}
>> +
>>   			/* In final convert_pseudo_ld_imm64() step, this is
>>   			 * converted into regular 64-bit imm load insn.
>>   			 */
>> @@ -10222,9 +10413,13 @@ static void convert_pseudo_ld_imm64(struct bpf_verifier_env *env)
>>   	int insn_cnt = env->prog->len;
>>   	int i;
>>   
>> -	for (i = 0; i < insn_cnt; i++, insn++)
>> -		if (insn->code == (BPF_LD | BPF_IMM | BPF_DW))
>> -			insn->src_reg = 0;
>> +	for (i = 0; i < insn_cnt; i++, insn++) {
>> +		if (insn->code != (BPF_LD | BPF_IMM | BPF_DW))
>> +			continue;
>> +		if (insn->src_reg == BPF_PSEUDO_FUNC)
>> +			continue;
>> +		insn->src_reg = 0;
>> +	}
>>   }
>>   
>>   /* single env->prog->insni[off] instruction was replaced with the range
>> @@ -10846,6 +11041,12 @@ static int jit_subprogs(struct bpf_verifier_env *env)
>>   		return 0;
>>   
>>   	for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) {
>> +		if (bpf_pseudo_func(insn)) {
>> +			env->insn_aux_data[i].call_imm = insn->imm;
>> +			/* subprog is encoded in insn[1].imm */
>> +			continue;
>> +		}
>> +
>>   		if (!bpf_pseudo_call(insn))
>>   			continue;
>>   		/* Upon error here we cannot fall back to interpreter but
>> @@ -10975,6 +11176,12 @@ static int jit_subprogs(struct bpf_verifier_env *env)
>>   	for (i = 0; i < env->subprog_cnt; i++) {
>>   		insn = func[i]->insnsi;
>>   		for (j = 0; j < func[i]->len; j++, insn++) {
>> +			if (bpf_pseudo_func(insn)) {
>> +				subprog = insn[1].imm;
>> +				insn[0].imm = (u32)(long)func[subprog]->bpf_func;
>> +				insn[1].imm = ((u64)(long)func[subprog]->bpf_func) >> 32;
>> +				continue;
>> +			}
>>   			if (!bpf_pseudo_call(insn))
>>   				continue;
>>   			subprog = insn->off;
>> @@ -11020,6 +11227,11 @@ static int jit_subprogs(struct bpf_verifier_env *env)
>>   	 * later look the same as if they were interpreted only.
>>   	 */
>>   	for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) {
>> +		if (bpf_pseudo_func(insn)) {
>> +			insn[0].imm = env->insn_aux_data[i].call_imm;
>> +			insn[1].imm = find_subprog(env, i + insn[0].imm + 1);
>> +			continue;
>> +		}
>>   		if (!bpf_pseudo_call(insn))
>>   			continue;
>>   		insn->off = env->insn_aux_data[i].call_imm;
>> @@ -11083,6 +11295,13 @@ static int fixup_call_args(struct bpf_verifier_env *env)
>>   		verbose(env, "tail_calls are not allowed in non-JITed programs with bpf-to-bpf calls\n");
>>   		return -EINVAL;
>>   	}
>> +	if (env->subprog_cnt > 1 && env->prog->aux->with_callback_fn) {
> 
> Does this bool really need to be be part of 'aux'?
> There is a loop below that does if (!bpf_pseudo_call
> to fixup insns for the interpreter.
> May be add if (bpf_pseudo_func()) { return callbacks are not allowed in non-JITed }
> to the loop below as well?
> It's a trade off between memory and few extra insn.

I use this bit similar to tailcall. But agree with you that we can the 
check in below loop.

> 
>> +		/* When JIT fails the progs with callback calls
>> +		 * have to be rejected, since interpreter doesn't support them yet.
>> +		 */
>> +		verbose(env, "callbacks are not allowed in non-JITed programs\n");
>> +		return -EINVAL;
>> +	}
>>   	for (i = 0; i < prog->len; i++, insn++) {
>>   		if (!bpf_pseudo_call(insn))
>>   			continue;
> 
> to this loop.
> 
> Thanks!
>
Andrii Nakryiko Feb. 8, 2021, 6:16 p.m. UTC | #3
On Thu, Feb 4, 2021 at 5:53 PM Yonghong Song <yhs@fb.com> wrote:
>
> The bpf_for_each_map_elem() helper is introduced which
> iterates all map elements with a callback function. The
> helper signature looks like
>   long bpf_for_each_map_elem(map, callback_fn, callback_ctx, flags)
> and for each map element, the callback_fn will be called. For example,
> like hashmap, the callback signature may look like
>   long callback_fn(map, key, val, callback_ctx)
>
> There are two known use cases for this. One is from upstream ([1]) where
> a for_each_map_elem helper may help implement a timeout mechanism
> in a more generic way. Another is from our internal discussion
> for a firewall use case where a map contains all the rules. The packet
> data can be compared to all these rules to decide allow or deny
> the packet.
>
> For array maps, users can already use a bounded loop to traverse
> elements. Using this helper can avoid using bounded loop. For other
> type of maps (e.g., hash maps) where bounded loop is hard or
> impossible to use, this helper provides a convenient way to
> operate on all elements.
>
> For callback_fn, besides map and map element, a callback_ctx,
> allocated on caller stack, is also passed to the callback
> function. This callback_ctx argument can provide additional
> input and allow to write to caller stack for output.
>
> If the callback_fn returns 0, the helper will iterate through next
> element if available. If the callback_fn returns 1, the helper
> will stop iterating and returns to the bpf program. Other return
> values are not used for now.
>
> Currently, this helper is only available with jit. It is possible
> to make it work with interpreter with so effort but I leave it
> as the future work.
>
> [1]: https://lore.kernel.org/bpf/20210122205415.113822-1-xiyou.wangcong@gmail.com/
>
> Signed-off-by: Yonghong Song <yhs@fb.com>
> ---

This is a great feature! Few questions and nits below.

>  include/linux/bpf.h            |  14 ++
>  include/linux/bpf_verifier.h   |   3 +
>  include/uapi/linux/bpf.h       |  28 ++++
>  kernel/bpf/bpf_iter.c          |  16 +++
>  kernel/bpf/helpers.c           |   2 +
>  kernel/bpf/verifier.c          | 251 ++++++++++++++++++++++++++++++---
>  kernel/trace/bpf_trace.c       |   2 +
>  tools/include/uapi/linux/bpf.h |  28 ++++
>  8 files changed, 328 insertions(+), 16 deletions(-)
>

[...]

>  const struct bpf_func_proto *bpf_tracing_func_proto(
>         enum bpf_func_id func_id, const struct bpf_prog *prog);
> diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> index dfe6f85d97dd..c4366b3da342 100644
> --- a/include/linux/bpf_verifier.h
> +++ b/include/linux/bpf_verifier.h
> @@ -68,6 +68,8 @@ struct bpf_reg_state {
>                         unsigned long raw1;
>                         unsigned long raw2;
>                 } raw;
> +
> +               u32 subprog; /* for PTR_TO_FUNC */

is it offset to subprog (in bytes or instructions?) or it's subprog
index? Let's make it clear with a better name or at least a comment.

>         };
>         /* For PTR_TO_PACKET, used to find other pointers with the same variable
>          * offset, so they can share range knowledge.
> @@ -204,6 +206,7 @@ struct bpf_func_state {
>         int acquired_refs;
>         struct bpf_reference_state *refs;
>         int allocated_stack;
> +       bool with_callback_fn;
>         struct bpf_stack_state *stack;
>  };
>
> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
> index c001766adcbc..d55bd4557376 100644
> --- a/include/uapi/linux/bpf.h
> +++ b/include/uapi/linux/bpf.h
> @@ -393,6 +393,15 @@ enum bpf_link_type {
>   *                   is struct/union.
>   */
>  #define BPF_PSEUDO_BTF_ID      3
> +/* insn[0].src_reg:  BPF_PSEUDO_FUNC
> + * insn[0].imm:      insn offset to the func
> + * insn[1].imm:      0
> + * insn[0].off:      0
> + * insn[1].off:      0
> + * ldimm64 rewrite:  address of the function
> + * verifier type:    PTR_TO_FUNC.
> + */
> +#define BPF_PSEUDO_FUNC                4
>
>  /* when bpf_call->src_reg == BPF_PSEUDO_CALL, bpf_call->imm == pc-relative
>   * offset to another bpf function
> @@ -3836,6 +3845,24 @@ union bpf_attr {
>   *     Return
>   *             A pointer to a struct socket on success or NULL if the file is
>   *             not a socket.
> + *
> + * long bpf_for_each_map_elem(struct bpf_map *map, void *callback_fn, void *callback_ctx, u64 flags)

struct bpf_map * here might be problematic. In other instances where
we pass map (bpf_map_update_elem, for example) we specify this as
(void *). Let's do that instead here?

> + *     Description
> + *             For each element in **map**, call **callback_fn** function with
> + *             **map**, **callback_ctx** and other map-specific parameters.
> + *             For example, for hash and array maps, the callback signature can
> + *             be `u64 callback_fn(map, map_key, map_value, callback_ctx)`.
> + *             The **callback_fn** should be a static function and
> + *             the **callback_ctx** should be a pointer to the stack.
> + *             The **flags** is used to control certain aspects of the helper.
> + *             Currently, the **flags** must be 0.
> + *
> + *             If **callback_fn** return 0, the helper will continue to the next
> + *             element. If return value is 1, the helper will skip the rest of
> + *             elements and return. Other return values are not used now.
> + *     Return
> + *             0 for success, **-EINVAL** for invalid **flags** or unsupported
> + *             **callback_fn** return value.

just a thought: returning the number of elements *actually* iterated
seems useful (even though I don't have a specific use case right now).

>   */
>  #define __BPF_FUNC_MAPPER(FN)          \
>         FN(unspec),                     \
> @@ -4001,6 +4028,7 @@ union bpf_attr {
>         FN(ktime_get_coarse_ns),        \
>         FN(ima_inode_hash),             \
>         FN(sock_from_file),             \
> +       FN(for_each_map_elem),          \

to be more in sync with other map operations, can we call this
`bpf_map_for_each_elem`? I think it makes sense and doesn't read
backwards at all.

>         /* */
>
>  /* integer value in 'imm' field of BPF_CALL instruction selects which helper
> diff --git a/kernel/bpf/bpf_iter.c b/kernel/bpf/bpf_iter.c
> index 5454161407f1..5187f49d3216 100644
> --- a/kernel/bpf/bpf_iter.c
> +++ b/kernel/bpf/bpf_iter.c
> @@ -675,3 +675,19 @@ int bpf_iter_run_prog(struct bpf_prog *prog, void *ctx)
>          */
>         return ret == 0 ? 0 : -EAGAIN;
>  }
> +
> +BPF_CALL_4(bpf_for_each_map_elem, struct bpf_map *, map, void *, callback_fn,
> +          void *, callback_ctx, u64, flags)
> +{
> +       return map->ops->map_for_each_callback(map, callback_fn, callback_ctx, flags);
> +}
> +
> +const struct bpf_func_proto bpf_for_each_map_elem_proto = {
> +       .func           = bpf_for_each_map_elem,
> +       .gpl_only       = false,
> +       .ret_type       = RET_INTEGER,
> +       .arg1_type      = ARG_CONST_MAP_PTR,
> +       .arg2_type      = ARG_PTR_TO_FUNC,
> +       .arg3_type      = ARG_PTR_TO_STACK_OR_NULL,

I looked through this code just once but haven't noticed anything that
would strictly require that pointer is specifically to stack. Can this
be made into a pointer to any allocated memory? E.g., why can't we
allow passing a pointer to a ringbuf sample, for instance? Or
MAP_VALUE?

> +       .arg4_type      = ARG_ANYTHING,
> +};
> diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
> index 308427fe03a3..074800226327 100644
> --- a/kernel/bpf/helpers.c
> +++ b/kernel/bpf/helpers.c
> @@ -708,6 +708,8 @@ bpf_base_func_proto(enum bpf_func_id func_id)
>                 return &bpf_ringbuf_discard_proto;
>         case BPF_FUNC_ringbuf_query:
>                 return &bpf_ringbuf_query_proto;
> +       case BPF_FUNC_for_each_map_elem:
> +               return &bpf_for_each_map_elem_proto;
>         default:
>                 break;
>         }
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index db294b75d03b..050b067a0be6 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -234,6 +234,12 @@ static bool bpf_pseudo_call(const struct bpf_insn *insn)
>                insn->src_reg == BPF_PSEUDO_CALL;
>  }
>

[...]

>         map = env->used_maps[aux->map_index];
>         mark_reg_known_zero(env, regs, insn->dst_reg);
>         dst_reg->map_ptr = map;
> @@ -8195,9 +8361,23 @@ static int visit_insn(int t, int insn_cnt, struct bpf_verifier_env *env)
>
>         /* All non-branch instructions have a single fall-through edge. */
>         if (BPF_CLASS(insns[t].code) != BPF_JMP &&
> -           BPF_CLASS(insns[t].code) != BPF_JMP32)
> +           BPF_CLASS(insns[t].code) != BPF_JMP32 &&
> +           !bpf_pseudo_func(insns + t))
>                 return push_insn(t, t + 1, FALLTHROUGH, env, false);
>
> +       if (bpf_pseudo_func(insns + t)) {


if you check this before above JMP|JMP32 check, you won't need to do
!bpf_pseudo_func, right? I think it's cleaner.

> +               ret = push_insn(t, t + 1, FALLTHROUGH, env, false);
> +               if (ret)
> +                       return ret;
> +
> +               if (t + 1 < insn_cnt)
> +                       init_explored_state(env, t + 1);
> +               init_explored_state(env, t);
> +               ret = push_insn(t, t + insns[t].imm + 1, BRANCH,
> +                               env, false);
> +               return ret;
> +       }
> +
>         switch (BPF_OP(insns[t].code)) {
>         case BPF_EXIT:
>                 return DONE_EXPLORING;

[...]
Yonghong Song Feb. 9, 2021, 6:41 a.m. UTC | #4
On 2/8/21 10:16 AM, Andrii Nakryiko wrote:
> On Thu, Feb 4, 2021 at 5:53 PM Yonghong Song <yhs@fb.com> wrote:
>>
>> The bpf_for_each_map_elem() helper is introduced which
>> iterates all map elements with a callback function. The
>> helper signature looks like
>>    long bpf_for_each_map_elem(map, callback_fn, callback_ctx, flags)
>> and for each map element, the callback_fn will be called. For example,
>> like hashmap, the callback signature may look like
>>    long callback_fn(map, key, val, callback_ctx)
>>
>> There are two known use cases for this. One is from upstream ([1]) where
>> a for_each_map_elem helper may help implement a timeout mechanism
>> in a more generic way. Another is from our internal discussion
>> for a firewall use case where a map contains all the rules. The packet
>> data can be compared to all these rules to decide allow or deny
>> the packet.
>>
>> For array maps, users can already use a bounded loop to traverse
>> elements. Using this helper can avoid using bounded loop. For other
>> type of maps (e.g., hash maps) where bounded loop is hard or
>> impossible to use, this helper provides a convenient way to
>> operate on all elements.
>>
>> For callback_fn, besides map and map element, a callback_ctx,
>> allocated on caller stack, is also passed to the callback
>> function. This callback_ctx argument can provide additional
>> input and allow to write to caller stack for output.
>>
>> If the callback_fn returns 0, the helper will iterate through next
>> element if available. If the callback_fn returns 1, the helper
>> will stop iterating and returns to the bpf program. Other return
>> values are not used for now.
>>
>> Currently, this helper is only available with jit. It is possible
>> to make it work with interpreter with so effort but I leave it
>> as the future work.
>>
>> [1]: https://lore.kernel.org/bpf/20210122205415.113822-1-xiyou.wangcong@gmail.com/
>>
>> Signed-off-by: Yonghong Song <yhs@fb.com>
>> ---
> 
> This is a great feature! Few questions and nits below.
> 
>>   include/linux/bpf.h            |  14 ++
>>   include/linux/bpf_verifier.h   |   3 +
>>   include/uapi/linux/bpf.h       |  28 ++++
>>   kernel/bpf/bpf_iter.c          |  16 +++
>>   kernel/bpf/helpers.c           |   2 +
>>   kernel/bpf/verifier.c          | 251 ++++++++++++++++++++++++++++++---
>>   kernel/trace/bpf_trace.c       |   2 +
>>   tools/include/uapi/linux/bpf.h |  28 ++++
>>   8 files changed, 328 insertions(+), 16 deletions(-)
>>
> 
> [...]
> 
>>   const struct bpf_func_proto *bpf_tracing_func_proto(
>>          enum bpf_func_id func_id, const struct bpf_prog *prog);
>> diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
>> index dfe6f85d97dd..c4366b3da342 100644
>> --- a/include/linux/bpf_verifier.h
>> +++ b/include/linux/bpf_verifier.h
>> @@ -68,6 +68,8 @@ struct bpf_reg_state {
>>                          unsigned long raw1;
>>                          unsigned long raw2;
>>                  } raw;
>> +
>> +               u32 subprog; /* for PTR_TO_FUNC */
> 
> is it offset to subprog (in bytes or instructions?) or it's subprog
> index? Let's make it clear with a better name or at least a comment.

This is for subprog number (or index in some subprog related arrays).
In verifier.c, subprog or subprogno is used to represent the subprog
number. I will use subprogno in the next revision.

> 
>>          };
>>          /* For PTR_TO_PACKET, used to find other pointers with the same variable
>>           * offset, so they can share range knowledge.
>> @@ -204,6 +206,7 @@ struct bpf_func_state {
>>          int acquired_refs;
>>          struct bpf_reference_state *refs;
>>          int allocated_stack;
>> +       bool with_callback_fn;
>>          struct bpf_stack_state *stack;
>>   };
>>
>> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
>> index c001766adcbc..d55bd4557376 100644
>> --- a/include/uapi/linux/bpf.h
>> +++ b/include/uapi/linux/bpf.h
>> @@ -393,6 +393,15 @@ enum bpf_link_type {
>>    *                   is struct/union.
>>    */
>>   #define BPF_PSEUDO_BTF_ID      3
>> +/* insn[0].src_reg:  BPF_PSEUDO_FUNC
>> + * insn[0].imm:      insn offset to the func
>> + * insn[1].imm:      0
>> + * insn[0].off:      0
>> + * insn[1].off:      0
>> + * ldimm64 rewrite:  address of the function
>> + * verifier type:    PTR_TO_FUNC.
>> + */
>> +#define BPF_PSEUDO_FUNC                4
>>
>>   /* when bpf_call->src_reg == BPF_PSEUDO_CALL, bpf_call->imm == pc-relative
>>    * offset to another bpf function
>> @@ -3836,6 +3845,24 @@ union bpf_attr {
>>    *     Return
>>    *             A pointer to a struct socket on success or NULL if the file is
>>    *             not a socket.
>> + *
>> + * long bpf_for_each_map_elem(struct bpf_map *map, void *callback_fn, void *callback_ctx, u64 flags)
> 
> struct bpf_map * here might be problematic. In other instances where
> we pass map (bpf_map_update_elem, for example) we specify this as
> (void *). Let's do that instead here?

We should be fine here. bpf_map_lookup_elem etc. all have "struct 
bpf_map *map", it is rewritten by bpf_helpers_doc.py to "void *map".

> 
>> + *     Description
>> + *             For each element in **map**, call **callback_fn** function with
>> + *             **map**, **callback_ctx** and other map-specific parameters.
>> + *             For example, for hash and array maps, the callback signature can
>> + *             be `u64 callback_fn(map, map_key, map_value, callback_ctx)`.
>> + *             The **callback_fn** should be a static function and
>> + *             the **callback_ctx** should be a pointer to the stack.
>> + *             The **flags** is used to control certain aspects of the helper.
>> + *             Currently, the **flags** must be 0.
>> + *
>> + *             If **callback_fn** return 0, the helper will continue to the next
>> + *             element. If return value is 1, the helper will skip the rest of
>> + *             elements and return. Other return values are not used now.
>> + *     Return
>> + *             0 for success, **-EINVAL** for invalid **flags** or unsupported
>> + *             **callback_fn** return value.
> 
> just a thought: returning the number of elements *actually* iterated
> seems useful (even though I don't have a specific use case right now).

Good idea. Will change to this in the next revision.

> 
>>    */
>>   #define __BPF_FUNC_MAPPER(FN)          \
>>          FN(unspec),                     \
>> @@ -4001,6 +4028,7 @@ union bpf_attr {
>>          FN(ktime_get_coarse_ns),        \
>>          FN(ima_inode_hash),             \
>>          FN(sock_from_file),             \
>> +       FN(for_each_map_elem),          \
> 
> to be more in sync with other map operations, can we call this
> `bpf_map_for_each_elem`? I think it makes sense and doesn't read
> backwards at all.

I am using for_each prefix as in the future I (or others) may add
more for_each_* helpers, e.g., for_each_task, for_each_hlist_rcu, etc.
This represents a family of helpers with callback functions. So I
would like to stick with for_each_* names.

> 
>>          /* */
>>
>>   /* integer value in 'imm' field of BPF_CALL instruction selects which helper
>> diff --git a/kernel/bpf/bpf_iter.c b/kernel/bpf/bpf_iter.c
>> index 5454161407f1..5187f49d3216 100644
>> --- a/kernel/bpf/bpf_iter.c
>> +++ b/kernel/bpf/bpf_iter.c
>> @@ -675,3 +675,19 @@ int bpf_iter_run_prog(struct bpf_prog *prog, void *ctx)
>>           */
>>          return ret == 0 ? 0 : -EAGAIN;
>>   }
>> +
>> +BPF_CALL_4(bpf_for_each_map_elem, struct bpf_map *, map, void *, callback_fn,
>> +          void *, callback_ctx, u64, flags)
>> +{
>> +       return map->ops->map_for_each_callback(map, callback_fn, callback_ctx, flags);
>> +}
>> +
>> +const struct bpf_func_proto bpf_for_each_map_elem_proto = {
>> +       .func           = bpf_for_each_map_elem,
>> +       .gpl_only       = false,
>> +       .ret_type       = RET_INTEGER,
>> +       .arg1_type      = ARG_CONST_MAP_PTR,
>> +       .arg2_type      = ARG_PTR_TO_FUNC,
>> +       .arg3_type      = ARG_PTR_TO_STACK_OR_NULL,
> 
> I looked through this code just once but haven't noticed anything that
> would strictly require that pointer is specifically to stack. Can this
> be made into a pointer to any allocated memory? E.g., why can't we
> allow passing a pointer to a ringbuf sample, for instance? Or
> MAP_VALUE?

ARG_PTR_TO_STACK_OR_NULL in the most flexible one. For example, if you
want to pass map_value or ringbuf sample, you can assign these values
to the stack like
    struct ctx_t {
       struct map_value_t *map_value;
       char *ringbuf_mem;
    } tmp;
    tmp.map_value = ...;
    tmp.ringbuf_mem = ...;
    bpf_for_each_map_elem(map, callback_fn, &tmp, flags);
and callback_fn will be able to access map_value/ringbuf_mem
with their original register types.

This does not allow to pass ringbuf/map_value etc. as the
first class citizen. But I think this is a good compromise
to permit greater flexibility.

> 
>> +       .arg4_type      = ARG_ANYTHING,
>> +};
>> diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
>> index 308427fe03a3..074800226327 100644
>> --- a/kernel/bpf/helpers.c
>> +++ b/kernel/bpf/helpers.c
>> @@ -708,6 +708,8 @@ bpf_base_func_proto(enum bpf_func_id func_id)
>>                  return &bpf_ringbuf_discard_proto;
>>          case BPF_FUNC_ringbuf_query:
>>                  return &bpf_ringbuf_query_proto;
>> +       case BPF_FUNC_for_each_map_elem:
>> +               return &bpf_for_each_map_elem_proto;
>>          default:
>>                  break;
>>          }
>> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
>> index db294b75d03b..050b067a0be6 100644
>> --- a/kernel/bpf/verifier.c
>> +++ b/kernel/bpf/verifier.c
>> @@ -234,6 +234,12 @@ static bool bpf_pseudo_call(const struct bpf_insn *insn)
>>                 insn->src_reg == BPF_PSEUDO_CALL;
>>   }
>>
> 
> [...]
> 
>>          map = env->used_maps[aux->map_index];
>>          mark_reg_known_zero(env, regs, insn->dst_reg);
>>          dst_reg->map_ptr = map;
>> @@ -8195,9 +8361,23 @@ static int visit_insn(int t, int insn_cnt, struct bpf_verifier_env *env)
>>
>>          /* All non-branch instructions have a single fall-through edge. */
>>          if (BPF_CLASS(insns[t].code) != BPF_JMP &&
>> -           BPF_CLASS(insns[t].code) != BPF_JMP32)
>> +           BPF_CLASS(insns[t].code) != BPF_JMP32 &&
>> +           !bpf_pseudo_func(insns + t))
>>                  return push_insn(t, t + 1, FALLTHROUGH, env, false);
>>
>> +       if (bpf_pseudo_func(insns + t)) {
> 
> 
> if you check this before above JMP|JMP32 check, you won't need to do
> !bpf_pseudo_func, right? I think it's cleaner.

Agree. will change in v2.

> 
>> +               ret = push_insn(t, t + 1, FALLTHROUGH, env, false);
>> +               if (ret)
>> +                       return ret;
>> +
>> +               if (t + 1 < insn_cnt)
>> +                       init_explored_state(env, t + 1);
>> +               init_explored_state(env, t);
>> +               ret = push_insn(t, t + insns[t].imm + 1, BRANCH,
>> +                               env, false);
>> +               return ret;
>> +       }
>> +
>>          switch (BPF_OP(insns[t].code)) {
>>          case BPF_EXIT:
>>                  return DONE_EXPLORING;
> 
> [...]
>
Andrii Nakryiko Feb. 9, 2021, 5:33 p.m. UTC | #5
On Mon, Feb 8, 2021 at 10:41 PM Yonghong Song <yhs@fb.com> wrote:
>
>
>
> On 2/8/21 10:16 AM, Andrii Nakryiko wrote:
> > On Thu, Feb 4, 2021 at 5:53 PM Yonghong Song <yhs@fb.com> wrote:
> >>
> >> The bpf_for_each_map_elem() helper is introduced which
> >> iterates all map elements with a callback function. The
> >> helper signature looks like
> >>    long bpf_for_each_map_elem(map, callback_fn, callback_ctx, flags)
> >> and for each map element, the callback_fn will be called. For example,
> >> like hashmap, the callback signature may look like
> >>    long callback_fn(map, key, val, callback_ctx)
> >>
> >> There are two known use cases for this. One is from upstream ([1]) where
> >> a for_each_map_elem helper may help implement a timeout mechanism
> >> in a more generic way. Another is from our internal discussion
> >> for a firewall use case where a map contains all the rules. The packet
> >> data can be compared to all these rules to decide allow or deny
> >> the packet.
> >>
> >> For array maps, users can already use a bounded loop to traverse
> >> elements. Using this helper can avoid using bounded loop. For other
> >> type of maps (e.g., hash maps) where bounded loop is hard or
> >> impossible to use, this helper provides a convenient way to
> >> operate on all elements.
> >>
> >> For callback_fn, besides map and map element, a callback_ctx,
> >> allocated on caller stack, is also passed to the callback
> >> function. This callback_ctx argument can provide additional
> >> input and allow to write to caller stack for output.
> >>
> >> If the callback_fn returns 0, the helper will iterate through next
> >> element if available. If the callback_fn returns 1, the helper
> >> will stop iterating and returns to the bpf program. Other return
> >> values are not used for now.
> >>
> >> Currently, this helper is only available with jit. It is possible
> >> to make it work with interpreter with so effort but I leave it
> >> as the future work.
> >>
> >> [1]: https://lore.kernel.org/bpf/20210122205415.113822-1-xiyou.wangcong@gmail.com/
> >>
> >> Signed-off-by: Yonghong Song <yhs@fb.com>
> >> ---
> >
> > This is a great feature! Few questions and nits below.
> >
> >>   include/linux/bpf.h            |  14 ++
> >>   include/linux/bpf_verifier.h   |   3 +
> >>   include/uapi/linux/bpf.h       |  28 ++++
> >>   kernel/bpf/bpf_iter.c          |  16 +++
> >>   kernel/bpf/helpers.c           |   2 +
> >>   kernel/bpf/verifier.c          | 251 ++++++++++++++++++++++++++++++---
> >>   kernel/trace/bpf_trace.c       |   2 +
> >>   tools/include/uapi/linux/bpf.h |  28 ++++
> >>   8 files changed, 328 insertions(+), 16 deletions(-)
> >>
> >
> > [...]
> >
> >>   const struct bpf_func_proto *bpf_tracing_func_proto(
> >>          enum bpf_func_id func_id, const struct bpf_prog *prog);
> >> diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
> >> index dfe6f85d97dd..c4366b3da342 100644
> >> --- a/include/linux/bpf_verifier.h
> >> +++ b/include/linux/bpf_verifier.h
> >> @@ -68,6 +68,8 @@ struct bpf_reg_state {
> >>                          unsigned long raw1;
> >>                          unsigned long raw2;
> >>                  } raw;
> >> +
> >> +               u32 subprog; /* for PTR_TO_FUNC */
> >
> > is it offset to subprog (in bytes or instructions?) or it's subprog
> > index? Let's make it clear with a better name or at least a comment.
>
> This is for subprog number (or index in some subprog related arrays).
> In verifier.c, subprog or subprogno is used to represent the subprog
> number. I will use subprogno in the next revision.
>

yeah, that's more clear

> >
> >>          };
> >>          /* For PTR_TO_PACKET, used to find other pointers with the same variable
> >>           * offset, so they can share range knowledge.
> >> @@ -204,6 +206,7 @@ struct bpf_func_state {
> >>          int acquired_refs;
> >>          struct bpf_reference_state *refs;
> >>          int allocated_stack;
> >> +       bool with_callback_fn;
> >>          struct bpf_stack_state *stack;
> >>   };
> >>
> >> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
> >> index c001766adcbc..d55bd4557376 100644
> >> --- a/include/uapi/linux/bpf.h
> >> +++ b/include/uapi/linux/bpf.h
> >> @@ -393,6 +393,15 @@ enum bpf_link_type {
> >>    *                   is struct/union.
> >>    */
> >>   #define BPF_PSEUDO_BTF_ID      3
> >> +/* insn[0].src_reg:  BPF_PSEUDO_FUNC
> >> + * insn[0].imm:      insn offset to the func
> >> + * insn[1].imm:      0
> >> + * insn[0].off:      0
> >> + * insn[1].off:      0
> >> + * ldimm64 rewrite:  address of the function
> >> + * verifier type:    PTR_TO_FUNC.
> >> + */
> >> +#define BPF_PSEUDO_FUNC                4
> >>
> >>   /* when bpf_call->src_reg == BPF_PSEUDO_CALL, bpf_call->imm == pc-relative
> >>    * offset to another bpf function
> >> @@ -3836,6 +3845,24 @@ union bpf_attr {
> >>    *     Return
> >>    *             A pointer to a struct socket on success or NULL if the file is
> >>    *             not a socket.
> >> + *
> >> + * long bpf_for_each_map_elem(struct bpf_map *map, void *callback_fn, void *callback_ctx, u64 flags)
> >
> > struct bpf_map * here might be problematic. In other instances where
> > we pass map (bpf_map_update_elem, for example) we specify this as
> > (void *). Let's do that instead here?
>
> We should be fine here. bpf_map_lookup_elem etc. all have "struct
> bpf_map *map", it is rewritten by bpf_helpers_doc.py to "void *map".

ok, cool

>
> >
> >> + *     Description
> >> + *             For each element in **map**, call **callback_fn** function with
> >> + *             **map**, **callback_ctx** and other map-specific parameters.
> >> + *             For example, for hash and array maps, the callback signature can
> >> + *             be `u64 callback_fn(map, map_key, map_value, callback_ctx)`.
> >> + *             The **callback_fn** should be a static function and
> >> + *             the **callback_ctx** should be a pointer to the stack.
> >> + *             The **flags** is used to control certain aspects of the helper.
> >> + *             Currently, the **flags** must be 0.
> >> + *
> >> + *             If **callback_fn** return 0, the helper will continue to the next
> >> + *             element. If return value is 1, the helper will skip the rest of
> >> + *             elements and return. Other return values are not used now.
> >> + *     Return
> >> + *             0 for success, **-EINVAL** for invalid **flags** or unsupported
> >> + *             **callback_fn** return value.
> >
> > just a thought: returning the number of elements *actually* iterated
> > seems useful (even though I don't have a specific use case right now).
>
> Good idea. Will change to this in the next revision.
>
> >
> >>    */
> >>   #define __BPF_FUNC_MAPPER(FN)          \
> >>          FN(unspec),                     \
> >> @@ -4001,6 +4028,7 @@ union bpf_attr {
> >>          FN(ktime_get_coarse_ns),        \
> >>          FN(ima_inode_hash),             \
> >>          FN(sock_from_file),             \
> >> +       FN(for_each_map_elem),          \
> >
> > to be more in sync with other map operations, can we call this
> > `bpf_map_for_each_elem`? I think it makes sense and doesn't read
> > backwards at all.
>
> I am using for_each prefix as in the future I (or others) may add
> more for_each_* helpers, e.g., for_each_task, for_each_hlist_rcu, etc.
> This represents a family of helpers with callback functions. So I
> would like to stick with for_each_* names.
>

fair enough, not a big deal

> >
> >>          /* */
> >>
> >>   /* integer value in 'imm' field of BPF_CALL instruction selects which helper
> >> diff --git a/kernel/bpf/bpf_iter.c b/kernel/bpf/bpf_iter.c
> >> index 5454161407f1..5187f49d3216 100644
> >> --- a/kernel/bpf/bpf_iter.c
> >> +++ b/kernel/bpf/bpf_iter.c
> >> @@ -675,3 +675,19 @@ int bpf_iter_run_prog(struct bpf_prog *prog, void *ctx)
> >>           */
> >>          return ret == 0 ? 0 : -EAGAIN;
> >>   }
> >> +
> >> +BPF_CALL_4(bpf_for_each_map_elem, struct bpf_map *, map, void *, callback_fn,
> >> +          void *, callback_ctx, u64, flags)
> >> +{
> >> +       return map->ops->map_for_each_callback(map, callback_fn, callback_ctx, flags);
> >> +}
> >> +
> >> +const struct bpf_func_proto bpf_for_each_map_elem_proto = {
> >> +       .func           = bpf_for_each_map_elem,
> >> +       .gpl_only       = false,
> >> +       .ret_type       = RET_INTEGER,
> >> +       .arg1_type      = ARG_CONST_MAP_PTR,
> >> +       .arg2_type      = ARG_PTR_TO_FUNC,
> >> +       .arg3_type      = ARG_PTR_TO_STACK_OR_NULL,
> >
> > I looked through this code just once but haven't noticed anything that
> > would strictly require that pointer is specifically to stack. Can this
> > be made into a pointer to any allocated memory? E.g., why can't we
> > allow passing a pointer to a ringbuf sample, for instance? Or
> > MAP_VALUE?
>
> ARG_PTR_TO_STACK_OR_NULL in the most flexible one. For example, if you
> want to pass map_value or ringbuf sample, you can assign these values
> to the stack like
>     struct ctx_t {
>        struct map_value_t *map_value;
>        char *ringbuf_mem;
>     } tmp;
>     tmp.map_value = ...;
>     tmp.ringbuf_mem = ...;
>     bpf_for_each_map_elem(map, callback_fn, &tmp, flags);
> and callback_fn will be able to access map_value/ringbuf_mem
> with their original register types.
>
> This does not allow to pass ringbuf/map_value etc. as the
> first class citizen. But I think this is a good compromise
> to permit greater flexibility.

Yeah, thanks for the explanation.

>
> >
> >> +       .arg4_type      = ARG_ANYTHING,
> >> +};
> >> diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
> >> index 308427fe03a3..074800226327 100644
> >> --- a/kernel/bpf/helpers.c
> >> +++ b/kernel/bpf/helpers.c
> >> @@ -708,6 +708,8 @@ bpf_base_func_proto(enum bpf_func_id func_id)
> >>                  return &bpf_ringbuf_discard_proto;
> >>          case BPF_FUNC_ringbuf_query:
> >>                  return &bpf_ringbuf_query_proto;
> >> +       case BPF_FUNC_for_each_map_elem:
> >> +               return &bpf_for_each_map_elem_proto;
> >>          default:
> >>                  break;
> >>          }
> >> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> >> index db294b75d03b..050b067a0be6 100644
> >> --- a/kernel/bpf/verifier.c
> >> +++ b/kernel/bpf/verifier.c
> >> @@ -234,6 +234,12 @@ static bool bpf_pseudo_call(const struct bpf_insn *insn)
> >>                 insn->src_reg == BPF_PSEUDO_CALL;
> >>   }
> >>
> >
> > [...]
> >
> >>          map = env->used_maps[aux->map_index];
> >>          mark_reg_known_zero(env, regs, insn->dst_reg);
> >>          dst_reg->map_ptr = map;
> >> @@ -8195,9 +8361,23 @@ static int visit_insn(int t, int insn_cnt, struct bpf_verifier_env *env)
> >>
> >>          /* All non-branch instructions have a single fall-through edge. */
> >>          if (BPF_CLASS(insns[t].code) != BPF_JMP &&
> >> -           BPF_CLASS(insns[t].code) != BPF_JMP32)
> >> +           BPF_CLASS(insns[t].code) != BPF_JMP32 &&
> >> +           !bpf_pseudo_func(insns + t))
> >>                  return push_insn(t, t + 1, FALLTHROUGH, env, false);
> >>
> >> +       if (bpf_pseudo_func(insns + t)) {
> >
> >
> > if you check this before above JMP|JMP32 check, you won't need to do
> > !bpf_pseudo_func, right? I think it's cleaner.
>
> Agree. will change in v2.
>
> >
> >> +               ret = push_insn(t, t + 1, FALLTHROUGH, env, false);
> >> +               if (ret)
> >> +                       return ret;
> >> +
> >> +               if (t + 1 < insn_cnt)
> >> +                       init_explored_state(env, t + 1);
> >> +               init_explored_state(env, t);
> >> +               ret = push_insn(t, t + insns[t].imm + 1, BRANCH,
> >> +                               env, false);
> >> +               return ret;
> >> +       }
> >> +
> >>          switch (BPF_OP(insns[t].code)) {
> >>          case BPF_EXIT:
> >>                  return DONE_EXPLORING;
> >
> > [...]
> >
diff mbox series

Patch

diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index 321966fc35db..c8b72ae16cc5 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -40,6 +40,7 @@  struct bpf_local_storage;
 struct bpf_local_storage_map;
 struct kobject;
 struct mem_cgroup;
+struct bpf_func_state;
 
 extern struct idr btf_idr;
 extern spinlock_t btf_idr_lock;
@@ -130,6 +131,13 @@  struct bpf_map_ops {
 	bool (*map_meta_equal)(const struct bpf_map *meta0,
 			       const struct bpf_map *meta1);
 
+
+	int (*map_set_for_each_callback_args)(struct bpf_verifier_env *env,
+					      struct bpf_func_state *caller,
+					      struct bpf_func_state *callee);
+	int (*map_for_each_callback)(struct bpf_map *map, void *callback_fn,
+				     void *callback_ctx, u64 flags);
+
 	/* BTF name and id of struct allocated by map_alloc */
 	const char * const map_btf_name;
 	int *map_btf_id;
@@ -296,6 +304,8 @@  enum bpf_arg_type {
 	ARG_CONST_ALLOC_SIZE_OR_ZERO,	/* number of allocated bytes requested */
 	ARG_PTR_TO_BTF_ID_SOCK_COMMON,	/* pointer to in-kernel sock_common or bpf-mirrored bpf_sock */
 	ARG_PTR_TO_PERCPU_BTF_ID,	/* pointer to in-kernel percpu type */
+	ARG_PTR_TO_FUNC,	/* pointer to a bpf program function */
+	ARG_PTR_TO_STACK_OR_NULL,	/* pointer to stack or NULL */
 	__BPF_ARG_TYPE_MAX,
 };
 
@@ -412,6 +422,8 @@  enum bpf_reg_type {
 	PTR_TO_RDWR_BUF,	 /* reg points to a read/write buffer */
 	PTR_TO_RDWR_BUF_OR_NULL, /* reg points to a read/write buffer or NULL */
 	PTR_TO_PERCPU_BTF_ID,	 /* reg points to a percpu kernel variable */
+	PTR_TO_FUNC,		 /* reg points to a bpf program function */
+	PTR_TO_MAP_KEY,		 /* reg points to map element key */
 };
 
 /* The information passed from prog-specific *_is_valid_access
@@ -794,6 +806,7 @@  struct bpf_prog_aux {
 	bool func_proto_unreliable;
 	bool sleepable;
 	bool tail_call_reachable;
+	bool with_callback_fn;
 	enum bpf_tramp_prog_type trampoline_prog_type;
 	struct hlist_node tramp_hlist;
 	/* BTF_KIND_FUNC_PROTO for valid attach_btf_id */
@@ -1888,6 +1901,7 @@  extern const struct bpf_func_proto bpf_per_cpu_ptr_proto;
 extern const struct bpf_func_proto bpf_this_cpu_ptr_proto;
 extern const struct bpf_func_proto bpf_ktime_get_coarse_ns_proto;
 extern const struct bpf_func_proto bpf_sock_from_file_proto;
+extern const struct bpf_func_proto bpf_for_each_map_elem_proto;
 
 const struct bpf_func_proto *bpf_tracing_func_proto(
 	enum bpf_func_id func_id, const struct bpf_prog *prog);
diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index dfe6f85d97dd..c4366b3da342 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -68,6 +68,8 @@  struct bpf_reg_state {
 			unsigned long raw1;
 			unsigned long raw2;
 		} raw;
+
+		u32 subprog; /* for PTR_TO_FUNC */
 	};
 	/* For PTR_TO_PACKET, used to find other pointers with the same variable
 	 * offset, so they can share range knowledge.
@@ -204,6 +206,7 @@  struct bpf_func_state {
 	int acquired_refs;
 	struct bpf_reference_state *refs;
 	int allocated_stack;
+	bool with_callback_fn;
 	struct bpf_stack_state *stack;
 };
 
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index c001766adcbc..d55bd4557376 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -393,6 +393,15 @@  enum bpf_link_type {
  *                   is struct/union.
  */
 #define BPF_PSEUDO_BTF_ID	3
+/* insn[0].src_reg:  BPF_PSEUDO_FUNC
+ * insn[0].imm:      insn offset to the func
+ * insn[1].imm:      0
+ * insn[0].off:      0
+ * insn[1].off:      0
+ * ldimm64 rewrite:  address of the function
+ * verifier type:    PTR_TO_FUNC.
+ */
+#define BPF_PSEUDO_FUNC		4
 
 /* when bpf_call->src_reg == BPF_PSEUDO_CALL, bpf_call->imm == pc-relative
  * offset to another bpf function
@@ -3836,6 +3845,24 @@  union bpf_attr {
  *	Return
  *		A pointer to a struct socket on success or NULL if the file is
  *		not a socket.
+ *
+ * long bpf_for_each_map_elem(struct bpf_map *map, void *callback_fn, void *callback_ctx, u64 flags)
+ *	Description
+ *		For each element in **map**, call **callback_fn** function with
+ *		**map**, **callback_ctx** and other map-specific parameters.
+ *		For example, for hash and array maps, the callback signature can
+ *		be `u64 callback_fn(map, map_key, map_value, callback_ctx)`.
+ *		The **callback_fn** should be a static function and
+ *		the **callback_ctx** should be a pointer to the stack.
+ *		The **flags** is used to control certain aspects of the helper.
+ *		Currently, the **flags** must be 0.
+ *
+ *		If **callback_fn** return 0, the helper will continue to the next
+ *		element. If return value is 1, the helper will skip the rest of
+ *		elements and return. Other return values are not used now.
+ *	Return
+ *		0 for success, **-EINVAL** for invalid **flags** or unsupported
+ *		**callback_fn** return value.
  */
 #define __BPF_FUNC_MAPPER(FN)		\
 	FN(unspec),			\
@@ -4001,6 +4028,7 @@  union bpf_attr {
 	FN(ktime_get_coarse_ns),	\
 	FN(ima_inode_hash),		\
 	FN(sock_from_file),		\
+	FN(for_each_map_elem),		\
 	/* */
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
diff --git a/kernel/bpf/bpf_iter.c b/kernel/bpf/bpf_iter.c
index 5454161407f1..5187f49d3216 100644
--- a/kernel/bpf/bpf_iter.c
+++ b/kernel/bpf/bpf_iter.c
@@ -675,3 +675,19 @@  int bpf_iter_run_prog(struct bpf_prog *prog, void *ctx)
 	 */
 	return ret == 0 ? 0 : -EAGAIN;
 }
+
+BPF_CALL_4(bpf_for_each_map_elem, struct bpf_map *, map, void *, callback_fn,
+	   void *, callback_ctx, u64, flags)
+{
+	return map->ops->map_for_each_callback(map, callback_fn, callback_ctx, flags);
+}
+
+const struct bpf_func_proto bpf_for_each_map_elem_proto = {
+	.func		= bpf_for_each_map_elem,
+	.gpl_only	= false,
+	.ret_type	= RET_INTEGER,
+	.arg1_type	= ARG_CONST_MAP_PTR,
+	.arg2_type	= ARG_PTR_TO_FUNC,
+	.arg3_type	= ARG_PTR_TO_STACK_OR_NULL,
+	.arg4_type	= ARG_ANYTHING,
+};
diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
index 308427fe03a3..074800226327 100644
--- a/kernel/bpf/helpers.c
+++ b/kernel/bpf/helpers.c
@@ -708,6 +708,8 @@  bpf_base_func_proto(enum bpf_func_id func_id)
 		return &bpf_ringbuf_discard_proto;
 	case BPF_FUNC_ringbuf_query:
 		return &bpf_ringbuf_query_proto;
+	case BPF_FUNC_for_each_map_elem:
+		return &bpf_for_each_map_elem_proto;
 	default:
 		break;
 	}
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index db294b75d03b..050b067a0be6 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -234,6 +234,12 @@  static bool bpf_pseudo_call(const struct bpf_insn *insn)
 	       insn->src_reg == BPF_PSEUDO_CALL;
 }
 
+static bool bpf_pseudo_func(const struct bpf_insn *insn)
+{
+	return insn->code == (BPF_LD | BPF_IMM | BPF_DW) &&
+	       insn->src_reg == BPF_PSEUDO_FUNC;
+}
+
 struct bpf_call_arg_meta {
 	struct bpf_map *map_ptr;
 	bool raw_mode;
@@ -409,6 +415,7 @@  static bool reg_type_not_null(enum bpf_reg_type type)
 	return type == PTR_TO_SOCKET ||
 		type == PTR_TO_TCP_SOCK ||
 		type == PTR_TO_MAP_VALUE ||
+		type == PTR_TO_MAP_KEY ||
 		type == PTR_TO_SOCK_COMMON;
 }
 
@@ -451,7 +458,8 @@  static bool arg_type_may_be_null(enum bpf_arg_type type)
 	       type == ARG_PTR_TO_MEM_OR_NULL ||
 	       type == ARG_PTR_TO_CTX_OR_NULL ||
 	       type == ARG_PTR_TO_SOCKET_OR_NULL ||
-	       type == ARG_PTR_TO_ALLOC_MEM_OR_NULL;
+	       type == ARG_PTR_TO_ALLOC_MEM_OR_NULL ||
+	       type == ARG_PTR_TO_STACK_OR_NULL;
 }
 
 /* Determine whether the function releases some resources allocated by another
@@ -534,6 +542,8 @@  static const char * const reg_type_str[] = {
 	[PTR_TO_RDONLY_BUF_OR_NULL] = "rdonly_buf_or_null",
 	[PTR_TO_RDWR_BUF]	= "rdwr_buf",
 	[PTR_TO_RDWR_BUF_OR_NULL] = "rdwr_buf_or_null",
+	[PTR_TO_FUNC]		= "func",
+	[PTR_TO_MAP_KEY]	= "map_key",
 };
 
 static char slot_type_char[] = {
@@ -605,6 +615,7 @@  static void print_verifier_state(struct bpf_verifier_env *env,
 			if (type_is_pkt_pointer(t))
 				verbose(env, ",r=%d", reg->range);
 			else if (t == CONST_PTR_TO_MAP ||
+				 t == PTR_TO_MAP_KEY ||
 				 t == PTR_TO_MAP_VALUE ||
 				 t == PTR_TO_MAP_VALUE_OR_NULL)
 				verbose(env, ",ks=%d,vs=%d",
@@ -1492,6 +1503,19 @@  static int check_subprogs(struct bpf_verifier_env *env)
 
 	/* determine subprog starts. The end is one before the next starts */
 	for (i = 0; i < insn_cnt; i++) {
+		if (bpf_pseudo_func(insn + i)) {
+			if (!env->bpf_capable) {
+				verbose(env,
+					"function pointers are allowed for CAP_BPF and CAP_SYS_ADMIN\n");
+				return -EPERM;
+			}
+			ret = add_subprog(env, i + insn[i].imm + 1);
+			if (ret < 0)
+				return ret;
+			/* remember subprog */
+			insn[i + 1].imm = find_subprog(env, i + insn[i].imm + 1);
+			continue;
+		}
 		if (!bpf_pseudo_call(insn + i))
 			continue;
 		if (!env->bpf_capable) {
@@ -2223,6 +2247,8 @@  static bool is_spillable_regtype(enum bpf_reg_type type)
 	case PTR_TO_PERCPU_BTF_ID:
 	case PTR_TO_MEM:
 	case PTR_TO_MEM_OR_NULL:
+	case PTR_TO_FUNC:
+	case PTR_TO_MAP_KEY:
 		return true;
 	default:
 		return false;
@@ -2567,6 +2593,10 @@  static int __check_mem_access(struct bpf_verifier_env *env, int regno,
 
 	reg = &cur_regs(env)[regno];
 	switch (reg->type) {
+	case PTR_TO_MAP_KEY:
+		verbose(env, "invalid access to map key, key_size=%d off=%d size=%d\n",
+			mem_size, off, size);
+		break;
 	case PTR_TO_MAP_VALUE:
 		verbose(env, "invalid access to map value, value_size=%d off=%d size=%d\n",
 			mem_size, off, size);
@@ -2977,6 +3007,9 @@  static int check_ptr_alignment(struct bpf_verifier_env *env,
 	case PTR_TO_FLOW_KEYS:
 		pointer_desc = "flow keys ";
 		break;
+	case PTR_TO_MAP_KEY:
+		pointer_desc = "key ";
+		break;
 	case PTR_TO_MAP_VALUE:
 		pointer_desc = "value ";
 		break;
@@ -3078,7 +3111,7 @@  static int check_max_stack_depth(struct bpf_verifier_env *env)
 continue_func:
 	subprog_end = subprog[idx + 1].start;
 	for (; i < subprog_end; i++) {
-		if (!bpf_pseudo_call(insn + i))
+		if (!bpf_pseudo_call(insn + i) && !bpf_pseudo_func(insn + i))
 			continue;
 		/* remember insn and function to return to */
 		ret_insn[frame] = i + 1;
@@ -3430,7 +3463,19 @@  static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
 	/* for access checks, reg->off is just part of off */
 	off += reg->off;
 
-	if (reg->type == PTR_TO_MAP_VALUE) {
+	if (reg->type == PTR_TO_MAP_KEY) {
+		if (t == BPF_WRITE) {
+			verbose(env, "write to change key R%d not allowed\n", regno);
+			return -EACCES;
+		}
+
+		err = check_mem_region_access(env, regno, off, size,
+					      reg->map_ptr->key_size, false);
+		if (err)
+			return err;
+		if (value_regno >= 0)
+			mark_reg_unknown(env, regs, value_regno);
+	} else if (reg->type == PTR_TO_MAP_VALUE) {
 		if (t == BPF_WRITE && value_regno >= 0 &&
 		    is_pointer_value(env, value_regno)) {
 			verbose(env, "R%d leaks addr into map\n", value_regno);
@@ -3858,6 +3903,9 @@  static int check_helper_mem_access(struct bpf_verifier_env *env, int regno,
 	case PTR_TO_PACKET_META:
 		return check_packet_access(env, regno, reg->off, access_size,
 					   zero_size_allowed);
+	case PTR_TO_MAP_KEY:
+		return check_mem_region_access(env, regno, reg->off, access_size,
+					       reg->map_ptr->key_size, false);
 	case PTR_TO_MAP_VALUE:
 		if (check_map_access_type(env, regno, reg->off, access_size,
 					  meta && meta->raw_mode ? BPF_WRITE :
@@ -4049,6 +4097,7 @@  static const struct bpf_reg_types map_key_value_types = {
 		PTR_TO_STACK,
 		PTR_TO_PACKET,
 		PTR_TO_PACKET_META,
+		PTR_TO_MAP_KEY,
 		PTR_TO_MAP_VALUE,
 	},
 };
@@ -4080,6 +4129,7 @@  static const struct bpf_reg_types mem_types = {
 		PTR_TO_STACK,
 		PTR_TO_PACKET,
 		PTR_TO_PACKET_META,
+		PTR_TO_MAP_KEY,
 		PTR_TO_MAP_VALUE,
 		PTR_TO_MEM,
 		PTR_TO_RDONLY_BUF,
@@ -4092,6 +4142,7 @@  static const struct bpf_reg_types int_ptr_types = {
 		PTR_TO_STACK,
 		PTR_TO_PACKET,
 		PTR_TO_PACKET_META,
+		PTR_TO_MAP_KEY,
 		PTR_TO_MAP_VALUE,
 	},
 };
@@ -4104,6 +4155,8 @@  static const struct bpf_reg_types const_map_ptr_types = { .types = { CONST_PTR_T
 static const struct bpf_reg_types btf_ptr_types = { .types = { PTR_TO_BTF_ID } };
 static const struct bpf_reg_types spin_lock_types = { .types = { PTR_TO_MAP_VALUE } };
 static const struct bpf_reg_types percpu_btf_ptr_types = { .types = { PTR_TO_PERCPU_BTF_ID } };
+static const struct bpf_reg_types func_ptr_types = { .types = { PTR_TO_FUNC } };
+static const struct bpf_reg_types stack_ptr_types = { .types = { PTR_TO_STACK } };
 
 static const struct bpf_reg_types *compatible_reg_types[__BPF_ARG_TYPE_MAX] = {
 	[ARG_PTR_TO_MAP_KEY]		= &map_key_value_types,
@@ -4132,6 +4185,8 @@  static const struct bpf_reg_types *compatible_reg_types[__BPF_ARG_TYPE_MAX] = {
 	[ARG_PTR_TO_INT]		= &int_ptr_types,
 	[ARG_PTR_TO_LONG]		= &int_ptr_types,
 	[ARG_PTR_TO_PERCPU_BTF_ID]	= &percpu_btf_ptr_types,
+	[ARG_PTR_TO_FUNC]		= &func_ptr_types,
+	[ARG_PTR_TO_STACK_OR_NULL]	= &stack_ptr_types,
 };
 
 static int check_reg_type(struct bpf_verifier_env *env, u32 regno,
@@ -4932,12 +4987,92 @@  static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 	return 0;
 }
 
+static int check_map_elem_callback(struct bpf_verifier_env *env, int *insn_idx)
+{
+	struct bpf_verifier_state *state = env->cur_state;
+	struct bpf_prog_aux *aux = env->prog->aux;
+	struct bpf_func_state *caller, *callee;
+	struct bpf_map *map;
+	int err, subprog;
+
+	if (state->curframe + 1 >= MAX_CALL_FRAMES) {
+		verbose(env, "the call stack of %d frames is too deep\n",
+			state->curframe + 2);
+		return -E2BIG;
+	}
+
+	caller = state->frame[state->curframe];
+	if (state->frame[state->curframe + 1]) {
+		verbose(env, "verifier bug. Frame %d already allocated\n",
+			state->curframe + 1);
+		return -EFAULT;
+	}
+
+	caller->with_callback_fn = true;
+
+	callee = kzalloc(sizeof(*callee), GFP_KERNEL);
+	if (!callee)
+		return -ENOMEM;
+	state->frame[state->curframe + 1] = callee;
+
+	/* callee cannot access r0, r6 - r9 for reading and has to write
+	 * into its own stack before reading from it.
+	 * callee can read/write into caller's stack
+	 */
+	init_func_state(env, callee,
+			/* remember the callsite, it will be used by bpf_exit */
+			*insn_idx /* callsite */,
+			state->curframe + 1 /* frameno within this callchain */,
+			subprog /* subprog number within this prog */);
+
+	/* Transfer references to the callee */
+	err = transfer_reference_state(callee, caller);
+	if (err)
+		return err;
+
+	subprog = caller->regs[BPF_REG_2].subprog;
+	if (aux->func_info && aux->func_info_aux[subprog].linkage != BTF_FUNC_STATIC) {
+		verbose(env, "callback function R2 not static\n");
+		return -EINVAL;
+	}
+
+	map = caller->regs[BPF_REG_1].map_ptr;
+	if (!map->ops->map_set_for_each_callback_args ||
+	    !map->ops->map_for_each_callback) {
+		verbose(env, "callback function not allowed for map R1\n");
+		return -ENOTSUPP;
+	}
+
+	/* the following is only for hashmap, different maps
+	 * can have different callback signatures.
+	 */
+	err = map->ops->map_set_for_each_callback_args(env, caller, callee);
+	if (err)
+		return err;
+
+	clear_caller_saved_regs(env, caller->regs);
+
+	/* only increment it after check_reg_arg() finished */
+	state->curframe++;
+
+	/* and go analyze first insn of the callee */
+	*insn_idx = env->subprog_info[subprog].start - 1;
+
+	if (env->log.level & BPF_LOG_LEVEL) {
+		verbose(env, "caller:\n");
+		print_verifier_state(env, caller);
+		verbose(env, "callee:\n");
+		print_verifier_state(env, callee);
+	}
+	return 0;
+}
+
 static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
 {
 	struct bpf_verifier_state *state = env->cur_state;
 	struct bpf_func_state *caller, *callee;
 	struct bpf_reg_state *r0;
-	int err;
+	int i, err;
 
 	callee = state->frame[state->curframe];
 	r0 = &callee->regs[BPF_REG_0];
@@ -4955,7 +5090,17 @@  static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
 	state->curframe--;
 	caller = state->frame[state->curframe];
 	/* return to the caller whatever r0 had in the callee */
-	caller->regs[BPF_REG_0] = *r0;
+	if (caller->with_callback_fn) {
+		/* reset caller saved regs, the helper calling callback_fn
+		 * has RET_INTEGER return types.
+		 */
+		for (i = 0; i < CALLER_SAVED_REGS; i++)
+			mark_reg_not_init(env, caller->regs, caller_saved[i]);
+		caller->regs[BPF_REG_0].subreg_def = DEF_NOT_SUBREG;
+		mark_reg_unknown(env, caller->regs, BPF_REG_0);
+	} else {
+		caller->regs[BPF_REG_0] = *r0;
+	}
 
 	/* Transfer references to the caller */
 	err = transfer_reference_state(caller, callee);
@@ -5091,7 +5236,8 @@  static int check_reference_leak(struct bpf_verifier_env *env)
 	return state->acquired_refs ? -EINVAL : 0;
 }
 
-static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn_idx)
+static int check_helper_call(struct bpf_verifier_env *env, int func_id, int *insn_idx,
+			     bool map_elem_callback)
 {
 	const struct bpf_func_proto *fn = NULL;
 	struct bpf_reg_state *regs;
@@ -5151,11 +5297,11 @@  static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
 			return err;
 	}
 
-	err = record_func_map(env, &meta, func_id, insn_idx);
+	err = record_func_map(env, &meta, func_id, *insn_idx);
 	if (err)
 		return err;
 
-	err = record_func_key(env, &meta, func_id, insn_idx);
+	err = record_func_key(env, &meta, func_id, *insn_idx);
 	if (err)
 		return err;
 
@@ -5163,7 +5309,7 @@  static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
 	 * is inferred from register state.
 	 */
 	for (i = 0; i < meta.access_size; i++) {
-		err = check_mem_access(env, insn_idx, meta.regno, i, BPF_B,
+		err = check_mem_access(env, *insn_idx, meta.regno, i, BPF_B,
 				       BPF_WRITE, -1, false);
 		if (err)
 			return err;
@@ -5195,6 +5341,11 @@  static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
 		return -EINVAL;
 	}
 
+	if (map_elem_callback) {
+		env->prog->aux->with_callback_fn = true;
+		return check_map_elem_callback(env, insn_idx);
+	}
+
 	/* reset caller saved regs */
 	for (i = 0; i < CALLER_SAVED_REGS; i++) {
 		mark_reg_not_init(env, regs, caller_saved[i]);
@@ -5306,7 +5457,7 @@  static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
 		/* For release_reference() */
 		regs[BPF_REG_0].ref_obj_id = meta.ref_obj_id;
 	} else if (is_acquire_function(func_id, meta.map_ptr)) {
-		int id = acquire_reference_state(env, insn_idx);
+		int id = acquire_reference_state(env, *insn_idx);
 
 		if (id < 0)
 			return id;
@@ -5448,6 +5599,14 @@  static int retrieve_ptr_limit(const struct bpf_reg_state *ptr_reg,
 		else
 			*ptr_limit = -off;
 		return 0;
+	case PTR_TO_MAP_KEY:
+		if (mask_to_left) {
+			*ptr_limit = ptr_reg->umax_value + ptr_reg->off;
+		} else {
+			off = ptr_reg->smin_value + ptr_reg->off;
+			*ptr_limit = ptr_reg->map_ptr->key_size - off;
+		}
+		return 0;
 	case PTR_TO_MAP_VALUE:
 		if (mask_to_left) {
 			*ptr_limit = ptr_reg->umax_value + ptr_reg->off;
@@ -5614,6 +5773,7 @@  static int adjust_ptr_min_max_vals(struct bpf_verifier_env *env,
 		verbose(env, "R%d pointer arithmetic on %s prohibited\n",
 			dst, reg_type_str[ptr_reg->type]);
 		return -EACCES;
+	case PTR_TO_MAP_KEY:
 	case PTR_TO_MAP_VALUE:
 		if (!env->allow_ptr_leaks && !known && (smin_val < 0) != (smax_val < 0)) {
 			verbose(env, "R%d has unknown scalar with mixed signed bounds, pointer arithmetic with it prohibited for !root\n",
@@ -7818,6 +7978,12 @@  static int check_ld_imm(struct bpf_verifier_env *env, struct bpf_insn *insn)
 		return 0;
 	}
 
+	if (insn->src_reg == BPF_PSEUDO_FUNC) {
+		dst_reg->type = PTR_TO_FUNC;
+		dst_reg->subprog = insn[1].imm;
+		return 0;
+	}
+
 	map = env->used_maps[aux->map_index];
 	mark_reg_known_zero(env, regs, insn->dst_reg);
 	dst_reg->map_ptr = map;
@@ -8195,9 +8361,23 @@  static int visit_insn(int t, int insn_cnt, struct bpf_verifier_env *env)
 
 	/* All non-branch instructions have a single fall-through edge. */
 	if (BPF_CLASS(insns[t].code) != BPF_JMP &&
-	    BPF_CLASS(insns[t].code) != BPF_JMP32)
+	    BPF_CLASS(insns[t].code) != BPF_JMP32 &&
+	    !bpf_pseudo_func(insns + t))
 		return push_insn(t, t + 1, FALLTHROUGH, env, false);
 
+	if (bpf_pseudo_func(insns + t)) {
+		ret = push_insn(t, t + 1, FALLTHROUGH, env, false);
+		if (ret)
+			return ret;
+
+		if (t + 1 < insn_cnt)
+			init_explored_state(env, t + 1);
+		init_explored_state(env, t);
+		ret = push_insn(t, t + insns[t].imm + 1, BRANCH,
+				env, false);
+		return ret;
+	}
+
 	switch (BPF_OP(insns[t].code)) {
 	case BPF_EXIT:
 		return DONE_EXPLORING;
@@ -8819,6 +8999,7 @@  static bool regsafe(struct bpf_reg_state *rold, struct bpf_reg_state *rcur,
 			 */
 			return false;
 		}
+	case PTR_TO_MAP_KEY:
 	case PTR_TO_MAP_VALUE:
 		/* If the new min/max/var_off satisfy the old ones and
 		 * everything else matches, we are OK.
@@ -9646,6 +9827,8 @@  static int do_check(struct bpf_verifier_env *env)
 
 			env->jmps_processed++;
 			if (opcode == BPF_CALL) {
+				bool map_elem_callback;
+
 				if (BPF_SRC(insn->code) != BPF_K ||
 				    insn->off != 0 ||
 				    (insn->src_reg != BPF_REG_0 &&
@@ -9662,13 +9845,15 @@  static int do_check(struct bpf_verifier_env *env)
 					verbose(env, "function calls are not allowed while holding a lock\n");
 					return -EINVAL;
 				}
+				map_elem_callback = insn->src_reg != BPF_PSEUDO_CALL &&
+						   insn->imm == BPF_FUNC_for_each_map_elem;
 				if (insn->src_reg == BPF_PSEUDO_CALL)
 					err = check_func_call(env, insn, &env->insn_idx);
 				else
-					err = check_helper_call(env, insn->imm, env->insn_idx);
+					err = check_helper_call(env, insn->imm, &env->insn_idx,
+								map_elem_callback);
 				if (err)
 					return err;
-
 			} else if (opcode == BPF_JA) {
 				if (BPF_SRC(insn->code) != BPF_K ||
 				    insn->imm != 0 ||
@@ -10090,6 +10275,12 @@  static int resolve_pseudo_ldimm64(struct bpf_verifier_env *env)
 				goto next_insn;
 			}
 
+			if (insn[0].src_reg == BPF_PSEUDO_FUNC) {
+				aux = &env->insn_aux_data[i];
+				aux->ptr_type = PTR_TO_FUNC;
+				goto next_insn;
+			}
+
 			/* In final convert_pseudo_ld_imm64() step, this is
 			 * converted into regular 64-bit imm load insn.
 			 */
@@ -10222,9 +10413,13 @@  static void convert_pseudo_ld_imm64(struct bpf_verifier_env *env)
 	int insn_cnt = env->prog->len;
 	int i;
 
-	for (i = 0; i < insn_cnt; i++, insn++)
-		if (insn->code == (BPF_LD | BPF_IMM | BPF_DW))
-			insn->src_reg = 0;
+	for (i = 0; i < insn_cnt; i++, insn++) {
+		if (insn->code != (BPF_LD | BPF_IMM | BPF_DW))
+			continue;
+		if (insn->src_reg == BPF_PSEUDO_FUNC)
+			continue;
+		insn->src_reg = 0;
+	}
 }
 
 /* single env->prog->insni[off] instruction was replaced with the range
@@ -10846,6 +11041,12 @@  static int jit_subprogs(struct bpf_verifier_env *env)
 		return 0;
 
 	for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) {
+		if (bpf_pseudo_func(insn)) {
+			env->insn_aux_data[i].call_imm = insn->imm;
+			/* subprog is encoded in insn[1].imm */
+			continue;
+		}
+
 		if (!bpf_pseudo_call(insn))
 			continue;
 		/* Upon error here we cannot fall back to interpreter but
@@ -10975,6 +11176,12 @@  static int jit_subprogs(struct bpf_verifier_env *env)
 	for (i = 0; i < env->subprog_cnt; i++) {
 		insn = func[i]->insnsi;
 		for (j = 0; j < func[i]->len; j++, insn++) {
+			if (bpf_pseudo_func(insn)) {
+				subprog = insn[1].imm;
+				insn[0].imm = (u32)(long)func[subprog]->bpf_func;
+				insn[1].imm = ((u64)(long)func[subprog]->bpf_func) >> 32;
+				continue;
+			}
 			if (!bpf_pseudo_call(insn))
 				continue;
 			subprog = insn->off;
@@ -11020,6 +11227,11 @@  static int jit_subprogs(struct bpf_verifier_env *env)
 	 * later look the same as if they were interpreted only.
 	 */
 	for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) {
+		if (bpf_pseudo_func(insn)) {
+			insn[0].imm = env->insn_aux_data[i].call_imm;
+			insn[1].imm = find_subprog(env, i + insn[0].imm + 1);
+			continue;
+		}
 		if (!bpf_pseudo_call(insn))
 			continue;
 		insn->off = env->insn_aux_data[i].call_imm;
@@ -11083,6 +11295,13 @@  static int fixup_call_args(struct bpf_verifier_env *env)
 		verbose(env, "tail_calls are not allowed in non-JITed programs with bpf-to-bpf calls\n");
 		return -EINVAL;
 	}
+	if (env->subprog_cnt > 1 && env->prog->aux->with_callback_fn) {
+		/* When JIT fails the progs with callback calls
+		 * have to be rejected, since interpreter doesn't support them yet.
+		 */
+		verbose(env, "callbacks are not allowed in non-JITed programs\n");
+		return -EINVAL;
+	}
 	for (i = 0; i < prog->len; i++, insn++) {
 		if (!bpf_pseudo_call(insn))
 			continue;
diff --git a/kernel/trace/bpf_trace.c b/kernel/trace/bpf_trace.c
index 6c0018abe68a..8338333bfeb0 100644
--- a/kernel/trace/bpf_trace.c
+++ b/kernel/trace/bpf_trace.c
@@ -1366,6 +1366,8 @@  bpf_tracing_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
 		return &bpf_per_cpu_ptr_proto;
 	case BPF_FUNC_this_cpu_ptr:
 		return &bpf_this_cpu_ptr_proto;
+	case BPF_FUNC_for_each_map_elem:
+		return &bpf_for_each_map_elem_proto;
 	default:
 		return NULL;
 	}
diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
index c001766adcbc..d55bd4557376 100644
--- a/tools/include/uapi/linux/bpf.h
+++ b/tools/include/uapi/linux/bpf.h
@@ -393,6 +393,15 @@  enum bpf_link_type {
  *                   is struct/union.
  */
 #define BPF_PSEUDO_BTF_ID	3
+/* insn[0].src_reg:  BPF_PSEUDO_FUNC
+ * insn[0].imm:      insn offset to the func
+ * insn[1].imm:      0
+ * insn[0].off:      0
+ * insn[1].off:      0
+ * ldimm64 rewrite:  address of the function
+ * verifier type:    PTR_TO_FUNC.
+ */
+#define BPF_PSEUDO_FUNC		4
 
 /* when bpf_call->src_reg == BPF_PSEUDO_CALL, bpf_call->imm == pc-relative
  * offset to another bpf function
@@ -3836,6 +3845,24 @@  union bpf_attr {
  *	Return
  *		A pointer to a struct socket on success or NULL if the file is
  *		not a socket.
+ *
+ * long bpf_for_each_map_elem(struct bpf_map *map, void *callback_fn, void *callback_ctx, u64 flags)
+ *	Description
+ *		For each element in **map**, call **callback_fn** function with
+ *		**map**, **callback_ctx** and other map-specific parameters.
+ *		For example, for hash and array maps, the callback signature can
+ *		be `u64 callback_fn(map, map_key, map_value, callback_ctx)`.
+ *		The **callback_fn** should be a static function and
+ *		the **callback_ctx** should be a pointer to the stack.
+ *		The **flags** is used to control certain aspects of the helper.
+ *		Currently, the **flags** must be 0.
+ *
+ *		If **callback_fn** return 0, the helper will continue to the next
+ *		element. If return value is 1, the helper will skip the rest of
+ *		elements and return. Other return values are not used now.
+ *	Return
+ *		0 for success, **-EINVAL** for invalid **flags** or unsupported
+ *		**callback_fn** return value.
  */
 #define __BPF_FUNC_MAPPER(FN)		\
 	FN(unspec),			\
@@ -4001,6 +4028,7 @@  union bpf_attr {
 	FN(ktime_get_coarse_ns),	\
 	FN(ima_inode_hash),		\
 	FN(sock_from_file),		\
+	FN(for_each_map_elem),		\
 	/* */
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper