diff mbox series

[4/8] crypto: arm64/aes-ccm - Replace bytewise tail handling with NEON permute

Message ID 20240111123302.589910-14-ardb+git@google.com (mailing list archive)
State Superseded
Delegated to: Herbert Xu
Headers show
Series crypto: Clean up arm64 AES-CCM code | expand

Commit Message

Ard Biesheuvel Jan. 11, 2024, 12:33 p.m. UTC
From: Ard Biesheuvel <ardb@kernel.org>

Implement the CCM tail handling using a single sequence that uses
permute vectors and overlapping loads and stores, rather than going over
the tail byte by byte in a loop, and using scalar operations. This is
more efficient, even though the measured speedup is only around 1-2% on
the CPUs I have tried.

Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
---
 arch/arm64/crypto/aes-ce-ccm-core.S | 59 +++++++++++++-------
 arch/arm64/crypto/aes-ce-ccm-glue.c | 20 +++----
 2 files changed, 48 insertions(+), 31 deletions(-)

Comments

Ard Biesheuvel Jan. 11, 2024, 4:35 p.m. UTC | #1
On Thu, 11 Jan 2024 at 13:33, Ard Biesheuvel <ardb+git@google.com> wrote:
>
> From: Ard Biesheuvel <ardb@kernel.org>
>
> Implement the CCM tail handling using a single sequence that uses
> permute vectors and overlapping loads and stores, rather than going over
> the tail byte by byte in a loop, and using scalar operations. This is
> more efficient, even though the measured speedup is only around 1-2% on
> the CPUs I have tried.
>
> Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
> ---
>  arch/arm64/crypto/aes-ce-ccm-core.S | 59 +++++++++++++-------
>  arch/arm64/crypto/aes-ce-ccm-glue.c | 20 +++----
>  2 files changed, 48 insertions(+), 31 deletions(-)
>
...

The hunks below don't belong here: they were supposed to be squashed
into the previous patch.

I will fix that up for the next revision.


> diff --git a/arch/arm64/crypto/aes-ce-ccm-glue.c b/arch/arm64/crypto/aes-ce-ccm-glue.c
> index 2f4e6a318fcd..4710e59075f5 100644
> --- a/arch/arm64/crypto/aes-ce-ccm-glue.c
> +++ b/arch/arm64/crypto/aes-ce-ccm-glue.c
> @@ -181,16 +181,16 @@ static int ccm_encrypt(struct aead_request *req)
>                 if (walk.nbytes == walk.total)
>                         tail = 0;
>
> -               if (unlikely(walk.total < AES_BLOCK_SIZE))
> -                       src = dst = memcpy(buf + sizeof(buf) - walk.total,
> -                                          src, walk.total);
> +               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
> +                       src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
> +                                          src, walk.nbytes);
>
>                 ce_aes_ccm_encrypt(dst, src, walk.nbytes - tail,
>                                    ctx->key_enc, num_rounds(ctx),
>                                    mac, walk.iv);
>
> -               if (unlikely(walk.total < AES_BLOCK_SIZE))
> -                       memcpy(walk.dst.virt.addr, dst, walk.total);
> +               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
> +                       memcpy(walk.dst.virt.addr, dst, walk.nbytes);
>
>                 if (walk.nbytes == walk.total)
>                         ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
> @@ -248,16 +248,16 @@ static int ccm_decrypt(struct aead_request *req)
>                 if (walk.nbytes == walk.total)
>                         tail = 0;
>
> -               if (unlikely(walk.total < AES_BLOCK_SIZE))
> -                       src = dst = memcpy(buf + sizeof(buf) - walk.total,
> -                                          src, walk.total);
> +               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
> +                       src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
> +                                          src, walk.nbytes);
>
>                 ce_aes_ccm_decrypt(dst, src, walk.nbytes - tail,
>                                    ctx->key_enc, num_rounds(ctx),
>                                    mac, walk.iv);
>
> -               if (unlikely(walk.total < AES_BLOCK_SIZE))
> -                       memcpy(walk.dst.virt.addr, dst, walk.total);
> +               if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
> +                       memcpy(walk.dst.virt.addr, dst, walk.nbytes);
>
>                 if (walk.nbytes == walk.total)
>                         ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
> --
> 2.43.0.275.g3460e3d667-goog
>
diff mbox series

Patch

diff --git a/arch/arm64/crypto/aes-ce-ccm-core.S b/arch/arm64/crypto/aes-ce-ccm-core.S
index b03f7f71f893..b21a9b759ab2 100644
--- a/arch/arm64/crypto/aes-ce-ccm-core.S
+++ b/arch/arm64/crypto/aes-ce-ccm-core.S
@@ -1,8 +1,11 @@ 
 /* SPDX-License-Identifier: GPL-2.0-only */
 /*
- * aesce-ccm-core.S - AES-CCM transform for ARMv8 with Crypto Extensions
+ * aes-ce-ccm-core.S - AES-CCM transform for ARMv8 with Crypto Extensions
  *
- * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
+ * Copyright (C) 2013 - 2017 Linaro Ltd.
+ * Copyright (C) 2024 Google LLC
+ *
+ * Author: Ard Biesheuvel <ardb@kernel.org>
  */
 
 #include <linux/linkage.h>
@@ -168,13 +171,13 @@  CPU_LE(	rev	x8, x8			)	/* keep swabbed ctr in reg */
 	ld1	{v2.16b}, [x1], #16		/* load next input block */
 	.if	\enc == 1
 	eor	v2.16b, v2.16b, v5.16b		/* final round enc+mac */
-	eor	v1.16b, v1.16b, v2.16b		/* xor with crypted ctr */
+	eor	v6.16b, v1.16b, v2.16b		/* xor with crypted ctr */
 	.else
 	eor	v2.16b, v2.16b, v1.16b		/* xor with crypted ctr */
-	eor	v1.16b, v2.16b, v5.16b		/* final round enc */
+	eor	v6.16b, v2.16b, v5.16b		/* final round enc */
 	.endif
 	eor	v0.16b, v0.16b, v2.16b		/* xor mac with pt ^ rk[last] */
-	st1	{v1.16b}, [x0], #16		/* write output block */
+	st1	{v6.16b}, [x0], #16		/* write output block */
 	bne	0b
 CPU_LE(	rev	x8, x8			)
 	st1	{v0.16b}, [x5]			/* store mac */
@@ -183,25 +186,31 @@  CPU_LE(	rev	x8, x8			)
 
 6:	eor	v0.16b, v0.16b, v5.16b		/* final round mac */
 	eor	v1.16b, v1.16b, v5.16b		/* final round enc */
-	st1	{v0.16b}, [x5]			/* store mac */
-	add	w2, w2, #16			/* process partial tail block */
-7:	ldrb	w9, [x1], #1			/* get 1 byte of input */
-	umov	w6, v1.b[0]			/* get top crypted ctr byte */
-	umov	w7, v0.b[0]			/* get top mac byte */
+
+	add	x1, x1, w2, sxtw		/* rewind the input pointer (w2 < 0) */
+	add	x0, x0, w2, sxtw		/* rewind the output pointer */
+
+	adr_l	x8, .Lpermute			/* load permute vectors */
+	add	x9, x8, w2, sxtw
+	sub	x8, x8, w2, sxtw
+	ld1	{v7.16b-v8.16b}, [x9]
+	ld1	{v9.16b}, [x8]
+
+	ld1	{v2.16b}, [x1]			/* load a full block of input */
+	tbl	v1.16b, {v1.16b}, v7.16b	/* move keystream to end of register */
 	.if	\enc == 1
-	eor	w7, w7, w9
-	eor	w9, w9, w6
+	tbl	v7.16b, {v2.16b}, v9.16b	/* copy plaintext to start of v7 */
+	eor	v2.16b, v2.16b, v1.16b		/* encrypt partial input block */
 	.else
-	eor	w9, w9, w6
-	eor	w7, w7, w9
+	eor	v2.16b, v2.16b, v1.16b		/* decrypt partial input block */
+	tbl	v7.16b, {v2.16b}, v9.16b	/* copy plaintext to start of v7 */
 	.endif
-	strb	w9, [x0], #1			/* store out byte */
-	strb	w7, [x5], #1			/* store mac byte */
-	subs	w2, w2, #1
-	beq	5b
-	ext	v0.16b, v0.16b, v0.16b, #1	/* shift out mac byte */
-	ext	v1.16b, v1.16b, v1.16b, #1	/* shift out ctr byte */
-	b	7b
+	eor	v0.16b, v0.16b, v7.16b		/* fold plaintext into mac */
+	tbx	v2.16b, {v6.16b}, v8.16b	/* insert output from previous iteration */
+
+	st1	{v0.16b}, [x5]			/* store mac */
+	st1	{v2.16b}, [x0]			/* store output block */
+	ret
 	.endm
 
 	/*
@@ -219,3 +228,11 @@  SYM_FUNC_END(ce_aes_ccm_encrypt)
 SYM_FUNC_START(ce_aes_ccm_decrypt)
 	aes_ccm_do_crypt	0
 SYM_FUNC_END(ce_aes_ccm_decrypt)
+
+	.section ".rodata", "a"
+	.align	6
+	.fill	15, 1, 0xff
+.Lpermute:
+	.byte	0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7
+	.byte	0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf
+	.fill	15, 1, 0xff
diff --git a/arch/arm64/crypto/aes-ce-ccm-glue.c b/arch/arm64/crypto/aes-ce-ccm-glue.c
index 2f4e6a318fcd..4710e59075f5 100644
--- a/arch/arm64/crypto/aes-ce-ccm-glue.c
+++ b/arch/arm64/crypto/aes-ce-ccm-glue.c
@@ -181,16 +181,16 @@  static int ccm_encrypt(struct aead_request *req)
 		if (walk.nbytes == walk.total)
 			tail = 0;
 
-		if (unlikely(walk.total < AES_BLOCK_SIZE))
-			src = dst = memcpy(buf + sizeof(buf) - walk.total,
-					   src, walk.total);
+		if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+			src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
+					   src, walk.nbytes);
 
 		ce_aes_ccm_encrypt(dst, src, walk.nbytes - tail,
 				   ctx->key_enc, num_rounds(ctx),
 				   mac, walk.iv);
 
-		if (unlikely(walk.total < AES_BLOCK_SIZE))
-			memcpy(walk.dst.virt.addr, dst, walk.total);
+		if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+			memcpy(walk.dst.virt.addr, dst, walk.nbytes);
 
 		if (walk.nbytes == walk.total)
 			ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
@@ -248,16 +248,16 @@  static int ccm_decrypt(struct aead_request *req)
 		if (walk.nbytes == walk.total)
 			tail = 0;
 
-		if (unlikely(walk.total < AES_BLOCK_SIZE))
-			src = dst = memcpy(buf + sizeof(buf) - walk.total,
-					   src, walk.total);
+		if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+			src = dst = memcpy(&buf[sizeof(buf) - walk.nbytes],
+					   src, walk.nbytes);
 
 		ce_aes_ccm_decrypt(dst, src, walk.nbytes - tail,
 				   ctx->key_enc, num_rounds(ctx),
 				   mac, walk.iv);
 
-		if (unlikely(walk.total < AES_BLOCK_SIZE))
-			memcpy(walk.dst.virt.addr, dst, walk.total);
+		if (unlikely(walk.nbytes < AES_BLOCK_SIZE))
+			memcpy(walk.dst.virt.addr, dst, walk.nbytes);
 
 		if (walk.nbytes == walk.total)
 			ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));