diff mbox series

[02/18] vfio/mdev: Add missing typesafety around mdev_device

Message ID 2-v1-7dedf20b2b75+4f785-vfio2_jgg@nvidia.com (mailing list archive)
State New, archived
Headers show
Series Make vfio_mdev type safe | expand

Commit Message

Jason Gunthorpe March 23, 2021, 5:55 p.m. UTC
The mdev API should accept and pass a 'struct mdev_device *' in all
places, not pass a 'struct device *' and cast it internally with
to_mdev_device(). Particularly in its struct mdev_driver functions, the
whole point of a bus's struct device_driver wrapper is to provide type
safety compared to the default struct device_driver.

Further, the driver core standard is for bus drivers to expose their
device structure in their public headers that can be used with
container_of() inlines and '&foo->dev' to go between the class levels, and
'&foo->dev' to be used with dev_err/etc driver core helper functions. Move
'struct mdev_device' to mdev.h

Once done this allows moving some one instruction exported functions to
static inlines, which in turns allows removing one of the two grotesque
symbol_get()'s related to mdev in the core code.

Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
---
 .../driver-api/vfio-mediated-device.rst       |  4 +-
 drivers/vfio/mdev/mdev_core.c                 | 64 ++-----------------
 drivers/vfio/mdev/mdev_driver.c               |  4 +-
 drivers/vfio/mdev/mdev_private.h              | 23 +------
 drivers/vfio/mdev/mdev_sysfs.c                | 26 ++++----
 drivers/vfio/mdev/vfio_mdev.c                 |  7 +-
 drivers/vfio/vfio_iommu_type1.c               | 25 ++------
 include/linux/mdev.h                          | 58 +++++++++++++----
 8 files changed, 83 insertions(+), 128 deletions(-)

Comments

Tian, Kevin March 26, 2021, 2:04 a.m. UTC | #1
> From: Jason Gunthorpe <jgg@nvidia.com>
> Sent: Wednesday, March 24, 2021 1:55 AM
> 
> The mdev API should accept and pass a 'struct mdev_device *' in all
> places, not pass a 'struct device *' and cast it internally with
> to_mdev_device(). Particularly in its struct mdev_driver functions, the
> whole point of a bus's struct device_driver wrapper is to provide type
> safety compared to the default struct device_driver.
> 
> Further, the driver core standard is for bus drivers to expose their
> device structure in their public headers that can be used with
> container_of() inlines and '&foo->dev' to go between the class levels, and
> '&foo->dev' to be used with dev_err/etc driver core helper functions. Move
> 'struct mdev_device' to mdev.h
> 
> Once done this allows moving some one instruction exported functions to
> static inlines, which in turns allows removing one of the two grotesque
> symbol_get()'s related to mdev in the core code.
> 
> Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>

Reviewed-by: Kevin Tian <kevin.tian@intel.com>

> ---
>  .../driver-api/vfio-mediated-device.rst       |  4 +-
>  drivers/vfio/mdev/mdev_core.c                 | 64 ++-----------------
>  drivers/vfio/mdev/mdev_driver.c               |  4 +-
>  drivers/vfio/mdev/mdev_private.h              | 23 +------
>  drivers/vfio/mdev/mdev_sysfs.c                | 26 ++++----
>  drivers/vfio/mdev/vfio_mdev.c                 |  7 +-
>  drivers/vfio/vfio_iommu_type1.c               | 25 ++------
>  include/linux/mdev.h                          | 58 +++++++++++++----
>  8 files changed, 83 insertions(+), 128 deletions(-)
> 
> diff --git a/Documentation/driver-api/vfio-mediated-device.rst
> b/Documentation/driver-api/vfio-mediated-device.rst
> index 25eb7d5b834ba3..c43c1dc3333373 100644
> --- a/Documentation/driver-api/vfio-mediated-device.rst
> +++ b/Documentation/driver-api/vfio-mediated-device.rst
> @@ -105,8 +105,8 @@ structure to represent a mediated device's driver::
>        */
>       struct mdev_driver {
>  	     const char *name;
> -	     int  (*probe)  (struct device *dev);
> -	     void (*remove) (struct device *dev);
> +	     int  (*probe)  (struct mdev_device *dev);
> +	     void (*remove) (struct mdev_device *dev);
>  	     struct device_driver    driver;
>       };
> 
> diff --git a/drivers/vfio/mdev/mdev_core.c
> b/drivers/vfio/mdev/mdev_core.c
> index 6de97d25a3f87d..057922a1707e04 100644
> --- a/drivers/vfio/mdev/mdev_core.c
> +++ b/drivers/vfio/mdev/mdev_core.c
> @@ -33,36 +33,6 @@ struct device *mdev_parent_dev(struct mdev_device
> *mdev)
>  }
>  EXPORT_SYMBOL(mdev_parent_dev);
> 
> -void *mdev_get_drvdata(struct mdev_device *mdev)
> -{
> -	return mdev->driver_data;
> -}
> -EXPORT_SYMBOL(mdev_get_drvdata);
> -
> -void mdev_set_drvdata(struct mdev_device *mdev, void *data)
> -{
> -	mdev->driver_data = data;
> -}
> -EXPORT_SYMBOL(mdev_set_drvdata);
> -
> -struct device *mdev_dev(struct mdev_device *mdev)
> -{
> -	return &mdev->dev;
> -}
> -EXPORT_SYMBOL(mdev_dev);
> -
> -struct mdev_device *mdev_from_dev(struct device *dev)
> -{
> -	return dev_is_mdev(dev) ? to_mdev_device(dev) : NULL;
> -}
> -EXPORT_SYMBOL(mdev_from_dev);
> -
> -const guid_t *mdev_uuid(struct mdev_device *mdev)
> -{
> -	return &mdev->uuid;
> -}
> -EXPORT_SYMBOL(mdev_uuid);
> -
>  /* Should be called holding parent_list_lock */
>  static struct mdev_parent *__find_parent_device(struct device *dev)
>  {
> @@ -107,7 +77,7 @@ static void mdev_device_remove_common(struct
> mdev_device *mdev)
>  	int ret;
> 
>  	type = to_mdev_type(mdev->type_kobj);
> -	mdev_remove_sysfs_files(&mdev->dev, type);
> +	mdev_remove_sysfs_files(mdev, type);
>  	device_del(&mdev->dev);
>  	parent = mdev->parent;
>  	lockdep_assert_held(&parent->unreg_sem);
> @@ -122,12 +92,10 @@ static void mdev_device_remove_common(struct
> mdev_device *mdev)
> 
>  static int mdev_device_remove_cb(struct device *dev, void *data)
>  {
> -	if (dev_is_mdev(dev)) {
> -		struct mdev_device *mdev;
> +	struct mdev_device *mdev = mdev_from_dev(dev);
> 
> -		mdev = to_mdev_device(dev);
> +	if (mdev)
>  		mdev_device_remove_common(mdev);
> -	}
>  	return 0;
>  }
> 
> @@ -332,7 +300,7 @@ int mdev_device_create(struct kobject *kobj,
>  	if (ret)
>  		goto add_fail;
> 
> -	ret = mdev_create_sysfs_files(&mdev->dev, type);
> +	ret = mdev_create_sysfs_files(mdev, type);
>  	if (ret)
>  		goto sysfs_fail;
> 
> @@ -354,13 +322,11 @@ int mdev_device_create(struct kobject *kobj,
>  	return ret;
>  }
> 
> -int mdev_device_remove(struct device *dev)
> +int mdev_device_remove(struct mdev_device *mdev)
>  {
> -	struct mdev_device *mdev, *tmp;
> +	struct mdev_device *tmp;
>  	struct mdev_parent *parent;
> 
> -	mdev = to_mdev_device(dev);
> -
>  	mutex_lock(&mdev_list_lock);
>  	list_for_each_entry(tmp, &mdev_list, next) {
>  		if (tmp == mdev)
> @@ -390,24 +356,6 @@ int mdev_device_remove(struct device *dev)
>  	return 0;
>  }
> 
> -int mdev_set_iommu_device(struct device *dev, struct device
> *iommu_device)
> -{
> -	struct mdev_device *mdev = to_mdev_device(dev);
> -
> -	mdev->iommu_device = iommu_device;
> -
> -	return 0;
> -}
> -EXPORT_SYMBOL(mdev_set_iommu_device);
> -
> -struct device *mdev_get_iommu_device(struct device *dev)
> -{
> -	struct mdev_device *mdev = to_mdev_device(dev);
> -
> -	return mdev->iommu_device;
> -}
> -EXPORT_SYMBOL(mdev_get_iommu_device);
> -
>  static int __init mdev_init(void)
>  {
>  	return mdev_bus_register();
> diff --git a/drivers/vfio/mdev/mdev_driver.c
> b/drivers/vfio/mdev/mdev_driver.c
> index 0d3223aee20b83..44c3ba7e56d923 100644
> --- a/drivers/vfio/mdev/mdev_driver.c
> +++ b/drivers/vfio/mdev/mdev_driver.c
> @@ -48,7 +48,7 @@ static int mdev_probe(struct device *dev)
>  		return ret;
> 
>  	if (drv && drv->probe) {
> -		ret = drv->probe(dev);
> +		ret = drv->probe(mdev);
>  		if (ret)
>  			mdev_detach_iommu(mdev);
>  	}
> @@ -62,7 +62,7 @@ static int mdev_remove(struct device *dev)
>  	struct mdev_device *mdev = to_mdev_device(dev);
> 
>  	if (drv && drv->remove)
> -		drv->remove(dev);
> +		drv->remove(mdev);
> 
>  	mdev_detach_iommu(mdev);
> 
> diff --git a/drivers/vfio/mdev/mdev_private.h
> b/drivers/vfio/mdev/mdev_private.h
> index 74c2e541146999..bb60ec4a8d9d21 100644
> --- a/drivers/vfio/mdev/mdev_private.h
> +++ b/drivers/vfio/mdev/mdev_private.h
> @@ -24,23 +24,6 @@ struct mdev_parent {
>  	struct rw_semaphore unreg_sem;
>  };
> 
> -struct mdev_device {
> -	struct device dev;
> -	struct mdev_parent *parent;
> -	guid_t uuid;
> -	void *driver_data;
> -	struct list_head next;
> -	struct kobject *type_kobj;
> -	struct device *iommu_device;
> -	bool active;
> -};
> -
> -static inline struct mdev_device *to_mdev_device(struct device *dev)
> -{
> -	return container_of(dev, struct mdev_device, dev);
> -}
> -#define dev_is_mdev(d)		((d)->bus == &mdev_bus_type)
> -
>  struct mdev_type {
>  	struct kobject kobj;
>  	struct kobject *devices_kobj;
> @@ -57,11 +40,11 @@ struct mdev_type {
>  int  parent_create_sysfs_files(struct mdev_parent *parent);
>  void parent_remove_sysfs_files(struct mdev_parent *parent);
> 
> -int  mdev_create_sysfs_files(struct device *dev, struct mdev_type *type);
> -void mdev_remove_sysfs_files(struct device *dev, struct mdev_type *type);
> +int  mdev_create_sysfs_files(struct mdev_device *mdev, struct mdev_type
> *type);
> +void mdev_remove_sysfs_files(struct mdev_device *mdev, struct
> mdev_type *type);
> 
>  int  mdev_device_create(struct kobject *kobj,
>  			struct device *dev, const guid_t *uuid);
> -int  mdev_device_remove(struct device *dev);
> +int  mdev_device_remove(struct mdev_device *dev);
> 
>  #endif /* MDEV_PRIVATE_H */
> diff --git a/drivers/vfio/mdev/mdev_sysfs.c
> b/drivers/vfio/mdev/mdev_sysfs.c
> index 917fd84c1c6f24..6a5450587b79e9 100644
> --- a/drivers/vfio/mdev/mdev_sysfs.c
> +++ b/drivers/vfio/mdev/mdev_sysfs.c
> @@ -225,6 +225,7 @@ int parent_create_sysfs_files(struct mdev_parent
> *parent)
>  static ssize_t remove_store(struct device *dev, struct device_attribute *attr,
>  			    const char *buf, size_t count)
>  {
> +	struct mdev_device *mdev = to_mdev_device(dev);
>  	unsigned long val;
> 
>  	if (kstrtoul(buf, 0, &val) < 0)
> @@ -233,7 +234,7 @@ static ssize_t remove_store(struct device *dev, struct
> device_attribute *attr,
>  	if (val && device_remove_file_self(dev, attr)) {
>  		int ret;
> 
> -		ret = mdev_device_remove(dev);
> +		ret = mdev_device_remove(mdev);
>  		if (ret)
>  			return ret;
>  	}
> @@ -248,34 +249,37 @@ static const struct attribute *mdev_device_attrs[] =
> {
>  	NULL,
>  };
> 
> -int  mdev_create_sysfs_files(struct device *dev, struct mdev_type *type)
> +int mdev_create_sysfs_files(struct mdev_device *mdev, struct mdev_type
> *type)
>  {
> +	struct kobject *kobj = &mdev->dev.kobj;
>  	int ret;
> 
> -	ret = sysfs_create_link(type->devices_kobj, &dev->kobj,
> dev_name(dev));
> +	ret = sysfs_create_link(type->devices_kobj, kobj, dev_name(&mdev-
> >dev));
>  	if (ret)
>  		return ret;
> 
> -	ret = sysfs_create_link(&dev->kobj, &type->kobj, "mdev_type");
> +	ret = sysfs_create_link(kobj, &type->kobj, "mdev_type");
>  	if (ret)
>  		goto type_link_failed;
> 
> -	ret = sysfs_create_files(&dev->kobj, mdev_device_attrs);
> +	ret = sysfs_create_files(kobj, mdev_device_attrs);
>  	if (ret)
>  		goto create_files_failed;
> 
>  	return ret;
> 
>  create_files_failed:
> -	sysfs_remove_link(&dev->kobj, "mdev_type");
> +	sysfs_remove_link(kobj, "mdev_type");
>  type_link_failed:
> -	sysfs_remove_link(type->devices_kobj, dev_name(dev));
> +	sysfs_remove_link(type->devices_kobj, dev_name(&mdev->dev));
>  	return ret;
>  }
> 
> -void mdev_remove_sysfs_files(struct device *dev, struct mdev_type *type)
> +void mdev_remove_sysfs_files(struct mdev_device *mdev, struct
> mdev_type *type)
>  {
> -	sysfs_remove_files(&dev->kobj, mdev_device_attrs);
> -	sysfs_remove_link(&dev->kobj, "mdev_type");
> -	sysfs_remove_link(type->devices_kobj, dev_name(dev));
> +	struct kobject *kobj = &mdev->dev.kobj;
> +
> +	sysfs_remove_files(kobj, mdev_device_attrs);
> +	sysfs_remove_link(kobj, "mdev_type");
> +	sysfs_remove_link(type->devices_kobj, dev_name(&mdev->dev));
>  }
> diff --git a/drivers/vfio/mdev/vfio_mdev.c b/drivers/vfio/mdev/vfio_mdev.c
> index ae7e322fbe3c26..91b7b8b9eb9cb8 100644
> --- a/drivers/vfio/mdev/vfio_mdev.c
> +++ b/drivers/vfio/mdev/vfio_mdev.c
> @@ -124,9 +124,8 @@ static const struct vfio_device_ops
> vfio_mdev_dev_ops = {
>  	.request	= vfio_mdev_request,
>  };
> 
> -static int vfio_mdev_probe(struct device *dev)
> +static int vfio_mdev_probe(struct mdev_device *mdev)
>  {
> -	struct mdev_device *mdev = to_mdev_device(dev);
>  	struct vfio_device *vdev;
>  	int ret;
> 
> @@ -144,9 +143,9 @@ static int vfio_mdev_probe(struct device *dev)
>  	return 0;
>  }
> 
> -static void vfio_mdev_remove(struct device *dev)
> +static void vfio_mdev_remove(struct mdev_device *mdev)
>  {
> -	struct vfio_device *vdev = dev_get_drvdata(dev);
> +	struct vfio_device *vdev = dev_get_drvdata(&mdev->dev);
> 
>  	vfio_unregister_group_dev(vdev);
>  	kfree(vdev);
> diff --git a/drivers/vfio/vfio_iommu_type1.c
> b/drivers/vfio/vfio_iommu_type1.c
> index 4bb162c1d649b3..90b45ff1d87a7b 100644
> --- a/drivers/vfio/vfio_iommu_type1.c
> +++ b/drivers/vfio/vfio_iommu_type1.c
> @@ -1923,28 +1923,13 @@ static bool vfio_iommu_has_sw_msi(struct
> list_head *group_resv_regions,
>  	return ret;
>  }
> 
> -static struct device *vfio_mdev_get_iommu_device(struct device *dev)
> -{
> -	struct device *(*fn)(struct device *dev);
> -	struct device *iommu_device;
> -
> -	fn = symbol_get(mdev_get_iommu_device);
> -	if (fn) {
> -		iommu_device = fn(dev);
> -		symbol_put(mdev_get_iommu_device);
> -
> -		return iommu_device;
> -	}
> -
> -	return NULL;
> -}
> -
>  static int vfio_mdev_attach_domain(struct device *dev, void *data)
>  {
> +	struct mdev_device *mdev = to_mdev_device(dev);
>  	struct iommu_domain *domain = data;
>  	struct device *iommu_device;
> 
> -	iommu_device = vfio_mdev_get_iommu_device(dev);
> +	iommu_device = mdev_get_iommu_device(mdev);
>  	if (iommu_device) {
>  		if (iommu_dev_feature_enabled(iommu_device,
> IOMMU_DEV_FEAT_AUX))
>  			return iommu_aux_attach_device(domain,
> iommu_device);
> @@ -1957,10 +1942,11 @@ static int vfio_mdev_attach_domain(struct
> device *dev, void *data)
> 
>  static int vfio_mdev_detach_domain(struct device *dev, void *data)
>  {
> +	struct mdev_device *mdev = to_mdev_device(dev);
>  	struct iommu_domain *domain = data;
>  	struct device *iommu_device;
> 
> -	iommu_device = vfio_mdev_get_iommu_device(dev);
> +	iommu_device = mdev_get_iommu_device(mdev);
>  	if (iommu_device) {
>  		if (iommu_dev_feature_enabled(iommu_device,
> IOMMU_DEV_FEAT_AUX))
>  			iommu_aux_detach_device(domain, iommu_device);
> @@ -2008,9 +1994,10 @@ static bool vfio_bus_is_mdev(struct bus_type
> *bus)
> 
>  static int vfio_mdev_iommu_device(struct device *dev, void *data)
>  {
> +	struct mdev_device *mdev = to_mdev_device(dev);
>  	struct device **old = data, *new;
> 
> -	new = vfio_mdev_get_iommu_device(dev);
> +	new = mdev_get_iommu_device(mdev);
>  	if (!new || (*old && *old != new))
>  		return -EINVAL;
> 
> diff --git a/include/linux/mdev.h b/include/linux/mdev.h
> index 27eb383cb95de0..52f7ea19dd0f56 100644
> --- a/include/linux/mdev.h
> +++ b/include/linux/mdev.h
> @@ -10,7 +10,21 @@
>  #ifndef MDEV_H
>  #define MDEV_H
> 
> -struct mdev_device;
> +struct mdev_device {
> +	struct device dev;
> +	struct mdev_parent *parent;
> +	guid_t uuid;
> +	void *driver_data;
> +	struct list_head next;
> +	struct kobject *type_kobj;
> +	struct device *iommu_device;
> +	bool active;
> +};
> +
> +static inline struct mdev_device *to_mdev_device(struct device *dev)
> +{
> +	return container_of(dev, struct mdev_device, dev);
> +}
> 
>  /*
>   * Called by the parent device driver to set the device which represents
> @@ -19,12 +33,17 @@ struct mdev_device;
>   *
>   * @dev: the mediated device that iommu will isolate.
>   * @iommu_device: a pci device which represents the iommu for @dev.
> - *
> - * Return 0 for success, otherwise negative error value.
>   */
> -int mdev_set_iommu_device(struct device *dev, struct device
> *iommu_device);
> +static inline void mdev_set_iommu_device(struct mdev_device *mdev,
> +					 struct device *iommu_device)
> +{
> +	mdev->iommu_device = iommu_device;
> +}
> 
> -struct device *mdev_get_iommu_device(struct device *dev);
> +static inline struct device *mdev_get_iommu_device(struct mdev_device
> *mdev)
> +{
> +	return mdev->iommu_device;
> +}
> 
>  /**
>   * struct mdev_parent_ops - Structure to be registered for each parent
> device to
> @@ -126,16 +145,25 @@ struct mdev_type_attribute
> mdev_type_attr_##_name =		\
>   **/
>  struct mdev_driver {
>  	const char *name;
> -	int  (*probe)(struct device *dev);
> -	void (*remove)(struct device *dev);
> +	int (*probe)(struct mdev_device *dev);
> +	void (*remove)(struct mdev_device *dev);
>  	struct device_driver driver;
>  };
> 
>  #define to_mdev_driver(drv)	container_of(drv, struct mdev_driver, driver)
> 
> -void *mdev_get_drvdata(struct mdev_device *mdev);
> -void mdev_set_drvdata(struct mdev_device *mdev, void *data);
> -const guid_t *mdev_uuid(struct mdev_device *mdev);
> +static inline void *mdev_get_drvdata(struct mdev_device *mdev)
> +{
> +	return mdev->driver_data;
> +}
> +static inline void mdev_set_drvdata(struct mdev_device *mdev, void *data)
> +{
> +	mdev->driver_data = data;
> +}
> +static inline const guid_t *mdev_uuid(struct mdev_device *mdev)
> +{
> +	return &mdev->uuid;
> +}
> 
>  extern struct bus_type mdev_bus_type;
> 
> @@ -146,7 +174,13 @@ int mdev_register_driver(struct mdev_driver *drv,
> struct module *owner);
>  void mdev_unregister_driver(struct mdev_driver *drv);
> 
>  struct device *mdev_parent_dev(struct mdev_device *mdev);
> -struct device *mdev_dev(struct mdev_device *mdev);
> -struct mdev_device *mdev_from_dev(struct device *dev);
> +static inline struct device *mdev_dev(struct mdev_device *mdev)
> +{
> +	return &mdev->dev;
> +}
> +static inline struct mdev_device *mdev_from_dev(struct device *dev)
> +{
> +	return dev->bus == &mdev_bus_type ? to_mdev_device(dev) : NULL;
> +}
> 
>  #endif /* MDEV_H */
> --
> 2.31.0
Cornelia Huck March 30, 2021, 3:24 p.m. UTC | #2
On Tue, 23 Mar 2021 14:55:19 -0300
Jason Gunthorpe <jgg@nvidia.com> wrote:

> The mdev API should accept and pass a 'struct mdev_device *' in all
> places, not pass a 'struct device *' and cast it internally with
> to_mdev_device(). Particularly in its struct mdev_driver functions, the
> whole point of a bus's struct device_driver wrapper is to provide type
> safety compared to the default struct device_driver.
> 
> Further, the driver core standard is for bus drivers to expose their
> device structure in their public headers that can be used with
> container_of() inlines and '&foo->dev' to go between the class levels, and
> '&foo->dev' to be used with dev_err/etc driver core helper functions. Move
> 'struct mdev_device' to mdev.h
> 
> Once done this allows moving some one instruction exported functions to
> static inlines, which in turns allows removing one of the two grotesque
> symbol_get()'s related to mdev in the core code.
> 
> Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
> ---
>  .../driver-api/vfio-mediated-device.rst       |  4 +-
>  drivers/vfio/mdev/mdev_core.c                 | 64 ++-----------------
>  drivers/vfio/mdev/mdev_driver.c               |  4 +-
>  drivers/vfio/mdev/mdev_private.h              | 23 +------
>  drivers/vfio/mdev/mdev_sysfs.c                | 26 ++++----
>  drivers/vfio/mdev/vfio_mdev.c                 |  7 +-
>  drivers/vfio/vfio_iommu_type1.c               | 25 ++------
>  include/linux/mdev.h                          | 58 +++++++++++++----
>  8 files changed, 83 insertions(+), 128 deletions(-)

Reviewed-by: Cornelia Huck <cohuck@redhat.com>
diff mbox series

Patch

diff --git a/Documentation/driver-api/vfio-mediated-device.rst b/Documentation/driver-api/vfio-mediated-device.rst
index 25eb7d5b834ba3..c43c1dc3333373 100644
--- a/Documentation/driver-api/vfio-mediated-device.rst
+++ b/Documentation/driver-api/vfio-mediated-device.rst
@@ -105,8 +105,8 @@  structure to represent a mediated device's driver::
       */
      struct mdev_driver {
 	     const char *name;
-	     int  (*probe)  (struct device *dev);
-	     void (*remove) (struct device *dev);
+	     int  (*probe)  (struct mdev_device *dev);
+	     void (*remove) (struct mdev_device *dev);
 	     struct device_driver    driver;
      };
 
diff --git a/drivers/vfio/mdev/mdev_core.c b/drivers/vfio/mdev/mdev_core.c
index 6de97d25a3f87d..057922a1707e04 100644
--- a/drivers/vfio/mdev/mdev_core.c
+++ b/drivers/vfio/mdev/mdev_core.c
@@ -33,36 +33,6 @@  struct device *mdev_parent_dev(struct mdev_device *mdev)
 }
 EXPORT_SYMBOL(mdev_parent_dev);
 
-void *mdev_get_drvdata(struct mdev_device *mdev)
-{
-	return mdev->driver_data;
-}
-EXPORT_SYMBOL(mdev_get_drvdata);
-
-void mdev_set_drvdata(struct mdev_device *mdev, void *data)
-{
-	mdev->driver_data = data;
-}
-EXPORT_SYMBOL(mdev_set_drvdata);
-
-struct device *mdev_dev(struct mdev_device *mdev)
-{
-	return &mdev->dev;
-}
-EXPORT_SYMBOL(mdev_dev);
-
-struct mdev_device *mdev_from_dev(struct device *dev)
-{
-	return dev_is_mdev(dev) ? to_mdev_device(dev) : NULL;
-}
-EXPORT_SYMBOL(mdev_from_dev);
-
-const guid_t *mdev_uuid(struct mdev_device *mdev)
-{
-	return &mdev->uuid;
-}
-EXPORT_SYMBOL(mdev_uuid);
-
 /* Should be called holding parent_list_lock */
 static struct mdev_parent *__find_parent_device(struct device *dev)
 {
@@ -107,7 +77,7 @@  static void mdev_device_remove_common(struct mdev_device *mdev)
 	int ret;
 
 	type = to_mdev_type(mdev->type_kobj);
-	mdev_remove_sysfs_files(&mdev->dev, type);
+	mdev_remove_sysfs_files(mdev, type);
 	device_del(&mdev->dev);
 	parent = mdev->parent;
 	lockdep_assert_held(&parent->unreg_sem);
@@ -122,12 +92,10 @@  static void mdev_device_remove_common(struct mdev_device *mdev)
 
 static int mdev_device_remove_cb(struct device *dev, void *data)
 {
-	if (dev_is_mdev(dev)) {
-		struct mdev_device *mdev;
+	struct mdev_device *mdev = mdev_from_dev(dev);
 
-		mdev = to_mdev_device(dev);
+	if (mdev)
 		mdev_device_remove_common(mdev);
-	}
 	return 0;
 }
 
@@ -332,7 +300,7 @@  int mdev_device_create(struct kobject *kobj,
 	if (ret)
 		goto add_fail;
 
-	ret = mdev_create_sysfs_files(&mdev->dev, type);
+	ret = mdev_create_sysfs_files(mdev, type);
 	if (ret)
 		goto sysfs_fail;
 
@@ -354,13 +322,11 @@  int mdev_device_create(struct kobject *kobj,
 	return ret;
 }
 
-int mdev_device_remove(struct device *dev)
+int mdev_device_remove(struct mdev_device *mdev)
 {
-	struct mdev_device *mdev, *tmp;
+	struct mdev_device *tmp;
 	struct mdev_parent *parent;
 
-	mdev = to_mdev_device(dev);
-
 	mutex_lock(&mdev_list_lock);
 	list_for_each_entry(tmp, &mdev_list, next) {
 		if (tmp == mdev)
@@ -390,24 +356,6 @@  int mdev_device_remove(struct device *dev)
 	return 0;
 }
 
-int mdev_set_iommu_device(struct device *dev, struct device *iommu_device)
-{
-	struct mdev_device *mdev = to_mdev_device(dev);
-
-	mdev->iommu_device = iommu_device;
-
-	return 0;
-}
-EXPORT_SYMBOL(mdev_set_iommu_device);
-
-struct device *mdev_get_iommu_device(struct device *dev)
-{
-	struct mdev_device *mdev = to_mdev_device(dev);
-
-	return mdev->iommu_device;
-}
-EXPORT_SYMBOL(mdev_get_iommu_device);
-
 static int __init mdev_init(void)
 {
 	return mdev_bus_register();
diff --git a/drivers/vfio/mdev/mdev_driver.c b/drivers/vfio/mdev/mdev_driver.c
index 0d3223aee20b83..44c3ba7e56d923 100644
--- a/drivers/vfio/mdev/mdev_driver.c
+++ b/drivers/vfio/mdev/mdev_driver.c
@@ -48,7 +48,7 @@  static int mdev_probe(struct device *dev)
 		return ret;
 
 	if (drv && drv->probe) {
-		ret = drv->probe(dev);
+		ret = drv->probe(mdev);
 		if (ret)
 			mdev_detach_iommu(mdev);
 	}
@@ -62,7 +62,7 @@  static int mdev_remove(struct device *dev)
 	struct mdev_device *mdev = to_mdev_device(dev);
 
 	if (drv && drv->remove)
-		drv->remove(dev);
+		drv->remove(mdev);
 
 	mdev_detach_iommu(mdev);
 
diff --git a/drivers/vfio/mdev/mdev_private.h b/drivers/vfio/mdev/mdev_private.h
index 74c2e541146999..bb60ec4a8d9d21 100644
--- a/drivers/vfio/mdev/mdev_private.h
+++ b/drivers/vfio/mdev/mdev_private.h
@@ -24,23 +24,6 @@  struct mdev_parent {
 	struct rw_semaphore unreg_sem;
 };
 
-struct mdev_device {
-	struct device dev;
-	struct mdev_parent *parent;
-	guid_t uuid;
-	void *driver_data;
-	struct list_head next;
-	struct kobject *type_kobj;
-	struct device *iommu_device;
-	bool active;
-};
-
-static inline struct mdev_device *to_mdev_device(struct device *dev)
-{
-	return container_of(dev, struct mdev_device, dev);
-}
-#define dev_is_mdev(d)		((d)->bus == &mdev_bus_type)
-
 struct mdev_type {
 	struct kobject kobj;
 	struct kobject *devices_kobj;
@@ -57,11 +40,11 @@  struct mdev_type {
 int  parent_create_sysfs_files(struct mdev_parent *parent);
 void parent_remove_sysfs_files(struct mdev_parent *parent);
 
-int  mdev_create_sysfs_files(struct device *dev, struct mdev_type *type);
-void mdev_remove_sysfs_files(struct device *dev, struct mdev_type *type);
+int  mdev_create_sysfs_files(struct mdev_device *mdev, struct mdev_type *type);
+void mdev_remove_sysfs_files(struct mdev_device *mdev, struct mdev_type *type);
 
 int  mdev_device_create(struct kobject *kobj,
 			struct device *dev, const guid_t *uuid);
-int  mdev_device_remove(struct device *dev);
+int  mdev_device_remove(struct mdev_device *dev);
 
 #endif /* MDEV_PRIVATE_H */
diff --git a/drivers/vfio/mdev/mdev_sysfs.c b/drivers/vfio/mdev/mdev_sysfs.c
index 917fd84c1c6f24..6a5450587b79e9 100644
--- a/drivers/vfio/mdev/mdev_sysfs.c
+++ b/drivers/vfio/mdev/mdev_sysfs.c
@@ -225,6 +225,7 @@  int parent_create_sysfs_files(struct mdev_parent *parent)
 static ssize_t remove_store(struct device *dev, struct device_attribute *attr,
 			    const char *buf, size_t count)
 {
+	struct mdev_device *mdev = to_mdev_device(dev);
 	unsigned long val;
 
 	if (kstrtoul(buf, 0, &val) < 0)
@@ -233,7 +234,7 @@  static ssize_t remove_store(struct device *dev, struct device_attribute *attr,
 	if (val && device_remove_file_self(dev, attr)) {
 		int ret;
 
-		ret = mdev_device_remove(dev);
+		ret = mdev_device_remove(mdev);
 		if (ret)
 			return ret;
 	}
@@ -248,34 +249,37 @@  static const struct attribute *mdev_device_attrs[] = {
 	NULL,
 };
 
-int  mdev_create_sysfs_files(struct device *dev, struct mdev_type *type)
+int mdev_create_sysfs_files(struct mdev_device *mdev, struct mdev_type *type)
 {
+	struct kobject *kobj = &mdev->dev.kobj;
 	int ret;
 
-	ret = sysfs_create_link(type->devices_kobj, &dev->kobj, dev_name(dev));
+	ret = sysfs_create_link(type->devices_kobj, kobj, dev_name(&mdev->dev));
 	if (ret)
 		return ret;
 
-	ret = sysfs_create_link(&dev->kobj, &type->kobj, "mdev_type");
+	ret = sysfs_create_link(kobj, &type->kobj, "mdev_type");
 	if (ret)
 		goto type_link_failed;
 
-	ret = sysfs_create_files(&dev->kobj, mdev_device_attrs);
+	ret = sysfs_create_files(kobj, mdev_device_attrs);
 	if (ret)
 		goto create_files_failed;
 
 	return ret;
 
 create_files_failed:
-	sysfs_remove_link(&dev->kobj, "mdev_type");
+	sysfs_remove_link(kobj, "mdev_type");
 type_link_failed:
-	sysfs_remove_link(type->devices_kobj, dev_name(dev));
+	sysfs_remove_link(type->devices_kobj, dev_name(&mdev->dev));
 	return ret;
 }
 
-void mdev_remove_sysfs_files(struct device *dev, struct mdev_type *type)
+void mdev_remove_sysfs_files(struct mdev_device *mdev, struct mdev_type *type)
 {
-	sysfs_remove_files(&dev->kobj, mdev_device_attrs);
-	sysfs_remove_link(&dev->kobj, "mdev_type");
-	sysfs_remove_link(type->devices_kobj, dev_name(dev));
+	struct kobject *kobj = &mdev->dev.kobj;
+
+	sysfs_remove_files(kobj, mdev_device_attrs);
+	sysfs_remove_link(kobj, "mdev_type");
+	sysfs_remove_link(type->devices_kobj, dev_name(&mdev->dev));
 }
diff --git a/drivers/vfio/mdev/vfio_mdev.c b/drivers/vfio/mdev/vfio_mdev.c
index ae7e322fbe3c26..91b7b8b9eb9cb8 100644
--- a/drivers/vfio/mdev/vfio_mdev.c
+++ b/drivers/vfio/mdev/vfio_mdev.c
@@ -124,9 +124,8 @@  static const struct vfio_device_ops vfio_mdev_dev_ops = {
 	.request	= vfio_mdev_request,
 };
 
-static int vfio_mdev_probe(struct device *dev)
+static int vfio_mdev_probe(struct mdev_device *mdev)
 {
-	struct mdev_device *mdev = to_mdev_device(dev);
 	struct vfio_device *vdev;
 	int ret;
 
@@ -144,9 +143,9 @@  static int vfio_mdev_probe(struct device *dev)
 	return 0;
 }
 
-static void vfio_mdev_remove(struct device *dev)
+static void vfio_mdev_remove(struct mdev_device *mdev)
 {
-	struct vfio_device *vdev = dev_get_drvdata(dev);
+	struct vfio_device *vdev = dev_get_drvdata(&mdev->dev);
 
 	vfio_unregister_group_dev(vdev);
 	kfree(vdev);
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index 4bb162c1d649b3..90b45ff1d87a7b 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -1923,28 +1923,13 @@  static bool vfio_iommu_has_sw_msi(struct list_head *group_resv_regions,
 	return ret;
 }
 
-static struct device *vfio_mdev_get_iommu_device(struct device *dev)
-{
-	struct device *(*fn)(struct device *dev);
-	struct device *iommu_device;
-
-	fn = symbol_get(mdev_get_iommu_device);
-	if (fn) {
-		iommu_device = fn(dev);
-		symbol_put(mdev_get_iommu_device);
-
-		return iommu_device;
-	}
-
-	return NULL;
-}
-
 static int vfio_mdev_attach_domain(struct device *dev, void *data)
 {
+	struct mdev_device *mdev = to_mdev_device(dev);
 	struct iommu_domain *domain = data;
 	struct device *iommu_device;
 
-	iommu_device = vfio_mdev_get_iommu_device(dev);
+	iommu_device = mdev_get_iommu_device(mdev);
 	if (iommu_device) {
 		if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
 			return iommu_aux_attach_device(domain, iommu_device);
@@ -1957,10 +1942,11 @@  static int vfio_mdev_attach_domain(struct device *dev, void *data)
 
 static int vfio_mdev_detach_domain(struct device *dev, void *data)
 {
+	struct mdev_device *mdev = to_mdev_device(dev);
 	struct iommu_domain *domain = data;
 	struct device *iommu_device;
 
-	iommu_device = vfio_mdev_get_iommu_device(dev);
+	iommu_device = mdev_get_iommu_device(mdev);
 	if (iommu_device) {
 		if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
 			iommu_aux_detach_device(domain, iommu_device);
@@ -2008,9 +1994,10 @@  static bool vfio_bus_is_mdev(struct bus_type *bus)
 
 static int vfio_mdev_iommu_device(struct device *dev, void *data)
 {
+	struct mdev_device *mdev = to_mdev_device(dev);
 	struct device **old = data, *new;
 
-	new = vfio_mdev_get_iommu_device(dev);
+	new = mdev_get_iommu_device(mdev);
 	if (!new || (*old && *old != new))
 		return -EINVAL;
 
diff --git a/include/linux/mdev.h b/include/linux/mdev.h
index 27eb383cb95de0..52f7ea19dd0f56 100644
--- a/include/linux/mdev.h
+++ b/include/linux/mdev.h
@@ -10,7 +10,21 @@ 
 #ifndef MDEV_H
 #define MDEV_H
 
-struct mdev_device;
+struct mdev_device {
+	struct device dev;
+	struct mdev_parent *parent;
+	guid_t uuid;
+	void *driver_data;
+	struct list_head next;
+	struct kobject *type_kobj;
+	struct device *iommu_device;
+	bool active;
+};
+
+static inline struct mdev_device *to_mdev_device(struct device *dev)
+{
+	return container_of(dev, struct mdev_device, dev);
+}
 
 /*
  * Called by the parent device driver to set the device which represents
@@ -19,12 +33,17 @@  struct mdev_device;
  *
  * @dev: the mediated device that iommu will isolate.
  * @iommu_device: a pci device which represents the iommu for @dev.
- *
- * Return 0 for success, otherwise negative error value.
  */
-int mdev_set_iommu_device(struct device *dev, struct device *iommu_device);
+static inline void mdev_set_iommu_device(struct mdev_device *mdev,
+					 struct device *iommu_device)
+{
+	mdev->iommu_device = iommu_device;
+}
 
-struct device *mdev_get_iommu_device(struct device *dev);
+static inline struct device *mdev_get_iommu_device(struct mdev_device *mdev)
+{
+	return mdev->iommu_device;
+}
 
 /**
  * struct mdev_parent_ops - Structure to be registered for each parent device to
@@ -126,16 +145,25 @@  struct mdev_type_attribute mdev_type_attr_##_name =		\
  **/
 struct mdev_driver {
 	const char *name;
-	int  (*probe)(struct device *dev);
-	void (*remove)(struct device *dev);
+	int (*probe)(struct mdev_device *dev);
+	void (*remove)(struct mdev_device *dev);
 	struct device_driver driver;
 };
 
 #define to_mdev_driver(drv)	container_of(drv, struct mdev_driver, driver)
 
-void *mdev_get_drvdata(struct mdev_device *mdev);
-void mdev_set_drvdata(struct mdev_device *mdev, void *data);
-const guid_t *mdev_uuid(struct mdev_device *mdev);
+static inline void *mdev_get_drvdata(struct mdev_device *mdev)
+{
+	return mdev->driver_data;
+}
+static inline void mdev_set_drvdata(struct mdev_device *mdev, void *data)
+{
+	mdev->driver_data = data;
+}
+static inline const guid_t *mdev_uuid(struct mdev_device *mdev)
+{
+	return &mdev->uuid;
+}
 
 extern struct bus_type mdev_bus_type;
 
@@ -146,7 +174,13 @@  int mdev_register_driver(struct mdev_driver *drv, struct module *owner);
 void mdev_unregister_driver(struct mdev_driver *drv);
 
 struct device *mdev_parent_dev(struct mdev_device *mdev);
-struct device *mdev_dev(struct mdev_device *mdev);
-struct mdev_device *mdev_from_dev(struct device *dev);
+static inline struct device *mdev_dev(struct mdev_device *mdev)
+{
+	return &mdev->dev;
+}
+static inline struct mdev_device *mdev_from_dev(struct device *dev)
+{
+	return dev->bus == &mdev_bus_type ? to_mdev_device(dev) : NULL;
+}
 
 #endif /* MDEV_H */