@@ -1,8 +1,11 @@
// SPDX-License-Identifier: GPL-2.0-only
/*
- * aes-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions
+ * aes-ce-ccm-glue.c - AES-CCM transform for ARMv8 with Crypto Extensions
*
- * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
+ * Copyright (C) 2013 - 2017 Linaro Ltd.
+ * Copyright (C) 2024 Google LLC
+ *
+ * Author: Ard Biesheuvel <ardb@kernel.org>
*/
#include <asm/neon.h>
@@ -149,7 +152,7 @@ static int ccm_encrypt(struct aead_request *req)
struct crypto_aes_ctx *ctx = crypto_aead_ctx(aead);
struct skcipher_walk walk;
u8 __aligned(8) mac[AES_BLOCK_SIZE];
- u8 buf[AES_BLOCK_SIZE];
+ u8 orig_iv[AES_BLOCK_SIZE];
u32 len = req->cryptlen;
int err;
@@ -158,7 +161,7 @@ static int ccm_encrypt(struct aead_request *req)
return err;
/* preserve the original iv for the final round */
- memcpy(buf, req->iv, AES_BLOCK_SIZE);
+ memcpy(orig_iv, req->iv, AES_BLOCK_SIZE);
err = skcipher_walk_aead_encrypt(&walk, req, false);
if (unlikely(err))
@@ -171,16 +174,26 @@ static int ccm_encrypt(struct aead_request *req)
do {
u32 tail = walk.nbytes % AES_BLOCK_SIZE;
+ const u8 *src = walk.src.virt.addr;
+ u8 *dst = walk.dst.virt.addr;
+ u8 buf[AES_BLOCK_SIZE];
if (walk.nbytes == walk.total)
tail = 0;
- ce_aes_ccm_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
- walk.nbytes - tail, ctx->key_enc,
- num_rounds(ctx), mac, walk.iv);
+ if (unlikely(walk.total < AES_BLOCK_SIZE))
+ src = dst = memcpy(buf + sizeof(buf) - walk.total,
+ src, walk.total);
+
+ ce_aes_ccm_encrypt(dst, src, walk.nbytes - tail,
+ ctx->key_enc, num_rounds(ctx),
+ mac, walk.iv);
+
+ if (unlikely(walk.total < AES_BLOCK_SIZE))
+ memcpy(walk.dst.virt.addr, dst, walk.total);
if (walk.nbytes == walk.total)
- ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+ ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
if (walk.nbytes) {
err = skcipher_walk_done(&walk, tail);
@@ -206,7 +219,7 @@ static int ccm_decrypt(struct aead_request *req)
unsigned int authsize = crypto_aead_authsize(aead);
struct skcipher_walk walk;
u8 __aligned(8) mac[AES_BLOCK_SIZE];
- u8 buf[AES_BLOCK_SIZE];
+ u8 orig_iv[AES_BLOCK_SIZE];
u32 len = req->cryptlen - authsize;
int err;
@@ -215,7 +228,7 @@ static int ccm_decrypt(struct aead_request *req)
return err;
/* preserve the original iv for the final round */
- memcpy(buf, req->iv, AES_BLOCK_SIZE);
+ memcpy(orig_iv, req->iv, AES_BLOCK_SIZE);
err = skcipher_walk_aead_decrypt(&walk, req, false);
if (unlikely(err))
@@ -228,16 +241,26 @@ static int ccm_decrypt(struct aead_request *req)
do {
u32 tail = walk.nbytes % AES_BLOCK_SIZE;
+ const u8 *src = walk.src.virt.addr;
+ u8 *dst = walk.dst.virt.addr;
+ u8 buf[AES_BLOCK_SIZE];
if (walk.nbytes == walk.total)
tail = 0;
- ce_aes_ccm_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
- walk.nbytes - tail, ctx->key_enc,
- num_rounds(ctx), mac, walk.iv);
+ if (unlikely(walk.total < AES_BLOCK_SIZE))
+ src = dst = memcpy(buf + sizeof(buf) - walk.total,
+ src, walk.total);
+
+ ce_aes_ccm_decrypt(dst, src, walk.nbytes - tail,
+ ctx->key_enc, num_rounds(ctx),
+ mac, walk.iv);
+
+ if (unlikely(walk.total < AES_BLOCK_SIZE))
+ memcpy(walk.dst.virt.addr, dst, walk.total);
if (walk.nbytes == walk.total)
- ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+ ce_aes_ccm_final(mac, orig_iv, ctx->key_enc, num_rounds(ctx));
if (walk.nbytes) {
err = skcipher_walk_done(&walk, tail);
@@ -250,11 +273,11 @@ static int ccm_decrypt(struct aead_request *req)
return err;
/* compare calculated auth tag with the stored one */
- scatterwalk_map_and_copy(buf, req->src,
+ scatterwalk_map_and_copy(orig_iv, req->src,
req->assoclen + req->cryptlen - authsize,
authsize, 0);
- if (crypto_memneq(mac, buf, authsize))
+ if (crypto_memneq(mac, orig_iv, authsize))
return -EBADMSG;
return 0;
}
@@ -293,6 +316,6 @@ module_init(aes_mod_init);
module_exit(aes_mod_exit);
MODULE_DESCRIPTION("Synchronous AES in CCM mode using ARMv8 Crypto Extensions");
-MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
+MODULE_AUTHOR("Ard Biesheuvel <ardb@kernel.org>");
MODULE_LICENSE("GPL v2");
MODULE_ALIAS_CRYPTO("ccm(aes)");