diff mbox series

[RFC,17/21] coco/sev-guest: Implement the guest side of things

Message ID 20240823132137.336874-18-aik@amd.com (mailing list archive)
State New
Headers show
Series Secure VFIO, TDISP, SEV TIO | expand

Commit Message

Alexey Kardashevskiy Aug. 23, 2024, 1:21 p.m. UTC
Define tsm_ops for the guest and forward the ops calls to the HV via
SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST.
Do the attestation report examination and enable MMIO.

Signed-off-by: Alexey Kardashevskiy <aik@amd.com>
---
 drivers/virt/coco/sev-guest/Makefile        |   2 +-
 arch/x86/include/asm/sev.h                  |   2 +
 drivers/virt/coco/sev-guest/sev-guest.h     |   2 +
 include/linux/psp-sev.h                     |  22 +
 arch/x86/coco/sev/core.c                    |  11 +
 drivers/virt/coco/sev-guest/sev_guest.c     |  16 +-
 drivers/virt/coco/sev-guest/sev_guest_tio.c | 513 ++++++++++++++++++++
 7 files changed, 566 insertions(+), 2 deletions(-)

Comments

Jonathan Cameron Aug. 28, 2024, 3:54 p.m. UTC | #1
On Fri, 23 Aug 2024 23:21:31 +1000
Alexey Kardashevskiy <aik@amd.com> wrote:

> Define tsm_ops for the guest and forward the ops calls to the HV via
> SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST.
> Do the attestation report examination and enable MMIO.
> 
> Signed-off-by: Alexey Kardashevskiy <aik@amd.com>
More trivial stuff.

> diff --git a/drivers/virt/coco/sev-guest/sev_guest_tio.c b/drivers/virt/coco/sev-guest/sev_guest_tio.c
> new file mode 100644
> index 000000000000..33a082e7f039
> --- /dev/null
> +++ b/drivers/virt/coco/sev-guest/sev_guest_tio.c
> @@ -0,0 +1,513 @@



> +static int tio_tdi_sdte_write(struct tsm_tdi *tdi, struct snp_guest_dev *snp_dev, bool invalidate)
> +{
> +	struct snp_guest_crypto *crypto = snp_dev->crypto;
> +	size_t resp_len = sizeof(struct tio_msg_sdte_write_rsp) + crypto->a_len;
> +	struct tio_msg_sdte_write_rsp *rsp = kzalloc(resp_len, GFP_KERNEL);
> +	struct tio_msg_sdte_write_req req = {
> +		.guest_device_id = pci_dev_id(tdi->pdev),
> +		.sdte.vmpl = 0,
> +		.sdte.vtom = tsm_vtom,
> +		.sdte.vtom_en = 1,
> +		.sdte.iw = 1,
> +		.sdte.ir = 1,
> +		.sdte.v = 1,
> +	};
> +	u64 fw_err = 0;
> +	u64 bdfn = pci_dev_id(tdi->pdev);
> +	int rc;
> +
> +	BUILD_BUG_ON(sizeof(struct sdte) * 8 != 512);
> +
> +	if (invalidate)
> +		memset(&req, 0, sizeof(req));

Little odd to fill it then zero it.  Maybe just fill it
if !invalidate

> +
> +	pci_notice(tdi->pdev, "SDTE write vTOM=%lx", (unsigned long) req.sdte.vtom << 21);
> +
> +	if (!rsp)

I'd allocate rsp down here as then obvious what is going on.

> +		return -ENOMEM;
> +
> +	rc = handle_tio_guest_request(snp_dev, TIO_MSG_SDTE_WRITE_REQ,
> +			       &req, sizeof(req), rsp, resp_len,
> +			       NULL, NULL, &bdfn, NULL, &fw_err);
> +	if (rc) {
> +		pci_err(tdi->pdev, "SDTE write failed with 0x%llx\n", fw_err);
> +		goto free_exit;
> +	}
> +
> +free_exit:
> +	/* The response buffer contains the sensitive data, explicitly clear it. */
> +	memzero_explicit(&rsp, sizeof(resp_len));
> +	kfree(rsp);

kfree_sensitive() perhaps?

> +	return rc;
> +}

> +static int sev_guest_tdi_validate(struct tsm_tdi *tdi, bool invalidate, void *private_data)
> +{
> +	struct snp_guest_dev *snp_dev = private_data;
> +	struct tsm_tdi_status ts = { 0 };
> +	int ret;
> +
> +	if (!tdi->report) {
> +		ret = tio_tdi_status(tdi, snp_dev, &ts);
> +
> +		if (ret || !tdi->report) {
> +			pci_err(tdi->pdev, "No report available, ret=%d", ret);
> +			if (!ret && tdi->report)
> +				ret = -EIO;
> +			return ret;
I'd split the error paths to simplify the logic.
		if (ret) {
			pci_err(tdi->pdev, "No report available, ret=%d", ret);
			return ret;
		}
		if (!tdi->report) {
			pci_err(... some more meaningful message)
			return -EIO;
> +		}
> +
> +		if (ts.state != TDISP_STATE_RUN) {
> +			pci_err(tdi->pdev, "Not in RUN state, state=%d instead", ts.state);
> +			return -EIO;
> +		}
> +	}
> +
> +	ret = tio_tdi_sdte_write(tdi, snp_dev, invalidate);
> +	if (ret)
> +		return ret;
> +
> +	ret = tio_tdi_mmio_validate(tdi, snp_dev, invalidate);

return tio_tdi_mmio_validate();

> +	if (ret)
> +		return ret;
> +
> +	return 0;
> +}
Zhi Wang Sept. 14, 2024, 7:19 a.m. UTC | #2
On Fri, 23 Aug 2024 23:21:31 +1000
Alexey Kardashevskiy <aik@amd.com> wrote:

> Define tsm_ops for the guest and forward the ops calls to the HV via
> SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST.
> Do the attestation report examination and enable MMIO.
> 

It seems in both guest side (this patch) and host side
(PATCH 7 tsm_report_show()), if the SW wants to reach the latest TDI
report, they have to call get TDI status verb first.

As this is about UABI, if this is expected, it would nice that we can
explicitly document this requirement. Or we just get the fresh report
from the device all the time?

Thanks,
Zhi.


> Signed-off-by: Alexey Kardashevskiy <aik@amd.com>
> ---
>  drivers/virt/coco/sev-guest/Makefile        |   2 +-
>  arch/x86/include/asm/sev.h                  |   2 +
>  drivers/virt/coco/sev-guest/sev-guest.h     |   2 +
>  include/linux/psp-sev.h                     |  22 +
>  arch/x86/coco/sev/core.c                    |  11 +
>  drivers/virt/coco/sev-guest/sev_guest.c     |  16 +-
>  drivers/virt/coco/sev-guest/sev_guest_tio.c | 513
> ++++++++++++++++++++ 7 files changed, 566 insertions(+), 2
> deletions(-)
> 
> diff --git a/drivers/virt/coco/sev-guest/Makefile
> b/drivers/virt/coco/sev-guest/Makefile index
> 2d7dffed7b2f..34ea9fab698b 100644 ---
> a/drivers/virt/coco/sev-guest/Makefile +++
> b/drivers/virt/coco/sev-guest/Makefile @@ -1,3 +1,3 @@
>  # SPDX-License-Identifier: GPL-2.0-only
>  obj-$(CONFIG_SEV_GUEST) += sev-guest.o
> -sev-guest-y += sev_guest.o
> +sev-guest-y += sev_guest.o sev_guest_tio.o
> diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
> index 8edd7bccabf2..431c12bbd337 100644
> --- a/arch/x86/include/asm/sev.h
> +++ b/arch/x86/include/asm/sev.h
> @@ -117,6 +117,8 @@ struct snp_req_data {
>  	unsigned long resp_gpa;
>  	unsigned long data_gpa;
>  	unsigned int data_npages;
> +	unsigned int guest_rid;
> +	unsigned long param;
>  };
>  
>  #define MAX_AUTHTAG_LEN		32
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.h
> b/drivers/virt/coco/sev-guest/sev-guest.h index
> 765f42ff55aa..d1254148c83b 100644 ---
> a/drivers/virt/coco/sev-guest/sev-guest.h +++
> b/drivers/virt/coco/sev-guest/sev-guest.h @@ -51,4 +51,6 @@ int
> handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
> void *alloc_shared_pages(struct device *dev, size_t sz); void
> free_shared_pages(void *buf, size_t sz); 
> +void sev_guest_tsm_set_ops(bool set, struct snp_guest_dev *snp_dev);
> +
>  #endif /* __VIRT_SEVGUEST_H__ */
> diff --git a/include/linux/psp-sev.h b/include/linux/psp-sev.h
> index adf40e0316dc..bff7396d18de 100644
> --- a/include/linux/psp-sev.h
> +++ b/include/linux/psp-sev.h
> @@ -1050,6 +1050,9 @@ static inline void snp_free_firmware_page(void
> *addr) { } #define MMIO_VALIDATE_RANGEID(r)  ((r) & 0x7)
>  #define MMIO_VALIDATE_RESERVED(r) ((r) & 0xFFF0000000000008ULL)
>  
> +#define MMIO_MK_VALIDATE(start, size, range_id) \
> +	(MMIO_VALIDATE_GPA(start) | (get_order(size >> 12) << 4) |
> ((range_id) & 0xFF)) +
>  /* Optional Certificates/measurements/report data from
> TIO_GUEST_REQUEST */ struct tio_blob_table_entry {
>  	guid_t guid;
> @@ -1067,4 +1070,23 @@ struct tio_blob_table_entry {
>  #define TIO_GUID_REPORT \
>  	GUID_INIT(0x70dc5b0e, 0x0cc0, 0x4cd5, 0x97, 0xbb, 0xff,
> 0x0b, 0xa2, 0x5b, 0xf3, 0x20) 
> +/*
> + * Status codes from TIO_MSG_MMIO_VALIDATE_REQ
> + */
> +enum mmio_validate_status {
> +	MMIO_VALIDATE_SUCCESS = 0,
> +	MMIO_VALIDATE_INVALID_TDI = 1,
> +	MMIO_VALIDATE_TDI_UNBOUND = 2,
> +	MMIO_VALIDATE_NOT_ASSIGNED = 3, /* At least one page is not
> assigned to the guest */
> +	MMIO_VALIDATE_NOT_UNIFORM = 4,  /* The Validated bit is not
> uniformly set for
> +					   the MMIO subrange */
> +	MMIO_VALIDATE_NOT_IMMUTABLE = 5,/* At least one page does
> not have immutable bit set
> +					   when validated bit is
> clear */
> +	MMIO_VALIDATE_NOT_MAPPED = 6,   /* At least one page is not
> mapped to the expected GPA */
> +	MMIO_VALIDATE_NOT_REPORTED = 7, /* The provided MMIO range
> ID is not reported in
> +					   the interface report */
> +	MMIO_VALIDATE_OUT_OF_RANGE = 8, /* The subrange is out the
> MMIO range in
> +					   the interface report */
> +};
> +
>  #endif	/* __PSP_SEV_H__ */
> diff --git a/arch/x86/coco/sev/core.c b/arch/x86/coco/sev/core.c
> index de1df0cb45da..d05a97421ffc 100644
> --- a/arch/x86/coco/sev/core.c
> +++ b/arch/x86/coco/sev/core.c
> @@ -2468,6 +2468,11 @@ int snp_issue_guest_request(u64 exit_code,
> struct snp_req_data *input, struct sn if (exit_code ==
> SVM_VMGEXIT_EXT_GUEST_REQUEST) { ghcb_set_rax(ghcb, input->data_gpa);
>  		ghcb_set_rbx(ghcb, input->data_npages);
> +	} else if (exit_code == SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST) {
> +		ghcb_set_rax(ghcb, input->data_gpa);
> +		ghcb_set_rbx(ghcb, input->data_npages);
> +		ghcb_set_rcx(ghcb, input->guest_rid);
> +		ghcb_set_rdx(ghcb, input->param);
>  	}
>  
>  	ret = sev_es_ghcb_hv_call(ghcb, &ctxt, exit_code,
> input->req_gpa, input->resp_gpa); @@ -2477,6 +2482,8 @@ int
> snp_issue_guest_request(u64 exit_code, struct snp_req_data *input,
> struct sn rio->exitinfo2 = ghcb->save.sw_exit_info_2; switch
> (rio->exitinfo2) { case 0:
> +		if (exit_code == SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST)
> +			input->param = ghcb_get_rdx(ghcb);
>  		break;
>  
>  	case SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_BUSY):
> @@ -2489,6 +2496,10 @@ int snp_issue_guest_request(u64 exit_code,
> struct snp_req_data *input, struct sn input->data_npages =
> ghcb_get_rbx(ghcb); ret = -ENOSPC;
>  			break;
> +		} else if (exit_code ==
> SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST) {
> +			input->data_npages = ghcb_get_rbx(ghcb);
> +			ret = -ENOSPC;
> +			break;
>  		}
>  		fallthrough;
>  	default:
> diff --git a/drivers/virt/coco/sev-guest/sev_guest.c
> b/drivers/virt/coco/sev-guest/sev_guest.c index
> d04d270f359e..571faade5690 100644 ---
> a/drivers/virt/coco/sev-guest/sev_guest.c +++
> b/drivers/virt/coco/sev-guest/sev_guest.c @@ -52,6 +52,10 @@ static
> int vmpck_id = -1; module_param(vmpck_id, int, 0444);
>  MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating
> with the PSP."); 
> +static bool tsm_enable = true;
> +module_param(tsm_enable, bool, 0644);
> +MODULE_PARM_DESC(tsm_enable, "Enable SEV TIO");
> +
>  /* Mutex to serialize the shared buffer access and command handling.
> */ DEFINE_MUTEX(snp_cmd_mutex);
>  
> @@ -277,7 +281,8 @@ static int verify_and_dec_payload(struct
> snp_guest_dev *snp_dev, void *payload, return -EBADMSG;
>  
>  	/* Verify response message type and version number. */
> -	if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
> +	if ((resp_hdr->msg_type != (req_hdr->msg_type + 1) &&
> +	     (resp_hdr->msg_type != (req_hdr->msg_type - 0x80))) ||
>  	    resp_hdr->msg_version != req_hdr->msg_version)
>  		return -EBADMSG;
>  
> @@ -337,6 +342,10 @@ static int __handle_guest_request(struct
> snp_guest_dev *snp_dev, u64 exit_code, rc =
> snp_issue_guest_request(exit_code, &snp_dev->input, rio); switch (rc)
> { case -ENOSPC:
> +		if (exit_code == SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST) {
> +			pr_warn("SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST
> => -ENOSPC");
> +			break;
> +		}
>  		/*
>  		 * If the extended guest request fails due to having
> too
>  		 * small of a certificate data buffer, retry the same
> @@ -1142,6 +1151,9 @@ static int __init sev_guest_probe(struct
> platform_device *pdev) if (ret)
>  		goto e_free_cert_data;
>  
> +	if (tsm_enable)
> +		sev_guest_tsm_set_ops(true, snp_dev);
> +
>  	dev_info(dev, "Initialized SEV guest driver (using vmpck_id
> %d)\n", vmpck_id); return 0;
>  
> @@ -1160,6 +1172,8 @@ static void __exit sev_guest_remove(struct
> platform_device *pdev) {
>  	struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
>  
> +	if (tsm_enable)
> +		sev_guest_tsm_set_ops(false, snp_dev);
>  	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
>  	free_shared_pages(snp_dev->response, sizeof(struct
> snp_guest_msg)); free_shared_pages(snp_dev->request, sizeof(struct
> snp_guest_msg)); diff --git
> a/drivers/virt/coco/sev-guest/sev_guest_tio.c
> b/drivers/virt/coco/sev-guest/sev_guest_tio.c new file mode 100644
> index 000000000000..33a082e7f039 --- /dev/null
> +++ b/drivers/virt/coco/sev-guest/sev_guest_tio.c
> @@ -0,0 +1,513 @@
> +// SPDX-License-Identifier: GPL-2.0-only
> +
> +#include <linux/pci.h>
> +#include <linux/psp-sev.h>
> +#include <linux/tsm.h>
> +
> +#include <asm/svm.h>
> +#include <asm/sev.h>
> +
> +#include "sev-guest.h"
> +
> +#define TIO_MESSAGE_VERSION	1
> +
> +ulong tsm_vtom = 0x7fffffff;
> +module_param(tsm_vtom, ulong, 0644);
> +MODULE_PARM_DESC(tsm_vtom, "SEV TIO vTOM value");
> +
> +static void tio_guest_blob_free(struct tsm_blob *b)
> +{
> +	memset(b->data, 0, b->len);
> +}
> +
> +static int handle_tio_guest_request(struct snp_guest_dev *snp_dev,
> u8 type,
> +				   void *req_buf, size_t req_sz,
> void *resp_buf, u32 resp_sz,
> +				   u64 *pt_pa, u64 *npages, u64
> *bdfn, u64 *param, u64 *fw_err) +{
> +	struct snp_guest_request_ioctl rio = {
> +		.msg_version = TIO_MESSAGE_VERSION,
> +		.exitinfo2 = 0,
> +	};
> +	int ret;
> +
> +	snp_dev->input.data_gpa = 0;
> +	snp_dev->input.data_npages = 0;
> +	snp_dev->input.guest_rid = 0;
> +	snp_dev->input.param = 0;
> +
> +	if (pt_pa && npages) {
> +		snp_dev->input.data_gpa = *pt_pa;
> +		snp_dev->input.data_npages = *npages;
> +	}
> +	if (bdfn)
> +		snp_dev->input.guest_rid = *bdfn;
> +	if (param)
> +		snp_dev->input.param = *param;
> +
> +	mutex_lock(&snp_cmd_mutex);
> +	ret = handle_guest_request(snp_dev,
> SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST,
> +				   &rio, type, req_buf, req_sz,
> resp_buf, resp_sz);
> +	mutex_unlock(&snp_cmd_mutex);
> +
> +	if (param)
> +		*param = snp_dev->input.param;
> +
> +	*fw_err = rio.exitinfo2;
> +
> +	return ret;
> +}
> +
> +static int guest_request_tio_certs(struct snp_guest_dev *snp_dev, u8
> type,
> +				   void *req_buf, size_t req_sz,
> void *resp_buf, u32 resp_sz,
> +				   u64 bdfn, enum tsm_tdisp_state
> *state,
> +				   struct tsm_blob **certs, struct
> tsm_blob **meas,
> +				   struct tsm_blob **report, u64
> *fw_err) +{
> +	u64 certs_size = SZ_32K, c1 = 0, pt_pa, param = 0;
> +	struct tio_blob_table_entry *pt;
> +	int rc;
> +
> +	pt = alloc_shared_pages(snp_dev->dev, certs_size);
> +	if (!pt)
> +		return -ENOMEM;
> +
> +	pt_pa = __pa(pt);
> +	c1 = certs_size;
> +	rc = handle_tio_guest_request(snp_dev, type, req_buf,
> req_sz, resp_buf, resp_sz,
> +				      &pt_pa, &c1, &bdfn, state ?
> &param : NULL, fw_err); +
> +	if (c1 > SZ_32K) {
> +		free_shared_pages(pt, certs_size);
> +		certs_size = c1;
> +		pt = alloc_shared_pages(snp_dev->dev, certs_size);
> +		if (!pt)
> +			return -ENOMEM;
> +
> +		pt_pa = __pa(pt);
> +		rc = handle_tio_guest_request(snp_dev, type,
> req_buf, req_sz, resp_buf, resp_sz,
> +					      &pt_pa, &c1, &bdfn,
> state ? &param : NULL, fw_err);
> +	}
> +
> +	if (rc)
> +		return rc;
> +
> +	tsm_blob_put(*meas);
> +	tsm_blob_put(*certs);
> +	tsm_blob_put(*report);
> +	*meas = NULL;
> +	*certs = NULL;
> +	*report = NULL;
> +
> +	for (unsigned int i = 0; i < 3; ++i) {
> +		u8 *ptr = ((u8 *) pt) + pt[i].offset;
> +		size_t len = pt[i].length;
> +		struct tsm_blob *b;
> +
> +		if (guid_is_null(&pt[i].guid))
> +			break;
> +
> +		if (!len)
> +			continue;
> +
> +		b = tsm_blob_new(ptr, len, tio_guest_blob_free);
> +		if (!b)
> +			break;
> +
> +		if (guid_equal(&pt[i].guid, &TIO_GUID_MEASUREMENTS))
> +			*meas = b;
> +		else if (guid_equal(&pt[i].guid,
> &TIO_GUID_CERTIFICATES))
> +			*certs = b;
> +		else if (guid_equal(&pt[i].guid, &TIO_GUID_REPORT))
> +			*report = b;
> +	}
> +	free_shared_pages(pt, certs_size);
> +
> +	if (state)
> +		*state = param;
> +
> +	return 0;
> +}
> +
> +struct tio_msg_tdi_info_req {
> +	__u16 guest_device_id;
> +	__u8 reserved[14];
> +} __packed;
> +
> +struct tio_msg_tdi_info_rsp {
> +	__u16 guest_device_id;
> +	__u16 status;
> +	__u8 reserved1[12];
> +	union {
> +		u32 meas_flags;
> +		struct {
> +			u32 meas_digest_valid : 1;
> +			u32 meas_digest_fresh : 1;
> +		};
> +	};
> +	union {
> +		u32 tdisp_lock_flags;
> +		/* These are TDISP's LOCK_INTERFACE_REQUEST flags */
> +		struct {
> +			u32 no_fw_update : 1;
> +			u32 cache_line_size : 1;
> +			u32 lock_msix : 1;
> +			u32 bind_p2p : 1;
> +			u32 all_request_redirect : 1;
> +		};
> +	};
> +	__u64 spdm_algos;
> +	__u8 certs_digest[48];
> +	__u8 meas_digest[48];
> +	__u8 interface_report_digest[48];
> +} __packed;
> +
> +static int tio_tdi_status(struct tsm_tdi *tdi, struct snp_guest_dev
> *snp_dev,
> +			  struct tsm_tdi_status *ts)
> +{
> +	struct snp_guest_crypto *crypto = snp_dev->crypto;
> +	size_t resp_len = sizeof(struct tio_msg_tdi_info_rsp) +
> crypto->a_len;
> +	struct tio_msg_tdi_info_rsp *rsp = kzalloc(resp_len,
> GFP_KERNEL);
> +	struct tio_msg_tdi_info_req req = {
> +		.guest_device_id = pci_dev_id(tdi->pdev),
> +	};
> +	u64 fw_err = 0;
> +	int rc;
> +	enum tsm_tdisp_state state = 0;
> +
> +	pci_notice(tdi->pdev, "TDI info");
> +	if (!rsp)
> +		return -ENOMEM;
> +
> +	rc = guest_request_tio_certs(snp_dev, TIO_MSG_TDI_INFO_REQ,
> &req,
> +				     sizeof(req), rsp, resp_len,
> +				     pci_dev_id(tdi->pdev), &state,
> +				     &tdi->tdev->certs,
> &tdi->tdev->meas,
> +				     &tdi->report, &fw_err);
> +
> +	ts->meas_digest_valid = rsp->meas_digest_valid;
> +	ts->meas_digest_fresh = rsp->meas_digest_fresh;
> +	ts->no_fw_update = rsp->no_fw_update;
> +	ts->cache_line_size = rsp->cache_line_size == 0 ? 64 : 128;
> +	ts->lock_msix = rsp->lock_msix;
> +	ts->bind_p2p = rsp->bind_p2p;
> +	ts->all_request_redirect = rsp->all_request_redirect;
> +#define __ALGO(x, n, y) \
> +	((((x) & (0xFFUL << (n))) == TIO_SPDM_ALGOS_##y) ? \
> +	 (1ULL << TSM_TDI_SPDM_ALGOS_##y) : 0)
> +	ts->spdm_algos =
> +		__ALGO(rsp->spdm_algos, 0, DHE_SECP256R1) |
> +		__ALGO(rsp->spdm_algos, 0, DHE_SECP384R1) |
> +		__ALGO(rsp->spdm_algos, 8, AEAD_AES_128_GCM) |
> +		__ALGO(rsp->spdm_algos, 8, AEAD_AES_256_GCM) |
> +		__ALGO(rsp->spdm_algos, 16,
> ASYM_TPM_ALG_RSASSA_3072) |
> +		__ALGO(rsp->spdm_algos, 16,
> ASYM_TPM_ALG_ECDSA_ECC_NIST_P256) |
> +		__ALGO(rsp->spdm_algos, 16,
> ASYM_TPM_ALG_ECDSA_ECC_NIST_P384) |
> +		__ALGO(rsp->spdm_algos, 24, HASH_TPM_ALG_SHA_256) |
> +		__ALGO(rsp->spdm_algos, 24, HASH_TPM_ALG_SHA_384) |
> +		__ALGO(rsp->spdm_algos, 32,
> KEY_SCHED_SPDM_KEY_SCHEDULE); +#undef __ALGO
> +	memcpy(ts->certs_digest, rsp->certs_digest,
> sizeof(ts->certs_digest));
> +	memcpy(ts->meas_digest, rsp->meas_digest,
> sizeof(ts->meas_digest));
> +	memcpy(ts->interface_report_digest,
> rsp->interface_report_digest,
> +	       sizeof(ts->interface_report_digest));
> +
> +	ts->valid = true;
> +	ts->state = state;
> +	/* The response buffer contains the sensitive data,
> explicitly clear it. */
> +	memzero_explicit(&rsp, sizeof(resp_len));
> +	kfree(rsp);
> +	return rc;
> +}
> +
> +struct tio_msg_mmio_validate_req {
> +	__u16 guest_device_id; /* Hypervisor provided identifier
> used by the guest
> +				  to identify the TDI in guest
> messages */
> +	__u16 reserved1;
> +	__u8 reserved2[12];
> +	__u64 subrange_base;
> +	__u32 subrange_page_count;
> +	__u32 range_offset;
> +	union {
> +		__u16 flags;
> +		struct {
> +			__u16 validated:1; /* Desired value to set
> RMP.Validated for the range */
> +			/* Force validated:
> +			 * 0: If subrange does not have
> RMP.Validated set uniformly, fail.
> +			 * 1: If subrange does not have
> RMP.Validated set uniformly, force
> +			 *    to requested value
> +			 */
> +			__u16 force_validated:1;
> +		};
> +	};
> +	__u16 range_id;
> +	__u8 reserved3[12];
> +} __packed;
> +
> +struct tio_msg_mmio_validate_rsp {
> +	__u16 guest_interface_id;
> +	__u16 status; /* MMIO_VALIDATE_xxx */
> +	__u8 reserved1[12];
> +	__u64 subrange_base;
> +	__u32 subrange_page_count;
> +	__u32 range_offset;
> +	union {
> +		__u16 flags;
> +		struct {
> +			__u16 changed:1; /* Indicates that the
> Validated bit has changed
> +					    due to this operation */
> +		};
> +	};
> +	__u16 range_id;
> +	__u8 reserved2[12];
> +} __packed;
> +
> +static int mmio_validate_range(struct snp_guest_dev *snp_dev, struct
> pci_dev *pdev,
> +			       unsigned int range_id,
> resource_size_t start, resource_size_t size,
> +			       bool invalidate, u64 *fw_err)
> +{
> +	struct snp_guest_crypto *crypto = snp_dev->crypto;
> +	size_t resp_len = sizeof(struct tio_msg_mmio_validate_rsp) +
> crypto->a_len;
> +	struct tio_msg_mmio_validate_rsp *rsp = kzalloc(resp_len,
> GFP_KERNEL);
> +	struct tio_msg_mmio_validate_req req = {
> +		.guest_device_id = pci_dev_id(pdev),
> +		.subrange_base = start,
> +		.subrange_page_count = size >> PAGE_SHIFT,
> +		.range_offset = 0,
> +		.validated = 1, /* Desired value to set
> RMP.Validated for the range */
> +		.force_validated = 0,
> +		.range_id = range_id,
> +	};
> +	u64 bdfn = pci_dev_id(pdev);
> +	u64 mmio_val = MMIO_MK_VALIDATE(start, size, range_id);
> +	int rc;
> +
> +	if (!rsp)
> +		return -ENOMEM;
> +
> +	if (invalidate)
> +		memset(&req, 0, sizeof(req));
> +
> +	rc = handle_tio_guest_request(snp_dev,
> TIO_MSG_MMIO_VALIDATE_REQ,
> +			       &req, sizeof(req), rsp, resp_len,
> +			       NULL, NULL, &bdfn, &mmio_val, fw_err);
> +	if (rc)
> +		goto free_exit;
> +
> +	if (rsp->status)
> +		rc = -EBADR;
> +
> +free_exit:
> +	/* The response buffer contains the sensitive data,
> explicitly clear it. */
> +	memzero_explicit(&rsp, sizeof(resp_len));
> +	kfree(rsp);
> +	return rc;
> +}
> +
> +static int tio_tdi_mmio_validate(struct tsm_tdi *tdi, struct
> snp_guest_dev *snp_dev,
> +				 bool invalidate)
> +{
> +	struct pci_dev *pdev = tdi->pdev;
> +	struct tdi_report_mmio_range mr;
> +	struct resource *r;
> +	u64 fw_err = 0;
> +	int i = 0, rc;
> +
> +	pci_notice(tdi->pdev, "MMIO validate");
> +
> +	if (WARN_ON_ONCE(!tdi->report || !tdi->report->data))
> +		return -EFAULT;
> +
> +	for (i = 0; i < TDI_REPORT_MR_NUM(tdi->report); ++i) {
> +		mr = TDI_REPORT_MR(tdi->report, i);
> +		r = pci_resource_n(tdi->pdev, mr.range_id);
> +
> +		if (r->end == r->start || ((r->end - r->start + 1) &
> ~PAGE_MASK) || !mr.num) {
> +			pci_warn(tdi->pdev, "Skipping broken range
> [%d] #%d %d pages, %llx..%llx\n",
> +				i, mr.range_id, mr.num, r->start,
> r->end);
> +			continue;
> +		}
> +
> +		if (mr.is_non_tee_mem) {
> +			pci_info(tdi->pdev, "Skipping non-TEE range
> [%d] #%d %d pages, %llx..%llx\n",
> +				 i, mr.range_id, mr.num, r->start,
> r->end);
> +			continue;
> +		}
> +
> +		rc = mmio_validate_range(snp_dev, pdev, mr.range_id,
> +					 r->start, r->end - r->start
> + 1, invalidate, &fw_err);
> +		if (rc) {
> +			pci_err(pdev, "MMIO #%d %llx..%llx
> validation failed 0x%llx\n",
> +				mr.range_id, r->start, r->end,
> fw_err);
> +			continue;
> +		}
> +
> +		pci_notice(pdev, "MMIO #%d %llx..%llx validated\n",
> mr.range_id, r->start, r->end);
> +	}
> +
> +	return rc;
> +}
> +
> +struct sdte {
> +	__u64 v                  : 1;
> +	__u64 reserved           : 3;
> +	__u64 cxlio              : 3;
> +	__u64 reserved1          : 45;
> +	__u64 ppr                : 1;
> +	__u64 reserved2          : 1;
> +	__u64 giov               : 1;
> +	__u64 gv                 : 1;
> +	__u64 glx                : 2;
> +	__u64 gcr3_tbl_rp0       : 3;
> +	__u64 ir                 : 1;
> +	__u64 iw                 : 1;
> +	__u64 reserved3          : 1;
> +	__u16 domain_id;
> +	__u16 gcr3_tbl_rp1;
> +	__u32 interrupt          : 1;
> +	__u32 reserved4          : 5;
> +	__u32 ex                 : 1;
> +	__u32 sd                 : 1;
> +	__u32 reserved5          : 2;
> +	__u32 sats               : 1;
> +	__u32 gcr3_tbl_rp2       : 21;
> +	__u64 giv                : 1;
> +	__u64 gint_tbl_len       : 4;
> +	__u64 reserved6          : 1;
> +	__u64 gint_tbl           : 46;
> +	__u64 reserved7          : 2;
> +	__u64 gpm                : 2;
> +	__u64 reserved8          : 3;
> +	__u64 hpt_mode           : 1;
> +	__u64 reserved9          : 4;
> +	__u32 asid               : 12;
> +	__u32 reserved10         : 3;
> +	__u32 viommu_en          : 1;
> +	__u32 guest_device_id    : 16;
> +	__u32 guest_id           : 15;
> +	__u32 guest_id_mbo       : 1;
> +	__u32 reserved11         : 1;
> +	__u32 vmpl               : 2;
> +	__u32 reserved12         : 3;
> +	__u32 attrv              : 1;
> +	__u32 reserved13         : 1;
> +	__u32 sa                 : 8;
> +	__u8 ide_stream_id[8];
> +	__u32 vtom_en            : 1;
> +	__u32 vtom               : 31;
> +	__u32 rp_id              : 5;
> +	__u32 reserved14         : 27;
> +	__u8  reserved15[0x40-0x30];
> +} __packed;
> +
> +struct tio_msg_sdte_write_req {
> +	__u16 guest_device_id;
> +	__u8 reserved[14];
> +	struct sdte sdte;
> +} __packed;
> +
> +#define SDTE_WRITE_SUCCESS		0
> +#define SDTE_WRITE_INVALID_TDI		1
> +#define SDTE_WRITE_TDI_NOT_BOUND	2
> +#define SDTE_WRITE_RESERVED		3
> +
> +struct tio_msg_sdte_write_rsp {
> +	__u16 guest_device_id;
> +	__u16 status; /* SDTE_WRITE_xxx */
> +	__u8 reserved[12];
> +} __packed;
> +
> +static int tio_tdi_sdte_write(struct tsm_tdi *tdi, struct
> snp_guest_dev *snp_dev, bool invalidate) +{
> +	struct snp_guest_crypto *crypto = snp_dev->crypto;
> +	size_t resp_len = sizeof(struct tio_msg_sdte_write_rsp) +
> crypto->a_len;
> +	struct tio_msg_sdte_write_rsp *rsp = kzalloc(resp_len,
> GFP_KERNEL);
> +	struct tio_msg_sdte_write_req req = {
> +		.guest_device_id = pci_dev_id(tdi->pdev),
> +		.sdte.vmpl = 0,
> +		.sdte.vtom = tsm_vtom,
> +		.sdte.vtom_en = 1,
> +		.sdte.iw = 1,
> +		.sdte.ir = 1,
> +		.sdte.v = 1,
> +	};
> +	u64 fw_err = 0;
> +	u64 bdfn = pci_dev_id(tdi->pdev);
> +	int rc;
> +
> +	BUILD_BUG_ON(sizeof(struct sdte) * 8 != 512);
> +
> +	if (invalidate)
> +		memset(&req, 0, sizeof(req));
> +
> +	pci_notice(tdi->pdev, "SDTE write vTOM=%lx", (unsigned long)
> req.sdte.vtom << 21); +
> +	if (!rsp)
> +		return -ENOMEM;
> +
> +	rc = handle_tio_guest_request(snp_dev,
> TIO_MSG_SDTE_WRITE_REQ,
> +			       &req, sizeof(req), rsp, resp_len,
> +			       NULL, NULL, &bdfn, NULL, &fw_err);
> +	if (rc) {
> +		pci_err(tdi->pdev, "SDTE write failed with
> 0x%llx\n", fw_err);
> +		goto free_exit;
> +	}
> +
> +free_exit:
> +	/* The response buffer contains the sensitive data,
> explicitly clear it. */
> +	memzero_explicit(&rsp, sizeof(resp_len));
> +	kfree(rsp);
> +	return rc;
> +}
> +
> +static int sev_guest_tdi_status(struct tsm_tdi *tdi, void
> *private_data, struct tsm_tdi_status *ts) +{
> +	struct snp_guest_dev *snp_dev = private_data;
> +
> +	return tio_tdi_status(tdi, snp_dev, ts);
> +}
> +
> +static int sev_guest_tdi_validate(struct tsm_tdi *tdi, bool
> invalidate, void *private_data) +{
> +	struct snp_guest_dev *snp_dev = private_data;
> +	struct tsm_tdi_status ts = { 0 };
> +	int ret;
> +
> +	if (!tdi->report) {
> +		ret = tio_tdi_status(tdi, snp_dev, &ts);
> +
> +		if (ret || !tdi->report) {
> +			pci_err(tdi->pdev, "No report available,
> ret=%d", ret);
> +			if (!ret && tdi->report)
> +				ret = -EIO;
> +			return ret;
> +		}
> +
> +		if (ts.state != TDISP_STATE_RUN) {
> +			pci_err(tdi->pdev, "Not in RUN state,
> state=%d instead", ts.state);
> +			return -EIO;
> +		}
> +	}
> +
> +	ret = tio_tdi_sdte_write(tdi, snp_dev, invalidate);
> +	if (ret)
> +		return ret;
> +
> +	ret = tio_tdi_mmio_validate(tdi, snp_dev, invalidate);
> +	if (ret)
> +		return ret;
> +
> +	return 0;
> +}
> +
> +struct tsm_ops sev_guest_tsm_ops = {
> +	.tdi_validate = sev_guest_tdi_validate,
> +	.tdi_status = sev_guest_tdi_status,
> +};
> +
> +void sev_guest_tsm_set_ops(bool set, struct snp_guest_dev *snp_dev)
> +{
> +	if (set)
> +		tsm_set_ops(&sev_guest_tsm_ops, snp_dev);
> +	else
> +		tsm_set_ops(NULL, NULL);
> +}
diff mbox series

Patch

diff --git a/drivers/virt/coco/sev-guest/Makefile b/drivers/virt/coco/sev-guest/Makefile
index 2d7dffed7b2f..34ea9fab698b 100644
--- a/drivers/virt/coco/sev-guest/Makefile
+++ b/drivers/virt/coco/sev-guest/Makefile
@@ -1,3 +1,3 @@ 
 # SPDX-License-Identifier: GPL-2.0-only
 obj-$(CONFIG_SEV_GUEST) += sev-guest.o
-sev-guest-y += sev_guest.o
+sev-guest-y += sev_guest.o sev_guest_tio.o
diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
index 8edd7bccabf2..431c12bbd337 100644
--- a/arch/x86/include/asm/sev.h
+++ b/arch/x86/include/asm/sev.h
@@ -117,6 +117,8 @@  struct snp_req_data {
 	unsigned long resp_gpa;
 	unsigned long data_gpa;
 	unsigned int data_npages;
+	unsigned int guest_rid;
+	unsigned long param;
 };
 
 #define MAX_AUTHTAG_LEN		32
diff --git a/drivers/virt/coco/sev-guest/sev-guest.h b/drivers/virt/coco/sev-guest/sev-guest.h
index 765f42ff55aa..d1254148c83b 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.h
+++ b/drivers/virt/coco/sev-guest/sev-guest.h
@@ -51,4 +51,6 @@  int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 void *alloc_shared_pages(struct device *dev, size_t sz);
 void free_shared_pages(void *buf, size_t sz);
 
+void sev_guest_tsm_set_ops(bool set, struct snp_guest_dev *snp_dev);
+
 #endif /* __VIRT_SEVGUEST_H__ */
diff --git a/include/linux/psp-sev.h b/include/linux/psp-sev.h
index adf40e0316dc..bff7396d18de 100644
--- a/include/linux/psp-sev.h
+++ b/include/linux/psp-sev.h
@@ -1050,6 +1050,9 @@  static inline void snp_free_firmware_page(void *addr) { }
 #define MMIO_VALIDATE_RANGEID(r)  ((r) & 0x7)
 #define MMIO_VALIDATE_RESERVED(r) ((r) & 0xFFF0000000000008ULL)
 
+#define MMIO_MK_VALIDATE(start, size, range_id) \
+	(MMIO_VALIDATE_GPA(start) | (get_order(size >> 12) << 4) | ((range_id) & 0xFF))
+
 /* Optional Certificates/measurements/report data from TIO_GUEST_REQUEST */
 struct tio_blob_table_entry {
 	guid_t guid;
@@ -1067,4 +1070,23 @@  struct tio_blob_table_entry {
 #define TIO_GUID_REPORT \
 	GUID_INIT(0x70dc5b0e, 0x0cc0, 0x4cd5, 0x97, 0xbb, 0xff, 0x0b, 0xa2, 0x5b, 0xf3, 0x20)
 
+/*
+ * Status codes from TIO_MSG_MMIO_VALIDATE_REQ
+ */
+enum mmio_validate_status {
+	MMIO_VALIDATE_SUCCESS = 0,
+	MMIO_VALIDATE_INVALID_TDI = 1,
+	MMIO_VALIDATE_TDI_UNBOUND = 2,
+	MMIO_VALIDATE_NOT_ASSIGNED = 3, /* At least one page is not assigned to the guest */
+	MMIO_VALIDATE_NOT_UNIFORM = 4,  /* The Validated bit is not uniformly set for
+					   the MMIO subrange */
+	MMIO_VALIDATE_NOT_IMMUTABLE = 5,/* At least one page does not have immutable bit set
+					   when validated bit is clear */
+	MMIO_VALIDATE_NOT_MAPPED = 6,   /* At least one page is not mapped to the expected GPA */
+	MMIO_VALIDATE_NOT_REPORTED = 7, /* The provided MMIO range ID is not reported in
+					   the interface report */
+	MMIO_VALIDATE_OUT_OF_RANGE = 8, /* The subrange is out the MMIO range in
+					   the interface report */
+};
+
 #endif	/* __PSP_SEV_H__ */
diff --git a/arch/x86/coco/sev/core.c b/arch/x86/coco/sev/core.c
index de1df0cb45da..d05a97421ffc 100644
--- a/arch/x86/coco/sev/core.c
+++ b/arch/x86/coco/sev/core.c
@@ -2468,6 +2468,11 @@  int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct sn
 	if (exit_code == SVM_VMGEXIT_EXT_GUEST_REQUEST) {
 		ghcb_set_rax(ghcb, input->data_gpa);
 		ghcb_set_rbx(ghcb, input->data_npages);
+	} else if (exit_code == SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST) {
+		ghcb_set_rax(ghcb, input->data_gpa);
+		ghcb_set_rbx(ghcb, input->data_npages);
+		ghcb_set_rcx(ghcb, input->guest_rid);
+		ghcb_set_rdx(ghcb, input->param);
 	}
 
 	ret = sev_es_ghcb_hv_call(ghcb, &ctxt, exit_code, input->req_gpa, input->resp_gpa);
@@ -2477,6 +2482,8 @@  int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct sn
 	rio->exitinfo2 = ghcb->save.sw_exit_info_2;
 	switch (rio->exitinfo2) {
 	case 0:
+		if (exit_code == SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST)
+			input->param = ghcb_get_rdx(ghcb);
 		break;
 
 	case SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_BUSY):
@@ -2489,6 +2496,10 @@  int snp_issue_guest_request(u64 exit_code, struct snp_req_data *input, struct sn
 			input->data_npages = ghcb_get_rbx(ghcb);
 			ret = -ENOSPC;
 			break;
+		} else if (exit_code == SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST) {
+			input->data_npages = ghcb_get_rbx(ghcb);
+			ret = -ENOSPC;
+			break;
 		}
 		fallthrough;
 	default:
diff --git a/drivers/virt/coco/sev-guest/sev_guest.c b/drivers/virt/coco/sev-guest/sev_guest.c
index d04d270f359e..571faade5690 100644
--- a/drivers/virt/coco/sev-guest/sev_guest.c
+++ b/drivers/virt/coco/sev-guest/sev_guest.c
@@ -52,6 +52,10 @@  static int vmpck_id = -1;
 module_param(vmpck_id, int, 0444);
 MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.");
 
+static bool tsm_enable = true;
+module_param(tsm_enable, bool, 0644);
+MODULE_PARM_DESC(tsm_enable, "Enable SEV TIO");
+
 /* Mutex to serialize the shared buffer access and command handling. */
 DEFINE_MUTEX(snp_cmd_mutex);
 
@@ -277,7 +281,8 @@  static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
 		return -EBADMSG;
 
 	/* Verify response message type and version number. */
-	if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
+	if ((resp_hdr->msg_type != (req_hdr->msg_type + 1) &&
+	     (resp_hdr->msg_type != (req_hdr->msg_type - 0x80))) ||
 	    resp_hdr->msg_version != req_hdr->msg_version)
 		return -EBADMSG;
 
@@ -337,6 +342,10 @@  static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
 	rc = snp_issue_guest_request(exit_code, &snp_dev->input, rio);
 	switch (rc) {
 	case -ENOSPC:
+		if (exit_code == SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST) {
+			pr_warn("SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST => -ENOSPC");
+			break;
+		}
 		/*
 		 * If the extended guest request fails due to having too
 		 * small of a certificate data buffer, retry the same
@@ -1142,6 +1151,9 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 	if (ret)
 		goto e_free_cert_data;
 
+	if (tsm_enable)
+		sev_guest_tsm_set_ops(true, snp_dev);
+
 	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id);
 	return 0;
 
@@ -1160,6 +1172,8 @@  static void __exit sev_guest_remove(struct platform_device *pdev)
 {
 	struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
 
+	if (tsm_enable)
+		sev_guest_tsm_set_ops(false, snp_dev);
 	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
 	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
 	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
diff --git a/drivers/virt/coco/sev-guest/sev_guest_tio.c b/drivers/virt/coco/sev-guest/sev_guest_tio.c
new file mode 100644
index 000000000000..33a082e7f039
--- /dev/null
+++ b/drivers/virt/coco/sev-guest/sev_guest_tio.c
@@ -0,0 +1,513 @@ 
+// SPDX-License-Identifier: GPL-2.0-only
+
+#include <linux/pci.h>
+#include <linux/psp-sev.h>
+#include <linux/tsm.h>
+
+#include <asm/svm.h>
+#include <asm/sev.h>
+
+#include "sev-guest.h"
+
+#define TIO_MESSAGE_VERSION	1
+
+ulong tsm_vtom = 0x7fffffff;
+module_param(tsm_vtom, ulong, 0644);
+MODULE_PARM_DESC(tsm_vtom, "SEV TIO vTOM value");
+
+static void tio_guest_blob_free(struct tsm_blob *b)
+{
+	memset(b->data, 0, b->len);
+}
+
+static int handle_tio_guest_request(struct snp_guest_dev *snp_dev, u8 type,
+				   void *req_buf, size_t req_sz, void *resp_buf, u32 resp_sz,
+				   u64 *pt_pa, u64 *npages, u64 *bdfn, u64 *param, u64 *fw_err)
+{
+	struct snp_guest_request_ioctl rio = {
+		.msg_version = TIO_MESSAGE_VERSION,
+		.exitinfo2 = 0,
+	};
+	int ret;
+
+	snp_dev->input.data_gpa = 0;
+	snp_dev->input.data_npages = 0;
+	snp_dev->input.guest_rid = 0;
+	snp_dev->input.param = 0;
+
+	if (pt_pa && npages) {
+		snp_dev->input.data_gpa = *pt_pa;
+		snp_dev->input.data_npages = *npages;
+	}
+	if (bdfn)
+		snp_dev->input.guest_rid = *bdfn;
+	if (param)
+		snp_dev->input.param = *param;
+
+	mutex_lock(&snp_cmd_mutex);
+	ret = handle_guest_request(snp_dev, SVM_VMGEXIT_SEV_TIO_GUEST_REQUEST,
+				   &rio, type, req_buf, req_sz, resp_buf, resp_sz);
+	mutex_unlock(&snp_cmd_mutex);
+
+	if (param)
+		*param = snp_dev->input.param;
+
+	*fw_err = rio.exitinfo2;
+
+	return ret;
+}
+
+static int guest_request_tio_certs(struct snp_guest_dev *snp_dev, u8 type,
+				   void *req_buf, size_t req_sz, void *resp_buf, u32 resp_sz,
+				   u64 bdfn, enum tsm_tdisp_state *state,
+				   struct tsm_blob **certs, struct tsm_blob **meas,
+				   struct tsm_blob **report, u64 *fw_err)
+{
+	u64 certs_size = SZ_32K, c1 = 0, pt_pa, param = 0;
+	struct tio_blob_table_entry *pt;
+	int rc;
+
+	pt = alloc_shared_pages(snp_dev->dev, certs_size);
+	if (!pt)
+		return -ENOMEM;
+
+	pt_pa = __pa(pt);
+	c1 = certs_size;
+	rc = handle_tio_guest_request(snp_dev, type, req_buf, req_sz, resp_buf, resp_sz,
+				      &pt_pa, &c1, &bdfn, state ? &param : NULL, fw_err);
+
+	if (c1 > SZ_32K) {
+		free_shared_pages(pt, certs_size);
+		certs_size = c1;
+		pt = alloc_shared_pages(snp_dev->dev, certs_size);
+		if (!pt)
+			return -ENOMEM;
+
+		pt_pa = __pa(pt);
+		rc = handle_tio_guest_request(snp_dev, type, req_buf, req_sz, resp_buf, resp_sz,
+					      &pt_pa, &c1, &bdfn, state ? &param : NULL, fw_err);
+	}
+
+	if (rc)
+		return rc;
+
+	tsm_blob_put(*meas);
+	tsm_blob_put(*certs);
+	tsm_blob_put(*report);
+	*meas = NULL;
+	*certs = NULL;
+	*report = NULL;
+
+	for (unsigned int i = 0; i < 3; ++i) {
+		u8 *ptr = ((u8 *) pt) + pt[i].offset;
+		size_t len = pt[i].length;
+		struct tsm_blob *b;
+
+		if (guid_is_null(&pt[i].guid))
+			break;
+
+		if (!len)
+			continue;
+
+		b = tsm_blob_new(ptr, len, tio_guest_blob_free);
+		if (!b)
+			break;
+
+		if (guid_equal(&pt[i].guid, &TIO_GUID_MEASUREMENTS))
+			*meas = b;
+		else if (guid_equal(&pt[i].guid, &TIO_GUID_CERTIFICATES))
+			*certs = b;
+		else if (guid_equal(&pt[i].guid, &TIO_GUID_REPORT))
+			*report = b;
+	}
+	free_shared_pages(pt, certs_size);
+
+	if (state)
+		*state = param;
+
+	return 0;
+}
+
+struct tio_msg_tdi_info_req {
+	__u16 guest_device_id;
+	__u8 reserved[14];
+} __packed;
+
+struct tio_msg_tdi_info_rsp {
+	__u16 guest_device_id;
+	__u16 status;
+	__u8 reserved1[12];
+	union {
+		u32 meas_flags;
+		struct {
+			u32 meas_digest_valid : 1;
+			u32 meas_digest_fresh : 1;
+		};
+	};
+	union {
+		u32 tdisp_lock_flags;
+		/* These are TDISP's LOCK_INTERFACE_REQUEST flags */
+		struct {
+			u32 no_fw_update : 1;
+			u32 cache_line_size : 1;
+			u32 lock_msix : 1;
+			u32 bind_p2p : 1;
+			u32 all_request_redirect : 1;
+		};
+	};
+	__u64 spdm_algos;
+	__u8 certs_digest[48];
+	__u8 meas_digest[48];
+	__u8 interface_report_digest[48];
+} __packed;
+
+static int tio_tdi_status(struct tsm_tdi *tdi, struct snp_guest_dev *snp_dev,
+			  struct tsm_tdi_status *ts)
+{
+	struct snp_guest_crypto *crypto = snp_dev->crypto;
+	size_t resp_len = sizeof(struct tio_msg_tdi_info_rsp) + crypto->a_len;
+	struct tio_msg_tdi_info_rsp *rsp = kzalloc(resp_len, GFP_KERNEL);
+	struct tio_msg_tdi_info_req req = {
+		.guest_device_id = pci_dev_id(tdi->pdev),
+	};
+	u64 fw_err = 0;
+	int rc;
+	enum tsm_tdisp_state state = 0;
+
+	pci_notice(tdi->pdev, "TDI info");
+	if (!rsp)
+		return -ENOMEM;
+
+	rc = guest_request_tio_certs(snp_dev, TIO_MSG_TDI_INFO_REQ, &req,
+				     sizeof(req), rsp, resp_len,
+				     pci_dev_id(tdi->pdev), &state,
+				     &tdi->tdev->certs, &tdi->tdev->meas,
+				     &tdi->report, &fw_err);
+
+	ts->meas_digest_valid = rsp->meas_digest_valid;
+	ts->meas_digest_fresh = rsp->meas_digest_fresh;
+	ts->no_fw_update = rsp->no_fw_update;
+	ts->cache_line_size = rsp->cache_line_size == 0 ? 64 : 128;
+	ts->lock_msix = rsp->lock_msix;
+	ts->bind_p2p = rsp->bind_p2p;
+	ts->all_request_redirect = rsp->all_request_redirect;
+#define __ALGO(x, n, y) \
+	((((x) & (0xFFUL << (n))) == TIO_SPDM_ALGOS_##y) ? \
+	 (1ULL << TSM_TDI_SPDM_ALGOS_##y) : 0)
+	ts->spdm_algos =
+		__ALGO(rsp->spdm_algos, 0, DHE_SECP256R1) |
+		__ALGO(rsp->spdm_algos, 0, DHE_SECP384R1) |
+		__ALGO(rsp->spdm_algos, 8, AEAD_AES_128_GCM) |
+		__ALGO(rsp->spdm_algos, 8, AEAD_AES_256_GCM) |
+		__ALGO(rsp->spdm_algos, 16, ASYM_TPM_ALG_RSASSA_3072) |
+		__ALGO(rsp->spdm_algos, 16, ASYM_TPM_ALG_ECDSA_ECC_NIST_P256) |
+		__ALGO(rsp->spdm_algos, 16, ASYM_TPM_ALG_ECDSA_ECC_NIST_P384) |
+		__ALGO(rsp->spdm_algos, 24, HASH_TPM_ALG_SHA_256) |
+		__ALGO(rsp->spdm_algos, 24, HASH_TPM_ALG_SHA_384) |
+		__ALGO(rsp->spdm_algos, 32, KEY_SCHED_SPDM_KEY_SCHEDULE);
+#undef __ALGO
+	memcpy(ts->certs_digest, rsp->certs_digest, sizeof(ts->certs_digest));
+	memcpy(ts->meas_digest, rsp->meas_digest, sizeof(ts->meas_digest));
+	memcpy(ts->interface_report_digest, rsp->interface_report_digest,
+	       sizeof(ts->interface_report_digest));
+
+	ts->valid = true;
+	ts->state = state;
+	/* The response buffer contains the sensitive data, explicitly clear it. */
+	memzero_explicit(&rsp, sizeof(resp_len));
+	kfree(rsp);
+	return rc;
+}
+
+struct tio_msg_mmio_validate_req {
+	__u16 guest_device_id; /* Hypervisor provided identifier used by the guest
+				  to identify the TDI in guest messages */
+	__u16 reserved1;
+	__u8 reserved2[12];
+	__u64 subrange_base;
+	__u32 subrange_page_count;
+	__u32 range_offset;
+	union {
+		__u16 flags;
+		struct {
+			__u16 validated:1; /* Desired value to set RMP.Validated for the range */
+			/* Force validated:
+			 * 0: If subrange does not have RMP.Validated set uniformly, fail.
+			 * 1: If subrange does not have RMP.Validated set uniformly, force
+			 *    to requested value
+			 */
+			__u16 force_validated:1;
+		};
+	};
+	__u16 range_id;
+	__u8 reserved3[12];
+} __packed;
+
+struct tio_msg_mmio_validate_rsp {
+	__u16 guest_interface_id;
+	__u16 status; /* MMIO_VALIDATE_xxx */
+	__u8 reserved1[12];
+	__u64 subrange_base;
+	__u32 subrange_page_count;
+	__u32 range_offset;
+	union {
+		__u16 flags;
+		struct {
+			__u16 changed:1; /* Indicates that the Validated bit has changed
+					    due to this operation */
+		};
+	};
+	__u16 range_id;
+	__u8 reserved2[12];
+} __packed;
+
+static int mmio_validate_range(struct snp_guest_dev *snp_dev, struct pci_dev *pdev,
+			       unsigned int range_id, resource_size_t start, resource_size_t size,
+			       bool invalidate, u64 *fw_err)
+{
+	struct snp_guest_crypto *crypto = snp_dev->crypto;
+	size_t resp_len = sizeof(struct tio_msg_mmio_validate_rsp) + crypto->a_len;
+	struct tio_msg_mmio_validate_rsp *rsp = kzalloc(resp_len, GFP_KERNEL);
+	struct tio_msg_mmio_validate_req req = {
+		.guest_device_id = pci_dev_id(pdev),
+		.subrange_base = start,
+		.subrange_page_count = size >> PAGE_SHIFT,
+		.range_offset = 0,
+		.validated = 1, /* Desired value to set RMP.Validated for the range */
+		.force_validated = 0,
+		.range_id = range_id,
+	};
+	u64 bdfn = pci_dev_id(pdev);
+	u64 mmio_val = MMIO_MK_VALIDATE(start, size, range_id);
+	int rc;
+
+	if (!rsp)
+		return -ENOMEM;
+
+	if (invalidate)
+		memset(&req, 0, sizeof(req));
+
+	rc = handle_tio_guest_request(snp_dev, TIO_MSG_MMIO_VALIDATE_REQ,
+			       &req, sizeof(req), rsp, resp_len,
+			       NULL, NULL, &bdfn, &mmio_val, fw_err);
+	if (rc)
+		goto free_exit;
+
+	if (rsp->status)
+		rc = -EBADR;
+
+free_exit:
+	/* The response buffer contains the sensitive data, explicitly clear it. */
+	memzero_explicit(&rsp, sizeof(resp_len));
+	kfree(rsp);
+	return rc;
+}
+
+static int tio_tdi_mmio_validate(struct tsm_tdi *tdi, struct snp_guest_dev *snp_dev,
+				 bool invalidate)
+{
+	struct pci_dev *pdev = tdi->pdev;
+	struct tdi_report_mmio_range mr;
+	struct resource *r;
+	u64 fw_err = 0;
+	int i = 0, rc;
+
+	pci_notice(tdi->pdev, "MMIO validate");
+
+	if (WARN_ON_ONCE(!tdi->report || !tdi->report->data))
+		return -EFAULT;
+
+	for (i = 0; i < TDI_REPORT_MR_NUM(tdi->report); ++i) {
+		mr = TDI_REPORT_MR(tdi->report, i);
+		r = pci_resource_n(tdi->pdev, mr.range_id);
+
+		if (r->end == r->start || ((r->end - r->start + 1) & ~PAGE_MASK) || !mr.num) {
+			pci_warn(tdi->pdev, "Skipping broken range [%d] #%d %d pages, %llx..%llx\n",
+				i, mr.range_id, mr.num, r->start, r->end);
+			continue;
+		}
+
+		if (mr.is_non_tee_mem) {
+			pci_info(tdi->pdev, "Skipping non-TEE range [%d] #%d %d pages, %llx..%llx\n",
+				 i, mr.range_id, mr.num, r->start, r->end);
+			continue;
+		}
+
+		rc = mmio_validate_range(snp_dev, pdev, mr.range_id,
+					 r->start, r->end - r->start + 1, invalidate, &fw_err);
+		if (rc) {
+			pci_err(pdev, "MMIO #%d %llx..%llx validation failed 0x%llx\n",
+				mr.range_id, r->start, r->end, fw_err);
+			continue;
+		}
+
+		pci_notice(pdev, "MMIO #%d %llx..%llx validated\n",  mr.range_id, r->start, r->end);
+	}
+
+	return rc;
+}
+
+struct sdte {
+	__u64 v                  : 1;
+	__u64 reserved           : 3;
+	__u64 cxlio              : 3;
+	__u64 reserved1          : 45;
+	__u64 ppr                : 1;
+	__u64 reserved2          : 1;
+	__u64 giov               : 1;
+	__u64 gv                 : 1;
+	__u64 glx                : 2;
+	__u64 gcr3_tbl_rp0       : 3;
+	__u64 ir                 : 1;
+	__u64 iw                 : 1;
+	__u64 reserved3          : 1;
+	__u16 domain_id;
+	__u16 gcr3_tbl_rp1;
+	__u32 interrupt          : 1;
+	__u32 reserved4          : 5;
+	__u32 ex                 : 1;
+	__u32 sd                 : 1;
+	__u32 reserved5          : 2;
+	__u32 sats               : 1;
+	__u32 gcr3_tbl_rp2       : 21;
+	__u64 giv                : 1;
+	__u64 gint_tbl_len       : 4;
+	__u64 reserved6          : 1;
+	__u64 gint_tbl           : 46;
+	__u64 reserved7          : 2;
+	__u64 gpm                : 2;
+	__u64 reserved8          : 3;
+	__u64 hpt_mode           : 1;
+	__u64 reserved9          : 4;
+	__u32 asid               : 12;
+	__u32 reserved10         : 3;
+	__u32 viommu_en          : 1;
+	__u32 guest_device_id    : 16;
+	__u32 guest_id           : 15;
+	__u32 guest_id_mbo       : 1;
+	__u32 reserved11         : 1;
+	__u32 vmpl               : 2;
+	__u32 reserved12         : 3;
+	__u32 attrv              : 1;
+	__u32 reserved13         : 1;
+	__u32 sa                 : 8;
+	__u8 ide_stream_id[8];
+	__u32 vtom_en            : 1;
+	__u32 vtom               : 31;
+	__u32 rp_id              : 5;
+	__u32 reserved14         : 27;
+	__u8  reserved15[0x40-0x30];
+} __packed;
+
+struct tio_msg_sdte_write_req {
+	__u16 guest_device_id;
+	__u8 reserved[14];
+	struct sdte sdte;
+} __packed;
+
+#define SDTE_WRITE_SUCCESS		0
+#define SDTE_WRITE_INVALID_TDI		1
+#define SDTE_WRITE_TDI_NOT_BOUND	2
+#define SDTE_WRITE_RESERVED		3
+
+struct tio_msg_sdte_write_rsp {
+	__u16 guest_device_id;
+	__u16 status; /* SDTE_WRITE_xxx */
+	__u8 reserved[12];
+} __packed;
+
+static int tio_tdi_sdte_write(struct tsm_tdi *tdi, struct snp_guest_dev *snp_dev, bool invalidate)
+{
+	struct snp_guest_crypto *crypto = snp_dev->crypto;
+	size_t resp_len = sizeof(struct tio_msg_sdte_write_rsp) + crypto->a_len;
+	struct tio_msg_sdte_write_rsp *rsp = kzalloc(resp_len, GFP_KERNEL);
+	struct tio_msg_sdte_write_req req = {
+		.guest_device_id = pci_dev_id(tdi->pdev),
+		.sdte.vmpl = 0,
+		.sdte.vtom = tsm_vtom,
+		.sdte.vtom_en = 1,
+		.sdte.iw = 1,
+		.sdte.ir = 1,
+		.sdte.v = 1,
+	};
+	u64 fw_err = 0;
+	u64 bdfn = pci_dev_id(tdi->pdev);
+	int rc;
+
+	BUILD_BUG_ON(sizeof(struct sdte) * 8 != 512);
+
+	if (invalidate)
+		memset(&req, 0, sizeof(req));
+
+	pci_notice(tdi->pdev, "SDTE write vTOM=%lx", (unsigned long) req.sdte.vtom << 21);
+
+	if (!rsp)
+		return -ENOMEM;
+
+	rc = handle_tio_guest_request(snp_dev, TIO_MSG_SDTE_WRITE_REQ,
+			       &req, sizeof(req), rsp, resp_len,
+			       NULL, NULL, &bdfn, NULL, &fw_err);
+	if (rc) {
+		pci_err(tdi->pdev, "SDTE write failed with 0x%llx\n", fw_err);
+		goto free_exit;
+	}
+
+free_exit:
+	/* The response buffer contains the sensitive data, explicitly clear it. */
+	memzero_explicit(&rsp, sizeof(resp_len));
+	kfree(rsp);
+	return rc;
+}
+
+static int sev_guest_tdi_status(struct tsm_tdi *tdi, void *private_data, struct tsm_tdi_status *ts)
+{
+	struct snp_guest_dev *snp_dev = private_data;
+
+	return tio_tdi_status(tdi, snp_dev, ts);
+}
+
+static int sev_guest_tdi_validate(struct tsm_tdi *tdi, bool invalidate, void *private_data)
+{
+	struct snp_guest_dev *snp_dev = private_data;
+	struct tsm_tdi_status ts = { 0 };
+	int ret;
+
+	if (!tdi->report) {
+		ret = tio_tdi_status(tdi, snp_dev, &ts);
+
+		if (ret || !tdi->report) {
+			pci_err(tdi->pdev, "No report available, ret=%d", ret);
+			if (!ret && tdi->report)
+				ret = -EIO;
+			return ret;
+		}
+
+		if (ts.state != TDISP_STATE_RUN) {
+			pci_err(tdi->pdev, "Not in RUN state, state=%d instead", ts.state);
+			return -EIO;
+		}
+	}
+
+	ret = tio_tdi_sdte_write(tdi, snp_dev, invalidate);
+	if (ret)
+		return ret;
+
+	ret = tio_tdi_mmio_validate(tdi, snp_dev, invalidate);
+	if (ret)
+		return ret;
+
+	return 0;
+}
+
+struct tsm_ops sev_guest_tsm_ops = {
+	.tdi_validate = sev_guest_tdi_validate,
+	.tdi_status = sev_guest_tdi_status,
+};
+
+void sev_guest_tsm_set_ops(bool set, struct snp_guest_dev *snp_dev)
+{
+	if (set)
+		tsm_set_ops(&sev_guest_tsm_ops, snp_dev);
+	else
+		tsm_set_ops(NULL, NULL);
+}