diff mbox series

PCI: hv: Move completion variable from stack to heap in hv_compose_msi_msg()

Message ID 1620806824-31151-1-git-send-email-longli@linuxonhyperv.com (mailing list archive)
State Superseded
Delegated to: Lorenzo Pieralisi
Headers show
Series PCI: hv: Move completion variable from stack to heap in hv_compose_msi_msg() | expand

Commit Message

Long Li May 12, 2021, 8:07 a.m. UTC
From: Long Li <longli@microsoft.com>

hv_compose_msi_msg() may be called with interrupt disabled. It calls
wait_for_completion() in a loop and may exit the loop earlier if the device is
being ejected or it's hitting other errors. However the VSP may send
completion packet after the loop exit and the completion variable is no
longer valid on the stack. This results in a kernel oops.

Fix this by relocating completion variable from stack to heap, and use hbus
to maintain a list of leftover completions for future cleanup if necessary.

Signed-off-by: Long Li <longli@microsoft.com>
---
 drivers/pci/controller/pci-hyperv.c | 97 +++++++++++++++++++----------
 1 file changed, 65 insertions(+), 32 deletions(-)

Comments

Michael Kelley (LINUX) May 26, 2021, 6:27 p.m. UTC | #1
From: longli@linuxonhyperv.com <longli@linuxonhyperv.com> Sent: Wednesday, May 12, 2021 1:07 AM
> 
> hv_compose_msi_msg() may be called with interrupt disabled. It calls
> wait_for_completion() in a loop and may exit the loop earlier if the device is
> being ejected or it's hitting other errors. However the VSP may send
> completion packet after the loop exit and the completion variable is no
> longer valid on the stack. This results in a kernel oops.
> 
> Fix this by relocating completion variable from stack to heap, and use hbus
> to maintain a list of leftover completions for future cleanup if necessary.

Interesting problem.  I haven't reviewed the details of your implementation
because I'd like to propose an alternate approach to solving the problem.

You have fixed the problem for hv_compose_msi_msg(), but it seems like the
same problem could occur in other places in pci-hyperv.c where a VMbus
request is sent, and waiting for the response could be aborted by the device
being rescinded.

The current code (and with your patch) passes the guest memory address of
the completion packet to Hyper-V as the requestID.  Hyper-V responds and
passes back the requestID, whereupon hv_pci_onchannelcallback() treats it
as the guest memory address of the completion packet.  This all assumes that
Hyper-V is trusted and that it doesn't pass back a bogus value that will be
treated as a guest memory address.  But Andrea Parri has been updating
other VMbus drivers (like netvsc and storvsc) to *not* pass guest memory
addresses as the requestID. The pci-hyperv.c driver has not been fixed in this
regard, but I think this patch could take big step in that direction.

My alternate approach is as follows:
1.  For reach PCI VMbus channel, keep a 64-bit counter.  When a VMbus message
is to be sent, increment the counter atomically, and send the next value as the
requestID.   The counter will not wrap-around in any practical time period, so
the requestIDs are essentially unique.  Or just read a clock value to get a unique
requestID.
2.  Also keep a per-channel list of mappings from requestID to the guest memory
address of the completion packet.  For PCI channels, there will be very few
requests outstanding concurrently, so this can be a simple linked list, protected
by a spin lock.
3. Before sending a new VMbus message that is expecting a response, add the
mapping to the list.  The guest memory address can be for a stack local, like
the current code.
4. When the sending function completes, either because the response was
received, or because wait_for_response() aborted, remove the mapping from
the linked list.
5. hv_pci_onchannelcallback() gets the requestID from Hyper-V and looks it
up in the linked list.  If there's no match in the linked list, the completion
response from Hyper-V is ignored.  It's either a late response or a completely
bogus response from Hyper-V.  If there is a match, then the address of the
completion packet is available and valid.  The completion function will need
to run while the spin lock is held on the linked list, so that the completion
packet address is ensured to remain valid while the completion function
executes.

I don't think my proposed approach is any more complicated that what your
patch does, and it is a step in the direction of fully hardening the
pci-hyperv.c driver.

This approach is a bit different from netvsc and storvsc because those drivers
must handle lots of in-flight requests, and searching a linked list in the
onchannelcallback function would be too slow.  The overall idea is the same,
but a different approach is used to generate requestIDs and to map
between requestIDs and guest memory addresses.

Thoughts?

Michael

> 
> Signed-off-by: Long Li <longli@microsoft.com>
> ---
>  drivers/pci/controller/pci-hyperv.c | 97 +++++++++++++++++++----------
>  1 file changed, 65 insertions(+), 32 deletions(-)
> 
> diff --git a/drivers/pci/controller/pci-hyperv.c b/drivers/pci/controller/pci-hyperv.c
> index 9499ae3275fe..29fe26e2193c 100644
> --- a/drivers/pci/controller/pci-hyperv.c
> +++ b/drivers/pci/controller/pci-hyperv.c
> @@ -473,6 +473,9 @@ struct hv_pcibus_device {
>  	struct msi_controller msi_chip;
>  	struct irq_domain *irq_domain;
> 
> +	struct list_head compose_msi_msg_ctxt_list;
> +	spinlock_t compose_msi_msg_ctxt_list_lock;
> +
>  	spinlock_t retarget_msi_interrupt_lock;
> 
>  	struct workqueue_struct *wq;
> @@ -552,6 +555,17 @@ struct hv_pci_compl {
>  	s32 completion_status;
>  };
> 
> +struct compose_comp_ctxt {
> +	struct hv_pci_compl comp_pkt;
> +	struct tran_int_desc int_desc;
> +};
> +
> +struct compose_msi_msg_ctxt {
> +	struct list_head list;
> +	struct pci_packet pci_pkt;
> +	struct compose_comp_ctxt comp;
> +};
> +
>  static void hv_pci_onchannelcallback(void *context);
> 
>  /**
> @@ -1293,11 +1307,6 @@ static void hv_irq_unmask(struct irq_data *data)
>  	pci_msi_unmask_irq(data);
>  }
> 
> -struct compose_comp_ctxt {
> -	struct hv_pci_compl comp_pkt;
> -	struct tran_int_desc int_desc;
> -};
> -
>  static void hv_pci_compose_compl(void *context, struct pci_response *resp,
>  				 int resp_packet_size)
>  {
> @@ -1373,16 +1382,12 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
>  	struct pci_bus *pbus;
>  	struct pci_dev *pdev;
>  	struct cpumask *dest;
> -	struct compose_comp_ctxt comp;
>  	struct tran_int_desc *int_desc;
> -	struct {
> -		struct pci_packet pci_pkt;
> -		union {
> -			struct pci_create_interrupt v1;
> -			struct pci_create_interrupt2 v2;
> -		} int_pkts;
> -	} __packed ctxt;
> -
> +	struct compose_msi_msg_ctxt *ctxt;
> +	union {
> +		struct pci_create_interrupt v1;
> +		struct pci_create_interrupt2 v2;
> +	} int_pkts;
>  	u32 size;
>  	int ret;
> 
> @@ -1402,18 +1407,24 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
>  		hv_int_desc_free(hpdev, int_desc);
>  	}
> 
> +	ctxt = kzalloc(sizeof(*ctxt), GFP_ATOMIC);
> +	if (!ctxt)
> +		goto drop_reference;
> +
>  	int_desc = kzalloc(sizeof(*int_desc), GFP_ATOMIC);
> -	if (!int_desc)
> +	if (!int_desc) {
> +		kfree(ctxt);
>  		goto drop_reference;
> +	}
> 
> -	memset(&ctxt, 0, sizeof(ctxt));
> -	init_completion(&comp.comp_pkt.host_event);
> -	ctxt.pci_pkt.completion_func = hv_pci_compose_compl;
> -	ctxt.pci_pkt.compl_ctxt = &comp;
> +	memset(ctxt, 0, sizeof(*ctxt));
> +	init_completion(&ctxt->comp.comp_pkt.host_event);
> +	ctxt->pci_pkt.completion_func = hv_pci_compose_compl;
> +	ctxt->pci_pkt.compl_ctxt = &ctxt->comp;
> 
>  	switch (hbus->protocol_version) {
>  	case PCI_PROTOCOL_VERSION_1_1:
> -		size = hv_compose_msi_req_v1(&ctxt.int_pkts.v1,
> +		size = hv_compose_msi_req_v1(&int_pkts.v1,
>  					dest,
>  					hpdev->desc.win_slot.slot,
>  					cfg->vector);
> @@ -1421,7 +1432,7 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
> 
>  	case PCI_PROTOCOL_VERSION_1_2:
>  	case PCI_PROTOCOL_VERSION_1_3:
> -		size = hv_compose_msi_req_v2(&ctxt.int_pkts.v2,
> +		size = hv_compose_msi_req_v2(&int_pkts.v2,
>  					dest,
>  					hpdev->desc.win_slot.slot,
>  					cfg->vector);
> @@ -1434,17 +1445,18 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
>  		 */
>  		dev_err(&hbus->hdev->device,
>  			"Unexpected vPCI protocol, update driver.");
> +		kfree(ctxt);
>  		goto free_int_desc;
>  	}
> 
> -	ret = vmbus_sendpacket(hpdev->hbus->hdev->channel, &ctxt.int_pkts,
> -			       size, (unsigned long)&ctxt.pci_pkt,
> +	ret = vmbus_sendpacket(hpdev->hbus->hdev->channel, &int_pkts,
> +			       size, (unsigned long)&ctxt->pci_pkt,
>  			       VM_PKT_DATA_INBAND,
>  			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
>  	if (ret) {
>  		dev_err(&hbus->hdev->device,
> -			"Sending request for interrupt failed: 0x%x",
> -			comp.comp_pkt.completion_status);
> +			"Sending request for interrupt failed: 0x%x", ret);
> +		kfree(ctxt);
>  		goto free_int_desc;
>  	}
> 
> @@ -1458,7 +1470,7 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
>  	 * Since this function is called with IRQ locks held, can't
>  	 * do normal wait for completion; instead poll.
>  	 */
> -	while (!try_wait_for_completion(&comp.comp_pkt.host_event)) {
> +	while (!try_wait_for_completion(&ctxt->comp.comp_pkt.host_event)) {
>  		unsigned long flags;
> 
>  		/* 0xFFFF means an invalid PCI VENDOR ID. */
> @@ -1494,10 +1506,11 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
> 
>  	tasklet_enable(&channel->callback_event);
> 
> -	if (comp.comp_pkt.completion_status < 0) {
> +	if (ctxt->comp.comp_pkt.completion_status < 0) {
>  		dev_err(&hbus->hdev->device,
>  			"Request for interrupt failed: 0x%x",
> -			comp.comp_pkt.completion_status);
> +			ctxt->comp.comp_pkt.completion_status);
> +		kfree(ctxt);
>  		goto free_int_desc;
>  	}
> 
> @@ -1506,23 +1519,36 @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> msi_msg *msg)
>  	 * irq_set_chip_data() here would be appropriate, but the lock it takes
>  	 * is already held.
>  	 */
> -	*int_desc = comp.int_desc;
> +	*int_desc = ctxt->comp.int_desc;
>  	data->chip_data = int_desc;
> 
>  	/* Pass up the result. */
> -	msg->address_hi = comp.int_desc.address >> 32;
> -	msg->address_lo = comp.int_desc.address & 0xffffffff;
> -	msg->data = comp.int_desc.data;
> +	msg->address_hi = ctxt->comp.int_desc.address >> 32;
> +	msg->address_lo = ctxt->comp.int_desc.address & 0xffffffff;
> +	msg->data = ctxt->comp.int_desc.data;
> 
>  	put_pcichild(hpdev);
> +	kfree(ctxt);
>  	return;
> 
>  enable_tasklet:
>  	tasklet_enable(&channel->callback_event);
> +
> +	/*
> +	 * Move uncompleted context to the leftover list.
> +	 * The host may send completion at a later time, and we ignore this
> +	 * completion but keep the memory reference valid.
> +	 */
> +	spin_lock(&hbus->compose_msi_msg_ctxt_list_lock);
> +	list_add_tail(&ctxt->list, &hbus->compose_msi_msg_ctxt_list);
> +	spin_unlock(&hbus->compose_msi_msg_ctxt_list_lock);
> +
>  free_int_desc:
>  	kfree(int_desc);
> +
>  drop_reference:
>  	put_pcichild(hpdev);
> +
>  return_null_message:
>  	msg->address_hi = 0;
>  	msg->address_lo = 0;
> @@ -3076,9 +3102,11 @@ static int hv_pci_probe(struct hv_device *hdev,
>  	INIT_LIST_HEAD(&hbus->children);
>  	INIT_LIST_HEAD(&hbus->dr_list);
>  	INIT_LIST_HEAD(&hbus->resources_for_children);
> +	INIT_LIST_HEAD(&hbus->compose_msi_msg_ctxt_list);
>  	spin_lock_init(&hbus->config_lock);
>  	spin_lock_init(&hbus->device_list_lock);
>  	spin_lock_init(&hbus->retarget_msi_interrupt_lock);
> +	spin_lock_init(&hbus->compose_msi_msg_ctxt_list_lock);
>  	hbus->wq = alloc_ordered_workqueue("hv_pci_%x", 0,
>  					   hbus->sysdata.domain);
>  	if (!hbus->wq) {
> @@ -3282,6 +3310,7 @@ static int hv_pci_bus_exit(struct hv_device *hdev, bool
> keep_devs)
>  static int hv_pci_remove(struct hv_device *hdev)
>  {
>  	struct hv_pcibus_device *hbus;
> +	struct compose_msi_msg_ctxt *ctxt, *tmp;
>  	int ret;
> 
>  	hbus = hv_get_drvdata(hdev);
> @@ -3318,6 +3347,10 @@ static int hv_pci_remove(struct hv_device *hdev)
> 
>  	hv_put_dom_num(hbus->sysdata.domain);
> 
> +	list_for_each_entry_safe(ctxt, tmp, &hbus->compose_msi_msg_ctxt_list, list) {
> +		list_del(&ctxt->list);
> +		kfree(ctxt);
> +	}
>  	kfree(hbus);
>  	return ret;
>  }
> --
> 2.27.0
Long Li June 1, 2021, 7:27 p.m. UTC | #2
> Subject: RE: [PATCH] PCI: hv: Move completion variable from stack to heap in
> hv_compose_msi_msg()
> 
> From: longli@linuxonhyperv.com <longli@linuxonhyperv.com> Sent:
> Wednesday, May 12, 2021 1:07 AM
> >
> > hv_compose_msi_msg() may be called with interrupt disabled. It calls
> > wait_for_completion() in a loop and may exit the loop earlier if the
> > device is being ejected or it's hitting other errors. However the VSP
> > may send completion packet after the loop exit and the completion
> > variable is no longer valid on the stack. This results in a kernel oops.
> >
> > Fix this by relocating completion variable from stack to heap, and use
> > hbus to maintain a list of leftover completions for future cleanup if
> necessary.
> 
> Interesting problem.  I haven't reviewed the details of your implementation
> because I'd like to propose an alternate approach to solving the problem.
> 
> You have fixed the problem for hv_compose_msi_msg(), but it seems like the
> same problem could occur in other places in pci-hyperv.c where a VMbus
> request is sent, and waiting for the response could be aborted by the device
> being rescinded.

The problem in hv_compose_msi_msg() is different to other places, it's a bug in the PCI driver that it doesn't handle the case where the device is ejected (PCI_EJECT). After device is ejected, it's valid that VSP may still send back completion on a prior pending request.

On the other hand, if a device is rescinded, it's not possible to get a completion on this device afterwards. If we are still getting a completion, it's a bug in the VSP or it's from a malicious host.

I agree if the intent is to deal with a untrusted host, I can follow the same principle to add this support to all requests to VSP. But this is a different problem to what this patch intends to address. I can see they may share the same design principle and common code. My question on a untrusted host is: If a host is untrusted and is misbehaving on purpose, what's the point of keep the VM running and not crashing the PCI driver?

> 
> The current code (and with your patch) passes the guest memory address of
> the completion packet to Hyper-V as the requestID.  Hyper-V responds and
> passes back the requestID, whereupon hv_pci_onchannelcallback() treats it as
> the guest memory address of the completion packet.  This all assumes that
> Hyper-V is trusted and that it doesn't pass back a bogus value that will be
> treated as a guest memory address.  But Andrea Parri has been updating
> other VMbus drivers (like netvsc and storvsc) to *not* pass guest memory
> addresses as the requestID. The pci-hyperv.c driver has not been fixed in this
> regard, but I think this patch could take big step in that direction.
> 
> My alternate approach is as follows:
> 1.  For reach PCI VMbus channel, keep a 64-bit counter.  When a VMbus
> message is to be sent, increment the counter atomically, and send the next
> value as the
> requestID.   The counter will not wrap-around in any practical time period, so
> the requestIDs are essentially unique.  Or just read a clock value to get a
> unique requestID.
> 2.  Also keep a per-channel list of mappings from requestID to the guest
> memory address of the completion packet.  For PCI channels, there will be
> very few requests outstanding concurrently, so this can be a simple linked list,
> protected by a spin lock.
> 3. Before sending a new VMbus message that is expecting a response, add the
> mapping to the list.  The guest memory address can be for a stack local, like
> the current code.
> 4. When the sending function completes, either because the response was
> received, or because wait_for_response() aborted, remove the mapping from
> the linked list.
> 5. hv_pci_onchannelcallback() gets the requestID from Hyper-V and looks it
> up in the linked list.  If there's no match in the linked list, the completion
> response from Hyper-V is ignored.  It's either a late response or a completely
> bogus response from Hyper-V.  If there is a match, then the address of the
> completion packet is available and valid.  The completion function will need to
> run while the spin lock is held on the linked list, so that the completion packet
> address is ensured to remain valid while the completion function executes.
> 
> I don't think my proposed approach is any more complicated that what your
> patch does, and it is a step in the direction of fully hardening the pci-hyperv.c
> driver.
> 
> This approach is a bit different from netvsc and storvsc because those drivers
> must handle lots of in-flight requests, and searching a linked list in the
> onchannelcallback function would be too slow.  The overall idea is the same,
> but a different approach is used to generate requestIDs and to map between
> requestIDs and guest memory addresses.
> 
> Thoughts?
> 
> Michael
> 
> >
> > Signed-off-by: Long Li <longli@microsoft.com>
> > ---
> >  drivers/pci/controller/pci-hyperv.c | 97
> > +++++++++++++++++++----------
> >  1 file changed, 65 insertions(+), 32 deletions(-)
> >
> > diff --git a/drivers/pci/controller/pci-hyperv.c
> > b/drivers/pci/controller/pci-hyperv.c
> > index 9499ae3275fe..29fe26e2193c 100644
> > --- a/drivers/pci/controller/pci-hyperv.c
> > +++ b/drivers/pci/controller/pci-hyperv.c
> > @@ -473,6 +473,9 @@ struct hv_pcibus_device {
> >  	struct msi_controller msi_chip;
> >  	struct irq_domain *irq_domain;
> >
> > +	struct list_head compose_msi_msg_ctxt_list;
> > +	spinlock_t compose_msi_msg_ctxt_list_lock;
> > +
> >  	spinlock_t retarget_msi_interrupt_lock;
> >
> >  	struct workqueue_struct *wq;
> > @@ -552,6 +555,17 @@ struct hv_pci_compl {
> >  	s32 completion_status;
> >  };
> >
> > +struct compose_comp_ctxt {
> > +	struct hv_pci_compl comp_pkt;
> > +	struct tran_int_desc int_desc;
> > +};
> > +
> > +struct compose_msi_msg_ctxt {
> > +	struct list_head list;
> > +	struct pci_packet pci_pkt;
> > +	struct compose_comp_ctxt comp;
> > +};
> > +
> >  static void hv_pci_onchannelcallback(void *context);
> >
> >  /**
> > @@ -1293,11 +1307,6 @@ static void hv_irq_unmask(struct irq_data *data)
> >  	pci_msi_unmask_irq(data);
> >  }
> >
> > -struct compose_comp_ctxt {
> > -	struct hv_pci_compl comp_pkt;
> > -	struct tran_int_desc int_desc;
> > -};
> > -
> >  static void hv_pci_compose_compl(void *context, struct pci_response
> *resp,
> >  				 int resp_packet_size)
> >  {
> > @@ -1373,16 +1382,12 @@ static void hv_compose_msi_msg(struct
> irq_data
> > *data, struct msi_msg *msg)
> >  	struct pci_bus *pbus;
> >  	struct pci_dev *pdev;
> >  	struct cpumask *dest;
> > -	struct compose_comp_ctxt comp;
> >  	struct tran_int_desc *int_desc;
> > -	struct {
> > -		struct pci_packet pci_pkt;
> > -		union {
> > -			struct pci_create_interrupt v1;
> > -			struct pci_create_interrupt2 v2;
> > -		} int_pkts;
> > -	} __packed ctxt;
> > -
> > +	struct compose_msi_msg_ctxt *ctxt;
> > +	union {
> > +		struct pci_create_interrupt v1;
> > +		struct pci_create_interrupt2 v2;
> > +	} int_pkts;
> >  	u32 size;
> >  	int ret;
> >
> > @@ -1402,18 +1407,24 @@ static void hv_compose_msi_msg(struct
> irq_data
> > *data, struct msi_msg *msg)
> >  		hv_int_desc_free(hpdev, int_desc);
> >  	}
> >
> > +	ctxt = kzalloc(sizeof(*ctxt), GFP_ATOMIC);
> > +	if (!ctxt)
> > +		goto drop_reference;
> > +
> >  	int_desc = kzalloc(sizeof(*int_desc), GFP_ATOMIC);
> > -	if (!int_desc)
> > +	if (!int_desc) {
> > +		kfree(ctxt);
> >  		goto drop_reference;
> > +	}
> >
> > -	memset(&ctxt, 0, sizeof(ctxt));
> > -	init_completion(&comp.comp_pkt.host_event);
> > -	ctxt.pci_pkt.completion_func = hv_pci_compose_compl;
> > -	ctxt.pci_pkt.compl_ctxt = &comp;
> > +	memset(ctxt, 0, sizeof(*ctxt));
> > +	init_completion(&ctxt->comp.comp_pkt.host_event);
> > +	ctxt->pci_pkt.completion_func = hv_pci_compose_compl;
> > +	ctxt->pci_pkt.compl_ctxt = &ctxt->comp;
> >
> >  	switch (hbus->protocol_version) {
> >  	case PCI_PROTOCOL_VERSION_1_1:
> > -		size = hv_compose_msi_req_v1(&ctxt.int_pkts.v1,
> > +		size = hv_compose_msi_req_v1(&int_pkts.v1,
> >  					dest,
> >  					hpdev->desc.win_slot.slot,
> >  					cfg->vector);
> > @@ -1421,7 +1432,7 @@ static void hv_compose_msi_msg(struct irq_data
> > *data, struct msi_msg *msg)
> >
> >  	case PCI_PROTOCOL_VERSION_1_2:
> >  	case PCI_PROTOCOL_VERSION_1_3:
> > -		size = hv_compose_msi_req_v2(&ctxt.int_pkts.v2,
> > +		size = hv_compose_msi_req_v2(&int_pkts.v2,
> >  					dest,
> >  					hpdev->desc.win_slot.slot,
> >  					cfg->vector);
> > @@ -1434,17 +1445,18 @@ static void hv_compose_msi_msg(struct
> irq_data
> > *data, struct msi_msg *msg)
> >  		 */
> >  		dev_err(&hbus->hdev->device,
> >  			"Unexpected vPCI protocol, update driver.");
> > +		kfree(ctxt);
> >  		goto free_int_desc;
> >  	}
> >
> > -	ret = vmbus_sendpacket(hpdev->hbus->hdev->channel,
> &ctxt.int_pkts,
> > -			       size, (unsigned long)&ctxt.pci_pkt,
> > +	ret = vmbus_sendpacket(hpdev->hbus->hdev->channel, &int_pkts,
> > +			       size, (unsigned long)&ctxt->pci_pkt,
> >  			       VM_PKT_DATA_INBAND,
> >
> VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
> >  	if (ret) {
> >  		dev_err(&hbus->hdev->device,
> > -			"Sending request for interrupt failed: 0x%x",
> > -			comp.comp_pkt.completion_status);
> > +			"Sending request for interrupt failed: 0x%x", ret);
> > +		kfree(ctxt);
> >  		goto free_int_desc;
> >  	}
> >
> > @@ -1458,7 +1470,7 @@ static void hv_compose_msi_msg(struct irq_data
> > *data, struct msi_msg *msg)
> >  	 * Since this function is called with IRQ locks held, can't
> >  	 * do normal wait for completion; instead poll.
> >  	 */
> > -	while (!try_wait_for_completion(&comp.comp_pkt.host_event)) {
> > +	while (!try_wait_for_completion(&ctxt->comp.comp_pkt.host_event))
> {
> >  		unsigned long flags;
> >
> >  		/* 0xFFFF means an invalid PCI VENDOR ID. */ @@ -1494,10
> +1506,11
> > @@ static void hv_compose_msi_msg(struct irq_data *data, struct
> > msi_msg *msg)
> >
> >  	tasklet_enable(&channel->callback_event);
> >
> > -	if (comp.comp_pkt.completion_status < 0) {
> > +	if (ctxt->comp.comp_pkt.completion_status < 0) {
> >  		dev_err(&hbus->hdev->device,
> >  			"Request for interrupt failed: 0x%x",
> > -			comp.comp_pkt.completion_status);
> > +			ctxt->comp.comp_pkt.completion_status);
> > +		kfree(ctxt);
> >  		goto free_int_desc;
> >  	}
> >
> > @@ -1506,23 +1519,36 @@ static void hv_compose_msi_msg(struct
> irq_data
> > *data, struct msi_msg *msg)
> >  	 * irq_set_chip_data() here would be appropriate, but the lock it
> takes
> >  	 * is already held.
> >  	 */
> > -	*int_desc = comp.int_desc;
> > +	*int_desc = ctxt->comp.int_desc;
> >  	data->chip_data = int_desc;
> >
> >  	/* Pass up the result. */
> > -	msg->address_hi = comp.int_desc.address >> 32;
> > -	msg->address_lo = comp.int_desc.address & 0xffffffff;
> > -	msg->data = comp.int_desc.data;
> > +	msg->address_hi = ctxt->comp.int_desc.address >> 32;
> > +	msg->address_lo = ctxt->comp.int_desc.address & 0xffffffff;
> > +	msg->data = ctxt->comp.int_desc.data;
> >
> >  	put_pcichild(hpdev);
> > +	kfree(ctxt);
> >  	return;
> >
> >  enable_tasklet:
> >  	tasklet_enable(&channel->callback_event);
> > +
> > +	/*
> > +	 * Move uncompleted context to the leftover list.
> > +	 * The host may send completion at a later time, and we ignore this
> > +	 * completion but keep the memory reference valid.
> > +	 */
> > +	spin_lock(&hbus->compose_msi_msg_ctxt_list_lock);
> > +	list_add_tail(&ctxt->list, &hbus->compose_msi_msg_ctxt_list);
> > +	spin_unlock(&hbus->compose_msi_msg_ctxt_list_lock);
> > +
> >  free_int_desc:
> >  	kfree(int_desc);
> > +
> >  drop_reference:
> >  	put_pcichild(hpdev);
> > +
> >  return_null_message:
> >  	msg->address_hi = 0;
> >  	msg->address_lo = 0;
> > @@ -3076,9 +3102,11 @@ static int hv_pci_probe(struct hv_device *hdev,
> >  	INIT_LIST_HEAD(&hbus->children);
> >  	INIT_LIST_HEAD(&hbus->dr_list);
> >  	INIT_LIST_HEAD(&hbus->resources_for_children);
> > +	INIT_LIST_HEAD(&hbus->compose_msi_msg_ctxt_list);
> >  	spin_lock_init(&hbus->config_lock);
> >  	spin_lock_init(&hbus->device_list_lock);
> >  	spin_lock_init(&hbus->retarget_msi_interrupt_lock);
> > +	spin_lock_init(&hbus->compose_msi_msg_ctxt_list_lock);
> >  	hbus->wq = alloc_ordered_workqueue("hv_pci_%x", 0,
> >  					   hbus->sysdata.domain);
> >  	if (!hbus->wq) {
> > @@ -3282,6 +3310,7 @@ static int hv_pci_bus_exit(struct hv_device
> > *hdev, bool
> > keep_devs)
> >  static int hv_pci_remove(struct hv_device *hdev)  {
> >  	struct hv_pcibus_device *hbus;
> > +	struct compose_msi_msg_ctxt *ctxt, *tmp;
> >  	int ret;
> >
> >  	hbus = hv_get_drvdata(hdev);
> > @@ -3318,6 +3347,10 @@ static int hv_pci_remove(struct hv_device
> > *hdev)
> >
> >  	hv_put_dom_num(hbus->sysdata.domain);
> >
> > +	list_for_each_entry_safe(ctxt, tmp, &hbus-
> >compose_msi_msg_ctxt_list, list) {
> > +		list_del(&ctxt->list);
> > +		kfree(ctxt);
> > +	}
> >  	kfree(hbus);
> >  	return ret;
> >  }
> > --
> > 2.27.0
Andrea Parri June 1, 2021, 11:13 p.m. UTC | #3
> I agree if the intent is to deal with a untrusted host, I can follow the same principle to add this support to all requests to VSP. But this is a different problem to what this patch intends to address. I can see they may share the same design principle and common code. My question on a untrusted host is: If a host is untrusted and is misbehaving on purpose, what's the point of keep the VM running and not crashing the PCI driver?

I think the principle can be summarized with "keep the VM _running, if you can
handle the misbehaviour (possibly, warning on "something wrong/unexpected just
happened"); crash, otherwise".

Of course, this is just a principle: the exact meaning of that 'handle' should
be leverage case by case (which I admittedly haven't here); I'm thinking, e.g.,
at corresponding complexity/performance impacts and risks of 'mis-assessments'.

Thanks,
  Andrea
Long Li June 4, 2021, 8:49 a.m. UTC | #4
> Subject: Re: [PATCH] PCI: hv: Move completion variable from stack to heap in
> hv_compose_msi_msg()
> 
> > I agree if the intent is to deal with a untrusted host, I can follow the same
> principle to add this support to all requests to VSP. But this is a different
> problem to what this patch intends to address. I can see they may share the
> same design principle and common code. My question on a untrusted host is:
> If a host is untrusted and is misbehaving on purpose, what's the point of
> keep the VM running and not crashing the PCI driver?
> 
> I think the principle can be summarized with "keep the VM _running, if you
> can handle the misbehaviour (possibly, warning on "something
> wrong/unexpected just happened"); crash, otherwise".
> 
> Of course, this is just a principle: the exact meaning of that 'handle' should be
> leverage case by case (which I admittedly haven't here); I'm thinking, e.g., at
> corresponding complexity/performance impacts and risks of 'mis-
> assessments'.
> 
> Thanks,
>   Andrea

I will follow Michael's suggestion and send v2.

Long
diff mbox series

Patch

diff --git a/drivers/pci/controller/pci-hyperv.c b/drivers/pci/controller/pci-hyperv.c
index 9499ae3275fe..29fe26e2193c 100644
--- a/drivers/pci/controller/pci-hyperv.c
+++ b/drivers/pci/controller/pci-hyperv.c
@@ -473,6 +473,9 @@  struct hv_pcibus_device {
 	struct msi_controller msi_chip;
 	struct irq_domain *irq_domain;
 
+	struct list_head compose_msi_msg_ctxt_list;
+	spinlock_t compose_msi_msg_ctxt_list_lock;
+
 	spinlock_t retarget_msi_interrupt_lock;
 
 	struct workqueue_struct *wq;
@@ -552,6 +555,17 @@  struct hv_pci_compl {
 	s32 completion_status;
 };
 
+struct compose_comp_ctxt {
+	struct hv_pci_compl comp_pkt;
+	struct tran_int_desc int_desc;
+};
+
+struct compose_msi_msg_ctxt {
+	struct list_head list;
+	struct pci_packet pci_pkt;
+	struct compose_comp_ctxt comp;
+};
+
 static void hv_pci_onchannelcallback(void *context);
 
 /**
@@ -1293,11 +1307,6 @@  static void hv_irq_unmask(struct irq_data *data)
 	pci_msi_unmask_irq(data);
 }
 
-struct compose_comp_ctxt {
-	struct hv_pci_compl comp_pkt;
-	struct tran_int_desc int_desc;
-};
-
 static void hv_pci_compose_compl(void *context, struct pci_response *resp,
 				 int resp_packet_size)
 {
@@ -1373,16 +1382,12 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 	struct pci_bus *pbus;
 	struct pci_dev *pdev;
 	struct cpumask *dest;
-	struct compose_comp_ctxt comp;
 	struct tran_int_desc *int_desc;
-	struct {
-		struct pci_packet pci_pkt;
-		union {
-			struct pci_create_interrupt v1;
-			struct pci_create_interrupt2 v2;
-		} int_pkts;
-	} __packed ctxt;
-
+	struct compose_msi_msg_ctxt *ctxt;
+	union {
+		struct pci_create_interrupt v1;
+		struct pci_create_interrupt2 v2;
+	} int_pkts;
 	u32 size;
 	int ret;
 
@@ -1402,18 +1407,24 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 		hv_int_desc_free(hpdev, int_desc);
 	}
 
+	ctxt = kzalloc(sizeof(*ctxt), GFP_ATOMIC);
+	if (!ctxt)
+		goto drop_reference;
+
 	int_desc = kzalloc(sizeof(*int_desc), GFP_ATOMIC);
-	if (!int_desc)
+	if (!int_desc) {
+		kfree(ctxt);
 		goto drop_reference;
+	}
 
-	memset(&ctxt, 0, sizeof(ctxt));
-	init_completion(&comp.comp_pkt.host_event);
-	ctxt.pci_pkt.completion_func = hv_pci_compose_compl;
-	ctxt.pci_pkt.compl_ctxt = &comp;
+	memset(ctxt, 0, sizeof(*ctxt));
+	init_completion(&ctxt->comp.comp_pkt.host_event);
+	ctxt->pci_pkt.completion_func = hv_pci_compose_compl;
+	ctxt->pci_pkt.compl_ctxt = &ctxt->comp;
 
 	switch (hbus->protocol_version) {
 	case PCI_PROTOCOL_VERSION_1_1:
-		size = hv_compose_msi_req_v1(&ctxt.int_pkts.v1,
+		size = hv_compose_msi_req_v1(&int_pkts.v1,
 					dest,
 					hpdev->desc.win_slot.slot,
 					cfg->vector);
@@ -1421,7 +1432,7 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 
 	case PCI_PROTOCOL_VERSION_1_2:
 	case PCI_PROTOCOL_VERSION_1_3:
-		size = hv_compose_msi_req_v2(&ctxt.int_pkts.v2,
+		size = hv_compose_msi_req_v2(&int_pkts.v2,
 					dest,
 					hpdev->desc.win_slot.slot,
 					cfg->vector);
@@ -1434,17 +1445,18 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 		 */
 		dev_err(&hbus->hdev->device,
 			"Unexpected vPCI protocol, update driver.");
+		kfree(ctxt);
 		goto free_int_desc;
 	}
 
-	ret = vmbus_sendpacket(hpdev->hbus->hdev->channel, &ctxt.int_pkts,
-			       size, (unsigned long)&ctxt.pci_pkt,
+	ret = vmbus_sendpacket(hpdev->hbus->hdev->channel, &int_pkts,
+			       size, (unsigned long)&ctxt->pci_pkt,
 			       VM_PKT_DATA_INBAND,
 			       VMBUS_DATA_PACKET_FLAG_COMPLETION_REQUESTED);
 	if (ret) {
 		dev_err(&hbus->hdev->device,
-			"Sending request for interrupt failed: 0x%x",
-			comp.comp_pkt.completion_status);
+			"Sending request for interrupt failed: 0x%x", ret);
+		kfree(ctxt);
 		goto free_int_desc;
 	}
 
@@ -1458,7 +1470,7 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 	 * Since this function is called with IRQ locks held, can't
 	 * do normal wait for completion; instead poll.
 	 */
-	while (!try_wait_for_completion(&comp.comp_pkt.host_event)) {
+	while (!try_wait_for_completion(&ctxt->comp.comp_pkt.host_event)) {
 		unsigned long flags;
 
 		/* 0xFFFF means an invalid PCI VENDOR ID. */
@@ -1494,10 +1506,11 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 
 	tasklet_enable(&channel->callback_event);
 
-	if (comp.comp_pkt.completion_status < 0) {
+	if (ctxt->comp.comp_pkt.completion_status < 0) {
 		dev_err(&hbus->hdev->device,
 			"Request for interrupt failed: 0x%x",
-			comp.comp_pkt.completion_status);
+			ctxt->comp.comp_pkt.completion_status);
+		kfree(ctxt);
 		goto free_int_desc;
 	}
 
@@ -1506,23 +1519,36 @@  static void hv_compose_msi_msg(struct irq_data *data, struct msi_msg *msg)
 	 * irq_set_chip_data() here would be appropriate, but the lock it takes
 	 * is already held.
 	 */
-	*int_desc = comp.int_desc;
+	*int_desc = ctxt->comp.int_desc;
 	data->chip_data = int_desc;
 
 	/* Pass up the result. */
-	msg->address_hi = comp.int_desc.address >> 32;
-	msg->address_lo = comp.int_desc.address & 0xffffffff;
-	msg->data = comp.int_desc.data;
+	msg->address_hi = ctxt->comp.int_desc.address >> 32;
+	msg->address_lo = ctxt->comp.int_desc.address & 0xffffffff;
+	msg->data = ctxt->comp.int_desc.data;
 
 	put_pcichild(hpdev);
+	kfree(ctxt);
 	return;
 
 enable_tasklet:
 	tasklet_enable(&channel->callback_event);
+
+	/*
+	 * Move uncompleted context to the leftover list.
+	 * The host may send completion at a later time, and we ignore this
+	 * completion but keep the memory reference valid.
+	 */
+	spin_lock(&hbus->compose_msi_msg_ctxt_list_lock);
+	list_add_tail(&ctxt->list, &hbus->compose_msi_msg_ctxt_list);
+	spin_unlock(&hbus->compose_msi_msg_ctxt_list_lock);
+
 free_int_desc:
 	kfree(int_desc);
+
 drop_reference:
 	put_pcichild(hpdev);
+
 return_null_message:
 	msg->address_hi = 0;
 	msg->address_lo = 0;
@@ -3076,9 +3102,11 @@  static int hv_pci_probe(struct hv_device *hdev,
 	INIT_LIST_HEAD(&hbus->children);
 	INIT_LIST_HEAD(&hbus->dr_list);
 	INIT_LIST_HEAD(&hbus->resources_for_children);
+	INIT_LIST_HEAD(&hbus->compose_msi_msg_ctxt_list);
 	spin_lock_init(&hbus->config_lock);
 	spin_lock_init(&hbus->device_list_lock);
 	spin_lock_init(&hbus->retarget_msi_interrupt_lock);
+	spin_lock_init(&hbus->compose_msi_msg_ctxt_list_lock);
 	hbus->wq = alloc_ordered_workqueue("hv_pci_%x", 0,
 					   hbus->sysdata.domain);
 	if (!hbus->wq) {
@@ -3282,6 +3310,7 @@  static int hv_pci_bus_exit(struct hv_device *hdev, bool keep_devs)
 static int hv_pci_remove(struct hv_device *hdev)
 {
 	struct hv_pcibus_device *hbus;
+	struct compose_msi_msg_ctxt *ctxt, *tmp;
 	int ret;
 
 	hbus = hv_get_drvdata(hdev);
@@ -3318,6 +3347,10 @@  static int hv_pci_remove(struct hv_device *hdev)
 
 	hv_put_dom_num(hbus->sysdata.domain);
 
+	list_for_each_entry_safe(ctxt, tmp, &hbus->compose_msi_msg_ctxt_list, list) {
+		list_del(&ctxt->list);
+		kfree(ctxt);
+	}
 	kfree(hbus);
 	return ret;
 }