diff mbox series

[RFC,bpf-next,v3,04/16] bpf/helpers: introduce sleepable bpf_timers

Message ID 20240221-hid-bpf-sleepable-v3-4-1fb378ca6301@kernel.org (mailing list archive)
State New
Headers show
Series sleepable bpf_timer (was: allow HID-BPF to do device IOs) | expand

Commit Message

Benjamin Tissoires Feb. 21, 2024, 4:25 p.m. UTC
They are implemented as a workqueue, which means that there are no
guarantees of timing nor ordering.

Signed-off-by: Benjamin Tissoires <bentiss@kernel.org>

---

changes in v3:
- extracted the implementation in bpf_timer only, without
  bpf_timer_set_sleepable_cb()
- rely on schedule_work() only, from bpf_timer_start()
- add semaphore to ensure bpf_timer_work_cb() is accessing
  consistent data

changes in v2 (compared to the one attaches to v1 0/9):
- make use of a kfunc
- add a (non-used) BPF_F_TIMER_SLEEPABLE
- the callback is *not* called, it makes the kernel crashes
---
 include/uapi/linux/bpf.h |  4 +++
 kernel/bpf/helpers.c     | 92 ++++++++++++++++++++++++++++++++++++++++--------
 2 files changed, 82 insertions(+), 14 deletions(-)

Comments

Benjamin Tissoires Feb. 22, 2024, 8:05 a.m. UTC | #1
On Feb 21 2024, Benjamin Tissoires wrote:
> They are implemented as a workqueue, which means that there are no
> guarantees of timing nor ordering.
> 
> Signed-off-by: Benjamin Tissoires <bentiss@kernel.org>
> 
> ---
> 
> changes in v3:
> - extracted the implementation in bpf_timer only, without
>   bpf_timer_set_sleepable_cb()
> - rely on schedule_work() only, from bpf_timer_start()
> - add semaphore to ensure bpf_timer_work_cb() is accessing
>   consistent data
> 
> changes in v2 (compared to the one attaches to v1 0/9):
> - make use of a kfunc
> - add a (non-used) BPF_F_TIMER_SLEEPABLE
> - the callback is *not* called, it makes the kernel crashes
> ---
>  include/uapi/linux/bpf.h |  4 +++
>  kernel/bpf/helpers.c     | 92 ++++++++++++++++++++++++++++++++++++++++--------
>  2 files changed, 82 insertions(+), 14 deletions(-)
> 
> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
> index d96708380e52..1fc7ecbd9d33 100644
> --- a/include/uapi/linux/bpf.h
> +++ b/include/uapi/linux/bpf.h
> @@ -7421,10 +7421,14 @@ struct bpf_core_relo {
>   *     - BPF_F_TIMER_ABS: Timeout passed is absolute time, by default it is
>   *       relative to current time.
>   *     - BPF_F_TIMER_CPU_PIN: Timer will be pinned to the CPU of the caller.
> + *     - BPF_F_TIMER_SLEEPABLE: Timer will run in a sleepable context, with
> + *       no guarantees of ordering nor timing (consider this as being just
> + *       offloaded immediately).
>   */
>  enum {
>  	BPF_F_TIMER_ABS = (1ULL << 0),
>  	BPF_F_TIMER_CPU_PIN = (1ULL << 1),
> +	BPF_F_TIMER_SLEEPABLE = (1ULL << 2),
>  };
>  
>  /* BPF numbers iterator state */
> diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
> index 93edf730d288..f9add0abe40a 100644
> --- a/kernel/bpf/helpers.c
> +++ b/kernel/bpf/helpers.c
> @@ -23,6 +23,7 @@
>  #include <linux/btf_ids.h>
>  #include <linux/bpf_mem_alloc.h>
>  #include <linux/kasan.h>
> +#include <linux/semaphore.h>
>  
>  #include "../../lib/kstrtox.h"
>  
> @@ -1094,13 +1095,19 @@ const struct bpf_func_proto bpf_snprintf_proto = {
>   * bpf_timer_cancel() cancels the timer and decrements prog's refcnt.
>   * Inner maps can contain bpf timers as well. ops->map_release_uref is
>   * freeing the timers when inner map is replaced or deleted by user space.
> + *
> + * sleepable_lock protects only the setup of the workqueue, not the callback
> + * itself. This is done to ensure we don't run concurrently a free of the
> + * callback or the associated program.
>   */
>  struct bpf_hrtimer {
>  	struct hrtimer timer;
> +	struct work_struct work;
>  	struct bpf_map *map;
>  	struct bpf_prog *prog;
>  	void __rcu *callback_fn;
>  	void *value;
> +	struct semaphore sleepable_lock;
>  };
>  
>  /* the actual struct hidden inside uapi struct bpf_timer */
> @@ -1113,6 +1120,55 @@ struct bpf_timer_kern {
>  	struct bpf_spin_lock lock;
>  } __attribute__((aligned(8)));
>  
> +static u32 __bpf_timer_compute_key(struct bpf_hrtimer *timer)
> +{
> +	struct bpf_map *map = timer->map;
> +	void *value = timer->value;
> +
> +	if (map->map_type == BPF_MAP_TYPE_ARRAY) {
> +		struct bpf_array *array = container_of(map, struct bpf_array, map);
> +
> +		/* compute the key */
> +		return ((char *)value - array->value) / array->elem_size;
> +	}
> +
> +	/* hash or lru */
> +	return *(u32 *)(value - round_up(map->key_size, 8));
> +}
> +
> +static void bpf_timer_work_cb(struct work_struct *work)
> +{
> +	struct bpf_hrtimer *t = container_of(work, struct bpf_hrtimer, work);
> +	struct bpf_map *map = t->map;
> +	void *value = t->value;
> +	bpf_callback_t callback_fn;
> +	u32 key;
> +
> +	BTF_TYPE_EMIT(struct bpf_timer);
> +
> +	down(&t->sleepable_lock);
> +
> +	callback_fn = READ_ONCE(t->callback_fn);
> +	if (!callback_fn) {
> +		up(&t->sleepable_lock);
> +		return;
> +	}
> +
> +	key = __bpf_timer_compute_key(t);
> +
> +	/* prevent the callback to be freed by bpf_timer_cancel() while running
> +	 * so we can release the semaphore
> +	 */
> +	bpf_prog_inc(t->prog);
> +
> +	up(&t->sleepable_lock);
> +
> +	callback_fn((u64)(long)map, (u64)(long)&key, (u64)(long)value, 0, 0);
> +	/* The verifier checked that return value is zero. */
> +
> +	bpf_prog_put(t->prog);
> +}
> +
>  static DEFINE_PER_CPU(struct bpf_hrtimer *, hrtimer_running);
>  
>  static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
> @@ -1121,8 +1177,7 @@ static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
>  	struct bpf_map *map = t->map;
>  	void *value = t->value;
>  	bpf_callback_t callback_fn;
> -	void *key;
> -	u32 idx;
> +	u32 key;
>  
>  	BTF_TYPE_EMIT(struct bpf_timer);
>  	callback_fn = rcu_dereference_check(t->callback_fn, rcu_read_lock_bh_held());
> @@ -1136,17 +1191,9 @@ static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
>  	 * bpf_map_delete_elem() on the same timer.
>  	 */
>  	this_cpu_write(hrtimer_running, t);
> -	if (map->map_type == BPF_MAP_TYPE_ARRAY) {
> -		struct bpf_array *array = container_of(map, struct bpf_array, map);
> -
> -		/* compute the key */
> -		idx = ((char *)value - array->value) / array->elem_size;
> -		key = &idx;
> -	} else { /* hash or lru */
> -		key = value - round_up(map->key_size, 8);
> -	}
> +	key = __bpf_timer_compute_key(t);
>  
> -	callback_fn((u64)(long)map, (u64)(long)key, (u64)(long)value, 0, 0);
> +	callback_fn((u64)(long)map, (u64)(long)&key, (u64)(long)value, 0, 0);
>  	/* The verifier checked that return value is zero. */
>  
>  	this_cpu_write(hrtimer_running, NULL);
> @@ -1191,6 +1238,8 @@ BPF_CALL_3(bpf_timer_init, struct bpf_timer_kern *, timer, struct bpf_map *, map
>  	t->prog = NULL;
>  	rcu_assign_pointer(t->callback_fn, NULL);
>  	hrtimer_init(&t->timer, clockid, HRTIMER_MODE_REL_SOFT);
> +	INIT_WORK(&t->work, bpf_timer_work_cb);
> +	sema_init(&t->sleepable_lock, 1);
>  	t->timer.function = bpf_timer_cb;
>  	WRITE_ONCE(timer->timer, t);
>  	/* Guarantee the order between timer->timer and map->usercnt. So
> @@ -1245,6 +1294,7 @@ BPF_CALL_3(bpf_timer_set_callback, struct bpf_timer_kern *, timer, void *, callb
>  		ret = -EPERM;
>  		goto out;
>  	}
> +	down(&t->sleepable_lock);
>  	prev = t->prog;
>  	if (prev != prog) {
>  		/* Bump prog refcnt once. Every bpf_timer_set_callback()
> @@ -1261,6 +1311,7 @@ BPF_CALL_3(bpf_timer_set_callback, struct bpf_timer_kern *, timer, void *, callb
>  		t->prog = prog;
>  	}
>  	rcu_assign_pointer(t->callback_fn, callback_fn);
> +	up(&t->sleepable_lock);
>  out:
>  	__bpf_spin_unlock_irqrestore(&timer->lock);
>  	return ret;
> @@ -1282,7 +1333,7 @@ BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
>  
>  	if (in_nmi())
>  		return -EOPNOTSUPP;
> -	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN))
> +	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN | BPF_F_TIMER_SLEEPABLE))
>  		return -EINVAL;
>  	__bpf_spin_lock_irqsave(&timer->lock);
>  	t = timer->timer;
> @@ -1299,7 +1350,10 @@ BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
>  	if (flags & BPF_F_TIMER_CPU_PIN)
>  		mode |= HRTIMER_MODE_PINNED;
>  
> -	hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);
> +	if (flags & BPF_F_TIMER_SLEEPABLE)
> +		schedule_work(&t->work);
> +	else
> +		hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);
>  out:
>  	__bpf_spin_unlock_irqrestore(&timer->lock);
>  	return ret;
> @@ -1346,13 +1400,21 @@ BPF_CALL_1(bpf_timer_cancel, struct bpf_timer_kern *, timer)
>  		ret = -EDEADLK;
>  		goto out;
>  	}
> +	down(&t->sleepable_lock);

Sigh. I initially used a semaphore because here I wanted to have a
down_trylock() to mimic the behavior of hrtimer. However, this doesn't
work because we don't know who is actually calling bpf_timer_cancel(),
and we might not be able to cancel the timer from other threads. And
actually it doesn't matter because the semaphore is just preventing the
setup of the callback, not the sleepable callback itself so it's fine to
call bpf_timer_cancel() from within the callback itself: the timer will
be freed but the callback will not because the associated prog is
incremented before entering the callback.

Anyway, I better change this as a simple spinlock (or bpf_spinlock).

Also I realized that I still have the RFC in the prefix.
I can repost a v4 with the spinlock change if it is better to not have
the RFC.

Cheers,
Benjamin

>  	drop_prog_refcnt(t);
> +	up(&t->sleepable_lock);
>  out:
>  	__bpf_spin_unlock_irqrestore(&timer->lock);
>  	/* Cancel the timer and wait for associated callback to finish
>  	 * if it was running.
>  	 */
>  	ret = ret ?: hrtimer_cancel(&t->timer);
> +
> +	/* also cancel the sleepable work, but *do not* wait for
> +	 * it to finish if it was running as we might not be in a
> +	 * sleepable context
> +	 */
> +	ret = ret ?: cancel_work(&t->work);
>  	return ret;
>  }
>  
> @@ -1407,6 +1469,8 @@ void bpf_timer_cancel_and_free(void *val)
>  	 */
>  	if (this_cpu_read(hrtimer_running) != t)
>  		hrtimer_cancel(&t->timer);
> +
> +	cancel_work_sync(&t->work);
>  	kfree(t);
>  }
>  
> 
> -- 
> 2.43.0
>
Toke Høiland-Jørgensen Feb. 22, 2024, 11:50 a.m. UTC | #2
Benjamin Tissoires <bentiss@kernel.org> writes:

> @@ -1245,6 +1294,7 @@ BPF_CALL_3(bpf_timer_set_callback, struct bpf_timer_kern *, timer, void *, callb
>  		ret = -EPERM;
>  		goto out;
>  	}
> +	down(&t->sleepable_lock);
>  	prev = t->prog;
>  	if (prev != prog) {
>  		/* Bump prog refcnt once. Every bpf_timer_set_callback()
> @@ -1261,6 +1311,7 @@ BPF_CALL_3(bpf_timer_set_callback, struct bpf_timer_kern *, timer, void *, callb
>  		t->prog = prog;
>  	}
>  	rcu_assign_pointer(t->callback_fn, callback_fn);
> +	up(&t->sleepable_lock);
>  out:
>  	__bpf_spin_unlock_irqrestore(&timer->lock);
>  	return ret;
> @@ -1282,7 +1333,7 @@ BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
>  
>  	if (in_nmi())
>  		return -EOPNOTSUPP;
> -	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN))
> +	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN | BPF_F_TIMER_SLEEPABLE))
>  		return -EINVAL;
>  	__bpf_spin_lock_irqsave(&timer->lock);
>  	t = timer->timer;
> @@ -1299,7 +1350,10 @@ BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
>  	if (flags & BPF_F_TIMER_CPU_PIN)
>  		mode |= HRTIMER_MODE_PINNED;
>  
> -	hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);
> +	if (flags & BPF_F_TIMER_SLEEPABLE)
> +		schedule_work(&t->work);
> +	else
> +		hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);
>  out:
>  	__bpf_spin_unlock_irqrestore(&timer->lock);
>  	return ret;

I think it's a little weird to just ignore the timeout parameter when
called with the sleepable flag. But I guess it can work at least as a
first pass; however, in that case we should enforce that the caller
passes in a timeout of 0, so that if we do add support for a timeout for
sleepable timers in the future, callers will be able to detect this.

-Toke
Alexei Starovoitov Feb. 22, 2024, 8:47 p.m. UTC | #3
On Wed, Feb 21, 2024 at 8:25 AM Benjamin Tissoires <bentiss@kernel.org> wrote:
>  /* the actual struct hidden inside uapi struct bpf_timer */
> @@ -1113,6 +1120,55 @@ struct bpf_timer_kern {
>         struct bpf_spin_lock lock;
>  } __attribute__((aligned(8)));
>
> +static u32 __bpf_timer_compute_key(struct bpf_hrtimer *timer)
> +{
> +       struct bpf_map *map = timer->map;
> +       void *value = timer->value;
> +
> +       if (map->map_type == BPF_MAP_TYPE_ARRAY) {
> +               struct bpf_array *array = container_of(map, struct bpf_array, map);
> +
> +               /* compute the key */
> +               return ((char *)value - array->value) / array->elem_size;
> +       }
> +
> +       /* hash or lru */
> +       return *(u32 *)(value - round_up(map->key_size, 8));
> +}
> +
> +static void bpf_timer_work_cb(struct work_struct *work)
> +{
> +       struct bpf_hrtimer *t = container_of(work, struct bpf_hrtimer, work);
> +       struct bpf_map *map = t->map;
> +       void *value = t->value;
> +       bpf_callback_t callback_fn;
> +       u32 key;
> +
> +       BTF_TYPE_EMIT(struct bpf_timer);
> +
> +       down(&t->sleepable_lock);
> +
> +       callback_fn = READ_ONCE(t->callback_fn);
> +       if (!callback_fn) {
> +               up(&t->sleepable_lock);
> +               return;
> +       }
> +
> +       key = __bpf_timer_compute_key(t);
> +
> +
> +       callback_fn((u64)(long)map, (u64)(long)&key, (u64)(long)value, 0, 0);
> +       /* The verifier checked that return value is zero. */
> +
> +       bpf_prog_put(t->prog);
> +}
> +
>  static DEFINE_PER_CPU(struct bpf_hrtimer *, hrtimer_running);
>
>  static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
> @@ -1121,8 +1177,7 @@ static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
>         struct bpf_map *map = t->map;
>         void *value = t->value;
>         bpf_callback_t callback_fn;
> -       void *key;
> -       u32 idx;
> +       u32 key;
>
>         BTF_TYPE_EMIT(struct bpf_timer);
>         callback_fn = rcu_dereference_check(t->callback_fn, rcu_read_lock_bh_held());
> @@ -1136,17 +1191,9 @@ static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
>          * bpf_map_delete_elem() on the same timer.
>          */
>         this_cpu_write(hrtimer_running, t);
> -       if (map->map_type == BPF_MAP_TYPE_ARRAY) {
> -               struct bpf_array *array = container_of(map, struct bpf_array, map);
> -
> -               /* compute the key */
> -               idx = ((char *)value - array->value) / array->elem_size;
> -               key = &idx;
> -       } else { /* hash or lru */
> -               key = value - round_up(map->key_size, 8);
> -       }
> +       key = __bpf_timer_compute_key(t);

Please don't mix such "cleanup" with main changes.
It's buggy for a hash map.
Instead of passing a pointer to the real key into bpf prog
you're reading the first 4 bytes from the key. Copying it into a temp var
and passing an address to that.
It would have been very painful to debug such a bug if it slipped through,
since bpf prog would sort-of work for 4-byte keys.
Eduard Zingerman Feb. 22, 2024, 10:40 p.m. UTC | #4
On Wed, 2024-02-21 at 17:25 +0100, Benjamin Tissoires wrote:

[...]

> @@ -1282,7 +1333,7 @@ BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
>  
>  	if (in_nmi())
>  		return -EOPNOTSUPP;
> -	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN))
> +	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN | BPF_F_TIMER_SLEEPABLE))
>  		return -EINVAL;
>  	__bpf_spin_lock_irqsave(&timer->lock);
>  	t = timer->timer;
> @@ -1299,7 +1350,10 @@ BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
>  	if (flags & BPF_F_TIMER_CPU_PIN)
>  		mode |= HRTIMER_MODE_PINNED;
>  
> -	hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);
> +	if (flags & BPF_F_TIMER_SLEEPABLE)
> +		schedule_work(&t->work);
> +	else
> +		hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);

It looks like nsecs is simply ignored for sleepable timers.
Should this be hrtimer_start() that waits nsecs and schedules work,
or schedule_delayed_work()? (but it takes delay in jiffies, which is
probably too coarse). Sorry if I miss something.
Benjamin Tissoires Feb. 27, 2024, 2:27 p.m. UTC | #5
On Feb 23 2024, Eduard Zingerman wrote:
> On Wed, 2024-02-21 at 17:25 +0100, Benjamin Tissoires wrote:
> 
> [...]
> 
> > @@ -1282,7 +1333,7 @@ BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
> >  
> >  	if (in_nmi())
> >  		return -EOPNOTSUPP;
> > -	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN))
> > +	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN | BPF_F_TIMER_SLEEPABLE))
> >  		return -EINVAL;
> >  	__bpf_spin_lock_irqsave(&timer->lock);
> >  	t = timer->timer;
> > @@ -1299,7 +1350,10 @@ BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
> >  	if (flags & BPF_F_TIMER_CPU_PIN)
> >  		mode |= HRTIMER_MODE_PINNED;
> >  
> > -	hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);
> > +	if (flags & BPF_F_TIMER_SLEEPABLE)
> > +		schedule_work(&t->work);
> > +	else
> > +		hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);
> 
> It looks like nsecs is simply ignored for sleepable timers.
> Should this be hrtimer_start() that waits nsecs and schedules work,
> or schedule_delayed_work()? (but it takes delay in jiffies, which is
> probably too coarse). Sorry if I miss something.

Yeah, I agree it's confusing, but as mentioned by Toke in his reply, we
should return -EINVAL if a timer value is provided (for now).

Alexei mentioned[0] that he didn't want to mix delays in hrtimers with
workqueue as they are non deterministic. So AFAIU, I should add the only
guarantee we can provide: a sleepable context, and proper delays in
sleepable contexts will be added once we have a better workqueue
selection available.

Cheers,
Benjamin

[0] https://lore.kernel.org/bpf/CAO-hwJKz+eRA+BFLANTrEqz2jQAOANTE3c7eqNJ6wDqJR7jMiQ@mail.gmail.com/T/#md15e431cbcddec9fcaddf1c305234523ed26f7ce
diff mbox series

Patch

diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index d96708380e52..1fc7ecbd9d33 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -7421,10 +7421,14 @@  struct bpf_core_relo {
  *     - BPF_F_TIMER_ABS: Timeout passed is absolute time, by default it is
  *       relative to current time.
  *     - BPF_F_TIMER_CPU_PIN: Timer will be pinned to the CPU of the caller.
+ *     - BPF_F_TIMER_SLEEPABLE: Timer will run in a sleepable context, with
+ *       no guarantees of ordering nor timing (consider this as being just
+ *       offloaded immediately).
  */
 enum {
 	BPF_F_TIMER_ABS = (1ULL << 0),
 	BPF_F_TIMER_CPU_PIN = (1ULL << 1),
+	BPF_F_TIMER_SLEEPABLE = (1ULL << 2),
 };
 
 /* BPF numbers iterator state */
diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
index 93edf730d288..f9add0abe40a 100644
--- a/kernel/bpf/helpers.c
+++ b/kernel/bpf/helpers.c
@@ -23,6 +23,7 @@ 
 #include <linux/btf_ids.h>
 #include <linux/bpf_mem_alloc.h>
 #include <linux/kasan.h>
+#include <linux/semaphore.h>
 
 #include "../../lib/kstrtox.h"
 
@@ -1094,13 +1095,19 @@  const struct bpf_func_proto bpf_snprintf_proto = {
  * bpf_timer_cancel() cancels the timer and decrements prog's refcnt.
  * Inner maps can contain bpf timers as well. ops->map_release_uref is
  * freeing the timers when inner map is replaced or deleted by user space.
+ *
+ * sleepable_lock protects only the setup of the workqueue, not the callback
+ * itself. This is done to ensure we don't run concurrently a free of the
+ * callback or the associated program.
  */
 struct bpf_hrtimer {
 	struct hrtimer timer;
+	struct work_struct work;
 	struct bpf_map *map;
 	struct bpf_prog *prog;
 	void __rcu *callback_fn;
 	void *value;
+	struct semaphore sleepable_lock;
 };
 
 /* the actual struct hidden inside uapi struct bpf_timer */
@@ -1113,6 +1120,55 @@  struct bpf_timer_kern {
 	struct bpf_spin_lock lock;
 } __attribute__((aligned(8)));
 
+static u32 __bpf_timer_compute_key(struct bpf_hrtimer *timer)
+{
+	struct bpf_map *map = timer->map;
+	void *value = timer->value;
+
+	if (map->map_type == BPF_MAP_TYPE_ARRAY) {
+		struct bpf_array *array = container_of(map, struct bpf_array, map);
+
+		/* compute the key */
+		return ((char *)value - array->value) / array->elem_size;
+	}
+
+	/* hash or lru */
+	return *(u32 *)(value - round_up(map->key_size, 8));
+}
+
+static void bpf_timer_work_cb(struct work_struct *work)
+{
+	struct bpf_hrtimer *t = container_of(work, struct bpf_hrtimer, work);
+	struct bpf_map *map = t->map;
+	void *value = t->value;
+	bpf_callback_t callback_fn;
+	u32 key;
+
+	BTF_TYPE_EMIT(struct bpf_timer);
+
+	down(&t->sleepable_lock);
+
+	callback_fn = READ_ONCE(t->callback_fn);
+	if (!callback_fn) {
+		up(&t->sleepable_lock);
+		return;
+	}
+
+	key = __bpf_timer_compute_key(t);
+
+	/* prevent the callback to be freed by bpf_timer_cancel() while running
+	 * so we can release the semaphore
+	 */
+	bpf_prog_inc(t->prog);
+
+	up(&t->sleepable_lock);
+
+	callback_fn((u64)(long)map, (u64)(long)&key, (u64)(long)value, 0, 0);
+	/* The verifier checked that return value is zero. */
+
+	bpf_prog_put(t->prog);
+}
+
 static DEFINE_PER_CPU(struct bpf_hrtimer *, hrtimer_running);
 
 static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
@@ -1121,8 +1177,7 @@  static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
 	struct bpf_map *map = t->map;
 	void *value = t->value;
 	bpf_callback_t callback_fn;
-	void *key;
-	u32 idx;
+	u32 key;
 
 	BTF_TYPE_EMIT(struct bpf_timer);
 	callback_fn = rcu_dereference_check(t->callback_fn, rcu_read_lock_bh_held());
@@ -1136,17 +1191,9 @@  static enum hrtimer_restart bpf_timer_cb(struct hrtimer *hrtimer)
 	 * bpf_map_delete_elem() on the same timer.
 	 */
 	this_cpu_write(hrtimer_running, t);
-	if (map->map_type == BPF_MAP_TYPE_ARRAY) {
-		struct bpf_array *array = container_of(map, struct bpf_array, map);
-
-		/* compute the key */
-		idx = ((char *)value - array->value) / array->elem_size;
-		key = &idx;
-	} else { /* hash or lru */
-		key = value - round_up(map->key_size, 8);
-	}
+	key = __bpf_timer_compute_key(t);
 
-	callback_fn((u64)(long)map, (u64)(long)key, (u64)(long)value, 0, 0);
+	callback_fn((u64)(long)map, (u64)(long)&key, (u64)(long)value, 0, 0);
 	/* The verifier checked that return value is zero. */
 
 	this_cpu_write(hrtimer_running, NULL);
@@ -1191,6 +1238,8 @@  BPF_CALL_3(bpf_timer_init, struct bpf_timer_kern *, timer, struct bpf_map *, map
 	t->prog = NULL;
 	rcu_assign_pointer(t->callback_fn, NULL);
 	hrtimer_init(&t->timer, clockid, HRTIMER_MODE_REL_SOFT);
+	INIT_WORK(&t->work, bpf_timer_work_cb);
+	sema_init(&t->sleepable_lock, 1);
 	t->timer.function = bpf_timer_cb;
 	WRITE_ONCE(timer->timer, t);
 	/* Guarantee the order between timer->timer and map->usercnt. So
@@ -1245,6 +1294,7 @@  BPF_CALL_3(bpf_timer_set_callback, struct bpf_timer_kern *, timer, void *, callb
 		ret = -EPERM;
 		goto out;
 	}
+	down(&t->sleepable_lock);
 	prev = t->prog;
 	if (prev != prog) {
 		/* Bump prog refcnt once. Every bpf_timer_set_callback()
@@ -1261,6 +1311,7 @@  BPF_CALL_3(bpf_timer_set_callback, struct bpf_timer_kern *, timer, void *, callb
 		t->prog = prog;
 	}
 	rcu_assign_pointer(t->callback_fn, callback_fn);
+	up(&t->sleepable_lock);
 out:
 	__bpf_spin_unlock_irqrestore(&timer->lock);
 	return ret;
@@ -1282,7 +1333,7 @@  BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
 
 	if (in_nmi())
 		return -EOPNOTSUPP;
-	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN))
+	if (flags & ~(BPF_F_TIMER_ABS | BPF_F_TIMER_CPU_PIN | BPF_F_TIMER_SLEEPABLE))
 		return -EINVAL;
 	__bpf_spin_lock_irqsave(&timer->lock);
 	t = timer->timer;
@@ -1299,7 +1350,10 @@  BPF_CALL_3(bpf_timer_start, struct bpf_timer_kern *, timer, u64, nsecs, u64, fla
 	if (flags & BPF_F_TIMER_CPU_PIN)
 		mode |= HRTIMER_MODE_PINNED;
 
-	hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);
+	if (flags & BPF_F_TIMER_SLEEPABLE)
+		schedule_work(&t->work);
+	else
+		hrtimer_start(&t->timer, ns_to_ktime(nsecs), mode);
 out:
 	__bpf_spin_unlock_irqrestore(&timer->lock);
 	return ret;
@@ -1346,13 +1400,21 @@  BPF_CALL_1(bpf_timer_cancel, struct bpf_timer_kern *, timer)
 		ret = -EDEADLK;
 		goto out;
 	}
+	down(&t->sleepable_lock);
 	drop_prog_refcnt(t);
+	up(&t->sleepable_lock);
 out:
 	__bpf_spin_unlock_irqrestore(&timer->lock);
 	/* Cancel the timer and wait for associated callback to finish
 	 * if it was running.
 	 */
 	ret = ret ?: hrtimer_cancel(&t->timer);
+
+	/* also cancel the sleepable work, but *do not* wait for
+	 * it to finish if it was running as we might not be in a
+	 * sleepable context
+	 */
+	ret = ret ?: cancel_work(&t->work);
 	return ret;
 }
 
@@ -1407,6 +1469,8 @@  void bpf_timer_cancel_and_free(void *val)
 	 */
 	if (this_cpu_read(hrtimer_running) != t)
 		hrtimer_cancel(&t->timer);
+
+	cancel_work_sync(&t->work);
 	kfree(t);
 }