diff mbox series

[v9,03/24] virt: sev-guest: Make payload a variable length array

Message ID 20240531043038.3370793-4-nikunj@amd.com (mailing list archive)
State New
Headers show
Series Add Secure TSC support for SNP guests | expand

Commit Message

Nikunj A. Dadhania May 31, 2024, 4:30 a.m. UTC
Currently, guest message is PAGE_SIZE bytes and payload is hard-coded to
4000 bytes, assuming snp_guest_msg_hdr structure as 96 bytes.

Remove the structure size assumption and hard-coding of payload size and
instead use variable length array.

While at it, rename the local guest message variables for clarity.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
Suggested-by: Tom Lendacky <thomas.lendacky@amd.com>
---
 drivers/virt/coco/sev-guest/sev-guest.h |  5 +-
 drivers/virt/coco/sev-guest/sev-guest.c | 74 +++++++++++++++----------
 2 files changed, 48 insertions(+), 31 deletions(-)

Comments

Tom Lendacky June 17, 2024, 8:51 p.m. UTC | #1
On 5/30/24 23:30, Nikunj A Dadhania wrote:
> Currently, guest message is PAGE_SIZE bytes and payload is hard-coded to
> 4000 bytes, assuming snp_guest_msg_hdr structure as 96 bytes.
> 
> Remove the structure size assumption and hard-coding of payload size and
> instead use variable length array.
> 
> While at it, rename the local guest message variables for clarity.
> 
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
> Suggested-by: Tom Lendacky <thomas.lendacky@amd.com>

Reviewed-by: Tom Lendacky <thomas.lendacky@amd.com>

> ---
>  drivers/virt/coco/sev-guest/sev-guest.h |  5 +-
>  drivers/virt/coco/sev-guest/sev-guest.c | 74 +++++++++++++++----------
>  2 files changed, 48 insertions(+), 31 deletions(-)
>
Borislav Petkov June 21, 2024, 4:54 p.m. UTC | #2
On Fri, May 31, 2024 at 10:00:17AM +0530, Nikunj A Dadhania wrote:
> Currently, guest message is PAGE_SIZE bytes and payload is hard-coded to
> 4000 bytes, assuming snp_guest_msg_hdr structure as 96 bytes.
> 
> Remove the structure size assumption and hard-coding of payload size and
> instead use variable length array.

I don't understand here what hard-coding is being removed?

It is simply done differently:

from

> -     snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));

to

> +     snp_dev->request = alloc_shared_pages(dev, SNP_GUEST_MSG_SIZE);

Maybe I'm missing the point here but do you mean by removing the hard-coding
this:

+#define SNP_GUEST_MSG_SIZE 4096
+#define SNP_GUEST_MSG_PAYLOAD_SIZE (SNP_GUEST_MSG_SIZE - sizeof(struct snp_guest_msg))

where the msg payload size will get computed at build time and you won't have
to do that 4000 in the struct definition:

	u8 payload[4000];

?
Nikunj A. Dadhania June 23, 2024, 4:16 p.m. UTC | #3
On 6/21/2024 10:24 PM, Borislav Petkov wrote:
> On Fri, May 31, 2024 at 10:00:17AM +0530, Nikunj A Dadhania wrote:
>> Currently, guest message is PAGE_SIZE bytes and payload is hard-coded to
>> 4000 bytes, assuming snp_guest_msg_hdr structure as 96 bytes.
>>
>> Remove the structure size assumption and hard-coding of payload size and
>> instead use variable length array.
> 
> I don't understand here what hard-coding is being removed?
> 
> It is simply done differently:
> 
> from
> 
>> -     snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> 
> to
> 
>> +     snp_dev->request = alloc_shared_pages(dev, SNP_GUEST_MSG_SIZE);
> 
> Maybe I'm missing the point here but do you mean by removing the hard-coding
> this:
> 
> +#define SNP_GUEST_MSG_SIZE 4096
> +#define SNP_GUEST_MSG_PAYLOAD_SIZE (SNP_GUEST_MSG_SIZE - sizeof(struct snp_guest_msg))
> 
> where the msg payload size will get computed at build time and you won't have
> to do that 4000 in the struct definition:
> 
> 	u8 payload[4000];
> 
> ?

Yes, payload was earlier fixed at 4000 bytes, without considering the size of snp_guest_msg.

Regards
Nikunj
Borislav Petkov June 24, 2024, 6:11 a.m. UTC | #4
On Sun, Jun 23, 2024 at 09:46:09PM +0530, Nikunj A. Dadhania wrote:
> Yes, payload was earlier fixed at 4000 bytes, without considering the size
> of snp_guest_msg.

Sorry, you'd need to try explaining this again. Who wasn't considering the
size of snp_guest_msg?

AFAICT, the code currently does sizeof(struct snp_guest_msg) which contains
both the header *and* the payload.

What could help is if you structure your commit message this way:

1. Prepare the context for the explanation briefly.

2. Explain the problem at hand.

3. "It happens because of <...>"

4. "Fix it by doing X"

5. "(Potentially do Y)."

And some of those above are optional depending on the issue being
explained.

Thx.
Nikunj A. Dadhania June 24, 2024, 10:03 a.m. UTC | #5
On 6/24/2024 11:41 AM, Borislav Petkov wrote:
> On Sun, Jun 23, 2024 at 09:46:09PM +0530, Nikunj A. Dadhania wrote:
>> Yes, payload was earlier fixed at 4000 bytes, without considering the size
>> of snp_guest_msg.
> 
> Sorry, you'd need to try explaining this again. Who wasn't considering the
> size of snp_guest_msg?

Sorry, I meant snp_guest_msg_hdr here.

snp_guest_msg includes header and payload. There is an implicit assumption 
that the snp_guest_msg_hdr will always be 96 bytes, and with that assumption 
the payload array size is set to 4000 bytes magic number. 


> AFAICT, the code currently does sizeof(struct snp_guest_msg) which contains
> both the header *and* the payload.
> 
> What could help is if you structure your commit message this way:

How about the below commit message:

-----------------------------------------------------------------------
Currently, snp_guest_msg includes a message header (96 bytes) and a
payload (4000 bytes). There is an implicit assumption here that the SNP
message header will always be 96 bytes, and with that assumption the
payload array size has been set to 4000 bytes magic number. If any new
member is added to the SNP message header, the SNP guest message will
span more than a page.

Instead of using magic number '4000' for the payload array in the
snp_guest_msg structure, use a variable length array for payload. Allocate 
snp_guest_msg of constant size (SNP_GUEST_MSG_SIZE=4096). This will ensure
that message size won't grow beyond the page size even if the message header
size increases. Also, add SNP_GUEST_MSG_PAYLOAD_SIZE for checking buffer
over runs.

While at it, rename the local guest message variables for clarity.
-----------------------------------------------------------------------

Regards
Nikunj
Tom Lendacky June 24, 2024, 1 p.m. UTC | #6
On 6/23/24 11:16, Nikunj A. Dadhania wrote:
> On 6/21/2024 10:24 PM, Borislav Petkov wrote:
>> On Fri, May 31, 2024 at 10:00:17AM +0530, Nikunj A Dadhania wrote:
>>> Currently, guest message is PAGE_SIZE bytes and payload is hard-coded to
>>> 4000 bytes, assuming snp_guest_msg_hdr structure as 96 bytes.
>>>
>>> Remove the structure size assumption and hard-coding of payload size and
>>> instead use variable length array.
>>
>> I don't understand here what hard-coding is being removed?
>>
>> It is simply done differently:
>>
>> from
>>
>>> -     snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
>>
>> to
>>
>>> +     snp_dev->request = alloc_shared_pages(dev, SNP_GUEST_MSG_SIZE);
>>
>> Maybe I'm missing the point here but do you mean by removing the hard-coding
>> this:
>>
>> +#define SNP_GUEST_MSG_SIZE 4096
>> +#define SNP_GUEST_MSG_PAYLOAD_SIZE (SNP_GUEST_MSG_SIZE - sizeof(struct snp_guest_msg))
>>
>> where the msg payload size will get computed at build time and you won't have
>> to do that 4000 in the struct definition:
>>
>> 	u8 payload[4000];
>>
>> ?
> 
> Yes, payload was earlier fixed at 4000 bytes, without considering the size of snp_guest_msg.

An alternative to the #defines would be something like:

struct snp_guest_msg {
	struct snp_guest_msg_hdr hdr;
	u8 payload[PAGE_SIZE - sizeof(struct snp_guest_msg_hdr)];
 } __packed;

Not sure it matters, but does reduce the changes while ensuring the
payload plus header doesn't exceed a page.

Thanks,
Tom

> 
> Regards
> Nikunj
Borislav Petkov June 24, 2024, 1:39 p.m. UTC | #7
On Mon, Jun 24, 2024 at 08:00:38AM -0500, Tom Lendacky wrote:
> An alternative to the #defines would be something like:
> 
> struct snp_guest_msg {
> 	struct snp_guest_msg_hdr hdr;
> 	u8 payload[PAGE_SIZE - sizeof(struct snp_guest_msg_hdr)];
>  } __packed;
> 
> Not sure it matters, but does reduce the changes while ensuring the
> payload plus header doesn't exceed a page.

Yeah, because that would've been my next question - the requirement to keep it
<= PAGE_SIZE.

So yeah, Nikunj, please do that also and add a 

	BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE);

somewhere in the driver to catch all kinds of funky stuff.

Thx.
Nikunj A. Dadhania June 24, 2024, 6:12 p.m. UTC | #8
On 6/24/2024 7:09 PM, Borislav Petkov wrote:
> On Mon, Jun 24, 2024 at 08:00:38AM -0500, Tom Lendacky wrote:
>> An alternative to the #defines would be something like:
>>
>> struct snp_guest_msg {
>> 	struct snp_guest_msg_hdr hdr;
>> 	u8 payload[PAGE_SIZE - sizeof(struct snp_guest_msg_hdr)];
>>  } __packed;
>>
>> Not sure it matters, but does reduce the changes while ensuring the
>> payload plus header doesn't exceed a page.

Yes, it does reduce a lot of churn.

> 
> Yeah, because that would've been my next question - the requirement to keep it
> <= PAGE_SIZE.
> 
> So yeah, Nikunj, please do that also and add a 
> 
> 	BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE);
> 
> somewhere in the driver to catch all kinds of funky stuff.

Sure, here is the new patch. I have separated the variable name changes to a new patch.

Subject: [PATCH] virt: sev-guest: Ensure the SNP guest messages do not exceed
 a page

Currently, snp_guest_msg includes a message header (96 bytes) and a
payload (4000 bytes). There is an implicit assumption here that the SNP
message header will always be 96 bytes, and with that assumption the
payload array size has been set to 4000 bytes magic number. If any new
member is added to the SNP message header, the SNP guest message will span
more than a page.

Instead of using magic number '4000' for the payload, declare the
snp_guest_msg in a way that payload plus the message header do not exceed a
page.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
Suggested-by: Tom Lendacky <thomas.lendacky@amd.com>
---
 drivers/virt/coco/sev-guest/sev-guest.h | 2 +-
 drivers/virt/coco/sev-guest/sev-guest.c | 3 +++
 2 files changed, 4 insertions(+), 1 deletion(-)

diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/drivers/virt/coco/sev-guest/sev-guest.h
index ceb798a404d6..de14a4f01b9d 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.h
+++ b/drivers/virt/coco/sev-guest/sev-guest.h
@@ -60,7 +60,7 @@ struct snp_guest_msg_hdr {
 
 struct snp_guest_msg {
 	struct snp_guest_msg_hdr hdr;
-	u8 payload[4000];
+	u8 payload[PAGE_SIZE - sizeof(struct snp_guest_msg_hdr)];
 } __packed;
 
 #endif /* __VIRT_SEVGUEST_H__ */
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 427571a2d1a2..c4aae5d4308e 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -1033,6 +1033,9 @@ static int __init sev_guest_probe(struct platform_device *pdev)
 	snp_dev->dev = dev;
 	snp_dev->secrets = secrets;
 
+	/* Ensure SNP guest messages do not span more than a page */
+	BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE);
+
 	/* Allocate the shared page used for the request and response message. */
 	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
 	if (!snp_dev->request)
Borislav Petkov June 25, 2024, 12:19 p.m. UTC | #9
On Mon, Jun 24, 2024 at 11:42:44PM +0530, Nikunj A. Dadhania wrote:
> Sure, here is the new patch. I have separated the variable name changes to a new patch.

Right, please next time you send your set, sort the cleanups to sev-guest
first so that I can pick them up separately.

> 
> Subject: [PATCH] virt: sev-guest: Ensure the SNP guest messages do not exceed
>  a page
> 
> Currently, snp_guest_msg includes a message header (96 bytes) and a
> payload (4000 bytes). There is an implicit assumption here that the SNP
> message header will always be 96 bytes, and with that assumption the
> payload array size has been set to 4000 bytes magic number. If any new
> member is added to the SNP message header, the SNP guest message will span
> more than a page.
> 
> Instead of using magic number '4000' for the payload, declare the
> snp_guest_msg in a way that payload plus the message header do not exceed a
> page.
> 
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
> Suggested-by: Tom Lendacky <thomas.lendacky@amd.com>
> ---
>  drivers/virt/coco/sev-guest/sev-guest.h | 2 +-
>  drivers/virt/coco/sev-guest/sev-guest.c | 3 +++
>  2 files changed, 4 insertions(+), 1 deletion(-)
> 
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/drivers/virt/coco/sev-guest/sev-guest.h
> index ceb798a404d6..de14a4f01b9d 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.h
> +++ b/drivers/virt/coco/sev-guest/sev-guest.h
> @@ -60,7 +60,7 @@ struct snp_guest_msg_hdr {
>  
>  struct snp_guest_msg {
>  	struct snp_guest_msg_hdr hdr;
> -	u8 payload[4000];
> +	u8 payload[PAGE_SIZE - sizeof(struct snp_guest_msg_hdr)];
>  } __packed;
>  
>  #endif /* __VIRT_SEVGUEST_H__ */
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index 427571a2d1a2..c4aae5d4308e 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
> @@ -1033,6 +1033,9 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>  	snp_dev->dev = dev;
>  	snp_dev->secrets = secrets;
>  
> +	/* Ensure SNP guest messages do not span more than a page */
> +	BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE);
> +
>  	/* Allocate the shared page used for the request and response message. */
>  	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
>  	if (!snp_dev->request)
> -- 

Yap, that's exactly how stuff like that should be done:

Acked-by: Borislav Petkov (AMD) <bp@alien8.de>
diff mbox series

Patch

diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/drivers/virt/coco/sev-guest/sev-guest.h
index ceb798a404d6..97796f658fd3 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.h
+++ b/drivers/virt/coco/sev-guest/sev-guest.h
@@ -60,7 +60,10 @@  struct snp_guest_msg_hdr {
 
 struct snp_guest_msg {
 	struct snp_guest_msg_hdr hdr;
-	u8 payload[4000];
+	u8 payload[];
 } __packed;
 
+#define SNP_GUEST_MSG_SIZE 4096
+#define SNP_GUEST_MSG_PAYLOAD_SIZE (SNP_GUEST_MSG_SIZE - sizeof(struct snp_guest_msg))
+
 #endif /* __VIRT_SEVGUEST_H__ */
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 7e1bf2056b47..69bd817239d8 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -48,7 +48,7 @@  struct snp_guest_dev {
 	 * Avoid information leakage by double-buffering shared messages
 	 * in fields that are in regular encrypted memory.
 	 */
-	struct snp_guest_msg secret_request, secret_response;
+	struct snp_guest_msg *secret_request, *secret_response;
 
 	struct snp_secrets_page *secrets;
 	struct snp_req_data input;
@@ -171,40 +171,40 @@  static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
 
 static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
 {
-	struct snp_guest_msg *resp = &snp_dev->secret_response;
-	struct snp_guest_msg *req = &snp_dev->secret_request;
-	struct snp_guest_msg_hdr *req_hdr = &req->hdr;
-	struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
+	struct snp_guest_msg *resp_msg = snp_dev->secret_response;
+	struct snp_guest_msg *req_msg = snp_dev->secret_request;
+	struct snp_guest_msg_hdr *req_msg_hdr = &req_msg->hdr;
+	struct snp_guest_msg_hdr *resp_msg_hdr = &resp_msg->hdr;
 	struct aesgcm_ctx *ctx = snp_dev->ctx;
 	u8 iv[GCM_AES_IV_SIZE] = {};
 
 	pr_debug("response [seqno %lld type %d version %d sz %d]\n",
-		 resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version,
-		 resp_hdr->msg_sz);
+		 resp_msg_hdr->msg_seqno, resp_msg_hdr->msg_type, resp_msg_hdr->msg_version,
+		 resp_msg_hdr->msg_sz);
 
 	/* Copy response from shared memory to encrypted memory. */
-	memcpy(resp, snp_dev->response, sizeof(*resp));
+	memcpy(resp_msg, snp_dev->response, SNP_GUEST_MSG_SIZE);
 
 	/* Verify that the sequence counter is incremented by 1 */
-	if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1)))
+	if (unlikely(resp_msg_hdr->msg_seqno != (req_msg_hdr->msg_seqno + 1)))
 		return -EBADMSG;
 
 	/* Verify response message type and version number. */
-	if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
-	    resp_hdr->msg_version != req_hdr->msg_version)
+	if (resp_msg_hdr->msg_type != (req_msg_hdr->msg_type + 1) ||
+	    resp_msg_hdr->msg_version != req_msg_hdr->msg_version)
 		return -EBADMSG;
 
 	/*
 	 * If the message size is greater than our buffer length then return
 	 * an error.
 	 */
-	if (unlikely((resp_hdr->msg_sz + ctx->authsize) > sz))
+	if (unlikely((resp_msg_hdr->msg_sz + ctx->authsize) > sz))
 		return -EBADMSG;
 
 	/* Decrypt the payload */
-	memcpy(iv, &resp_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_hdr->msg_seqno)));
-	if (!aesgcm_decrypt(ctx, payload, resp->payload, resp_hdr->msg_sz,
-			    &resp_hdr->algo, AAD_LEN, iv, resp_hdr->authtag))
+	memcpy(iv, &resp_msg_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_msg_hdr->msg_seqno)));
+	if (!aesgcm_decrypt(ctx, payload, resp_msg->payload, resp_msg_hdr->msg_sz,
+			    &resp_msg_hdr->algo, AAD_LEN, iv, resp_msg_hdr->authtag))
 		return -EBADMSG;
 
 	return 0;
@@ -213,12 +213,12 @@  static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
 static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
 			void *payload, size_t sz)
 {
-	struct snp_guest_msg *req = &snp_dev->secret_request;
-	struct snp_guest_msg_hdr *hdr = &req->hdr;
+	struct snp_guest_msg *msg = snp_dev->secret_request;
+	struct snp_guest_msg_hdr *hdr = &msg->hdr;
 	struct aesgcm_ctx *ctx = snp_dev->ctx;
 	u8 iv[GCM_AES_IV_SIZE] = {};
 
-	memset(req, 0, sizeof(*req));
+	memset(msg, 0, SNP_GUEST_MSG_SIZE);
 
 	hdr->algo = SNP_AEAD_AES_256_GCM;
 	hdr->hdr_version = MSG_HDR_VER;
@@ -236,11 +236,11 @@  static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
 	pr_debug("request [seqno %lld type %d version %d sz %d]\n",
 		 hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
 
-	if (WARN_ON((sz + ctx->authsize) > sizeof(req->payload)))
+	if (WARN_ON((sz + ctx->authsize) > SNP_GUEST_MSG_PAYLOAD_SIZE))
 		return -EBADMSG;
 
 	memcpy(iv, &hdr->msg_seqno, min(sizeof(iv), sizeof(hdr->msg_seqno)));
-	aesgcm_encrypt(ctx, req->payload, payload, sz, &hdr->algo, AAD_LEN,
+	aesgcm_encrypt(ctx, msg->payload, payload, sz, &hdr->algo, AAD_LEN,
 		       iv, hdr->authtag);
 
 	return 0;
@@ -346,7 +346,7 @@  static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 		return -EIO;
 
 	/* Clear shared memory's response for the host to populate. */
-	memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));
+	memset(snp_dev->response, 0, SNP_GUEST_MSG_SIZE);
 
 	/* Encrypt the userspace provided payload in snp_dev->secret_request. */
 	rc = enc_payload(snp_dev, seqno, rio->msg_version, type, req_buf, req_sz);
@@ -357,8 +357,7 @@  static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	 * Write the fully encrypted request to the shared unencrypted
 	 * request page.
 	 */
-	memcpy(snp_dev->request, &snp_dev->secret_request,
-	       sizeof(snp_dev->secret_request));
+	memcpy(snp_dev->request, &snp_dev->secret_request, SNP_GUEST_MSG_SIZE);
 
 	rc = __handle_guest_request(snp_dev, exit_code, rio);
 	if (rc) {
@@ -842,12 +841,21 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 	snp_dev->dev = dev;
 	snp_dev->secrets = secrets;
 
+	/* Allocate secret request and response message for double buffering */
+	snp_dev->secret_request = kzalloc(SNP_GUEST_MSG_SIZE, GFP_KERNEL);
+	if (!snp_dev->secret_request)
+		goto e_unmap;
+
+	snp_dev->secret_response = kzalloc(SNP_GUEST_MSG_SIZE, GFP_KERNEL);
+	if (!snp_dev->secret_response)
+		goto e_free_secret_req;
+
 	/* Allocate the shared page used for the request and response message. */
-	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
+	snp_dev->request = alloc_shared_pages(dev, SNP_GUEST_MSG_SIZE);
 	if (!snp_dev->request)
-		goto e_unmap;
+		goto e_free_secret_resp;
 
-	snp_dev->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
+	snp_dev->response = alloc_shared_pages(dev, SNP_GUEST_MSG_SIZE);
 	if (!snp_dev->response)
 		goto e_free_request;
 
@@ -890,9 +898,13 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 e_free_cert_data:
 	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
 e_free_response:
-	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
+	free_shared_pages(snp_dev->response, SNP_GUEST_MSG_SIZE);
 e_free_request:
-	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
+	free_shared_pages(snp_dev->request, SNP_GUEST_MSG_SIZE);
+e_free_secret_resp:
+	kfree(snp_dev->secret_response);
+e_free_secret_req:
+	kfree(snp_dev->secret_request);
 e_unmap:
 	iounmap(mapping);
 	return ret;
@@ -903,8 +915,10 @@  static void __exit sev_guest_remove(struct platform_device *pdev)
 	struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
 
 	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
-	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
-	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
+	free_shared_pages(snp_dev->response, SNP_GUEST_MSG_SIZE);
+	free_shared_pages(snp_dev->request, SNP_GUEST_MSG_SIZE);
+	kfree(snp_dev->secret_response);
+	kfree(snp_dev->secret_request);
 	kfree(snp_dev->ctx);
 	misc_deregister(&snp_dev->misc);
 }