diff mbox series

[03/13] vfio/mdev: simplify mdev_type handling

Message ID 20220614045428.278494-4-hch@lst.de (mailing list archive)
State New, archived
Headers show
Series [01/13] vfio/mdev: make mdev.h standalone includable | expand

Commit Message

Christoph Hellwig June 14, 2022, 4:54 a.m. UTC
Instead of abusing struct attribute_group to control initialization of
struct mdev_type, just define the actual attributes in the mdev_driver,
allocate the mdev_type structures in the caller and pass them to
mdev_register_parent.

This allows the caller to use container_of to get at the containing
struture and thus significantly simplify the code.

Signed-off-by: Christoph Hellwig <hch@lst.de>
---
 .../driver-api/vfio-mediated-device.rst       |   2 +-
 drivers/gpu/drm/i915/gvt/gvt.h                |   3 +-
 drivers/gpu/drm/i915/gvt/kvmgt.c              |  98 ++--------------
 drivers/gpu/drm/i915/gvt/vgpu.c               |  18 ++-
 drivers/s390/cio/cio.h                        |   2 +
 drivers/s390/cio/vfio_ccw_ops.c               |  19 +--
 drivers/s390/crypto/vfio_ap_ops.c             |  19 +--
 drivers/s390/crypto/vfio_ap_private.h         |   2 +
 drivers/vfio/mdev/mdev_core.c                 |  54 ++++-----
 drivers/vfio/mdev/mdev_driver.c               |   5 +-
 drivers/vfio/mdev/mdev_private.h              |  14 +--
 drivers/vfio/mdev/mdev_sysfs.c                | 111 ++----------------
 include/linux/mdev.h                          |  26 ++--
 samples/vfio-mdev/mbochs.c                    |  57 ++++-----
 samples/vfio-mdev/mdpy.c                      |  50 +++-----
 samples/vfio-mdev/mtty.c                      |  62 +++++-----
 16 files changed, 181 insertions(+), 361 deletions(-)

Comments

Yi Liu June 14, 2022, 6:14 a.m. UTC | #1
On 2022/6/14 12:54, Christoph Hellwig wrote:
> Instead of abusing struct attribute_group to control initialization of
> struct mdev_type, just define the actual attributes in the mdev_driver,
> allocate the mdev_type structures in the caller and pass them to
> mdev_register_parent.
> 
> This allows the caller to use container_of to get at the containing
> struture and thus significantly simplify the code.

s/struture/structure

> 
> Signed-off-by: Christoph Hellwig <hch@lst.de>
> ---
>   .../driver-api/vfio-mediated-device.rst       |   2 +-
>   drivers/gpu/drm/i915/gvt/gvt.h                |   3 +-
>   drivers/gpu/drm/i915/gvt/kvmgt.c              |  98 ++--------------
>   drivers/gpu/drm/i915/gvt/vgpu.c               |  18 ++-
>   drivers/s390/cio/cio.h                        |   2 +
>   drivers/s390/cio/vfio_ccw_ops.c               |  19 +--
>   drivers/s390/crypto/vfio_ap_ops.c             |  19 +--
>   drivers/s390/crypto/vfio_ap_private.h         |   2 +
>   drivers/vfio/mdev/mdev_core.c                 |  54 ++++-----
>   drivers/vfio/mdev/mdev_driver.c               |   5 +-
>   drivers/vfio/mdev/mdev_private.h              |  14 +--
>   drivers/vfio/mdev/mdev_sysfs.c                | 111 ++----------------
>   include/linux/mdev.h                          |  26 ++--
>   samples/vfio-mdev/mbochs.c                    |  57 ++++-----
>   samples/vfio-mdev/mdpy.c                      |  50 +++-----
>   samples/vfio-mdev/mtty.c                      |  62 +++++-----
>   16 files changed, 181 insertions(+), 361 deletions(-)
> 
> diff --git a/Documentation/driver-api/vfio-mediated-device.rst b/Documentation/driver-api/vfio-mediated-device.rst
> index 3749f59c855fa..599008bdc1dcb 100644
> --- a/Documentation/driver-api/vfio-mediated-device.rst
> +++ b/Documentation/driver-api/vfio-mediated-device.rst
> @@ -105,7 +105,7 @@ structure to represent a mediated device's driver::
>        struct mdev_driver {
>   	     int  (*probe)  (struct mdev_device *dev);
>   	     void (*remove) (struct mdev_device *dev);
> -	     struct attribute_group **supported_type_groups;
> +	     const struct attribute * const *types_attrs;
>   	     struct device_driver    driver;
>        };
>   
> diff --git a/drivers/gpu/drm/i915/gvt/gvt.h b/drivers/gpu/drm/i915/gvt/gvt.h
> index ddffd337f1c60..0ccb9bb7180cd 100644
> --- a/drivers/gpu/drm/i915/gvt/gvt.h
> +++ b/drivers/gpu/drm/i915/gvt/gvt.h
> @@ -298,7 +298,7 @@ struct intel_gvt_firmware {
>   
>   #define NR_MAX_INTEL_VGPU_TYPES 20
>   struct intel_vgpu_type {
> -	char name[16];
> +	struct mdev_type type;
>   	unsigned int avail_instance;
>   	unsigned int low_gm_size;
>   	unsigned int high_gm_size;
> @@ -329,6 +329,7 @@ struct intel_gvt {
>   	struct notifier_block shadow_ctx_notifier_block[I915_NUM_ENGINES];
>   	DECLARE_HASHTABLE(cmd_table, GVT_CMD_HASH_BITS);
>   	struct mdev_parent parent;
> +	struct mdev_type **mdev_types;
>   	struct intel_vgpu_type *types;
>   	unsigned int num_types;
>   	struct intel_vgpu *idle_vgpu;
> diff --git a/drivers/gpu/drm/i915/gvt/kvmgt.c b/drivers/gpu/drm/i915/gvt/kvmgt.c
> index 70401374c72bc..1c6b7e8bec4fb 100644
> --- a/drivers/gpu/drm/i915/gvt/kvmgt.c
> +++ b/drivers/gpu/drm/i915/gvt/kvmgt.c
> @@ -117,17 +117,10 @@ static ssize_t available_instances_show(struct mdev_type *mtype,
>   					struct mdev_type_attribute *attr,
>   					char *buf)
>   {
> -	struct intel_vgpu_type *type;
> -	unsigned int num = 0;
> -	struct intel_gvt *gvt = kdev_to_i915(mtype_get_parent_dev(mtype))->gvt;
> +	struct intel_vgpu_type *type =
> +		container_of(mtype, struct intel_vgpu_type, type);
>   
> -	type = &gvt->types[mtype_get_type_group_id(mtype)];
> -	if (!type)
> -		num = 0;
> -	else
> -		num = type->avail_instance;
> -
> -	return sprintf(buf, "%u\n", num);
> +	return sprintf(buf, "%u\n", type->avail_instance);
>   }
>   
>   static ssize_t device_api_show(struct mdev_type *mtype,
> @@ -139,12 +132,8 @@ static ssize_t device_api_show(struct mdev_type *mtype,
>   static ssize_t description_show(struct mdev_type *mtype,
>   				struct mdev_type_attribute *attr, char *buf)
>   {
> -	struct intel_vgpu_type *type;
> -	struct intel_gvt *gvt = kdev_to_i915(mtype_get_parent_dev(mtype))->gvt;
> -
> -	type = &gvt->types[mtype_get_type_group_id(mtype)];
> -	if (!type)
> -		return 0;
> +	struct intel_vgpu_type *type =
> +		container_of(mtype, struct intel_vgpu_type, type);
>   
>   	return sprintf(buf, "low_gm_size: %dMB\nhigh_gm_size: %dMB\n"
>   		       "fence: %d\nresolution: %s\n"
> @@ -158,14 +147,7 @@ static ssize_t description_show(struct mdev_type *mtype,
>   static ssize_t name_show(struct mdev_type *mtype,
>   			 struct mdev_type_attribute *attr, char *buf)
>   {
> -	struct intel_vgpu_type *type;
> -	struct intel_gvt *gvt = kdev_to_i915(mtype_get_parent_dev(mtype))->gvt;
> -
> -	type = &gvt->types[mtype_get_type_group_id(mtype)];
> -	if (!type)
> -		return 0;
> -
> -	return sprintf(buf, "%s\n", type->name);
> +	return sprintf(buf, "%s\n", mtype->sysfs_name);
>   }
>   
>   static MDEV_TYPE_ATTR_RO(available_instances);
> @@ -173,7 +155,7 @@ static MDEV_TYPE_ATTR_RO(device_api);
>   static MDEV_TYPE_ATTR_RO(description);
>   static MDEV_TYPE_ATTR_RO(name);
>   
> -static struct attribute *gvt_type_attrs[] = {
> +static const struct attribute *gvt_type_attrs[] = {
>   	&mdev_type_attr_available_instances.attr,
>   	&mdev_type_attr_device_api.attr,
>   	&mdev_type_attr_description.attr,
> @@ -181,51 +163,6 @@ static struct attribute *gvt_type_attrs[] = {
>   	NULL,
>   };
>   
> -static struct attribute_group *gvt_vgpu_type_groups[] = {
> -	[0 ... NR_MAX_INTEL_VGPU_TYPES - 1] = NULL,
> -};
> -
> -static int intel_gvt_init_vgpu_type_groups(struct intel_gvt *gvt)
> -{
> -	int i, j;
> -	struct intel_vgpu_type *type;
> -	struct attribute_group *group;
> -
> -	for (i = 0; i < gvt->num_types; i++) {
> -		type = &gvt->types[i];
> -
> -		group = kzalloc(sizeof(struct attribute_group), GFP_KERNEL);
> -		if (!group)
> -			goto unwind;
> -
> -		group->name = type->name;
> -		group->attrs = gvt_type_attrs;
> -		gvt_vgpu_type_groups[i] = group;
> -	}
> -
> -	return 0;
> -
> -unwind:
> -	for (j = 0; j < i; j++) {
> -		group = gvt_vgpu_type_groups[j];
> -		kfree(group);
> -	}
> -
> -	return -ENOMEM;
> -}
> -
> -static void intel_gvt_cleanup_vgpu_type_groups(struct intel_gvt *gvt)
> -{
> -	int i;
> -	struct attribute_group *group;
> -
> -	for (i = 0; i < gvt->num_types; i++) {
> -		group = gvt_vgpu_type_groups[i];
> -		gvt_vgpu_type_groups[i] = NULL;
> -		kfree(group);
> -	}
> -}
> -
>   static void gvt_unpin_guest_page(struct intel_vgpu *vgpu, unsigned long gfn,
>   		unsigned long size)
>   {
> @@ -1614,14 +1551,11 @@ static int intel_vgpu_probe(struct mdev_device *mdev)
>   {
>   	struct device *pdev = mdev_parent_dev(mdev);
>   	struct intel_gvt *gvt = kdev_to_i915(pdev)->gvt;
> -	struct intel_vgpu_type *type;
> +	struct intel_vgpu_type *type =
> +		container_of(mdev->type, struct intel_vgpu_type, type);
>   	struct intel_vgpu *vgpu;
>   	int ret;
>   
> -	type = &gvt->types[mdev_get_type_group_id(mdev)];
> -	if (!type)
> -		return -EINVAL;
> -
>   	vgpu = intel_gvt_create_vgpu(gvt, type);
>   	if (IS_ERR(vgpu)) {
>   		gvt_err("failed to create intel vgpu: %ld\n", PTR_ERR(vgpu));
> @@ -1660,7 +1594,7 @@ static struct mdev_driver intel_vgpu_mdev_driver = {
>   	},
>   	.probe		= intel_vgpu_probe,
>   	.remove		= intel_vgpu_remove,
> -	.supported_type_groups	= gvt_vgpu_type_groups,
> +	.types_attrs	= gvt_type_attrs,
>   };
>   
>   int intel_gvt_page_track_add(struct intel_vgpu *info, u64 gfn)
> @@ -1959,7 +1893,6 @@ static void intel_gvt_clean_device(struct drm_i915_private *i915)
>   		return;
>   
>   	mdev_unregister_parent(&gvt->parent);
> -	intel_gvt_cleanup_vgpu_type_groups(gvt);
>   	intel_gvt_destroy_idle_vgpu(gvt->idle_vgpu);
>   	intel_gvt_clean_vgpu_types(gvt);
>   
> @@ -2059,20 +1992,15 @@ static int intel_gvt_init_device(struct drm_i915_private *i915)
>   
>   	intel_gvt_debugfs_init(gvt);
>   
> -	ret = intel_gvt_init_vgpu_type_groups(gvt);
> -	if (ret)
> -		goto out_destroy_idle_vgpu;
> -
>   	ret = mdev_register_parent(&gvt->parent, i915->drm.dev,
> -				   &intel_vgpu_mdev_driver);
> +				   &intel_vgpu_mdev_driver,
> +				   gvt->mdev_types, gvt->num_types);
>   	if (ret)
> -		goto out_cleanup_vgpu_type_groups;
> +		goto out_destroy_idle_vgpu;
>   
>   	gvt_dbg_core("gvt device initialization is done\n");
>   	return 0;
>   
> -out_cleanup_vgpu_type_groups:
> -	intel_gvt_cleanup_vgpu_type_groups(gvt);
>   out_destroy_idle_vgpu:
>   	intel_gvt_destroy_idle_vgpu(gvt->idle_vgpu);
>   	intel_gvt_debugfs_clean(gvt);
> diff --git a/drivers/gpu/drm/i915/gvt/vgpu.c b/drivers/gpu/drm/i915/gvt/vgpu.c
> index 46da19b3225d2..2f38b90ecb8ac 100644
> --- a/drivers/gpu/drm/i915/gvt/vgpu.c
> +++ b/drivers/gpu/drm/i915/gvt/vgpu.c
> @@ -131,6 +131,13 @@ int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
>   	if (!gvt->types)
>   		return -ENOMEM;
>   
> +	gvt->mdev_types = kcalloc(num_types, sizeof(*gvt->mdev_types),
> +			     GFP_KERNEL);
> +	if (!gvt->mdev_types) {
> +		kfree(gvt->types);
> +		return -ENOMEM;
> +	}
> +
>   	min_low = MB_TO_BYTES(32);
>   	for (i = 0; i < num_types; ++i) {
>   		if (low_avail / vgpu_types[i].low_mm == 0)
> @@ -150,19 +157,21 @@ int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
>   						   high_avail / vgpu_types[i].high_mm);
>   
>   		if (GRAPHICS_VER(gvt->gt->i915) == 8)
> -			sprintf(gvt->types[i].name, "GVTg_V4_%s",
> +			sprintf(gvt->types[i].type.sysfs_name, "GVTg_V4_%s",
>   				vgpu_types[i].name);
>   		else if (GRAPHICS_VER(gvt->gt->i915) == 9)
> -			sprintf(gvt->types[i].name, "GVTg_V5_%s",
> +			sprintf(gvt->types[i].type.sysfs_name, "GVTg_V5_%s",
>   				vgpu_types[i].name);
>   
>   		gvt_dbg_core("type[%d]: %s avail %u low %u high %u fence %u weight %u res %s\n",
> -			     i, gvt->types[i].name,
> +			     i, gvt->types[i].type.sysfs_name,
>   			     gvt->types[i].avail_instance,
>   			     gvt->types[i].low_gm_size,
>   			     gvt->types[i].high_gm_size, gvt->types[i].fence,
>   			     gvt->types[i].weight,
>   			     vgpu_edid_str(gvt->types[i].resolution));
> +
> +		gvt->mdev_types[i] = &gvt->types[i].type;
>   	}
>   
>   	gvt->num_types = i;
> @@ -171,6 +180,7 @@ int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
>   
>   void intel_gvt_clean_vgpu_types(struct intel_gvt *gvt)
>   {
> +	kfree(gvt->mdev_types);
>   	kfree(gvt->types);
>   }
>   
> @@ -198,7 +208,7 @@ static void intel_gvt_update_vgpu_types(struct intel_gvt *gvt)
>   						   fence_min);
>   
>   		gvt_dbg_core("update type[%d]: %s avail %u low %u high %u fence %u\n",
> -		       i, gvt->types[i].name,
> +		       i, gvt->types[i].type.sysfs_name,
>   		       gvt->types[i].avail_instance, gvt->types[i].low_gm_size,
>   		       gvt->types[i].high_gm_size, gvt->types[i].fence);
>   	}
> diff --git a/drivers/s390/cio/cio.h b/drivers/s390/cio/cio.h
> index 22be5ac7d23c1..1da45307a1862 100644
> --- a/drivers/s390/cio/cio.h
> +++ b/drivers/s390/cio/cio.h
> @@ -110,6 +110,8 @@ struct subchannel {
>   	 */
>   	const char *driver_override;
>   	struct mdev_parent parent;
> +	struct mdev_type mdev_type;
> +	struct mdev_type *mdev_types[1];
>   } __attribute__ ((aligned(8)));
>   
>   DECLARE_PER_CPU_ALIGNED(struct irb, cio_irb);
> diff --git a/drivers/s390/cio/vfio_ccw_ops.c b/drivers/s390/cio/vfio_ccw_ops.c
> index 9192a21085ce4..25b8d42a522ac 100644
> --- a/drivers/s390/cio/vfio_ccw_ops.c
> +++ b/drivers/s390/cio/vfio_ccw_ops.c
> @@ -95,23 +95,13 @@ static ssize_t available_instances_show(struct mdev_type *mtype,
>   }
>   static MDEV_TYPE_ATTR_RO(available_instances);
>   
> -static struct attribute *mdev_types_attrs[] = {
> +static const struct attribute *mdev_types_attrs[] = {
>   	&mdev_type_attr_name.attr,
>   	&mdev_type_attr_device_api.attr,
>   	&mdev_type_attr_available_instances.attr,
>   	NULL,
>   };
>   
> -static struct attribute_group mdev_type_group = {
> -	.name  = "io",
> -	.attrs = mdev_types_attrs,
> -};
> -
> -static struct attribute_group *mdev_type_groups[] = {
> -	&mdev_type_group,
> -	NULL,
> -};
> -
>   static int vfio_ccw_mdev_probe(struct mdev_device *mdev)
>   {
>   	struct vfio_ccw_private *private = dev_get_drvdata(mdev->dev.parent);
> @@ -654,13 +644,16 @@ struct mdev_driver vfio_ccw_mdev_driver = {
>   	},
>   	.probe = vfio_ccw_mdev_probe,
>   	.remove = vfio_ccw_mdev_remove,
> -	.supported_type_groups  = mdev_type_groups,
> +	.types_attrs = mdev_types_attrs,
>   };
>   
>   int vfio_ccw_mdev_reg(struct subchannel *sch)
>   {
> +	sprintf(sch->mdev_type.sysfs_name, "io");
> +	sch->mdev_types[0] = &sch->mdev_type;
>   	return mdev_register_parent(&sch->parent, &sch->dev,
> -				    &vfio_ccw_mdev_driver);
> +				    &vfio_ccw_mdev_driver, sch->mdev_types,
> +				    1);
>   }
>   
>   void vfio_ccw_mdev_unreg(struct subchannel *sch)
> diff --git a/drivers/s390/crypto/vfio_ap_ops.c b/drivers/s390/crypto/vfio_ap_ops.c
> index 834945150dc9f..ff25858b2ebbe 100644
> --- a/drivers/s390/crypto/vfio_ap_ops.c
> +++ b/drivers/s390/crypto/vfio_ap_ops.c
> @@ -537,23 +537,13 @@ static ssize_t device_api_show(struct mdev_type *mtype,
>   
>   static MDEV_TYPE_ATTR_RO(device_api);
>   
> -static struct attribute *vfio_ap_mdev_type_attrs[] = {
> +static const struct attribute *vfio_ap_mdev_type_attrs[] = {
>   	&mdev_type_attr_name.attr,
>   	&mdev_type_attr_device_api.attr,
>   	&mdev_type_attr_available_instances.attr,
>   	NULL,
>   };
>   
> -static struct attribute_group vfio_ap_mdev_hwvirt_type_group = {
> -	.name = VFIO_AP_MDEV_TYPE_HWVIRT,
> -	.attrs = vfio_ap_mdev_type_attrs,
> -};
> -
> -static struct attribute_group *vfio_ap_mdev_type_groups[] = {
> -	&vfio_ap_mdev_hwvirt_type_group,
> -	NULL,
> -};
> -
>   struct vfio_ap_queue_reserved {
>   	unsigned long *apid;
>   	unsigned long *apqi;
> @@ -1472,7 +1462,7 @@ static struct mdev_driver vfio_ap_matrix_driver = {
>   	},
>   	.probe = vfio_ap_mdev_probe,
>   	.remove = vfio_ap_mdev_remove,
> -	.supported_type_groups = vfio_ap_mdev_type_groups,
> +	.types_attrs = vfio_ap_mdev_type_attrs,
>   };
>   
>   int vfio_ap_mdev_register(void)
> @@ -1485,8 +1475,11 @@ int vfio_ap_mdev_register(void)
>   	if (ret)
>   		return ret;
>   
> +	strcpy(matrix_dev->mdev_type.sysfs_name, VFIO_AP_MDEV_TYPE_HWVIRT);
> +	matrix_dev->mdev_types[0] = &matrix_dev->mdev_type;
>   	ret = mdev_register_parent(&matrix_dev->parent, &matrix_dev->device,
> -				   &vfio_ap_matrix_driver);
> +				   &vfio_ap_matrix_driver,
> +				   matrix_dev->mdev_types, 1);
>   	if (ret)
>   		goto err_driver;
>   	return 0;
> diff --git a/drivers/s390/crypto/vfio_ap_private.h b/drivers/s390/crypto/vfio_ap_private.h
> index 0191f6bc973a4..5dc5050d03791 100644
> --- a/drivers/s390/crypto/vfio_ap_private.h
> +++ b/drivers/s390/crypto/vfio_ap_private.h
> @@ -46,6 +46,8 @@ struct ap_matrix_dev {
>   	struct mutex lock;
>   	struct ap_driver  *vfio_ap_drv;
>   	struct mdev_parent parent;
> +	struct mdev_type mdev_type;
> +	struct mdev_type *mdev_types[];
>   };
>   
>   extern struct ap_matrix_dev *matrix_dev;
> diff --git a/drivers/vfio/mdev/mdev_core.c b/drivers/vfio/mdev/mdev_core.c
> index d38faed14c689..71c7f4e521a74 100644
> --- a/drivers/vfio/mdev/mdev_core.c
> +++ b/drivers/vfio/mdev/mdev_core.c
> @@ -29,26 +29,6 @@ struct device *mdev_parent_dev(struct mdev_device *mdev)
>   }
>   EXPORT_SYMBOL(mdev_parent_dev);
>   
> -/*
> - * Return the index in supported_type_groups that this mdev_device was created
> - * from.
> - */
> -unsigned int mdev_get_type_group_id(struct mdev_device *mdev)
> -{
> -	return mdev->type->type_group_id;
> -}
> -EXPORT_SYMBOL(mdev_get_type_group_id);
> -
> -/*
> - * Used in mdev_type_attribute sysfs functions to return the index in the
> - * supported_type_groups that the sysfs is called from.
> - */
> -unsigned int mtype_get_type_group_id(struct mdev_type *mtype)
> -{
> -	return mtype->type_group_id;
> -}
> -EXPORT_SYMBOL(mtype_get_type_group_id);
> -
>   /*
>    * Used in mdev_type_attribute sysfs functions to return the parent struct
>    * device
> @@ -85,19 +65,19 @@ static int mdev_device_remove_cb(struct device *dev, void *data)
>    * @parent: parent structure registered
>    * @dev: device structure representing parent device.
>    * @mdev_driver: Device driver to bind to the newly created mdev
> + * @types: Array of supported mdev types
> + * @nr_types: Number of entries in @types
>    *
>    * Returns a negative value on error, otherwise 0.
>    */
>   int mdev_register_parent(struct mdev_parent *parent, struct device *dev,
> -		struct mdev_driver *mdev_driver)
> +		struct mdev_driver *mdev_driver, struct mdev_type **types,
> +		unsigned int nr_types)
>   {
>   	char *env_string = "MDEV_STATE=registered";
>   	char *envp[] = { env_string, NULL };
>   	int ret;
> -
> -	/* check for mandatory ops */
> -	if (!mdev_driver->supported_type_groups)
> -		return -EINVAL;
> +	int i;
>   
>   	memset(parent, 0, sizeof(*parent));
>   	init_rwsem(&parent->unreg_sem);
> @@ -110,9 +90,23 @@ int mdev_register_parent(struct mdev_parent *parent, struct device *dev,
>   			return -ENOMEM;
>   	}
>   
> -	ret = parent_create_sysfs_files(parent);
> -	if (ret)
> +	parent->mdev_types_kset = kset_create_and_add("mdev_supported_types",
> +					       NULL, &parent->dev->kobj);
> +	if (!parent->mdev_types_kset)
> +		return -ENOMEM;
> +
> +	for (i = 0; i < nr_types; i++) {
> +		ret = mdev_type_add(parent, types[i]);
> +		if (ret)
> +			break;
> +	}
> +	parent->types = types;
> +	parent->nr_types = i;
> +
> +	if (ret) {
> +		mdev_unregister_parent(parent);
>   		return ret;
> +	}
>   
>   	ret = class_compat_create_link(mdev_bus_compat_class, dev, NULL);
>   	if (ret)
> @@ -132,13 +126,17 @@ void mdev_unregister_parent(struct mdev_parent *parent)
>   {
>   	char *env_string = "MDEV_STATE=unregistered";
>   	char *envp[] = { env_string, NULL };
> +	int i;
>   
>   	dev_info(parent->dev, "MDEV: Unregistering\n");
>   
> +	for (i = 0; i < parent->nr_types; i++)
> +		mdev_type_remove(parent->types[i]);
> +									

trailing spaces.

>   	down_write(&parent->unreg_sem);
>   	class_compat_remove_link(mdev_bus_compat_class, parent->dev, NULL);
>   	device_for_each_child(parent->dev, NULL, mdev_device_remove_cb);
> -	parent_remove_sysfs_files(parent);
> +	kset_unregister(parent->mdev_types_kset);
>   	up_write(&parent->unreg_sem);
>   
>   	kobject_uevent_env(&parent->dev->kobj, KOBJ_CHANGE, envp);
> diff --git a/drivers/vfio/mdev/mdev_driver.c b/drivers/vfio/mdev/mdev_driver.c
> index 7bd4bb9850e81..1da1ecf76a0d5 100644
> --- a/drivers/vfio/mdev/mdev_driver.c
> +++ b/drivers/vfio/mdev/mdev_driver.c
> @@ -56,10 +56,9 @@ EXPORT_SYMBOL_GPL(mdev_bus_type);
>    **/
>   int mdev_register_driver(struct mdev_driver *drv)
>   {
> -	/* initialize common driver fields */
> +	if (!drv->types_attrs)
> +		return -EINVAL;
>   	drv->driver.bus = &mdev_bus_type;
> -
> -	/* register with core */
>   	return driver_register(&drv->driver);
>   }
>   EXPORT_SYMBOL(mdev_register_driver);
> diff --git a/drivers/vfio/mdev/mdev_private.h b/drivers/vfio/mdev/mdev_private.h
> index 297f911fdc890..6980f504018f3 100644
> --- a/drivers/vfio/mdev/mdev_private.h
> +++ b/drivers/vfio/mdev/mdev_private.h
> @@ -13,14 +13,6 @@
>   int  mdev_bus_register(void);
>   void mdev_bus_unregister(void);
>   
> -struct mdev_type {
> -	struct kobject kobj;
> -	struct kobject *devices_kobj;
> -	struct mdev_parent *parent;
> -	struct list_head next;
> -	unsigned int type_group_id;
> -};
> -
>   extern const struct attribute_group *mdev_device_groups[];
>   
>   #define to_mdev_type_attr(_attr)	\
> @@ -28,12 +20,12 @@ extern const struct attribute_group *mdev_device_groups[];
>   #define to_mdev_type(_kobj)		\
>   	container_of(_kobj, struct mdev_type, kobj)
>   
> -int  parent_create_sysfs_files(struct mdev_parent *parent);
> -void parent_remove_sysfs_files(struct mdev_parent *parent);
> -
>   int  mdev_create_sysfs_files(struct mdev_device *mdev);
>   void mdev_remove_sysfs_files(struct mdev_device *mdev);
>   
> +int mdev_type_add(struct mdev_parent *parent, struct mdev_type *type);
> +void mdev_type_remove(struct mdev_type *type);
> +
>   int mdev_device_create(struct mdev_type *kobj, const guid_t *uuid);
>   int  mdev_device_remove(struct mdev_device *dev);
>   
> diff --git a/drivers/vfio/mdev/mdev_sysfs.c b/drivers/vfio/mdev/mdev_sysfs.c
> index b71ffc5594870..09745e8e120f9 100644
> --- a/drivers/vfio/mdev/mdev_sysfs.c
> +++ b/drivers/vfio/mdev/mdev_sysfs.c
> @@ -80,8 +80,6 @@ static void mdev_type_release(struct kobject *kobj)
>   	struct mdev_type *type = to_mdev_type(kobj);
>   
>   	pr_debug("Releasing group %s\n", kobj->name);
> -	/* Pairs with the get in add_mdev_supported_type() */
> -	put_device(type->parent->dev);
>   	kfree(type);
>   }
>   
> @@ -90,36 +88,17 @@ static struct kobj_type mdev_type_ktype = {
>   	.release = mdev_type_release,
>   };
>   
> -static struct mdev_type *add_mdev_supported_type(struct mdev_parent *parent,
> -						 unsigned int type_group_id)
> +int mdev_type_add(struct mdev_parent *parent, struct mdev_type *type)
>   {
> -	struct mdev_type *type;
> -	struct attribute_group *group =
> -		parent->mdev_driver->supported_type_groups[type_group_id];
>   	int ret;
>   
> -	if (!group->name) {
> -		pr_err("%s: Type name empty!\n", __func__);
> -		return ERR_PTR(-EINVAL);
> -	}
> -
> -	type = kzalloc(sizeof(*type), GFP_KERNEL);
> -	if (!type)
> -		return ERR_PTR(-ENOMEM);
> -
> -	type->kobj.kset = parent->mdev_types_kset;
>   	type->parent = parent;
> -	/* Pairs with the put in mdev_type_release() */
> -	get_device(parent->dev);
> -	type->type_group_id = type_group_id;
> -
> -	ret = kobject_init_and_add(&type->kobj, &mdev_type_ktype, NULL,
> -				   "%s-%s", dev_driver_string(parent->dev),
> -				   group->name);
> -	if (ret) {
> -		kobject_put(&type->kobj);
> -		return ERR_PTR(ret);
> -	}
> +	type->kobj.kset = parent->mdev_types_kset;
> +	ret = kobject_init_and_add(&type->kobj, &mdev_type_ktype, NULL, "%s-%s",
> +				   dev_driver_string(parent->dev),
> +				   type->sysfs_name);
> +	if (ret)
> +		return ret;
>   
>   	ret = sysfs_create_file(&type->kobj, &mdev_type_attr_create.attr);
>   	if (ret)
> @@ -131,13 +110,10 @@ static struct mdev_type *add_mdev_supported_type(struct mdev_parent *parent,
>   		goto attr_devices_failed;
>   	}
>   
> -	ret = sysfs_create_files(&type->kobj,
> -				 (const struct attribute **)group->attrs);
> -	if (ret) {
> -		ret = -ENOMEM;
> +	ret = sysfs_create_files(&type->kobj, parent->mdev_driver->types_attrs);
> +	if (ret)
>   		goto attrs_failed;
> -	}
> -	return type;
> +	return 0;
>   
>   attrs_failed:
>   	kobject_put(type->devices_kobj);
> @@ -146,80 +122,19 @@ static struct mdev_type *add_mdev_supported_type(struct mdev_parent *parent,
>   attr_create_failed:
>   	kobject_del(&type->kobj);
>   	kobject_put(&type->kobj);
> -	return ERR_PTR(ret);
> +	return ret;
>   }
>   
> -static void remove_mdev_supported_type(struct mdev_type *type)
> +void mdev_type_remove(struct mdev_type *type)
>   {
> -	struct attribute_group *group =
> -		type->parent->mdev_driver->supported_type_groups[type->type_group_id];
> +	sysfs_remove_files(&type->kobj, type->parent->mdev_driver->types_attrs);
>   
> -	sysfs_remove_files(&type->kobj,
> -			   (const struct attribute **)group->attrs);
>   	kobject_put(type->devices_kobj);
>   	sysfs_remove_file(&type->kobj, &mdev_type_attr_create.attr);
>   	kobject_del(&type->kobj);
>   	kobject_put(&type->kobj);
>   }
>   
> -static int add_mdev_supported_type_groups(struct mdev_parent *parent)
> -{
> -	int i;
> -
> -	for (i = 0; parent->mdev_driver->supported_type_groups[i]; i++) {
> -		struct mdev_type *type;
> -
> -		type = add_mdev_supported_type(parent, i);
> -		if (IS_ERR(type)) {
> -			struct mdev_type *ltype, *tmp;
> -
> -			list_for_each_entry_safe(ltype, tmp, &parent->type_list,
> -						  next) {
> -				list_del(&ltype->next);
> -				remove_mdev_supported_type(ltype);
> -			}
> -			return PTR_ERR(type);
> -		}
> -		list_add(&type->next, &parent->type_list);
> -	}
> -	return 0;
> -}
> -
> -/* mdev sysfs functions */
> -void parent_remove_sysfs_files(struct mdev_parent *parent)
> -{
> -	struct mdev_type *type, *tmp;
> -
> -	list_for_each_entry_safe(type, tmp, &parent->type_list, next) {
> -		list_del(&type->next);
> -		remove_mdev_supported_type(type);
> -	}
> -
> -	kset_unregister(parent->mdev_types_kset);
> -}
> -
> -int parent_create_sysfs_files(struct mdev_parent *parent)
> -{
> -	int ret;
> -
> -	parent->mdev_types_kset = kset_create_and_add("mdev_supported_types",
> -					       NULL, &parent->dev->kobj);
> -
> -	if (!parent->mdev_types_kset)
> -		return -ENOMEM;
> -
> -	INIT_LIST_HEAD(&parent->type_list);
> -
> -	ret = add_mdev_supported_type_groups(parent);
> -	if (ret)
> -		goto create_err;
> -	return 0;
> -
> -create_err:
> -	kset_unregister(parent->mdev_types_kset);
> -	return ret;
> -}
> -
>   static ssize_t remove_store(struct device *dev, struct device_attribute *attr,
>   			    const char *buf, size_t count)
>   {
> diff --git a/include/linux/mdev.h b/include/linux/mdev.h
> index 327ce3e5c6b5f..d28ddf0f561a0 100644
> --- a/include/linux/mdev.h
> +++ b/include/linux/mdev.h
> @@ -23,14 +23,27 @@ struct mdev_device {
>   	bool active;
>   };
>   
> +struct mdev_type {
> +	/* set by the driver before calling mdev_register parent: */
> +	char sysfs_name[32];
> +
> +	/* set by the core, can be used drivers */
> +	struct mdev_parent *parent;
> +
> +	/* internal only */
> +	struct kobject kobj;
> +	struct kobject *devices_kobj;
> +};
> +
>   /* embedded into the struct device that the mdev devices hang off */
>   struct mdev_parent {
>   	struct device *dev;
>   	struct mdev_driver *mdev_driver;
>   	struct kset *mdev_types_kset;
> -	struct list_head type_list;
>   	/* Synchronize device creation/removal with parent unregistration */
>   	struct rw_semaphore unreg_sem;
> +	struct mdev_type **types;
> +	unsigned int nr_types;
>   };
>   
>   static inline struct mdev_device *to_mdev_device(struct device *dev)
> @@ -38,8 +51,6 @@ static inline struct mdev_device *to_mdev_device(struct device *dev)
>   	return container_of(dev, struct mdev_device, dev);
>   }
>   
> -unsigned int mdev_get_type_group_id(struct mdev_device *mdev);
> -unsigned int mtype_get_type_group_id(struct mdev_type *mtype);
>   struct device *mtype_get_parent_dev(struct mdev_type *mtype);
>   
>   /* interface for exporting mdev supported type attributes */
> @@ -66,15 +77,13 @@ struct mdev_type_attribute mdev_type_attr_##_name =		\
>    * struct mdev_driver - Mediated device driver
>    * @probe: called when new device created
>    * @remove: called when device removed
> - * @supported_type_groups: Attributes to define supported types. It is mandatory
> - *			to provide supported types.
> + * @types_attrs: attributes to the type kobjects.
>    * @driver: device driver structure
> - *
>    **/
>   struct mdev_driver {
>   	int (*probe)(struct mdev_device *dev);
>   	void (*remove)(struct mdev_device *dev);
> -	struct attribute_group **supported_type_groups;
> +	const struct attribute * const *types_attrs;
>   	struct device_driver driver;
>   };
>   
> @@ -86,7 +95,8 @@ static inline const guid_t *mdev_uuid(struct mdev_device *mdev)
>   extern struct bus_type mdev_bus_type;
>   
>   int mdev_register_parent(struct mdev_parent *parent, struct device *dev,
> -		struct mdev_driver *mdev_driver);
> +		struct mdev_driver *mdev_driver, struct mdev_type **types,
> +		unsigned int nr_types);
>   void mdev_unregister_parent(struct mdev_parent *parent);
>   
>   int mdev_register_driver(struct mdev_driver *drv);
> diff --git a/samples/vfio-mdev/mbochs.c b/samples/vfio-mdev/mbochs.c
> index 30b3643b3b389..1069f561cb012 100644
> --- a/samples/vfio-mdev/mbochs.c
> +++ b/samples/vfio-mdev/mbochs.c
> @@ -99,23 +99,27 @@ MODULE_PARM_DESC(mem, "megabytes available to " MBOCHS_NAME " devices");
>   #define MBOCHS_TYPE_2 "medium"
>   #define MBOCHS_TYPE_3 "large"
>   
> -static const struct mbochs_type {
> +static struct mbochs_type {
> +	struct mdev_type type;
>   	const char *name;
>   	u32 mbytes;
>   	u32 max_x;
>   	u32 max_y;
>   } mbochs_types[] = {
>   	{
> +		.type.sysfs_name	= MBOCHS_TYPE_1,
>   		.name	= MBOCHS_CLASS_NAME "-" MBOCHS_TYPE_1,
>   		.mbytes = 4,
>   		.max_x  = 800,
>   		.max_y  = 600,
>   	}, {
> +		.type.sysfs_name	= MBOCHS_TYPE_2,
>   		.name	= MBOCHS_CLASS_NAME "-" MBOCHS_TYPE_2,
>   		.mbytes = 16,
>   		.max_x  = 1920,
>   		.max_y  = 1440,
>   	}, {
> +		.type.sysfs_name	= MBOCHS_TYPE_3,
>   		.name	= MBOCHS_CLASS_NAME "-" MBOCHS_TYPE_3,
>   		.mbytes = 64,
>   		.max_x  = 0,
> @@ -123,6 +127,11 @@ static const struct mbochs_type {
>   	},
>   };
>   
> +static struct mdev_type *mbochs_mdev_types[] = {
> +	&mbochs_types[0].type,
> +	&mbochs_types[1].type,
> +	&mbochs_types[2].type,
> +};
>   
>   static dev_t		mbochs_devt;
>   static struct class	*mbochs_class;
> @@ -508,8 +517,8 @@ static int mbochs_reset(struct mdev_state *mdev_state)
>   static int mbochs_probe(struct mdev_device *mdev)
>   {
>   	int avail_mbytes = atomic_read(&mbochs_avail_mbytes);
> -	const struct mbochs_type *type =
> -		&mbochs_types[mdev_get_type_group_id(mdev)];
> +	struct mbochs_type *type =
> +		container_of(mdev->type, struct mbochs_type, type);
>   	struct device *dev = mdev_dev(mdev);
>   	struct mdev_state *mdev_state;
>   	int ret = -ENOMEM;
> @@ -1328,8 +1337,8 @@ static const struct attribute_group *mdev_dev_groups[] = {
>   static ssize_t name_show(struct mdev_type *mtype,
>   			 struct mdev_type_attribute *attr, char *buf)
>   {
> -	const struct mbochs_type *type =
> -		&mbochs_types[mtype_get_type_group_id(mtype)];
> +	struct mbochs_type *type =
> +		container_of(mtype, struct mbochs_type, type);
>   
>   	return sprintf(buf, "%s\n", type->name);
>   }
> @@ -1338,8 +1347,8 @@ static MDEV_TYPE_ATTR_RO(name);
>   static ssize_t description_show(struct mdev_type *mtype,
>   				struct mdev_type_attribute *attr, char *buf)
>   {
> -	const struct mbochs_type *type =
> -		&mbochs_types[mtype_get_type_group_id(mtype)];
> +	struct mbochs_type *type =
> +		container_of(mtype, struct mbochs_type, type);
>   
>   	return sprintf(buf, "virtual display, %d MB video memory\n",
>   		       type ? type->mbytes  : 0);
> @@ -1350,8 +1359,8 @@ static ssize_t available_instances_show(struct mdev_type *mtype,
>   					struct mdev_type_attribute *attr,
>   					char *buf)
>   {
> -	const struct mbochs_type *type =
> -		&mbochs_types[mtype_get_type_group_id(mtype)];
> +	struct mbochs_type *type =
> +		container_of(mtype, struct mbochs_type, type);
>   	int count = atomic_read(&mbochs_avail_mbytes) / type->mbytes;
>   
>   	return sprintf(buf, "%d\n", count);
> @@ -1365,7 +1374,7 @@ static ssize_t device_api_show(struct mdev_type *mtype,
>   }
>   static MDEV_TYPE_ATTR_RO(device_api);
>   
> -static struct attribute *mdev_types_attrs[] = {
> +static const struct attribute *mdev_types_attrs[] = {
>   	&mdev_type_attr_name.attr,
>   	&mdev_type_attr_description.attr,
>   	&mdev_type_attr_device_api.attr,
> @@ -1373,28 +1382,6 @@ static struct attribute *mdev_types_attrs[] = {
>   	NULL,
>   };
>   
> -static struct attribute_group mdev_type_group1 = {
> -	.name  = MBOCHS_TYPE_1,
> -	.attrs = mdev_types_attrs,
> -};
> -
> -static struct attribute_group mdev_type_group2 = {
> -	.name  = MBOCHS_TYPE_2,
> -	.attrs = mdev_types_attrs,
> -};
> -
> -static struct attribute_group mdev_type_group3 = {
> -	.name  = MBOCHS_TYPE_3,
> -	.attrs = mdev_types_attrs,
> -};
> -
> -static struct attribute_group *mdev_type_groups[] = {
> -	&mdev_type_group1,
> -	&mdev_type_group2,
> -	&mdev_type_group3,
> -	NULL,
> -};
> -
>   static const struct vfio_device_ops mbochs_dev_ops = {
>   	.close_device = mbochs_close_device,
>   	.read = mbochs_read,
> @@ -1412,7 +1399,7 @@ static struct mdev_driver mbochs_driver = {
>   	},
>   	.probe = mbochs_probe,
>   	.remove	= mbochs_remove,
> -	.supported_type_groups = mdev_type_groups,
> +	.types_attrs = mdev_types_attrs,
>   };
>   
>   static const struct file_operations vd_fops = {
> @@ -1457,7 +1444,9 @@ static int __init mbochs_dev_init(void)
>   	if (ret)
>   		goto err_class;
>   
> -	ret = mdev_register_parent(&mbochs_parent, &mbochs_dev, &mbochs_driver);
> +	ret = mdev_register_parent(&mbochs_parent, &mbochs_dev, &mbochs_driver,
> +				   mbochs_mdev_types,
> +				   ARRAY_SIZE(mbochs_mdev_types));
>   	if (ret)
>   		goto err_device;
>   
> diff --git a/samples/vfio-mdev/mdpy.c b/samples/vfio-mdev/mdpy.c
> index 132bb055628a6..40b1c8a58157c 100644
> --- a/samples/vfio-mdev/mdpy.c
> +++ b/samples/vfio-mdev/mdpy.c
> @@ -51,7 +51,8 @@ MODULE_PARM_DESC(count, "number of " MDPY_NAME " devices");
>   #define MDPY_TYPE_2 "xga"
>   #define MDPY_TYPE_3 "hd"
>   
> -static const struct mdpy_type {
> +static struct mdpy_type {
> +	struct mdev_type type;
>   	const char *name;
>   	u32 format;
>   	u32 bytepp;
> @@ -59,18 +60,21 @@ static const struct mdpy_type {
>   	u32 height;
>   } mdpy_types[] = {
>   	{
> +		.type.sysfs_name 	= MDPY_TYPE_1,
>   		.name	= MDPY_CLASS_NAME "-" MDPY_TYPE_1,
>   		.format = DRM_FORMAT_XRGB8888,
>   		.bytepp = 4,
>   		.width	= 640,
>   		.height = 480,
>   	}, {
> +		.type.sysfs_name 	= MDPY_TYPE_2,
>   		.name	= MDPY_CLASS_NAME "-" MDPY_TYPE_2,
>   		.format = DRM_FORMAT_XRGB8888,
>   		.bytepp = 4,
>   		.width	= 1024,
>   		.height = 768,
>   	}, {
> +		.type.sysfs_name 	= MDPY_TYPE_3,
>   		.name	= MDPY_CLASS_NAME "-" MDPY_TYPE_3,
>   		.format = DRM_FORMAT_XRGB8888,
>   		.bytepp = 4,
> @@ -78,6 +82,12 @@ static const struct mdpy_type {
>   		.height = 1080,
>   	},
>   };
> +	

trailing spaces.

> +static struct mdev_type *mdpy_mdev_types[] = {
> +	&mdpy_types[0].type,
> +	&mdpy_types[1].type,
> +	&mdpy_types[2].type,
> +};
>   
>   static dev_t		mdpy_devt;
>   static struct class	*mdpy_class;
> @@ -219,7 +229,7 @@ static int mdpy_reset(struct mdev_state *mdev_state)
>   static int mdpy_probe(struct mdev_device *mdev)
>   {
>   	const struct mdpy_type *type =
> -		&mdpy_types[mdev_get_type_group_id(mdev)];
> +		container_of(mdev->type, struct mdpy_type, type);
>   	struct device *dev = mdev_dev(mdev);
>   	struct mdev_state *mdev_state;
>   	u32 fbsize;
> @@ -644,8 +654,7 @@ static const struct attribute_group *mdev_dev_groups[] = {
>   static ssize_t name_show(struct mdev_type *mtype,
>   			 struct mdev_type_attribute *attr, char *buf)
>   {
> -	const struct mdpy_type *type =
> -		&mdpy_types[mtype_get_type_group_id(mtype)];
> +	struct mdpy_type *type = container_of(mtype, struct mdpy_type, type);
>   
>   	return sprintf(buf, "%s\n", type->name);
>   }
> @@ -654,8 +663,7 @@ static MDEV_TYPE_ATTR_RO(name);
>   static ssize_t description_show(struct mdev_type *mtype,
>   				struct mdev_type_attribute *attr, char *buf)
>   {
> -	const struct mdpy_type *type =
> -		&mdpy_types[mtype_get_type_group_id(mtype)];
> +	struct mdpy_type *type = container_of(mtype, struct mdpy_type, type);
>   
>   	return sprintf(buf, "virtual display, %dx%d framebuffer\n",
>   		       type->width, type->height);
> @@ -677,7 +685,7 @@ static ssize_t device_api_show(struct mdev_type *mtype,
>   }
>   static MDEV_TYPE_ATTR_RO(device_api);
>   
> -static struct attribute *mdev_types_attrs[] = {
> +static const struct attribute *mdev_types_attrs[] = {
>   	&mdev_type_attr_name.attr,
>   	&mdev_type_attr_description.attr,
>   	&mdev_type_attr_device_api.attr,
> @@ -685,28 +693,6 @@ static struct attribute *mdev_types_attrs[] = {
>   	NULL,
>   };
>   
> -static struct attribute_group mdev_type_group1 = {
> -	.name  = MDPY_TYPE_1,
> -	.attrs = mdev_types_attrs,
> -};
> -
> -static struct attribute_group mdev_type_group2 = {
> -	.name  = MDPY_TYPE_2,
> -	.attrs = mdev_types_attrs,
> -};
> -
> -static struct attribute_group mdev_type_group3 = {
> -	.name  = MDPY_TYPE_3,
> -	.attrs = mdev_types_attrs,
> -};
> -
> -static struct attribute_group *mdev_type_groups[] = {
> -	&mdev_type_group1,
> -	&mdev_type_group2,
> -	&mdev_type_group3,
> -	NULL,
> -};
> -
>   static const struct vfio_device_ops mdpy_dev_ops = {
>   	.read = mdpy_read,
>   	.write = mdpy_write,
> @@ -723,7 +709,7 @@ static struct mdev_driver mdpy_driver = {
>   	},
>   	.probe = mdpy_probe,
>   	.remove	= mdpy_remove,
> -	.supported_type_groups = mdev_type_groups,
> +	.types_attrs = mdev_types_attrs,
>   };
>   
>   static const struct file_operations vd_fops = {
> @@ -766,7 +752,9 @@ static int __init mdpy_dev_init(void)
>   	if (ret)
>   		goto err_class;
>   
> -	ret = mdev_register_parent(&mdpy_parent, &mdpy_dev, &mdpy_driver);
> +	ret = mdev_register_parent(&mdpy_parent, &mdpy_dev, &mdpy_driver,
> +				   mdpy_mdev_types,
> +				   ARRAY_SIZE(mdpy_mdev_types));
>   	if (ret)
>   		goto err_device;
>   
> diff --git a/samples/vfio-mdev/mtty.c b/samples/vfio-mdev/mtty.c
> index 8ba5f6084a093..029a19ef8ce7b 100644
> --- a/samples/vfio-mdev/mtty.c
> +++ b/samples/vfio-mdev/mtty.c
> @@ -143,6 +143,20 @@ struct mdev_state {
>   	int nr_ports;
>   };
>   
> +static struct mtty_type {
> +	struct mdev_type type;
> +	int nr_ports;
> +	const char *name;
> +} mtty_types[2] = {
> +	{ .nr_ports = 1, .type.sysfs_name = "1", .name = "Single port serial" },
> +	{ .nr_ports = 2, .type.sysfs_name = "2", .name = "Dual port serial" },
> +};
> +
> +static struct mdev_type *mtty_mdev_types[] = {
> +	&mtty_types[0].type,
> +	&mtty_types[1].type,
> +};
> +
>   static atomic_t mdev_avail_ports = ATOMIC_INIT(MAX_MTTYS);
>   
>   static const struct file_operations vd_fops = {
> @@ -704,16 +718,18 @@ static ssize_t mdev_access(struct mdev_state *mdev_state, u8 *buf, size_t count,
>   
>   static int mtty_probe(struct mdev_device *mdev)
>   {
> +	struct mtty_type *type =
> +		container_of(mdev->type, struct mtty_type, type);
>   	struct mdev_state *mdev_state;
> -	int nr_ports = mdev_get_type_group_id(mdev) + 1;
>   	int avail_ports = atomic_read(&mdev_avail_ports);
>   	int ret;
>   
>   	do {
> -		if (avail_ports < nr_ports)
> +		if (avail_ports < type->nr_ports)
>   			return -ENOSPC;
>   	} while (!atomic_try_cmpxchg(&mdev_avail_ports,
> -				     &avail_ports, avail_ports - nr_ports));
> +				     &avail_ports,
> +				     avail_ports - type->nr_ports));
>   
>   	mdev_state = kzalloc(sizeof(struct mdev_state), GFP_KERNEL);
>   	if (mdev_state == NULL) {
> @@ -723,13 +739,13 @@ static int mtty_probe(struct mdev_device *mdev)
>   
>   	vfio_init_group_dev(&mdev_state->vdev, &mdev->dev, &mtty_dev_ops);
>   
> -	mdev_state->nr_ports = nr_ports;
> +	mdev_state->nr_ports = type->nr_ports;
>   	mdev_state->irq_index = -1;
>   	mdev_state->s[0].max_fifo_size = MAX_FIFO_SIZE;
>   	mdev_state->s[1].max_fifo_size = MAX_FIFO_SIZE;
>   	mutex_init(&mdev_state->rxtx_lock);
> -	mdev_state->vconfig = kzalloc(MTTY_CONFIG_SPACE_SIZE, GFP_KERNEL);
>   
> +	mdev_state->vconfig = kzalloc(MTTY_CONFIG_SPACE_SIZE, GFP_KERNEL);
>   	if (mdev_state->vconfig == NULL) {
>   		ret = -ENOMEM;
>   		goto err_state;
> @@ -752,7 +768,7 @@ static int mtty_probe(struct mdev_device *mdev)
>   	vfio_uninit_group_dev(&mdev_state->vdev);
>   	kfree(mdev_state);
>   err_nr_ports:
> -	atomic_add(nr_ports, &mdev_avail_ports);
> +	atomic_add(type->nr_ports, &mdev_avail_ports);
>   	return ret;
>   }
>   
> @@ -1233,11 +1249,9 @@ static const struct attribute_group *mdev_dev_groups[] = {
>   static ssize_t name_show(struct mdev_type *mtype,
>   			 struct mdev_type_attribute *attr, char *buf)
>   {
> -	static const char *name_str[2] = { "Single port serial",
> -					   "Dual port serial" };
> +	struct mtty_type *type = container_of(mtype, struct mtty_type, type);
>   
> -	return sysfs_emit(buf, "%s\n",
> -			  name_str[mtype_get_type_group_id(mtype)]);
> +	return sysfs_emit(buf, "%s\n", type->name);
>   }
>   
>   static MDEV_TYPE_ATTR_RO(name);
> @@ -1246,9 +1260,10 @@ static ssize_t available_instances_show(struct mdev_type *mtype,
>   					struct mdev_type_attribute *attr,
>   					char *buf)
>   {
> -	unsigned int ports = mtype_get_type_group_id(mtype) + 1;
> +	struct mtty_type *type = container_of(mtype, struct mtty_type, type);
>   
> -	return sprintf(buf, "%d\n", atomic_read(&mdev_avail_ports) / ports);
> +	return sprintf(buf, "%d\n", atomic_read(&mdev_avail_ports) /
> +			type->nr_ports);
>   }
>   
>   static MDEV_TYPE_ATTR_RO(available_instances);
> @@ -1261,29 +1276,13 @@ static ssize_t device_api_show(struct mdev_type *mtype,
>   
>   static MDEV_TYPE_ATTR_RO(device_api);
>   
> -static struct attribute *mdev_types_attrs[] = {
> +static const struct attribute *mdev_types_attrs[] = {
>   	&mdev_type_attr_name.attr,
>   	&mdev_type_attr_device_api.attr,
>   	&mdev_type_attr_available_instances.attr,
>   	NULL,
>   };
>   
> -static struct attribute_group mdev_type_group1 = {
> -	.name  = "1",
> -	.attrs = mdev_types_attrs,
> -};
> -
> -static struct attribute_group mdev_type_group2 = {
> -	.name  = "2",
> -	.attrs = mdev_types_attrs,
> -};
> -
> -static struct attribute_group *mdev_type_groups[] = {
> -	&mdev_type_group1,
> -	&mdev_type_group2,
> -	NULL,
> -};
> -
>   static const struct vfio_device_ops mtty_dev_ops = {
>   	.name = "vfio-mtty",
>   	.read = mtty_read,
> @@ -1300,7 +1299,7 @@ static struct mdev_driver mtty_driver = {
>   	},
>   	.probe = mtty_probe,
>   	.remove	= mtty_remove,
> -	.supported_type_groups = mdev_type_groups,
> +	.types_attrs = mdev_types_attrs,
>   };
>   
>   static void mtty_device_release(struct device *dev)
> @@ -1352,7 +1351,8 @@ static int __init mtty_dev_init(void)
>   		goto err_class;
>   
>   	ret = mdev_register_parent(&mtty_dev.parent, &mtty_dev.dev,
> -				   &mtty_driver);
> +				   &mtty_driver, mtty_mdev_types,
> +				   ARRAY_SIZE(mtty_mdev_types));
>   	if (ret)
>   		goto err_device;
>   	return 0;
Tian, Kevin June 14, 2022, 10:06 a.m. UTC | #2
> From: Christoph Hellwig
> Sent: Tuesday, June 14, 2022 12:54 PM
> @@ -131,6 +131,13 @@ int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
>  	if (!gvt->types)
>  		return -ENOMEM;
> 
> +	gvt->mdev_types = kcalloc(num_types, sizeof(*gvt->mdev_types),
> +			     GFP_KERNEL);
> +	if (!gvt->mdev_types) {
> +		kfree(gvt->types);
> +		return -ENOMEM;
> +	}
> +
>  	min_low = MB_TO_BYTES(32);
>  	for (i = 0; i < num_types; ++i) {
>  		if (low_avail / vgpu_types[i].low_mm == 0)
> @@ -150,19 +157,21 @@ int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
>  						   high_avail /
> vgpu_types[i].high_mm);

there is a memory leak a few lines above:

if (vgpu_types[i].weight < 1 ||
	vgpu_types[i].weight > VGPU_MAX_WEIGHT)
	return -EINVAL;

both old code and this patch forgot to free the buffers upon error.
Zhenyu Wang June 15, 2022, 6:11 a.m. UTC | #3
On 2022.06.15 08:28:40 +0200, Christoph Hellwig wrote:
> On Tue, Jun 14, 2022 at 10:06:16AM +0000, Tian, Kevin wrote:
> > > +	gvt->mdev_types = kcalloc(num_types, sizeof(*gvt->mdev_types),
> > > +			     GFP_KERNEL);
> > > +	if (!gvt->mdev_types) {
> > > +		kfree(gvt->types);
> > > +		return -ENOMEM;
> > > +	}
> > > +
> > >  	min_low = MB_TO_BYTES(32);
> > >  	for (i = 0; i < num_types; ++i) {
> > >  		if (low_avail / vgpu_types[i].low_mm == 0)
> > > @@ -150,19 +157,21 @@ int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
> > >  						   high_avail /
> > > vgpu_types[i].high_mm);
> > 
> > there is a memory leak a few lines above:
> > 
> > if (vgpu_types[i].weight < 1 ||
> > 	vgpu_types[i].weight > VGPU_MAX_WEIGHT)
> > 	return -EINVAL;
> > 
> > both old code and this patch forgot to free the buffers upon error.

Sorry about that! It should be my original fault...

> 
> Yeah.  I'll add a patch to the beginning of the series to fix the
> existing leak and will then make sure to not leak the new allocation
> either.

Thanks a lot!
Christoph Hellwig June 15, 2022, 6:28 a.m. UTC | #4
On Tue, Jun 14, 2022 at 10:06:16AM +0000, Tian, Kevin wrote:
> > +	gvt->mdev_types = kcalloc(num_types, sizeof(*gvt->mdev_types),
> > +			     GFP_KERNEL);
> > +	if (!gvt->mdev_types) {
> > +		kfree(gvt->types);
> > +		return -ENOMEM;
> > +	}
> > +
> >  	min_low = MB_TO_BYTES(32);
> >  	for (i = 0; i < num_types; ++i) {
> >  		if (low_avail / vgpu_types[i].low_mm == 0)
> > @@ -150,19 +157,21 @@ int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
> >  						   high_avail /
> > vgpu_types[i].high_mm);
> 
> there is a memory leak a few lines above:
> 
> if (vgpu_types[i].weight < 1 ||
> 	vgpu_types[i].weight > VGPU_MAX_WEIGHT)
> 	return -EINVAL;
> 
> both old code and this patch forgot to free the buffers upon error.

Yeah.  I'll add a patch to the beginning of the series to fix the
existing leak and will then make sure to not leak the new allocation
either.
Kirti Wankhede June 15, 2022, 7:55 p.m. UTC | #5
On 6/14/2022 10:24 AM, Christoph Hellwig wrote:
<snip>

    /*
>    * Used in mdev_type_attribute sysfs functions to return the parent struct
>    * device
> @@ -85,19 +65,19 @@ static int mdev_device_remove_cb(struct device *dev, void *data)
>    * @parent: parent structure registered
>    * @dev: device structure representing parent device.
>    * @mdev_driver: Device driver to bind to the newly created mdev
> + * @types: Array of supported mdev types
> + * @nr_types: Number of entries in @types
>    *
>    * Returns a negative value on error, otherwise 0.
>    */
>   int mdev_register_parent(struct mdev_parent *parent, struct device *dev,
> -		struct mdev_driver *mdev_driver)
> +		struct mdev_driver *mdev_driver, struct mdev_type **types,
> +		unsigned int nr_types)
>   {
>   	char *env_string = "MDEV_STATE=registered";
>   	char *envp[] = { env_string, NULL };
>   	int ret;
> -
> -	/* check for mandatory ops */
> -	if (!mdev_driver->supported_type_groups)
> -		return -EINVAL;
> +	int i;
>   
>   	memset(parent, 0, sizeof(*parent));
>   	init_rwsem(&parent->unreg_sem);
> @@ -110,9 +90,23 @@ int mdev_register_parent(struct mdev_parent *parent, struct device *dev,
>   			return -ENOMEM;
>   	}
>   
> -	ret = parent_create_sysfs_files(parent);
> -	if (ret)
> +	parent->mdev_types_kset = kset_create_and_add("mdev_supported_types",
> +					       NULL, &parent->dev->kobj);
> +	if (!parent->mdev_types_kset)
> +		return -ENOMEM;
> +
> +	for (i = 0; i < nr_types; i++) {
> +		ret = mdev_type_add(parent, types[i]);
> +		if (ret)
> +			break;
> +	}

Above code should be in parent_create_sysfs_files(), all sysfs related 
stuff should be placed in mdev_sysfs.c

> +	parent->types = types;
> +	parent->nr_types = i;
> +
> +	if (ret) {
> +		mdev_unregister_parent(parent);
>   		return ret;
> +	}
>   
>   	ret = class_compat_create_link(mdev_bus_compat_class, dev, NULL);
>   	if (ret)
> @@ -132,13 +126,17 @@ void mdev_unregister_parent(struct mdev_parent *parent)
>   {
>   	char *env_string = "MDEV_STATE=unregistered";
>   	char *envp[] = { env_string, NULL };
> +	int i;
>   
>   	dev_info(parent->dev, "MDEV: Unregistering\n");
>   
> +	for (i = 0; i < parent->nr_types; i++)
> +		mdev_type_remove(parent->types[i]);
> +									
>   	down_write(&parent->unreg_sem);
>   	class_compat_remove_link(mdev_bus_compat_class, parent->dev, NULL);
>   	device_for_each_child(parent->dev, NULL, mdev_device_remove_cb);
> -	parent_remove_sysfs_files(parent);
> +	kset_unregister(parent->mdev_types_kset);
>   	up_write(&parent->unreg_sem);
>

Same as above, parent_remove_sysfs_files() can be updated accordingly 
instead of moving it here.

Rest looks fine to me.

Thanks,
Kirti
Christoph Hellwig June 19, 2022, 7:42 a.m. UTC | #6
On Thu, Jun 16, 2022 at 01:25:25AM +0530, Kirti Wankhede wrote:
>> +	parent->mdev_types_kset = kset_create_and_add("mdev_supported_types",
>> +					       NULL, &parent->dev->kobj);
>> +	if (!parent->mdev_types_kset)
>> +		return -ENOMEM;
>> +
>> +	for (i = 0; i < nr_types; i++) {
>> +		ret = mdev_type_add(parent, types[i]);
>> +		if (ret)
>> +			break;
>> +	}
>
> Above code should be in parent_create_sysfs_files(), all sysfs related 
> stuff should be placed in mdev_sysfs.c

Yes, it could. But why?  This is core logic of the interface and has
nothing to do with sysfs.
Jason Gunthorpe June 23, 2022, 9:27 p.m. UTC | #7
On Tue, Jun 14, 2022 at 06:54:18AM +0200, Christoph Hellwig wrote:

> diff --git a/drivers/vfio/mdev/mdev_sysfs.c b/drivers/vfio/mdev/mdev_sysfs.c
> index b71ffc5594870..09745e8e120f9 100644
> --- a/drivers/vfio/mdev/mdev_sysfs.c
> +++ b/drivers/vfio/mdev/mdev_sysfs.c
> @@ -80,8 +80,6 @@ static void mdev_type_release(struct kobject *kobj)
>  	struct mdev_type *type = to_mdev_type(kobj);
>  
>  	pr_debug("Releasing group %s\n", kobj->name);
> -	/* Pairs with the get in add_mdev_supported_type() */
> -	put_device(type->parent->dev);

I couldn't figure out why is it now OK to delete this? 

It still looks required because mdev_core.c continues to use
mdev-type->parent in various places and that pointer was being
protected by the

   kobject_get(&type->kobj);

in mdev_device_create() through the above kref..

eg after the whole series is applied this looks troubled:

	/* Pairs with the get in mdev_device_create() */
	kobject_put(&mdev->type->kobj);

	mutex_lock(&mdev_list_lock);
	list_del(&mdev->next);
	mdev->type->parent->available_instances++;
        ^^^^^^^^^^^^^^^^^^^^^

As there is now nothing keeping parent or type alive?

I think what would be good here is to directly
get_device(type->parent->dev) in mdev_device_create() and put in
mdev_device_release() then it is really clear how parent and
parent->dev are kept alive.

It looks like we also have to keep the type around and ref'd too for
the sysfs manipulation.

But overall I think this is fine

Jason
Kirti Wankhede June 24, 2022, 12:32 p.m. UTC | #8
On 6/24/2022 2:57 AM, Jason Gunthorpe wrote:
> On Tue, Jun 14, 2022 at 06:54:18AM +0200, Christoph Hellwig wrote:
> 
>> diff --git a/drivers/vfio/mdev/mdev_sysfs.c b/drivers/vfio/mdev/mdev_sysfs.c
>> index b71ffc5594870..09745e8e120f9 100644
>> --- a/drivers/vfio/mdev/mdev_sysfs.c
>> +++ b/drivers/vfio/mdev/mdev_sysfs.c
>> @@ -80,8 +80,6 @@ static void mdev_type_release(struct kobject *kobj)
>>   	struct mdev_type *type = to_mdev_type(kobj);
>>   
>>   	pr_debug("Releasing group %s\n", kobj->name);
>> -	/* Pairs with the get in add_mdev_supported_type() */
>> -	put_device(type->parent->dev);
> 
> I couldn't figure out why is it now OK to delete this?
> 
> It still looks required because mdev_core.c continues to use
> mdev-type->parent in various places and that pointer was being
> protected by the
> 
>     kobject_get(&type->kobj);
> 
> in mdev_device_create() through the above kref..
> 
> eg after the whole series is applied this looks troubled:
> 
> 	/* Pairs with the get in mdev_device_create() */
> 	kobject_put(&mdev->type->kobj);
> 
> 	mutex_lock(&mdev_list_lock);
> 	list_del(&mdev->next);
> 	mdev->type->parent->available_instances++;
>          ^^^^^^^^^^^^^^^^^^^^^
> 
> As there is now nothing keeping parent or type alive?
> 
> I think what would be good here is to directly
> get_device(type->parent->dev) in mdev_device_create() and put in
> mdev_device_release() then it is really clear how parent and
> parent->dev are kept alive.
> 
> It looks like we also have to keep the type around and ref'd too for
> the sysfs manipulation.
> 
> But overall I think this is fine
> 

See my reply on [PATCH 2/13].

Thanks,
Kirti



> Jason
diff mbox series

Patch

diff --git a/Documentation/driver-api/vfio-mediated-device.rst b/Documentation/driver-api/vfio-mediated-device.rst
index 3749f59c855fa..599008bdc1dcb 100644
--- a/Documentation/driver-api/vfio-mediated-device.rst
+++ b/Documentation/driver-api/vfio-mediated-device.rst
@@ -105,7 +105,7 @@  structure to represent a mediated device's driver::
      struct mdev_driver {
 	     int  (*probe)  (struct mdev_device *dev);
 	     void (*remove) (struct mdev_device *dev);
-	     struct attribute_group **supported_type_groups;
+	     const struct attribute * const *types_attrs;
 	     struct device_driver    driver;
      };
 
diff --git a/drivers/gpu/drm/i915/gvt/gvt.h b/drivers/gpu/drm/i915/gvt/gvt.h
index ddffd337f1c60..0ccb9bb7180cd 100644
--- a/drivers/gpu/drm/i915/gvt/gvt.h
+++ b/drivers/gpu/drm/i915/gvt/gvt.h
@@ -298,7 +298,7 @@  struct intel_gvt_firmware {
 
 #define NR_MAX_INTEL_VGPU_TYPES 20
 struct intel_vgpu_type {
-	char name[16];
+	struct mdev_type type;
 	unsigned int avail_instance;
 	unsigned int low_gm_size;
 	unsigned int high_gm_size;
@@ -329,6 +329,7 @@  struct intel_gvt {
 	struct notifier_block shadow_ctx_notifier_block[I915_NUM_ENGINES];
 	DECLARE_HASHTABLE(cmd_table, GVT_CMD_HASH_BITS);
 	struct mdev_parent parent;
+	struct mdev_type **mdev_types;
 	struct intel_vgpu_type *types;
 	unsigned int num_types;
 	struct intel_vgpu *idle_vgpu;
diff --git a/drivers/gpu/drm/i915/gvt/kvmgt.c b/drivers/gpu/drm/i915/gvt/kvmgt.c
index 70401374c72bc..1c6b7e8bec4fb 100644
--- a/drivers/gpu/drm/i915/gvt/kvmgt.c
+++ b/drivers/gpu/drm/i915/gvt/kvmgt.c
@@ -117,17 +117,10 @@  static ssize_t available_instances_show(struct mdev_type *mtype,
 					struct mdev_type_attribute *attr,
 					char *buf)
 {
-	struct intel_vgpu_type *type;
-	unsigned int num = 0;
-	struct intel_gvt *gvt = kdev_to_i915(mtype_get_parent_dev(mtype))->gvt;
+	struct intel_vgpu_type *type =
+		container_of(mtype, struct intel_vgpu_type, type);
 
-	type = &gvt->types[mtype_get_type_group_id(mtype)];
-	if (!type)
-		num = 0;
-	else
-		num = type->avail_instance;
-
-	return sprintf(buf, "%u\n", num);
+	return sprintf(buf, "%u\n", type->avail_instance);
 }
 
 static ssize_t device_api_show(struct mdev_type *mtype,
@@ -139,12 +132,8 @@  static ssize_t device_api_show(struct mdev_type *mtype,
 static ssize_t description_show(struct mdev_type *mtype,
 				struct mdev_type_attribute *attr, char *buf)
 {
-	struct intel_vgpu_type *type;
-	struct intel_gvt *gvt = kdev_to_i915(mtype_get_parent_dev(mtype))->gvt;
-
-	type = &gvt->types[mtype_get_type_group_id(mtype)];
-	if (!type)
-		return 0;
+	struct intel_vgpu_type *type =
+		container_of(mtype, struct intel_vgpu_type, type);
 
 	return sprintf(buf, "low_gm_size: %dMB\nhigh_gm_size: %dMB\n"
 		       "fence: %d\nresolution: %s\n"
@@ -158,14 +147,7 @@  static ssize_t description_show(struct mdev_type *mtype,
 static ssize_t name_show(struct mdev_type *mtype,
 			 struct mdev_type_attribute *attr, char *buf)
 {
-	struct intel_vgpu_type *type;
-	struct intel_gvt *gvt = kdev_to_i915(mtype_get_parent_dev(mtype))->gvt;
-
-	type = &gvt->types[mtype_get_type_group_id(mtype)];
-	if (!type)
-		return 0;
-
-	return sprintf(buf, "%s\n", type->name);
+	return sprintf(buf, "%s\n", mtype->sysfs_name);
 }
 
 static MDEV_TYPE_ATTR_RO(available_instances);
@@ -173,7 +155,7 @@  static MDEV_TYPE_ATTR_RO(device_api);
 static MDEV_TYPE_ATTR_RO(description);
 static MDEV_TYPE_ATTR_RO(name);
 
-static struct attribute *gvt_type_attrs[] = {
+static const struct attribute *gvt_type_attrs[] = {
 	&mdev_type_attr_available_instances.attr,
 	&mdev_type_attr_device_api.attr,
 	&mdev_type_attr_description.attr,
@@ -181,51 +163,6 @@  static struct attribute *gvt_type_attrs[] = {
 	NULL,
 };
 
-static struct attribute_group *gvt_vgpu_type_groups[] = {
-	[0 ... NR_MAX_INTEL_VGPU_TYPES - 1] = NULL,
-};
-
-static int intel_gvt_init_vgpu_type_groups(struct intel_gvt *gvt)
-{
-	int i, j;
-	struct intel_vgpu_type *type;
-	struct attribute_group *group;
-
-	for (i = 0; i < gvt->num_types; i++) {
-		type = &gvt->types[i];
-
-		group = kzalloc(sizeof(struct attribute_group), GFP_KERNEL);
-		if (!group)
-			goto unwind;
-
-		group->name = type->name;
-		group->attrs = gvt_type_attrs;
-		gvt_vgpu_type_groups[i] = group;
-	}
-
-	return 0;
-
-unwind:
-	for (j = 0; j < i; j++) {
-		group = gvt_vgpu_type_groups[j];
-		kfree(group);
-	}
-
-	return -ENOMEM;
-}
-
-static void intel_gvt_cleanup_vgpu_type_groups(struct intel_gvt *gvt)
-{
-	int i;
-	struct attribute_group *group;
-
-	for (i = 0; i < gvt->num_types; i++) {
-		group = gvt_vgpu_type_groups[i];
-		gvt_vgpu_type_groups[i] = NULL;
-		kfree(group);
-	}
-}
-
 static void gvt_unpin_guest_page(struct intel_vgpu *vgpu, unsigned long gfn,
 		unsigned long size)
 {
@@ -1614,14 +1551,11 @@  static int intel_vgpu_probe(struct mdev_device *mdev)
 {
 	struct device *pdev = mdev_parent_dev(mdev);
 	struct intel_gvt *gvt = kdev_to_i915(pdev)->gvt;
-	struct intel_vgpu_type *type;
+	struct intel_vgpu_type *type =
+		container_of(mdev->type, struct intel_vgpu_type, type);
 	struct intel_vgpu *vgpu;
 	int ret;
 
-	type = &gvt->types[mdev_get_type_group_id(mdev)];
-	if (!type)
-		return -EINVAL;
-
 	vgpu = intel_gvt_create_vgpu(gvt, type);
 	if (IS_ERR(vgpu)) {
 		gvt_err("failed to create intel vgpu: %ld\n", PTR_ERR(vgpu));
@@ -1660,7 +1594,7 @@  static struct mdev_driver intel_vgpu_mdev_driver = {
 	},
 	.probe		= intel_vgpu_probe,
 	.remove		= intel_vgpu_remove,
-	.supported_type_groups	= gvt_vgpu_type_groups,
+	.types_attrs	= gvt_type_attrs,
 };
 
 int intel_gvt_page_track_add(struct intel_vgpu *info, u64 gfn)
@@ -1959,7 +1893,6 @@  static void intel_gvt_clean_device(struct drm_i915_private *i915)
 		return;
 
 	mdev_unregister_parent(&gvt->parent);
-	intel_gvt_cleanup_vgpu_type_groups(gvt);
 	intel_gvt_destroy_idle_vgpu(gvt->idle_vgpu);
 	intel_gvt_clean_vgpu_types(gvt);
 
@@ -2059,20 +1992,15 @@  static int intel_gvt_init_device(struct drm_i915_private *i915)
 
 	intel_gvt_debugfs_init(gvt);
 
-	ret = intel_gvt_init_vgpu_type_groups(gvt);
-	if (ret)
-		goto out_destroy_idle_vgpu;
-
 	ret = mdev_register_parent(&gvt->parent, i915->drm.dev,
-				   &intel_vgpu_mdev_driver);
+				   &intel_vgpu_mdev_driver,
+				   gvt->mdev_types, gvt->num_types);
 	if (ret)
-		goto out_cleanup_vgpu_type_groups;
+		goto out_destroy_idle_vgpu;
 
 	gvt_dbg_core("gvt device initialization is done\n");
 	return 0;
 
-out_cleanup_vgpu_type_groups:
-	intel_gvt_cleanup_vgpu_type_groups(gvt);
 out_destroy_idle_vgpu:
 	intel_gvt_destroy_idle_vgpu(gvt->idle_vgpu);
 	intel_gvt_debugfs_clean(gvt);
diff --git a/drivers/gpu/drm/i915/gvt/vgpu.c b/drivers/gpu/drm/i915/gvt/vgpu.c
index 46da19b3225d2..2f38b90ecb8ac 100644
--- a/drivers/gpu/drm/i915/gvt/vgpu.c
+++ b/drivers/gpu/drm/i915/gvt/vgpu.c
@@ -131,6 +131,13 @@  int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
 	if (!gvt->types)
 		return -ENOMEM;
 
+	gvt->mdev_types = kcalloc(num_types, sizeof(*gvt->mdev_types),
+			     GFP_KERNEL);
+	if (!gvt->mdev_types) {
+		kfree(gvt->types);
+		return -ENOMEM;
+	}
+
 	min_low = MB_TO_BYTES(32);
 	for (i = 0; i < num_types; ++i) {
 		if (low_avail / vgpu_types[i].low_mm == 0)
@@ -150,19 +157,21 @@  int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
 						   high_avail / vgpu_types[i].high_mm);
 
 		if (GRAPHICS_VER(gvt->gt->i915) == 8)
-			sprintf(gvt->types[i].name, "GVTg_V4_%s",
+			sprintf(gvt->types[i].type.sysfs_name, "GVTg_V4_%s",
 				vgpu_types[i].name);
 		else if (GRAPHICS_VER(gvt->gt->i915) == 9)
-			sprintf(gvt->types[i].name, "GVTg_V5_%s",
+			sprintf(gvt->types[i].type.sysfs_name, "GVTg_V5_%s",
 				vgpu_types[i].name);
 
 		gvt_dbg_core("type[%d]: %s avail %u low %u high %u fence %u weight %u res %s\n",
-			     i, gvt->types[i].name,
+			     i, gvt->types[i].type.sysfs_name,
 			     gvt->types[i].avail_instance,
 			     gvt->types[i].low_gm_size,
 			     gvt->types[i].high_gm_size, gvt->types[i].fence,
 			     gvt->types[i].weight,
 			     vgpu_edid_str(gvt->types[i].resolution));
+
+		gvt->mdev_types[i] = &gvt->types[i].type;
 	}
 
 	gvt->num_types = i;
@@ -171,6 +180,7 @@  int intel_gvt_init_vgpu_types(struct intel_gvt *gvt)
 
 void intel_gvt_clean_vgpu_types(struct intel_gvt *gvt)
 {
+	kfree(gvt->mdev_types);
 	kfree(gvt->types);
 }
 
@@ -198,7 +208,7 @@  static void intel_gvt_update_vgpu_types(struct intel_gvt *gvt)
 						   fence_min);
 
 		gvt_dbg_core("update type[%d]: %s avail %u low %u high %u fence %u\n",
-		       i, gvt->types[i].name,
+		       i, gvt->types[i].type.sysfs_name,
 		       gvt->types[i].avail_instance, gvt->types[i].low_gm_size,
 		       gvt->types[i].high_gm_size, gvt->types[i].fence);
 	}
diff --git a/drivers/s390/cio/cio.h b/drivers/s390/cio/cio.h
index 22be5ac7d23c1..1da45307a1862 100644
--- a/drivers/s390/cio/cio.h
+++ b/drivers/s390/cio/cio.h
@@ -110,6 +110,8 @@  struct subchannel {
 	 */
 	const char *driver_override;
 	struct mdev_parent parent;
+	struct mdev_type mdev_type;
+	struct mdev_type *mdev_types[1];
 } __attribute__ ((aligned(8)));
 
 DECLARE_PER_CPU_ALIGNED(struct irb, cio_irb);
diff --git a/drivers/s390/cio/vfio_ccw_ops.c b/drivers/s390/cio/vfio_ccw_ops.c
index 9192a21085ce4..25b8d42a522ac 100644
--- a/drivers/s390/cio/vfio_ccw_ops.c
+++ b/drivers/s390/cio/vfio_ccw_ops.c
@@ -95,23 +95,13 @@  static ssize_t available_instances_show(struct mdev_type *mtype,
 }
 static MDEV_TYPE_ATTR_RO(available_instances);
 
-static struct attribute *mdev_types_attrs[] = {
+static const struct attribute *mdev_types_attrs[] = {
 	&mdev_type_attr_name.attr,
 	&mdev_type_attr_device_api.attr,
 	&mdev_type_attr_available_instances.attr,
 	NULL,
 };
 
-static struct attribute_group mdev_type_group = {
-	.name  = "io",
-	.attrs = mdev_types_attrs,
-};
-
-static struct attribute_group *mdev_type_groups[] = {
-	&mdev_type_group,
-	NULL,
-};
-
 static int vfio_ccw_mdev_probe(struct mdev_device *mdev)
 {
 	struct vfio_ccw_private *private = dev_get_drvdata(mdev->dev.parent);
@@ -654,13 +644,16 @@  struct mdev_driver vfio_ccw_mdev_driver = {
 	},
 	.probe = vfio_ccw_mdev_probe,
 	.remove = vfio_ccw_mdev_remove,
-	.supported_type_groups  = mdev_type_groups,
+	.types_attrs = mdev_types_attrs,
 };
 
 int vfio_ccw_mdev_reg(struct subchannel *sch)
 {
+	sprintf(sch->mdev_type.sysfs_name, "io");
+	sch->mdev_types[0] = &sch->mdev_type;
 	return mdev_register_parent(&sch->parent, &sch->dev,
-				    &vfio_ccw_mdev_driver);
+				    &vfio_ccw_mdev_driver, sch->mdev_types,
+				    1);
 }
 
 void vfio_ccw_mdev_unreg(struct subchannel *sch)
diff --git a/drivers/s390/crypto/vfio_ap_ops.c b/drivers/s390/crypto/vfio_ap_ops.c
index 834945150dc9f..ff25858b2ebbe 100644
--- a/drivers/s390/crypto/vfio_ap_ops.c
+++ b/drivers/s390/crypto/vfio_ap_ops.c
@@ -537,23 +537,13 @@  static ssize_t device_api_show(struct mdev_type *mtype,
 
 static MDEV_TYPE_ATTR_RO(device_api);
 
-static struct attribute *vfio_ap_mdev_type_attrs[] = {
+static const struct attribute *vfio_ap_mdev_type_attrs[] = {
 	&mdev_type_attr_name.attr,
 	&mdev_type_attr_device_api.attr,
 	&mdev_type_attr_available_instances.attr,
 	NULL,
 };
 
-static struct attribute_group vfio_ap_mdev_hwvirt_type_group = {
-	.name = VFIO_AP_MDEV_TYPE_HWVIRT,
-	.attrs = vfio_ap_mdev_type_attrs,
-};
-
-static struct attribute_group *vfio_ap_mdev_type_groups[] = {
-	&vfio_ap_mdev_hwvirt_type_group,
-	NULL,
-};
-
 struct vfio_ap_queue_reserved {
 	unsigned long *apid;
 	unsigned long *apqi;
@@ -1472,7 +1462,7 @@  static struct mdev_driver vfio_ap_matrix_driver = {
 	},
 	.probe = vfio_ap_mdev_probe,
 	.remove = vfio_ap_mdev_remove,
-	.supported_type_groups = vfio_ap_mdev_type_groups,
+	.types_attrs = vfio_ap_mdev_type_attrs,
 };
 
 int vfio_ap_mdev_register(void)
@@ -1485,8 +1475,11 @@  int vfio_ap_mdev_register(void)
 	if (ret)
 		return ret;
 
+	strcpy(matrix_dev->mdev_type.sysfs_name, VFIO_AP_MDEV_TYPE_HWVIRT);
+	matrix_dev->mdev_types[0] = &matrix_dev->mdev_type;
 	ret = mdev_register_parent(&matrix_dev->parent, &matrix_dev->device,
-				   &vfio_ap_matrix_driver);
+				   &vfio_ap_matrix_driver,
+				   matrix_dev->mdev_types, 1);
 	if (ret)
 		goto err_driver;
 	return 0;
diff --git a/drivers/s390/crypto/vfio_ap_private.h b/drivers/s390/crypto/vfio_ap_private.h
index 0191f6bc973a4..5dc5050d03791 100644
--- a/drivers/s390/crypto/vfio_ap_private.h
+++ b/drivers/s390/crypto/vfio_ap_private.h
@@ -46,6 +46,8 @@  struct ap_matrix_dev {
 	struct mutex lock;
 	struct ap_driver  *vfio_ap_drv;
 	struct mdev_parent parent;
+	struct mdev_type mdev_type;
+	struct mdev_type *mdev_types[];
 };
 
 extern struct ap_matrix_dev *matrix_dev;
diff --git a/drivers/vfio/mdev/mdev_core.c b/drivers/vfio/mdev/mdev_core.c
index d38faed14c689..71c7f4e521a74 100644
--- a/drivers/vfio/mdev/mdev_core.c
+++ b/drivers/vfio/mdev/mdev_core.c
@@ -29,26 +29,6 @@  struct device *mdev_parent_dev(struct mdev_device *mdev)
 }
 EXPORT_SYMBOL(mdev_parent_dev);
 
-/*
- * Return the index in supported_type_groups that this mdev_device was created
- * from.
- */
-unsigned int mdev_get_type_group_id(struct mdev_device *mdev)
-{
-	return mdev->type->type_group_id;
-}
-EXPORT_SYMBOL(mdev_get_type_group_id);
-
-/*
- * Used in mdev_type_attribute sysfs functions to return the index in the
- * supported_type_groups that the sysfs is called from.
- */
-unsigned int mtype_get_type_group_id(struct mdev_type *mtype)
-{
-	return mtype->type_group_id;
-}
-EXPORT_SYMBOL(mtype_get_type_group_id);
-
 /*
  * Used in mdev_type_attribute sysfs functions to return the parent struct
  * device
@@ -85,19 +65,19 @@  static int mdev_device_remove_cb(struct device *dev, void *data)
  * @parent: parent structure registered
  * @dev: device structure representing parent device.
  * @mdev_driver: Device driver to bind to the newly created mdev
+ * @types: Array of supported mdev types
+ * @nr_types: Number of entries in @types
  *
  * Returns a negative value on error, otherwise 0.
  */
 int mdev_register_parent(struct mdev_parent *parent, struct device *dev,
-		struct mdev_driver *mdev_driver)
+		struct mdev_driver *mdev_driver, struct mdev_type **types,
+		unsigned int nr_types)
 {
 	char *env_string = "MDEV_STATE=registered";
 	char *envp[] = { env_string, NULL };
 	int ret;
-
-	/* check for mandatory ops */
-	if (!mdev_driver->supported_type_groups)
-		return -EINVAL;
+	int i;
 
 	memset(parent, 0, sizeof(*parent));
 	init_rwsem(&parent->unreg_sem);
@@ -110,9 +90,23 @@  int mdev_register_parent(struct mdev_parent *parent, struct device *dev,
 			return -ENOMEM;
 	}
 
-	ret = parent_create_sysfs_files(parent);
-	if (ret)
+	parent->mdev_types_kset = kset_create_and_add("mdev_supported_types",
+					       NULL, &parent->dev->kobj);
+	if (!parent->mdev_types_kset)
+		return -ENOMEM;
+
+	for (i = 0; i < nr_types; i++) {
+		ret = mdev_type_add(parent, types[i]);
+		if (ret)
+			break;
+	}
+	parent->types = types;
+	parent->nr_types = i;
+
+	if (ret) {
+		mdev_unregister_parent(parent);
 		return ret;
+	}
 
 	ret = class_compat_create_link(mdev_bus_compat_class, dev, NULL);
 	if (ret)
@@ -132,13 +126,17 @@  void mdev_unregister_parent(struct mdev_parent *parent)
 {
 	char *env_string = "MDEV_STATE=unregistered";
 	char *envp[] = { env_string, NULL };
+	int i;
 
 	dev_info(parent->dev, "MDEV: Unregistering\n");
 
+	for (i = 0; i < parent->nr_types; i++)
+		mdev_type_remove(parent->types[i]);
+									  
 	down_write(&parent->unreg_sem);
 	class_compat_remove_link(mdev_bus_compat_class, parent->dev, NULL);
 	device_for_each_child(parent->dev, NULL, mdev_device_remove_cb);
-	parent_remove_sysfs_files(parent);
+	kset_unregister(parent->mdev_types_kset);
 	up_write(&parent->unreg_sem);
 
 	kobject_uevent_env(&parent->dev->kobj, KOBJ_CHANGE, envp);
diff --git a/drivers/vfio/mdev/mdev_driver.c b/drivers/vfio/mdev/mdev_driver.c
index 7bd4bb9850e81..1da1ecf76a0d5 100644
--- a/drivers/vfio/mdev/mdev_driver.c
+++ b/drivers/vfio/mdev/mdev_driver.c
@@ -56,10 +56,9 @@  EXPORT_SYMBOL_GPL(mdev_bus_type);
  **/
 int mdev_register_driver(struct mdev_driver *drv)
 {
-	/* initialize common driver fields */
+	if (!drv->types_attrs)
+		return -EINVAL;
 	drv->driver.bus = &mdev_bus_type;
-
-	/* register with core */
 	return driver_register(&drv->driver);
 }
 EXPORT_SYMBOL(mdev_register_driver);
diff --git a/drivers/vfio/mdev/mdev_private.h b/drivers/vfio/mdev/mdev_private.h
index 297f911fdc890..6980f504018f3 100644
--- a/drivers/vfio/mdev/mdev_private.h
+++ b/drivers/vfio/mdev/mdev_private.h
@@ -13,14 +13,6 @@ 
 int  mdev_bus_register(void);
 void mdev_bus_unregister(void);
 
-struct mdev_type {
-	struct kobject kobj;
-	struct kobject *devices_kobj;
-	struct mdev_parent *parent;
-	struct list_head next;
-	unsigned int type_group_id;
-};
-
 extern const struct attribute_group *mdev_device_groups[];
 
 #define to_mdev_type_attr(_attr)	\
@@ -28,12 +20,12 @@  extern const struct attribute_group *mdev_device_groups[];
 #define to_mdev_type(_kobj)		\
 	container_of(_kobj, struct mdev_type, kobj)
 
-int  parent_create_sysfs_files(struct mdev_parent *parent);
-void parent_remove_sysfs_files(struct mdev_parent *parent);
-
 int  mdev_create_sysfs_files(struct mdev_device *mdev);
 void mdev_remove_sysfs_files(struct mdev_device *mdev);
 
+int mdev_type_add(struct mdev_parent *parent, struct mdev_type *type);
+void mdev_type_remove(struct mdev_type *type);
+
 int mdev_device_create(struct mdev_type *kobj, const guid_t *uuid);
 int  mdev_device_remove(struct mdev_device *dev);
 
diff --git a/drivers/vfio/mdev/mdev_sysfs.c b/drivers/vfio/mdev/mdev_sysfs.c
index b71ffc5594870..09745e8e120f9 100644
--- a/drivers/vfio/mdev/mdev_sysfs.c
+++ b/drivers/vfio/mdev/mdev_sysfs.c
@@ -80,8 +80,6 @@  static void mdev_type_release(struct kobject *kobj)
 	struct mdev_type *type = to_mdev_type(kobj);
 
 	pr_debug("Releasing group %s\n", kobj->name);
-	/* Pairs with the get in add_mdev_supported_type() */
-	put_device(type->parent->dev);
 	kfree(type);
 }
 
@@ -90,36 +88,17 @@  static struct kobj_type mdev_type_ktype = {
 	.release = mdev_type_release,
 };
 
-static struct mdev_type *add_mdev_supported_type(struct mdev_parent *parent,
-						 unsigned int type_group_id)
+int mdev_type_add(struct mdev_parent *parent, struct mdev_type *type)
 {
-	struct mdev_type *type;
-	struct attribute_group *group =
-		parent->mdev_driver->supported_type_groups[type_group_id];
 	int ret;
 
-	if (!group->name) {
-		pr_err("%s: Type name empty!\n", __func__);
-		return ERR_PTR(-EINVAL);
-	}
-
-	type = kzalloc(sizeof(*type), GFP_KERNEL);
-	if (!type)
-		return ERR_PTR(-ENOMEM);
-
-	type->kobj.kset = parent->mdev_types_kset;
 	type->parent = parent;
-	/* Pairs with the put in mdev_type_release() */
-	get_device(parent->dev);
-	type->type_group_id = type_group_id;
-
-	ret = kobject_init_and_add(&type->kobj, &mdev_type_ktype, NULL,
-				   "%s-%s", dev_driver_string(parent->dev),
-				   group->name);
-	if (ret) {
-		kobject_put(&type->kobj);
-		return ERR_PTR(ret);
-	}
+	type->kobj.kset = parent->mdev_types_kset;
+	ret = kobject_init_and_add(&type->kobj, &mdev_type_ktype, NULL, "%s-%s",
+				   dev_driver_string(parent->dev),
+				   type->sysfs_name);
+	if (ret)
+		return ret;
 
 	ret = sysfs_create_file(&type->kobj, &mdev_type_attr_create.attr);
 	if (ret)
@@ -131,13 +110,10 @@  static struct mdev_type *add_mdev_supported_type(struct mdev_parent *parent,
 		goto attr_devices_failed;
 	}
 
-	ret = sysfs_create_files(&type->kobj,
-				 (const struct attribute **)group->attrs);
-	if (ret) {
-		ret = -ENOMEM;
+	ret = sysfs_create_files(&type->kobj, parent->mdev_driver->types_attrs);
+	if (ret)
 		goto attrs_failed;
-	}
-	return type;
+	return 0;
 
 attrs_failed:
 	kobject_put(type->devices_kobj);
@@ -146,80 +122,19 @@  static struct mdev_type *add_mdev_supported_type(struct mdev_parent *parent,
 attr_create_failed:
 	kobject_del(&type->kobj);
 	kobject_put(&type->kobj);
-	return ERR_PTR(ret);
+	return ret;
 }
 
-static void remove_mdev_supported_type(struct mdev_type *type)
+void mdev_type_remove(struct mdev_type *type)
 {
-	struct attribute_group *group =
-		type->parent->mdev_driver->supported_type_groups[type->type_group_id];
+	sysfs_remove_files(&type->kobj, type->parent->mdev_driver->types_attrs);
 
-	sysfs_remove_files(&type->kobj,
-			   (const struct attribute **)group->attrs);
 	kobject_put(type->devices_kobj);
 	sysfs_remove_file(&type->kobj, &mdev_type_attr_create.attr);
 	kobject_del(&type->kobj);
 	kobject_put(&type->kobj);
 }
 
-static int add_mdev_supported_type_groups(struct mdev_parent *parent)
-{
-	int i;
-
-	for (i = 0; parent->mdev_driver->supported_type_groups[i]; i++) {
-		struct mdev_type *type;
-
-		type = add_mdev_supported_type(parent, i);
-		if (IS_ERR(type)) {
-			struct mdev_type *ltype, *tmp;
-
-			list_for_each_entry_safe(ltype, tmp, &parent->type_list,
-						  next) {
-				list_del(&ltype->next);
-				remove_mdev_supported_type(ltype);
-			}
-			return PTR_ERR(type);
-		}
-		list_add(&type->next, &parent->type_list);
-	}
-	return 0;
-}
-
-/* mdev sysfs functions */
-void parent_remove_sysfs_files(struct mdev_parent *parent)
-{
-	struct mdev_type *type, *tmp;
-
-	list_for_each_entry_safe(type, tmp, &parent->type_list, next) {
-		list_del(&type->next);
-		remove_mdev_supported_type(type);
-	}
-
-	kset_unregister(parent->mdev_types_kset);
-}
-
-int parent_create_sysfs_files(struct mdev_parent *parent)
-{
-	int ret;
-
-	parent->mdev_types_kset = kset_create_and_add("mdev_supported_types",
-					       NULL, &parent->dev->kobj);
-
-	if (!parent->mdev_types_kset)
-		return -ENOMEM;
-
-	INIT_LIST_HEAD(&parent->type_list);
-
-	ret = add_mdev_supported_type_groups(parent);
-	if (ret)
-		goto create_err;
-	return 0;
-
-create_err:
-	kset_unregister(parent->mdev_types_kset);
-	return ret;
-}
-
 static ssize_t remove_store(struct device *dev, struct device_attribute *attr,
 			    const char *buf, size_t count)
 {
diff --git a/include/linux/mdev.h b/include/linux/mdev.h
index 327ce3e5c6b5f..d28ddf0f561a0 100644
--- a/include/linux/mdev.h
+++ b/include/linux/mdev.h
@@ -23,14 +23,27 @@  struct mdev_device {
 	bool active;
 };
 
+struct mdev_type {
+	/* set by the driver before calling mdev_register parent: */
+	char sysfs_name[32];
+
+	/* set by the core, can be used drivers */
+	struct mdev_parent *parent;
+
+	/* internal only */
+	struct kobject kobj;
+	struct kobject *devices_kobj;
+};
+
 /* embedded into the struct device that the mdev devices hang off */
 struct mdev_parent {
 	struct device *dev;
 	struct mdev_driver *mdev_driver;
 	struct kset *mdev_types_kset;
-	struct list_head type_list;
 	/* Synchronize device creation/removal with parent unregistration */
 	struct rw_semaphore unreg_sem;
+	struct mdev_type **types;
+	unsigned int nr_types;
 };
 
 static inline struct mdev_device *to_mdev_device(struct device *dev)
@@ -38,8 +51,6 @@  static inline struct mdev_device *to_mdev_device(struct device *dev)
 	return container_of(dev, struct mdev_device, dev);
 }
 
-unsigned int mdev_get_type_group_id(struct mdev_device *mdev);
-unsigned int mtype_get_type_group_id(struct mdev_type *mtype);
 struct device *mtype_get_parent_dev(struct mdev_type *mtype);
 
 /* interface for exporting mdev supported type attributes */
@@ -66,15 +77,13 @@  struct mdev_type_attribute mdev_type_attr_##_name =		\
  * struct mdev_driver - Mediated device driver
  * @probe: called when new device created
  * @remove: called when device removed
- * @supported_type_groups: Attributes to define supported types. It is mandatory
- *			to provide supported types.
+ * @types_attrs: attributes to the type kobjects.
  * @driver: device driver structure
- *
  **/
 struct mdev_driver {
 	int (*probe)(struct mdev_device *dev);
 	void (*remove)(struct mdev_device *dev);
-	struct attribute_group **supported_type_groups;
+	const struct attribute * const *types_attrs;
 	struct device_driver driver;
 };
 
@@ -86,7 +95,8 @@  static inline const guid_t *mdev_uuid(struct mdev_device *mdev)
 extern struct bus_type mdev_bus_type;
 
 int mdev_register_parent(struct mdev_parent *parent, struct device *dev,
-		struct mdev_driver *mdev_driver);
+		struct mdev_driver *mdev_driver, struct mdev_type **types,
+		unsigned int nr_types);
 void mdev_unregister_parent(struct mdev_parent *parent);
 
 int mdev_register_driver(struct mdev_driver *drv);
diff --git a/samples/vfio-mdev/mbochs.c b/samples/vfio-mdev/mbochs.c
index 30b3643b3b389..1069f561cb012 100644
--- a/samples/vfio-mdev/mbochs.c
+++ b/samples/vfio-mdev/mbochs.c
@@ -99,23 +99,27 @@  MODULE_PARM_DESC(mem, "megabytes available to " MBOCHS_NAME " devices");
 #define MBOCHS_TYPE_2 "medium"
 #define MBOCHS_TYPE_3 "large"
 
-static const struct mbochs_type {
+static struct mbochs_type {
+	struct mdev_type type;
 	const char *name;
 	u32 mbytes;
 	u32 max_x;
 	u32 max_y;
 } mbochs_types[] = {
 	{
+		.type.sysfs_name	= MBOCHS_TYPE_1,
 		.name	= MBOCHS_CLASS_NAME "-" MBOCHS_TYPE_1,
 		.mbytes = 4,
 		.max_x  = 800,
 		.max_y  = 600,
 	}, {
+		.type.sysfs_name	= MBOCHS_TYPE_2,
 		.name	= MBOCHS_CLASS_NAME "-" MBOCHS_TYPE_2,
 		.mbytes = 16,
 		.max_x  = 1920,
 		.max_y  = 1440,
 	}, {
+		.type.sysfs_name	= MBOCHS_TYPE_3,
 		.name	= MBOCHS_CLASS_NAME "-" MBOCHS_TYPE_3,
 		.mbytes = 64,
 		.max_x  = 0,
@@ -123,6 +127,11 @@  static const struct mbochs_type {
 	},
 };
 
+static struct mdev_type *mbochs_mdev_types[] = {
+	&mbochs_types[0].type,
+	&mbochs_types[1].type,
+	&mbochs_types[2].type,
+};
 
 static dev_t		mbochs_devt;
 static struct class	*mbochs_class;
@@ -508,8 +517,8 @@  static int mbochs_reset(struct mdev_state *mdev_state)
 static int mbochs_probe(struct mdev_device *mdev)
 {
 	int avail_mbytes = atomic_read(&mbochs_avail_mbytes);
-	const struct mbochs_type *type =
-		&mbochs_types[mdev_get_type_group_id(mdev)];
+	struct mbochs_type *type =
+		container_of(mdev->type, struct mbochs_type, type);
 	struct device *dev = mdev_dev(mdev);
 	struct mdev_state *mdev_state;
 	int ret = -ENOMEM;
@@ -1328,8 +1337,8 @@  static const struct attribute_group *mdev_dev_groups[] = {
 static ssize_t name_show(struct mdev_type *mtype,
 			 struct mdev_type_attribute *attr, char *buf)
 {
-	const struct mbochs_type *type =
-		&mbochs_types[mtype_get_type_group_id(mtype)];
+	struct mbochs_type *type =
+		container_of(mtype, struct mbochs_type, type);
 
 	return sprintf(buf, "%s\n", type->name);
 }
@@ -1338,8 +1347,8 @@  static MDEV_TYPE_ATTR_RO(name);
 static ssize_t description_show(struct mdev_type *mtype,
 				struct mdev_type_attribute *attr, char *buf)
 {
-	const struct mbochs_type *type =
-		&mbochs_types[mtype_get_type_group_id(mtype)];
+	struct mbochs_type *type =
+		container_of(mtype, struct mbochs_type, type);
 
 	return sprintf(buf, "virtual display, %d MB video memory\n",
 		       type ? type->mbytes  : 0);
@@ -1350,8 +1359,8 @@  static ssize_t available_instances_show(struct mdev_type *mtype,
 					struct mdev_type_attribute *attr,
 					char *buf)
 {
-	const struct mbochs_type *type =
-		&mbochs_types[mtype_get_type_group_id(mtype)];
+	struct mbochs_type *type =
+		container_of(mtype, struct mbochs_type, type);
 	int count = atomic_read(&mbochs_avail_mbytes) / type->mbytes;
 
 	return sprintf(buf, "%d\n", count);
@@ -1365,7 +1374,7 @@  static ssize_t device_api_show(struct mdev_type *mtype,
 }
 static MDEV_TYPE_ATTR_RO(device_api);
 
-static struct attribute *mdev_types_attrs[] = {
+static const struct attribute *mdev_types_attrs[] = {
 	&mdev_type_attr_name.attr,
 	&mdev_type_attr_description.attr,
 	&mdev_type_attr_device_api.attr,
@@ -1373,28 +1382,6 @@  static struct attribute *mdev_types_attrs[] = {
 	NULL,
 };
 
-static struct attribute_group mdev_type_group1 = {
-	.name  = MBOCHS_TYPE_1,
-	.attrs = mdev_types_attrs,
-};
-
-static struct attribute_group mdev_type_group2 = {
-	.name  = MBOCHS_TYPE_2,
-	.attrs = mdev_types_attrs,
-};
-
-static struct attribute_group mdev_type_group3 = {
-	.name  = MBOCHS_TYPE_3,
-	.attrs = mdev_types_attrs,
-};
-
-static struct attribute_group *mdev_type_groups[] = {
-	&mdev_type_group1,
-	&mdev_type_group2,
-	&mdev_type_group3,
-	NULL,
-};
-
 static const struct vfio_device_ops mbochs_dev_ops = {
 	.close_device = mbochs_close_device,
 	.read = mbochs_read,
@@ -1412,7 +1399,7 @@  static struct mdev_driver mbochs_driver = {
 	},
 	.probe = mbochs_probe,
 	.remove	= mbochs_remove,
-	.supported_type_groups = mdev_type_groups,
+	.types_attrs = mdev_types_attrs,
 };
 
 static const struct file_operations vd_fops = {
@@ -1457,7 +1444,9 @@  static int __init mbochs_dev_init(void)
 	if (ret)
 		goto err_class;
 
-	ret = mdev_register_parent(&mbochs_parent, &mbochs_dev, &mbochs_driver);
+	ret = mdev_register_parent(&mbochs_parent, &mbochs_dev, &mbochs_driver,
+				   mbochs_mdev_types,
+				   ARRAY_SIZE(mbochs_mdev_types));
 	if (ret)
 		goto err_device;
 
diff --git a/samples/vfio-mdev/mdpy.c b/samples/vfio-mdev/mdpy.c
index 132bb055628a6..40b1c8a58157c 100644
--- a/samples/vfio-mdev/mdpy.c
+++ b/samples/vfio-mdev/mdpy.c
@@ -51,7 +51,8 @@  MODULE_PARM_DESC(count, "number of " MDPY_NAME " devices");
 #define MDPY_TYPE_2 "xga"
 #define MDPY_TYPE_3 "hd"
 
-static const struct mdpy_type {
+static struct mdpy_type {
+	struct mdev_type type;
 	const char *name;
 	u32 format;
 	u32 bytepp;
@@ -59,18 +60,21 @@  static const struct mdpy_type {
 	u32 height;
 } mdpy_types[] = {
 	{
+		.type.sysfs_name 	= MDPY_TYPE_1,
 		.name	= MDPY_CLASS_NAME "-" MDPY_TYPE_1,
 		.format = DRM_FORMAT_XRGB8888,
 		.bytepp = 4,
 		.width	= 640,
 		.height = 480,
 	}, {
+		.type.sysfs_name 	= MDPY_TYPE_2,
 		.name	= MDPY_CLASS_NAME "-" MDPY_TYPE_2,
 		.format = DRM_FORMAT_XRGB8888,
 		.bytepp = 4,
 		.width	= 1024,
 		.height = 768,
 	}, {
+		.type.sysfs_name 	= MDPY_TYPE_3,
 		.name	= MDPY_CLASS_NAME "-" MDPY_TYPE_3,
 		.format = DRM_FORMAT_XRGB8888,
 		.bytepp = 4,
@@ -78,6 +82,12 @@  static const struct mdpy_type {
 		.height = 1080,
 	},
 };
+	
+static struct mdev_type *mdpy_mdev_types[] = {
+	&mdpy_types[0].type,
+	&mdpy_types[1].type,
+	&mdpy_types[2].type,
+};
 
 static dev_t		mdpy_devt;
 static struct class	*mdpy_class;
@@ -219,7 +229,7 @@  static int mdpy_reset(struct mdev_state *mdev_state)
 static int mdpy_probe(struct mdev_device *mdev)
 {
 	const struct mdpy_type *type =
-		&mdpy_types[mdev_get_type_group_id(mdev)];
+		container_of(mdev->type, struct mdpy_type, type);
 	struct device *dev = mdev_dev(mdev);
 	struct mdev_state *mdev_state;
 	u32 fbsize;
@@ -644,8 +654,7 @@  static const struct attribute_group *mdev_dev_groups[] = {
 static ssize_t name_show(struct mdev_type *mtype,
 			 struct mdev_type_attribute *attr, char *buf)
 {
-	const struct mdpy_type *type =
-		&mdpy_types[mtype_get_type_group_id(mtype)];
+	struct mdpy_type *type = container_of(mtype, struct mdpy_type, type);
 
 	return sprintf(buf, "%s\n", type->name);
 }
@@ -654,8 +663,7 @@  static MDEV_TYPE_ATTR_RO(name);
 static ssize_t description_show(struct mdev_type *mtype,
 				struct mdev_type_attribute *attr, char *buf)
 {
-	const struct mdpy_type *type =
-		&mdpy_types[mtype_get_type_group_id(mtype)];
+	struct mdpy_type *type = container_of(mtype, struct mdpy_type, type);
 
 	return sprintf(buf, "virtual display, %dx%d framebuffer\n",
 		       type->width, type->height);
@@ -677,7 +685,7 @@  static ssize_t device_api_show(struct mdev_type *mtype,
 }
 static MDEV_TYPE_ATTR_RO(device_api);
 
-static struct attribute *mdev_types_attrs[] = {
+static const struct attribute *mdev_types_attrs[] = {
 	&mdev_type_attr_name.attr,
 	&mdev_type_attr_description.attr,
 	&mdev_type_attr_device_api.attr,
@@ -685,28 +693,6 @@  static struct attribute *mdev_types_attrs[] = {
 	NULL,
 };
 
-static struct attribute_group mdev_type_group1 = {
-	.name  = MDPY_TYPE_1,
-	.attrs = mdev_types_attrs,
-};
-
-static struct attribute_group mdev_type_group2 = {
-	.name  = MDPY_TYPE_2,
-	.attrs = mdev_types_attrs,
-};
-
-static struct attribute_group mdev_type_group3 = {
-	.name  = MDPY_TYPE_3,
-	.attrs = mdev_types_attrs,
-};
-
-static struct attribute_group *mdev_type_groups[] = {
-	&mdev_type_group1,
-	&mdev_type_group2,
-	&mdev_type_group3,
-	NULL,
-};
-
 static const struct vfio_device_ops mdpy_dev_ops = {
 	.read = mdpy_read,
 	.write = mdpy_write,
@@ -723,7 +709,7 @@  static struct mdev_driver mdpy_driver = {
 	},
 	.probe = mdpy_probe,
 	.remove	= mdpy_remove,
-	.supported_type_groups = mdev_type_groups,
+	.types_attrs = mdev_types_attrs,
 };
 
 static const struct file_operations vd_fops = {
@@ -766,7 +752,9 @@  static int __init mdpy_dev_init(void)
 	if (ret)
 		goto err_class;
 
-	ret = mdev_register_parent(&mdpy_parent, &mdpy_dev, &mdpy_driver);
+	ret = mdev_register_parent(&mdpy_parent, &mdpy_dev, &mdpy_driver,
+				   mdpy_mdev_types,
+				   ARRAY_SIZE(mdpy_mdev_types));
 	if (ret)
 		goto err_device;
 
diff --git a/samples/vfio-mdev/mtty.c b/samples/vfio-mdev/mtty.c
index 8ba5f6084a093..029a19ef8ce7b 100644
--- a/samples/vfio-mdev/mtty.c
+++ b/samples/vfio-mdev/mtty.c
@@ -143,6 +143,20 @@  struct mdev_state {
 	int nr_ports;
 };
 
+static struct mtty_type {
+	struct mdev_type type;
+	int nr_ports;
+	const char *name;
+} mtty_types[2] = {
+	{ .nr_ports = 1, .type.sysfs_name = "1", .name = "Single port serial" },
+	{ .nr_ports = 2, .type.sysfs_name = "2", .name = "Dual port serial" },
+};
+
+static struct mdev_type *mtty_mdev_types[] = {
+	&mtty_types[0].type,
+	&mtty_types[1].type,
+};
+
 static atomic_t mdev_avail_ports = ATOMIC_INIT(MAX_MTTYS);
 
 static const struct file_operations vd_fops = {
@@ -704,16 +718,18 @@  static ssize_t mdev_access(struct mdev_state *mdev_state, u8 *buf, size_t count,
 
 static int mtty_probe(struct mdev_device *mdev)
 {
+	struct mtty_type *type =
+		container_of(mdev->type, struct mtty_type, type);
 	struct mdev_state *mdev_state;
-	int nr_ports = mdev_get_type_group_id(mdev) + 1;
 	int avail_ports = atomic_read(&mdev_avail_ports);
 	int ret;
 
 	do {
-		if (avail_ports < nr_ports)
+		if (avail_ports < type->nr_ports)
 			return -ENOSPC;
 	} while (!atomic_try_cmpxchg(&mdev_avail_ports,
-				     &avail_ports, avail_ports - nr_ports));
+				     &avail_ports,
+				     avail_ports - type->nr_ports));
 
 	mdev_state = kzalloc(sizeof(struct mdev_state), GFP_KERNEL);
 	if (mdev_state == NULL) {
@@ -723,13 +739,13 @@  static int mtty_probe(struct mdev_device *mdev)
 
 	vfio_init_group_dev(&mdev_state->vdev, &mdev->dev, &mtty_dev_ops);
 
-	mdev_state->nr_ports = nr_ports;
+	mdev_state->nr_ports = type->nr_ports;
 	mdev_state->irq_index = -1;
 	mdev_state->s[0].max_fifo_size = MAX_FIFO_SIZE;
 	mdev_state->s[1].max_fifo_size = MAX_FIFO_SIZE;
 	mutex_init(&mdev_state->rxtx_lock);
-	mdev_state->vconfig = kzalloc(MTTY_CONFIG_SPACE_SIZE, GFP_KERNEL);
 
+	mdev_state->vconfig = kzalloc(MTTY_CONFIG_SPACE_SIZE, GFP_KERNEL);
 	if (mdev_state->vconfig == NULL) {
 		ret = -ENOMEM;
 		goto err_state;
@@ -752,7 +768,7 @@  static int mtty_probe(struct mdev_device *mdev)
 	vfio_uninit_group_dev(&mdev_state->vdev);
 	kfree(mdev_state);
 err_nr_ports:
-	atomic_add(nr_ports, &mdev_avail_ports);
+	atomic_add(type->nr_ports, &mdev_avail_ports);
 	return ret;
 }
 
@@ -1233,11 +1249,9 @@  static const struct attribute_group *mdev_dev_groups[] = {
 static ssize_t name_show(struct mdev_type *mtype,
 			 struct mdev_type_attribute *attr, char *buf)
 {
-	static const char *name_str[2] = { "Single port serial",
-					   "Dual port serial" };
+	struct mtty_type *type = container_of(mtype, struct mtty_type, type);
 
-	return sysfs_emit(buf, "%s\n",
-			  name_str[mtype_get_type_group_id(mtype)]);
+	return sysfs_emit(buf, "%s\n", type->name);
 }
 
 static MDEV_TYPE_ATTR_RO(name);
@@ -1246,9 +1260,10 @@  static ssize_t available_instances_show(struct mdev_type *mtype,
 					struct mdev_type_attribute *attr,
 					char *buf)
 {
-	unsigned int ports = mtype_get_type_group_id(mtype) + 1;
+	struct mtty_type *type = container_of(mtype, struct mtty_type, type);
 
-	return sprintf(buf, "%d\n", atomic_read(&mdev_avail_ports) / ports);
+	return sprintf(buf, "%d\n", atomic_read(&mdev_avail_ports) /
+			type->nr_ports);
 }
 
 static MDEV_TYPE_ATTR_RO(available_instances);
@@ -1261,29 +1276,13 @@  static ssize_t device_api_show(struct mdev_type *mtype,
 
 static MDEV_TYPE_ATTR_RO(device_api);
 
-static struct attribute *mdev_types_attrs[] = {
+static const struct attribute *mdev_types_attrs[] = {
 	&mdev_type_attr_name.attr,
 	&mdev_type_attr_device_api.attr,
 	&mdev_type_attr_available_instances.attr,
 	NULL,
 };
 
-static struct attribute_group mdev_type_group1 = {
-	.name  = "1",
-	.attrs = mdev_types_attrs,
-};
-
-static struct attribute_group mdev_type_group2 = {
-	.name  = "2",
-	.attrs = mdev_types_attrs,
-};
-
-static struct attribute_group *mdev_type_groups[] = {
-	&mdev_type_group1,
-	&mdev_type_group2,
-	NULL,
-};
-
 static const struct vfio_device_ops mtty_dev_ops = {
 	.name = "vfio-mtty",
 	.read = mtty_read,
@@ -1300,7 +1299,7 @@  static struct mdev_driver mtty_driver = {
 	},
 	.probe = mtty_probe,
 	.remove	= mtty_remove,
-	.supported_type_groups = mdev_type_groups,
+	.types_attrs = mdev_types_attrs,
 };
 
 static void mtty_device_release(struct device *dev)
@@ -1352,7 +1351,8 @@  static int __init mtty_dev_init(void)
 		goto err_class;
 
 	ret = mdev_register_parent(&mtty_dev.parent, &mtty_dev.dev,
-				   &mtty_driver);
+				   &mtty_driver, mtty_mdev_types,
+				   ARRAY_SIZE(mtty_mdev_types));
 	if (ret)
 		goto err_device;
 	return 0;