diff mbox series

[19/30] iommu/mtk: Move to iommu_fw_alloc_per_device_ids()

Message ID 19-v1-f82a05539a64+5042-iommu_fwspec_p2_jgg@nvidia.com (mailing list archive)
State Handled Elsewhere
Headers show
Series Make a new API for drivers to use to get their FW | expand

Commit Message

Jason Gunthorpe Nov. 30, 2023, 1:10 a.m. UTC
mtk was doing a lot of stuff under of_xlate, and it looked kind of like it
might support multi-instances. But the dt files don't do that, and the
driver has no way to keep track of which instance the ids are for.

Enforce single instance with iommu_of_get_single_iommu().

Introduce a per-device data to store the iommu and ids list. Allocate and
initialize it with iommu_fw_alloc_per_device_ids(). Remove
mtk_iommu_of_xlate().

Convert the rest of the funcs from calling dev_iommu_fwspec_get() to using
the per-device data and remove all use of fwspec.

Covnert the places using dev_iommu_priv_get() to use the per-device data
not the iommu.

Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
---
 drivers/iommu/mtk_iommu.c | 116 ++++++++++++++++++++------------------
 1 file changed, 62 insertions(+), 54 deletions(-)
diff mbox series

Patch

diff --git a/drivers/iommu/mtk_iommu.c b/drivers/iommu/mtk_iommu.c
index 7abe9e85a57063..477171e83eaa6e 100644
--- a/drivers/iommu/mtk_iommu.c
+++ b/drivers/iommu/mtk_iommu.c
@@ -13,6 +13,7 @@ 
 #include <linux/interrupt.h>
 #include <linux/io.h>
 #include <linux/iommu.h>
+#include <linux/iommu-driver.h>
 #include <linux/iopoll.h>
 #include <linux/io-pgtable.h>
 #include <linux/list.h>
@@ -277,6 +278,12 @@  struct mtk_iommu_data {
 	struct mtk_smi_larb_iommu	larb_imu[MTK_LARB_NR_MAX];
 };
 
+struct mtk_iommu_device {
+	struct mtk_iommu_data *iommu;
+	unsigned int num_ids;
+	u32 ids[] __counted_by(num_ids);
+};
+
 struct mtk_iommu_domain {
 	struct io_pgtable_cfg		cfg;
 	struct io_pgtable_ops		*iop;
@@ -526,14 +533,14 @@  static irqreturn_t mtk_iommu_isr(int irq, void *dev_id)
 static unsigned int mtk_iommu_get_bank_id(struct device *dev,
 					  const struct mtk_iommu_plat_data *plat_data)
 {
-	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
+	struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
 	unsigned int i, portmsk = 0, bankid = 0;
 
 	if (plat_data->banks_num == 1)
 		return bankid;
 
-	for (i = 0; i < fwspec->num_ids; i++)
-		portmsk |= BIT(MTK_M4U_TO_PORT(fwspec->ids[i]));
+	for (i = 0; i < mtkdev->num_ids; i++)
+		portmsk |= BIT(MTK_M4U_TO_PORT(mtkdev->ids[i]));
 
 	for (i = 0; i < plat_data->banks_num && i < MTK_IOMMU_BANK_MAX; i++) {
 		if (!plat_data->banks_enable[i])
@@ -550,7 +557,7 @@  static unsigned int mtk_iommu_get_bank_id(struct device *dev,
 static int mtk_iommu_get_iova_region_id(struct device *dev,
 					const struct mtk_iommu_plat_data *plat_data)
 {
-	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
+	struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
 	unsigned int portidmsk = 0, larbid;
 	const u32 *rgn_larb_msk;
 	int i;
@@ -558,9 +565,9 @@  static int mtk_iommu_get_iova_region_id(struct device *dev,
 	if (plat_data->iova_region_nr == 1)
 		return 0;
 
-	larbid = MTK_M4U_TO_LARB(fwspec->ids[0]);
-	for (i = 0; i < fwspec->num_ids; i++)
-		portidmsk |= BIT(MTK_M4U_TO_PORT(fwspec->ids[i]));
+	larbid = MTK_M4U_TO_LARB(mtkdev->ids[0]);
+	for (i = 0; i < mtkdev->num_ids; i++)
+		portidmsk |= BIT(MTK_M4U_TO_PORT(mtkdev->ids[i]));
 
 	for (i = 0; i < plat_data->iova_region_nr; i++) {
 		rgn_larb_msk = plat_data->iova_region_larb_msk[i];
@@ -579,22 +586,22 @@  static int mtk_iommu_get_iova_region_id(struct device *dev,
 static int mtk_iommu_config(struct mtk_iommu_data *data, struct device *dev,
 			    bool enable, unsigned int regionid)
 {
+	struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
 	struct mtk_smi_larb_iommu    *larb_mmu;
 	unsigned int                 larbid, portid;
-	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
 	const struct mtk_iommu_iova_region *region;
 	unsigned long portid_msk = 0;
 	struct arm_smccc_res res;
 	int i, ret = 0;
 
-	for (i = 0; i < fwspec->num_ids; ++i) {
-		portid = MTK_M4U_TO_PORT(fwspec->ids[i]);
+	for (i = 0; i < mtkdev->num_ids; ++i) {
+		portid = MTK_M4U_TO_PORT(mtkdev->ids[i]);
 		portid_msk |= BIT(portid);
 	}
 
 	if (MTK_IOMMU_IS_TYPE(data->plat_data, MTK_IOMMU_TYPE_MM)) {
 		/* All ports should be in the same larb. just use 0 here */
-		larbid = MTK_M4U_TO_LARB(fwspec->ids[0]);
+		larbid = MTK_M4U_TO_LARB(mtkdev->ids[0]);
 		larb_mmu = &data->larb_imu[larbid];
 		region = data->plat_data->iova_region + regionid;
 
@@ -618,7 +625,7 @@  static int mtk_iommu_config(struct mtk_iommu_data *data, struct device *dev,
 		} else {
 			/* PCI dev has only one output id, enable the next writing bit for PCIe */
 			if (dev_is_pci(dev)) {
-				if (fwspec->num_ids != 1) {
+				if (mtkdev->num_ids != 1) {
 					dev_err(dev, "PCI dev can only have one port.\n");
 					return -ENODEV;
 				}
@@ -708,7 +715,9 @@  static void mtk_iommu_domain_free(struct iommu_domain *domain)
 static int mtk_iommu_attach_device(struct iommu_domain *domain,
 				   struct device *dev)
 {
-	struct mtk_iommu_data *data = dev_iommu_priv_get(dev), *frstdata;
+	struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
+	struct mtk_iommu_data *data = mtkdev->iommu;
+	struct mtk_iommu_data *frstdata;
 	struct mtk_iommu_domain *dom = to_mtk_domain(domain);
 	struct list_head *hw_list = data->hw_list;
 	struct device *m4udev = data->dev;
@@ -777,12 +786,12 @@  static int mtk_iommu_identity_attach(struct iommu_domain *identity_domain,
 				     struct device *dev)
 {
 	struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
-	struct mtk_iommu_data *data = dev_iommu_priv_get(dev);
+	struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
 
 	if (domain == identity_domain || !domain)
 		return 0;
 
-	mtk_iommu_config(data, dev, false, 0);
+	mtk_iommu_config(mtkdev->iommu, dev, false, 0);
 	return 0;
 }
 
@@ -860,14 +869,28 @@  static phys_addr_t mtk_iommu_iova_to_phys(struct iommu_domain *domain,
 	return pa;
 }
 
-static struct iommu_device *mtk_iommu_probe_device(struct device *dev)
+static struct iommu_device *
+mtk_iommu_probe_device(struct iommu_probe_info *pinf)
 {
-	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
-	struct mtk_iommu_data *data = dev_iommu_priv_get(dev);
+	struct mtk_iommu_device *mtkdev;
+	struct device *dev = pinf->dev;
+	struct mtk_iommu_data *data;
 	struct device_link *link;
 	struct device *larbdev;
 	unsigned int larbid, larbidx, i;
 
+	data = iommu_of_get_single_iommu(pinf, &mtk_iommu_ops, 1,
+					 struct mtk_iommu_data, iommu);
+	if (IS_ERR(data))
+		return ERR_CAST(data);
+
+	mtkdev = iommu_fw_alloc_per_device_ids(pinf, mtkdev);
+	if (IS_ERR(mtkdev))
+		return ERR_CAST(mtkdev);
+	mtkdev->iommu = data;
+
+	dev_iommu_priv_set(dev, mtkdev);
+
 	if (!MTK_IOMMU_IS_TYPE(data->plat_data, MTK_IOMMU_TYPE_MM))
 		return &data->iommu;
 
@@ -876,42 +899,46 @@  static struct iommu_device *mtk_iommu_probe_device(struct device *dev)
 	 * The device that connects with each a larb is a independent HW.
 	 * All the ports in each a device should be in the same larbs.
 	 */
-	larbid = MTK_M4U_TO_LARB(fwspec->ids[0]);
+	larbid = MTK_M4U_TO_LARB(mtkdev->ids[0]);
 	if (larbid >= MTK_LARB_NR_MAX)
-		return ERR_PTR(-EINVAL);
+		goto err_out;
 
-	for (i = 1; i < fwspec->num_ids; i++) {
-		larbidx = MTK_M4U_TO_LARB(fwspec->ids[i]);
+	for (i = 1; i < mtkdev->num_ids; i++) {
+		larbidx = MTK_M4U_TO_LARB(mtkdev->ids[i]);
 		if (larbid != larbidx) {
 			dev_err(dev, "Can only use one larb. Fail@larb%d-%d.\n",
 				larbid, larbidx);
-			return ERR_PTR(-EINVAL);
+			goto err_out;
 		}
 	}
 	larbdev = data->larb_imu[larbid].dev;
 	if (!larbdev)
-		return ERR_PTR(-EINVAL);
+		goto err_out;
 
 	link = device_link_add(dev, larbdev,
 			       DL_FLAG_PM_RUNTIME | DL_FLAG_STATELESS);
 	if (!link)
 		dev_err(dev, "Unable to link %s\n", dev_name(larbdev));
 	return &data->iommu;
+
+err_out:
+	kfree(mtkdev);
+	return ERR_PTR(-EINVAL);
 }
 
 static void mtk_iommu_release_device(struct device *dev)
 {
-	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
-	struct mtk_iommu_data *data;
+	struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
+	struct mtk_iommu_data *data = mtkdev->iommu;
 	struct device *larbdev;
 	unsigned int larbid;
 
-	data = dev_iommu_priv_get(dev);
 	if (MTK_IOMMU_IS_TYPE(data->plat_data, MTK_IOMMU_TYPE_MM)) {
-		larbid = MTK_M4U_TO_LARB(fwspec->ids[0]);
+		larbid = MTK_M4U_TO_LARB(mtkdev->ids[0]);
 		larbdev = data->larb_imu[larbid].dev;
 		device_link_remove(dev, larbdev);
 	}
+	kfree(mtkdev);
 }
 
 static int mtk_iommu_get_group_id(struct device *dev, const struct mtk_iommu_plat_data *plat_data)
@@ -931,7 +958,9 @@  static int mtk_iommu_get_group_id(struct device *dev, const struct mtk_iommu_pla
 
 static struct iommu_group *mtk_iommu_device_group(struct device *dev)
 {
-	struct mtk_iommu_data *c_data = dev_iommu_priv_get(dev), *data;
+	struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
+	struct mtk_iommu_data *c_data = mtkdev->iommu;
+	struct mtk_iommu_data *data;
 	struct list_head *hw_list = c_data->hw_list;
 	struct iommu_group *group;
 	int groupid;
@@ -957,32 +986,11 @@  static struct iommu_group *mtk_iommu_device_group(struct device *dev)
 	return group;
 }
 
-static int mtk_iommu_of_xlate(struct device *dev, struct of_phandle_args *args)
-{
-	struct platform_device *m4updev;
-
-	if (args->args_count != 1) {
-		dev_err(dev, "invalid #iommu-cells(%d) property for IOMMU\n",
-			args->args_count);
-		return -EINVAL;
-	}
-
-	if (!dev_iommu_priv_get(dev)) {
-		/* Get the m4u device */
-		m4updev = of_find_device_by_node(args->np);
-		if (WARN_ON(!m4updev))
-			return -EINVAL;
-
-		dev_iommu_priv_set(dev, platform_get_drvdata(m4updev));
-	}
-
-	return iommu_fwspec_add_ids(dev, args->args, 1);
-}
-
 static void mtk_iommu_get_resv_regions(struct device *dev,
 				       struct list_head *head)
 {
-	struct mtk_iommu_data *data = dev_iommu_priv_get(dev);
+	struct mtk_iommu_device *mtkdev = dev_iommu_priv_get(dev);
+	struct mtk_iommu_data *data = mtkdev->iommu;
 	unsigned int regionid = mtk_iommu_get_iova_region_id(dev, data->plat_data), i;
 	const struct mtk_iommu_iova_region *resv, *curdom;
 	struct iommu_resv_region *region;
@@ -1012,10 +1020,10 @@  static void mtk_iommu_get_resv_regions(struct device *dev,
 static const struct iommu_ops mtk_iommu_ops = {
 	.identity_domain = &mtk_iommu_identity_domain,
 	.domain_alloc_paging = mtk_iommu_domain_alloc_paging,
-	.probe_device	= mtk_iommu_probe_device,
+	.probe_device_pinf = mtk_iommu_probe_device,
 	.release_device	= mtk_iommu_release_device,
 	.device_group	= mtk_iommu_device_group,
-	.of_xlate	= mtk_iommu_of_xlate,
+	.of_xlate = iommu_dummy_of_xlate,
 	.get_resv_regions = mtk_iommu_get_resv_regions,
 	.pgsize_bitmap	= SZ_4K | SZ_64K | SZ_1M | SZ_16M,
 	.owner		= THIS_MODULE,