Message ID | 20240731150811.156771-11-nikunj@amd.com (mailing list archive) |
---|---|
State | New, archived |
Headers | show |
Series | Add Secure TSC support for SNP guests | expand |
On 7/31/24 10:08, Nikunj A Dadhania wrote: > Currently, the sev-guest driver is the only user of SNP guest messaging. > snp_guest_dev structure holds all the allocated buffers, secrets page and The snp_guest_dev structure... > VMPCK details. In preparation of adding messaging allocation and s/of/for/ > initialization APIs, decouple snp_guest_dev from messaging-related > information by carving out guest message context structure(snp_msg_desc). s/out guest/out the guest/ > > Incorporate this newly added context into snp_send_guest_request() and all > related functions, replacing the use of the snp_guest_dev. > > No functional change. > > Signed-off-by: Nikunj A Dadhania <nikunj@amd.com> Reviewed-by: Tom Lendacky <thomas.lendacky@amd.com> > --- > arch/x86/include/asm/sev.h | 21 +++ > drivers/virt/coco/sev-guest/sev-guest.c | 183 ++++++++++++------------ > 2 files changed, 111 insertions(+), 93 deletions(-) > > diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h > index 27fa1c9c3465..2e49c4a9e7fe 100644 > --- a/arch/x86/include/asm/sev.h > +++ b/arch/x86/include/asm/sev.h > @@ -234,6 +234,27 @@ struct snp_secrets_page { > u8 rsvd4[3744]; > } __packed; > > +struct snp_msg_desc { > + /* request and response are in unencrypted memory */ > + struct snp_guest_msg *request, *response; > + > + /* > + * Avoid information leakage by double-buffering shared messages > + * in fields that are in regular encrypted memory. > + */ > + struct snp_guest_msg secret_request, secret_response; > + > + struct snp_secrets_page *secrets; > + struct snp_req_data input; > + > + void *certs_data; > + > + struct aesgcm_ctx *ctx; > + > + u32 *os_area_msg_seqno; > + u8 *vmpck; > +}; > + > /* > * The SVSM Calling Area (CA) related structures. > */ > diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c > index 42f7126f1718..38ddabcd7ba3 100644 > --- a/drivers/virt/coco/sev-guest/sev-guest.c > +++ b/drivers/virt/coco/sev-guest/sev-guest.c > @@ -40,26 +40,13 @@ struct snp_guest_dev { > struct device *dev; > struct miscdevice misc; > > - void *certs_data; > - struct aesgcm_ctx *ctx; > - /* request and response are in unencrypted memory */ > - struct snp_guest_msg *request, *response; > - > - /* > - * Avoid information leakage by double-buffering shared messages > - * in fields that are in regular encrypted memory. > - */ > - struct snp_guest_msg secret_request, secret_response; > + struct snp_msg_desc *msg_desc; > > - struct snp_secrets_page *secrets; > - struct snp_req_data input; > union { > struct snp_report_req report; > struct snp_derived_key_req derived_key; > struct snp_ext_report_req ext_report; > } req; > - u32 *os_area_msg_seqno; > - u8 *vmpck; > }; > > /* > @@ -76,12 +63,12 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP. > /* Mutex to serialize the shared buffer access and command handling. */ > static DEFINE_MUTEX(snp_cmd_mutex); > > -static bool is_vmpck_empty(struct snp_guest_dev *snp_dev) > +static bool is_vmpck_empty(struct snp_msg_desc *mdesc) > { > char zero_key[VMPCK_KEY_LEN] = {0}; > > - if (snp_dev->vmpck) > - return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN); > + if (mdesc->vmpck) > + return !memcmp(mdesc->vmpck, zero_key, VMPCK_KEY_LEN); > > return true; > } > @@ -103,30 +90,30 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev) > * vulnerable. If the sequence number were incremented for a fresh IV the ASP > * will reject the request. > */ > -static void snp_disable_vmpck(struct snp_guest_dev *snp_dev) > +static void snp_disable_vmpck(struct snp_msg_desc *mdesc) > { > - dev_alert(snp_dev->dev, "Disabling VMPCK%d communication key to prevent IV reuse.\n", > + pr_alert("Disabling VMPCK%d communication key to prevent IV reuse.\n", > vmpck_id); > - memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN); > - snp_dev->vmpck = NULL; > + memzero_explicit(mdesc->vmpck, VMPCK_KEY_LEN); > + mdesc->vmpck = NULL; > } > > -static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev) > +static inline u64 __snp_get_msg_seqno(struct snp_msg_desc *mdesc) > { > u64 count; > > lockdep_assert_held(&snp_cmd_mutex); > > /* Read the current message sequence counter from secrets pages */ > - count = *snp_dev->os_area_msg_seqno; > + count = *mdesc->os_area_msg_seqno; > > return count + 1; > } > > /* Return a non-zero on success */ > -static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev) > +static u64 snp_get_msg_seqno(struct snp_msg_desc *mdesc) > { > - u64 count = __snp_get_msg_seqno(snp_dev); > + u64 count = __snp_get_msg_seqno(mdesc); > > /* > * The message sequence counter for the SNP guest request is a 64-bit > @@ -137,20 +124,20 @@ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev) > * invalid number and will fail the message request. > */ > if (count >= UINT_MAX) { > - dev_err(snp_dev->dev, "request message sequence counter overflow\n"); > + pr_err("request message sequence counter overflow\n"); > return 0; > } > > return count; > } > > -static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev) > +static void snp_inc_msg_seqno(struct snp_msg_desc *mdesc) > { > /* > * The counter is also incremented by the PSP, so increment it by 2 > * and save in secrets page. > */ > - *snp_dev->os_area_msg_seqno += 2; > + *mdesc->os_area_msg_seqno += 2; > } > > static inline struct snp_guest_dev *to_snp_dev(struct file *file) > @@ -177,13 +164,13 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen) > return ctx; > } > > -static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *req) > +static int verify_and_dec_payload(struct snp_msg_desc *mdesc, struct snp_guest_req *req) > { > - struct snp_guest_msg *resp_msg = &snp_dev->secret_response; > - struct snp_guest_msg *req_msg = &snp_dev->secret_request; > + struct snp_guest_msg *resp_msg = &mdesc->secret_response; > + struct snp_guest_msg *req_msg = &mdesc->secret_request; > struct snp_guest_msg_hdr *req_msg_hdr = &req_msg->hdr; > struct snp_guest_msg_hdr *resp_msg_hdr = &resp_msg->hdr; > - struct aesgcm_ctx *ctx = snp_dev->ctx; > + struct aesgcm_ctx *ctx = mdesc->ctx; > u8 iv[GCM_AES_IV_SIZE] = {}; > > pr_debug("response [seqno %lld type %d version %d sz %d]\n", > @@ -191,7 +178,7 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_gues > resp_msg_hdr->msg_sz); > > /* Copy response from shared memory to encrypted memory. */ > - memcpy(resp_msg, snp_dev->response, sizeof(*resp_msg)); > + memcpy(resp_msg, mdesc->response, sizeof(*resp_msg)); > > /* Verify that the sequence counter is incremented by 1 */ > if (unlikely(resp_msg_hdr->msg_seqno != (req_msg_hdr->msg_seqno + 1))) > @@ -218,11 +205,11 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_gues > return 0; > } > > -static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req) > +static int enc_payload(struct snp_msg_desc *mdesc, u64 seqno, struct snp_guest_req *req) > { > - struct snp_guest_msg *msg = &snp_dev->secret_request; > + struct snp_guest_msg *msg = &mdesc->secret_request; > struct snp_guest_msg_hdr *hdr = &msg->hdr; > - struct aesgcm_ctx *ctx = snp_dev->ctx; > + struct aesgcm_ctx *ctx = mdesc->ctx; > u8 iv[GCM_AES_IV_SIZE] = {}; > > memset(msg, 0, sizeof(*msg)); > @@ -253,7 +240,7 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_gues > return 0; > } > > -static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req, > +static int __handle_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_req *req, > struct snp_guest_request_ioctl *rio) > { > unsigned long req_start = jiffies; > @@ -268,7 +255,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues > * sequence number must be incremented or the VMPCK must be deleted to > * prevent reuse of the IV. > */ > - rc = snp_issue_guest_request(req, &snp_dev->input, rio); > + rc = snp_issue_guest_request(req, &mdesc->input, rio); > switch (rc) { > case -ENOSPC: > /* > @@ -278,7 +265,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues > * order to increment the sequence number and thus avoid > * IV reuse. > */ > - override_npages = snp_dev->input.data_npages; > + override_npages = mdesc->input.data_npages; > req->exit_code = SVM_VMGEXIT_GUEST_REQUEST; > > /* > @@ -318,7 +305,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues > * structure and any failure will wipe the VMPCK, preventing further > * use anyway. > */ > - snp_inc_msg_seqno(snp_dev); > + snp_inc_msg_seqno(mdesc); > > if (override_err) { > rio->exitinfo2 = override_err; > @@ -334,12 +321,12 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues > } > > if (override_npages) > - snp_dev->input.data_npages = override_npages; > + mdesc->input.data_npages = override_npages; > > return rc; > } > > -static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req, > +static int snp_send_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_req *req, > struct snp_guest_request_ioctl *rio) > { > u64 seqno; > @@ -348,15 +335,15 @@ static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues > guard(mutex)(&snp_cmd_mutex); > > /* Get message sequence and verify that its a non-zero */ > - seqno = snp_get_msg_seqno(snp_dev); > + seqno = snp_get_msg_seqno(mdesc); > if (!seqno) > return -EIO; > > /* Clear shared memory's response for the host to populate. */ > - memset(snp_dev->response, 0, sizeof(struct snp_guest_msg)); > + memset(mdesc->response, 0, sizeof(struct snp_guest_msg)); > > - /* Encrypt the userspace provided payload in snp_dev->secret_request. */ > - rc = enc_payload(snp_dev, seqno, req); > + /* Encrypt the userspace provided payload in mdesc->secret_request. */ > + rc = enc_payload(mdesc, seqno, req); > if (rc) > return rc; > > @@ -364,34 +351,33 @@ static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues > * Write the fully encrypted request to the shared unencrypted > * request page. > */ > - memcpy(snp_dev->request, &snp_dev->secret_request, > - sizeof(snp_dev->secret_request)); > + memcpy(mdesc->request, &mdesc->secret_request, > + sizeof(mdesc->secret_request)); > > - rc = __handle_guest_request(snp_dev, req, rio); > + rc = __handle_guest_request(mdesc, req, rio); > if (rc) { > if (rc == -EIO && > rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN)) > return rc; > > - dev_alert(snp_dev->dev, > - "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n", > - rc, rio->exitinfo2); > + pr_alert("Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n", > + rc, rio->exitinfo2); > > - snp_disable_vmpck(snp_dev); > + snp_disable_vmpck(mdesc); > return rc; > } > > - rc = verify_and_dec_payload(snp_dev, req); > + rc = verify_and_dec_payload(mdesc, req); > if (rc) { > - dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc); > - snp_disable_vmpck(snp_dev); > + pr_alert("Detected unexpected decode failure from ASP. rc: %d\n", rc); > + snp_disable_vmpck(mdesc); > return rc; > } > > return 0; > } > > -static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, > +static int handle_guest_request(struct snp_msg_desc *mdesc, u64 exit_code, > struct snp_guest_request_ioctl *rio, u8 type, > void *req_buf, size_t req_sz, void *resp_buf, > u32 resp_sz) > @@ -407,7 +393,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, > .exit_code = exit_code, > }; > > - return snp_send_guest_request(snp_dev, &req, rio); > + return snp_send_guest_request(mdesc, &req, rio); > } > > struct snp_req_resp { > @@ -418,6 +404,7 @@ struct snp_req_resp { > static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg) > { > struct snp_report_req *report_req = &snp_dev->req.report; > + struct snp_msg_desc *mdesc = snp_dev->msg_desc; > struct snp_report_resp *report_resp; > int rc, resp_len; > > @@ -432,12 +419,12 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io > * response payload. Make sure that it has enough space to cover the > * authtag. > */ > - resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize; > + resp_len = sizeof(report_resp->data) + mdesc->ctx->authsize; > report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT); > if (!report_resp) > return -ENOMEM; > > - rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ, > + rc = handle_guest_request(mdesc, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ, > report_req, sizeof(*report_req), report_resp->data, resp_len); > if (rc) > goto e_free; > @@ -454,6 +441,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque > { > struct snp_derived_key_req *derived_key_req = &snp_dev->req.derived_key; > struct snp_derived_key_resp derived_key_resp = {0}; > + struct snp_msg_desc *mdesc = snp_dev->msg_desc; > int rc, resp_len; > /* Response data is 64 bytes and max authsize for GCM is 16 bytes. */ > u8 buf[64 + 16]; > @@ -466,7 +454,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque > * response payload. Make sure that it has enough space to cover the > * authtag. > */ > - resp_len = sizeof(derived_key_resp.data) + snp_dev->ctx->authsize; > + resp_len = sizeof(derived_key_resp.data) + mdesc->ctx->authsize; > if (sizeof(buf) < resp_len) > return -ENOMEM; > > @@ -474,7 +462,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque > sizeof(*derived_key_req))) > return -EFAULT; > > - rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_KEY_REQ, > + rc = handle_guest_request(mdesc, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_KEY_REQ, > derived_key_req, sizeof(*derived_key_req), buf, resp_len); > if (rc) > return rc; > @@ -495,6 +483,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques > > { > struct snp_ext_report_req *report_req = &snp_dev->req.ext_report; > + struct snp_msg_desc *mdesc = snp_dev->msg_desc; > struct snp_report_resp *report_resp; > int ret, npages = 0, resp_len; > sockptr_t certs_address; > @@ -527,7 +516,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques > * the host. If host does not supply any certs in it, then copy > * zeros to indicate that certificate data was not provided. > */ > - memset(snp_dev->certs_data, 0, report_req->certs_len); > + memset(mdesc->certs_data, 0, report_req->certs_len); > npages = report_req->certs_len >> PAGE_SHIFT; > cmd: > /* > @@ -535,19 +524,19 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques > * response payload. Make sure that it has enough space to cover the > * authtag. > */ > - resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize; > + resp_len = sizeof(report_resp->data) + mdesc->ctx->authsize; > report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT); > if (!report_resp) > return -ENOMEM; > > - snp_dev->input.data_npages = npages; > - ret = handle_guest_request(snp_dev, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ, > + mdesc->input.data_npages = npages; > + ret = handle_guest_request(mdesc, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ, > &report_req->data, sizeof(report_req->data), > report_resp->data, resp_len); > > /* If certs length is invalid then copy the returned length */ > if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) { > - report_req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT; > + report_req->certs_len = mdesc->input.data_npages << PAGE_SHIFT; > > if (copy_to_sockptr(io->req_data, report_req, sizeof(*report_req))) > ret = -EFAULT; > @@ -556,7 +545,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques > if (ret) > goto e_free; > > - if (npages && copy_to_sockptr(certs_address, snp_dev->certs_data, report_req->certs_len)) { > + if (npages && copy_to_sockptr(certs_address, mdesc->certs_data, report_req->certs_len)) { > ret = -EFAULT; > goto e_free; > } > @@ -572,6 +561,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques > static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long arg) > { > struct snp_guest_dev *snp_dev = to_snp_dev(file); > + struct snp_msg_desc *mdesc = snp_dev->msg_desc; > void __user *argp = (void __user *)arg; > struct snp_guest_request_ioctl input; > struct snp_req_resp io; > @@ -587,7 +577,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long > return -EINVAL; > > /* Check if the VMPCK is not empty */ > - if (is_vmpck_empty(snp_dev)) { > + if (is_vmpck_empty(mdesc)) { > dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n"); > return -ENOTTY; > } > @@ -862,7 +852,7 @@ static int sev_report_new(struct tsm_report *report, void *data) > return -ENOMEM; > > /* Check if the VMPCK is not empty */ > - if (is_vmpck_empty(snp_dev)) { > + if (is_vmpck_empty(snp_dev->msg_desc)) { > dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n"); > return -ENOTTY; > } > @@ -992,6 +982,7 @@ static int __init sev_guest_probe(struct platform_device *pdev) > struct snp_secrets_page *secrets; > struct device *dev = &pdev->dev; > struct snp_guest_dev *snp_dev; > + struct snp_msg_desc *mdesc; > struct miscdevice *misc; > void __iomem *mapping; > int ret; > @@ -1014,46 +1005,50 @@ static int __init sev_guest_probe(struct platform_device *pdev) > if (!snp_dev) > goto e_unmap; > > + mdesc = devm_kzalloc(&pdev->dev, sizeof(struct snp_msg_desc), GFP_KERNEL); > + if (!mdesc) > + goto e_unmap; > + > /* Adjust the default VMPCK key based on the executing VMPL level */ > if (vmpck_id == -1) > vmpck_id = snp_vmpl; > > ret = -EINVAL; > - snp_dev->vmpck = get_vmpck(vmpck_id, secrets, &snp_dev->os_area_msg_seqno); > - if (!snp_dev->vmpck) { > + mdesc->vmpck = get_vmpck(vmpck_id, secrets, &mdesc->os_area_msg_seqno); > + if (!mdesc->vmpck) { > dev_err(dev, "Invalid VMPCK%d communication key\n", vmpck_id); > goto e_unmap; > } > > /* Verify that VMPCK is not zero. */ > - if (is_vmpck_empty(snp_dev)) { > + if (is_vmpck_empty(mdesc)) { > dev_err(dev, "Empty VMPCK%d communication key\n", vmpck_id); > goto e_unmap; > } > > platform_set_drvdata(pdev, snp_dev); > snp_dev->dev = dev; > - snp_dev->secrets = secrets; > + mdesc->secrets = secrets; > > /* Ensure SNP guest messages do not span more than a page */ > BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE); > > /* Allocate the shared page used for the request and response message. */ > - snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); > - if (!snp_dev->request) > + mdesc->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); > + if (!mdesc->request) > goto e_unmap; > > - snp_dev->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); > - if (!snp_dev->response) > + mdesc->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); > + if (!mdesc->response) > goto e_free_request; > > - snp_dev->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE); > - if (!snp_dev->certs_data) > + mdesc->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE); > + if (!mdesc->certs_data) > goto e_free_response; > > ret = -EIO; > - snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN); > - if (!snp_dev->ctx) > + mdesc->ctx = snp_init_crypto(mdesc->vmpck, VMPCK_KEY_LEN); > + if (!mdesc->ctx) > goto e_free_cert_data; > > misc = &snp_dev->misc; > @@ -1062,9 +1057,9 @@ static int __init sev_guest_probe(struct platform_device *pdev) > misc->fops = &snp_guest_fops; > > /* Initialize the input addresses for guest request */ > - snp_dev->input.req_gpa = __pa(snp_dev->request); > - snp_dev->input.resp_gpa = __pa(snp_dev->response); > - snp_dev->input.data_gpa = __pa(snp_dev->certs_data); > + mdesc->input.req_gpa = __pa(mdesc->request); > + mdesc->input.resp_gpa = __pa(mdesc->response); > + mdesc->input.data_gpa = __pa(mdesc->certs_data); > > /* Set the privlevel_floor attribute based on the vmpck_id */ > sev_tsm_ops.privlevel_floor = vmpck_id; > @@ -1081,17 +1076,18 @@ static int __init sev_guest_probe(struct platform_device *pdev) > if (ret) > goto e_free_ctx; > > + snp_dev->msg_desc = mdesc; > dev_info(dev, "Initialized SEV guest driver (using VMPCK%d communication key)\n", vmpck_id); > return 0; > > e_free_ctx: > - kfree(snp_dev->ctx); > + kfree(mdesc->ctx); > e_free_cert_data: > - free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE); > + free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE); > e_free_response: > - free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg)); > + free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg)); > e_free_request: > - free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg)); > + free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg)); > e_unmap: > iounmap(mapping); > return ret; > @@ -1100,11 +1096,12 @@ static int __init sev_guest_probe(struct platform_device *pdev) > static void __exit sev_guest_remove(struct platform_device *pdev) > { > struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev); > + struct snp_msg_desc *mdesc = snp_dev->msg_desc; > > - free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE); > - free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg)); > - free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg)); > - kfree(snp_dev->ctx); > + free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE); > + free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg)); > + free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg)); > + kfree(mdesc->ctx); > misc_deregister(&snp_dev->misc); > } >
diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h index 27fa1c9c3465..2e49c4a9e7fe 100644 --- a/arch/x86/include/asm/sev.h +++ b/arch/x86/include/asm/sev.h @@ -234,6 +234,27 @@ struct snp_secrets_page { u8 rsvd4[3744]; } __packed; +struct snp_msg_desc { + /* request and response are in unencrypted memory */ + struct snp_guest_msg *request, *response; + + /* + * Avoid information leakage by double-buffering shared messages + * in fields that are in regular encrypted memory. + */ + struct snp_guest_msg secret_request, secret_response; + + struct snp_secrets_page *secrets; + struct snp_req_data input; + + void *certs_data; + + struct aesgcm_ctx *ctx; + + u32 *os_area_msg_seqno; + u8 *vmpck; +}; + /* * The SVSM Calling Area (CA) related structures. */ diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c index 42f7126f1718..38ddabcd7ba3 100644 --- a/drivers/virt/coco/sev-guest/sev-guest.c +++ b/drivers/virt/coco/sev-guest/sev-guest.c @@ -40,26 +40,13 @@ struct snp_guest_dev { struct device *dev; struct miscdevice misc; - void *certs_data; - struct aesgcm_ctx *ctx; - /* request and response are in unencrypted memory */ - struct snp_guest_msg *request, *response; - - /* - * Avoid information leakage by double-buffering shared messages - * in fields that are in regular encrypted memory. - */ - struct snp_guest_msg secret_request, secret_response; + struct snp_msg_desc *msg_desc; - struct snp_secrets_page *secrets; - struct snp_req_data input; union { struct snp_report_req report; struct snp_derived_key_req derived_key; struct snp_ext_report_req ext_report; } req; - u32 *os_area_msg_seqno; - u8 *vmpck; }; /* @@ -76,12 +63,12 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP. /* Mutex to serialize the shared buffer access and command handling. */ static DEFINE_MUTEX(snp_cmd_mutex); -static bool is_vmpck_empty(struct snp_guest_dev *snp_dev) +static bool is_vmpck_empty(struct snp_msg_desc *mdesc) { char zero_key[VMPCK_KEY_LEN] = {0}; - if (snp_dev->vmpck) - return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN); + if (mdesc->vmpck) + return !memcmp(mdesc->vmpck, zero_key, VMPCK_KEY_LEN); return true; } @@ -103,30 +90,30 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev) * vulnerable. If the sequence number were incremented for a fresh IV the ASP * will reject the request. */ -static void snp_disable_vmpck(struct snp_guest_dev *snp_dev) +static void snp_disable_vmpck(struct snp_msg_desc *mdesc) { - dev_alert(snp_dev->dev, "Disabling VMPCK%d communication key to prevent IV reuse.\n", + pr_alert("Disabling VMPCK%d communication key to prevent IV reuse.\n", vmpck_id); - memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN); - snp_dev->vmpck = NULL; + memzero_explicit(mdesc->vmpck, VMPCK_KEY_LEN); + mdesc->vmpck = NULL; } -static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev) +static inline u64 __snp_get_msg_seqno(struct snp_msg_desc *mdesc) { u64 count; lockdep_assert_held(&snp_cmd_mutex); /* Read the current message sequence counter from secrets pages */ - count = *snp_dev->os_area_msg_seqno; + count = *mdesc->os_area_msg_seqno; return count + 1; } /* Return a non-zero on success */ -static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev) +static u64 snp_get_msg_seqno(struct snp_msg_desc *mdesc) { - u64 count = __snp_get_msg_seqno(snp_dev); + u64 count = __snp_get_msg_seqno(mdesc); /* * The message sequence counter for the SNP guest request is a 64-bit @@ -137,20 +124,20 @@ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev) * invalid number and will fail the message request. */ if (count >= UINT_MAX) { - dev_err(snp_dev->dev, "request message sequence counter overflow\n"); + pr_err("request message sequence counter overflow\n"); return 0; } return count; } -static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev) +static void snp_inc_msg_seqno(struct snp_msg_desc *mdesc) { /* * The counter is also incremented by the PSP, so increment it by 2 * and save in secrets page. */ - *snp_dev->os_area_msg_seqno += 2; + *mdesc->os_area_msg_seqno += 2; } static inline struct snp_guest_dev *to_snp_dev(struct file *file) @@ -177,13 +164,13 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen) return ctx; } -static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_req *req) +static int verify_and_dec_payload(struct snp_msg_desc *mdesc, struct snp_guest_req *req) { - struct snp_guest_msg *resp_msg = &snp_dev->secret_response; - struct snp_guest_msg *req_msg = &snp_dev->secret_request; + struct snp_guest_msg *resp_msg = &mdesc->secret_response; + struct snp_guest_msg *req_msg = &mdesc->secret_request; struct snp_guest_msg_hdr *req_msg_hdr = &req_msg->hdr; struct snp_guest_msg_hdr *resp_msg_hdr = &resp_msg->hdr; - struct aesgcm_ctx *ctx = snp_dev->ctx; + struct aesgcm_ctx *ctx = mdesc->ctx; u8 iv[GCM_AES_IV_SIZE] = {}; pr_debug("response [seqno %lld type %d version %d sz %d]\n", @@ -191,7 +178,7 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_gues resp_msg_hdr->msg_sz); /* Copy response from shared memory to encrypted memory. */ - memcpy(resp_msg, snp_dev->response, sizeof(*resp_msg)); + memcpy(resp_msg, mdesc->response, sizeof(*resp_msg)); /* Verify that the sequence counter is incremented by 1 */ if (unlikely(resp_msg_hdr->msg_seqno != (req_msg_hdr->msg_seqno + 1))) @@ -218,11 +205,11 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, struct snp_gues return 0; } -static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_guest_req *req) +static int enc_payload(struct snp_msg_desc *mdesc, u64 seqno, struct snp_guest_req *req) { - struct snp_guest_msg *msg = &snp_dev->secret_request; + struct snp_guest_msg *msg = &mdesc->secret_request; struct snp_guest_msg_hdr *hdr = &msg->hdr; - struct aesgcm_ctx *ctx = snp_dev->ctx; + struct aesgcm_ctx *ctx = mdesc->ctx; u8 iv[GCM_AES_IV_SIZE] = {}; memset(msg, 0, sizeof(*msg)); @@ -253,7 +240,7 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, struct snp_gues return 0; } -static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req, +static int __handle_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_req *req, struct snp_guest_request_ioctl *rio) { unsigned long req_start = jiffies; @@ -268,7 +255,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues * sequence number must be incremented or the VMPCK must be deleted to * prevent reuse of the IV. */ - rc = snp_issue_guest_request(req, &snp_dev->input, rio); + rc = snp_issue_guest_request(req, &mdesc->input, rio); switch (rc) { case -ENOSPC: /* @@ -278,7 +265,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues * order to increment the sequence number and thus avoid * IV reuse. */ - override_npages = snp_dev->input.data_npages; + override_npages = mdesc->input.data_npages; req->exit_code = SVM_VMGEXIT_GUEST_REQUEST; /* @@ -318,7 +305,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues * structure and any failure will wipe the VMPCK, preventing further * use anyway. */ - snp_inc_msg_seqno(snp_dev); + snp_inc_msg_seqno(mdesc); if (override_err) { rio->exitinfo2 = override_err; @@ -334,12 +321,12 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues } if (override_npages) - snp_dev->input.data_npages = override_npages; + mdesc->input.data_npages = override_npages; return rc; } -static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_guest_req *req, +static int snp_send_guest_request(struct snp_msg_desc *mdesc, struct snp_guest_req *req, struct snp_guest_request_ioctl *rio) { u64 seqno; @@ -348,15 +335,15 @@ static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues guard(mutex)(&snp_cmd_mutex); /* Get message sequence and verify that its a non-zero */ - seqno = snp_get_msg_seqno(snp_dev); + seqno = snp_get_msg_seqno(mdesc); if (!seqno) return -EIO; /* Clear shared memory's response for the host to populate. */ - memset(snp_dev->response, 0, sizeof(struct snp_guest_msg)); + memset(mdesc->response, 0, sizeof(struct snp_guest_msg)); - /* Encrypt the userspace provided payload in snp_dev->secret_request. */ - rc = enc_payload(snp_dev, seqno, req); + /* Encrypt the userspace provided payload in mdesc->secret_request. */ + rc = enc_payload(mdesc, seqno, req); if (rc) return rc; @@ -364,34 +351,33 @@ static int snp_send_guest_request(struct snp_guest_dev *snp_dev, struct snp_gues * Write the fully encrypted request to the shared unencrypted * request page. */ - memcpy(snp_dev->request, &snp_dev->secret_request, - sizeof(snp_dev->secret_request)); + memcpy(mdesc->request, &mdesc->secret_request, + sizeof(mdesc->secret_request)); - rc = __handle_guest_request(snp_dev, req, rio); + rc = __handle_guest_request(mdesc, req, rio); if (rc) { if (rc == -EIO && rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN)) return rc; - dev_alert(snp_dev->dev, - "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n", - rc, rio->exitinfo2); + pr_alert("Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n", + rc, rio->exitinfo2); - snp_disable_vmpck(snp_dev); + snp_disable_vmpck(mdesc); return rc; } - rc = verify_and_dec_payload(snp_dev, req); + rc = verify_and_dec_payload(mdesc, req); if (rc) { - dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc); - snp_disable_vmpck(snp_dev); + pr_alert("Detected unexpected decode failure from ASP. rc: %d\n", rc); + snp_disable_vmpck(mdesc); return rc; } return 0; } -static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, +static int handle_guest_request(struct snp_msg_desc *mdesc, u64 exit_code, struct snp_guest_request_ioctl *rio, u8 type, void *req_buf, size_t req_sz, void *resp_buf, u32 resp_sz) @@ -407,7 +393,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, .exit_code = exit_code, }; - return snp_send_guest_request(snp_dev, &req, rio); + return snp_send_guest_request(mdesc, &req, rio); } struct snp_req_resp { @@ -418,6 +404,7 @@ struct snp_req_resp { static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg) { struct snp_report_req *report_req = &snp_dev->req.report; + struct snp_msg_desc *mdesc = snp_dev->msg_desc; struct snp_report_resp *report_resp; int rc, resp_len; @@ -432,12 +419,12 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io * response payload. Make sure that it has enough space to cover the * authtag. */ - resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize; + resp_len = sizeof(report_resp->data) + mdesc->ctx->authsize; report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT); if (!report_resp) return -ENOMEM; - rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ, + rc = handle_guest_request(mdesc, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ, report_req, sizeof(*report_req), report_resp->data, resp_len); if (rc) goto e_free; @@ -454,6 +441,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque { struct snp_derived_key_req *derived_key_req = &snp_dev->req.derived_key; struct snp_derived_key_resp derived_key_resp = {0}; + struct snp_msg_desc *mdesc = snp_dev->msg_desc; int rc, resp_len; /* Response data is 64 bytes and max authsize for GCM is 16 bytes. */ u8 buf[64 + 16]; @@ -466,7 +454,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque * response payload. Make sure that it has enough space to cover the * authtag. */ - resp_len = sizeof(derived_key_resp.data) + snp_dev->ctx->authsize; + resp_len = sizeof(derived_key_resp.data) + mdesc->ctx->authsize; if (sizeof(buf) < resp_len) return -ENOMEM; @@ -474,7 +462,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque sizeof(*derived_key_req))) return -EFAULT; - rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_KEY_REQ, + rc = handle_guest_request(mdesc, SVM_VMGEXIT_GUEST_REQUEST, arg, SNP_MSG_KEY_REQ, derived_key_req, sizeof(*derived_key_req), buf, resp_len); if (rc) return rc; @@ -495,6 +483,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques { struct snp_ext_report_req *report_req = &snp_dev->req.ext_report; + struct snp_msg_desc *mdesc = snp_dev->msg_desc; struct snp_report_resp *report_resp; int ret, npages = 0, resp_len; sockptr_t certs_address; @@ -527,7 +516,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques * the host. If host does not supply any certs in it, then copy * zeros to indicate that certificate data was not provided. */ - memset(snp_dev->certs_data, 0, report_req->certs_len); + memset(mdesc->certs_data, 0, report_req->certs_len); npages = report_req->certs_len >> PAGE_SHIFT; cmd: /* @@ -535,19 +524,19 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques * response payload. Make sure that it has enough space to cover the * authtag. */ - resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize; + resp_len = sizeof(report_resp->data) + mdesc->ctx->authsize; report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT); if (!report_resp) return -ENOMEM; - snp_dev->input.data_npages = npages; - ret = handle_guest_request(snp_dev, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ, + mdesc->input.data_npages = npages; + ret = handle_guest_request(mdesc, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg, SNP_MSG_REPORT_REQ, &report_req->data, sizeof(report_req->data), report_resp->data, resp_len); /* If certs length is invalid then copy the returned length */ if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) { - report_req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT; + report_req->certs_len = mdesc->input.data_npages << PAGE_SHIFT; if (copy_to_sockptr(io->req_data, report_req, sizeof(*report_req))) ret = -EFAULT; @@ -556,7 +545,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques if (ret) goto e_free; - if (npages && copy_to_sockptr(certs_address, snp_dev->certs_data, report_req->certs_len)) { + if (npages && copy_to_sockptr(certs_address, mdesc->certs_data, report_req->certs_len)) { ret = -EFAULT; goto e_free; } @@ -572,6 +561,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long arg) { struct snp_guest_dev *snp_dev = to_snp_dev(file); + struct snp_msg_desc *mdesc = snp_dev->msg_desc; void __user *argp = (void __user *)arg; struct snp_guest_request_ioctl input; struct snp_req_resp io; @@ -587,7 +577,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long return -EINVAL; /* Check if the VMPCK is not empty */ - if (is_vmpck_empty(snp_dev)) { + if (is_vmpck_empty(mdesc)) { dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n"); return -ENOTTY; } @@ -862,7 +852,7 @@ static int sev_report_new(struct tsm_report *report, void *data) return -ENOMEM; /* Check if the VMPCK is not empty */ - if (is_vmpck_empty(snp_dev)) { + if (is_vmpck_empty(snp_dev->msg_desc)) { dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n"); return -ENOTTY; } @@ -992,6 +982,7 @@ static int __init sev_guest_probe(struct platform_device *pdev) struct snp_secrets_page *secrets; struct device *dev = &pdev->dev; struct snp_guest_dev *snp_dev; + struct snp_msg_desc *mdesc; struct miscdevice *misc; void __iomem *mapping; int ret; @@ -1014,46 +1005,50 @@ static int __init sev_guest_probe(struct platform_device *pdev) if (!snp_dev) goto e_unmap; + mdesc = devm_kzalloc(&pdev->dev, sizeof(struct snp_msg_desc), GFP_KERNEL); + if (!mdesc) + goto e_unmap; + /* Adjust the default VMPCK key based on the executing VMPL level */ if (vmpck_id == -1) vmpck_id = snp_vmpl; ret = -EINVAL; - snp_dev->vmpck = get_vmpck(vmpck_id, secrets, &snp_dev->os_area_msg_seqno); - if (!snp_dev->vmpck) { + mdesc->vmpck = get_vmpck(vmpck_id, secrets, &mdesc->os_area_msg_seqno); + if (!mdesc->vmpck) { dev_err(dev, "Invalid VMPCK%d communication key\n", vmpck_id); goto e_unmap; } /* Verify that VMPCK is not zero. */ - if (is_vmpck_empty(snp_dev)) { + if (is_vmpck_empty(mdesc)) { dev_err(dev, "Empty VMPCK%d communication key\n", vmpck_id); goto e_unmap; } platform_set_drvdata(pdev, snp_dev); snp_dev->dev = dev; - snp_dev->secrets = secrets; + mdesc->secrets = secrets; /* Ensure SNP guest messages do not span more than a page */ BUILD_BUG_ON(sizeof(struct snp_guest_msg) > PAGE_SIZE); /* Allocate the shared page used for the request and response message. */ - snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); - if (!snp_dev->request) + mdesc->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); + if (!mdesc->request) goto e_unmap; - snp_dev->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); - if (!snp_dev->response) + mdesc->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg)); + if (!mdesc->response) goto e_free_request; - snp_dev->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE); - if (!snp_dev->certs_data) + mdesc->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE); + if (!mdesc->certs_data) goto e_free_response; ret = -EIO; - snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN); - if (!snp_dev->ctx) + mdesc->ctx = snp_init_crypto(mdesc->vmpck, VMPCK_KEY_LEN); + if (!mdesc->ctx) goto e_free_cert_data; misc = &snp_dev->misc; @@ -1062,9 +1057,9 @@ static int __init sev_guest_probe(struct platform_device *pdev) misc->fops = &snp_guest_fops; /* Initialize the input addresses for guest request */ - snp_dev->input.req_gpa = __pa(snp_dev->request); - snp_dev->input.resp_gpa = __pa(snp_dev->response); - snp_dev->input.data_gpa = __pa(snp_dev->certs_data); + mdesc->input.req_gpa = __pa(mdesc->request); + mdesc->input.resp_gpa = __pa(mdesc->response); + mdesc->input.data_gpa = __pa(mdesc->certs_data); /* Set the privlevel_floor attribute based on the vmpck_id */ sev_tsm_ops.privlevel_floor = vmpck_id; @@ -1081,17 +1076,18 @@ static int __init sev_guest_probe(struct platform_device *pdev) if (ret) goto e_free_ctx; + snp_dev->msg_desc = mdesc; dev_info(dev, "Initialized SEV guest driver (using VMPCK%d communication key)\n", vmpck_id); return 0; e_free_ctx: - kfree(snp_dev->ctx); + kfree(mdesc->ctx); e_free_cert_data: - free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE); + free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE); e_free_response: - free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg)); + free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg)); e_free_request: - free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg)); + free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg)); e_unmap: iounmap(mapping); return ret; @@ -1100,11 +1096,12 @@ static int __init sev_guest_probe(struct platform_device *pdev) static void __exit sev_guest_remove(struct platform_device *pdev) { struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev); + struct snp_msg_desc *mdesc = snp_dev->msg_desc; - free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE); - free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg)); - free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg)); - kfree(snp_dev->ctx); + free_shared_pages(mdesc->certs_data, SEV_FW_BLOB_MAX_SIZE); + free_shared_pages(mdesc->response, sizeof(struct snp_guest_msg)); + free_shared_pages(mdesc->request, sizeof(struct snp_guest_msg)); + kfree(mdesc->ctx); misc_deregister(&snp_dev->misc); }
Currently, the sev-guest driver is the only user of SNP guest messaging. snp_guest_dev structure holds all the allocated buffers, secrets page and VMPCK details. In preparation of adding messaging allocation and initialization APIs, decouple snp_guest_dev from messaging-related information by carving out guest message context structure(snp_msg_desc). Incorporate this newly added context into snp_send_guest_request() and all related functions, replacing the use of the snp_guest_dev. No functional change. Signed-off-by: Nikunj A Dadhania <nikunj@amd.com> --- arch/x86/include/asm/sev.h | 21 +++ drivers/virt/coco/sev-guest/sev-guest.c | 183 ++++++++++++------------ 2 files changed, 111 insertions(+), 93 deletions(-)