@@ -22,6 +22,10 @@ enum {
IOMMU_TEST_OP_MOCK_DOMAIN_FLAGS,
IOMMU_TEST_OP_DIRTY,
IOMMU_TEST_OP_MD_CHECK_IOTLB,
+ IOMMU_TEST_OP_PASID_ATTACH,
+ IOMMU_TEST_OP_PASID_REPLACE,
+ IOMMU_TEST_OP_PASID_DETACH,
+ IOMMU_TEST_OP_PASID_CHECK_DOMAIN,
};
enum {
@@ -127,6 +131,32 @@ struct iommu_test_cmd {
__u32 id;
__u32 iotlb;
} check_iotlb;
+ struct {
+ __u32 pasid;
+ __u32 pt_id;
+ /* @id is stdev_id for IOMMU_TEST_OP_PASID_ATTACH
+ * pasid#1024 is for special test, avoid use it
+ * in normal case.
+ */
+ } pasid_attach;
+ struct {
+ __u32 pasid;
+ __u32 pt_id;
+ /* @id is stdev_id for IOMMU_TEST_OP_PASID_ATTACH
+ * pasid#1024 is for special test, avoid use it
+ * in normal case.
+ */
+ } pasid_replace;
+ struct {
+ __u32 pasid;
+ /* @id is stdev_id for IOMMU_TEST_OP_PASID_DETACH */
+ } pasid_detach;
+ struct {
+ __u32 pasid;
+ __u32 hwpt_id;
+ __u64 out_result_ptr;
+ /* @id is stdev_id for IOMMU_TEST_OP_HWPT_GET_DOMAIN */
+ } pasid_check;
};
__u32 last;
};
@@ -516,6 +516,8 @@ static struct iommu_device *mock_probe_device(struct device *dev)
return &mock_iommu_device;
}
+static bool pasid_1024_attached;
+
static void mock_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
struct iommu_domain *domain)
{
@@ -524,6 +526,8 @@ static void mock_iommu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
switch (domain->type) {
case IOMMU_DOMAIN_NESTED:
case IOMMU_DOMAIN_UNMANAGED:
+ if (pasid == 1024)
+ pasid_1024_attached = false;
break;
default:
/* should never reach here */
@@ -537,6 +541,20 @@ static int mock_domain_set_dev_pasid_nop(struct iommu_domain *domain,
struct device *dev, ioasid_t pasid,
struct iommu_domain *old)
{
+ /*
+ * First attach with pasid 1024 succ, second attach would fail,
+ * and so on. This is helpful to test the case in which the iommu
+ * layer needs to rollback to old domain due to driver failure.
+ */
+ if (pasid == 1024) {
+ if (pasid_1024_attached) {
+ pasid_1024_attached = false;
+ // Fake an error to fail the replacement
+ return -ENOMEM;
+ }
+ pasid_1024_attached = true;
+ }
+
return 0;
}
@@ -1414,6 +1432,114 @@ static int iommufd_test_dirty(struct iommufd_ucmd *ucmd, unsigned int mockpt_id,
return rc;
}
+static int iommufd_test_pasid_attach(struct iommufd_ucmd *ucmd,
+ struct iommu_test_cmd *cmd)
+{
+ struct selftest_obj *sobj;
+ int rc;
+
+ sobj = iommufd_test_get_self_test_device(ucmd->ictx, cmd->id);
+ if (IS_ERR(sobj))
+ return PTR_ERR(sobj);
+
+ rc = iommufd_device_pasid_attach(sobj->idev.idev,
+ cmd->pasid_attach.pasid,
+ &cmd->pasid_attach.pt_id);
+ iommufd_put_object(ucmd->ictx, &sobj->obj);
+ return rc;
+}
+
+static int iommufd_test_pasid_replace(struct iommufd_ucmd *ucmd,
+ struct iommu_test_cmd *cmd)
+{
+ struct selftest_obj *sobj;
+ int rc;
+
+ sobj = iommufd_test_get_self_test_device(ucmd->ictx, cmd->id);
+ if (IS_ERR(sobj))
+ return PTR_ERR(sobj);
+
+ rc = iommufd_device_pasid_replace(sobj->idev.idev,
+ cmd->pasid_attach.pasid,
+ &cmd->pasid_attach.pt_id);
+ iommufd_put_object(ucmd->ictx, &sobj->obj);
+ return rc;
+}
+
+static int iommufd_test_pasid_detach(struct iommufd_ucmd *ucmd,
+ struct iommu_test_cmd *cmd)
+{
+ struct selftest_obj *sobj;
+
+ sobj = iommufd_test_get_self_test_device(ucmd->ictx, cmd->id);
+ if (IS_ERR(sobj))
+ return PTR_ERR(sobj);
+
+ iommufd_device_pasid_detach(sobj->idev.idev,
+ cmd->pasid_detach.pasid);
+ iommufd_put_object(ucmd->ictx, &sobj->obj);
+ return 0;
+}
+
+static inline struct iommufd_hw_pagetable *
+iommufd_get_hwpt(struct iommufd_ucmd *ucmd, u32 id)
+{
+ struct iommufd_object *pt_obj;
+
+ pt_obj = iommufd_get_object(ucmd->ictx, id, IOMMUFD_OBJ_ANY);
+ if (IS_ERR(pt_obj))
+ return ERR_CAST(pt_obj);
+
+ if (pt_obj->type != IOMMUFD_OBJ_HWPT_NESTED &&
+ pt_obj->type != IOMMUFD_OBJ_HWPT_PAGING) {
+ iommufd_put_object(ucmd->ictx, pt_obj);
+ return ERR_PTR(-EINVAL);
+ }
+
+ return container_of(pt_obj, struct iommufd_hw_pagetable, obj);
+}
+
+static int iommufd_test_pasid_check_domain(struct iommufd_ucmd *ucmd,
+ struct iommu_test_cmd *cmd)
+{
+ struct iommu_domain *attached_domain, *expect_domain = NULL;
+ struct iommufd_hw_pagetable *hwpt = NULL;
+ struct selftest_obj *sobj;
+ struct mock_dev *mdev;
+ bool result;
+ int rc = 0;
+
+ sobj = iommufd_test_get_self_test_device(ucmd->ictx, cmd->id);
+ if (IS_ERR(sobj))
+ return PTR_ERR(sobj);
+
+ mdev = sobj->idev.mock_dev;
+
+ attached_domain = iommu_get_domain_for_dev_pasid(&mdev->dev,
+ cmd->pasid_check.pasid, 0);
+ if (IS_ERR(attached_domain))
+ attached_domain = NULL;
+
+ if (cmd->pasid_check.hwpt_id) {
+ hwpt = iommufd_get_hwpt(ucmd, cmd->pasid_check.hwpt_id);
+ if (IS_ERR(hwpt)) {
+ rc = PTR_ERR(hwpt);
+ goto out_put_dev;
+ }
+ expect_domain = hwpt->domain;
+ }
+
+ result = (attached_domain == expect_domain) ? 1 : 0;
+ if (copy_to_user(u64_to_user_ptr(cmd->pasid_check.out_result_ptr),
+ &result, sizeof(result)))
+ rc = -EFAULT;
+ if (hwpt)
+ iommufd_put_object(ucmd->ictx, &hwpt->obj);
+out_put_dev:
+ iommufd_put_object(ucmd->ictx, &sobj->obj);
+ return rc;
+}
+
void iommufd_selftest_destroy(struct iommufd_object *obj)
{
struct selftest_obj *sobj = container_of(obj, struct selftest_obj, obj);
@@ -1489,6 +1615,14 @@ int iommufd_test(struct iommufd_ucmd *ucmd)
cmd->dirty.page_size,
u64_to_user_ptr(cmd->dirty.uptr),
cmd->dirty.flags);
+ case IOMMU_TEST_OP_PASID_ATTACH:
+ return iommufd_test_pasid_attach(ucmd, cmd);
+ case IOMMU_TEST_OP_PASID_REPLACE:
+ return iommufd_test_pasid_replace(ucmd, cmd);
+ case IOMMU_TEST_OP_PASID_DETACH:
+ return iommufd_test_pasid_detach(ucmd, cmd);
+ case IOMMU_TEST_OP_PASID_CHECK_DOMAIN:
+ return iommufd_test_pasid_check_domain(ucmd, cmd);
default:
return -EOPNOTSUPP;
}
@@ -1532,6 +1666,7 @@ int __init iommufd_test_init(void)
goto err_sysfs;
mock_iommu_device.max_pasids = (1 << 20);//20 bits
+ pasid_1024_attached = false;
return 0;
err_sysfs:
This adds 4 test ops for pasid attach/replace/detach testing. There are ops to attach/detach pasid, and also op to check the attached domain of a pasid. Signed-off-by: Yi Liu <yi.l.liu@intel.com> --- drivers/iommu/iommufd/iommufd_test.h | 30 ++++++ drivers/iommu/iommufd/selftest.c | 135 +++++++++++++++++++++++++++ 2 files changed, 165 insertions(+)