diff mbox series

[v2,3/5] vfio: Don't leak a group reference if the group already exists

Message ID 3-v2-fd9627d27b2b+26c-vfio_group_cdev_jgg@nvidia.com (mailing list archive)
State New, archived
Headers show
Series Update vfio_group to use the modern cdev lifecycle | expand

Commit Message

Jason Gunthorpe Oct. 13, 2021, 2:27 p.m. UTC
If vfio_create_group() searches the group list and returns an already
existing group it does not put back the iommu_group reference that the
caller passed in.

Change the semantic of vfio_create_group() to not move the reference in
from the caller, but instead obtain a new reference inside and leave the
caller's reference alone. The two callers must now call iommu_group_put().

This is an unlikely race as the only caller that could hit it has already
searched the group list before attempting to create the group.

Fixes: cba3345cc494 ("vfio: VFIO core")
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
---
 drivers/vfio/vfio.c | 14 +++++---------
 1 file changed, 5 insertions(+), 9 deletions(-)

Comments

Christoph Hellwig Oct. 13, 2021, 4:09 p.m. UTC | #1
> @@ -775,12 +776,7 @@ static struct vfio_group *vfio_group_find_or_alloc(struct device *dev)
>  	if (group)
>  		goto out_put;
>  
> -	/* a newly created vfio_group keeps the reference. */
>  	group = vfio_create_group(iommu_group, VFIO_IOMMU);
> -	if (IS_ERR(group))
> -		goto out_put;
> -	return group;
> -
>  out_put:
>  	iommu_group_put(iommu_group);
>  	return group;

I'd simplify this down to:

	group = vfio_group_get_from_iommu(iommu_group);
	if (!group)
		group = vfio_create_group(iommu_group, VFIO_IOMMU);

but otherwise this looks good:

Reviewed-by: Christoph Hellwig <hch@lst.de>
Jason Gunthorpe Oct. 13, 2021, 4:18 p.m. UTC | #2
On Wed, Oct 13, 2021 at 06:09:10PM +0200, Christoph Hellwig wrote:
> > @@ -775,12 +776,7 @@ static struct vfio_group *vfio_group_find_or_alloc(struct device *dev)
> >  	if (group)
> >  		goto out_put;
> >  
> > -	/* a newly created vfio_group keeps the reference. */
> >  	group = vfio_create_group(iommu_group, VFIO_IOMMU);
> > -	if (IS_ERR(group))
> > -		goto out_put;
> > -	return group;
> > -
> >  out_put:
> >  	iommu_group_put(iommu_group);
> >  	return group;
> 
> I'd simplify this down to:
> 
> 	group = vfio_group_get_from_iommu(iommu_group);
> 	if (!group)
> 		group = vfio_create_group(iommu_group, VFIO_IOMMU);

Yes, OK,  I changed it into this:

	group = vfio_group_get_from_iommu(iommu_group);
	if (!group)
		group = vfio_create_group(iommu_group, VFIO_IOMMU);

	/* The vfio_group holds a reference to the iommu_group */
	iommu_group_put(iommu_group);
	return group;
}

Which I think is clearer on the comment too

Thanks,
Jason
Tian, Kevin Oct. 14, 2021, 2:08 a.m. UTC | #3
> From: Jason Gunthorpe <jgg@nvidia.com>
> Sent: Thursday, October 14, 2021 12:19 AM
> 
> On Wed, Oct 13, 2021 at 06:09:10PM +0200, Christoph Hellwig wrote:
> > > @@ -775,12 +776,7 @@ static struct vfio_group
> *vfio_group_find_or_alloc(struct device *dev)
> > >  	if (group)
> > >  		goto out_put;
> > >
> > > -	/* a newly created vfio_group keeps the reference. */
> > >  	group = vfio_create_group(iommu_group, VFIO_IOMMU);
> > > -	if (IS_ERR(group))
> > > -		goto out_put;
> > > -	return group;
> > > -
> > >  out_put:
> > >  	iommu_group_put(iommu_group);
> > >  	return group;
> >
> > I'd simplify this down to:
> >
> > 	group = vfio_group_get_from_iommu(iommu_group);
> > 	if (!group)
> > 		group = vfio_create_group(iommu_group, VFIO_IOMMU);
> 
> Yes, OK,  I changed it into this:
> 
> 	group = vfio_group_get_from_iommu(iommu_group);
> 	if (!group)
> 		group = vfio_create_group(iommu_group, VFIO_IOMMU);
> 
> 	/* The vfio_group holds a reference to the iommu_group */
> 	iommu_group_put(iommu_group);
> 	return group;
> }
> 
> Which I think is clearer on the comment too
> 

with above:

Reviewed-by: Kevin Tian <kevin.tian@intel.com>
diff mbox series

Patch

diff --git a/drivers/vfio/vfio.c b/drivers/vfio/vfio.c
index 513fb5a4c102db..fd39eae9516ff6 100644
--- a/drivers/vfio/vfio.c
+++ b/drivers/vfio/vfio.c
@@ -334,6 +334,7 @@  static void vfio_group_unlock_and_free(struct vfio_group *group)
 		list_del(&unbound->unbound_next);
 		kfree(unbound);
 	}
+	iommu_group_put(group->iommu_group);
 	kfree(group);
 }
 
@@ -385,12 +386,15 @@  static struct vfio_group *vfio_create_group(struct iommu_group *iommu_group,
 	atomic_set(&group->opened, 0);
 	init_waitqueue_head(&group->container_q);
 	group->iommu_group = iommu_group;
+	/* put in vfio_group_unlock_and_free() */
+	iommu_group_ref_get(iommu_group);
 	group->type = type;
 	BLOCKING_INIT_NOTIFIER_HEAD(&group->notifier);
 
 	group->nb.notifier_call = vfio_iommu_group_notifier;
 	ret = iommu_group_register_notifier(iommu_group, &group->nb);
 	if (ret) {
+		iommu_group_put(iommu_group);
 		kfree(group);
 		return ERR_PTR(ret);
 	}
@@ -426,7 +430,6 @@  static struct vfio_group *vfio_create_group(struct iommu_group *iommu_group,
 	list_add(&group->vfio_next, &vfio.group_list);
 
 	mutex_unlock(&vfio.group_lock);
-
 	return group;
 }
 
@@ -434,7 +437,6 @@  static struct vfio_group *vfio_create_group(struct iommu_group *iommu_group,
 static void vfio_group_release(struct kref *kref)
 {
 	struct vfio_group *group = container_of(kref, struct vfio_group, kref);
-	struct iommu_group *iommu_group = group->iommu_group;
 
 	/*
 	 * These data structures all have paired operations that can only be
@@ -450,7 +452,6 @@  static void vfio_group_release(struct kref *kref)
 	list_del(&group->vfio_next);
 	vfio_free_group_minor(group->minor);
 	vfio_group_unlock_and_free(group);
-	iommu_group_put(iommu_group);
 }
 
 static void vfio_group_put(struct vfio_group *group)
@@ -735,7 +736,7 @@  static struct vfio_group *vfio_noiommu_group_alloc(struct device *dev,
 		ret = PTR_ERR(group);
 		goto out_remove_device;
 	}
-
+	iommu_group_put(iommu_group);
 	return group;
 
 out_remove_device:
@@ -775,12 +776,7 @@  static struct vfio_group *vfio_group_find_or_alloc(struct device *dev)
 	if (group)
 		goto out_put;
 
-	/* a newly created vfio_group keeps the reference. */
 	group = vfio_create_group(iommu_group, VFIO_IOMMU);
-	if (IS_ERR(group))
-		goto out_put;
-	return group;
-
 out_put:
 	iommu_group_put(iommu_group);
 	return group;