diff mbox series

[v5,08/12] iommufd: Enforce pasid compatible domain for PASID-capable device

Message ID 20241104132513.15890-9-yi.l.liu@intel.com (mailing list archive)
State New
Headers show
Series iommufd support pasid attach/replace | expand

Commit Message

Yi Liu Nov. 4, 2024, 1:25 p.m. UTC
iommu hw may have special requirement on the domain attached to PASID
capable device. e.g. AMD IOMMU requires the domain allocated with the
IOMMU_HWPT_ALLOC_PASID flag. Hence, iommufd should enforce it when the
domain is used by PASID-capable device.

Signed-off-by: Yi Liu <yi.l.liu@intel.com>
---
 drivers/iommu/intel/iommu.c             | 6 ++++--
 drivers/iommu/iommufd/hw_pagetable.c    | 7 +++++--
 drivers/iommu/iommufd/iommufd_private.h | 7 +++++++
 3 files changed, 16 insertions(+), 4 deletions(-)
diff mbox series

Patch

diff --git a/drivers/iommu/intel/iommu.c b/drivers/iommu/intel/iommu.c
index a1341078b962..d24e21a757ff 100644
--- a/drivers/iommu/intel/iommu.c
+++ b/drivers/iommu/intel/iommu.c
@@ -3545,13 +3545,15 @@  intel_iommu_domain_alloc_user(struct device *dev, u32 flags,
 
 	/* Must be NESTING domain */
 	if (parent) {
-		if (!nested_supported(iommu) || flags)
+		if (!nested_supported(iommu) ||
+		    flags & ~IOMMU_HWPT_ALLOC_PASID)
 			return ERR_PTR(-EOPNOTSUPP);
 		return intel_nested_domain_alloc(parent, user_data);
 	}
 
 	if (flags &
-	    (~(IOMMU_HWPT_ALLOC_NEST_PARENT | IOMMU_HWPT_ALLOC_DIRTY_TRACKING)))
+	    (~(IOMMU_HWPT_ALLOC_NEST_PARENT | IOMMU_HWPT_ALLOC_DIRTY_TRACKING |
+	       IOMMU_HWPT_ALLOC_PASID)))
 		return ERR_PTR(-EOPNOTSUPP);
 	if (nested_parent && !nested_supported(iommu))
 		return ERR_PTR(-EOPNOTSUPP);
diff --git a/drivers/iommu/iommufd/hw_pagetable.c b/drivers/iommu/iommufd/hw_pagetable.c
index 48639427749b..e4932a5a87ea 100644
--- a/drivers/iommu/iommufd/hw_pagetable.c
+++ b/drivers/iommu/iommufd/hw_pagetable.c
@@ -107,7 +107,8 @@  iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
 			  const struct iommu_user_data *user_data)
 {
 	const u32 valid_flags = IOMMU_HWPT_ALLOC_NEST_PARENT |
-				IOMMU_HWPT_ALLOC_DIRTY_TRACKING;
+				IOMMU_HWPT_ALLOC_DIRTY_TRACKING |
+				IOMMU_HWPT_ALLOC_PASID;
 	const struct iommu_ops *ops = dev_iommu_ops(idev->dev);
 	struct iommufd_hwpt_paging *hwpt_paging;
 	struct iommufd_hw_pagetable *hwpt;
@@ -128,6 +129,7 @@  iommufd_hwpt_paging_alloc(struct iommufd_ctx *ictx, struct iommufd_ioas *ioas,
 	if (IS_ERR(hwpt_paging))
 		return ERR_CAST(hwpt_paging);
 	hwpt = &hwpt_paging->common;
+	hwpt->pasid_compat = flags & IOMMU_HWPT_ALLOC_PASID;
 
 	INIT_LIST_HEAD(&hwpt_paging->hwpt_item);
 	/* Pairs with iommufd_hw_pagetable_destroy() */
@@ -223,7 +225,7 @@  iommufd_hwpt_nested_alloc(struct iommufd_ctx *ictx,
 	struct iommufd_hw_pagetable *hwpt;
 	int rc;
 
-	if ((flags & ~IOMMU_HWPT_FAULT_ID_VALID) ||
+	if ((flags & ~(IOMMU_HWPT_FAULT_ID_VALID | IOMMU_HWPT_ALLOC_PASID)) ||
 	    !user_data->len || !ops->domain_alloc_user)
 		return ERR_PTR(-EOPNOTSUPP);
 	if (parent->auto_domain || !parent->nest_parent ||
@@ -235,6 +237,7 @@  iommufd_hwpt_nested_alloc(struct iommufd_ctx *ictx,
 	if (IS_ERR(hwpt_nested))
 		return ERR_CAST(hwpt_nested);
 	hwpt = &hwpt_nested->common;
+	hwpt->pasid_compat = flags & IOMMU_HWPT_ALLOC_PASID;
 
 	refcount_inc(&parent->common.obj.users);
 	hwpt_nested->parent = parent;
diff --git a/drivers/iommu/iommufd/iommufd_private.h b/drivers/iommu/iommufd/iommufd_private.h
index 11773cef5acc..81a95f869e10 100644
--- a/drivers/iommu/iommufd/iommufd_private.h
+++ b/drivers/iommu/iommufd/iommufd_private.h
@@ -296,6 +296,7 @@  struct iommufd_hw_pagetable {
 	struct iommufd_object obj;
 	struct iommu_domain *domain;
 	struct iommufd_fault *fault;
+	bool pasid_compat : 1;
 };
 
 struct iommufd_hwpt_paging {
@@ -531,6 +532,9 @@  static inline int iommufd_hwpt_attach_device(struct iommufd_hw_pagetable *hwpt,
 					     struct iommufd_device *idev,
 					     ioasid_t pasid)
 {
+	if (idev->dev->iommu->max_pasids && !hwpt->pasid_compat)
+		return -EINVAL;
+
 	if (hwpt->fault)
 		return iommufd_fault_domain_attach_dev(hwpt, idev, pasid);
 
@@ -564,6 +568,9 @@  static inline int iommufd_hwpt_replace_device(struct iommufd_device *idev,
 	struct iommufd_attach_handle *curr;
 	int ret;
 
+	if (idev->dev->iommu->max_pasids && !hwpt->pasid_compat)
+		return -EINVAL;
+
 	if (old->fault || hwpt->fault)
 		return iommufd_fault_domain_replace_dev(idev, pasid,
 							hwpt, old);