diff mbox series

[bpf-next,2/5] bpf: bpf_csum_diff: optimize and homogenize for all archs

Message ID 20241021122112.101513-3-puranjay@kernel.org (mailing list archive)
State Superseded
Headers show
Series Optimize bpf_csum_diff() and homogenize for all archs | expand

Checks

Context Check Description
conchuod/vmtest-for-next-PR fail PR summary
conchuod/patch-2-test-1 success .github/scripts/patches/tests/build_rv32_defconfig.sh took 138.27s
conchuod/patch-2-test-2 success .github/scripts/patches/tests/build_rv64_clang_allmodconfig.sh took 1276.66s
conchuod/patch-2-test-3 success .github/scripts/patches/tests/build_rv64_gcc_allmodconfig.sh took 1491.11s
conchuod/patch-2-test-4 success .github/scripts/patches/tests/build_rv64_nommu_k210_defconfig.sh took 20.97s
conchuod/patch-2-test-5 success .github/scripts/patches/tests/build_rv64_nommu_virt_defconfig.sh took 22.80s
conchuod/patch-2-test-6 success .github/scripts/patches/tests/checkpatch.sh took 0.45s
conchuod/patch-2-test-7 success .github/scripts/patches/tests/dtb_warn_rv64.sh took 43.84s
conchuod/patch-2-test-8 success .github/scripts/patches/tests/header_inline.sh took 0.00s
conchuod/patch-2-test-9 success .github/scripts/patches/tests/kdoc.sh took 0.66s
conchuod/patch-2-test-10 success .github/scripts/patches/tests/module_param.sh took 0.03s
conchuod/patch-2-test-11 success .github/scripts/patches/tests/verify_fixes.sh took 0.00s
conchuod/patch-2-test-12 success .github/scripts/patches/tests/verify_signedoff.sh took 0.03s

Commit Message

Puranjay Mohan Oct. 21, 2024, 12:21 p.m. UTC
1. Optimization
   ------------

The current implementation copies the 'from' and 'to' buffers to a
scratchpad and it takes the bitwise NOT of 'from' buffer while copying.
In the next step csum_partial() is called with this scratchpad.

so, mathematically, the current implementation is doing:

	result = csum(to - from)

Here, 'to'  and '~ from' are copied in to the scratchpad buffer, we need
it in the scratchpad buffer because csum_partial() takes a single
contiguous buffer and not two disjoint buffers like 'to' and 'from'.

We can re write this equation to:

	result = csum(to) - csum(from)

using the distributive property of csum().

this allows 'to' and 'from' to be at different locations and therefore
this scratchpad and copying is not needed.

This in C code will look like:

result = csum_sub(csum_partial(to, to_size, seed),
                  csum_partial(from, from_size, 0));

2. Homogenization
   --------------

The bpf_csum_diff() helper calls csum_partial() which is implemented by
some architectures like arm and x86 but other architectures rely on the
generic implementation in lib/checksum.c

The generic implementation in lib/checksum.c returns a 16 bit value but
the arch specific implementations can return more than 16 bits, this
works out in most places because before the result is used, it is passed
through csum_fold() that turns it into a 16-bit value.

bpf_csum_diff() directly returns the value from csum_partial() and
therefore the returned values could be different on different
architectures. see discussion in [1]:

for the int value 28 the calculated checksums are:

x86                    :    -29 : 0xffffffe3
generic (arm64, riscv) :  65507 : 0x0000ffe3
arm                    : 131042 : 0x0001ffe2

Pass the result of bpf_csum_diff() through from32to16() before returning
to homogenize this result for all architectures.

NOTE: from32to16() is used instead of csum_fold() because csum_fold()
does from32to16() + bitwise NOT of the result, which is not what we want
to do here.

[1] https://lore.kernel.org/bpf/CAJ+HfNiQbOcqCLxFUP2FMm5QrLXUUaj852Fxe3hn_2JNiucn6g@mail.gmail.com/

Signed-off-by: Puranjay Mohan <puranjay@kernel.org>
---
 net/core/filter.c | 37 +++++++++----------------------------
 1 file changed, 9 insertions(+), 28 deletions(-)

Comments

Daniel Borkmann Oct. 21, 2024, 1:42 p.m. UTC | #1
On 10/21/24 2:21 PM, Puranjay Mohan wrote:
> 1. Optimization
>     ------------
> 
> The current implementation copies the 'from' and 'to' buffers to a
> scratchpad and it takes the bitwise NOT of 'from' buffer while copying.
> In the next step csum_partial() is called with this scratchpad.
> 
> so, mathematically, the current implementation is doing:
> 
> 	result = csum(to - from)
> 
> Here, 'to'  and '~ from' are copied in to the scratchpad buffer, we need
> it in the scratchpad buffer because csum_partial() takes a single
> contiguous buffer and not two disjoint buffers like 'to' and 'from'.
> 
> We can re write this equation to:
> 
> 	result = csum(to) - csum(from)
> 
> using the distributive property of csum().
> 
> this allows 'to' and 'from' to be at different locations and therefore
> this scratchpad and copying is not needed.
> 
> This in C code will look like:
> 
> result = csum_sub(csum_partial(to, to_size, seed),
>                    csum_partial(from, from_size, 0));
> 
> 2. Homogenization
>     --------------
> 
> The bpf_csum_diff() helper calls csum_partial() which is implemented by
> some architectures like arm and x86 but other architectures rely on the
> generic implementation in lib/checksum.c
> 
> The generic implementation in lib/checksum.c returns a 16 bit value but
> the arch specific implementations can return more than 16 bits, this
> works out in most places because before the result is used, it is passed
> through csum_fold() that turns it into a 16-bit value.
> 
> bpf_csum_diff() directly returns the value from csum_partial() and
> therefore the returned values could be different on different
> architectures. see discussion in [1]:
> 
> for the int value 28 the calculated checksums are:
> 
> x86                    :    -29 : 0xffffffe3
> generic (arm64, riscv) :  65507 : 0x0000ffe3
> arm                    : 131042 : 0x0001ffe2
> 
> Pass the result of bpf_csum_diff() through from32to16() before returning
> to homogenize this result for all architectures.
> 
> NOTE: from32to16() is used instead of csum_fold() because csum_fold()
> does from32to16() + bitwise NOT of the result, which is not what we want
> to do here.
> 
> [1] https://lore.kernel.org/bpf/CAJ+HfNiQbOcqCLxFUP2FMm5QrLXUUaj852Fxe3hn_2JNiucn6g@mail.gmail.com/
> 
> Signed-off-by: Puranjay Mohan <puranjay@kernel.org>

Thanks for looking into this!

Acked-by: Daniel Borkmann <daniel@iogearbox.net>
Toke Høiland-Jørgensen Oct. 22, 2024, 9:54 a.m. UTC | #2
Puranjay Mohan <puranjay@kernel.org> writes:

> 1. Optimization
>    ------------
>
> The current implementation copies the 'from' and 'to' buffers to a
> scratchpad and it takes the bitwise NOT of 'from' buffer while copying.
> In the next step csum_partial() is called with this scratchpad.
>
> so, mathematically, the current implementation is doing:
>
> 	result = csum(to - from)
>
> Here, 'to'  and '~ from' are copied in to the scratchpad buffer, we need
> it in the scratchpad buffer because csum_partial() takes a single
> contiguous buffer and not two disjoint buffers like 'to' and 'from'.
>
> We can re write this equation to:
>
> 	result = csum(to) - csum(from)
>
> using the distributive property of csum().
>
> this allows 'to' and 'from' to be at different locations and therefore
> this scratchpad and copying is not needed.
>
> This in C code will look like:
>
> result = csum_sub(csum_partial(to, to_size, seed),
>                   csum_partial(from, from_size, 0));
>
> 2. Homogenization
>    --------------
>
> The bpf_csum_diff() helper calls csum_partial() which is implemented by
> some architectures like arm and x86 but other architectures rely on the
> generic implementation in lib/checksum.c
>
> The generic implementation in lib/checksum.c returns a 16 bit value but
> the arch specific implementations can return more than 16 bits, this
> works out in most places because before the result is used, it is passed
> through csum_fold() that turns it into a 16-bit value.
>
> bpf_csum_diff() directly returns the value from csum_partial() and
> therefore the returned values could be different on different
> architectures. see discussion in [1]:
>
> for the int value 28 the calculated checksums are:
>
> x86                    :    -29 : 0xffffffe3
> generic (arm64, riscv) :  65507 : 0x0000ffe3
> arm                    : 131042 : 0x0001ffe2
>
> Pass the result of bpf_csum_diff() through from32to16() before returning
> to homogenize this result for all architectures.
>
> NOTE: from32to16() is used instead of csum_fold() because csum_fold()
> does from32to16() + bitwise NOT of the result, which is not what we want
> to do here.
>
> [1] https://lore.kernel.org/bpf/CAJ+HfNiQbOcqCLxFUP2FMm5QrLXUUaj852Fxe3hn_2JNiucn6g@mail.gmail.com/
>
> Signed-off-by: Puranjay Mohan <puranjay@kernel.org>

Pretty neat simplification :)

Reviewed-by: Toke Høiland-Jørgensen <toke@redhat.com>
kernel test robot Oct. 22, 2024, 6:09 p.m. UTC | #3
Hi Puranjay,

kernel test robot noticed the following build warnings:

[auto build test WARNING on bpf-next/master]

url:    https://github.com/intel-lab-lkp/linux/commits/Puranjay-Mohan/net-checksum-move-from32to16-to-generic-header/20241021-202707
base:   https://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next.git master
patch link:    https://lore.kernel.org/r/20241021122112.101513-3-puranjay%40kernel.org
patch subject: [PATCH bpf-next 2/5] bpf: bpf_csum_diff: optimize and homogenize for all archs
config: x86_64-randconfig-122-20241022 (https://download.01.org/0day-ci/archive/20241023/202410230122.BYZLEUHz-lkp@intel.com/config)
compiler: gcc-12 (Debian 12.2.0-14) 12.2.0
reproduce (this is a W=1 build): (https://download.01.org/0day-ci/archive/20241023/202410230122.BYZLEUHz-lkp@intel.com/reproduce)

If you fix the issue in a separate patch/commit (i.e. not just a new version of
the same patch/commit), kindly add following tags
| Reported-by: kernel test robot <lkp@intel.com>
| Closes: https://lore.kernel.org/oe-kbuild-all/202410230122.BYZLEUHz-lkp@intel.com/

sparse warnings: (new ones prefixed by >>)
   net/core/filter.c:1423:39: sparse: sparse: incorrect type in argument 1 (different address spaces) @@     expected struct sock_filter const *filter @@     got struct sock_filter [noderef] __user *filter @@
   net/core/filter.c:1423:39: sparse:     expected struct sock_filter const *filter
   net/core/filter.c:1423:39: sparse:     got struct sock_filter [noderef] __user *filter
   net/core/filter.c:1501:39: sparse: sparse: incorrect type in argument 1 (different address spaces) @@     expected struct sock_filter const *filter @@     got struct sock_filter [noderef] __user *filter @@
   net/core/filter.c:1501:39: sparse:     expected struct sock_filter const *filter
   net/core/filter.c:1501:39: sparse:     got struct sock_filter [noderef] __user *filter
   net/core/filter.c:2321:45: sparse: sparse: incorrect type in argument 2 (different base types) @@     expected restricted __be32 [usertype] daddr @@     got unsigned int [usertype] ipv4_nh @@
   net/core/filter.c:2321:45: sparse:     expected restricted __be32 [usertype] daddr
   net/core/filter.c:2321:45: sparse:     got unsigned int [usertype] ipv4_nh
   net/core/filter.c:10993:31: sparse: sparse: symbol 'sk_filter_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11000:27: sparse: sparse: symbol 'sk_filter_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11004:31: sparse: sparse: symbol 'tc_cls_act_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11013:27: sparse: sparse: symbol 'tc_cls_act_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11017:31: sparse: sparse: symbol 'xdp_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11029:31: sparse: sparse: symbol 'cg_skb_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11035:27: sparse: sparse: symbol 'cg_skb_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11039:31: sparse: sparse: symbol 'lwt_in_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11045:27: sparse: sparse: symbol 'lwt_in_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11049:31: sparse: sparse: symbol 'lwt_out_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11055:27: sparse: sparse: symbol 'lwt_out_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11059:31: sparse: sparse: symbol 'lwt_xmit_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11066:27: sparse: sparse: symbol 'lwt_xmit_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11070:31: sparse: sparse: symbol 'lwt_seg6local_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11076:27: sparse: sparse: symbol 'lwt_seg6local_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11079:31: sparse: sparse: symbol 'cg_sock_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11085:27: sparse: sparse: symbol 'cg_sock_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11088:31: sparse: sparse: symbol 'cg_sock_addr_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11094:27: sparse: sparse: symbol 'cg_sock_addr_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11097:31: sparse: sparse: symbol 'sock_ops_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11103:27: sparse: sparse: symbol 'sock_ops_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11106:31: sparse: sparse: symbol 'sk_skb_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11113:27: sparse: sparse: symbol 'sk_skb_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11116:31: sparse: sparse: symbol 'sk_msg_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11123:27: sparse: sparse: symbol 'sk_msg_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11126:31: sparse: sparse: symbol 'flow_dissector_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11132:27: sparse: sparse: symbol 'flow_dissector_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11460:31: sparse: sparse: symbol 'sk_reuseport_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:11466:27: sparse: sparse: symbol 'sk_reuseport_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11668:27: sparse: sparse: symbol 'sk_lookup_prog_ops' was not declared. Should it be static?
   net/core/filter.c:11672:31: sparse: sparse: symbol 'sk_lookup_verifier_ops' was not declared. Should it be static?
   net/core/filter.c:1931:43: sparse: sparse: incorrect type in argument 2 (different base types) @@     expected restricted __wsum [usertype] diff @@     got unsigned long long [usertype] to @@
   net/core/filter.c:1931:43: sparse:     expected restricted __wsum [usertype] diff
   net/core/filter.c:1931:43: sparse:     got unsigned long long [usertype] to
   net/core/filter.c:1934:36: sparse: sparse: incorrect type in argument 2 (different base types) @@     expected restricted __be16 [usertype] old @@     got unsigned long long [usertype] from @@
   net/core/filter.c:1934:36: sparse:     expected restricted __be16 [usertype] old
   net/core/filter.c:1934:36: sparse:     got unsigned long long [usertype] from
   net/core/filter.c:1934:42: sparse: sparse: incorrect type in argument 3 (different base types) @@     expected restricted __be16 [usertype] new @@     got unsigned long long [usertype] to @@
   net/core/filter.c:1934:42: sparse:     expected restricted __be16 [usertype] new
   net/core/filter.c:1934:42: sparse:     got unsigned long long [usertype] to
   net/core/filter.c:1937:36: sparse: sparse: incorrect type in argument 2 (different base types) @@     expected restricted __be32 [usertype] from @@     got unsigned long long [usertype] from @@
   net/core/filter.c:1937:36: sparse:     expected restricted __be32 [usertype] from
   net/core/filter.c:1937:36: sparse:     got unsigned long long [usertype] from
   net/core/filter.c:1937:42: sparse: sparse: incorrect type in argument 3 (different base types) @@     expected restricted __be32 [usertype] to @@     got unsigned long long [usertype] to @@
   net/core/filter.c:1937:42: sparse:     expected restricted __be32 [usertype] to
   net/core/filter.c:1937:42: sparse:     got unsigned long long [usertype] to
   net/core/filter.c:1982:59: sparse: sparse: incorrect type in argument 3 (different base types) @@     expected restricted __wsum [usertype] diff @@     got unsigned long long [usertype] to @@
   net/core/filter.c:1982:59: sparse:     expected restricted __wsum [usertype] diff
   net/core/filter.c:1982:59: sparse:     got unsigned long long [usertype] to
   net/core/filter.c:1985:52: sparse: sparse: incorrect type in argument 3 (different base types) @@     expected restricted __be16 [usertype] from @@     got unsigned long long [usertype] from @@
   net/core/filter.c:1985:52: sparse:     expected restricted __be16 [usertype] from
   net/core/filter.c:1985:52: sparse:     got unsigned long long [usertype] from
   net/core/filter.c:1985:58: sparse: sparse: incorrect type in argument 4 (different base types) @@     expected restricted __be16 [usertype] to @@     got unsigned long long [usertype] to @@
   net/core/filter.c:1985:58: sparse:     expected restricted __be16 [usertype] to
   net/core/filter.c:1985:58: sparse:     got unsigned long long [usertype] to
   net/core/filter.c:1988:52: sparse: sparse: incorrect type in argument 3 (different base types) @@     expected restricted __be32 [usertype] from @@     got unsigned long long [usertype] from @@
   net/core/filter.c:1988:52: sparse:     expected restricted __be32 [usertype] from
   net/core/filter.c:1988:52: sparse:     got unsigned long long [usertype] from
   net/core/filter.c:1988:58: sparse: sparse: incorrect type in argument 4 (different base types) @@     expected restricted __be32 [usertype] to @@     got unsigned long long [usertype] to @@
   net/core/filter.c:1988:58: sparse:     expected restricted __be32 [usertype] to
   net/core/filter.c:1988:58: sparse:     got unsigned long long [usertype] to
>> net/core/filter.c:2023:39: sparse: sparse: incorrect type in return expression (different base types) @@     expected unsigned long long @@     got restricted __sum16 @@
   net/core/filter.c:2023:39: sparse:     expected unsigned long long
   net/core/filter.c:2023:39: sparse:     got restricted __sum16
   net/core/filter.c:2026:39: sparse: sparse: incorrect type in return expression (different base types) @@     expected unsigned long long @@     got restricted __sum16 @@
   net/core/filter.c:2026:39: sparse:     expected unsigned long long
   net/core/filter.c:2026:39: sparse:     got restricted __sum16
   net/core/filter.c:2029:39: sparse: sparse: incorrect type in return expression (different base types) @@     expected unsigned long long @@     got restricted __sum16 @@
   net/core/filter.c:2029:39: sparse:     expected unsigned long long
   net/core/filter.c:2029:39: sparse:     got restricted __sum16
>> net/core/filter.c:2031:16: sparse: sparse: incorrect type in return expression (different base types) @@     expected unsigned long long @@     got restricted __wsum [usertype] seed @@
   net/core/filter.c:2031:16: sparse:     expected unsigned long long
   net/core/filter.c:2031:16: sparse:     got restricted __wsum [usertype] seed
   net/core/filter.c:2053:35: sparse: sparse: incorrect type in return expression (different base types) @@     expected unsigned long long @@     got restricted __wsum [usertype] csum @@
   net/core/filter.c:2053:35: sparse:     expected unsigned long long
   net/core/filter.c:2053:35: sparse:     got restricted __wsum [usertype] csum

vim +2023 net/core/filter.c

  1956	
  1957	BPF_CALL_5(bpf_l4_csum_replace, struct sk_buff *, skb, u32, offset,
  1958		   u64, from, u64, to, u64, flags)
  1959	{
  1960		bool is_pseudo = flags & BPF_F_PSEUDO_HDR;
  1961		bool is_mmzero = flags & BPF_F_MARK_MANGLED_0;
  1962		bool do_mforce = flags & BPF_F_MARK_ENFORCE;
  1963		__sum16 *ptr;
  1964	
  1965		if (unlikely(flags & ~(BPF_F_MARK_MANGLED_0 | BPF_F_MARK_ENFORCE |
  1966				       BPF_F_PSEUDO_HDR | BPF_F_HDR_FIELD_MASK)))
  1967			return -EINVAL;
  1968		if (unlikely(offset > 0xffff || offset & 1))
  1969			return -EFAULT;
  1970		if (unlikely(bpf_try_make_writable(skb, offset + sizeof(*ptr))))
  1971			return -EFAULT;
  1972	
  1973		ptr = (__sum16 *)(skb->data + offset);
  1974		if (is_mmzero && !do_mforce && !*ptr)
  1975			return 0;
  1976	
  1977		switch (flags & BPF_F_HDR_FIELD_MASK) {
  1978		case 0:
  1979			if (unlikely(from != 0))
  1980				return -EINVAL;
  1981	
  1982			inet_proto_csum_replace_by_diff(ptr, skb, to, is_pseudo);
  1983			break;
  1984		case 2:
> 1985			inet_proto_csum_replace2(ptr, skb, from, to, is_pseudo);
  1986			break;
  1987		case 4:
  1988			inet_proto_csum_replace4(ptr, skb, from, to, is_pseudo);
  1989			break;
  1990		default:
  1991			return -EINVAL;
  1992		}
  1993	
  1994		if (is_mmzero && !*ptr)
  1995			*ptr = CSUM_MANGLED_0;
  1996		return 0;
  1997	}
  1998	
  1999	static const struct bpf_func_proto bpf_l4_csum_replace_proto = {
  2000		.func		= bpf_l4_csum_replace,
  2001		.gpl_only	= false,
  2002		.ret_type	= RET_INTEGER,
  2003		.arg1_type	= ARG_PTR_TO_CTX,
  2004		.arg2_type	= ARG_ANYTHING,
  2005		.arg3_type	= ARG_ANYTHING,
  2006		.arg4_type	= ARG_ANYTHING,
  2007		.arg5_type	= ARG_ANYTHING,
  2008	};
  2009	
  2010	BPF_CALL_5(bpf_csum_diff, __be32 *, from, u32, from_size,
  2011		   __be32 *, to, u32, to_size, __wsum, seed)
  2012	{
  2013		/* This is quite flexible, some examples:
  2014		 *
  2015		 * from_size == 0, to_size > 0,  seed := csum --> pushing data
  2016		 * from_size > 0,  to_size == 0, seed := csum --> pulling data
  2017		 * from_size > 0,  to_size > 0,  seed := 0    --> diffing data
  2018		 *
  2019		 * Even for diffing, from_size and to_size don't need to be equal.
  2020		 */
  2021	
  2022		if (from_size && to_size)
> 2023			return csum_from32to16(csum_sub(csum_partial(to, to_size, seed),
  2024							csum_partial(from, from_size, 0)));
  2025		if (to_size)
  2026			return csum_from32to16(csum_partial(to, to_size, seed));
  2027	
  2028		if (from_size)
  2029			return csum_from32to16(~csum_partial(from, from_size, ~seed));
  2030	
> 2031		return seed;
  2032	}
  2033
diff mbox series

Patch

diff --git a/net/core/filter.c b/net/core/filter.c
index bd0d08bf76bb8..e00bec7de9edd 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -1654,18 +1654,6 @@  void sk_reuseport_prog_free(struct bpf_prog *prog)
 		bpf_prog_destroy(prog);
 }
 
-struct bpf_scratchpad {
-	union {
-		__be32 diff[MAX_BPF_STACK / sizeof(__be32)];
-		u8     buff[MAX_BPF_STACK];
-	};
-	local_lock_t	bh_lock;
-};
-
-static DEFINE_PER_CPU(struct bpf_scratchpad, bpf_sp) = {
-	.bh_lock	= INIT_LOCAL_LOCK(bh_lock),
-};
-
 static inline int __bpf_try_make_writable(struct sk_buff *skb,
 					  unsigned int write_len)
 {
@@ -2022,11 +2010,6 @@  static const struct bpf_func_proto bpf_l4_csum_replace_proto = {
 BPF_CALL_5(bpf_csum_diff, __be32 *, from, u32, from_size,
 	   __be32 *, to, u32, to_size, __wsum, seed)
 {
-	struct bpf_scratchpad *sp = this_cpu_ptr(&bpf_sp);
-	u32 diff_size = from_size + to_size;
-	int i, j = 0;
-	__wsum ret;
-
 	/* This is quite flexible, some examples:
 	 *
 	 * from_size == 0, to_size > 0,  seed := csum --> pushing data
@@ -2035,19 +2018,17 @@  BPF_CALL_5(bpf_csum_diff, __be32 *, from, u32, from_size,
 	 *
 	 * Even for diffing, from_size and to_size don't need to be equal.
 	 */
-	if (unlikely(((from_size | to_size) & (sizeof(__be32) - 1)) ||
-		     diff_size > sizeof(sp->diff)))
-		return -EINVAL;
 
-	local_lock_nested_bh(&bpf_sp.bh_lock);
-	for (i = 0; i < from_size / sizeof(__be32); i++, j++)
-		sp->diff[j] = ~from[i];
-	for (i = 0; i <   to_size / sizeof(__be32); i++, j++)
-		sp->diff[j] = to[i];
+	if (from_size && to_size)
+		return csum_from32to16(csum_sub(csum_partial(to, to_size, seed),
+						csum_partial(from, from_size, 0)));
+	if (to_size)
+		return csum_from32to16(csum_partial(to, to_size, seed));
 
-	ret = csum_partial(sp->diff, diff_size, seed);
-	local_unlock_nested_bh(&bpf_sp.bh_lock);
-	return ret;
+	if (from_size)
+		return csum_from32to16(~csum_partial(from, from_size, ~seed));
+
+	return seed;
 }
 
 static const struct bpf_func_proto bpf_csum_diff_proto = {