diff mbox series

[v8,12/12] iommu: Use refcount for fault data access

Message ID 20231207064308.313316-13-baolu.lu@linux.intel.com (mailing list archive)
State New, archived
Headers show
Series iommu: Prepare to deliver page faults to user space | expand

Commit Message

Baolu Lu Dec. 7, 2023, 6:43 a.m. UTC
The per-device fault data structure stores information about faults
occurring on a device. Its lifetime spans from IOPF enablement to
disablement. Multiple paths, including IOPF reporting, handling, and
responding, may access it concurrently.

Previously, a mutex protected the fault data from use after free. But
this is not performance friendly due to the critical nature of IOPF
handling paths.

Refine this with a refcount-based approach. The fault data pointer is
obtained within an RCU read region with a refcount. The fault data
pointer is returned for usage only when the pointer is valid and a
refcount is successfully obtained. The fault data is freed with
kfree_rcu(), ensuring data is only freed after all RCU critical regions
complete.

Suggested-by: Jason Gunthorpe <jgg@nvidia.com>
Signed-off-by: Lu Baolu <baolu.lu@linux.intel.com>
Tested-by: Yan Zhao <yan.y.zhao@intel.com>
---
 include/linux/iommu.h      |  4 ++
 drivers/iommu/io-pgfault.c | 81 +++++++++++++++++++++++++-------------
 2 files changed, 57 insertions(+), 28 deletions(-)

Comments

Jason Gunthorpe Dec. 11, 2023, 3:12 p.m. UTC | #1
On Thu, Dec 07, 2023 at 02:43:08PM +0800, Lu Baolu wrote:
> +/*
> + * Return the fault parameter of a device if it exists. Otherwise, return NULL.
> + * On a successful return, the caller takes a reference of this parameter and
> + * should put it after use by calling iopf_put_dev_fault_param().
> + */
> +static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
> +{
> +	struct dev_iommu *param = dev->iommu;
> +	struct iommu_fault_param *fault_param;
> +
> +	if (!param)
> +		return NULL;

Is it actually possible to call this function on a device that does
not have an iommu driver probed? I'd be surprised by that, maybe this
should be WARN_ONE

> +
> +	rcu_read_lock();
> +	fault_param = param->fault_param;

The RCU stuff is not right, like this:

diff --git a/drivers/iommu/io-pgfault.c b/drivers/iommu/io-pgfault.c
index 2ace32c6d13bf3..0258f79c8ddf98 100644
--- a/drivers/iommu/io-pgfault.c
+++ b/drivers/iommu/io-pgfault.c
@@ -40,7 +40,7 @@ static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
 		return NULL;
 
 	rcu_read_lock();
-	fault_param = param->fault_param;
+	fault_param = rcu_dereference(param->fault_param);
 	if (fault_param && !refcount_inc_not_zero(&fault_param->users))
 		fault_param = NULL;
 	rcu_read_unlock();
@@ -51,17 +51,8 @@ static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
 /* Caller must hold a reference of the fault parameter. */
 static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param)
 {
-	struct dev_iommu *param = fault_param->dev->iommu;
-
-	rcu_read_lock();
-	if (!refcount_dec_and_test(&fault_param->users)) {
-		rcu_read_unlock();
-		return;
-	}
-	rcu_read_unlock();
-
-	param->fault_param = NULL;
-	kfree_rcu(fault_param, rcu);
+	if (refcount_dec_and_test(&fault_param->users))
+		kfree_rcu(fault_param, rcu);
 }
 
 /**
@@ -174,7 +165,7 @@ static int iommu_handle_iopf(struct iommu_fault *fault,
 	}
 
 	mutex_unlock(&iopf_param->lock);
-	ret = domain->iopf_handler(group);
+	ret = domain->iopf_handler(iopf_param, group);
 	mutex_lock(&iopf_param->lock);
 	if (ret)
 		iopf_free_group(group);
@@ -398,7 +389,8 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
 
 	mutex_lock(&queue->lock);
 	mutex_lock(&param->lock);
-	if (param->fault_param) {
+	if (rcu_dereference_check(param->fault_param,
+				  lockdep_is_held(&param->lock))) {
 		ret = -EBUSY;
 		goto done_unlock;
 	}
@@ -418,7 +410,7 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
 	list_add(&fault_param->queue_list, &queue->devices);
 	fault_param->queue = queue;
 
-	param->fault_param = fault_param;
+	rcu_assign_pointer(param->fault_param, fault_param);
 
 done_unlock:
 	mutex_unlock(&param->lock);
@@ -442,10 +434,12 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
 	int ret = 0;
 	struct iopf_fault *iopf, *next;
 	struct dev_iommu *param = dev->iommu;
-	struct iommu_fault_param *fault_param = param->fault_param;
+	struct iommu_fault_param *fault_param;
 
 	mutex_lock(&queue->lock);
 	mutex_lock(&param->lock);
+	fault_param = rcu_dereference_check(param->fault_param,
+					    lockdep_is_held(&param->lock));
 	if (!fault_param) {
 		ret = -ENODEV;
 		goto unlock;
@@ -467,7 +461,10 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
 	list_for_each_entry_safe(iopf, next, &fault_param->partial, list)
 		kfree(iopf);
 
-	iopf_put_dev_fault_param(fault_param);
+	/* dec the ref owned by iopf_queue_add_device() */
+	rcu_assign_pointer(param->fault_param, NULL);
+	if (refcount_dec_and_test(&fault_param->users))
+		kfree_rcu(fault_param, rcu);
 unlock:
 	mutex_unlock(&param->lock);
 	mutex_unlock(&queue->lock);
diff --git a/drivers/iommu/iommu-sva.c b/drivers/iommu/iommu-sva.c
index 325d1810e133a1..63c1a233a7e91f 100644
--- a/drivers/iommu/iommu-sva.c
+++ b/drivers/iommu/iommu-sva.c
@@ -232,10 +232,9 @@ static void iommu_sva_handle_iopf(struct work_struct *work)
 	iopf_free_group(group);
 }
 
-static int iommu_sva_iopf_handler(struct iopf_group *group)
+static int iommu_sva_iopf_handler(struct iommu_fault_param *fault_param,
+				  struct iopf_group *group)
 {
-	struct iommu_fault_param *fault_param = group->dev->iommu->fault_param;
-
 	INIT_WORK(&group->work, iommu_sva_handle_iopf);
 	if (!queue_work(fault_param->queue->wq, &group->work))
 		return -EBUSY;
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index 8020bb44a64ab1..e16fa9811d5023 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -41,6 +41,7 @@ struct iommu_dirty_ops;
 struct notifier_block;
 struct iommu_sva;
 struct iommu_dma_cookie;
+struct iommu_fault_param;
 
 #define IOMMU_FAULT_PERM_READ	(1 << 0) /* read */
 #define IOMMU_FAULT_PERM_WRITE	(1 << 1) /* write */
@@ -210,7 +211,8 @@ struct iommu_domain {
 	unsigned long pgsize_bitmap;	/* Bitmap of page sizes in use */
 	struct iommu_domain_geometry geometry;
 	struct iommu_dma_cookie *iova_cookie;
-	int (*iopf_handler)(struct iopf_group *group);
+	int (*iopf_handler)(struct iommu_fault_param *fault_param,
+			    struct iopf_group *group);
 	void *fault_data;
 	union {
 		struct {
@@ -637,7 +639,7 @@ struct iommu_fault_param {
  */
 struct dev_iommu {
 	struct mutex lock;
-	struct iommu_fault_param	*fault_param;
+	struct iommu_fault_param __rcu	*fault_param;
 	struct iommu_fwspec		*fwspec;
 	struct iommu_device		*iommu_dev;
 	void				*priv;
Jason Gunthorpe Dec. 11, 2023, 3:24 p.m. UTC | #2
On Thu, Dec 07, 2023 at 02:43:08PM +0800, Lu Baolu wrote:
> @@ -217,12 +250,9 @@ int iommu_page_response(struct device *dev,
>  	if (!ops->page_response)
>  		return -ENODEV;
>  
> -	mutex_lock(&param->lock);
> -	fault_param = param->fault_param;
> -	if (!fault_param) {
> -		mutex_unlock(&param->lock);
> +	fault_param = iopf_get_dev_fault_param(dev);
> +	if (!fault_param)
>  		return -EINVAL;
> -	}

The refcounting should work by passing around the fault_param object,
not re-obtaining it from the dev from a work.

The work should be locked to the iommu_fault_param that was active
when the work was launched.

When we get to iommu_page_response it does this:

	/* Only send response if there is a fault report pending */
	mutex_lock(&fault_param->lock);
	if (list_empty(&fault_param->faults)) {
		dev_warn_ratelimited(dev, "no pending PRQ, drop response\n");
		goto done_unlock;
	}

Which determines that the iommu_fault_param is stale and pending
free..

Also iopf_queue_remove_device() is messed up - it returns an error
code but nothing ever does anything with it :( Remove functions like
this should never fail.

Removal should be like I explained earlier:
 - Disable new PRI reception
 - Ack all outstanding PRQ to the device
 - Disable PRI on the device
 - Tear down the iopf infrastructure

So under this model if the iopf_queue_remove_device() has been called
it should be sort of a 'disassociate' action where fault_param is
still floating out there but iommu_page_response() does nothing.

IOW pass the refcount from the iommu_report_device_fault() down into
the fault handler, into the work and then into iommu_page_response()
which will ultimately put it back.

> @@ -282,22 +313,15 @@ EXPORT_SYMBOL_GPL(iommu_page_response);
>   */
>  int iopf_queue_flush_dev(struct device *dev)
>  {
> -	int ret = 0;
> -	struct iommu_fault_param *iopf_param;
> -	struct dev_iommu *param = dev->iommu;
> +	struct iommu_fault_param *iopf_param = iopf_get_dev_fault_param(dev);
>  
> -	if (!param)
> +	if (!iopf_param)
>  		return -ENODEV;

And this also seems unnecessary, it is a bug to call this after
iopf_queue_remove_device() right? Just
rcu_derefernce(param->fault_param, true) and WARN_ON NULL.

Jason
Baolu Lu Dec. 12, 2023, 3:44 a.m. UTC | #3
On 12/11/23 11:12 PM, Jason Gunthorpe wrote:
> On Thu, Dec 07, 2023 at 02:43:08PM +0800, Lu Baolu wrote:
>> +/*
>> + * Return the fault parameter of a device if it exists. Otherwise, return NULL.
>> + * On a successful return, the caller takes a reference of this parameter and
>> + * should put it after use by calling iopf_put_dev_fault_param().
>> + */
>> +static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
>> +{
>> +	struct dev_iommu *param = dev->iommu;
>> +	struct iommu_fault_param *fault_param;
>> +
>> +	if (!param)
>> +		return NULL;
> 
> Is it actually possible to call this function on a device that does
> not have an iommu driver probed? I'd be surprised by that, maybe this
> should be WARN_ONE

Above check seems to be unnecessary. This helper should only be used
during the iommu probe and release. We can just remove it as any drivers
that abuse this will generate a null-pointer reference warning.

> 
>> +
>> +	rcu_read_lock();
>> +	fault_param = param->fault_param;
> 
> The RCU stuff is not right, like this:
> 
> diff --git a/drivers/iommu/io-pgfault.c b/drivers/iommu/io-pgfault.c
> index 2ace32c6d13bf3..0258f79c8ddf98 100644
> --- a/drivers/iommu/io-pgfault.c
> +++ b/drivers/iommu/io-pgfault.c
> @@ -40,7 +40,7 @@ static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
>   		return NULL;
>   
>   	rcu_read_lock();
> -	fault_param = param->fault_param;
> +	fault_param = rcu_dereference(param->fault_param);
>   	if (fault_param && !refcount_inc_not_zero(&fault_param->users))
>   		fault_param = NULL;
>   	rcu_read_unlock();
> @@ -51,17 +51,8 @@ static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
>   /* Caller must hold a reference of the fault parameter. */
>   static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param)
>   {
> -	struct dev_iommu *param = fault_param->dev->iommu;
> -
> -	rcu_read_lock();
> -	if (!refcount_dec_and_test(&fault_param->users)) {
> -		rcu_read_unlock();
> -		return;
> -	}
> -	rcu_read_unlock();
> -
> -	param->fault_param = NULL;
> -	kfree_rcu(fault_param, rcu);
> +	if (refcount_dec_and_test(&fault_param->users))
> +		kfree_rcu(fault_param, rcu);
>   }
>   
>   /**
> @@ -174,7 +165,7 @@ static int iommu_handle_iopf(struct iommu_fault *fault,
>   	}
>   
>   	mutex_unlock(&iopf_param->lock);
> -	ret = domain->iopf_handler(group);
> +	ret = domain->iopf_handler(iopf_param, group);
>   	mutex_lock(&iopf_param->lock);
>   	if (ret)
>   		iopf_free_group(group);
> @@ -398,7 +389,8 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
>   
>   	mutex_lock(&queue->lock);
>   	mutex_lock(&param->lock);
> -	if (param->fault_param) {
> +	if (rcu_dereference_check(param->fault_param,
> +				  lockdep_is_held(&param->lock))) {
>   		ret = -EBUSY;
>   		goto done_unlock;
>   	}
> @@ -418,7 +410,7 @@ int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
>   	list_add(&fault_param->queue_list, &queue->devices);
>   	fault_param->queue = queue;
>   
> -	param->fault_param = fault_param;
> +	rcu_assign_pointer(param->fault_param, fault_param);
>   
>   done_unlock:
>   	mutex_unlock(&param->lock);
> @@ -442,10 +434,12 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
>   	int ret = 0;
>   	struct iopf_fault *iopf, *next;
>   	struct dev_iommu *param = dev->iommu;
> -	struct iommu_fault_param *fault_param = param->fault_param;
> +	struct iommu_fault_param *fault_param;
>   
>   	mutex_lock(&queue->lock);
>   	mutex_lock(&param->lock);
> +	fault_param = rcu_dereference_check(param->fault_param,
> +					    lockdep_is_held(&param->lock));
>   	if (!fault_param) {
>   		ret = -ENODEV;
>   		goto unlock;
> @@ -467,7 +461,10 @@ int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
>   	list_for_each_entry_safe(iopf, next, &fault_param->partial, list)
>   		kfree(iopf);
>   
> -	iopf_put_dev_fault_param(fault_param);
> +	/* dec the ref owned by iopf_queue_add_device() */
> +	rcu_assign_pointer(param->fault_param, NULL);
> +	if (refcount_dec_and_test(&fault_param->users))
> +		kfree_rcu(fault_param, rcu);
>   unlock:
>   	mutex_unlock(&param->lock);
>   	mutex_unlock(&queue->lock);
> diff --git a/drivers/iommu/iommu-sva.c b/drivers/iommu/iommu-sva.c
> index 325d1810e133a1..63c1a233a7e91f 100644
> --- a/drivers/iommu/iommu-sva.c
> +++ b/drivers/iommu/iommu-sva.c
> @@ -232,10 +232,9 @@ static void iommu_sva_handle_iopf(struct work_struct *work)
>   	iopf_free_group(group);
>   }
>   
> -static int iommu_sva_iopf_handler(struct iopf_group *group)
> +static int iommu_sva_iopf_handler(struct iommu_fault_param *fault_param,
> +				  struct iopf_group *group)
>   {
> -	struct iommu_fault_param *fault_param = group->dev->iommu->fault_param;
> -
>   	INIT_WORK(&group->work, iommu_sva_handle_iopf);
>   	if (!queue_work(fault_param->queue->wq, &group->work))
>   		return -EBUSY;
> diff --git a/include/linux/iommu.h b/include/linux/iommu.h
> index 8020bb44a64ab1..e16fa9811d5023 100644
> --- a/include/linux/iommu.h
> +++ b/include/linux/iommu.h
> @@ -41,6 +41,7 @@ struct iommu_dirty_ops;
>   struct notifier_block;
>   struct iommu_sva;
>   struct iommu_dma_cookie;
> +struct iommu_fault_param;
>   
>   #define IOMMU_FAULT_PERM_READ	(1 << 0) /* read */
>   #define IOMMU_FAULT_PERM_WRITE	(1 << 1) /* write */
> @@ -210,7 +211,8 @@ struct iommu_domain {
>   	unsigned long pgsize_bitmap;	/* Bitmap of page sizes in use */
>   	struct iommu_domain_geometry geometry;
>   	struct iommu_dma_cookie *iova_cookie;
> -	int (*iopf_handler)(struct iopf_group *group);
> +	int (*iopf_handler)(struct iommu_fault_param *fault_param,
> +			    struct iopf_group *group);

How about folding fault_param into iopf_group?

iopf_group is the central data around a iopf handling. The iopf_group
holds the reference count of the device's fault parameter structure
throughout its entire lifecycle.

>   	void *fault_data;
>   	union {
>   		struct {
> @@ -637,7 +639,7 @@ struct iommu_fault_param {
>    */
>   struct dev_iommu {
>   	struct mutex lock;
> -	struct iommu_fault_param	*fault_param;
> +	struct iommu_fault_param __rcu	*fault_param;
>   	struct iommu_fwspec		*fwspec;
>   	struct iommu_device		*iommu_dev;
>   	void				*priv;

The iommu_page_response() needs to change accordingly which is pointed
out in the next email.

Others look good to me. Thank you so much!

Best regards,
baolu
Baolu Lu Dec. 12, 2023, 5:07 a.m. UTC | #4
On 12/11/23 11:24 PM, Jason Gunthorpe wrote:
> On Thu, Dec 07, 2023 at 02:43:08PM +0800, Lu Baolu wrote:
>> @@ -217,12 +250,9 @@ int iommu_page_response(struct device *dev,
>>   	if (!ops->page_response)
>>   		return -ENODEV;
>>   
>> -	mutex_lock(&param->lock);
>> -	fault_param = param->fault_param;
>> -	if (!fault_param) {
>> -		mutex_unlock(&param->lock);
>> +	fault_param = iopf_get_dev_fault_param(dev);
>> +	if (!fault_param)
>>   		return -EINVAL;
>> -	}
> The refcounting should work by passing around the fault_param object,
> not re-obtaining it from the dev from a work.
> 
> The work should be locked to the iommu_fault_param that was active
> when the work was launched.
> 
> When we get to iommu_page_response it does this:
> 
> 	/* Only send response if there is a fault report pending */
> 	mutex_lock(&fault_param->lock);
> 	if (list_empty(&fault_param->faults)) {
> 		dev_warn_ratelimited(dev, "no pending PRQ, drop response\n");
> 		goto done_unlock;
> 	}
> 
> Which determines that the iommu_fault_param is stale and pending
> free..

Yes, agreed. The iopf_fault_param should be passed in together with the
iopf_group. The reference count should be released in the
iopf_free_group(). These two helps could look like below:

int iommu_page_response(struct iopf_group *group,
			struct iommu_page_response *msg)
{
	bool needs_pasid;
	int ret = -EINVAL;
	struct iopf_fault *evt;
	struct iommu_fault_page_request *prm;
	struct device *dev = group->fault_param->dev;
	const struct iommu_ops *ops = dev_iommu_ops(dev);
	bool has_pasid = msg->flags & IOMMU_PAGE_RESP_PASID_VALID;
	struct iommu_fault_param *fault_param = group->fault_param;

	if (!ops->page_response)
		return -ENODEV;

	/* Only send response if there is a fault report pending */
	mutex_lock(&fault_param->lock);
	if (list_empty(&fault_param->faults)) {
		dev_warn_ratelimited(dev, "no pending PRQ, drop response\n");
		goto done_unlock;
	}
	/*
	 * Check if we have a matching page request pending to respond,
	 * otherwise return -EINVAL
	 */
	list_for_each_entry(evt, &fault_param->faults, list) {
		prm = &evt->fault.prm;
		if (prm->grpid != msg->grpid)
			continue;

		/*
		 * If the PASID is required, the corresponding request is
		 * matched using the group ID, the PASID valid bit and the PASID
		 * value. Otherwise only the group ID matches request and
		 * response.
		 */
		needs_pasid = prm->flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID;
		if (needs_pasid && (!has_pasid || msg->pasid != prm->pasid))
			continue;

		if (!needs_pasid && has_pasid) {
			/* No big deal, just clear it. */
			msg->flags &= ~IOMMU_PAGE_RESP_PASID_VALID;
			msg->pasid = 0;
		}

		ret = ops->page_response(dev, evt, msg);
		list_del(&evt->list);
		kfree(evt);
		break;
	}

done_unlock:
	mutex_unlock(&fault_param->lock);

	return ret;
}

...

void iopf_free_group(struct iopf_group *group)
{
	struct iopf_fault *iopf, *next;

	list_for_each_entry_safe(iopf, next, &group->faults, list) {
		if (!(iopf->fault.prm.flags & IOMMU_FAULT_PAGE_REQUEST_LAST_PAGE))
			kfree(iopf);
	}

	/* Pair with iommu_report_device_fault(). */
	iopf_put_dev_fault_param(group->fault_param);
	kfree(group);
}

Best regards,
baolu
Baolu Lu Dec. 12, 2023, 5:17 a.m. UTC | #5
On 12/11/23 11:24 PM, Jason Gunthorpe wrote:
> Also iopf_queue_remove_device() is messed up - it returns an error
> code but nothing ever does anything with it 
Baolu Lu Dec. 12, 2023, 5:23 a.m. UTC | #6
On 12/11/23 11:24 PM, Jason Gunthorpe wrote:
>> @@ -282,22 +313,15 @@ EXPORT_SYMBOL_GPL(iommu_page_response);
>>    */
>>   int iopf_queue_flush_dev(struct device *dev)
>>   {
>> -	int ret = 0;
>> -	struct iommu_fault_param *iopf_param;
>> -	struct dev_iommu *param = dev->iommu;
>> +	struct iommu_fault_param *iopf_param = iopf_get_dev_fault_param(dev);
>>   
>> -	if (!param)
>> +	if (!iopf_param)
>>   		return -ENODEV;
> And this also seems unnecessary, it is a bug to call this after
> iopf_queue_remove_device() right? Just

Yes. They both are called from the iommu driver. The iommu driver should
guarantee this.

> rcu_derefernce(param->fault_param, true) and WARN_ON NULL.

Okay, sure.

Best regards,
baolu
Jason Gunthorpe Dec. 12, 2023, 3:14 p.m. UTC | #7
On Tue, Dec 12, 2023 at 01:17:47PM +0800, Baolu Lu wrote:
> On 12/11/23 11:24 PM, Jason Gunthorpe wrote:
> > Also iopf_queue_remove_device() is messed up - it returns an error
> > code but nothing ever does anything with it 
Jason Gunthorpe Dec. 12, 2023, 3:18 p.m. UTC | #8
On Tue, Dec 12, 2023 at 01:07:17PM +0800, Baolu Lu wrote:

> Yes, agreed. The iopf_fault_param should be passed in together with the
> iopf_group. The reference count should be released in the
> iopf_free_group(). These two helps could look like below:
> 
> int iommu_page_response(struct iopf_group *group,
> 			struct iommu_page_response *msg)
> {
> 	bool needs_pasid;
> 	int ret = -EINVAL;
> 	struct iopf_fault *evt;
> 	struct iommu_fault_page_request *prm;
> 	struct device *dev = group->fault_param->dev;
> 	const struct iommu_ops *ops = dev_iommu_ops(dev);
> 	bool has_pasid = msg->flags & IOMMU_PAGE_RESP_PASID_VALID;
> 	struct iommu_fault_param *fault_param = group->fault_param;
>
> 	if (!ops->page_response)
> 		return -ENODEV;

We should never get here if this is the case, prevent the device from
being added in the first place

> 	/* Only send response if there is a fault report pending */
> 	mutex_lock(&fault_param->lock);
> 	if (list_empty(&fault_param->faults)) {
> 		dev_warn_ratelimited(dev, "no pending PRQ, drop response\n");
> 		goto done_unlock;
> 	}
> 	/*
> 	 * Check if we have a matching page request pending to respond,
> 	 * otherwise return -EINVAL
> 	 */
> 	list_for_each_entry(evt, &fault_param->faults, list) {
> 		prm = &evt->fault.prm;
> 		if (prm->grpid != msg->grpid)
> 			continue;
> 
> 		/*
> 		 * If the PASID is required, the corresponding request is
> 		 * matched using the group ID, the PASID valid bit and the PASID
> 		 * value. Otherwise only the group ID matches request and
> 		 * response.
> 		 */
> 		needs_pasid = prm->flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID;
> 		if (needs_pasid && (!has_pasid || msg->pasid != prm->pasid))
> 			continue;
> 
> 		if (!needs_pasid && has_pasid) {
> 			/* No big deal, just clear it. */
> 			msg->flags &= ~IOMMU_PAGE_RESP_PASID_VALID;
> 			msg->pasid = 0;
> 		}
> 
> 		ret = ops->page_response(dev, evt, msg);
> 		list_del(&evt->list);
> 		kfree(evt);
> 		break;
> 	}
> 
> done_unlock:
> 	mutex_unlock(&fault_param->lock);

I would have expected the group to free'd here? But regardless this
looks like a good direction

Jason
Jason Gunthorpe Dec. 12, 2023, 3:18 p.m. UTC | #9
On Tue, Dec 12, 2023 at 11:44:14AM +0800, Baolu Lu wrote:
> > @@ -210,7 +211,8 @@ struct iommu_domain {
> >   	unsigned long pgsize_bitmap;	/* Bitmap of page sizes in use */
> >   	struct iommu_domain_geometry geometry;
> >   	struct iommu_dma_cookie *iova_cookie;
> > -	int (*iopf_handler)(struct iopf_group *group);
> > +	int (*iopf_handler)(struct iommu_fault_param *fault_param,
> > +			    struct iopf_group *group);
> 
> How about folding fault_param into iopf_group?
> 
> iopf_group is the central data around a iopf handling. The iopf_group
> holds the reference count of the device's fault parameter structure
> throughout its entire lifecycle.

Yeah, I think that is the right thing to do

Jason
Baolu Lu Dec. 13, 2023, 2:14 a.m. UTC | #10
On 12/12/23 11:14 PM, Jason Gunthorpe wrote:
> On Tue, Dec 12, 2023 at 01:17:47PM +0800, Baolu Lu wrote:
>> On 12/11/23 11:24 PM, Jason Gunthorpe wrote:
>>> Also iopf_queue_remove_device() is messed up - it returns an error
>>> code but nothing ever does anything with it 
Baolu Lu Dec. 13, 2023, 2:19 a.m. UTC | #11
On 12/12/23 11:18 PM, Jason Gunthorpe wrote:
> On Tue, Dec 12, 2023 at 01:07:17PM +0800, Baolu Lu wrote:
> 
>> Yes, agreed. The iopf_fault_param should be passed in together with the
>> iopf_group. The reference count should be released in the
>> iopf_free_group(). These two helps could look like below:
>>
>> int iommu_page_response(struct iopf_group *group,
>> 			struct iommu_page_response *msg)
>> {
>> 	bool needs_pasid;
>> 	int ret = -EINVAL;
>> 	struct iopf_fault *evt;
>> 	struct iommu_fault_page_request *prm;
>> 	struct device *dev = group->fault_param->dev;
>> 	const struct iommu_ops *ops = dev_iommu_ops(dev);
>> 	bool has_pasid = msg->flags & IOMMU_PAGE_RESP_PASID_VALID;
>> 	struct iommu_fault_param *fault_param = group->fault_param;
>>
>> 	if (!ops->page_response)
>> 		return -ENODEV;
> 
> We should never get here if this is the case, prevent the device from
> being added in the first place

Yeah, could move it to iopf_queue_add_device(). WARN and return failure
there if the driver is not ready for page request handling.

> 
>> 	/* Only send response if there is a fault report pending */
>> 	mutex_lock(&fault_param->lock);
>> 	if (list_empty(&fault_param->faults)) {
>> 		dev_warn_ratelimited(dev, "no pending PRQ, drop response\n");
>> 		goto done_unlock;
>> 	}
>> 	/*
>> 	 * Check if we have a matching page request pending to respond,
>> 	 * otherwise return -EINVAL
>> 	 */
>> 	list_for_each_entry(evt, &fault_param->faults, list) {
>> 		prm = &evt->fault.prm;
>> 		if (prm->grpid != msg->grpid)
>> 			continue;
>>
>> 		/*
>> 		 * If the PASID is required, the corresponding request is
>> 		 * matched using the group ID, the PASID valid bit and the PASID
>> 		 * value. Otherwise only the group ID matches request and
>> 		 * response.
>> 		 */
>> 		needs_pasid = prm->flags & IOMMU_FAULT_PAGE_RESPONSE_NEEDS_PASID;
>> 		if (needs_pasid && (!has_pasid || msg->pasid != prm->pasid))
>> 			continue;
>>
>> 		if (!needs_pasid && has_pasid) {
>> 			/* No big deal, just clear it. */
>> 			msg->flags &= ~IOMMU_PAGE_RESP_PASID_VALID;
>> 			msg->pasid = 0;
>> 		}
>>
>> 		ret = ops->page_response(dev, evt, msg);
>> 		list_del(&evt->list);
>> 		kfree(evt);
>> 		break;
>> 	}
>>
>> done_unlock:
>> 	mutex_unlock(&fault_param->lock);
> 
> I would have expected the group to free'd here? But regardless this
> looks like a good direction

Both work for me. We can decide it according to the needs of code later.

> 
> Jason

Best regards,
baolu
diff mbox series

Patch

diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index 63df77cc0b61..8020bb44a64a 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -597,6 +597,8 @@  struct iommu_device {
 /**
  * struct iommu_fault_param - per-device IOMMU fault data
  * @lock: protect pending faults list
+ * @users: user counter to manage the lifetime of the data
+ * @ruc: rcu head for kfree_rcu()
  * @dev: the device that owns this param
  * @queue: IOPF queue
  * @queue_list: index into queue->devices
@@ -606,6 +608,8 @@  struct iommu_device {
  */
 struct iommu_fault_param {
 	struct mutex lock;
+	refcount_t users;
+	struct rcu_head rcu;
 
 	struct device *dev;
 	struct iopf_queue *queue;
diff --git a/drivers/iommu/io-pgfault.c b/drivers/iommu/io-pgfault.c
index 9439eaf54928..2ace32c6d13b 100644
--- a/drivers/iommu/io-pgfault.c
+++ b/drivers/iommu/io-pgfault.c
@@ -26,6 +26,44 @@  void iopf_free_group(struct iopf_group *group)
 }
 EXPORT_SYMBOL_GPL(iopf_free_group);
 
+/*
+ * Return the fault parameter of a device if it exists. Otherwise, return NULL.
+ * On a successful return, the caller takes a reference of this parameter and
+ * should put it after use by calling iopf_put_dev_fault_param().
+ */
+static struct iommu_fault_param *iopf_get_dev_fault_param(struct device *dev)
+{
+	struct dev_iommu *param = dev->iommu;
+	struct iommu_fault_param *fault_param;
+
+	if (!param)
+		return NULL;
+
+	rcu_read_lock();
+	fault_param = param->fault_param;
+	if (fault_param && !refcount_inc_not_zero(&fault_param->users))
+		fault_param = NULL;
+	rcu_read_unlock();
+
+	return fault_param;
+}
+
+/* Caller must hold a reference of the fault parameter. */
+static void iopf_put_dev_fault_param(struct iommu_fault_param *fault_param)
+{
+	struct dev_iommu *param = fault_param->dev->iommu;
+
+	rcu_read_lock();
+	if (!refcount_dec_and_test(&fault_param->users)) {
+		rcu_read_unlock();
+		return;
+	}
+	rcu_read_unlock();
+
+	param->fault_param = NULL;
+	kfree_rcu(fault_param, rcu);
+}
+
 /**
  * iommu_handle_iopf - IO Page Fault handler
  * @fault: fault event
@@ -167,15 +205,11 @@  int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
 {
 	struct iommu_fault_param *fault_param;
 	struct iopf_fault *evt_pending = NULL;
-	struct dev_iommu *param = dev->iommu;
 	int ret = 0;
 
-	mutex_lock(&param->lock);
-	fault_param = param->fault_param;
-	if (!fault_param) {
-		mutex_unlock(&param->lock);
+	fault_param = iopf_get_dev_fault_param(dev);
+	if (!fault_param)
 		return -EINVAL;
-	}
 
 	mutex_lock(&fault_param->lock);
 	if (evt->fault.type == IOMMU_FAULT_PAGE_REQ &&
@@ -196,7 +230,7 @@  int iommu_report_device_fault(struct device *dev, struct iopf_fault *evt)
 	}
 done_unlock:
 	mutex_unlock(&fault_param->lock);
-	mutex_unlock(&param->lock);
+	iopf_put_dev_fault_param(fault_param);
 
 	return ret;
 }
@@ -209,7 +243,6 @@  int iommu_page_response(struct device *dev,
 	int ret = -EINVAL;
 	struct iopf_fault *evt;
 	struct iommu_fault_page_request *prm;
-	struct dev_iommu *param = dev->iommu;
 	struct iommu_fault_param *fault_param;
 	const struct iommu_ops *ops = dev_iommu_ops(dev);
 	bool has_pasid = msg->flags & IOMMU_PAGE_RESP_PASID_VALID;
@@ -217,12 +250,9 @@  int iommu_page_response(struct device *dev,
 	if (!ops->page_response)
 		return -ENODEV;
 
-	mutex_lock(&param->lock);
-	fault_param = param->fault_param;
-	if (!fault_param) {
-		mutex_unlock(&param->lock);
+	fault_param = iopf_get_dev_fault_param(dev);
+	if (!fault_param)
 		return -EINVAL;
-	}
 
 	/* Only send response if there is a fault report pending */
 	mutex_lock(&fault_param->lock);
@@ -263,7 +293,8 @@  int iommu_page_response(struct device *dev,
 
 done_unlock:
 	mutex_unlock(&fault_param->lock);
-	mutex_unlock(&param->lock);
+	iopf_put_dev_fault_param(fault_param);
+
 	return ret;
 }
 EXPORT_SYMBOL_GPL(iommu_page_response);
@@ -282,22 +313,15 @@  EXPORT_SYMBOL_GPL(iommu_page_response);
  */
 int iopf_queue_flush_dev(struct device *dev)
 {
-	int ret = 0;
-	struct iommu_fault_param *iopf_param;
-	struct dev_iommu *param = dev->iommu;
+	struct iommu_fault_param *iopf_param = iopf_get_dev_fault_param(dev);
 
-	if (!param)
+	if (!iopf_param)
 		return -ENODEV;
 
-	mutex_lock(&param->lock);
-	iopf_param = param->fault_param;
-	if (iopf_param)
-		flush_workqueue(iopf_param->queue->wq);
-	else
-		ret = -ENODEV;
-	mutex_unlock(&param->lock);
+	flush_workqueue(iopf_param->queue->wq);
+	iopf_put_dev_fault_param(iopf_param);
 
-	return ret;
+	return 0;
 }
 EXPORT_SYMBOL_GPL(iopf_queue_flush_dev);
 
@@ -389,6 +413,8 @@  int iopf_queue_add_device(struct iopf_queue *queue, struct device *dev)
 	INIT_LIST_HEAD(&fault_param->faults);
 	INIT_LIST_HEAD(&fault_param->partial);
 	fault_param->dev = dev;
+	refcount_set(&fault_param->users, 1);
+	init_rcu_head(&fault_param->rcu);
 	list_add(&fault_param->queue_list, &queue->devices);
 	fault_param->queue = queue;
 
@@ -441,8 +467,7 @@  int iopf_queue_remove_device(struct iopf_queue *queue, struct device *dev)
 	list_for_each_entry_safe(iopf, next, &fault_param->partial, list)
 		kfree(iopf);
 
-	param->fault_param = NULL;
-	kfree(fault_param);
+	iopf_put_dev_fault_param(fault_param);
 unlock:
 	mutex_unlock(&param->lock);
 	mutex_unlock(&queue->lock);