@@ -110,6 +110,8 @@ struct qat_dh_ctx {
unsigned int p_size;
bool g2;
struct qat_crypto_instance *inst;
+ struct crypto_kpp *ftfm;
+ bool fallback;
} __packed __aligned(64);
struct qat_asym_request {
@@ -381,6 +383,36 @@ static int qat_dh_compute_value(struct kpp_request *req)
return ret;
}
+static int qat_dh_generate_public_key(struct kpp_request *req)
+{
+ struct kpp_request *nreq = kpp_request_ctx(req);
+ struct crypto_kpp *tfm = crypto_kpp_reqtfm(req);
+ struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
+
+ if (ctx->fallback) {
+ memcpy(nreq, req, sizeof(*req));
+ kpp_request_set_tfm(nreq, ctx->ftfm);
+ return crypto_kpp_generate_public_key(nreq);
+ }
+
+ return qat_dh_compute_value(req);
+}
+
+static int qat_dh_compute_shared_secret(struct kpp_request *req)
+{
+ struct kpp_request *nreq = kpp_request_ctx(req);
+ struct crypto_kpp *tfm = crypto_kpp_reqtfm(req);
+ struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
+
+ if (ctx->fallback) {
+ memcpy(nreq, req, sizeof(*req));
+ kpp_request_set_tfm(nreq, ctx->ftfm);
+ return crypto_kpp_compute_shared_secret(nreq);
+ }
+
+ return qat_dh_compute_value(req);
+}
+
static int qat_dh_check_params_length(unsigned int p_len)
{
switch (p_len) {
@@ -398,9 +430,6 @@ static int qat_dh_set_params(struct qat_dh_ctx *ctx, struct dh *params)
struct qat_crypto_instance *inst = ctx->inst;
struct device *dev = &GET_DEV(inst->accel_dev);
- if (qat_dh_check_params_length(params->p_size << 3))
- return -EINVAL;
-
ctx->p_size = params->p_size;
ctx->p = dma_alloc_coherent(dev, ctx->p_size, &ctx->dma_p, GFP_KERNEL);
if (!ctx->p)
@@ -454,6 +483,13 @@ static int qat_dh_set_secret(struct crypto_kpp *tfm, const void *buf,
if (crypto_dh_decode_key(buf, len, ¶ms) < 0)
return -EINVAL;
+ if (qat_dh_check_params_length(params.p_size << 3)) {
+ ctx->fallback = true;
+ return crypto_kpp_set_secret(ctx->ftfm, buf, len);
+ }
+
+ ctx->fallback = false;
+
/* Free old secret if any */
qat_dh_clear_ctx(dev, ctx);
@@ -481,6 +517,9 @@ static unsigned int qat_dh_max_size(struct crypto_kpp *tfm)
{
struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
+ if (ctx->fallback)
+ return crypto_kpp_maxsize(ctx->ftfm);
+
return ctx->p_size;
}
@@ -489,11 +528,22 @@ static int qat_dh_init_tfm(struct crypto_kpp *tfm)
struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
struct qat_crypto_instance *inst =
qat_crypto_get_instance_node(numa_node_id());
+ const char *alg = kpp_alg_name(tfm);
+ unsigned int reqsize;
if (!inst)
return -EINVAL;
- kpp_set_reqsize(tfm, sizeof(struct qat_asym_request) + 64);
+ ctx->ftfm = crypto_alloc_kpp(alg, 0, CRYPTO_ALG_NEED_FALLBACK);
+ if (IS_ERR(ctx->ftfm))
+ return PTR_ERR(ctx->ftfm);
+
+ crypto_kpp_set_flags(ctx->ftfm, crypto_kpp_get_flags(tfm));
+
+ reqsize = max(sizeof(struct qat_asym_request) + 64,
+ sizeof(struct kpp_request) + crypto_kpp_reqsize(ctx->ftfm));
+
+ kpp_set_reqsize(tfm, reqsize);
ctx->p_size = 0;
ctx->g2 = false;
@@ -506,6 +556,9 @@ static void qat_dh_exit_tfm(struct crypto_kpp *tfm)
struct qat_dh_ctx *ctx = kpp_tfm_ctx(tfm);
struct device *dev = &GET_DEV(ctx->inst->accel_dev);
+ if (ctx->ftfm)
+ crypto_free_kpp(ctx->ftfm);
+
qat_dh_clear_ctx(dev, ctx);
qat_crypto_put_instance(ctx->inst);
}
@@ -1265,8 +1318,8 @@ static struct akcipher_alg rsa = {
static struct kpp_alg dh = {
.set_secret = qat_dh_set_secret,
- .generate_public_key = qat_dh_compute_value,
- .compute_shared_secret = qat_dh_compute_value,
+ .generate_public_key = qat_dh_generate_public_key,
+ .compute_shared_secret = qat_dh_compute_shared_secret,
.max_size = qat_dh_max_size,
.init = qat_dh_init_tfm,
.exit = qat_dh_exit_tfm,
@@ -1276,6 +1329,7 @@ static struct kpp_alg dh = {
.cra_priority = 1000,
.cra_module = THIS_MODULE,
.cra_ctxsize = sizeof(struct qat_dh_ctx),
+ .cra_flags = CRYPTO_ALG_NEED_FALLBACK,
},
};