diff mbox series

[v3,13/14] vfio/pci: Replace uses of vfio_device_data() with container_of

Message ID 13-v3-225de1400dfc+4e074-vfio1_jgg@nvidia.com (mailing list archive)
State New, archived
Headers show
Series Embed struct vfio_device in all sub-structures | expand

Commit Message

Jason Gunthorpe March 23, 2021, 4:15 p.m. UTC
This tidies a few confused places that think they can have a refcount on
the vfio_device but the device_data could be NULL, that isn't possible by
design.

Most of the change falls out when struct vfio_devices is updated to just
store the struct vfio_pci_device itself. This wasn't possible before
because there was no easy way to get from the 'struct vfio_pci_device' to
the 'struct vfio_device' to put back the refcount.

Reviewed-by: Christoph Hellwig <hch@lst.de>
Reviewed-by: Kevin Tian <kevin.tian@intel.com>
Reviewed-by: Cornelia Huck <cohuck@redhat.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
---
 drivers/vfio/pci/vfio_pci.c | 67 +++++++++++++------------------------
 1 file changed, 24 insertions(+), 43 deletions(-)
diff mbox series

Patch

diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index 5f1a782d1c65ae..1f70387c8afe37 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -517,30 +517,29 @@  static void vfio_pci_disable(struct vfio_pci_device *vdev)
 
 static struct pci_driver vfio_pci_driver;
 
-static struct vfio_pci_device *get_pf_vdev(struct vfio_pci_device *vdev,
-					   struct vfio_device **pf_dev)
+static struct vfio_pci_device *get_pf_vdev(struct vfio_pci_device *vdev)
 {
 	struct pci_dev *physfn = pci_physfn(vdev->pdev);
+	struct vfio_device *pf_dev;
 
 	if (!vdev->pdev->is_virtfn)
 		return NULL;
 
-	*pf_dev = vfio_device_get_from_dev(&physfn->dev);
-	if (!*pf_dev)
+	pf_dev = vfio_device_get_from_dev(&physfn->dev);
+	if (!pf_dev)
 		return NULL;
 
 	if (pci_dev_driver(physfn) != &vfio_pci_driver) {
-		vfio_device_put(*pf_dev);
+		vfio_device_put(pf_dev);
 		return NULL;
 	}
 
-	return vfio_device_data(*pf_dev);
+	return container_of(pf_dev, struct vfio_pci_device, vdev);
 }
 
 static void vfio_pci_vf_token_user_add(struct vfio_pci_device *vdev, int val)
 {
-	struct vfio_device *pf_dev;
-	struct vfio_pci_device *pf_vdev = get_pf_vdev(vdev, &pf_dev);
+	struct vfio_pci_device *pf_vdev = get_pf_vdev(vdev);
 
 	if (!pf_vdev)
 		return;
@@ -550,7 +549,7 @@  static void vfio_pci_vf_token_user_add(struct vfio_pci_device *vdev, int val)
 	WARN_ON(pf_vdev->vf_token->users < 0);
 	mutex_unlock(&pf_vdev->vf_token->lock);
 
-	vfio_device_put(pf_dev);
+	vfio_device_put(&pf_vdev->vdev);
 }
 
 static void vfio_pci_release(struct vfio_device *core_vdev)
@@ -794,7 +793,7 @@  int vfio_pci_register_dev_region(struct vfio_pci_device *vdev,
 }
 
 struct vfio_devices {
-	struct vfio_device **devices;
+	struct vfio_pci_device **devices;
 	int cur_index;
 	int max_index;
 };
@@ -1283,9 +1282,7 @@  static long vfio_pci_ioctl(struct vfio_device *core_vdev,
 			goto hot_reset_release;
 
 		for (; mem_idx < devs.cur_index; mem_idx++) {
-			struct vfio_pci_device *tmp;
-
-			tmp = vfio_device_data(devs.devices[mem_idx]);
+			struct vfio_pci_device *tmp = devs.devices[mem_idx];
 
 			ret = down_write_trylock(&tmp->memory_lock);
 			if (!ret) {
@@ -1300,17 +1297,13 @@  static long vfio_pci_ioctl(struct vfio_device *core_vdev,
 
 hot_reset_release:
 		for (i = 0; i < devs.cur_index; i++) {
-			struct vfio_device *device;
-			struct vfio_pci_device *tmp;
-
-			device = devs.devices[i];
-			tmp = vfio_device_data(device);
+			struct vfio_pci_device *tmp = devs.devices[i];
 
 			if (i < mem_idx)
 				up_write(&tmp->memory_lock);
 			else
 				mutex_unlock(&tmp->vma_lock);
-			vfio_device_put(device);
+			vfio_device_put(&tmp->vdev);
 		}
 		kfree(devs.devices);
 
@@ -1777,8 +1770,7 @@  static int vfio_pci_validate_vf_token(struct vfio_pci_device *vdev,
 		return 0; /* No VF token provided or required */
 
 	if (vdev->pdev->is_virtfn) {
-		struct vfio_device *pf_dev;
-		struct vfio_pci_device *pf_vdev = get_pf_vdev(vdev, &pf_dev);
+		struct vfio_pci_device *pf_vdev = get_pf_vdev(vdev);
 		bool match;
 
 		if (!pf_vdev) {
@@ -1791,7 +1783,7 @@  static int vfio_pci_validate_vf_token(struct vfio_pci_device *vdev,
 		}
 
 		if (!vf_token) {
-			vfio_device_put(pf_dev);
+			vfio_device_put(&pf_vdev->vdev);
 			pci_info_ratelimited(vdev->pdev,
 				"VF token required to access device\n");
 			return -EACCES;
@@ -1801,7 +1793,7 @@  static int vfio_pci_validate_vf_token(struct vfio_pci_device *vdev,
 		match = uuid_equal(uuid, &pf_vdev->vf_token->uuid);
 		mutex_unlock(&pf_vdev->vf_token->lock);
 
-		vfio_device_put(pf_dev);
+		vfio_device_put(&pf_vdev->vdev);
 
 		if (!match) {
 			pci_info_ratelimited(vdev->pdev,
@@ -2122,11 +2114,7 @@  static pci_ers_result_t vfio_pci_aer_err_detected(struct pci_dev *pdev,
 	if (device == NULL)
 		return PCI_ERS_RESULT_DISCONNECT;
 
-	vdev = vfio_device_data(device);
-	if (vdev == NULL) {
-		vfio_device_put(device);
-		return PCI_ERS_RESULT_DISCONNECT;
-	}
+	vdev = container_of(device, struct vfio_pci_device, vdev);
 
 	mutex_lock(&vdev->igate);
 
@@ -2142,7 +2130,6 @@  static pci_ers_result_t vfio_pci_aer_err_detected(struct pci_dev *pdev,
 
 static int vfio_pci_sriov_configure(struct pci_dev *pdev, int nr_virtfn)
 {
-	struct vfio_pci_device *vdev;
 	struct vfio_device *device;
 	int ret = 0;
 
@@ -2155,12 +2142,6 @@  static int vfio_pci_sriov_configure(struct pci_dev *pdev, int nr_virtfn)
 	if (!device)
 		return -ENODEV;
 
-	vdev = vfio_device_data(device);
-	if (!vdev) {
-		vfio_device_put(device);
-		return -ENODEV;
-	}
-
 	if (nr_virtfn == 0)
 		pci_disable_sriov(pdev);
 	else
@@ -2220,7 +2201,7 @@  static int vfio_pci_reflck_find(struct pci_dev *pdev, void *data)
 		return 0;
 	}
 
-	vdev = vfio_device_data(device);
+	vdev = container_of(device, struct vfio_pci_device, vdev);
 
 	if (vdev->reflck) {
 		vfio_pci_reflck_get(vdev->reflck);
@@ -2282,7 +2263,7 @@  static int vfio_pci_get_unused_devs(struct pci_dev *pdev, void *data)
 		return -EBUSY;
 	}
 
-	vdev = vfio_device_data(device);
+	vdev = container_of(device, struct vfio_pci_device, vdev);
 
 	/* Fault if the device is not unused */
 	if (vdev->refcnt) {
@@ -2290,7 +2271,7 @@  static int vfio_pci_get_unused_devs(struct pci_dev *pdev, void *data)
 		return -EBUSY;
 	}
 
-	devs->devices[devs->cur_index++] = device;
+	devs->devices[devs->cur_index++] = vdev;
 	return 0;
 }
 
@@ -2312,7 +2293,7 @@  static int vfio_pci_try_zap_and_vma_lock_cb(struct pci_dev *pdev, void *data)
 		return -EBUSY;
 	}
 
-	vdev = vfio_device_data(device);
+	vdev = container_of(device, struct vfio_pci_device, vdev);
 
 	/*
 	 * Locking multiple devices is prone to deadlock, runaway and
@@ -2323,7 +2304,7 @@  static int vfio_pci_try_zap_and_vma_lock_cb(struct pci_dev *pdev, void *data)
 		return -EBUSY;
 	}
 
-	devs->devices[devs->cur_index++] = device;
+	devs->devices[devs->cur_index++] = vdev;
 	return 0;
 }
 
@@ -2371,7 +2352,7 @@  static void vfio_pci_try_bus_reset(struct vfio_pci_device *vdev)
 
 	/* Does at least one need a reset? */
 	for (i = 0; i < devs.cur_index; i++) {
-		tmp = vfio_device_data(devs.devices[i]);
+		tmp = devs.devices[i];
 		if (tmp->needs_reset) {
 			ret = pci_reset_bus(vdev->pdev);
 			break;
@@ -2380,7 +2361,7 @@  static void vfio_pci_try_bus_reset(struct vfio_pci_device *vdev)
 
 put_devs:
 	for (i = 0; i < devs.cur_index; i++) {
-		tmp = vfio_device_data(devs.devices[i]);
+		tmp = devs.devices[i];
 
 		/*
 		 * If reset was successful, affected devices no longer need
@@ -2396,7 +2377,7 @@  static void vfio_pci_try_bus_reset(struct vfio_pci_device *vdev)
 				vfio_pci_set_power_state(tmp, PCI_D3hot);
 		}
 
-		vfio_device_put(devs.devices[i]);
+		vfio_device_put(&tmp->vdev);
 	}
 
 	kfree(devs.devices);