diff mbox series

[V12,1/6] mdev: make mdev bus agnostic

Message ID 20191118061703.8669-2-jasowang@redhat.com (mailing list archive)
State New, archived
Headers show
Series mdev based hardware virtio offloading support | expand

Commit Message

Jason Wang Nov. 18, 2019, 6:16 a.m. UTC
Current mdev is tied to a VFIO specific "mdev" bus. This prevent mdev
from being used by other types of API/buses. So this patch tries to make
mdev bus agnostic through making a mdev core a thin module:

- decouple VFIO bus specific bits from mdev_core.c to mdev_vfio.c and
  introduce mdev_vfio module
- require to specify the type of bus when registering mdev device and
  mdev driver

With those modifications mdev become a generic module that could be
used by multiple types of virtual buses and devices.

Signed-off-by: Jason Wang <jasowang@redhat.com>
---
 .../driver-api/vfio-mediated-device.rst       |  68 ++++++------
 MAINTAINERS                                   |   1 +
 drivers/gpu/drm/i915/gvt/kvmgt.c              |  10 +-
 drivers/s390/cio/vfio_ccw_ops.c               |   6 +-
 drivers/s390/crypto/vfio_ap_ops.c             |  21 ++--
 drivers/s390/crypto/vfio_ap_private.h         |   2 +-
 drivers/vfio/mdev/Kconfig                     |  17 ++-
 drivers/vfio/mdev/Makefile                    |   4 +-
 drivers/vfio/mdev/mdev_core.c                 | 104 +++++++++++++-----
 drivers/vfio/mdev/mdev_driver.c               |  29 ++---
 drivers/vfio/mdev/mdev_private.h              |  13 ++-
 drivers/vfio/mdev/mdev_vfio.c                 |  48 ++++++++
 drivers/vfio/mdev/vfio_mdev.c                 |   5 +-
 drivers/vfio/vfio_iommu_type1.c               |   6 +-
 include/linux/mdev.h                          |  16 ++-
 include/linux/mdev_vfio.h                     |  25 +++++
 samples/vfio-mdev/mbochs.c                    |   8 +-
 samples/vfio-mdev/mdpy.c                      |   8 +-
 samples/vfio-mdev/mtty.c                      |   8 +-
 19 files changed, 270 insertions(+), 129 deletions(-)
 create mode 100644 drivers/vfio/mdev/mdev_vfio.c
 create mode 100644 include/linux/mdev_vfio.h
diff mbox series

Patch

diff --git a/Documentation/driver-api/vfio-mediated-device.rst b/Documentation/driver-api/vfio-mediated-device.rst
index 25eb7d5b834b..1887d27a565e 100644
--- a/Documentation/driver-api/vfio-mediated-device.rst
+++ b/Documentation/driver-api/vfio-mediated-device.rst
@@ -49,35 +49,37 @@  devices as examples, as these devices are the first devices to use this module::
 
      +---------------+
      |               |
-     | +-----------+ |  mdev_register_driver() +--------------+
-     | |           | +<------------------------+              |
-     | |  mdev     | |                         |              |
-     | |  bus      | +------------------------>+ vfio_mdev.ko |<-> VFIO user
-     | |  driver   | |     probe()/remove()    |              |    APIs
-     | |           | |                         +--------------+
-     | +-----------+ |
+     |   MDEV CORE   |  mdev_register_driver() +--------------+
+     |    MODULE     +<------------------------+              |
+     |    mdev.ko    |                         |              |
+     |               +------------------------>+ vfio_mdev.ko |<-> VFIO user
+     |               |     probe()/remove()    |              |    APIs
+     |               |                         +--------------+
+     +---+-------+---+
+         |      /|\
+         |       |
+callbacks|       | mdev_register_device()
+         |       | mdev_register_bus()
+        \|/      |
+     +---+-------+---+
+     |               |  mdev_vfio_register_device() +--------------+
+     |               +<-----------------------------+              |
+     |               |                              |  nvidia.ko   |<-> physical
+     |               +----------------------------->+              |    device
+     |   MDEV VFIO   |        callbacks             +--------------+
+     |   Physical    |
+     |    device     |  mdev_vfio_register_device() +--------------+
+     |   interface   |<-----------------------------+              |
+     |               |                              |  i915.ko     |<-> physical
+     | mdev_vfio.ko  +----------------------------->+              |    device
+     |               |        callbacks             +--------------+
+     |               |
+     |               |  mdev_vfio_register_device() +--------------+
+     |               +<-----------------------------+              |
+     |               |                              | ccw_device.ko|<-> physical
+     |               +----------------------------->+              |    device
+     |               |        callbacks             +--------------+
      |               |
-     |  MDEV CORE    |
-     |   MODULE      |
-     |   mdev.ko     |
-     | +-----------+ |  mdev_register_device() +--------------+
-     | |           | +<------------------------+              |
-     | |           | |                         |  nvidia.ko   |<-> physical
-     | |           | +------------------------>+              |    device
-     | |           | |        callbacks        +--------------+
-     | | Physical  | |
-     | |  device   | |  mdev_register_device() +--------------+
-     | | interface | |<------------------------+              |
-     | |           | |                         |  i915.ko     |<-> physical
-     | |           | +------------------------>+              |    device
-     | |           | |        callbacks        +--------------+
-     | |           | |
-     | |           | |  mdev_register_device() +--------------+
-     | |           | +<------------------------+              |
-     | |           | |                         | ccw_device.ko|<-> physical
-     | |           | +------------------------>+              |    device
-     | |           | |        callbacks        +--------------+
-     | +-----------+ |
      +---------------+
 
 
@@ -116,7 +118,8 @@  to register and unregister itself with the core driver:
 * Register::
 
     extern int  mdev_register_driver(struct mdev_driver *drv,
-				   struct module *owner);
+                                     struct module *owner,
+                                     struct bus_type *bus);
 
 * Unregister::
 
@@ -159,11 +162,12 @@  The callbacks in the mdev_parent_ops structure are as follows:
 * write: write emulation callback
 * mmap: mmap emulation callback
 
-A driver should use the mdev_parent_ops structure in the function call to
-register itself with the mdev core driver::
+A driver should use the mdev_parent_ops structure and bus type in the
+function call to register itself with the mdev core driver::
 
 	extern int  mdev_register_device(struct device *dev,
-	                                 const struct mdev_parent_ops *ops);
+	                                 const struct mdev_parent_ops *ops,
+                                         struct bus_type *bus);
 
 However, the mdev_parent_ops structure is not required in the function call
 that a driver should use to unregister itself with the mdev core driver::
diff --git a/MAINTAINERS b/MAINTAINERS
index cba1095547fd..d335949240dc 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -17121,6 +17121,7 @@  S:	Maintained
 F:	Documentation/driver-api/vfio-mediated-device.rst
 F:	drivers/vfio/mdev/
 F:	include/linux/mdev.h
+F:	include/linux/mdev_vfio.h
 F:	samples/vfio-mdev/
 
 VFIO PLATFORM DRIVER
diff --git a/drivers/gpu/drm/i915/gvt/kvmgt.c b/drivers/gpu/drm/i915/gvt/kvmgt.c
index 343d79c1cb7e..8c02572c9b42 100644
--- a/drivers/gpu/drm/i915/gvt/kvmgt.c
+++ b/drivers/gpu/drm/i915/gvt/kvmgt.c
@@ -41,7 +41,7 @@ 
 #include <linux/uuid.h>
 #include <linux/kvm_host.h>
 #include <linux/vfio.h>
-#include <linux/mdev.h>
+#include <linux/mdev_vfio.h>
 #include <linux/debugfs.h>
 
 #include <linux/nospec.h>
@@ -1554,7 +1554,7 @@  static ssize_t
 vgpu_id_show(struct device *dev, struct device_attribute *attr,
 	     char *buf)
 {
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 
 	if (mdev) {
 		struct intel_vgpu *vgpu = (struct intel_vgpu *)
@@ -1568,7 +1568,7 @@  static ssize_t
 hw_id_show(struct device *dev, struct device_attribute *attr,
 	   char *buf)
 {
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 
 	if (mdev) {
 		struct intel_vgpu *vgpu = (struct intel_vgpu *)
@@ -1623,12 +1623,12 @@  static int kvmgt_host_init(struct device *dev, void *gvt, const void *ops)
 		return -EFAULT;
 	intel_vgpu_ops.supported_type_groups = kvm_vgpu_type_groups;
 
-	return mdev_register_device(dev, &intel_vgpu_ops);
+	return mdev_vfio_register_device(dev, &intel_vgpu_ops);
 }
 
 static void kvmgt_host_exit(struct device *dev)
 {
-	mdev_unregister_device(dev);
+	mdev_vfio_unregister_device(dev);
 }
 
 static int kvmgt_page_track_add(unsigned long handle, u64 gfn)
diff --git a/drivers/s390/cio/vfio_ccw_ops.c b/drivers/s390/cio/vfio_ccw_ops.c
index f0d71ab77c50..791b8b0eb027 100644
--- a/drivers/s390/cio/vfio_ccw_ops.c
+++ b/drivers/s390/cio/vfio_ccw_ops.c
@@ -11,7 +11,7 @@ 
  */
 
 #include <linux/vfio.h>
-#include <linux/mdev.h>
+#include <linux/mdev_vfio.h>
 #include <linux/nospec.h>
 #include <linux/slab.h>
 
@@ -588,10 +588,10 @@  static const struct mdev_parent_ops vfio_ccw_mdev_ops = {
 
 int vfio_ccw_mdev_reg(struct subchannel *sch)
 {
-	return mdev_register_device(&sch->dev, &vfio_ccw_mdev_ops);
+	return mdev_vfio_register_device(&sch->dev, &vfio_ccw_mdev_ops);
 }
 
 void vfio_ccw_mdev_unreg(struct subchannel *sch)
 {
-	mdev_unregister_device(&sch->dev);
+	mdev_vfio_unregister_device(&sch->dev);
 }
diff --git a/drivers/s390/crypto/vfio_ap_ops.c b/drivers/s390/crypto/vfio_ap_ops.c
index 5c0f53c6dde7..78048e670374 100644
--- a/drivers/s390/crypto/vfio_ap_ops.c
+++ b/drivers/s390/crypto/vfio_ap_ops.c
@@ -602,7 +602,7 @@  static ssize_t assign_adapter_store(struct device *dev,
 {
 	int ret;
 	unsigned long apid;
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct ap_matrix_mdev *matrix_mdev = mdev_get_drvdata(mdev);
 
 	/* If the guest is running, disallow assignment of adapter */
@@ -668,7 +668,7 @@  static ssize_t unassign_adapter_store(struct device *dev,
 {
 	int ret;
 	unsigned long apid;
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct ap_matrix_mdev *matrix_mdev = mdev_get_drvdata(mdev);
 
 	/* If the guest is running, disallow un-assignment of adapter */
@@ -748,7 +748,7 @@  static ssize_t assign_domain_store(struct device *dev,
 {
 	int ret;
 	unsigned long apqi;
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct ap_matrix_mdev *matrix_mdev = mdev_get_drvdata(mdev);
 	unsigned long max_apqi = matrix_mdev->matrix.aqm_max;
 
@@ -810,7 +810,7 @@  static ssize_t unassign_domain_store(struct device *dev,
 {
 	int ret;
 	unsigned long apqi;
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct ap_matrix_mdev *matrix_mdev = mdev_get_drvdata(mdev);
 
 	/* If the guest is running, disallow un-assignment of domain */
@@ -854,7 +854,7 @@  static ssize_t assign_control_domain_store(struct device *dev,
 {
 	int ret;
 	unsigned long id;
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct ap_matrix_mdev *matrix_mdev = mdev_get_drvdata(mdev);
 
 	/* If the guest is running, disallow assignment of control domain */
@@ -903,7 +903,7 @@  static ssize_t unassign_control_domain_store(struct device *dev,
 {
 	int ret;
 	unsigned long domid;
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct ap_matrix_mdev *matrix_mdev = mdev_get_drvdata(mdev);
 	unsigned long max_domid =  matrix_mdev->matrix.adm_max;
 
@@ -933,7 +933,7 @@  static ssize_t control_domains_show(struct device *dev,
 	int nchars = 0;
 	int n;
 	char *bufpos = buf;
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct ap_matrix_mdev *matrix_mdev = mdev_get_drvdata(mdev);
 	unsigned long max_domid = matrix_mdev->matrix.adm_max;
 
@@ -952,7 +952,7 @@  static DEVICE_ATTR_RO(control_domains);
 static ssize_t matrix_show(struct device *dev, struct device_attribute *attr,
 			   char *buf)
 {
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct ap_matrix_mdev *matrix_mdev = mdev_get_drvdata(mdev);
 	char *bufpos = buf;
 	unsigned long apid;
@@ -1295,10 +1295,11 @@  int vfio_ap_mdev_register(void)
 {
 	atomic_set(&matrix_dev->available_instances, MAX_ZDEV_ENTRIES_EXT);
 
-	return mdev_register_device(&matrix_dev->device, &vfio_ap_matrix_ops);
+	return mdev_vfio_register_device(&matrix_dev->device,
+					 &vfio_ap_matrix_ops);
 }
 
 void vfio_ap_mdev_unregister(void)
 {
-	mdev_unregister_device(&matrix_dev->device);
+	mdev_vfio_unregister_device(&matrix_dev->device);
 }
diff --git a/drivers/s390/crypto/vfio_ap_private.h b/drivers/s390/crypto/vfio_ap_private.h
index f46dde56b464..4e37e0e3433a 100644
--- a/drivers/s390/crypto/vfio_ap_private.h
+++ b/drivers/s390/crypto/vfio_ap_private.h
@@ -14,7 +14,7 @@ 
 
 #include <linux/types.h>
 #include <linux/device.h>
-#include <linux/mdev.h>
+#include <linux/mdev_vfio.h>
 #include <linux/delay.h>
 #include <linux/mutex.h>
 #include <linux/kvm_host.h>
diff --git a/drivers/vfio/mdev/Kconfig b/drivers/vfio/mdev/Kconfig
index 5da27f2100f9..2e07ca915a96 100644
--- a/drivers/vfio/mdev/Kconfig
+++ b/drivers/vfio/mdev/Kconfig
@@ -1,15 +1,24 @@ 
-# SPDX-License-Identifier: GPL-2.0-only
 
-config VFIO_MDEV
+config MDEV
 	tristate "Mediated device driver framework"
-	depends on VFIO
 	default n
 	help
 	  Provides a framework to virtualize devices.
-	  See Documentation/driver-api/vfio-mediated-device.rst for more details.
 
 	  If you don't know what do here, say N.
 
+config VFIO_MDEV
+	tristate "VFIO Mediated device driver"
+        depends on VFIO && MDEV
+        default n
+	help
+	  Proivdes a mediated BUS for userspace driver through VFIO
+	  framework. See Documentation/vfio-mediated-device.txt for
+	  more details.
+
+	  If you don't know what do here, say N.
+
+
 config VFIO_MDEV_DEVICE
 	tristate "VFIO driver for Mediated devices"
 	depends on VFIO && VFIO_MDEV
diff --git a/drivers/vfio/mdev/Makefile b/drivers/vfio/mdev/Makefile
index 101516fdf375..e9675501271a 100644
--- a/drivers/vfio/mdev/Makefile
+++ b/drivers/vfio/mdev/Makefile
@@ -1,6 +1,6 @@ 
-# SPDX-License-Identifier: GPL-2.0-only
 
 mdev-y := mdev_core.o mdev_sysfs.o mdev_driver.o
 
-obj-$(CONFIG_VFIO_MDEV) += mdev.o
+obj-$(CONFIG_MDEV) += mdev.o
+obj-$(CONFIG_VFIO_MDEV) += mdev_vfio.o
 obj-$(CONFIG_VFIO_MDEV_DEVICE) += vfio_mdev.o
diff --git a/drivers/vfio/mdev/mdev_core.c b/drivers/vfio/mdev/mdev_core.c
index b558d4cfd082..e1272a40c521 100644
--- a/drivers/vfio/mdev/mdev_core.c
+++ b/drivers/vfio/mdev/mdev_core.c
@@ -22,11 +22,13 @@ 
 
 static LIST_HEAD(parent_list);
 static DEFINE_MUTEX(parent_list_lock);
-static struct class_compat *mdev_bus_compat_class;
 
 static LIST_HEAD(mdev_list);
 static DEFINE_MUTEX(mdev_list_lock);
 
+static LIST_HEAD(class_compat_list);
+static DEFINE_MUTEX(compat_list_lock);
+
 struct device *mdev_parent_dev(struct mdev_device *mdev)
 {
 	return mdev->parent->dev;
@@ -51,9 +53,9 @@  struct device *mdev_dev(struct mdev_device *mdev)
 }
 EXPORT_SYMBOL(mdev_dev);
 
-struct mdev_device *mdev_from_dev(struct device *dev)
+struct mdev_device *mdev_from_dev(struct device *dev, struct bus_type *bus)
 {
-	return dev_is_mdev(dev) ? to_mdev_device(dev) : NULL;
+	return dev_is_mdev(dev, bus) ? to_mdev_device(dev) : NULL;
 }
 EXPORT_SYMBOL(mdev_from_dev);
 
@@ -122,7 +124,9 @@  static void mdev_device_remove_common(struct mdev_device *mdev)
 
 static int mdev_device_remove_cb(struct device *dev, void *data)
 {
-	if (dev_is_mdev(dev)) {
+	struct bus_type *bus = data;
+
+	if (dev_is_mdev(dev, bus)) {
 		struct mdev_device *mdev;
 
 		mdev = to_mdev_device(dev);
@@ -131,6 +135,41 @@  static int mdev_device_remove_cb(struct device *dev, void *data)
 	return 0;
 }
 
+static struct mdev_class_compat *get_class_compat(struct bus_type *bus)
+{
+	struct mdev_class_compat *mdev_class_compat;
+
+	list_for_each_entry(mdev_class_compat, &class_compat_list, next) {
+		if (mdev_class_compat->bus == bus)
+			return mdev_class_compat;
+	}
+
+	return NULL;
+}
+
+static struct class_compat *mdev_alloc_class_compat(struct bus_type *bus)
+{
+	struct mdev_class_compat *mdev_class_compat = get_class_compat(bus);
+	char class_name[64];
+
+	if (mdev_class_compat)
+		return mdev_class_compat->class_compat;
+
+	mdev_class_compat = kmalloc(sizeof(*mdev_class_compat), GFP_KERNEL);
+	if (!mdev_class_compat)
+		return NULL;
+	snprintf(class_name, 64, "%s_bus", bus->name);
+	mdev_class_compat->class_compat = class_compat_register(class_name);
+	if (!mdev_class_compat->class_compat) {
+		kfree(mdev_class_compat);
+		return NULL;
+	}
+	mdev_class_compat->bus = bus;
+	list_add(&mdev_class_compat->next, &class_compat_list);
+
+	return mdev_class_compat->class_compat;
+}
+
 /*
  * mdev_register_device : Register a device
  * @dev: device structure representing parent device.
@@ -139,12 +178,14 @@  static int mdev_device_remove_cb(struct device *dev, void *data)
  * Add device to list of registered parent devices.
  * Returns a negative value on error, otherwise 0.
  */
-int mdev_register_device(struct device *dev, const struct mdev_parent_ops *ops)
+int mdev_register_device(struct device *dev, const struct mdev_parent_ops *ops,
+			 struct bus_type *bus)
 {
 	int ret;
 	struct mdev_parent *parent;
 	char *env_string = "MDEV_STATE=registered";
 	char *envp[] = { env_string, NULL };
+	struct class_compat *class_compat;
 
 	/* check for mandatory ops */
 	if (!ops || !ops->create || !ops->remove || !ops->supported_type_groups)
@@ -175,20 +216,21 @@  int mdev_register_device(struct device *dev, const struct mdev_parent_ops *ops)
 
 	parent->dev = dev;
 	parent->ops = ops;
+	parent->bus = bus;
 
-	if (!mdev_bus_compat_class) {
-		mdev_bus_compat_class = class_compat_register("mdev_bus");
-		if (!mdev_bus_compat_class) {
-			ret = -ENOMEM;
-			goto add_dev_err;
-		}
+	mutex_lock(&compat_list_lock);
+	class_compat = mdev_alloc_class_compat(bus);
+	mutex_unlock(&compat_list_lock);
+	if (!class_compat) {
+		ret = -ENOMEM;
+		goto add_dev_err;
 	}
 
 	ret = parent_create_sysfs_files(parent);
 	if (ret)
 		goto add_dev_err;
 
-	ret = class_compat_create_link(mdev_bus_compat_class, dev, NULL);
+	ret = class_compat_create_link(class_compat, dev, NULL);
 	if (ret)
 		dev_warn(dev, "Failed to create compatibility class link\n");
 
@@ -223,6 +265,7 @@  void mdev_unregister_device(struct device *dev)
 	struct mdev_parent *parent;
 	char *env_string = "MDEV_STATE=unregistered";
 	char *envp[] = { env_string, NULL };
+	struct mdev_class_compat *mdev_class_compat;
 
 	mutex_lock(&parent_list_lock);
 	parent = __find_parent_device(dev);
@@ -238,9 +281,13 @@  void mdev_unregister_device(struct device *dev)
 
 	down_write(&parent->unreg_sem);
 
-	class_compat_remove_link(mdev_bus_compat_class, dev, NULL);
+	mutex_lock(&compat_list_lock);
+	mdev_class_compat = get_class_compat(parent->bus);
+	WARN_ON(!mdev_class_compat);
+	class_compat_remove_link(mdev_class_compat->class_compat, dev, NULL);
+	mutex_unlock(&compat_list_lock);
 
-	device_for_each_child(dev, NULL, mdev_device_remove_cb);
+	device_for_each_child(dev, parent->bus, mdev_device_remove_cb);
 
 	parent_remove_sysfs_files(parent);
 	up_write(&parent->unreg_sem);
@@ -314,7 +361,7 @@  int mdev_device_create(struct kobject *kobj,
 
 	device_initialize(&mdev->dev);
 	mdev->dev.parent  = dev;
-	mdev->dev.bus     = &mdev_bus_type;
+	mdev->dev.bus     = parent->bus;
 	mdev->dev.release = mdev_device_release;
 	dev_set_name(&mdev->dev, "%pUl", uuid);
 	mdev->dev.groups = parent->ops->mdev_attr_groups;
@@ -404,24 +451,29 @@  struct device *mdev_get_iommu_device(struct device *dev)
 }
 EXPORT_SYMBOL(mdev_get_iommu_device);
 
-static int __init mdev_init(void)
+int mdev_register_bus(struct bus_type *bus)
 {
-	return mdev_bus_register();
+	return bus_register(bus);
 }
+EXPORT_SYMBOL(mdev_register_bus);
 
-static void __exit mdev_exit(void)
+void mdev_unregister_bus(struct bus_type *bus)
 {
-	if (mdev_bus_compat_class)
-		class_compat_unregister(mdev_bus_compat_class);
-
-	mdev_bus_unregister();
+	struct mdev_class_compat *mdev_class_compat;
+
+	mutex_lock(&compat_list_lock);
+	mdev_class_compat = get_class_compat(bus);
+	if (mdev_class_compat) {
+		list_del(&mdev_class_compat->next);
+		class_compat_unregister(mdev_class_compat->class_compat);
+		kfree(mdev_class_compat);
+	}
+	bus_unregister(bus);
+	mutex_unlock(&compat_list_lock);
 }
-
-module_init(mdev_init)
-module_exit(mdev_exit)
+EXPORT_SYMBOL(mdev_unregister_bus);
 
 MODULE_VERSION(DRIVER_VERSION);
 MODULE_LICENSE("GPL v2");
 MODULE_AUTHOR(DRIVER_AUTHOR);
 MODULE_DESCRIPTION(DRIVER_DESC);
-MODULE_SOFTDEP("post: vfio_mdev");
diff --git a/drivers/vfio/mdev/mdev_driver.c b/drivers/vfio/mdev/mdev_driver.c
index 0d3223aee20b..c3a2ac023712 100644
--- a/drivers/vfio/mdev/mdev_driver.c
+++ b/drivers/vfio/mdev/mdev_driver.c
@@ -10,6 +10,7 @@ 
 #include <linux/device.h>
 #include <linux/iommu.h>
 #include <linux/mdev.h>
+#include <linux/mdev_vfio.h>
 
 #include "mdev_private.h"
 
@@ -37,7 +38,7 @@  static void mdev_detach_iommu(struct mdev_device *mdev)
 	dev_info(&mdev->dev, "MDEV: detaching iommu\n");
 }
 
-static int mdev_probe(struct device *dev)
+int mdev_probe(struct device *dev)
 {
 	struct mdev_driver *drv = to_mdev_driver(dev->driver);
 	struct mdev_device *mdev = to_mdev_device(dev);
@@ -55,8 +56,9 @@  static int mdev_probe(struct device *dev)
 
 	return ret;
 }
+EXPORT_SYMBOL(mdev_probe);
 
-static int mdev_remove(struct device *dev)
+int mdev_remove(struct device *dev)
 {
 	struct mdev_driver *drv = to_mdev_driver(dev->driver);
 	struct mdev_device *mdev = to_mdev_device(dev);
@@ -68,26 +70,22 @@  static int mdev_remove(struct device *dev)
 
 	return 0;
 }
-
-struct bus_type mdev_bus_type = {
-	.name		= "mdev",
-	.probe		= mdev_probe,
-	.remove		= mdev_remove,
-};
-EXPORT_SYMBOL_GPL(mdev_bus_type);
+EXPORT_SYMBOL(mdev_remove);
 
 /**
  * mdev_register_driver - register a new MDEV driver
  * @drv: the driver to register
  * @owner: module owner of driver to be registered
+ * @bus: but that the driver wants to attach
  *
  * Returns a negative value on error, otherwise 0.
  **/
-int mdev_register_driver(struct mdev_driver *drv, struct module *owner)
+int mdev_register_driver(struct mdev_driver *drv, struct module *owner,
+			 struct bus_type *bus)
 {
 	/* initialize common driver fields */
 	drv->driver.name = drv->name;
-	drv->driver.bus = &mdev_bus_type;
+	drv->driver.bus = bus;
 	drv->driver.owner = owner;
 
 	/* register with core */
@@ -105,12 +103,3 @@  void mdev_unregister_driver(struct mdev_driver *drv)
 }
 EXPORT_SYMBOL(mdev_unregister_driver);
 
-int mdev_bus_register(void)
-{
-	return bus_register(&mdev_bus_type);
-}
-
-void mdev_bus_unregister(void)
-{
-	bus_unregister(&mdev_bus_type);
-}
diff --git a/drivers/vfio/mdev/mdev_private.h b/drivers/vfio/mdev/mdev_private.h
index 7d922950caaf..298d7a0f493a 100644
--- a/drivers/vfio/mdev/mdev_private.h
+++ b/drivers/vfio/mdev/mdev_private.h
@@ -10,12 +10,10 @@ 
 #ifndef MDEV_PRIVATE_H
 #define MDEV_PRIVATE_H
 
-int  mdev_bus_register(void);
-void mdev_bus_unregister(void);
-
 struct mdev_parent {
 	struct device *dev;
 	const struct mdev_parent_ops *ops;
+	struct bus_type *bus;
 	struct kref ref;
 	struct list_head next;
 	struct kset *mdev_types_kset;
@@ -35,8 +33,15 @@  struct mdev_device {
 	bool active;
 };
 
+struct mdev_class_compat {
+	struct class_compat *class_compat;
+	struct bus_type *bus;
+	struct list_head next;
+};
+
+
 #define to_mdev_device(dev)	container_of(dev, struct mdev_device, dev)
-#define dev_is_mdev(d)		((d)->bus == &mdev_bus_type)
+#define dev_is_mdev(d, bus)	((d)->bus == bus)
 
 struct mdev_type {
 	struct kobject kobj;
diff --git a/drivers/vfio/mdev/mdev_vfio.c b/drivers/vfio/mdev/mdev_vfio.c
new file mode 100644
index 000000000000..f9d1191b9982
--- /dev/null
+++ b/drivers/vfio/mdev/mdev_vfio.c
@@ -0,0 +1,48 @@ 
+// SPDX-License-Identifier: GPL-2.0-only
+#include <linux/module.h>
+#include <linux/uuid.h>
+#include <linux/device.h>
+#include <linux/mdev_vfio.h>
+
+#define DRIVER_VERSION		"0.1"
+#define DRIVER_AUTHOR		"Jason Wang"
+#define DRIVER_DESC		"Mediated VFIO bus"
+
+struct bus_type mdev_vfio_bus_type = {
+	.name		= "mdev",
+	.probe		= mdev_probe,
+	.remove		= mdev_remove,
+};
+EXPORT_SYMBOL(mdev_vfio_bus_type);
+
+static int __init mdev_init(void)
+{
+	return mdev_register_bus(&mdev_vfio_bus_type);
+}
+
+static void __exit mdev_exit(void)
+{
+	mdev_unregister_bus(&mdev_vfio_bus_type);
+}
+
+int mdev_vfio_register_device(struct device *dev,
+			      const struct mdev_parent_ops *ops)
+{
+	return mdev_register_device(dev, ops, &mdev_vfio_bus_type);
+}
+EXPORT_SYMBOL(mdev_vfio_register_device);
+
+void mdev_vfio_unregister_device(struct device *dev)
+{
+	return mdev_unregister_device(dev);
+}
+EXPORT_SYMBOL(mdev_vfio_unregister_device);
+
+module_init(mdev_init)
+module_exit(mdev_exit)
+
+MODULE_VERSION(DRIVER_VERSION);
+MODULE_LICENSE("GPL v2");
+MODULE_AUTHOR(DRIVER_AUTHOR);
+MODULE_DESCRIPTION(DRIVER_DESC);
+MODULE_SOFTDEP("post: vfio_mdev");
diff --git a/drivers/vfio/mdev/vfio_mdev.c b/drivers/vfio/mdev/vfio_mdev.c
index 30964a4e0a28..16e9ebe30d4a 100644
--- a/drivers/vfio/mdev/vfio_mdev.c
+++ b/drivers/vfio/mdev/vfio_mdev.c
@@ -13,7 +13,7 @@ 
 #include <linux/kernel.h>
 #include <linux/slab.h>
 #include <linux/vfio.h>
-#include <linux/mdev.h>
+#include <linux/mdev_vfio.h>
 
 #include "mdev_private.h"
 
@@ -128,7 +128,8 @@  static struct mdev_driver vfio_mdev_driver = {
 
 static int __init vfio_mdev_init(void)
 {
-	return mdev_register_driver(&vfio_mdev_driver, THIS_MODULE);
+	return mdev_register_driver(&vfio_mdev_driver, THIS_MODULE,
+				    &mdev_vfio_bus_type);
 }
 
 static void __exit vfio_mdev_exit(void)
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index d864277ea16f..f35523f822eb 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -34,7 +34,7 @@ 
 #include <linux/uaccess.h>
 #include <linux/vfio.h>
 #include <linux/workqueue.h>
-#include <linux/mdev.h>
+#include <linux/mdev_vfio.h>
 #include <linux/notifier.h>
 #include <linux/dma-iommu.h>
 #include <linux/irqdomain.h>
@@ -1405,10 +1405,10 @@  static bool vfio_bus_is_mdev(struct bus_type *bus)
 	struct bus_type *mdev_bus;
 	bool ret = false;
 
-	mdev_bus = symbol_get(mdev_bus_type);
+	mdev_bus = symbol_get(mdev_vfio_bus_type);
 	if (mdev_bus) {
 		ret = (bus == mdev_bus);
-		symbol_put(mdev_bus_type);
+		symbol_put(mdev_vfio_bus_type);
 	}
 
 	return ret;
diff --git a/include/linux/mdev.h b/include/linux/mdev.h
index 0ce30ca78db0..ee2410246b3c 100644
--- a/include/linux/mdev.h
+++ b/include/linux/mdev.h
@@ -133,16 +133,22 @@  void *mdev_get_drvdata(struct mdev_device *mdev);
 void mdev_set_drvdata(struct mdev_device *mdev, void *data);
 const guid_t *mdev_uuid(struct mdev_device *mdev);
 
-extern struct bus_type mdev_bus_type;
-
-int mdev_register_device(struct device *dev, const struct mdev_parent_ops *ops);
+int mdev_register_device(struct device *dev, const struct mdev_parent_ops *ops,
+			 struct bus_type *bus);
 void mdev_unregister_device(struct device *dev);
 
-int mdev_register_driver(struct mdev_driver *drv, struct module *owner);
+int mdev_register_driver(struct mdev_driver *drv, struct module *owner,
+			 struct bus_type *bus);
 void mdev_unregister_driver(struct mdev_driver *drv);
 
 struct device *mdev_parent_dev(struct mdev_device *mdev);
 struct device *mdev_dev(struct mdev_device *mdev);
-struct mdev_device *mdev_from_dev(struct device *dev);
+struct mdev_device *mdev_from_dev(struct device *dev, struct bus_type *bus);
+
+int mdev_probe(struct device *dev);
+int mdev_remove(struct device *dev);
+
+int mdev_register_bus(struct bus_type *bus);
+void mdev_unregister_bus(struct bus_type *bus);
 
 #endif /* MDEV_H */
diff --git a/include/linux/mdev_vfio.h b/include/linux/mdev_vfio.h
new file mode 100644
index 000000000000..446a7537e3fb
--- /dev/null
+++ b/include/linux/mdev_vfio.h
@@ -0,0 +1,25 @@ 
+/* SPDX-License-Identifier: GPL-2.0-only */
+/*
+ * VFIO Mediated device definition
+ *
+ * Copyright (c) 2019, Red Hat. All rights reserved.
+ *     Author: Jason Wang <jasowang@redhat.com>
+ */
+
+#ifndef MDEV_VFIO_H
+#define MDEV_VFIO_H
+
+#include <linux/mdev.h>
+
+extern struct bus_type mdev_vfio_bus_type;
+
+int mdev_vfio_register_device(struct device *dev,
+			      const struct mdev_parent_ops *ops);
+void mdev_vfio_unregister_device(struct device *dev);
+
+static inline struct mdev_device *vfio_mdev_from_dev(struct device *dev)
+{
+	return mdev_from_dev(dev, &mdev_vfio_bus_type);
+}
+
+#endif
diff --git a/samples/vfio-mdev/mbochs.c b/samples/vfio-mdev/mbochs.c
index ac5c8c17b1ff..f041d58324b1 100644
--- a/samples/vfio-mdev/mbochs.c
+++ b/samples/vfio-mdev/mbochs.c
@@ -29,7 +29,7 @@ 
 #include <linux/vfio.h>
 #include <linux/iommu.h>
 #include <linux/sysfs.h>
-#include <linux/mdev.h>
+#include <linux/mdev_vfio.h>
 #include <linux/pci.h>
 #include <linux/dma-buf.h>
 #include <linux/highmem.h>
@@ -1332,7 +1332,7 @@  static ssize_t
 memory_show(struct device *dev, struct device_attribute *attr,
 	    char *buf)
 {
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
 
 	return sprintf(buf, "%d MB\n", mdev_state->type->mbytes);
@@ -1468,7 +1468,7 @@  static int __init mbochs_dev_init(void)
 	if (ret)
 		goto failed2;
 
-	ret = mdev_register_device(&mbochs_dev, &mdev_fops);
+	ret = mdev_vfio_register_device(&mbochs_dev, &mdev_fops);
 	if (ret)
 		goto failed3;
 
@@ -1487,7 +1487,7 @@  static int __init mbochs_dev_init(void)
 static void __exit mbochs_dev_exit(void)
 {
 	mbochs_dev.bus = NULL;
-	mdev_unregister_device(&mbochs_dev);
+	mdev_vfio_unregister_device(&mbochs_dev);
 
 	device_unregister(&mbochs_dev);
 	cdev_del(&mbochs_cdev);
diff --git a/samples/vfio-mdev/mdpy.c b/samples/vfio-mdev/mdpy.c
index cc86bf6566e4..9c32fe3795ad 100644
--- a/samples/vfio-mdev/mdpy.c
+++ b/samples/vfio-mdev/mdpy.c
@@ -25,7 +25,7 @@ 
 #include <linux/vfio.h>
 #include <linux/iommu.h>
 #include <linux/sysfs.h>
-#include <linux/mdev.h>
+#include <linux/mdev_vfio.h>
 #include <linux/pci.h>
 #include <drm/drm_fourcc.h>
 #include "mdpy-defs.h"
@@ -639,7 +639,7 @@  static ssize_t
 resolution_show(struct device *dev, struct device_attribute *attr,
 		char *buf)
 {
-	struct mdev_device *mdev = mdev_from_dev(dev);
+	struct mdev_device *mdev = vfio_mdev_from_dev(dev);
 	struct mdev_state *mdev_state = mdev_get_drvdata(mdev);
 
 	return sprintf(buf, "%dx%d\n",
@@ -775,7 +775,7 @@  static int __init mdpy_dev_init(void)
 	if (ret)
 		goto failed2;
 
-	ret = mdev_register_device(&mdpy_dev, &mdev_fops);
+	ret = mdev_vfio_register_device(&mdpy_dev, &mdev_fops);
 	if (ret)
 		goto failed3;
 
@@ -794,7 +794,7 @@  static int __init mdpy_dev_init(void)
 static void __exit mdpy_dev_exit(void)
 {
 	mdpy_dev.bus = NULL;
-	mdev_unregister_device(&mdpy_dev);
+	mdev_vfio_unregister_device(&mdpy_dev);
 
 	device_unregister(&mdpy_dev);
 	cdev_del(&mdpy_cdev);
diff --git a/samples/vfio-mdev/mtty.c b/samples/vfio-mdev/mtty.c
index ce84a300a4da..6e4e6339e0f1 100644
--- a/samples/vfio-mdev/mtty.c
+++ b/samples/vfio-mdev/mtty.c
@@ -26,7 +26,7 @@ 
 #include <linux/sysfs.h>
 #include <linux/ctype.h>
 #include <linux/file.h>
-#include <linux/mdev.h>
+#include <linux/mdev_vfio.h>
 #include <linux/pci.h>
 #include <linux/serial.h>
 #include <uapi/linux/serial_reg.h>
@@ -1285,7 +1285,7 @@  static ssize_t
 sample_mdev_dev_show(struct device *dev, struct device_attribute *attr,
 		     char *buf)
 {
-	if (mdev_from_dev(dev))
+	if (vfio_mdev_from_dev(dev))
 		return sprintf(buf, "This is MDEV %s\n", dev_name(dev));
 
 	return sprintf(buf, "\n");
@@ -1445,7 +1445,7 @@  static int __init mtty_dev_init(void)
 	if (ret)
 		goto failed2;
 
-	ret = mdev_register_device(&mtty_dev.dev, &mdev_fops);
+	ret = mdev_vfio_register_device(&mtty_dev.dev, &mdev_fops);
 	if (ret)
 		goto failed3;
 
@@ -1471,7 +1471,7 @@  static int __init mtty_dev_init(void)
 static void __exit mtty_dev_exit(void)
 {
 	mtty_dev.dev.bus = NULL;
-	mdev_unregister_device(&mtty_dev.dev);
+	mdev_vfio_unregister_device(&mtty_dev.dev);
 
 	device_unregister(&mtty_dev.dev);
 	idr_destroy(&mtty_dev.vd_idr);