diff mbox series

[RFC/RFT] crypto: arm64/aes-ce - add support for CTS-CBC mode

Message ID 20180908114213.9839-1-ard.biesheuvel@linaro.org (mailing list archive)
State New, archived
Headers show
Series [RFC/RFT] crypto: arm64/aes-ce - add support for CTS-CBC mode | expand

Commit Message

Ard Biesheuvel Sept. 8, 2018, 11:42 a.m. UTC
Currently, we rely on the generic CTS chaining mode wrapper to
instantiate the cts(cbc(aes)) skcipher. Due to the high performance
of the ARMv8 Crypto Extensions AES instructions (~1 cycles per byte),
any overhead in the chaining mode layers is amplified, and so it pays
off considerably to fold the CTS handling into the core algorithm.

On Cortex-A53, this results in a ~50% speedup for smaller block sizes.

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
---
Raw performance numbers after the patch.

 arch/arm64/crypto/aes-glue.c  | 142 ++++++++++++++++++++
 arch/arm64/crypto/aes-modes.S |  73 ++++++++++
 2 files changed, 215 insertions(+)
diff mbox series

Patch

diff --git a/arch/arm64/crypto/aes-glue.c b/arch/arm64/crypto/aes-glue.c
index adcb83eb683c..0860feedbafe 100644
--- a/arch/arm64/crypto/aes-glue.c
+++ b/arch/arm64/crypto/aes-glue.c
@@ -15,6 +15,7 @@ 
 #include <crypto/internal/hash.h>
 #include <crypto/internal/simd.h>
 #include <crypto/internal/skcipher.h>
+#include <crypto/scatterwalk.h>
 #include <linux/module.h>
 #include <linux/cpufeature.h>
 #include <crypto/xts.h>
@@ -31,6 +32,8 @@ 
 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
 #define aes_xts_encrypt		ce_aes_xts_encrypt
 #define aes_xts_decrypt		ce_aes_xts_decrypt
@@ -45,6 +48,8 @@  MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
+#define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
+#define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
 #define aes_xts_encrypt		neon_aes_xts_encrypt
 #define aes_xts_decrypt		neon_aes_xts_decrypt
@@ -73,6 +78,11 @@  asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u8 const rk[],
 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
 				int rounds, int blocks, u8 iv[]);
 
+asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u8 const rk[],
+				int rounds, int bytes, u8 iv[]);
+asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u8 const rk[],
+				int rounds, int bytes, u8 iv[]);
+
 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
 				int rounds, int blocks, u8 ctr[]);
 
@@ -209,6 +219,120 @@  static int cbc_decrypt(struct skcipher_request *req)
 	return err;
 }
 
+static int cts_cbc_encrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	int err, rounds = 6 + ctx->key_length / 4;
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct skcipher_request subreq = *req;
+	struct scatterlist sg_src[2], sg_dst[2];
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct skcipher_walk walk;
+	unsigned int blocks;
+
+	if (req->cryptlen == AES_BLOCK_SIZE)
+		cbc_blocks = 1;
+
+	if (cbc_blocks > 0) {
+		skcipher_request_set_crypt(&subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+		err = skcipher_walk_virt(&walk, &subreq, false);
+
+		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+			kernel_neon_begin();
+			aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+					(u8 *)ctx->key_enc, rounds, blocks,
+					walk.iv);
+			kernel_neon_end();
+			err = skcipher_walk_done(&walk,
+						 walk.nbytes % AES_BLOCK_SIZE);
+		}
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
+		dst = scatterwalk_ffwd(sg_dst, req->dst, subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &subreq, false);
+	if (err)
+		return err;
+
+	kernel_neon_begin();
+	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			    (u8 *)ctx->key_enc, rounds, walk.nbytes, walk.iv);
+	kernel_neon_end();
+
+	return skcipher_walk_done(&walk, 0);
+}
+
+static int cts_cbc_decrypt(struct skcipher_request *req)
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
+	int err, rounds = 6 + ctx->key_length / 4;
+	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+	struct skcipher_request subreq = *req;
+	struct scatterlist sg_src[2], sg_dst[2];
+	struct scatterlist *src = req->src, *dst = req->dst;
+	struct skcipher_walk walk;
+	unsigned int blocks;
+
+	if (req->cryptlen == AES_BLOCK_SIZE)
+		cbc_blocks = 1;
+
+	if (cbc_blocks > 0) {
+		skcipher_request_set_crypt(&subreq, req->src, req->dst,
+					   cbc_blocks * AES_BLOCK_SIZE,
+					   req->iv);
+		err = skcipher_walk_virt(&walk, &subreq, false);
+
+		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
+			kernel_neon_begin();
+			aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+					(u8 *)ctx->key_dec, rounds, blocks,
+					walk.iv);
+			kernel_neon_end();
+			err = skcipher_walk_done(&walk,
+						 walk.nbytes % AES_BLOCK_SIZE);
+		}
+		if (err)
+			return err;
+
+		if (req->cryptlen == AES_BLOCK_SIZE)
+			return 0;
+
+		src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
+		dst = scatterwalk_ffwd(sg_dst, req->dst, subreq.cryptlen);
+	}
+
+	/* handle ciphertext stealing */
+	skcipher_request_set_crypt(&subreq, src, dst,
+				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
+				   req->iv);
+
+	err = skcipher_walk_virt(&walk, &subreq, false);
+	if (err)
+		return err;
+
+	kernel_neon_begin();
+	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
+			    (u8 *)ctx->key_dec, rounds, walk.nbytes, walk.iv);
+	kernel_neon_end();
+
+	return skcipher_walk_done(&walk, 0);
+}
+
 static int ctr_encrypt(struct skcipher_request *req)
 {
 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -334,6 +458,24 @@  static struct skcipher_alg aes_algs[] = { {
 	.setkey		= skcipher_aes_setkey,
 	.encrypt	= cbc_encrypt,
 	.decrypt	= cbc_decrypt,
+}, {
+	.base = {
+		.cra_name		= "__cts(cbc(aes))",
+		.cra_driver_name	= "__cts-cbc-aes-" MODE,
+		.cra_priority		= PRIO,
+		.cra_flags		= CRYPTO_ALG_INTERNAL,
+		.cra_blocksize		= 1,
+		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
+		.cra_module		= THIS_MODULE,
+	},
+	.min_keysize	= AES_MIN_KEY_SIZE,
+	.max_keysize	= AES_MAX_KEY_SIZE,
+	.ivsize		= AES_BLOCK_SIZE,
+	.chunksize	= AES_BLOCK_SIZE,
+	.walksize	= 2 * AES_BLOCK_SIZE,
+	.setkey		= skcipher_aes_setkey,
+	.encrypt	= cts_cbc_encrypt,
+	.decrypt	= cts_cbc_decrypt,
 }, {
 	.base = {
 		.cra_name		= "__ctr(aes)",
diff --git a/arch/arm64/crypto/aes-modes.S b/arch/arm64/crypto/aes-modes.S
index 483a7130cf0e..61bab20de8da 100644
--- a/arch/arm64/crypto/aes-modes.S
+++ b/arch/arm64/crypto/aes-modes.S
@@ -205,6 +205,79 @@  AES_ENTRY(aes_cbc_decrypt)
 	ret
 AES_ENDPROC(aes_cbc_decrypt)
 
+	/*
+	 * aes_cbc_cts_encrypt(u8 out[], u8 const in[], u8 const rk[],
+	 *		       int rounds, int bytes, u8 iv[])
+	 * aes_cbc_cts_decrypt(u8 out[], u8 const in[], u8 const rk[],
+	 *		       int rounds, int bytes, u8 iv[])
+	 */
+
+AES_ENTRY(aes_cbc_cts_encrypt)
+	adr		x8, .Lcts_permute_table + 48
+	sub		x9, x8, x4
+	sub		x4, x4, #16
+	sub		x8, x8, #48
+	add		x8, x8, x4
+	ld1		{v6.16b}, [x9]
+	ld1		{v7.16b}, [x8]
+
+	ld1		{v4.16b}, [x5]			/* get iv */
+	enc_prepare	w3, x2, x6
+
+	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
+	ld1		{v1.16b}, [x1]
+
+	eor		v0.16b, v0.16b, v4.16b		/* xor with iv */
+	tbl		v1.16b, {v1.16b}, v6.16b
+	encrypt_block	v0, w3, x2, x6, w7
+
+	eor		v1.16b, v1.16b, v0.16b
+	tbl		v0.16b, {v0.16b}, v7.16b
+	encrypt_block	v1, w3, x2, x6, w7
+
+	add		x4, x0, x4
+	st1		{v0.16b}, [x4]			/* overlapping stores */
+	st1		{v1.16b}, [x0]
+	ret
+AES_ENDPROC(aes_cbc_cts_encrypt)
+
+AES_ENTRY(aes_cbc_cts_decrypt)
+	adr		x8, .Lcts_permute_table + 48
+	sub		x9, x8, x4
+	sub		x4, x4, #16
+	sub		x8, x8, #48
+	add		x8, x8, x4
+	ld1		{v6.16b}, [x9]
+	ld1		{v7.16b}, [x8]
+
+	ld1		{v4.16b}, [x5]			/* get iv */
+	dec_prepare	w3, x2, x6
+
+	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
+	ld1		{v1.16b}, [x1]
+
+	tbl		v2.16b, {v1.16b}, v6.16b
+	decrypt_block	v0, w3, x2, x6, w7
+	eor		v2.16b, v2.16b, v0.16b
+
+	tbx		v0.16b, {v1.16b}, v6.16b
+	tbl		v2.16b, {v2.16b}, v7.16b
+	decrypt_block	v0, w3, x2, x6, w7
+	eor		v0.16b, v0.16b, v4.16b		/* xor with iv */
+
+	add		x4, x0, x4
+	st1		{v2.16b}, [x4]			/* overlapping stores */
+	st1		{v0.16b}, [x0]
+	ret
+AES_ENDPROC(aes_cbc_cts_decrypt)
+
+.Lcts_permute_table:
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
+	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
 
 	/*
 	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,