diff mbox series

[v5,05/12] iommufd: Pass pasid through the device attach/replace path

Message ID 20241104132513.15890-6-yi.l.liu@intel.com (mailing list archive)
State New
Headers show
Series iommufd support pasid attach/replace | expand

Commit Message

Yi Liu Nov. 4, 2024, 1:25 p.m. UTC
Most of the core logic before conducting the actual device attach/
replace operation can be shared with pasid attach/replace. So pass
pasid through the device attach/replace helpers to prepare adding
pasid attach/replace.

So far the @pasid should only be IOMMU_NO_PASID. No functional change.

Signed-off-by: Kevin Tian <kevin.tian@intel.com>
Signed-off-by: Yi Liu <yi.l.liu@intel.com>
---
 drivers/iommu/iommufd/device.c          | 55 ++++++++++++++-----------
 drivers/iommu/iommufd/fault.c           | 16 ++++---
 drivers/iommu/iommufd/hw_pagetable.c    |  5 ++-
 drivers/iommu/iommufd/iommufd_private.h | 41 +++++++++++-------
 4 files changed, 71 insertions(+), 46 deletions(-)
diff mbox series

Patch

diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c
index 823c81145214..0b3f2094af4a 100644
--- a/drivers/iommu/iommufd/device.c
+++ b/drivers/iommu/iommufd/device.c
@@ -294,11 +294,12 @@  u32 iommufd_device_to_id(struct iommufd_device *idev)
 EXPORT_SYMBOL_NS_GPL(iommufd_device_to_id, IOMMUFD);
 
 struct iommufd_attach_handle *
-iommufd_device_get_attach_handle(struct iommufd_device *idev)
+iommufd_device_get_attach_handle(struct iommufd_device *idev, ioasid_t pasid)
 {
 	struct iommu_attach_handle *handle;
 
-	handle = iommu_attach_handle_get(idev->igroup->group, IOMMU_NO_PASID, 0);
+	WARN_ON(pasid != IOMMU_NO_PASID);
+	handle = iommu_attach_handle_get(idev->igroup->group, pasid, 0);
 	if (IS_ERR(handle))
 		return NULL;
 
@@ -306,7 +307,8 @@  iommufd_device_get_attach_handle(struct iommufd_device *idev)
 }
 
 int iommufd_dev_attach_handle(struct iommufd_hw_pagetable *hwpt,
-			      struct iommufd_device *idev)
+			      struct iommufd_device *idev,
+			      ioasid_t pasid)
 {
 	struct iommufd_attach_handle *handle;
 	int ret;
@@ -316,6 +318,7 @@  int iommufd_dev_attach_handle(struct iommufd_hw_pagetable *hwpt,
 		return -ENOMEM;
 
 	handle->idev = idev;
+	WARN_ON(pasid != IOMMU_NO_PASID);
 	ret = iommu_attach_group_handle(hwpt->domain, idev->igroup->group,
 					&handle->handle);
 	if (ret)
@@ -325,6 +328,7 @@  int iommufd_dev_attach_handle(struct iommufd_hw_pagetable *hwpt,
 }
 
 int iommufd_dev_replace_handle(struct iommufd_device *idev,
+			       ioasid_t pasid,
 			       struct iommufd_hw_pagetable *hwpt,
 			       struct iommufd_hw_pagetable *old)
 {
@@ -336,6 +340,7 @@  int iommufd_dev_replace_handle(struct iommufd_device *idev,
 		return -ENOMEM;
 
 	handle->idev = idev;
+	WARN_ON(pasid != IOMMU_NO_PASID);
 	ret = iommu_replace_group_handle(idev->igroup->group,
 					 hwpt->domain, &handle->handle);
 	if (ret)
@@ -404,7 +409,8 @@  iommufd_device_attach_reserved_iova(struct iommufd_device *idev,
 }
 
 int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
-				struct iommufd_device *idev)
+				struct iommufd_device *idev,
+				ioasid_t pasid)
 {
 	struct iommufd_hwpt_paging *hwpt_paging = find_hwpt_paging(hwpt);
 	int rc;
@@ -430,7 +436,7 @@  int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
 	 * attachment.
 	 */
 	if (list_empty(&idev->igroup->device_list)) {
-		rc = iommufd_hwpt_attach_device(hwpt, idev);
+		rc = iommufd_hwpt_attach_device(hwpt, idev, pasid);
 		if (rc)
 			goto err_unresv;
 		idev->igroup->hwpt = hwpt;
@@ -448,7 +454,7 @@  int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
 }
 
 struct iommufd_hw_pagetable *
-iommufd_hw_pagetable_detach(struct iommufd_device *idev)
+iommufd_hw_pagetable_detach(struct iommufd_device *idev, ioasid_t pasid)
 {
 	struct iommufd_hw_pagetable *hwpt = idev->igroup->hwpt;
 	struct iommufd_hwpt_paging *hwpt_paging = find_hwpt_paging(hwpt);
@@ -456,7 +462,7 @@  iommufd_hw_pagetable_detach(struct iommufd_device *idev)
 	mutex_lock(&idev->igroup->lock);
 	list_del(&idev->group_item);
 	if (list_empty(&idev->igroup->device_list)) {
-		iommufd_hwpt_detach_device(hwpt, idev);
+		iommufd_hwpt_detach_device(hwpt, idev, pasid);
 		idev->igroup->hwpt = NULL;
 	}
 	if (hwpt_paging)
@@ -468,12 +474,12 @@  iommufd_hw_pagetable_detach(struct iommufd_device *idev)
 }
 
 static struct iommufd_hw_pagetable *
-iommufd_device_do_attach(struct iommufd_device *idev,
+iommufd_device_do_attach(struct iommufd_device *idev, ioasid_t pasid,
 			 struct iommufd_hw_pagetable *hwpt)
 {
 	int rc;
 
-	rc = iommufd_hw_pagetable_attach(hwpt, idev);
+	rc = iommufd_hw_pagetable_attach(hwpt, idev, pasid);
 	if (rc)
 		return ERR_PTR(rc);
 	return NULL;
@@ -522,7 +528,7 @@  iommufd_group_do_replace_reserved_iova(struct iommufd_group *igroup,
 }
 
 static struct iommufd_hw_pagetable *
-iommufd_device_do_replace(struct iommufd_device *idev,
+iommufd_device_do_replace(struct iommufd_device *idev, ioasid_t pasid,
 			  struct iommufd_hw_pagetable *hwpt)
 {
 	struct iommufd_hwpt_paging *hwpt_paging = find_hwpt_paging(hwpt);
@@ -551,7 +557,7 @@  iommufd_device_do_replace(struct iommufd_device *idev,
 			goto err_unlock;
 	}
 
-	rc = iommufd_hwpt_replace_device(idev, hwpt, old_hwpt);
+	rc = iommufd_hwpt_replace_device(idev, pasid, hwpt, old_hwpt);
 	if (rc)
 		goto err_unresv;
 
@@ -584,7 +590,8 @@  iommufd_device_do_replace(struct iommufd_device *idev,
 }
 
 typedef struct iommufd_hw_pagetable *(*attach_fn)(
-	struct iommufd_device *idev, struct iommufd_hw_pagetable *hwpt);
+			struct iommufd_device *idev, ioasid_t pasid,
+			struct iommufd_hw_pagetable *hwpt);
 
 /*
  * When automatically managing the domains we search for a compatible domain in
@@ -592,7 +599,7 @@  typedef struct iommufd_hw_pagetable *(*attach_fn)(
  * Automatic domain selection will never pick a manually created domain.
  */
 static struct iommufd_hw_pagetable *
-iommufd_device_auto_get_domain(struct iommufd_device *idev,
+iommufd_device_auto_get_domain(struct iommufd_device *idev, ioasid_t pasid,
 			       struct iommufd_ioas *ioas, u32 *pt_id,
 			       attach_fn do_attach)
 {
@@ -621,7 +628,7 @@  iommufd_device_auto_get_domain(struct iommufd_device *idev,
 		hwpt = &hwpt_paging->common;
 		if (!iommufd_lock_obj(&hwpt->obj))
 			continue;
-		destroy_hwpt = (*do_attach)(idev, hwpt);
+		destroy_hwpt = (*do_attach)(idev, pasid, hwpt);
 		if (IS_ERR(destroy_hwpt)) {
 			iommufd_put_object(idev->ictx, &hwpt->obj);
 			/*
@@ -648,7 +655,7 @@  iommufd_device_auto_get_domain(struct iommufd_device *idev,
 	hwpt = &hwpt_paging->common;
 
 	if (!immediate_attach) {
-		destroy_hwpt = (*do_attach)(idev, hwpt);
+		destroy_hwpt = (*do_attach)(idev, pasid, hwpt);
 		if (IS_ERR(destroy_hwpt))
 			goto out_abort;
 	} else {
@@ -669,8 +676,9 @@  iommufd_device_auto_get_domain(struct iommufd_device *idev,
 	return destroy_hwpt;
 }
 
-static int iommufd_device_change_pt(struct iommufd_device *idev, u32 *pt_id,
-				    attach_fn do_attach)
+static int iommufd_device_change_pt(struct iommufd_device *idev,
+				    ioasid_t pasid,
+				    u32 *pt_id, attach_fn do_attach)
 {
 	struct iommufd_hw_pagetable *destroy_hwpt;
 	struct iommufd_object *pt_obj;
@@ -685,7 +693,7 @@  static int iommufd_device_change_pt(struct iommufd_device *idev, u32 *pt_id,
 		struct iommufd_hw_pagetable *hwpt =
 			container_of(pt_obj, struct iommufd_hw_pagetable, obj);
 
-		destroy_hwpt = (*do_attach)(idev, hwpt);
+		destroy_hwpt = (*do_attach)(idev, pasid, hwpt);
 		if (IS_ERR(destroy_hwpt))
 			goto out_put_pt_obj;
 		break;
@@ -694,8 +702,8 @@  static int iommufd_device_change_pt(struct iommufd_device *idev, u32 *pt_id,
 		struct iommufd_ioas *ioas =
 			container_of(pt_obj, struct iommufd_ioas, obj);
 
-		destroy_hwpt = iommufd_device_auto_get_domain(idev, ioas, pt_id,
-							      do_attach);
+		destroy_hwpt = iommufd_device_auto_get_domain(idev, pasid, ioas,
+							      pt_id, do_attach);
 		if (IS_ERR(destroy_hwpt))
 			goto out_put_pt_obj;
 		break;
@@ -732,7 +740,8 @@  int iommufd_device_attach(struct iommufd_device *idev, u32 *pt_id)
 {
 	int rc;
 
-	rc = iommufd_device_change_pt(idev, pt_id, &iommufd_device_do_attach);
+	rc = iommufd_device_change_pt(idev, IOMMU_NO_PASID, pt_id,
+				      &iommufd_device_do_attach);
 	if (rc)
 		return rc;
 
@@ -762,7 +771,7 @@  EXPORT_SYMBOL_NS_GPL(iommufd_device_attach, IOMMUFD);
  */
 int iommufd_device_replace(struct iommufd_device *idev, u32 *pt_id)
 {
-	return iommufd_device_change_pt(idev, pt_id,
+	return iommufd_device_change_pt(idev, IOMMU_NO_PASID, pt_id,
 					&iommufd_device_do_replace);
 }
 EXPORT_SYMBOL_NS_GPL(iommufd_device_replace, IOMMUFD);
@@ -778,7 +787,7 @@  void iommufd_device_detach(struct iommufd_device *idev)
 {
 	struct iommufd_hw_pagetable *hwpt;
 
-	hwpt = iommufd_hw_pagetable_detach(idev);
+	hwpt = iommufd_hw_pagetable_detach(idev, IOMMU_NO_PASID);
 	iommufd_hw_pagetable_put(idev->ictx, hwpt);
 	refcount_dec(&idev->obj.users);
 }
diff --git a/drivers/iommu/iommufd/fault.c b/drivers/iommu/iommufd/fault.c
index 55418a067869..3b60349e2913 100644
--- a/drivers/iommu/iommufd/fault.c
+++ b/drivers/iommu/iommufd/fault.c
@@ -56,7 +56,8 @@  static void iommufd_fault_iopf_disable(struct iommufd_device *idev)
 }
 
 int iommufd_fault_domain_attach_dev(struct iommufd_hw_pagetable *hwpt,
-				    struct iommufd_device *idev)
+				    struct iommufd_device *idev,
+				    ioasid_t pasid)
 {
 	int ret;
 
@@ -67,7 +68,7 @@  int iommufd_fault_domain_attach_dev(struct iommufd_hw_pagetable *hwpt,
 	if (ret)
 		return ret;
 
-	ret = iommufd_dev_attach_handle(hwpt, idev);
+	ret = iommufd_dev_attach_handle(hwpt, idev, pasid);
 	if (ret)
 		iommufd_fault_iopf_disable(idev);
 
@@ -104,11 +105,13 @@  static void iommufd_auto_response_faults(struct iommufd_hw_pagetable *hwpt,
 }
 
 void iommufd_fault_domain_detach_dev(struct iommufd_hw_pagetable *hwpt,
-				     struct iommufd_device *idev)
+				     struct iommufd_device *idev,
+				     ioasid_t pasid)
 {
 	struct iommufd_attach_handle *handle;
 
-	handle = iommufd_device_get_attach_handle(idev);
+	handle = iommufd_device_get_attach_handle(idev, pasid);
+	WARN_ON(pasid != IOMMU_NO_PASID);
 	iommu_detach_group_handle(hwpt->domain, idev->igroup->group);
 	iommufd_auto_response_faults(hwpt, handle);
 	iommufd_fault_iopf_disable(idev);
@@ -116,6 +119,7 @@  void iommufd_fault_domain_detach_dev(struct iommufd_hw_pagetable *hwpt,
 }
 
 int iommufd_fault_domain_replace_dev(struct iommufd_device *idev,
+				     ioasid_t pasid,
 				     struct iommufd_hw_pagetable *hwpt,
 				     struct iommufd_hw_pagetable *old)
 {
@@ -130,9 +134,9 @@  int iommufd_fault_domain_replace_dev(struct iommufd_device *idev,
 			return ret;
 	}
 
-	curr = iommufd_device_get_attach_handle(idev);
+	curr = iommufd_device_get_attach_handle(idev, pasid);
 
-	ret = iommufd_dev_replace_handle(idev, hwpt, old);
+	ret = iommufd_dev_replace_handle(idev, pasid, hwpt, old);
 	if (ret) {
 		if (iopf_on)
 			iommufd_fault_iopf_disable(idev);
diff --git a/drivers/iommu/iommufd/hw_pagetable.c b/drivers/iommu/iommufd/hw_pagetable.c
index d06bf6e6c19f..48639427749b 100644
--- a/drivers/iommu/iommufd/hw_pagetable.c
+++ b/drivers/iommu/iommufd/hw_pagetable.c
@@ -180,7 +180,8 @@  iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
 	 * sequence. Once those drivers are fixed this should be removed.
 	 */
 	if (immediate_attach) {
-		rc = iommufd_hw_pagetable_attach(hwpt, idev);
+		/* Sinc this is just a trick, so passing IOMMU_NO_PASID is enough */
+		rc = iommufd_hw_pagetable_attach(hwpt, idev, IOMMU_NO_PASID);
 		if (rc)
 			goto out_abort;
 	}
@@ -193,7 +194,7 @@  iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
 
 out_detach:
 	if (immediate_attach)
-		iommufd_hw_pagetable_detach(idev);
+		iommufd_hw_pagetable_detach(idev, IOMMU_NO_PASID);
 out_abort:
 	iommufd_object_abort_and_destroy(ictx, &hwpt->obj);
 	return ERR_PTR(rc);
diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h
index 19870b08056e..8e7265885f36 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -369,9 +369,10 @@  iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
 			  bool immediate_attach,
 			  const struct iommu_user_data *user_data);
 int iommufd_hw_pagetable_attach(struct iommufd_hw_pagetable *hwpt,
-				struct iommufd_device *idev);
+				struct iommufd_device *idev,
+				ioasid_t pasid);
 struct iommufd_hw_pagetable *
-iommufd_hw_pagetable_detach(struct iommufd_device *idev);
+iommufd_hw_pagetable_detach(struct iommufd_device *idev, ioasid_t pasid);
 void iommufd_hwpt_paging_destroy(struct iommufd_object *obj);
 void iommufd_hwpt_paging_abort(struct iommufd_object *obj);
 void iommufd_hwpt_nested_destroy(struct iommufd_object *obj);
@@ -479,10 +480,12 @@  struct iommufd_attach_handle {
 #define to_iommufd_handle(hdl)	container_of(hdl, struct iommufd_attach_handle, handle)
 
 struct iommufd_attach_handle *
-iommufd_device_get_attach_handle(struct iommufd_device *idev);
+iommufd_device_get_attach_handle(struct iommufd_device *idev, ioasid_t pasid);
 int iommufd_dev_attach_handle(struct iommufd_hw_pagetable *hwpt,
-			      struct iommufd_device *idev);
+			      struct iommufd_device *idev,
+			      ioasid_t pasid);
 int iommufd_dev_replace_handle(struct iommufd_device *idev,
+			       ioasid_t pasid,
 			       struct iommufd_hw_pagetable *hwpt,
 			       struct iommufd_hw_pagetable *old);
 
@@ -499,38 +502,45 @@  void iommufd_fault_destroy(struct iommufd_object *obj);
 int iommufd_fault_iopf_handler(struct iopf_group *group);
 
 int iommufd_fault_domain_attach_dev(struct iommufd_hw_pagetable *hwpt,
-				    struct iommufd_device *idev);
+				    struct iommufd_device *idev,
+				    ioasid_t pasid);
 void iommufd_fault_domain_detach_dev(struct iommufd_hw_pagetable *hwpt,
-				     struct iommufd_device *idev);
+				     struct iommufd_device *idev,
+				     ioasid_t pasid);
 int iommufd_fault_domain_replace_dev(struct iommufd_device *idev,
+				     ioasid_t pasid,
 				     struct iommufd_hw_pagetable *hwpt,
 				     struct iommufd_hw_pagetable *old);
 
 static inline int iommufd_hwpt_attach_device(struct iommufd_hw_pagetable *hwpt,
-					     struct iommufd_device *idev)
+					     struct iommufd_device *idev,
+					     ioasid_t pasid)
 {
 	if (hwpt->fault)
-		return iommufd_fault_domain_attach_dev(hwpt, idev);
+		return iommufd_fault_domain_attach_dev(hwpt, idev, pasid);
 
-	return iommufd_dev_attach_handle(hwpt, idev);
+	return iommufd_dev_attach_handle(hwpt, idev, pasid);
 }
 
 static inline void iommufd_hwpt_detach_device(struct iommufd_hw_pagetable *hwpt,
-					      struct iommufd_device *idev)
+					      struct iommufd_device *idev,
+					      ioasid_t pasid)
 {
 	struct iommufd_attach_handle *handle;
 
 	if (hwpt->fault) {
-		iommufd_fault_domain_detach_dev(hwpt, idev);
+		iommufd_fault_domain_detach_dev(hwpt, idev, pasid);
 		return;
 	}
 
-	handle = iommufd_device_get_attach_handle(idev);
+	handle = iommufd_device_get_attach_handle(idev, pasid);
+	WARN_ON(pasid != IOMMU_NO_PASID);
 	iommu_detach_group_handle(hwpt->domain, idev->igroup->group);
 	kfree(handle);
 }
 
 static inline int iommufd_hwpt_replace_device(struct iommufd_device *idev,
+					      ioasid_t pasid,
 					      struct iommufd_hw_pagetable *hwpt,
 					      struct iommufd_hw_pagetable *old)
 {
@@ -538,11 +548,12 @@  static inline int iommufd_hwpt_replace_device(struct iommufd_device *idev,
 	int ret;
 
 	if (old->fault || hwpt->fault)
-		return iommufd_fault_domain_replace_dev(idev, hwpt, old);
+		return iommufd_fault_domain_replace_dev(idev, pasid,
+							hwpt, old);
 
-	curr = iommufd_device_get_attach_handle(idev);
+	curr = iommufd_device_get_attach_handle(idev, pasid);
 
-	ret = iommufd_dev_replace_handle(idev, hwpt, old);
+	ret = iommufd_dev_replace_handle(idev, pasid, hwpt, old);
 	if (ret)
 		return ret;