@@ -1309,20 +1309,62 @@ static struct device *vfio_mdev_get_iommu_device(struct device *dev)
return NULL;
}
+static int vfio_mdev_set_domain(struct device *dev, struct iommu_domain *domain)
+{
+ void (*fn)(struct device *dev, void *domain);
+
+ fn = symbol_get(mdev_set_iommu_domain);
+ if (fn) {
+ fn(dev, domain);
+ symbol_put(mdev_set_iommu_domain);
+
+ return 0;
+ }
+
+ return -EINVAL;
+}
+
+static struct iommu_domain *vfio_mdev_get_domain(struct device *dev)
+{
+ void *(*fn)(struct device *dev);
+
+ fn = symbol_get(mdev_get_iommu_domain);
+ if (fn) {
+ struct iommu_domain *domain;
+
+ domain = fn(dev);
+ symbol_put(mdev_get_iommu_domain);
+
+ return domain;
+ }
+
+ return NULL;
+}
+
static int vfio_mdev_attach_domain(struct device *dev, void *data)
{
- struct iommu_domain *domain = data;
+ struct iommu_domain *domain;
struct device *iommu_device;
+ int ret = -ENODEV;
+
+ /* Only single domain is allowed to attach to an mdev. */
+ domain = vfio_mdev_get_domain(dev);
+ if (domain)
+ return -EINVAL;
+ domain = data;
iommu_device = vfio_mdev_get_iommu_device(dev);
if (iommu_device) {
if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
- return iommu_aux_attach_device(domain, iommu_device);
+ ret = iommu_aux_attach_device(domain, iommu_device);
else
- return iommu_attach_device(domain, iommu_device);
+ ret = iommu_attach_device(domain, iommu_device);
}
- return -EINVAL;
+ if (!ret)
+ vfio_mdev_set_domain(dev, domain);
+
+ return ret;
}
static int vfio_mdev_detach_domain(struct device *dev, void *data)
@@ -1338,6 +1380,8 @@ static int vfio_mdev_detach_domain(struct device *dev, void *data)
iommu_detach_device(domain, iommu_device);
}
+ vfio_mdev_set_domain(dev, NULL);
+
return 0;
}