diff mbox series

[v8,04/16] virt: sev-guest: Add vmpck_id to snp_guest_dev struct

Message ID 20240215113128.275608-5-nikunj@amd.com (mailing list archive)
State New, archived
Headers show
Series Add Secure TSC support for SNP guests | expand

Commit Message

Nikunj A. Dadhania Feb. 15, 2024, 11:31 a.m. UTC
Drop vmpck and os_area_msg_seqno pointers so that secret page layout
does not need to be exposed to the sev-guest driver after the rework.
Instead, add helper APIs to access vmpck and os_area_msg_seqno when
needed. Added define for maximum supported VMPCK.

Also, change function is_vmpck_empty() to snp_is_vmpck_empty() in
preparation for moving to sev.c.

Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
Reviewed-by: Tom Lendacky <thomas.lendacky@amd.com>
Tested-by: Peter Gonda <pgonda@google.com>
---
 arch/x86/include/asm/sev.h              |  1 +
 drivers/virt/coco/sev-guest/sev-guest.c | 95 ++++++++++++-------------
 2 files changed, 48 insertions(+), 48 deletions(-)

Comments

Borislav Petkov April 9, 2024, 10:23 a.m. UTC | #1
On Thu, Feb 15, 2024 at 05:01:16PM +0530, Nikunj A Dadhania wrote:
> Drop vmpck and os_area_msg_seqno pointers so that secret page layout
> does not need to be exposed to the sev-guest driver after the rework.
> Instead, add helper APIs to access vmpck and os_area_msg_seqno when
> needed. Added define for maximum supported VMPCK.

Do not talk about *what* the patch is doing in the commit message - that
should be obvious from the diff itself. Rather, concentrate on the *why*
it needs to be done.

Imagine one fine day you're doing git archeology, you find the place in
the code about which you want to find out why it was changed the way it
is now.

You do git annotate <filename> ... find the line, see the commit id and
you do:

git show <commit id>

You read the commit message and there's just gibberish and nothing's
explaining *why* that change was done. And you start scratching your
head, trying to figure out why...

I'm sure you're getting the idea.

> Also, change function is_vmpck_empty() to snp_is_vmpck_empty() in
> preparation for moving to sev.c.
> 
> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
> Reviewed-by: Tom Lendacky <thomas.lendacky@amd.com>
> Tested-by: Peter Gonda <pgonda@google.com>
> ---
>  arch/x86/include/asm/sev.h              |  1 +
>  drivers/virt/coco/sev-guest/sev-guest.c | 95 ++++++++++++-------------
>  2 files changed, 48 insertions(+), 48 deletions(-)
> 
> diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
> index 0c0b11af9f89..e4f52a606487 100644
> --- a/arch/x86/include/asm/sev.h
> +++ b/arch/x86/include/asm/sev.h
> @@ -135,6 +135,7 @@ struct secrets_os_area {
>  } __packed;
>  
>  #define VMPCK_KEY_LEN		32
> +#define VMPCK_MAX_NUM		4
>  
>  /* See the SNP spec version 0.9 for secrets page format */
>  struct snp_secrets_page_layout {
> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
> index 596cec03f9eb..646eb215f3c7 100644
> --- a/drivers/virt/coco/sev-guest/sev-guest.c
> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
> @@ -55,8 +55,7 @@ struct snp_guest_dev {
>  		struct snp_derived_key_req derived_key;
>  		struct snp_ext_report_req ext_report;
>  	} req;
> -	u32 *os_area_msg_seqno;
> -	u8 *vmpck;
> +	unsigned int vmpck_id;
>  };
>  
>  static u32 vmpck_id;
> @@ -66,14 +65,22 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
>  /* Mutex to serialize the shared buffer access and command handling. */
>  static DEFINE_MUTEX(snp_cmd_mutex);
>  
> -static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
> +static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev)

static functions don't need a prefix like "snp_".

>  {
> -	char zero_key[VMPCK_KEY_LEN] = {0};
> +	return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN;
> +}
>  
> -	if (snp_dev->vmpck)
> -		return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
> +static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)

Ditto.

> +{
> +	return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
> +}
>  
> -	return true;
> +static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev)
> +{
> +	char zero_key[VMPCK_KEY_LEN] = {0};
> +	u8 *key = snp_get_vmpck(snp_dev);
> +
> +	return !memcmp(key, zero_key, VMPCK_KEY_LEN);
>  }
>  
>  /*
> @@ -95,20 +102,22 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
>   */
>  static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
>  {
> -	dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
> -		  vmpck_id);
> -	memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
> -	snp_dev->vmpck = NULL;
> +	u8 *key = snp_get_vmpck(snp_dev);

Check whether is_vmpck_empty before you disable?

> +
> +	dev_alert(snp_dev->dev, "Disabling vmpck_id %u to prevent IV reuse.\n",
> +		  snp_dev->vmpck_id);
> +	memzero_explicit(key, VMPCK_KEY_LEN);
>  }
>  
>  static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>  {
> +	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
>  	u64 count;
>  
>  	lockdep_assert_held(&snp_cmd_mutex);
>  
>  	/* Read the current message sequence counter from secrets pages */
> -	count = *snp_dev->os_area_msg_seqno;
> +	count = *os_area_msg_seqno;

Why does that snp_get_os_area_msg_seqno() returns a pointer when you
deref it here again?

A function which returns a sequence number should return that number
- not a pointer to it.

Which then makes that u32 *os_area_msg_seqno redundant and you can use
the function directly.

IOW:

static inline u32 snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
{
        return snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
}

Simple.

>  
>  	return count + 1;
>  }
> @@ -136,11 +145,13 @@ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>  
>  static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
>  {
> +	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
> +
>  	/*
>  	 * The counter is also incremented by the PSP, so increment it by 2
>  	 * and save in secrets page.
>  	 */
> -	*snp_dev->os_area_msg_seqno += 2;
> +	*os_area_msg_seqno += 2;

Yah, you have a getter but not a setter. You're setting it through the
pointer. Do you see the imbalance in the APIs?

>  }
>  
>  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
> @@ -150,15 +161,22 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>  	return container_of(dev, struct snp_guest_dev, misc);
>  }
>  
> -static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
> +static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
>  {
>  	struct aesgcm_ctx *ctx;
> +	u8 *key;
> +
> +	if (snp_is_vmpck_empty(snp_dev)) {
> +		pr_err("VM communication key VMPCK%u is null\n", vmpck_id);

		      "Empty/invalid VMPCK%u communication key"

or so.

In a pre-patch, fix all your user-visible strings to say "VMPCK"
- capitalized as it is an abbreviation.

> +		return NULL;
> +	}
>  
>  	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
>  	if (!ctx)
>  		return NULL;
>  
> -	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
> +	key = snp_get_vmpck(snp_dev);
> +	if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
>  		pr_err("Crypto context initialization failed\n");
>  		kfree(ctx);
>  		return NULL;
> @@ -590,7 +608,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
>  	mutex_lock(&snp_cmd_mutex);
>  
>  	/* Check if the VMPCK is not empty */
> -	if (is_vmpck_empty(snp_dev)) {
> +	if (snp_is_vmpck_empty(snp_dev)) {
>  		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
>  		mutex_unlock(&snp_cmd_mutex);
>  		return -ENOTTY;
> @@ -667,32 +685,14 @@ static const struct file_operations snp_guest_fops = {
>  	.unlocked_ioctl = snp_guest_ioctl,
>  };
>  
> -static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
> +static bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
>  {
> -	u8 *key = NULL;
> +	if (WARN_ON((vmpck_id + 1) > VMPCK_MAX_NUM))
> +		return false;

So this will warn *and*, at the call site too. Let's tone that down.

>  
> -	switch (id) {
> -	case 0:
> -		*seqno = &layout->os_area.msg_seqno_0;
> -		key = layout->vmpck0;
> -		break;
> -	case 1:
> -		*seqno = &layout->os_area.msg_seqno_1;
> -		key = layout->vmpck1;
> -		break;
> -	case 2:
> -		*seqno = &layout->os_area.msg_seqno_2;
> -		key = layout->vmpck2;
> -		break;
> -	case 3:
> -		*seqno = &layout->os_area.msg_seqno_3;
> -		key = layout->vmpck3;
> -		break;
> -	default:
> -		break;
> -	}

Your commit message could explain why this is not needed, all of
a sudden.

> +	dev->vmpck_id = vmpck_id;
>  
> -	return key;
> +	return true;
>  }
>  
>  struct snp_msg_report_resp_hdr {
> @@ -728,7 +728,7 @@ static int sev_report_new(struct tsm_report *report, void *data)
>  	guard(mutex)(&snp_cmd_mutex);
>  
>  	/* Check if the VMPCK is not empty */
> -	if (is_vmpck_empty(snp_dev)) {
> +	if (snp_is_vmpck_empty(snp_dev)) {
>  		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
>  		return -ENOTTY;
>  	}
> @@ -848,21 +848,20 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>  		goto e_unmap;
>  
>  	ret = -EINVAL;
> -	snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
> -	if (!snp_dev->vmpck) {
> -		dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
> +	snp_dev->layout = layout;
> +	if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
> +		dev_err(dev, "invalid vmpck id %u\n", vmpck_id);
>  		goto e_unmap;
>  	}
>  
>  	/* Verify that VMPCK is not zero. */
> -	if (is_vmpck_empty(snp_dev)) {
> -		dev_err(dev, "vmpck id %d is null\n", vmpck_id);
> +	if (snp_is_vmpck_empty(snp_dev)) {
> +		dev_err(dev, "vmpck id %u is null\n", vmpck_id);

s!null!Invalid/Empty!

>  		goto e_unmap;
>  	}
>  
>  	platform_set_drvdata(pdev, snp_dev);
>  	snp_dev->dev = dev;
> -	snp_dev->layout = layout;
>  
>  	/* Allocate the shared page used for the request and response message. */
>  	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
> @@ -878,7 +877,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>  		goto e_free_response;
>  
>  	ret = -EIO;
> -	snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
> +	snp_dev->ctx = snp_init_crypto(snp_dev);
>  	if (!snp_dev->ctx)
>  		goto e_free_cert_data;
>  
> @@ -903,7 +902,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>  	if (ret)
>  		goto e_free_ctx;
>  
> -	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id);
> +	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %u)\n", vmpck_id);

Yet another spelling: "vmpck_id". Unify all those in a pre-patch pls
because it looks stupid.

Thx.
Nikunj A. Dadhania April 16, 2024, 5:57 a.m. UTC | #2
On 4/9/2024 3:53 PM, Borislav Petkov wrote:
> On Thu, Feb 15, 2024 at 05:01:16PM +0530, Nikunj A Dadhania wrote:
>> Drop vmpck and os_area_msg_seqno pointers so that secret page layout
>> does not need to be exposed to the sev-guest driver after the rework.
>> Instead, add helper APIs to access vmpck and os_area_msg_seqno when
>> needed. Added define for maximum supported VMPCK.
> 
> Do not talk about *what* the patch is doing in the commit message - that
> should be obvious from the diff itself. Rather, concentrate on the *why*
> it needs to be done.
> 
> Imagine one fine day you're doing git archeology, you find the place in
> the code about which you want to find out why it was changed the way it
> is now.
> 
> You do git annotate <filename> ... find the line, see the commit id and
> you do:
> 
> git show <commit id>
> 
> You read the commit message and there's just gibberish and nothing's
> explaining *why* that change was done. And you start scratching your
> head, trying to figure out why...
> 
> I'm sure you're getting the idea.

Sure, will reword the commit message and send the patch.

> 
>> Also, change function is_vmpck_empty() to snp_is_vmpck_empty() in
>> preparation for moving to sev.c.
>>
>> Signed-off-by: Nikunj A Dadhania <nikunj@amd.com>
>> Reviewed-by: Tom Lendacky <thomas.lendacky@amd.com>
>> Tested-by: Peter Gonda <pgonda@google.com>
>> ---
>>  arch/x86/include/asm/sev.h              |  1 +
>>  drivers/virt/coco/sev-guest/sev-guest.c | 95 ++++++++++++-------------
>>  2 files changed, 48 insertions(+), 48 deletions(-)
>>
>> diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
>> index 0c0b11af9f89..e4f52a606487 100644
>> --- a/arch/x86/include/asm/sev.h
>> +++ b/arch/x86/include/asm/sev.h
>> @@ -135,6 +135,7 @@ struct secrets_os_area {
>>  } __packed;
>>  
>>  #define VMPCK_KEY_LEN		32
>> +#define VMPCK_MAX_NUM		4
>>  
>>  /* See the SNP spec version 0.9 for secrets page format */
>>  struct snp_secrets_page_layout {
>> diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
>> index 596cec03f9eb..646eb215f3c7 100644
>> --- a/drivers/virt/coco/sev-guest/sev-guest.c
>> +++ b/drivers/virt/coco/sev-guest/sev-guest.c
>> @@ -55,8 +55,7 @@ struct snp_guest_dev {
>>  		struct snp_derived_key_req derived_key;
>>  		struct snp_ext_report_req ext_report;
>>  	} req;
>> -	u32 *os_area_msg_seqno;
>> -	u8 *vmpck;
>> +	unsigned int vmpck_id;
>>  };
>>  
>>  static u32 vmpck_id;
>> @@ -66,14 +65,22 @@ MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
>>  /* Mutex to serialize the shared buffer access and command handling. */
>>  static DEFINE_MUTEX(snp_cmd_mutex);
>>  
>> -static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
>> +static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev)
> 
> static functions don't need a prefix like "snp_".

Sure

> 
>>  {
>> -	char zero_key[VMPCK_KEY_LEN] = {0};
>> +	return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN;
>> +}
>>  
>> -	if (snp_dev->vmpck)
>> -		return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
>> +static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
> 
> Ditto.

Sure

> 
>> +{
>> +	return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
>> +}
>>  
>> -	return true;
>> +static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev)
>> +{
>> +	char zero_key[VMPCK_KEY_LEN] = {0};
>> +	u8 *key = snp_get_vmpck(snp_dev);
>> +
>> +	return !memcmp(key, zero_key, VMPCK_KEY_LEN);
>>  }
>>  
>>  /*
>> @@ -95,20 +102,22 @@ static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
>>   */
>>  static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
>>  {
>> -	dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
>> -		  vmpck_id);
>> -	memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
>> -	snp_dev->vmpck = NULL;
>> +	u8 *key = snp_get_vmpck(snp_dev);
> 
> Check whether is_vmpck_empty before you disable?
> 
>> +
>> +	dev_alert(snp_dev->dev, "Disabling vmpck_id %u to prevent IV reuse.\n",
>> +		  snp_dev->vmpck_id);
>> +	memzero_explicit(key, VMPCK_KEY_LEN);
>>  }
>>  
>>  static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>>  {
>> +	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
>>  	u64 count;
>>  
>>  	lockdep_assert_held(&snp_cmd_mutex);
>>  
>>  	/* Read the current message sequence counter from secrets pages */
>> -	count = *snp_dev->os_area_msg_seqno;
>> +	count = *os_area_msg_seqno;
> 
> Why does that snp_get_os_area_msg_seqno() returns a pointer when you
> deref it here again?
> 
> A function which returns a sequence number should return that number
> - not a pointer to it.
> 
> Which then makes that u32 *os_area_msg_seqno redundant and you can use
> the function directly.
> 
> IOW:
> 
> static inline u32 snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
> {
>         return snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;

This patch removes setting of layour page in snp_dev structure.

static inline u32 snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
{
        if (!platform_data)
                return NULL;

        return *(&platform_data->layout->os_area.msg_seqno_0 + vmpck_id);
}

> }
> 

> Simple.
> 
>>  
>>  	return count + 1;
>>  }
>> @@ -136,11 +145,13 @@ static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
>>  
>>  static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
>>  {
>> +	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
>> +
>>  	/*
>>  	 * The counter is also incremented by the PSP, so increment it by 2
>>  	 * and save in secrets page.
>>  	 */
>> -	*snp_dev->os_area_msg_seqno += 2;
>> +	*os_area_msg_seqno += 2;
> 
> Yah, you have a getter but not a setter. You're setting it through the
> pointer. 

I had a getter for getting the os_area_msg_seqno pointer, probably not a good function name.

> Do you see the imbalance in the APIs?

The msg_seqno should only be incremented by 2 (always), that was the reason to avoid a setter.

> 
>>  }
>>  
>>  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>> @@ -150,15 +161,22 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file)
>>  	return container_of(dev, struct snp_guest_dev, misc);
>>  }
>>  
>> -static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
>> +static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
>>  {
>>  	struct aesgcm_ctx *ctx;
>> +	u8 *key;
>> +
>> +	if (snp_is_vmpck_empty(snp_dev)) {
>> +		pr_err("VM communication key VMPCK%u is null\n", vmpck_id);
> 
> 		      "Empty/invalid VMPCK%u communication key"
> 
> or so.
> 
> In a pre-patch, fix all your user-visible strings to say "VMPCK"
> - capitalized as it is an abbreviation.

Sure, will do.

> 
>> +		return NULL;
>> +	}
>>  
>>  	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
>>  	if (!ctx)
>>  		return NULL;
>>  
>> -	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
>> +	key = snp_get_vmpck(snp_dev);
>> +	if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
>>  		pr_err("Crypto context initialization failed\n");
>>  		kfree(ctx);
>>  		return NULL;
>> @@ -590,7 +608,7 @@ static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
>>  	mutex_lock(&snp_cmd_mutex);
>>  
>>  	/* Check if the VMPCK is not empty */
>> -	if (is_vmpck_empty(snp_dev)) {
>> +	if (snp_is_vmpck_empty(snp_dev)) {
>>  		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
>>  		mutex_unlock(&snp_cmd_mutex);
>>  		return -ENOTTY;
>> @@ -667,32 +685,14 @@ static const struct file_operations snp_guest_fops = {
>>  	.unlocked_ioctl = snp_guest_ioctl,
>>  };
>>  
>> -static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
>> +static bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
>>  {
>> -	u8 *key = NULL;
>> +	if (WARN_ON((vmpck_id + 1) > VMPCK_MAX_NUM))
>> +		return false;
> 
> So this will warn *and*, at the call site too. Let's tone that down.

Sure.

> 
>>  
>> -	switch (id) {
>> -	case 0:
>> -		*seqno = &layout->os_area.msg_seqno_0;
>> -		key = layout->vmpck0;
>> -		break;
>> -	case 1:
>> -		*seqno = &layout->os_area.msg_seqno_1;
>> -		key = layout->vmpck1;
>> -		break;
>> -	case 2:
>> -		*seqno = &layout->os_area.msg_seqno_2;
>> -		key = layout->vmpck2;
>> -		break;
>> -	case 3:
>> -		*seqno = &layout->os_area.msg_seqno_3;
>> -		key = layout->vmpck3;
>> -		break;
>> -	default:
>> -		break;
>> -	}
> 
> Your commit message could explain why this is not needed, all of
> a sudden.

This was replaced by two independent APIs returning pointers to VMPCK and seqno pointer.

static inline u8 *snp_get_vmpck(unsigned int vmpck_id)
{
        if (!platform_data)
                return NULL;

        return platform_data->layout->vmpck0 + vmpck_id * VMPCK_KEY_LEN;
}

static inline u32 *snp_get_os_area_msg_seqno(unsigned int vmpck_id)
{
        if (!platform_data)
                return NULL;

        return &platform_data->layout->os_area.msg_seqno_0 + vmpck_id;
}

I will add more details.

> 
>> +	dev->vmpck_id = vmpck_id;
>>  
>> -	return key;
>> +	return true;
>>  }
>>  
>>  struct snp_msg_report_resp_hdr {
>> @@ -728,7 +728,7 @@ static int sev_report_new(struct tsm_report *report, void *data)
>>  	guard(mutex)(&snp_cmd_mutex);
>>  
>>  	/* Check if the VMPCK is not empty */
>> -	if (is_vmpck_empty(snp_dev)) {
>> +	if (snp_is_vmpck_empty(snp_dev)) {
>>  		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
>>  		return -ENOTTY;
>>  	}
>> @@ -848,21 +848,20 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>>  		goto e_unmap;
>>  
>>  	ret = -EINVAL;
>> -	snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
>> -	if (!snp_dev->vmpck) {
>> -		dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
>> +	snp_dev->layout = layout;
>> +	if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
>> +		dev_err(dev, "invalid vmpck id %u\n", vmpck_id);
>>  		goto e_unmap;
>>  	}
>>  
>>  	/* Verify that VMPCK is not zero. */
>> -	if (is_vmpck_empty(snp_dev)) {
>> -		dev_err(dev, "vmpck id %d is null\n", vmpck_id);
>> +	if (snp_is_vmpck_empty(snp_dev)) {
>> +		dev_err(dev, "vmpck id %u is null\n", vmpck_id);
> 
> s!null!Invalid/Empty!

Okay

> 
>>  		goto e_unmap;
>>  	}
>>  
>>  	platform_set_drvdata(pdev, snp_dev);
>>  	snp_dev->dev = dev;
>> -	snp_dev->layout = layout;
>>  
>>  	/* Allocate the shared page used for the request and response message. */
>>  	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
>> @@ -878,7 +877,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>>  		goto e_free_response;
>>  
>>  	ret = -EIO;
>> -	snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
>> +	snp_dev->ctx = snp_init_crypto(snp_dev);
>>  	if (!snp_dev->ctx)
>>  		goto e_free_cert_data;
>>  
>> @@ -903,7 +902,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
>>  	if (ret)
>>  		goto e_free_ctx;
>>  
>> -	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id);
>> +	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %u)\n", vmpck_id);
> 
> Yet another spelling: "vmpck_id". Unify all those in a pre-patch pls
> because it looks stupid.

Sure
> 
> Thx.
> 

Thanks for the review.

Regards
Nikunj
Borislav Petkov April 16, 2024, 9:06 a.m. UTC | #3
On Tue, Apr 16, 2024 at 11:27:24AM +0530, Nikunj A. Dadhania wrote:
> > Why does that snp_get_os_area_msg_seqno() returns a pointer when you
> > deref it here again?
> > 
> > A function which returns a sequence number should return that number
> > - not a pointer to it.
> > 
> > Which then makes that u32 *os_area_msg_seqno redundant and you can use
> > the function directly.
> > 
> > IOW:
> > 
> > static inline u32 snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
> > {
> >         return snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
> 
> This patch removes setting of layour page in snp_dev structure.

So?

> static inline u32 snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
> {
>         if (!platform_data)
>                 return NULL;
> 
>         return *(&platform_data->layout->os_area.msg_seqno_0 + vmpck_id);
> }

What?!

This snp_get_os_area_msg_seqno() is a new function added by this patch.

> I had a getter for getting the os_area_msg_seqno pointer, probably not
> a good function name.

Probably you need to go back to the drawing board and think about how
this thing should look like.

> > Do you see the imbalance in the APIs?
> 
> The msg_seqno should only be incremented by 2 (always), that was the reason to avoid a setter.

And what's wrong with the setter doing the incrementation so that
callers can't even get it wrong?

It sounds to me like you should redesign this sequence number handling
in a *separate* patch.

Thx.
Nikunj A. Dadhania April 17, 2024, 4:18 a.m. UTC | #4
On 4/16/2024 2:36 PM, Borislav Petkov wrote:
> On Tue, Apr 16, 2024 at 11:27:24AM +0530, Nikunj A. Dadhania wrote:
>>> Why does that snp_get_os_area_msg_seqno() returns a pointer when you
>>> deref it here again?
>>>
>>> A function which returns a sequence number should return that number
>>> - not a pointer to it.
>>>
>>> Which then makes that u32 *os_area_msg_seqno redundant and you can use
>>> the function directly.
>>>
>>> IOW:
>>>
>>> static inline u32 snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
>>> {
>>>         return snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
>>
>> This patch removes setting of layour page in snp_dev structure.
> 
> So?

* Instead of using snp_dev->layout, we will need to access it using platform_data->layout structure.
* Below will give incorrect value of sequence number, it will get VMPCK_0's sequence number and will add vmpck_id to that. Will work by fluke for VMPCK=0, but will fail for all other keys.

  return snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;


struct secrets_os_area {
...
        u32 msg_seqno_0;
        u32 msg_seqno_1;
        u32 msg_seqno_2;
        u32 msg_seqno_3;
...
}

* I am using vmpck_id to index to correct msg_seqno_*


Changing this to

struct secrets_os_area {
...
        u32 msg_seqno[VMPCK_MAX_NUM];
...
}


> 
>> static inline u32 snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
>> {
>>         if (!platform_data)
>>                 return NULL;
>>
>>         return *(&platform_data->layout->os_area.msg_seqno_0 + vmpck_id);
>> }
> 
> What?!

I can change the secrets_os_area like below to simplify things:

struct secrets_os_area {
...
        u32 msg_seqno[VMPCK_MAX_NUM];
...
}


static inline u32 snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
{
         if (!platform_data)
                 return NULL;

         return platform_data->layout->os_area.msg_seqno[snp_dev->vmpck_id];
}


> 
> This snp_get_os_area_msg_seqno() is a new function added by this patch.
> 
>> I had a getter for getting the os_area_msg_seqno pointer, probably not
>> a good function name.
> 
> Probably you need to go back to the drawing board and think about how
> this thing should look like.
> 
>>> Do you see the imbalance in the APIs?
>>
>> The msg_seqno should only be incremented by 2 (always), that was the reason to avoid a setter.
> 
> And what's wrong with the setter doing the incrementation so that
> callers can't even get it wrong?

Are you suggesting that setter should always increment by 2?

static inline u32 snp_set_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
{
...
	os_area.msg_seqno[snp_dev->vmpck_id] += 2;
...

}

> 
> It sounds to me like you should redesign this sequence number handling
> in a *separate* patch.

Sure, let me rethink and will post it as separate patch.

Regards
Nikunj
diff mbox series

Patch

diff --git a/arch/x86/include/asm/sev.h b/arch/x86/include/asm/sev.h
index 0c0b11af9f89..e4f52a606487 100644
--- a/arch/x86/include/asm/sev.h
+++ b/arch/x86/include/asm/sev.h
@@ -135,6 +135,7 @@  struct secrets_os_area {
 } __packed;
 
 #define VMPCK_KEY_LEN		32
+#define VMPCK_MAX_NUM		4
 
 /* See the SNP spec version 0.9 for secrets page format */
 struct snp_secrets_page_layout {
diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c
index 596cec03f9eb..646eb215f3c7 100644
--- a/drivers/virt/coco/sev-guest/sev-guest.c
+++ b/drivers/virt/coco/sev-guest/sev-guest.c
@@ -55,8 +55,7 @@  struct snp_guest_dev {
 		struct snp_derived_key_req derived_key;
 		struct snp_ext_report_req ext_report;
 	} req;
-	u32 *os_area_msg_seqno;
-	u8 *vmpck;
+	unsigned int vmpck_id;
 };
 
 static u32 vmpck_id;
@@ -66,14 +65,22 @@  MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.
 /* Mutex to serialize the shared buffer access and command handling. */
 static DEFINE_MUTEX(snp_cmd_mutex);
 
-static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
+static inline u8 *snp_get_vmpck(struct snp_guest_dev *snp_dev)
 {
-	char zero_key[VMPCK_KEY_LEN] = {0};
+	return snp_dev->layout->vmpck0 + snp_dev->vmpck_id * VMPCK_KEY_LEN;
+}
 
-	if (snp_dev->vmpck)
-		return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
+static inline u32 *snp_get_os_area_msg_seqno(struct snp_guest_dev *snp_dev)
+{
+	return &snp_dev->layout->os_area.msg_seqno_0 + snp_dev->vmpck_id;
+}
 
-	return true;
+static bool snp_is_vmpck_empty(struct snp_guest_dev *snp_dev)
+{
+	char zero_key[VMPCK_KEY_LEN] = {0};
+	u8 *key = snp_get_vmpck(snp_dev);
+
+	return !memcmp(key, zero_key, VMPCK_KEY_LEN);
 }
 
 /*
@@ -95,20 +102,22 @@  static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
  */
 static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
 {
-	dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
-		  vmpck_id);
-	memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
-	snp_dev->vmpck = NULL;
+	u8 *key = snp_get_vmpck(snp_dev);
+
+	dev_alert(snp_dev->dev, "Disabling vmpck_id %u to prevent IV reuse.\n",
+		  snp_dev->vmpck_id);
+	memzero_explicit(key, VMPCK_KEY_LEN);
 }
 
 static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
 {
+	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
 	u64 count;
 
 	lockdep_assert_held(&snp_cmd_mutex);
 
 	/* Read the current message sequence counter from secrets pages */
-	count = *snp_dev->os_area_msg_seqno;
+	count = *os_area_msg_seqno;
 
 	return count + 1;
 }
@@ -136,11 +145,13 @@  static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
 
 static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
 {
+	u32 *os_area_msg_seqno = snp_get_os_area_msg_seqno(snp_dev);
+
 	/*
 	 * The counter is also incremented by the PSP, so increment it by 2
 	 * and save in secrets page.
 	 */
-	*snp_dev->os_area_msg_seqno += 2;
+	*os_area_msg_seqno += 2;
 }
 
 static inline struct snp_guest_dev *to_snp_dev(struct file *file)
@@ -150,15 +161,22 @@  static inline struct snp_guest_dev *to_snp_dev(struct file *file)
 	return container_of(dev, struct snp_guest_dev, misc);
 }
 
-static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
+static struct aesgcm_ctx *snp_init_crypto(struct snp_guest_dev *snp_dev)
 {
 	struct aesgcm_ctx *ctx;
+	u8 *key;
+
+	if (snp_is_vmpck_empty(snp_dev)) {
+		pr_err("VM communication key VMPCK%u is null\n", vmpck_id);
+		return NULL;
+	}
 
 	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT);
 	if (!ctx)
 		return NULL;
 
-	if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) {
+	key = snp_get_vmpck(snp_dev);
+	if (aesgcm_expandkey(ctx, key, VMPCK_KEY_LEN, AUTHTAG_LEN)) {
 		pr_err("Crypto context initialization failed\n");
 		kfree(ctx);
 		return NULL;
@@ -590,7 +608,7 @@  static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long
 	mutex_lock(&snp_cmd_mutex);
 
 	/* Check if the VMPCK is not empty */
-	if (is_vmpck_empty(snp_dev)) {
+	if (snp_is_vmpck_empty(snp_dev)) {
 		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
 		mutex_unlock(&snp_cmd_mutex);
 		return -ENOTTY;
@@ -667,32 +685,14 @@  static const struct file_operations snp_guest_fops = {
 	.unlocked_ioctl = snp_guest_ioctl,
 };
 
-static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
+static bool snp_assign_vmpck(struct snp_guest_dev *dev, unsigned int vmpck_id)
 {
-	u8 *key = NULL;
+	if (WARN_ON((vmpck_id + 1) > VMPCK_MAX_NUM))
+		return false;
 
-	switch (id) {
-	case 0:
-		*seqno = &layout->os_area.msg_seqno_0;
-		key = layout->vmpck0;
-		break;
-	case 1:
-		*seqno = &layout->os_area.msg_seqno_1;
-		key = layout->vmpck1;
-		break;
-	case 2:
-		*seqno = &layout->os_area.msg_seqno_2;
-		key = layout->vmpck2;
-		break;
-	case 3:
-		*seqno = &layout->os_area.msg_seqno_3;
-		key = layout->vmpck3;
-		break;
-	default:
-		break;
-	}
+	dev->vmpck_id = vmpck_id;
 
-	return key;
+	return true;
 }
 
 struct snp_msg_report_resp_hdr {
@@ -728,7 +728,7 @@  static int sev_report_new(struct tsm_report *report, void *data)
 	guard(mutex)(&snp_cmd_mutex);
 
 	/* Check if the VMPCK is not empty */
-	if (is_vmpck_empty(snp_dev)) {
+	if (snp_is_vmpck_empty(snp_dev)) {
 		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
 		return -ENOTTY;
 	}
@@ -848,21 +848,20 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 		goto e_unmap;
 
 	ret = -EINVAL;
-	snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
-	if (!snp_dev->vmpck) {
-		dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
+	snp_dev->layout = layout;
+	if (!snp_assign_vmpck(snp_dev, vmpck_id)) {
+		dev_err(dev, "invalid vmpck id %u\n", vmpck_id);
 		goto e_unmap;
 	}
 
 	/* Verify that VMPCK is not zero. */
-	if (is_vmpck_empty(snp_dev)) {
-		dev_err(dev, "vmpck id %d is null\n", vmpck_id);
+	if (snp_is_vmpck_empty(snp_dev)) {
+		dev_err(dev, "vmpck id %u is null\n", vmpck_id);
 		goto e_unmap;
 	}
 
 	platform_set_drvdata(pdev, snp_dev);
 	snp_dev->dev = dev;
-	snp_dev->layout = layout;
 
 	/* Allocate the shared page used for the request and response message. */
 	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
@@ -878,7 +877,7 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 		goto e_free_response;
 
 	ret = -EIO;
-	snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN);
+	snp_dev->ctx = snp_init_crypto(snp_dev);
 	if (!snp_dev->ctx)
 		goto e_free_cert_data;
 
@@ -903,7 +902,7 @@  static int __init sev_guest_probe(struct platform_device *pdev)
 	if (ret)
 		goto e_free_ctx;
 
-	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id);
+	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %u)\n", vmpck_id);
 	return 0;
 
 e_free_ctx: