@@ -16,6 +16,7 @@
#include <linux/uuid.h>
#include <linux/sysfs.h>
#include <linux/mdev.h>
+#include <linux/iommu.h>
#include "mdev_private.h"
@@ -392,7 +393,13 @@ int mdev_device_remove(struct device *dev, bool force_remove)
static int __init mdev_init(void)
{
- return mdev_bus_register();
+ int ret;
+
+ ret = mdev_bus_register();
+ if (!ret)
+ iommu_set_bus(&mdev_bus_type);
+
+ return ret;
}
static void __exit mdev_exit(void)
@@ -21,6 +21,13 @@ static int mdev_attach_iommu(struct mdev_device *mdev)
int ret;
struct iommu_group *group;
+ /*
+ * If iommu_ops is set for bus, add_device() will allocate
+ * a group and add the device in the group.
+ */
+ if (iommu_present(mdev->dev.bus))
+ return 0;
+
group = iommu_group_alloc();
if (IS_ERR(group))
return PTR_ERR(group);
@@ -36,6 +43,9 @@ static int mdev_attach_iommu(struct mdev_device *mdev)
static void mdev_detach_iommu(struct mdev_device *mdev)
{
+ if (iommu_present(mdev->dev.bus))
+ return;
+
iommu_group_remove_device(&mdev->dev);
dev_info(&mdev->dev, "MDEV: detaching iommu\n");
}