diff mbox series

[bpf-next,v4,1/6] bpf: Factorize bpf_trace_printk and bpf_seq_printf

Message ID 20210414185406.917890-2-revest@chromium.org (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series Add a snprintf eBPF helper | expand

Checks

Context Check Description
netdev/cover_letter success Link
netdev/fixes_present success Link
netdev/patch_count success Link
netdev/tree_selection success Clearly marked for bpf-next
netdev/subject_prefix success Link
netdev/cc_maintainers warning 6 maintainers not CCed: netdev@vger.kernel.org mingo@redhat.com kafai@fb.com rostedt@goodmis.org john.fastabend@gmail.com songliubraving@fb.com
netdev/source_inline success Was 0 now: 0
netdev/verify_signedoff success Link
netdev/module_param success Was 0 now: 0
netdev/build_32bit success Errors and warnings before: 11919 this patch: 11892
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/verify_fixes success Link
netdev/checkpatch warning CHECK: Alignment should match open parenthesis CHECK: Please use a blank line after function/struct/union/enum declarations WARNING: line length of 81 exceeds 80 columns
netdev/build_allmodconfig_warn success Errors and warnings before: 12406 this patch: 12379
netdev/header_inline success Link

Commit Message

Florent Revest April 14, 2021, 6:54 p.m. UTC
Two helpers (trace_printk and seq_printf) have very similar
implementations of format string parsing and a third one is coming
(snprintf). To avoid code duplication and make the code easier to
maintain, this moves the operations associated with format string
parsing (validation and argument sanitization) into one generic
function.

The implementation of the two existing helpers already drifted quite a
bit so unifying them entailed a lot of changes:

- bpf_trace_printk always expected fmt[fmt_size] to be the terminating
  NULL character, this is no longer true, the first 0 is terminating.
- bpf_trace_printk now supports %% (which produces the percentage char).
- bpf_trace_printk now skips width formating fields.
- bpf_trace_printk now supports the X modifier (capital hexadecimal).
- bpf_trace_printk now supports %pK, %px, %pB, %pi4, %pI4, %pi6 and %pI6
- argument casting on 32 bit has been simplified into one macro and
  using an enum instead of obscure int increments.

- bpf_seq_printf now uses bpf_trace_copy_string instead of
  strncpy_from_kernel_nofault and handles the %pks %pus specifiers.
- bpf_seq_printf now prints longs correctly on 32 bit architectures.

- both were changed to use a global per-cpu tmp buffer instead of one
  stack buffer for trace_printk and 6 small buffers for seq_printf.
- to avoid per-cpu buffer usage conflict, these helpers disable
  preemption while the per-cpu buffer is in use.
- both helpers now support the %ps and %pS specifiers to print symbols.

The implementation is also moved from bpf_trace.c to helpers.c because
the upcoming bpf_snprintf helper will be made available to all BPF
programs and will need it.

Signed-off-by: Florent Revest <revest@chromium.org>
---
 include/linux/bpf.h      |  20 +++
 kernel/bpf/helpers.c     | 254 +++++++++++++++++++++++++++
 kernel/trace/bpf_trace.c | 371 ++++-----------------------------------
 3 files changed, 311 insertions(+), 334 deletions(-)

Comments

Andrii Nakryiko April 15, 2021, 12:37 a.m. UTC | #1
On Wed, Apr 14, 2021 at 11:54 AM Florent Revest <revest@chromium.org> wrote:
>
> Two helpers (trace_printk and seq_printf) have very similar
> implementations of format string parsing and a third one is coming
> (snprintf). To avoid code duplication and make the code easier to
> maintain, this moves the operations associated with format string
> parsing (validation and argument sanitization) into one generic
> function.
>
> The implementation of the two existing helpers already drifted quite a
> bit so unifying them entailed a lot of changes:
>
> - bpf_trace_printk always expected fmt[fmt_size] to be the terminating
>   NULL character, this is no longer true, the first 0 is terminating.
> - bpf_trace_printk now supports %% (which produces the percentage char).
> - bpf_trace_printk now skips width formating fields.
> - bpf_trace_printk now supports the X modifier (capital hexadecimal).
> - bpf_trace_printk now supports %pK, %px, %pB, %pi4, %pI4, %pi6 and %pI6
> - argument casting on 32 bit has been simplified into one macro and
>   using an enum instead of obscure int increments.
>
> - bpf_seq_printf now uses bpf_trace_copy_string instead of
>   strncpy_from_kernel_nofault and handles the %pks %pus specifiers.
> - bpf_seq_printf now prints longs correctly on 32 bit architectures.
>
> - both were changed to use a global per-cpu tmp buffer instead of one
>   stack buffer for trace_printk and 6 small buffers for seq_printf.
> - to avoid per-cpu buffer usage conflict, these helpers disable
>   preemption while the per-cpu buffer is in use.
> - both helpers now support the %ps and %pS specifiers to print symbols.
>
> The implementation is also moved from bpf_trace.c to helpers.c because
> the upcoming bpf_snprintf helper will be made available to all BPF
> programs and will need it.
>
> Signed-off-by: Florent Revest <revest@chromium.org>
> ---
>  include/linux/bpf.h      |  20 +++
>  kernel/bpf/helpers.c     | 254 +++++++++++++++++++++++++++
>  kernel/trace/bpf_trace.c | 371 ++++-----------------------------------
>  3 files changed, 311 insertions(+), 334 deletions(-)
>

[...]

> +static int try_get_fmt_tmp_buf(char **tmp_buf)
> +{
> +       struct bpf_printf_buf *bufs;
> +       int used;
> +
> +       if (*tmp_buf)
> +               return 0;
> +
> +       preempt_disable();
> +       used = this_cpu_inc_return(bpf_printf_buf_used);
> +       if (WARN_ON_ONCE(used > 1)) {
> +               this_cpu_dec(bpf_printf_buf_used);

this makes me uncomfortable. If used > 1, you won't preempt_enable()
here, but you'll decrease count. Then later bpf_printf_cleanup() will
be called (inside bpf_printf_prepare()) and will further decrease
count (which it didn't increase, so it's a mess now).

> +               return -EBUSY;
> +       }
> +       bufs = this_cpu_ptr(&bpf_printf_buf);
> +       *tmp_buf = bufs->tmp_buf;
> +
> +       return 0;
> +}
> +

[...]

> +                       i += 2;
> +                       if (!final_args)
> +                               goto fmt_next;
> +
> +                       if (try_get_fmt_tmp_buf(&tmp_buf)) {
> +                               err = -EBUSY;
> +                               goto out;

this probably should bypass doing bpf_printf_cleanup() and
try_get_fmt_tmp_buf() should enable preemption internally on error.

> +                       }
> +
> +                       copy_size = (fmt[i + 2] == '4') ? 4 : 16;
> +                       if (tmp_buf_len < copy_size) {
> +                               err = -ENOSPC;
> +                               goto out;
> +                       }
> +

[...]

> -static __printf(1, 0) int bpf_do_trace_printk(const char *fmt, ...)
> +BPF_CALL_5(bpf_trace_printk, char *, fmt, u32, fmt_size, u64, arg1,
> +          u64, arg2, u64, arg3)
>  {
> +       u64 args[MAX_TRACE_PRINTK_VARARGS] = { arg1, arg2, arg3 };
> +       enum bpf_printf_mod_type mod[MAX_TRACE_PRINTK_VARARGS];
>         static char buf[BPF_TRACE_PRINTK_SIZE];
>         unsigned long flags;
> -       va_list ap;
>         int ret;
>
> -       raw_spin_lock_irqsave(&trace_printk_lock, flags);
> -       va_start(ap, fmt);
> -       ret = vsnprintf(buf, sizeof(buf), fmt, ap);
> -       va_end(ap);
> -       /* vsnprintf() will not append null for zero-length strings */
> +       ret = bpf_printf_prepare(fmt, fmt_size, args, args, mod,
> +                                MAX_TRACE_PRINTK_VARARGS);
> +       if (ret < 0)
> +               return ret;
> +
> +       ret = snprintf(buf, sizeof(buf), fmt, BPF_CAST_FMT_ARG(0, args, mod),
> +               BPF_CAST_FMT_ARG(1, args, mod), BPF_CAST_FMT_ARG(2, args, mod));
> +       /* snprintf() will not append null for zero-length strings */
>         if (ret == 0)
>                 buf[0] = '\0';
> +
> +       raw_spin_lock_irqsave(&trace_printk_lock, flags);
>         trace_bpf_trace_printk(buf);
>         raw_spin_unlock_irqrestore(&trace_printk_lock, flags);
>
> -       return ret;

see here, no + 1 :(

> -}
> -
> -/*
> - * Only limited trace_printk() conversion specifiers allowed:
> - * %d %i %u %x %ld %li %lu %lx %lld %lli %llu %llx %p %pB %pks %pus %s
> - */

[...]
Florent Revest April 15, 2021, 9:32 a.m. UTC | #2
On Thu, Apr 15, 2021 at 2:38 AM Andrii Nakryiko
<andrii.nakryiko@gmail.com> wrote:
> On Wed, Apr 14, 2021 at 11:54 AM Florent Revest <revest@chromium.org> wrote:
> > +static int try_get_fmt_tmp_buf(char **tmp_buf)
> > +{
> > +       struct bpf_printf_buf *bufs;
> > +       int used;
> > +
> > +       if (*tmp_buf)
> > +               return 0;
> > +
> > +       preempt_disable();
> > +       used = this_cpu_inc_return(bpf_printf_buf_used);
> > +       if (WARN_ON_ONCE(used > 1)) {
> > +               this_cpu_dec(bpf_printf_buf_used);
>
> this makes me uncomfortable. If used > 1, you won't preempt_enable()
> here, but you'll decrease count. Then later bpf_printf_cleanup() will
> be called (inside bpf_printf_prepare()) and will further decrease
> count (which it didn't increase, so it's a mess now).

Awkward, yes. :( This code is untested because it only covers a niche
preempt_rt usecase that is hard to reproduce but I should have thought
harder about these corner cases.

> > +                       i += 2;
> > +                       if (!final_args)
> > +                               goto fmt_next;
> > +
> > +                       if (try_get_fmt_tmp_buf(&tmp_buf)) {
> > +                               err = -EBUSY;
> > +                               goto out;
>
> this probably should bypass doing bpf_printf_cleanup() and
> try_get_fmt_tmp_buf() should enable preemption internally on error.

Yes. I'll fix this and spend some more brain cycles thinking about
what I'm doing. ;)

> > -static __printf(1, 0) int bpf_do_trace_printk(const char *fmt, ...)
> > +BPF_CALL_5(bpf_trace_printk, char *, fmt, u32, fmt_size, u64, arg1,
> > +          u64, arg2, u64, arg3)
> >  {
> > +       u64 args[MAX_TRACE_PRINTK_VARARGS] = { arg1, arg2, arg3 };
> > +       enum bpf_printf_mod_type mod[MAX_TRACE_PRINTK_VARARGS];
> >         static char buf[BPF_TRACE_PRINTK_SIZE];
> >         unsigned long flags;
> > -       va_list ap;
> >         int ret;
> >
> > -       raw_spin_lock_irqsave(&trace_printk_lock, flags);
> > -       va_start(ap, fmt);
> > -       ret = vsnprintf(buf, sizeof(buf), fmt, ap);
> > -       va_end(ap);
> > -       /* vsnprintf() will not append null for zero-length strings */
> > +       ret = bpf_printf_prepare(fmt, fmt_size, args, args, mod,
> > +                                MAX_TRACE_PRINTK_VARARGS);
> > +       if (ret < 0)
> > +               return ret;
> > +
> > +       ret = snprintf(buf, sizeof(buf), fmt, BPF_CAST_FMT_ARG(0, args, mod),
> > +               BPF_CAST_FMT_ARG(1, args, mod), BPF_CAST_FMT_ARG(2, args, mod));
> > +       /* snprintf() will not append null for zero-length strings */
> >         if (ret == 0)
> >                 buf[0] = '\0';
> > +
> > +       raw_spin_lock_irqsave(&trace_printk_lock, flags);
> >         trace_bpf_trace_printk(buf);
> >         raw_spin_unlock_irqrestore(&trace_printk_lock, flags);
> >
> > -       return ret;
>
> see here, no + 1 :(

I wonder if it's a bug or a feature though. The helper documentation
says the helper returns "the number of bytes written to the buffer". I
am not familiar with the internals of trace_printk but if the
terminating \0 is not outputted in the trace_printk buffer, then it
kind of makes sense.

Also, if anyone uses this return value, I can imagine that the usecase
would be if (ret == 0) assume_nothing_was_written(). And if we
suddenly output 1 here, we might break something.

Because the helper is quite old, maybe we should improve the helper
documentation instead? Your call :)
Andrii Nakryiko April 15, 2021, 10:37 p.m. UTC | #3
On Thu, Apr 15, 2021 at 2:33 AM Florent Revest <revest@chromium.org> wrote:
>
> On Thu, Apr 15, 2021 at 2:38 AM Andrii Nakryiko
> <andrii.nakryiko@gmail.com> wrote:
> > On Wed, Apr 14, 2021 at 11:54 AM Florent Revest <revest@chromium.org> wrote:
> > > +static int try_get_fmt_tmp_buf(char **tmp_buf)
> > > +{
> > > +       struct bpf_printf_buf *bufs;
> > > +       int used;
> > > +
> > > +       if (*tmp_buf)
> > > +               return 0;
> > > +
> > > +       preempt_disable();
> > > +       used = this_cpu_inc_return(bpf_printf_buf_used);
> > > +       if (WARN_ON_ONCE(used > 1)) {
> > > +               this_cpu_dec(bpf_printf_buf_used);
> >
> > this makes me uncomfortable. If used > 1, you won't preempt_enable()
> > here, but you'll decrease count. Then later bpf_printf_cleanup() will
> > be called (inside bpf_printf_prepare()) and will further decrease
> > count (which it didn't increase, so it's a mess now).
>
> Awkward, yes. :( This code is untested because it only covers a niche
> preempt_rt usecase that is hard to reproduce but I should have thought
> harder about these corner cases.
>
> > > +                       i += 2;
> > > +                       if (!final_args)
> > > +                               goto fmt_next;
> > > +
> > > +                       if (try_get_fmt_tmp_buf(&tmp_buf)) {
> > > +                               err = -EBUSY;
> > > +                               goto out;
> >
> > this probably should bypass doing bpf_printf_cleanup() and
> > try_get_fmt_tmp_buf() should enable preemption internally on error.
>
> Yes. I'll fix this and spend some more brain cycles thinking about
> what I'm doing. ;)
>
> > > -static __printf(1, 0) int bpf_do_trace_printk(const char *fmt, ...)
> > > +BPF_CALL_5(bpf_trace_printk, char *, fmt, u32, fmt_size, u64, arg1,
> > > +          u64, arg2, u64, arg3)
> > >  {
> > > +       u64 args[MAX_TRACE_PRINTK_VARARGS] = { arg1, arg2, arg3 };
> > > +       enum bpf_printf_mod_type mod[MAX_TRACE_PRINTK_VARARGS];
> > >         static char buf[BPF_TRACE_PRINTK_SIZE];
> > >         unsigned long flags;
> > > -       va_list ap;
> > >         int ret;
> > >
> > > -       raw_spin_lock_irqsave(&trace_printk_lock, flags);
> > > -       va_start(ap, fmt);
> > > -       ret = vsnprintf(buf, sizeof(buf), fmt, ap);
> > > -       va_end(ap);
> > > -       /* vsnprintf() will not append null for zero-length strings */
> > > +       ret = bpf_printf_prepare(fmt, fmt_size, args, args, mod,
> > > +                                MAX_TRACE_PRINTK_VARARGS);
> > > +       if (ret < 0)
> > > +               return ret;
> > > +
> > > +       ret = snprintf(buf, sizeof(buf), fmt, BPF_CAST_FMT_ARG(0, args, mod),
> > > +               BPF_CAST_FMT_ARG(1, args, mod), BPF_CAST_FMT_ARG(2, args, mod));
> > > +       /* snprintf() will not append null for zero-length strings */
> > >         if (ret == 0)
> > >                 buf[0] = '\0';
> > > +
> > > +       raw_spin_lock_irqsave(&trace_printk_lock, flags);
> > >         trace_bpf_trace_printk(buf);
> > >         raw_spin_unlock_irqrestore(&trace_printk_lock, flags);
> > >
> > > -       return ret;
> >
> > see here, no + 1 :(
>
> I wonder if it's a bug or a feature though. The helper documentation
> says the helper returns "the number of bytes written to the buffer". I
> am not familiar with the internals of trace_printk but if the
> terminating \0 is not outputted in the trace_printk buffer, then it
> kind of makes sense.
>
> Also, if anyone uses this return value, I can imagine that the usecase
> would be if (ret == 0) assume_nothing_was_written(). And if we
> suddenly output 1 here, we might break something.
>
> Because the helper is quite old, maybe we should improve the helper
> documentation instead? Your call :)

Yeah, let's make helper's doc a bit more precise, otherwise let's not
touch it. I doubt many users ever check return result of
bpf_trace_printk() at all, tbh.
diff mbox series

Patch

diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index ff8cd68c01b3..77d1d8c65b81 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -2077,4 +2077,24 @@  int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t,
 struct btf_id_set;
 bool btf_id_set_contains(const struct btf_id_set *set, u32 id);
 
+enum bpf_printf_mod_type {
+	BPF_PRINTF_INT,
+	BPF_PRINTF_LONG,
+	BPF_PRINTF_LONG_LONG,
+};
+
+/* Workaround for getting va_list handling working with different argument type
+ * combinations generically for 32 and 64 bit archs.
+ */
+#define BPF_CAST_FMT_ARG(arg_nb, args, mod)				\
+	(mod[arg_nb] == BPF_PRINTF_LONG_LONG ||				\
+	 (mod[arg_nb] == BPF_PRINTF_LONG && __BITS_PER_LONG == 64)	\
+	  ? (u64)args[arg_nb]						\
+	  : (u32)args[arg_nb])
+
+int bpf_printf_prepare(char *fmt, u32 fmt_size, const u64 *raw_args,
+		       u64 *final_args, enum bpf_printf_mod_type *mod,
+		       u32 num_args);
+void bpf_printf_cleanup(void);
+
 #endif /* _LINUX_BPF_H */
diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
index f306611c4ddf..ff427f5b3358 100644
--- a/kernel/bpf/helpers.c
+++ b/kernel/bpf/helpers.c
@@ -669,6 +669,260 @@  const struct bpf_func_proto bpf_this_cpu_ptr_proto = {
 	.arg1_type	= ARG_PTR_TO_PERCPU_BTF_ID,
 };
 
+static int bpf_trace_copy_string(char *buf, void *unsafe_ptr, char fmt_ptype,
+		size_t bufsz)
+{
+	void __user *user_ptr = (__force void __user *)unsafe_ptr;
+
+	buf[0] = 0;
+
+	switch (fmt_ptype) {
+	case 's':
+#ifdef CONFIG_ARCH_HAS_NON_OVERLAPPING_ADDRESS_SPACE
+		if ((unsigned long)unsafe_ptr < TASK_SIZE)
+			return strncpy_from_user_nofault(buf, user_ptr, bufsz);
+		fallthrough;
+#endif
+	case 'k':
+		return strncpy_from_kernel_nofault(buf, unsafe_ptr, bufsz);
+	case 'u':
+		return strncpy_from_user_nofault(buf, user_ptr, bufsz);
+	}
+
+	return -EINVAL;
+}
+
+/* Per-cpu temp buffers which can be used by printf-like helpers for %s or %p
+ */
+#define MAX_PRINTF_BUF_LEN	512
+
+struct bpf_printf_buf {
+	char tmp_buf[MAX_PRINTF_BUF_LEN];
+};
+static DEFINE_PER_CPU(struct bpf_printf_buf, bpf_printf_buf);
+static DEFINE_PER_CPU(int, bpf_printf_buf_used);
+
+static int try_get_fmt_tmp_buf(char **tmp_buf)
+{
+	struct bpf_printf_buf *bufs;
+	int used;
+
+	if (*tmp_buf)
+		return 0;
+
+	preempt_disable();
+	used = this_cpu_inc_return(bpf_printf_buf_used);
+	if (WARN_ON_ONCE(used > 1)) {
+		this_cpu_dec(bpf_printf_buf_used);
+		return -EBUSY;
+	}
+	bufs = this_cpu_ptr(&bpf_printf_buf);
+	*tmp_buf = bufs->tmp_buf;
+
+	return 0;
+}
+
+void bpf_printf_cleanup(void)
+{
+	if (this_cpu_read(bpf_printf_buf_used)) {
+		this_cpu_dec(bpf_printf_buf_used);
+		preempt_enable();
+	}
+}
+
+/*
+ * bpf_parse_fmt_str - Generic pass on format strings for printf-like helpers
+ *
+ * Returns a negative value if fmt is an invalid format string or 0 otherwise.
+ *
+ * This can be used in two ways:
+ * - Format string verification only: when final_args and mod are NULL
+ * - Arguments preparation: in addition to the above verification, it writes in
+ *   final_args a copy of raw_args where pointers from BPF have been sanitized
+ *   into pointers safe to use by snprintf. This also writes in the mod array
+ *   the size requirement of each argument, usable by BPF_CAST_FMT_ARG for ex.
+ *
+ * In argument preparation mode, if 0 is returned, safe temporary buffers are
+ * allocated and bpf_printf_cleanup should be called to free them after use.
+ */
+int bpf_printf_prepare(char *fmt, u32 fmt_size, const u64 *raw_args,
+			u64 *final_args, enum bpf_printf_mod_type *mod,
+			u32 num_args)
+{
+	char *unsafe_ptr = NULL, *tmp_buf = NULL, *fmt_end;
+	size_t tmp_buf_len = MAX_PRINTF_BUF_LEN;
+	int err, i, num_spec = 0, copy_size;
+	enum bpf_printf_mod_type cur_mod;
+	u64 cur_arg;
+	char fmt_ptype;
+
+	if (!!final_args != !!mod)
+		return -EINVAL;
+
+	fmt_end = strnchr(fmt, fmt_size, 0);
+	if (!fmt_end)
+		return -EINVAL;
+	fmt_size = fmt_end - fmt;
+
+	for (i = 0; i < fmt_size; i++) {
+		if ((!isprint(fmt[i]) && !isspace(fmt[i])) || !isascii(fmt[i])) {
+			err = -EINVAL;
+			goto out;
+		}
+
+		if (fmt[i] != '%')
+			continue;
+
+		if (fmt[i + 1] == '%') {
+			i++;
+			continue;
+		}
+
+		if (num_spec >= num_args) {
+			err = -EINVAL;
+			goto out;
+		}
+
+		/* The string is zero-terminated so if fmt[i] != 0, we can
+		 * always access fmt[i + 1], in the worst case it will be a 0
+		 */
+		i++;
+
+		/* skip optional "[0 +-][num]" width formatting field */
+		while (fmt[i] == '0' || fmt[i] == '+'  || fmt[i] == '-' ||
+		       fmt[i] == ' ')
+			i++;
+		if (fmt[i] >= '1' && fmt[i] <= '9') {
+			i++;
+			while (fmt[i] >= '0' && fmt[i] <= '9')
+				i++;
+		}
+
+		if (fmt[i] == 'p') {
+			cur_mod = BPF_PRINTF_LONG;
+
+			if ((fmt[i + 1] == 'k' || fmt[i + 1] == 'u') &&
+			    fmt[i + 2] == 's') {
+				fmt_ptype = fmt[i + 1];
+				i += 2;
+				goto fmt_str;
+			}
+
+			if (fmt[i + 1] == 0 || isspace(fmt[i + 1]) ||
+			    ispunct(fmt[i + 1]) || fmt[i + 1] == 'K' ||
+			    fmt[i + 1] == 'x' || fmt[i + 1] == 'B' ||
+			    fmt[i + 1] == 's' || fmt[i + 1] == 'S') {
+				/* just kernel pointers */
+				if (final_args)
+					cur_arg = raw_args[num_spec];
+				goto fmt_next;
+			}
+
+			/* only support "%pI4", "%pi4", "%pI6" and "%pi6". */
+			if ((fmt[i + 1] != 'i' && fmt[i + 1] != 'I') ||
+			    (fmt[i + 2] != '4' && fmt[i + 2] != '6')) {
+				err = -EINVAL;
+				goto out;
+			}
+
+			i += 2;
+			if (!final_args)
+				goto fmt_next;
+
+			if (try_get_fmt_tmp_buf(&tmp_buf)) {
+				err = -EBUSY;
+				goto out;
+			}
+
+			copy_size = (fmt[i + 2] == '4') ? 4 : 16;
+			if (tmp_buf_len < copy_size) {
+				err = -ENOSPC;
+				goto out;
+			}
+
+			unsafe_ptr = (char *)(long)raw_args[num_spec];
+			err = copy_from_kernel_nofault(tmp_buf, unsafe_ptr,
+						       copy_size);
+			if (err < 0)
+				memset(tmp_buf, 0, copy_size);
+			cur_arg = (u64)(long)tmp_buf;
+			tmp_buf += copy_size;
+			tmp_buf_len -= copy_size;
+
+			goto fmt_next;
+		} else if (fmt[i] == 's') {
+			cur_mod = BPF_PRINTF_LONG;
+			fmt_ptype = fmt[i];
+fmt_str:
+			if (fmt[i + 1] != 0 &&
+			    !isspace(fmt[i + 1]) &&
+			    !ispunct(fmt[i + 1])) {
+				err = -EINVAL;
+				goto out;
+			}
+
+			if (!final_args)
+				goto fmt_next;
+
+			if (try_get_fmt_tmp_buf(&tmp_buf)) {
+				err = -EBUSY;
+				goto out;
+			}
+
+			if (!tmp_buf_len) {
+				err = -ENOSPC;
+				goto out;
+			}
+
+			unsafe_ptr = (char *)(long)raw_args[num_spec];
+			err = bpf_trace_copy_string(tmp_buf, unsafe_ptr,
+						    fmt_ptype, tmp_buf_len);
+			if (err < 0) {
+				tmp_buf[0] = '\0';
+				err = 1;
+			}
+
+			cur_arg = (u64)(long)tmp_buf;
+			tmp_buf += err;
+			tmp_buf_len -= err;
+
+			goto fmt_next;
+		}
+
+		cur_mod = BPF_PRINTF_INT;
+
+		if (fmt[i] == 'l') {
+			cur_mod = BPF_PRINTF_LONG;
+			i++;
+		}
+		if (fmt[i] == 'l') {
+			cur_mod = BPF_PRINTF_LONG_LONG;
+			i++;
+		}
+
+		if (fmt[i] != 'i' && fmt[i] != 'd' && fmt[i] != 'u' &&
+		    fmt[i] != 'x' && fmt[i] != 'X') {
+			err = -EINVAL;
+			goto out;
+		}
+
+		if (final_args)
+			cur_arg = raw_args[num_spec];
+fmt_next:
+		if (final_args) {
+			mod[num_spec] = cur_mod;
+			final_args[num_spec] = cur_arg;
+		}
+		num_spec++;
+	}
+
+	err = 0;
+out:
+	if (err)
+		bpf_printf_cleanup();
+	return err;
+}
+
 const struct bpf_func_proto bpf_get_current_task_proto __weak;
 const struct bpf_func_proto bpf_probe_read_user_proto __weak;
 const struct bpf_func_proto bpf_probe_read_user_str_proto __weak;
diff --git a/kernel/trace/bpf_trace.c b/kernel/trace/bpf_trace.c
index 0d23755c2747..a13f8644b357 100644
--- a/kernel/trace/bpf_trace.c
+++ b/kernel/trace/bpf_trace.c
@@ -372,188 +372,38 @@  static const struct bpf_func_proto *bpf_get_probe_write_proto(void)
 	return &bpf_probe_write_user_proto;
 }
 
-static void bpf_trace_copy_string(char *buf, void *unsafe_ptr, char fmt_ptype,
-		size_t bufsz)
-{
-	void __user *user_ptr = (__force void __user *)unsafe_ptr;
-
-	buf[0] = 0;
-
-	switch (fmt_ptype) {
-	case 's':
-#ifdef CONFIG_ARCH_HAS_NON_OVERLAPPING_ADDRESS_SPACE
-		if ((unsigned long)unsafe_ptr < TASK_SIZE) {
-			strncpy_from_user_nofault(buf, user_ptr, bufsz);
-			break;
-		}
-		fallthrough;
-#endif
-	case 'k':
-		strncpy_from_kernel_nofault(buf, unsafe_ptr, bufsz);
-		break;
-	case 'u':
-		strncpy_from_user_nofault(buf, user_ptr, bufsz);
-		break;
-	}
-}
-
 static DEFINE_RAW_SPINLOCK(trace_printk_lock);
 
-#define BPF_TRACE_PRINTK_SIZE   1024
+#define MAX_TRACE_PRINTK_VARARGS	3
+#define BPF_TRACE_PRINTK_SIZE		1024
 
-static __printf(1, 0) int bpf_do_trace_printk(const char *fmt, ...)
+BPF_CALL_5(bpf_trace_printk, char *, fmt, u32, fmt_size, u64, arg1,
+	   u64, arg2, u64, arg3)
 {
+	u64 args[MAX_TRACE_PRINTK_VARARGS] = { arg1, arg2, arg3 };
+	enum bpf_printf_mod_type mod[MAX_TRACE_PRINTK_VARARGS];
 	static char buf[BPF_TRACE_PRINTK_SIZE];
 	unsigned long flags;
-	va_list ap;
 	int ret;
 
-	raw_spin_lock_irqsave(&trace_printk_lock, flags);
-	va_start(ap, fmt);
-	ret = vsnprintf(buf, sizeof(buf), fmt, ap);
-	va_end(ap);
-	/* vsnprintf() will not append null for zero-length strings */
+	ret = bpf_printf_prepare(fmt, fmt_size, args, args, mod,
+				 MAX_TRACE_PRINTK_VARARGS);
+	if (ret < 0)
+		return ret;
+
+	ret = snprintf(buf, sizeof(buf), fmt, BPF_CAST_FMT_ARG(0, args, mod),
+		BPF_CAST_FMT_ARG(1, args, mod), BPF_CAST_FMT_ARG(2, args, mod));
+	/* snprintf() will not append null for zero-length strings */
 	if (ret == 0)
 		buf[0] = '\0';
+
+	raw_spin_lock_irqsave(&trace_printk_lock, flags);
 	trace_bpf_trace_printk(buf);
 	raw_spin_unlock_irqrestore(&trace_printk_lock, flags);
 
-	return ret;
-}
-
-/*
- * Only limited trace_printk() conversion specifiers allowed:
- * %d %i %u %x %ld %li %lu %lx %lld %lli %llu %llx %p %pB %pks %pus %s
- */
-BPF_CALL_5(bpf_trace_printk, char *, fmt, u32, fmt_size, u64, arg1,
-	   u64, arg2, u64, arg3)
-{
-	int i, mod[3] = {}, fmt_cnt = 0;
-	char buf[64], fmt_ptype;
-	void *unsafe_ptr = NULL;
-	bool str_seen = false;
+	bpf_printf_cleanup();
 
-	/*
-	 * bpf_check()->check_func_arg()->check_stack_boundary()
-	 * guarantees that fmt points to bpf program stack,
-	 * fmt_size bytes of it were initialized and fmt_size > 0
-	 */
-	if (fmt[--fmt_size] != 0)
-		return -EINVAL;
-
-	/* check format string for allowed specifiers */
-	for (i = 0; i < fmt_size; i++) {
-		if ((!isprint(fmt[i]) && !isspace(fmt[i])) || !isascii(fmt[i]))
-			return -EINVAL;
-
-		if (fmt[i] != '%')
-			continue;
-
-		if (fmt_cnt >= 3)
-			return -EINVAL;
-
-		/* fmt[i] != 0 && fmt[last] == 0, so we can access fmt[i + 1] */
-		i++;
-		if (fmt[i] == 'l') {
-			mod[fmt_cnt]++;
-			i++;
-		} else if (fmt[i] == 'p') {
-			mod[fmt_cnt]++;
-			if ((fmt[i + 1] == 'k' ||
-			     fmt[i + 1] == 'u') &&
-			    fmt[i + 2] == 's') {
-				fmt_ptype = fmt[i + 1];
-				i += 2;
-				goto fmt_str;
-			}
-
-			if (fmt[i + 1] == 'B') {
-				i++;
-				goto fmt_next;
-			}
-
-			/* disallow any further format extensions */
-			if (fmt[i + 1] != 0 &&
-			    !isspace(fmt[i + 1]) &&
-			    !ispunct(fmt[i + 1]))
-				return -EINVAL;
-
-			goto fmt_next;
-		} else if (fmt[i] == 's') {
-			mod[fmt_cnt]++;
-			fmt_ptype = fmt[i];
-fmt_str:
-			if (str_seen)
-				/* allow only one '%s' per fmt string */
-				return -EINVAL;
-			str_seen = true;
-
-			if (fmt[i + 1] != 0 &&
-			    !isspace(fmt[i + 1]) &&
-			    !ispunct(fmt[i + 1]))
-				return -EINVAL;
-
-			switch (fmt_cnt) {
-			case 0:
-				unsafe_ptr = (void *)(long)arg1;
-				arg1 = (long)buf;
-				break;
-			case 1:
-				unsafe_ptr = (void *)(long)arg2;
-				arg2 = (long)buf;
-				break;
-			case 2:
-				unsafe_ptr = (void *)(long)arg3;
-				arg3 = (long)buf;
-				break;
-			}
-
-			bpf_trace_copy_string(buf, unsafe_ptr, fmt_ptype,
-					sizeof(buf));
-			goto fmt_next;
-		}
-
-		if (fmt[i] == 'l') {
-			mod[fmt_cnt]++;
-			i++;
-		}
-
-		if (fmt[i] != 'i' && fmt[i] != 'd' &&
-		    fmt[i] != 'u' && fmt[i] != 'x')
-			return -EINVAL;
-fmt_next:
-		fmt_cnt++;
-	}
-
-/* Horrid workaround for getting va_list handling working with different
- * argument type combinations generically for 32 and 64 bit archs.
- */
-#define __BPF_TP_EMIT()	__BPF_ARG3_TP()
-#define __BPF_TP(...)							\
-	bpf_do_trace_printk(fmt, ##__VA_ARGS__)
-
-#define __BPF_ARG1_TP(...)						\
-	((mod[0] == 2 || (mod[0] == 1 && __BITS_PER_LONG == 64))	\
-	  ? __BPF_TP(arg1, ##__VA_ARGS__)				\
-	  : ((mod[0] == 1 || (mod[0] == 0 && __BITS_PER_LONG == 32))	\
-	      ? __BPF_TP((long)arg1, ##__VA_ARGS__)			\
-	      : __BPF_TP((u32)arg1, ##__VA_ARGS__)))
-
-#define __BPF_ARG2_TP(...)						\
-	((mod[1] == 2 || (mod[1] == 1 && __BITS_PER_LONG == 64))	\
-	  ? __BPF_ARG1_TP(arg2, ##__VA_ARGS__)				\
-	  : ((mod[1] == 1 || (mod[1] == 0 && __BITS_PER_LONG == 32))	\
-	      ? __BPF_ARG1_TP((long)arg2, ##__VA_ARGS__)		\
-	      : __BPF_ARG1_TP((u32)arg2, ##__VA_ARGS__)))
-
-#define __BPF_ARG3_TP(...)						\
-	((mod[2] == 2 || (mod[2] == 1 && __BITS_PER_LONG == 64))	\
-	  ? __BPF_ARG2_TP(arg3, ##__VA_ARGS__)				\
-	  : ((mod[2] == 1 || (mod[2] == 0 && __BITS_PER_LONG == 32))	\
-	      ? __BPF_ARG2_TP((long)arg3, ##__VA_ARGS__)		\
-	      : __BPF_ARG2_TP((u32)arg3, ##__VA_ARGS__)))
-
-	return __BPF_TP_EMIT();
+	return ret;
 }
 
 static const struct bpf_func_proto bpf_trace_printk_proto = {
@@ -581,184 +431,37 @@  const struct bpf_func_proto *bpf_get_trace_printk_proto(void)
 }
 
 #define MAX_SEQ_PRINTF_VARARGS		12
-#define MAX_SEQ_PRINTF_MAX_MEMCPY	6
-#define MAX_SEQ_PRINTF_STR_LEN		128
-
-struct bpf_seq_printf_buf {
-	char buf[MAX_SEQ_PRINTF_MAX_MEMCPY][MAX_SEQ_PRINTF_STR_LEN];
-};
-static DEFINE_PER_CPU(struct bpf_seq_printf_buf, bpf_seq_printf_buf);
-static DEFINE_PER_CPU(int, bpf_seq_printf_buf_used);
 
 BPF_CALL_5(bpf_seq_printf, struct seq_file *, m, char *, fmt, u32, fmt_size,
 	   const void *, data, u32, data_len)
 {
-	int err = -EINVAL, fmt_cnt = 0, memcpy_cnt = 0;
-	int i, buf_used, copy_size, num_args;
-	u64 params[MAX_SEQ_PRINTF_VARARGS];
-	struct bpf_seq_printf_buf *bufs;
-	const u64 *args = data;
-
-	buf_used = this_cpu_inc_return(bpf_seq_printf_buf_used);
-	if (WARN_ON_ONCE(buf_used > 1)) {
-		err = -EBUSY;
-		goto out;
-	}
-
-	bufs = this_cpu_ptr(&bpf_seq_printf_buf);
-
-	/*
-	 * bpf_check()->check_func_arg()->check_stack_boundary()
-	 * guarantees that fmt points to bpf program stack,
-	 * fmt_size bytes of it were initialized and fmt_size > 0
-	 */
-	if (fmt[--fmt_size] != 0)
-		goto out;
-
-	if (data_len & 7)
-		goto out;
-
-	for (i = 0; i < fmt_size; i++) {
-		if (fmt[i] == '%') {
-			if (fmt[i + 1] == '%')
-				i++;
-			else if (!data || !data_len)
-				goto out;
-		}
-	}
+	enum bpf_printf_mod_type mod[MAX_SEQ_PRINTF_VARARGS];
+	u64 args[MAX_SEQ_PRINTF_VARARGS];
+	int err, num_args;
 
+	if (data_len & 7 || data_len > MAX_SEQ_PRINTF_VARARGS * 8 ||
+	    (data_len && !data))
+		return -EINVAL;
 	num_args = data_len / 8;
 
-	/* check format string for allowed specifiers */
-	for (i = 0; i < fmt_size; i++) {
-		/* only printable ascii for now. */
-		if ((!isprint(fmt[i]) && !isspace(fmt[i])) || !isascii(fmt[i])) {
-			err = -EINVAL;
-			goto out;
-		}
-
-		if (fmt[i] != '%')
-			continue;
-
-		if (fmt[i + 1] == '%') {
-			i++;
-			continue;
-		}
-
-		if (fmt_cnt >= MAX_SEQ_PRINTF_VARARGS) {
-			err = -E2BIG;
-			goto out;
-		}
-
-		if (fmt_cnt >= num_args) {
-			err = -EINVAL;
-			goto out;
-		}
-
-		/* fmt[i] != 0 && fmt[last] == 0, so we can access fmt[i + 1] */
-		i++;
-
-		/* skip optional "[0 +-][num]" width formating field */
-		while (fmt[i] == '0' || fmt[i] == '+'  || fmt[i] == '-' ||
-		       fmt[i] == ' ')
-			i++;
-		if (fmt[i] >= '1' && fmt[i] <= '9') {
-			i++;
-			while (fmt[i] >= '0' && fmt[i] <= '9')
-				i++;
-		}
-
-		if (fmt[i] == 's') {
-			void *unsafe_ptr;
-
-			/* try our best to copy */
-			if (memcpy_cnt >= MAX_SEQ_PRINTF_MAX_MEMCPY) {
-				err = -E2BIG;
-				goto out;
-			}
-
-			unsafe_ptr = (void *)(long)args[fmt_cnt];
-			err = strncpy_from_kernel_nofault(bufs->buf[memcpy_cnt],
-					unsafe_ptr, MAX_SEQ_PRINTF_STR_LEN);
-			if (err < 0)
-				bufs->buf[memcpy_cnt][0] = '\0';
-			params[fmt_cnt] = (u64)(long)bufs->buf[memcpy_cnt];
-
-			fmt_cnt++;
-			memcpy_cnt++;
-			continue;
-		}
-
-		if (fmt[i] == 'p') {
-			if (fmt[i + 1] == 0 ||
-			    fmt[i + 1] == 'K' ||
-			    fmt[i + 1] == 'x' ||
-			    fmt[i + 1] == 'B') {
-				/* just kernel pointers */
-				params[fmt_cnt] = args[fmt_cnt];
-				fmt_cnt++;
-				continue;
-			}
-
-			/* only support "%pI4", "%pi4", "%pI6" and "%pi6". */
-			if (fmt[i + 1] != 'i' && fmt[i + 1] != 'I') {
-				err = -EINVAL;
-				goto out;
-			}
-			if (fmt[i + 2] != '4' && fmt[i + 2] != '6') {
-				err = -EINVAL;
-				goto out;
-			}
-
-			if (memcpy_cnt >= MAX_SEQ_PRINTF_MAX_MEMCPY) {
-				err = -E2BIG;
-				goto out;
-			}
-
-
-			copy_size = (fmt[i + 2] == '4') ? 4 : 16;
-
-			err = copy_from_kernel_nofault(bufs->buf[memcpy_cnt],
-						(void *) (long) args[fmt_cnt],
-						copy_size);
-			if (err < 0)
-				memset(bufs->buf[memcpy_cnt], 0, copy_size);
-			params[fmt_cnt] = (u64)(long)bufs->buf[memcpy_cnt];
-
-			i += 2;
-			fmt_cnt++;
-			memcpy_cnt++;
-			continue;
-		}
-
-		if (fmt[i] == 'l') {
-			i++;
-			if (fmt[i] == 'l')
-				i++;
-		}
-
-		if (fmt[i] != 'i' && fmt[i] != 'd' &&
-		    fmt[i] != 'u' && fmt[i] != 'x' &&
-		    fmt[i] != 'X') {
-			err = -EINVAL;
-			goto out;
-		}
-
-		params[fmt_cnt] = args[fmt_cnt];
-		fmt_cnt++;
-	}
+	err = bpf_printf_prepare(fmt, fmt_size, data, args, mod, num_args);
+	if (err < 0)
+		return err;
 
 	/* Maximumly we can have MAX_SEQ_PRINTF_VARARGS parameter, just give
 	 * all of them to seq_printf().
 	 */
-	seq_printf(m, fmt, params[0], params[1], params[2], params[3],
-		   params[4], params[5], params[6], params[7], params[8],
-		   params[9], params[10], params[11]);
+	seq_printf(m, fmt, BPF_CAST_FMT_ARG(0, args, mod),
+		BPF_CAST_FMT_ARG(1, args, mod), BPF_CAST_FMT_ARG(2, args, mod),
+		BPF_CAST_FMT_ARG(3, args, mod), BPF_CAST_FMT_ARG(4, args, mod),
+		BPF_CAST_FMT_ARG(5, args, mod), BPF_CAST_FMT_ARG(6, args, mod),
+		BPF_CAST_FMT_ARG(7, args, mod), BPF_CAST_FMT_ARG(8, args, mod),
+		BPF_CAST_FMT_ARG(9, args, mod), BPF_CAST_FMT_ARG(10, args, mod),
+		BPF_CAST_FMT_ARG(11, args, mod));
 
-	err = seq_has_overflowed(m) ? -EOVERFLOW : 0;
-out:
-	this_cpu_dec(bpf_seq_printf_buf_used);
-	return err;
+	bpf_printf_cleanup();
+
+	return seq_has_overflowed(m) ? -EOVERFLOW : 0;
 }
 
 BTF_ID_LIST_SINGLE(btf_seq_file_ids, struct, seq_file)