@@ -2635,18 +2635,37 @@ static int vfio_iommu_resv_refresh(struct vfio_iommu *iommu,
return ret;
}
+static struct device *vfio_get_iommu_device(struct vfio_group *group,
+ struct device *dev)
+{
+ if (group->mdev_group)
+ return vfio_mdev_get_iommu_device(dev);
+ else
+ return dev;
+}
+
static int vfio_dev_bind_gpasid_fn(struct device *dev, void *data)
{
struct domain_capsule *dc = (struct domain_capsule *)data;
unsigned long arg = *(unsigned long *)dc->data;
+ struct device *iommu_device;
+
+ iommu_device = vfio_get_iommu_device(dc->group, dev);
+ if (!iommu_device)
+ return -EINVAL;
- return iommu_uapi_sva_bind_gpasid(dc->domain, dev,
+ return iommu_uapi_sva_bind_gpasid(dc->domain, iommu_device,
(void __user *)arg);
}
static int vfio_dev_unbind_gpasid_fn(struct device *dev, void *data)
{
struct domain_capsule *dc = (struct domain_capsule *)data;
+ struct device *iommu_device;
+
+ iommu_device = vfio_get_iommu_device(dc->group, dev);
+ if (!iommu_device)
+ return -EINVAL;
/*
* dc->user is a toggle for the unbind operation. When user
@@ -2659,12 +2678,12 @@ static int vfio_dev_unbind_gpasid_fn(struct device *dev, void *data)
if (dc->user) {
unsigned long arg = *(unsigned long *)dc->data;
- iommu_uapi_sva_unbind_gpasid(dc->domain,
- dev, (void __user *)arg);
+ iommu_uapi_sva_unbind_gpasid(dc->domain, iommu_device,
+ (void __user *)arg);
} else {
ioasid_t pasid = *(ioasid_t *)dc->data;
- iommu_sva_unbind_gpasid(dc->domain, dev, pasid);
+ iommu_sva_unbind_gpasid(dc->domain, iommu_device, pasid);
}
return 0;
}
@@ -3295,8 +3314,14 @@ static int vfio_dev_cache_invalidate_fn(struct device *dev, void *data)
{
struct domain_capsule *dc = (struct domain_capsule *)data;
unsigned long arg = *(unsigned long *)dc->data;
+ struct device *iommu_device;
+
+ iommu_device = vfio_get_iommu_device(dc->group, dev);
+ if (!iommu_device)
+ return -EINVAL;
- iommu_uapi_cache_invalidate(dc->domain, dev, (void __user *)arg);
+ iommu_uapi_cache_invalidate(dc->domain, iommu_device,
+ (void __user *)arg);
return 0;
}