diff mbox series

[v9,07/24] virt: sev-guest: Store VMPCK index to SNP guest device structure

Message ID 20240531043038.3370793-8-nikunj@amd.com (mailing list archive)
State New
Headers show
Series Add Secure TSC support for SNP guests | expand

Commit Message

Nikunj A. Dadhania May 31, 2024, 4:30 a.m. UTC
Currently, SEV guest driver retrieves the pointers to VMPCK and
os_area_msg_seqno from the secrets page. In order to get rid of this
dependency, use vmpck_id to index the appropriate key and the corresponding
message sequence number.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
---
 drivers/virt/coco/sev-guest/sev-guest.c | 67 ++++++++++++-------------
 1 file changed, 33 insertions(+), 34 deletions(-)
diff mbox series

Patch

diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index a3c0b22d2e14..0729d0b73495 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -55,8 +55,7 @@  struct snp_guest_dev {
 		struct snp_derived_key_req derived_key;
 		struct snp_ext_report_req ext_report;
 	} req;
-	u32 *os_area_msg_seqno;
-	u8 *vmpck;
+	unsigned int vmpck_id;
 };
 
 static u32 vmpck_id;
@@ -66,14 +65,17 @@  MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
 /* Mutex to serialize the shared buffer access and command handling. */
 static DEFINE_MUTEX(snp_cmd_mutex);
 
+static inline u8 *get_vmpck(struct snp_guest_dev *snp_dev)
+{
+	return snp_dev->secrets->vmpck[snp_dev->vmpck_id];
+}
+
 static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
 {
 	char zero_key[VMPCK_KEY_LEN] = {0};
+	u8 *key = get_vmpck(snp_dev);
 
-	if (snp_dev->vmpck)
-		return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
-
-	return true;
+	return !memcmp(key, zero_key, VMPCK_KEY_LEN);
 }
 
 /*
@@ -95,28 +97,23 @@  static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
  */
 static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
 {
-	dev_alert(snp_dev->dev, "Disabling VMPCK%d to prevent IV reuse.\n",
-		  vmpck_id);
-	memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
-	snp_dev->vmpck = NULL;
-}
-
-static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
-{
-	u64 count;
-
-	lockdep_assert_held(&snp_cmd_mutex);
+	u8 *key = get_vmpck(snp_dev);
 
-	/* Read the current message sequence counter from secrets pages */
-	count = *snp_dev->os_area_msg_seqno;
+	if (is_vmpck_empty(snp_dev))
+		return;
 
-	return count + 1;
+	dev_alert(snp_dev->dev, "Disabling VMPCK%u to prevent IV reuse.\n", snp_dev->vmpck_id);
+	memzero_explicit(key, VMPCK_KEY_LEN);
 }
 
 /* Return a non-zero on success */
 static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
 {
-	u64 count = __snp_get_msg_seqno(snp_dev);
+	u64 count;
+
+	lockdep_assert_held(&snp_cmd_mutex);
+
+	count = snp_dev->secrets->os_area.msg_seqno[snp_dev->vmpck_id] + 1;
 
 	/*
 	 * The message sequence counter for the SNP guest request is a  64-bit
@@ -140,7 +137,7 @@  static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
 	 * The counter is also incremented by the PSP, so increment it by 2
 	 * and save in secrets page.
 	 */
-	*snp_dev->os_area_msg_seqno += 2;
+	snp_dev->secrets->os_area.msg_seqno[snp_dev->vmpck_id] += 2;
 }
 
 static inline struct snp_guest_dev *to_snp_dev(struct file *file)
@@ -150,15 +147,17 @@  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
 	return container_of(dev, struct snp_guest_dev, misc);
 }
 
-static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
+static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
 {
 	struct aesgcm_ctx *ctx;
+	u8 *key;
 
 	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
 	if (!ctx)
 		return NULL;
 
-	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
+	key = get_vmpck(snp_dev);
+	if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
 		pr_err("Crypto context initialization failed\n");
 		kfree(ctx);
 		return NULL;
@@ -666,13 +665,14 @@  static const struct file_operations snp_guest_fops = {
 	.unlocked_ioctl = snp_guest_ioctl,
 };
 
-static u8 *get_vmpck(int id, struct snp_secrets_page *secrets, u32 **seqno)
+static bool assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
 {
-	if ((id + 1) > VMPCK_MAX_NUM)
-		return NULL;
+	if ((vmpck_id + 1) > VMPCK_MAX_NUM)
+		return false;
+
+	dev->vmpck_id = vmpck_id;
 
-	*seqno = &secrets->os_area.msg_seqno[id];
-	return secrets->vmpck[id];
+	return true;
 }
 
 struct snp_msg_report_resp_hdr {
@@ -828,21 +828,20 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 		goto e_unmap;
 
 	ret = -EINVAL;
-	snp_dev->vmpck = get_vmpck(vmpck_id, secrets, &snp_dev->os_area_msg_seqno);
-	if (!snp_dev->vmpck) {
+	snp_dev->secrets = secrets;
+	if (!assign_vmpck(snp_dev, vmpck_id)) {
 		dev_err(dev, "Invalid VMPCK%d communication key\n", vmpck_id);
 		goto e_unmap;
 	}
 
 	/* Verify that VMPCK is not zero. */
 	if (is_vmpck_empty(snp_dev)) {
-		dev_err(dev, "Empty VMPCK%d communication key\n", vmpck_id);
+		dev_err(dev, "Empty VMPCK%d communication key\n", snp_dev->vmpck_id);
 		goto e_unmap;
 	}
 
 	platform_set_drvdata(pdev, snp_dev);
 	snp_dev->dev = dev;
-	snp_dev->secrets = secrets;
 
 	/* Allocate secret request and response message for double buffering */
 	snp_dev->secret_request = kzalloc(SNP_GUEST_MSG_SIZE, GFP_KERNEL);
@@ -867,7 +866,7 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 		goto e_free_response;
 
 	ret = -EIO;
-	snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
+	snp_dev->ctx = snp_init_crypto(snp_dev);
 	if (!snp_dev->ctx)
 		goto e_free_cert_data;