diff mbox series

crypto: ablkcipher - Unmap pages after an external error

Message ID 20190903070501.GA9978@gondor.apana.org.au (mailing list archive)
State Changes Requested
Delegated to: Herbert Xu
Headers show
Series crypto: ablkcipher - Unmap pages after an external error | expand

Commit Message

Herbert Xu Sept. 3, 2019, 7:05 a.m. UTC
ablkcipher_walk_done may be called with an error by internal or
external callers.  For those internal callers we shouldn't unmap
pages but for external callers we must unmap any pages that are
in use.

This patch adds a new function ablkcipher_walk_unwind so that we
can eliminate the internal callers.

Reported-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>
Fixes: 318abdfbe708 ("crypto: ablkcipher - fix crash flushing...")
Cc: <stable@vger.kernel.org>
Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
diff mbox series

Patch

diff --git a/crypto/ablkcipher.c b/crypto/ablkcipher.c
index 072b5646a0a3..a61d13fabe3c 100644
--- a/crypto/ablkcipher.c
+++ b/crypto/ablkcipher.c
@@ -66,9 +66,11 @@  static inline u8 *ablkcipher_get_spot(u8 *start, unsigned int len)
 	return max(start, end_page);
 }
 
-static inline void ablkcipher_done_slow(struct ablkcipher_walk *walk,
-					unsigned int n)
+static inline unsigned int ablkcipher_done_slow(struct ablkcipher_walk *walk,
+						unsigned int bsize)
 {
+	unsigned int n = bsize;
+
 	for (;;) {
 		unsigned int len_this_page = scatterwalk_pagelen(&walk->out);
 
@@ -80,59 +82,73 @@  static inline void ablkcipher_done_slow(struct ablkcipher_walk *walk,
 		n -= len_this_page;
 		scatterwalk_start(&walk->out, sg_next(walk->out.sg));
 	}
+
+	return bsize;
 }
 
-static inline void ablkcipher_done_fast(struct ablkcipher_walk *walk,
-					unsigned int n)
+static inline unsigned int ablkcipher_done_fast(struct ablkcipher_walk *walk,
+						unsigned int n)
 {
 	scatterwalk_advance(&walk->in, n);
 	scatterwalk_advance(&walk->out, n);
+
+	return n;
 }
 
+static int ablkcipher_walk_unwind(struct ablkcipher_request *req,
+				  struct ablkcipher_walk *walk, int err);
 static int ablkcipher_walk_next(struct ablkcipher_request *req,
 				struct ablkcipher_walk *walk);
 
 int ablkcipher_walk_done(struct ablkcipher_request *req,
 			 struct ablkcipher_walk *walk, int err)
 {
-	struct crypto_tfm *tfm = req->base.tfm;
-	unsigned int n; /* bytes processed */
-	bool more;
-
-	if (unlikely(err < 0))
-		goto finish;
+	unsigned int nbytes = 0;
 
-	n = walk->nbytes - err;
-	walk->total -= n;
-	more = (walk->total != 0);
+	if (likely(err >= 0)) {
+		unsigned int n = walk->nbytes - err;
 
-	if (likely(!(walk->flags & ABLKCIPHER_WALK_SLOW))) {
-		ablkcipher_done_fast(walk, n);
-	} else {
-		if (WARN_ON(err)) {
-			/* unexpected case; didn't process all bytes */
+		if (likely(!(walk->flags & ABLKCIPHER_WALK_SLOW)))
+			n = ablkcipher_done_fast(walk, n);
+		else if (WARN_ON(err)) {
 			err = -EINVAL;
-			goto finish;
-		}
-		ablkcipher_done_slow(walk, n);
+			goto err;
+		} else
+			n = ablkcipher_done_slow(walk, n);
+
+		nbytes = walk->total - n;
+		err = 0;
 	}
 
-	scatterwalk_done(&walk->in, 0, more);
-	scatterwalk_done(&walk->out, 1, more);
+	scatterwalk_done(&walk->in, 0, nbytes);
+	scatterwalk_done(&walk->out, 1, nbytes);
+
+err:
+	walk->total = nbytes;
+	walk->nbytes = nbytes;
 
-	if (more) {
+	if (nbytes) {
 		crypto_yield(req->base.flags);
 		return ablkcipher_walk_next(req, walk);
 	}
-	err = 0;
-finish:
+
+	return ablkcipher_walk_unwind(req, walk, err);
+}
+EXPORT_SYMBOL_GPL(ablkcipher_walk_done);
+
+static int ablkcipher_walk_unwind(struct ablkcipher_request *req,
+				  struct ablkcipher_walk *walk, int err)
+{
+	struct crypto_tfm *tfm = req->base.tfm;
+
 	walk->nbytes = 0;
+
 	if (walk->iv != req->info)
 		memcpy(req->info, walk->iv, tfm->crt_ablkcipher.ivsize);
 	kfree(walk->iv_buffer);
+
 	return err;
 }
-EXPORT_SYMBOL_GPL(ablkcipher_walk_done);
 
 static inline int ablkcipher_next_slow(struct ablkcipher_request *req,
 				       struct ablkcipher_walk *walk,
@@ -151,7 +167,7 @@  static inline int ablkcipher_next_slow(struct ablkcipher_request *req,
 
 	p = kmalloc(n, GFP_ATOMIC);
 	if (!p)
-		return ablkcipher_walk_done(req, walk, -ENOMEM);
+		return ablkcipher_walk_unwind(req, walk, -ENOMEM);
 
 	base = p + 1;
 
@@ -222,7 +238,7 @@  static int ablkcipher_walk_next(struct ablkcipher_request *req,
 	n = walk->total;
 	if (unlikely(n < crypto_tfm_alg_blocksize(tfm))) {
 		req->base.flags |= CRYPTO_TFM_RES_BAD_BLOCK_LEN;
-		return ablkcipher_walk_done(req, walk, -EINVAL);
+		return ablkcipher_walk_unwind(req, walk, -EINVAL);
 	}
 
 	walk->flags &= ~ABLKCIPHER_WALK_SLOW;