@@ -823,10 +823,39 @@ void vfio_uninit_group_dev(struct vfio_device *device)
}
EXPORT_SYMBOL_GPL(vfio_uninit_group_dev);
+struct vfio_group *vfio_group_find_or_alloc(struct device *dev)
+{
+ struct iommu_group *iommu_group;
+ struct vfio_group *group;
+
+ iommu_group = vfio_iommu_group_get(dev);
+ if (!iommu_group)
+ return ERR_PTR(-EINVAL);
+
+ /* a found vfio_group already holds a reference to the iommu_group */
+ group = vfio_group_get_from_iommu(iommu_group);
+ if (group)
+ goto out_put;
+
+ /* a newly created vfio_group keeps the reference. */
+ group = vfio_create_group(iommu_group);
+ if (IS_ERR(group))
+ goto out_remove;
+ return group;
+
+out_remove:
+#ifdef CONFIG_VFIO_NOIOMMU
+ if (iommu_group_get_iommudata(iommu_group) == &noiommu)
+ iommu_group_remove_device(dev);
+#endif
+out_put:
+ iommu_group_put(iommu_group);
+ return group;
+}
+
int vfio_register_group_dev(struct vfio_device *device)
{
struct vfio_device *existing_device;
- struct iommu_group *iommu_group;
struct vfio_group *group;
/*
@@ -836,36 +865,17 @@ int vfio_register_group_dev(struct vfio_device *device)
if (!device->dev_set)
vfio_assign_device_set(device, device);
- iommu_group = vfio_iommu_group_get(device->dev);
- if (!iommu_group)
- return -EINVAL;
-
- group = vfio_group_get_from_iommu(iommu_group);
- if (!group) {
- group = vfio_create_group(iommu_group);
- if (IS_ERR(group)) {
-#ifdef CONFIG_VFIO_NOIOMMU
- if (iommu_group_get_iommudata(iommu_group) == &noiommu)
- iommu_group_remove_device(device->dev);
-#endif
- iommu_group_put(iommu_group);
- return PTR_ERR(group);
- }
- } else {
- /*
- * A found vfio_group already holds a reference to the
- * iommu_group. A created vfio_group keeps the reference.
- */
- iommu_group_put(iommu_group);
- }
+ group = vfio_group_find_or_alloc(device->dev);
+ if (IS_ERR(group))
+ return PTR_ERR(group);
existing_device = vfio_group_get_device(group, device->dev);
if (existing_device) {
dev_WARN(device->dev, "Device already exists on group %d\n",
- iommu_group_id(iommu_group));
+ iommu_group_id(group->iommu_group));
vfio_device_put(existing_device);
#ifdef CONFIG_VFIO_NOIOMMU
- if (iommu_group_get_iommudata(iommu_group) == &noiommu)
+ if (iommu_group_get_iommudata(group->iommu_group) == &noiommu)
iommu_group_remove_device(device->dev);
#endif
vfio_group_put(group);