diff mbox series

[v2,27/29] crypto: skcipher - use the new scatterwalk functions

Message ID 20241230001418.74739-28-ebiggers@kernel.org (mailing list archive)
State Under Review
Delegated to: Herbert Xu
Headers show
Series crypto: scatterlist handling improvements | expand

Commit Message

Eric Biggers Dec. 30, 2024, 12:14 a.m. UTC
From: Eric Biggers <ebiggers@google.com>

Convert skcipher_walk to use the new scatterwalk functions.

This includes a few changes to exactly where the different parts of the
iteration happen.  For example the dcache flush that previously happened
in scatterwalk_done() now happens in scatterwalk_dst_done() or in
memcpy_to_scatterwalk().  Advancing to the next sg entry now happens
just-in-time in scatterwalk_clamp() instead of in scatterwalk_done().

Signed-off-by: Eric Biggers <ebiggers@google.com>
---
 crypto/skcipher.c | 51 ++++++++++++++++++-----------------------------
 1 file changed, 19 insertions(+), 32 deletions(-)
diff mbox series

Patch

diff --git a/crypto/skcipher.c b/crypto/skcipher.c
index 7abafe385fd5..8f6b09377368 100644
--- a/crypto/skcipher.c
+++ b/crypto/skcipher.c
@@ -46,20 +46,10 @@  static inline void skcipher_map_src(struct skcipher_walk *walk)
 static inline void skcipher_map_dst(struct skcipher_walk *walk)
 {
 	walk->dst.virt.addr = scatterwalk_map(&walk->out);
 }
 
-static inline void skcipher_unmap_src(struct skcipher_walk *walk)
-{
-	scatterwalk_unmap(walk->src.virt.addr);
-}
-
-static inline void skcipher_unmap_dst(struct skcipher_walk *walk)
-{
-	scatterwalk_unmap(walk->dst.virt.addr);
-}
-
 static inline gfp_t skcipher_walk_gfp(struct skcipher_walk *walk)
 {
 	return walk->flags & SKCIPHER_WALK_SLEEP ? GFP_KERNEL : GFP_ATOMIC;
 }
 
@@ -67,18 +57,10 @@  static inline struct skcipher_alg *__crypto_skcipher_alg(
 	struct crypto_alg *alg)
 {
 	return container_of(alg, struct skcipher_alg, base);
 }
 
-static int skcipher_done_slow(struct skcipher_walk *walk, unsigned int bsize)
-{
-	u8 *addr = PTR_ALIGN(walk->buffer, walk->alignmask + 1);
-
-	scatterwalk_copychunks(addr, &walk->out, bsize, 1);
-	return 0;
-}
-
 /**
  * skcipher_walk_done() - finish one step of a skcipher_walk
  * @walk: the skcipher_walk
  * @res: number of bytes *not* processed (>= 0) from walk->nbytes,
  *	 or a -errno value to terminate the walk due to an error
@@ -109,44 +91,45 @@  int skcipher_walk_done(struct skcipher_walk *walk, int res)
 	}
 
 	if (likely(!(walk->flags & (SKCIPHER_WALK_SLOW |
 				    SKCIPHER_WALK_COPY |
 				    SKCIPHER_WALK_DIFF)))) {
-unmap_src:
-		skcipher_unmap_src(walk);
+		scatterwalk_advance(&walk->in, n);
 	} else if (walk->flags & SKCIPHER_WALK_DIFF) {
-		skcipher_unmap_dst(walk);
-		goto unmap_src;
+		scatterwalk_unmap(walk->src.virt.addr);
+		scatterwalk_advance(&walk->in, n);
 	} else if (walk->flags & SKCIPHER_WALK_COPY) {
+		scatterwalk_advance(&walk->in, n);
 		skcipher_map_dst(walk);
 		memcpy(walk->dst.virt.addr, walk->page, n);
-		skcipher_unmap_dst(walk);
 	} else { /* SKCIPHER_WALK_SLOW */
 		if (res > 0) {
 			/*
 			 * Didn't process all bytes.  Either the algorithm is
 			 * broken, or this was the last step and it turned out
 			 * the message wasn't evenly divisible into blocks but
 			 * the algorithm requires it.
 			 */
 			res = -EINVAL;
 			total = 0;
-		} else
-			n = skcipher_done_slow(walk, n);
+		} else {
+			u8 *buf = PTR_ALIGN(walk->buffer, walk->alignmask + 1);
+
+			memcpy_to_scatterwalk(&walk->out, buf, n);
+		}
+		goto dst_done;
 	}
 
+	scatterwalk_done_dst(&walk->out, walk->dst.virt.addr, n);
+dst_done:
+
 	if (res > 0)
 		res = 0;
 
 	walk->total = total;
 	walk->nbytes = 0;
 
-	scatterwalk_advance(&walk->in, n);
-	scatterwalk_advance(&walk->out, n);
-	scatterwalk_done(&walk->in, 0, total);
-	scatterwalk_done(&walk->out, 1, total);
-
 	if (total) {
 		if (walk->flags & SKCIPHER_WALK_SLEEP)
 			cond_resched();
 		walk->flags &= ~(SKCIPHER_WALK_SLOW | SKCIPHER_WALK_COPY |
 				 SKCIPHER_WALK_DIFF);
@@ -189,11 +172,11 @@  static int skcipher_next_slow(struct skcipher_walk *walk, unsigned int bsize)
 		walk->buffer = buffer;
 	}
 	walk->dst.virt.addr = PTR_ALIGN(buffer, alignmask + 1);
 	walk->src.virt.addr = walk->dst.virt.addr;
 
-	scatterwalk_copychunks(walk->src.virt.addr, &walk->in, bsize, 0);
+	memcpy_from_scatterwalk(walk->src.virt.addr, &walk->in, bsize);
 
 	walk->nbytes = bsize;
 	walk->flags |= SKCIPHER_WALK_SLOW;
 
 	return 0;
@@ -203,11 +186,15 @@  static int skcipher_next_copy(struct skcipher_walk *walk)
 {
 	u8 *tmp = walk->page;
 
 	skcipher_map_src(walk);
 	memcpy(tmp, walk->src.virt.addr, walk->nbytes);
-	skcipher_unmap_src(walk);
+	scatterwalk_unmap(walk->src.virt.addr);
+	/*
+	 * walk->in is advanced later when the number of bytes actually
+	 * processed (which might be less than walk->nbytes) is known.
+	 */
 
 	walk->src.virt.addr = tmp;
 	walk->dst.virt.addr = tmp;
 	return 0;
 }