diff mbox series

[v8,10/14] iommu/arm-smmu-v3: Put the SVA mmu notifier in the smmu_domain

Message ID 10-v8-6f85cdc10ce7+563e-smmuv3_newapi_p2b_jgg@nvidia.com (mailing list archive)
State New, archived
Headers show
Series Update SMMUv3 to the modern iommu API (part 2b/3) | expand

Commit Message

Jason Gunthorpe June 4, 2024, 12:15 a.m. UTC
This removes all the notifier de-duplication logic in the driver and
relies on the core code to de-duplicate and allocate only one SVA domain
per mm per smmu instance. This naturally gives a 1:1 relationship between
SVA domain and mmu notifier.

It is a significant simplication of the flow, as we end up with a single
struct arm_smmu_domain for each MM and the invalidation can then be
shifted to properly use the masters list like S1/S2 do.

Remove all of the previous mmu_notifier, bond, shared cd, and cd refcount
logic entirely.

The logic here is tightly wound together with the unusued BTM
support. Since the BTM logic requires holding all the iommu_domains in a
global ASID xarray it conflicts with the design to have a single SVA
domain per PASID, as multiple SMMU instances will need to have different
domains.

Following patches resolve this by making the ASID xarray per-instance
instead of global. However, converting the BTM code over to this
methodology requires many changes.

Thus, since ARM_SMMU_FEAT_BTM is never enabled, remove the parts of the
BTM support for ASID sharing that interact with SVA as well.

A followup series is already working on fully enabling the BTM support,
that requires iommufd's VIOMMU feature to bring in the KVM's VMID as
well. It will come with an already written patch to bring back the ASID
sharing using a per-instance ASID xarray.

https://lore.kernel.org/linux-iommu/20240208151837.35068-1-shameerali.kolothum.thodi@huawei.com/
https://lore.kernel.org/linux-iommu/26-v6-228e7adf25eb+4155-smmuv3_newapi_p2_jgg@nvidia.com/

Tested-by: Nicolin Chen <nicolinc@nvidia.com>
Tested-by: Shameer Kolothum <shameerali.kolothum.thodi@huawei.com>
Reviewed-by: Nicolin Chen <nicolinc@nvidia.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
---
 .../iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c   | 395 +++---------------
 drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c   |  69 +--
 drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h   |  15 +-
 3 files changed, 86 insertions(+), 393 deletions(-)

Comments

Michael Shavit June 24, 2024, 9:54 a.m. UTC | #1
On Tue, Jun 4, 2024 at 8:16 AM Jason Gunthorpe <jgg@nvidia.com> wrote:
>
> This removes all the notifier de-duplication logic in the driver and
> relies on the core code to de-duplicate and allocate only one SVA domain
> per mm per smmu instance. This naturally gives a 1:1 relationship between
> SVA domain and mmu notifier.
>
> It is a significant simplication of the flow, as we end up with a single
> struct arm_smmu_domain for each MM and the invalidation can then be
> shifted to properly use the masters list like S1/S2 do.
>
> Remove all of the previous mmu_notifier, bond, shared cd, and cd refcount
> logic entirely.
>
> The logic here is tightly wound together with the unusued BTM
> support. Since the BTM logic requires holding all the iommu_domains in a
> global ASID xarray it conflicts with the design to have a single SVA
> domain per PASID, as multiple SMMU instances will need to have different
> domains.
>
> Following patches resolve this by making the ASID xarray per-instance
> instead of global. However, converting the BTM code over to this
> methodology requires many changes.
>
> Thus, since ARM_SMMU_FEAT_BTM is never enabled, remove the parts of the
> BTM support for ASID sharing that interact with SVA as well.
>
> A followup series is already working on fully enabling the BTM support,
> that requires iommufd's VIOMMU feature to bring in the KVM's VMID as
> well. It will come with an already written patch to bring back the ASID
> sharing using a per-instance ASID xarray.
>
> https://lore.kernel.org/linux-iommu/20240208151837.35068-1-shameerali.kolothum.thodi@huawei.com/
> https://lore.kernel.org/linux-iommu/26-v6-228e7adf25eb+4155-smmuv3_newapi_p2_jgg@nvidia.com/
>
> Tested-by: Nicolin Chen <nicolinc@nvidia.com>
> Tested-by: Shameer Kolothum <shameerali.kolothum.thodi@huawei.com>
> Reviewed-by: Nicolin Chen <nicolinc@nvidia.com>
> Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
> ---
>  .../iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c   | 395 +++---------------
>  drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c   |  69 +--
>  drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h   |  15 +-
>  3 files changed, 86 insertions(+), 393 deletions(-)
>
> diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
> index aa033cd65adc5a..a7c36654dee5a5 100644
> --- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
> +++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
> @@ -13,29 +13,9 @@
>  #include "arm-smmu-v3.h"
>  #include "../../io-pgtable-arm.h"
>
> -struct arm_smmu_mmu_notifier {
> -       struct mmu_notifier             mn;
> -       struct arm_smmu_ctx_desc        *cd;
> -       bool                            cleared;
> -       refcount_t                      refs;
> -       struct list_head                list;
> -       struct arm_smmu_domain          *domain;
> -};
> -
> -#define mn_to_smmu(mn) container_of(mn, struct arm_smmu_mmu_notifier, mn)
> -
> -struct arm_smmu_bond {
> -       struct mm_struct                *mm;
> -       struct arm_smmu_mmu_notifier    *smmu_mn;
> -       struct list_head                list;
> -};
> -
> -#define sva_to_bond(handle) \
> -       container_of(handle, struct arm_smmu_bond, sva)
> -
>  static DEFINE_MUTEX(sva_lock);
>
> -static void
> +static void __maybe_unused
>  arm_smmu_update_s1_domain_cd_entry(struct arm_smmu_domain *smmu_domain)
>  {
>         struct arm_smmu_master_domain *master_domain;
> @@ -58,58 +38,6 @@ arm_smmu_update_s1_domain_cd_entry(struct arm_smmu_domain *smmu_domain)
>         spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
>  }
>
> -/*
> - * Check if the CPU ASID is available on the SMMU side. If a private context
> - * descriptor is using it, try to replace it.
> - */
> -static struct arm_smmu_ctx_desc *
> -arm_smmu_share_asid(struct mm_struct *mm, u16 asid)
> -{
> -       int ret;
> -       u32 new_asid;
> -       struct arm_smmu_ctx_desc *cd;
> -       struct arm_smmu_device *smmu;
> -       struct arm_smmu_domain *smmu_domain;
> -
> -       cd = xa_load(&arm_smmu_asid_xa, asid);
> -       if (!cd)
> -               return NULL;
> -
> -       if (cd->mm) {
> -               if (WARN_ON(cd->mm != mm))
> -                       return ERR_PTR(-EINVAL);
> -               /* All devices bound to this mm use the same cd struct. */
> -               refcount_inc(&cd->refs);
> -               return cd;
> -       }
> -
> -       smmu_domain = container_of(cd, struct arm_smmu_domain, cd);
> -       smmu = smmu_domain->smmu;
> -
> -       ret = xa_alloc(&arm_smmu_asid_xa, &new_asid, cd,
> -                      XA_LIMIT(1, (1 << smmu->asid_bits) - 1), GFP_KERNEL);
> -       if (ret)
> -               return ERR_PTR(-ENOSPC);
> -       /*
> -        * Race with unmap: TLB invalidations will start targeting the new ASID,
> -        * which isn't assigned yet. We'll do an invalidate-all on the old ASID
> -        * later, so it doesn't matter.
> -        */
> -       cd->asid = new_asid;
> -       /*
> -        * Update ASID and invalidate CD in all associated masters. There will
> -        * be some overlap between use of both ASIDs, until we invalidate the
> -        * TLB.
> -        */
> -       arm_smmu_update_s1_domain_cd_entry(smmu_domain);
> -
> -       /* Invalidate TLB entries previously associated with that context */
> -       arm_smmu_tlb_inv_asid(smmu, asid);
> -
> -       xa_erase(&arm_smmu_asid_xa, asid);
> -       return NULL;
> -}
> -

Can we leave a comment on ASID sharing in the code since it isn't
added back until the next patch series? There are references to ASID
sharing remaining (and even added in this commit) that don't make
sense without this function (e.g "Prevent arm_smmu_share_asid() from
trying to change the ASID").


>  static u64 page_size_to_cd(void)
>  {
>         static_assert(PAGE_SIZE == SZ_4K || PAGE_SIZE == SZ_16K ||
> @@ -187,69 +115,6 @@ void arm_smmu_make_sva_cd(struct arm_smmu_cd *target,
>  }
>  EXPORT_SYMBOL_IF_KUNIT(arm_smmu_make_sva_cd);
>
> -static struct arm_smmu_ctx_desc *arm_smmu_alloc_shared_cd(struct mm_struct *mm)
> -{
> -       u16 asid;
> -       int err = 0;
> -       struct arm_smmu_ctx_desc *cd;
> -       struct arm_smmu_ctx_desc *ret = NULL;
> -
> -       /* Don't free the mm until we release the ASID */
> -       mmgrab(mm);
> -
> -       asid = arm64_mm_context_get(mm);
> -       if (!asid) {
> -               err = -ESRCH;
> -               goto out_drop_mm;
> -       }
> -
> -       cd = kzalloc(sizeof(*cd), GFP_KERNEL);
> -       if (!cd) {
> -               err = -ENOMEM;
> -               goto out_put_context;
> -       }
> -
> -       refcount_set(&cd->refs, 1);
> -
> -       mutex_lock(&arm_smmu_asid_lock);
> -       ret = arm_smmu_share_asid(mm, asid);
> -       if (ret) {
> -               mutex_unlock(&arm_smmu_asid_lock);
> -               goto out_free_cd;
> -       }
> -
> -       err = xa_insert(&arm_smmu_asid_xa, asid, cd, GFP_KERNEL);
> -       mutex_unlock(&arm_smmu_asid_lock);
> -
> -       if (err)
> -               goto out_free_asid;
> -
> -       cd->asid = asid;
> -       cd->mm = mm;
> -
> -       return cd;
> -
> -out_free_asid:
> -       arm_smmu_free_asid(cd);
> -out_free_cd:
> -       kfree(cd);
> -out_put_context:
> -       arm64_mm_context_put(mm);
> -out_drop_mm:
> -       mmdrop(mm);
> -       return err < 0 ? ERR_PTR(err) : ret;
> -}
> -
> -static void arm_smmu_free_shared_cd(struct arm_smmu_ctx_desc *cd)
> -{
> -       if (arm_smmu_free_asid(cd)) {
> -               /* Unpin ASID */
> -               arm64_mm_context_put(cd->mm);
> -               mmdrop(cd->mm);
> -               kfree(cd);
> -       }
> -}
> -
>  /*
>   * Cloned from the MAX_TLBI_OPS in arch/arm64/include/asm/tlbflush.h, this
>   * is used as a threshold to replace per-page TLBI commands to issue in the
> @@ -264,8 +129,8 @@ static void arm_smmu_mm_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
>                                                 unsigned long start,
>                                                 unsigned long end)
>  {
> -       struct arm_smmu_mmu_notifier *smmu_mn = mn_to_smmu(mn);
> -       struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
> +       struct arm_smmu_domain *smmu_domain =
> +               container_of(mn, struct arm_smmu_domain, mmu_notifier);
>         size_t size;
>
>         /*
> @@ -282,34 +147,22 @@ static void arm_smmu_mm_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
>                         size = 0;
>         }
>
> -       if (!(smmu_domain->smmu->features & ARM_SMMU_FEAT_BTM)) {
> -               if (!size)
> -                       arm_smmu_tlb_inv_asid(smmu_domain->smmu,
> -                                             smmu_mn->cd->asid);
> -               else
> -                       arm_smmu_tlb_inv_range_asid(start, size,
> -                                                   smmu_mn->cd->asid,
> -                                                   PAGE_SIZE, false,
> -                                                   smmu_domain);
> -       }
> +       if (!size)
> +               arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_domain->cd.asid);
> +       else
> +               arm_smmu_tlb_inv_range_asid(start, size, smmu_domain->cd.asid,
> +                                           PAGE_SIZE, false, smmu_domain);
>
> -       arm_smmu_atc_inv_domain_sva(smmu_domain, mm_get_enqcmd_pasid(mm), start,
> -                                   size);
> +       arm_smmu_atc_inv_domain(smmu_domain, start, size);
>  }
>
>  static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
>  {
> -       struct arm_smmu_mmu_notifier *smmu_mn = mn_to_smmu(mn);
> -       struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
> +       struct arm_smmu_domain *smmu_domain =
> +               container_of(mn, struct arm_smmu_domain, mmu_notifier);
>         struct arm_smmu_master_domain *master_domain;
>         unsigned long flags;
>
> -       mutex_lock(&sva_lock);
> -       if (smmu_mn->cleared) {
> -               mutex_unlock(&sva_lock);
> -               return;
> -       }
> -
>         /*
>          * DMA may still be running. Keep the cd valid to avoid C_BAD_CD events,
>          * but disable translation.
> @@ -321,25 +174,23 @@ static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
>                 struct arm_smmu_cd target;
>                 struct arm_smmu_cd *cdptr;
>
> -               cdptr = arm_smmu_get_cd_ptr(master, mm_get_enqcmd_pasid(mm));
> +               cdptr = arm_smmu_get_cd_ptr(master, master_domain->ssid);
>                 if (WARN_ON(!cdptr))
>                         continue;
> -               arm_smmu_make_sva_cd(&target, master, NULL, smmu_mn->cd->asid);
> -               arm_smmu_write_cd_entry(master, mm_get_enqcmd_pasid(mm), cdptr,
> +               arm_smmu_make_sva_cd(&target, master, NULL,
> +                                    smmu_domain->cd.asid);
> +               arm_smmu_write_cd_entry(master, master_domain->ssid, cdptr,
>                                         &target);
>         }
>         spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
>
> -       arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_mn->cd->asid);
> -       arm_smmu_atc_inv_domain_sva(smmu_domain, mm_get_enqcmd_pasid(mm), 0, 0);
> -
> -       smmu_mn->cleared = true;
> -       mutex_unlock(&sva_lock);
> +       arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_domain->cd.asid);
> +       arm_smmu_atc_inv_domain(smmu_domain, 0, 0);
>  }
>
>  static void arm_smmu_mmu_notifier_free(struct mmu_notifier *mn)
>  {
> -       kfree(mn_to_smmu(mn));
> +       kfree(container_of(mn, struct arm_smmu_domain, mmu_notifier));
>  }
>
>  static const struct mmu_notifier_ops arm_smmu_mmu_notifier_ops = {
> @@ -348,115 +199,6 @@ static const struct mmu_notifier_ops arm_smmu_mmu_notifier_ops = {
>         .free_notifier                  = arm_smmu_mmu_notifier_free,
>  };
>
> -/* Allocate or get existing MMU notifier for this {domain, mm} pair */
> -static struct arm_smmu_mmu_notifier *
> -arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
> -                         struct mm_struct *mm)
> -{
> -       int ret;
> -       struct arm_smmu_ctx_desc *cd;
> -       struct arm_smmu_mmu_notifier *smmu_mn;
> -
> -       list_for_each_entry(smmu_mn, &smmu_domain->mmu_notifiers, list) {
> -               if (smmu_mn->mn.mm == mm) {
> -                       refcount_inc(&smmu_mn->refs);
> -                       return smmu_mn;
> -               }
> -       }
> -
> -       cd = arm_smmu_alloc_shared_cd(mm);
> -       if (IS_ERR(cd))
> -               return ERR_CAST(cd);
> -
> -       smmu_mn = kzalloc(sizeof(*smmu_mn), GFP_KERNEL);
> -       if (!smmu_mn) {
> -               ret = -ENOMEM;
> -               goto err_free_cd;
> -       }
> -
> -       refcount_set(&smmu_mn->refs, 1);
> -       smmu_mn->cd = cd;
> -       smmu_mn->domain = smmu_domain;
> -       smmu_mn->mn.ops = &arm_smmu_mmu_notifier_ops;
> -
> -       ret = mmu_notifier_register(&smmu_mn->mn, mm);
> -       if (ret) {
> -               kfree(smmu_mn);
> -               goto err_free_cd;
> -       }
> -
> -       list_add(&smmu_mn->list, &smmu_domain->mmu_notifiers);
> -       return smmu_mn;
> -
> -err_free_cd:
> -       arm_smmu_free_shared_cd(cd);
> -       return ERR_PTR(ret);
> -}
> -
> -static void arm_smmu_mmu_notifier_put(struct arm_smmu_mmu_notifier *smmu_mn)
> -{
> -       struct mm_struct *mm = smmu_mn->mn.mm;
> -       struct arm_smmu_ctx_desc *cd = smmu_mn->cd;
> -       struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
> -
> -       if (!refcount_dec_and_test(&smmu_mn->refs))
> -               return;
> -
> -       list_del(&smmu_mn->list);
> -
> -       /*
> -        * If we went through clear(), we've already invalidated, and no
> -        * new TLB entry can have been formed.
> -        */
> -       if (!smmu_mn->cleared) {
> -               arm_smmu_tlb_inv_asid(smmu_domain->smmu, cd->asid);
> -               arm_smmu_atc_inv_domain_sva(smmu_domain,
> -                                           mm_get_enqcmd_pasid(mm), 0, 0);
> -       }
> -
> -       /* Frees smmu_mn */
> -       mmu_notifier_put(&smmu_mn->mn);
> -       arm_smmu_free_shared_cd(cd);
> -}
> -
> -static struct arm_smmu_bond *__arm_smmu_sva_bind(struct device *dev,
> -                                                struct mm_struct *mm)
> -{
> -       int ret;
> -       struct arm_smmu_bond *bond;
> -       struct arm_smmu_master *master = dev_iommu_priv_get(dev);
> -       struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
> -       struct arm_smmu_domain *smmu_domain;
> -
> -       if (!(domain->type & __IOMMU_DOMAIN_PAGING))
> -               return ERR_PTR(-ENODEV);
> -       smmu_domain = to_smmu_domain(domain);
> -       if (smmu_domain->stage != ARM_SMMU_DOMAIN_S1)
> -               return ERR_PTR(-ENODEV);
> -
> -       if (!master || !master->sva_enabled)
> -               return ERR_PTR(-ENODEV);
> -
> -       bond = kzalloc(sizeof(*bond), GFP_KERNEL);
> -       if (!bond)
> -               return ERR_PTR(-ENOMEM);
> -
> -       bond->mm = mm;
> -
> -       bond->smmu_mn = arm_smmu_mmu_notifier_get(smmu_domain, mm);
> -       if (IS_ERR(bond->smmu_mn)) {
> -               ret = PTR_ERR(bond->smmu_mn);
> -               goto err_free_bond;
> -       }
> -
> -       list_add(&bond->list, &master->bonds);
> -       return bond;
> -
> -err_free_bond:
> -       kfree(bond);
> -       return ERR_PTR(ret);
> -}
> -
>  bool arm_smmu_sva_supported(struct arm_smmu_device *smmu)
>  {
>         unsigned long reg, fld;
> @@ -573,11 +315,6 @@ int arm_smmu_master_enable_sva(struct arm_smmu_master *master)
>  int arm_smmu_master_disable_sva(struct arm_smmu_master *master)
>  {
>         mutex_lock(&sva_lock);
> -       if (!list_empty(&master->bonds)) {
> -               dev_err(master->dev, "cannot disable SVA, device is bound\n");
> -               mutex_unlock(&sva_lock);
> -               return -EBUSY;
> -       }
>         arm_smmu_master_sva_disable_iopf(master);
>         master->sva_enabled = false;
>         mutex_unlock(&sva_lock);
> @@ -594,66 +331,51 @@ void arm_smmu_sva_notifier_synchronize(void)
>         mmu_notifier_synchronize();
>  }
>
> -void arm_smmu_sva_remove_dev_pasid(struct iommu_domain *domain,
> -                                  struct device *dev, ioasid_t id)
> -{
> -       struct mm_struct *mm = domain->mm;
> -       struct arm_smmu_bond *bond = NULL, *t;
> -       struct arm_smmu_master *master = dev_iommu_priv_get(dev);
> -
> -       arm_smmu_remove_pasid(master, to_smmu_domain(domain), id);
> -
> -       mutex_lock(&sva_lock);
> -       list_for_each_entry(t, &master->bonds, list) {
> -               if (t->mm == mm) {
> -                       bond = t;
> -                       break;
> -               }
> -       }
> -
> -       if (!WARN_ON(!bond)) {
> -               list_del(&bond->list);
> -               arm_smmu_mmu_notifier_put(bond->smmu_mn);
> -               kfree(bond);
> -       }
> -       mutex_unlock(&sva_lock);
> -}
> -
>  static int arm_smmu_sva_set_dev_pasid(struct iommu_domain *domain,
>                                       struct device *dev, ioasid_t id)
>  {
> +       struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
>         struct arm_smmu_master *master = dev_iommu_priv_get(dev);
> -       struct mm_struct *mm = domain->mm;
> -       struct arm_smmu_bond *bond;
>         struct arm_smmu_cd target;
>         int ret;
>
> -       if (mm_get_enqcmd_pasid(mm) != id)
> +       /* Prevent arm_smmu_mm_release from being called while we are attaching */
> +       if (!mmget_not_zero(domain->mm))
>                 return -EINVAL;
>
> -       mutex_lock(&sva_lock);
> -       bond = __arm_smmu_sva_bind(dev, mm);
> -       if (IS_ERR(bond)) {
> -               mutex_unlock(&sva_lock);
> -               return PTR_ERR(bond);
> -       }
> +       /*
> +        * This does not need the arm_smmu_asid_lock because SVA domains never
> +        * get reassigned
> +        */
> +       arm_smmu_make_sva_cd(&target, master, domain->mm, smmu_domain->cd.asid);
> +       ret = arm_smmu_set_pasid(master, smmu_domain, id, &target);
>
> -       arm_smmu_make_sva_cd(&target, master, mm, bond->smmu_mn->cd->asid);
> -       ret = arm_smmu_set_pasid(master, to_smmu_domain(domain), id, &target);
> -       if (ret) {
> -               list_del(&bond->list);
> -               arm_smmu_mmu_notifier_put(bond->smmu_mn);
> -               kfree(bond);
> -               mutex_unlock(&sva_lock);
> -               return ret;
> -       }
> -       mutex_unlock(&sva_lock);
> -       return 0;
> +       mmput(domain->mm);
> +       return ret;
>  }
>
>  static void arm_smmu_sva_domain_free(struct iommu_domain *domain)
>  {
> -       kfree(to_smmu_domain(domain));
> +       struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
> +
> +       /*
> +        * Ensure the ASID is empty in the iommu cache before allowing reuse.
> +        */
> +       arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_domain->cd.asid);
> +
> +       /*
> +        * Notice that the arm_smmu_mm_arch_invalidate_secondary_tlbs op can
> +        * still be called/running at this point. We allow the ASID to be
> +        * reused, and if there is a race then it just suffers harmless
> +        * unnecessary invalidation.
> +        */
> +       xa_erase(&arm_smmu_asid_xa, smmu_domain->cd.asid);
> +
> +       /*
> +        * Actual free is defered to the SRCU callback
> +        * arm_smmu_mmu_notifier_free()
> +        */
> +       mmu_notifier_put(&smmu_domain->mmu_notifier);
>  }
>
>  static const struct iommu_domain_ops arm_smmu_sva_domain_ops = {
> @@ -667,6 +389,8 @@ struct iommu_domain *arm_smmu_sva_domain_alloc(struct device *dev,
>         struct arm_smmu_master *master = dev_iommu_priv_get(dev);
>         struct arm_smmu_device *smmu = master->smmu;
>         struct arm_smmu_domain *smmu_domain;
> +       u32 asid;
> +       int ret;
>
>         smmu_domain = arm_smmu_domain_alloc();
>         if (IS_ERR(smmu_domain))
> @@ -675,5 +399,22 @@ struct iommu_domain *arm_smmu_sva_domain_alloc(struct device *dev,
>         smmu_domain->domain.ops = &arm_smmu_sva_domain_ops;
>         smmu_domain->smmu = smmu;
>
> +       ret = xa_alloc(&arm_smmu_asid_xa, &asid, smmu_domain,
> +                      XA_LIMIT(1, (1 << smmu->asid_bits) - 1), GFP_KERNEL);
> +       if (ret)
> +               goto err_free;
> +
> +       smmu_domain->cd.asid = asid;
> +       smmu_domain->mmu_notifier.ops = &arm_smmu_mmu_notifier_ops;
> +       ret = mmu_notifier_register(&smmu_domain->mmu_notifier, mm);
> +       if (ret)
> +               goto err_asid;
> +
>         return &smmu_domain->domain;
> +
> +err_asid:
> +       xa_erase(&arm_smmu_asid_xa, smmu_domain->cd.asid);
> +err_free:
> +       kfree(smmu_domain);
> +       return ERR_PTR(ret);
>  }
> diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
> index 24000027253de8..2a845ab6d53b57 100644
> --- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
> +++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
> @@ -1439,22 +1439,6 @@ static void arm_smmu_free_cd_tables(struct arm_smmu_master *master)
>         cd_table->cdtab = NULL;
>  }
>
> -bool arm_smmu_free_asid(struct arm_smmu_ctx_desc *cd)
> -{
> -       bool free;
> -       struct arm_smmu_ctx_desc *old_cd;
> -
> -       if (!cd->asid)
> -               return false;
> -
> -       free = refcount_dec_and_test(&cd->refs);
> -       if (free) {
> -               old_cd = xa_erase(&arm_smmu_asid_xa, cd->asid);
> -               WARN_ON(old_cd != cd);
> -       }
> -       return free;
> -}
> -
>  /* Stream table manipulation functions */
>  static void
>  arm_smmu_write_strtab_l1_desc(__le64 *dst, struct arm_smmu_strtab_l1_desc *desc)
> @@ -2023,8 +2007,8 @@ static int arm_smmu_atc_inv_master(struct arm_smmu_master *master,
>         return arm_smmu_cmdq_batch_submit(master->smmu, &cmds);
>  }
>
> -static int __arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
> -                                    ioasid_t ssid, unsigned long iova, size_t size)
> +int arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
> +                           unsigned long iova, size_t size)
>  {
>         struct arm_smmu_master_domain *master_domain;
>         int i;
> @@ -2062,15 +2046,7 @@ static int __arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
>                 if (!master->ats_enabled)
>                         continue;
>
> -               /*
> -                * Non-zero ssid means SVA is co-opting the S1 domain to issue
> -                * invalidations for SVA PASIDs.
> -                */
> -               if (ssid != IOMMU_NO_PASID)
> -                       arm_smmu_atc_inv_to_cmd(ssid, iova, size, &cmd);
> -               else
> -                       arm_smmu_atc_inv_to_cmd(master_domain->ssid, iova, size,
> -                                               &cmd);
> +               arm_smmu_atc_inv_to_cmd(master_domain->ssid, iova, size, &cmd);
>
>                 for (i = 0; i < master->num_streams; i++) {
>                         cmd.atc.sid = master->streams[i].id;
> @@ -2082,19 +2058,6 @@ static int __arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
>         return arm_smmu_cmdq_batch_submit(smmu_domain->smmu, &cmds);
>  }
>
> -static int arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
> -                                  unsigned long iova, size_t size)
> -{
> -       return __arm_smmu_atc_inv_domain(smmu_domain, IOMMU_NO_PASID, iova,
> -                                        size);
> -}
> -
> -int arm_smmu_atc_inv_domain_sva(struct arm_smmu_domain *smmu_domain,
> -                               ioasid_t ssid, unsigned long iova, size_t size)
> -{
> -       return __arm_smmu_atc_inv_domain(smmu_domain, ssid, iova, size);
> -}
> -
>  /* IO_PGTABLE API */
>  static void arm_smmu_tlb_inv_context(void *cookie)
>  {
> @@ -2283,7 +2246,6 @@ struct arm_smmu_domain *arm_smmu_domain_alloc(void)
>         mutex_init(&smmu_domain->init_mutex);
>         INIT_LIST_HEAD(&smmu_domain->devices);
>         spin_lock_init(&smmu_domain->devices_lock);
> -       INIT_LIST_HEAD(&smmu_domain->mmu_notifiers);
>
>         return smmu_domain;
>  }
> @@ -2325,7 +2287,7 @@ static void arm_smmu_domain_free_paging(struct iommu_domain *domain)
>         if (smmu_domain->stage == ARM_SMMU_DOMAIN_S1) {
>                 /* Prevent SVA from touching the CD while we're freeing it */
>                 mutex_lock(&arm_smmu_asid_lock);
> -               arm_smmu_free_asid(&smmu_domain->cd);
> +               xa_erase(&arm_smmu_asid_xa, smmu_domain->cd.asid);
>                 mutex_unlock(&arm_smmu_asid_lock);
>         } else {
>                 struct arm_smmu_s2_cfg *cfg = &smmu_domain->s2_cfg;
> @@ -2343,11 +2305,9 @@ static int arm_smmu_domain_finalise_s1(struct arm_smmu_device *smmu,
>         u32 asid;
>         struct arm_smmu_ctx_desc *cd = &smmu_domain->cd;
>
> -       refcount_set(&cd->refs, 1);
> -
>         /* Prevent SVA from modifying the ASID until it is written to the CD */
>         mutex_lock(&arm_smmu_asid_lock);
> -       ret = xa_alloc(&arm_smmu_asid_xa, &asid, cd,
> +       ret = xa_alloc(&arm_smmu_asid_xa, &asid, smmu_domain,
>                        XA_LIMIT(1, (1 << smmu->asid_bits) - 1), GFP_KERNEL);
>         cd->asid        = (u16)asid;
>         mutex_unlock(&arm_smmu_asid_lock);
> @@ -2835,6 +2795,9 @@ int arm_smmu_set_pasid(struct arm_smmu_master *master,
>
>         /* The core code validates pasid */
>
> +       if (smmu_domain->smmu != master->smmu)
> +               return -EINVAL;
> +
>         if (!master->cd_table.in_ste)
>                 return -ENODEV;
>
> @@ -2856,9 +2819,14 @@ int arm_smmu_set_pasid(struct arm_smmu_master *master,
>         return ret;
>  }
>
> -void arm_smmu_remove_pasid(struct arm_smmu_master *master,
> -                          struct arm_smmu_domain *smmu_domain, ioasid_t pasid)
> +static void arm_smmu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
> +                                     struct iommu_domain *domain)
>  {
> +       struct arm_smmu_master *master = dev_iommu_priv_get(dev);
> +       struct arm_smmu_domain *smmu_domain;
> +
> +       smmu_domain = to_smmu_domain(domain);
> +
>         mutex_lock(&arm_smmu_asid_lock);
>         arm_smmu_clear_cd(master, pasid);
>         if (master->ats_enabled)
> @@ -3129,7 +3097,6 @@ static struct iommu_device *arm_smmu_probe_device(struct device *dev)
>
>         master->dev = dev;
>         master->smmu = smmu;
> -       INIT_LIST_HEAD(&master->bonds);
>         dev_iommu_priv_set(dev, master);
>
>         ret = arm_smmu_insert_master(smmu, master);
> @@ -3311,12 +3278,6 @@ static int arm_smmu_def_domain_type(struct device *dev)
>         return 0;
>  }
>
> -static void arm_smmu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
> -                                     struct iommu_domain *domain)
> -{
> -       arm_smmu_sva_remove_dev_pasid(domain, dev, pasid);
> -}
> -
>  static struct iommu_ops arm_smmu_ops = {
>         .identity_domain        = &arm_smmu_identity_domain,
>         .blocked_domain         = &arm_smmu_blocked_domain,
> diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
> index 212c18c70fa03e..d175d9eee6c61b 100644
> --- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
> +++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
> @@ -587,9 +587,6 @@ struct arm_smmu_strtab_l1_desc {
>
>  struct arm_smmu_ctx_desc {
>         u16                             asid;
> -
> -       refcount_t                      refs;
> -       struct mm_struct                *mm;
>  };
>
>  struct arm_smmu_l1_ctx_desc {
> @@ -712,7 +709,6 @@ struct arm_smmu_master {
>         bool                            stall_enabled;
>         bool                            sva_enabled;
>         bool                            iopf_enabled;
> -       struct list_head                bonds;
>         unsigned int                    ssid_bits;
>  };
>
> @@ -741,7 +737,7 @@ struct arm_smmu_domain {
>         struct list_head                devices;
>         spinlock_t                      devices_lock;
>
> -       struct list_head                mmu_notifiers;
> +       struct mmu_notifier             mmu_notifier;
>  };


>
>  /* The following are exposed for testing purposes. */
> @@ -805,16 +801,13 @@ void arm_smmu_write_cd_entry(struct arm_smmu_master *master, int ssid,
>  int arm_smmu_set_pasid(struct arm_smmu_master *master,
>                        struct arm_smmu_domain *smmu_domain, ioasid_t pasid,
>                        const struct arm_smmu_cd *cd);
> -void arm_smmu_remove_pasid(struct arm_smmu_master *master,
> -                          struct arm_smmu_domain *smmu_domain, ioasid_t pasid);
>
>  void arm_smmu_tlb_inv_asid(struct arm_smmu_device *smmu, u16 asid);
>  void arm_smmu_tlb_inv_range_asid(unsigned long iova, size_t size, int asid,
>                                  size_t granule, bool leaf,
>                                  struct arm_smmu_domain *smmu_domain);
> -bool arm_smmu_free_asid(struct arm_smmu_ctx_desc *cd);
> -int arm_smmu_atc_inv_domain_sva(struct arm_smmu_domain *smmu_domain,
> -                               ioasid_t ssid, unsigned long iova, size_t size);
> +int arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
> +                           unsigned long iova, size_t size);
>
>  #ifdef CONFIG_ARM_SMMU_V3_SVA
>  bool arm_smmu_sva_supported(struct arm_smmu_device *smmu);
> @@ -826,8 +819,6 @@ bool arm_smmu_master_iopf_supported(struct arm_smmu_master *master);
>  void arm_smmu_sva_notifier_synchronize(void);
>  struct iommu_domain *arm_smmu_sva_domain_alloc(struct device *dev,
>                                                struct mm_struct *mm);
> -void arm_smmu_sva_remove_dev_pasid(struct iommu_domain *domain,
> -                                  struct device *dev, ioasid_t id);
>  #else /* CONFIG_ARM_SMMU_V3_SVA */
>  static inline bool arm_smmu_sva_supported(struct arm_smmu_device *smmu)
>  {
> --
> 2.45.2
>

Reviewed-by: Michael Shavit <mshavit@google.com>
Jason Gunthorpe June 24, 2024, 5:01 p.m. UTC | #2
On Mon, Jun 24, 2024 at 05:54:42PM +0800, Michael Shavit wrote:

> Can we leave a comment on ASID sharing in the code since it isn't
> added back until the next patch series? There are references to ASID
> sharing remaining (and even added in this commit) that don't make
> sense without this function (e.g "Prevent arm_smmu_share_asid() from
> trying to change the ASID").

Yes, I left the comment references because I really do expect it to
come back soon.

My plan, broadly, is to allow the domain's to be shared across smmu
instances which should introduce the infrastructure to avoid the
invalidation race in unshare by letting the domain have multiple ASIDs
at the same time.

After that we would add in vBTM support, this is BTM on systems that
only support S1 with no S2. This avoids the VMID issue that is
blocking it while still being useful.

pBTM would come after the IOMMUFD VIOMMU support that Nicolin is
working on as the VIOMMU would be the vehicle to bring in the KVM VMID
binding from userspace.

I can delete the comments too, but then someone will ask why not
delete all the locking as well. :\

Thanks,
Jason
diff mbox series

Patch

diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
index aa033cd65adc5a..a7c36654dee5a5 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
@@ -13,29 +13,9 @@ 
 #include "arm-smmu-v3.h"
 #include "../../io-pgtable-arm.h"
 
-struct arm_smmu_mmu_notifier {
-	struct mmu_notifier		mn;
-	struct arm_smmu_ctx_desc	*cd;
-	bool				cleared;
-	refcount_t			refs;
-	struct list_head		list;
-	struct arm_smmu_domain		*domain;
-};
-
-#define mn_to_smmu(mn) container_of(mn, struct arm_smmu_mmu_notifier, mn)
-
-struct arm_smmu_bond {
-	struct mm_struct		*mm;
-	struct arm_smmu_mmu_notifier	*smmu_mn;
-	struct list_head		list;
-};
-
-#define sva_to_bond(handle) \
-	container_of(handle, struct arm_smmu_bond, sva)
-
 static DEFINE_MUTEX(sva_lock);
 
-static void
+static void __maybe_unused
 arm_smmu_update_s1_domain_cd_entry(struct arm_smmu_domain *smmu_domain)
 {
 	struct arm_smmu_master_domain *master_domain;
@@ -58,58 +38,6 @@  arm_smmu_update_s1_domain_cd_entry(struct arm_smmu_domain *smmu_domain)
 	spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
 }
 
-/*
- * Check if the CPU ASID is available on the SMMU side. If a private context
- * descriptor is using it, try to replace it.
- */
-static struct arm_smmu_ctx_desc *
-arm_smmu_share_asid(struct mm_struct *mm, u16 asid)
-{
-	int ret;
-	u32 new_asid;
-	struct arm_smmu_ctx_desc *cd;
-	struct arm_smmu_device *smmu;
-	struct arm_smmu_domain *smmu_domain;
-
-	cd = xa_load(&arm_smmu_asid_xa, asid);
-	if (!cd)
-		return NULL;
-
-	if (cd->mm) {
-		if (WARN_ON(cd->mm != mm))
-			return ERR_PTR(-EINVAL);
-		/* All devices bound to this mm use the same cd struct. */
-		refcount_inc(&cd->refs);
-		return cd;
-	}
-
-	smmu_domain = container_of(cd, struct arm_smmu_domain, cd);
-	smmu = smmu_domain->smmu;
-
-	ret = xa_alloc(&arm_smmu_asid_xa, &new_asid, cd,
-		       XA_LIMIT(1, (1 << smmu->asid_bits) - 1), GFP_KERNEL);
-	if (ret)
-		return ERR_PTR(-ENOSPC);
-	/*
-	 * Race with unmap: TLB invalidations will start targeting the new ASID,
-	 * which isn't assigned yet. We'll do an invalidate-all on the old ASID
-	 * later, so it doesn't matter.
-	 */
-	cd->asid = new_asid;
-	/*
-	 * Update ASID and invalidate CD in all associated masters. There will
-	 * be some overlap between use of both ASIDs, until we invalidate the
-	 * TLB.
-	 */
-	arm_smmu_update_s1_domain_cd_entry(smmu_domain);
-
-	/* Invalidate TLB entries previously associated with that context */
-	arm_smmu_tlb_inv_asid(smmu, asid);
-
-	xa_erase(&arm_smmu_asid_xa, asid);
-	return NULL;
-}
-
 static u64 page_size_to_cd(void)
 {
 	static_assert(PAGE_SIZE == SZ_4K || PAGE_SIZE == SZ_16K ||
@@ -187,69 +115,6 @@  void arm_smmu_make_sva_cd(struct arm_smmu_cd *target,
 }
 EXPORT_SYMBOL_IF_KUNIT(arm_smmu_make_sva_cd);
 
-static struct arm_smmu_ctx_desc *arm_smmu_alloc_shared_cd(struct mm_struct *mm)
-{
-	u16 asid;
-	int err = 0;
-	struct arm_smmu_ctx_desc *cd;
-	struct arm_smmu_ctx_desc *ret = NULL;
-
-	/* Don't free the mm until we release the ASID */
-	mmgrab(mm);
-
-	asid = arm64_mm_context_get(mm);
-	if (!asid) {
-		err = -ESRCH;
-		goto out_drop_mm;
-	}
-
-	cd = kzalloc(sizeof(*cd), GFP_KERNEL);
-	if (!cd) {
-		err = -ENOMEM;
-		goto out_put_context;
-	}
-
-	refcount_set(&cd->refs, 1);
-
-	mutex_lock(&arm_smmu_asid_lock);
-	ret = arm_smmu_share_asid(mm, asid);
-	if (ret) {
-		mutex_unlock(&arm_smmu_asid_lock);
-		goto out_free_cd;
-	}
-
-	err = xa_insert(&arm_smmu_asid_xa, asid, cd, GFP_KERNEL);
-	mutex_unlock(&arm_smmu_asid_lock);
-
-	if (err)
-		goto out_free_asid;
-
-	cd->asid = asid;
-	cd->mm = mm;
-
-	return cd;
-
-out_free_asid:
-	arm_smmu_free_asid(cd);
-out_free_cd:
-	kfree(cd);
-out_put_context:
-	arm64_mm_context_put(mm);
-out_drop_mm:
-	mmdrop(mm);
-	return err < 0 ? ERR_PTR(err) : ret;
-}
-
-static void arm_smmu_free_shared_cd(struct arm_smmu_ctx_desc *cd)
-{
-	if (arm_smmu_free_asid(cd)) {
-		/* Unpin ASID */
-		arm64_mm_context_put(cd->mm);
-		mmdrop(cd->mm);
-		kfree(cd);
-	}
-}
-
 /*
  * Cloned from the MAX_TLBI_OPS in arch/arm64/include/asm/tlbflush.h, this
  * is used as a threshold to replace per-page TLBI commands to issue in the
@@ -264,8 +129,8 @@  static void arm_smmu_mm_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
 						unsigned long start,
 						unsigned long end)
 {
-	struct arm_smmu_mmu_notifier *smmu_mn = mn_to_smmu(mn);
-	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
+	struct arm_smmu_domain *smmu_domain =
+		container_of(mn, struct arm_smmu_domain, mmu_notifier);
 	size_t size;
 
 	/*
@@ -282,34 +147,22 @@  static void arm_smmu_mm_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
 			size = 0;
 	}
 
-	if (!(smmu_domain->smmu->features & ARM_SMMU_FEAT_BTM)) {
-		if (!size)
-			arm_smmu_tlb_inv_asid(smmu_domain->smmu,
-					      smmu_mn->cd->asid);
-		else
-			arm_smmu_tlb_inv_range_asid(start, size,
-						    smmu_mn->cd->asid,
-						    PAGE_SIZE, false,
-						    smmu_domain);
-	}
+	if (!size)
+		arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_domain->cd.asid);
+	else
+		arm_smmu_tlb_inv_range_asid(start, size, smmu_domain->cd.asid,
+					    PAGE_SIZE, false, smmu_domain);
 
-	arm_smmu_atc_inv_domain_sva(smmu_domain, mm_get_enqcmd_pasid(mm), start,
-				    size);
+	arm_smmu_atc_inv_domain(smmu_domain, start, size);
 }
 
 static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 {
-	struct arm_smmu_mmu_notifier *smmu_mn = mn_to_smmu(mn);
-	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
+	struct arm_smmu_domain *smmu_domain =
+		container_of(mn, struct arm_smmu_domain, mmu_notifier);
 	struct arm_smmu_master_domain *master_domain;
 	unsigned long flags;
 
-	mutex_lock(&sva_lock);
-	if (smmu_mn->cleared) {
-		mutex_unlock(&sva_lock);
-		return;
-	}
-
 	/*
 	 * DMA may still be running. Keep the cd valid to avoid C_BAD_CD events,
 	 * but disable translation.
@@ -321,25 +174,23 @@  static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 		struct arm_smmu_cd target;
 		struct arm_smmu_cd *cdptr;
 
-		cdptr = arm_smmu_get_cd_ptr(master, mm_get_enqcmd_pasid(mm));
+		cdptr = arm_smmu_get_cd_ptr(master, master_domain->ssid);
 		if (WARN_ON(!cdptr))
 			continue;
-		arm_smmu_make_sva_cd(&target, master, NULL, smmu_mn->cd->asid);
-		arm_smmu_write_cd_entry(master, mm_get_enqcmd_pasid(mm), cdptr,
+		arm_smmu_make_sva_cd(&target, master, NULL,
+				     smmu_domain->cd.asid);
+		arm_smmu_write_cd_entry(master, master_domain->ssid, cdptr,
 					&target);
 	}
 	spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
 
-	arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_mn->cd->asid);
-	arm_smmu_atc_inv_domain_sva(smmu_domain, mm_get_enqcmd_pasid(mm), 0, 0);
-
-	smmu_mn->cleared = true;
-	mutex_unlock(&sva_lock);
+	arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_domain->cd.asid);
+	arm_smmu_atc_inv_domain(smmu_domain, 0, 0);
 }
 
 static void arm_smmu_mmu_notifier_free(struct mmu_notifier *mn)
 {
-	kfree(mn_to_smmu(mn));
+	kfree(container_of(mn, struct arm_smmu_domain, mmu_notifier));
 }
 
 static const struct mmu_notifier_ops arm_smmu_mmu_notifier_ops = {
@@ -348,115 +199,6 @@  static const struct mmu_notifier_ops arm_smmu_mmu_notifier_ops = {
 	.free_notifier			= arm_smmu_mmu_notifier_free,
 };
 
-/* Allocate or get existing MMU notifier for this {domain, mm} pair */
-static struct arm_smmu_mmu_notifier *
-arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
-			  struct mm_struct *mm)
-{
-	int ret;
-	struct arm_smmu_ctx_desc *cd;
-	struct arm_smmu_mmu_notifier *smmu_mn;
-
-	list_for_each_entry(smmu_mn, &smmu_domain->mmu_notifiers, list) {
-		if (smmu_mn->mn.mm == mm) {
-			refcount_inc(&smmu_mn->refs);
-			return smmu_mn;
-		}
-	}
-
-	cd = arm_smmu_alloc_shared_cd(mm);
-	if (IS_ERR(cd))
-		return ERR_CAST(cd);
-
-	smmu_mn = kzalloc(sizeof(*smmu_mn), GFP_KERNEL);
-	if (!smmu_mn) {
-		ret = -ENOMEM;
-		goto err_free_cd;
-	}
-
-	refcount_set(&smmu_mn->refs, 1);
-	smmu_mn->cd = cd;
-	smmu_mn->domain = smmu_domain;
-	smmu_mn->mn.ops = &arm_smmu_mmu_notifier_ops;
-
-	ret = mmu_notifier_register(&smmu_mn->mn, mm);
-	if (ret) {
-		kfree(smmu_mn);
-		goto err_free_cd;
-	}
-
-	list_add(&smmu_mn->list, &smmu_domain->mmu_notifiers);
-	return smmu_mn;
-
-err_free_cd:
-	arm_smmu_free_shared_cd(cd);
-	return ERR_PTR(ret);
-}
-
-static void arm_smmu_mmu_notifier_put(struct arm_smmu_mmu_notifier *smmu_mn)
-{
-	struct mm_struct *mm = smmu_mn->mn.mm;
-	struct arm_smmu_ctx_desc *cd = smmu_mn->cd;
-	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
-
-	if (!refcount_dec_and_test(&smmu_mn->refs))
-		return;
-
-	list_del(&smmu_mn->list);
-
-	/*
-	 * If we went through clear(), we've already invalidated, and no
-	 * new TLB entry can have been formed.
-	 */
-	if (!smmu_mn->cleared) {
-		arm_smmu_tlb_inv_asid(smmu_domain->smmu, cd->asid);
-		arm_smmu_atc_inv_domain_sva(smmu_domain,
-					    mm_get_enqcmd_pasid(mm), 0, 0);
-	}
-
-	/* Frees smmu_mn */
-	mmu_notifier_put(&smmu_mn->mn);
-	arm_smmu_free_shared_cd(cd);
-}
-
-static struct arm_smmu_bond *__arm_smmu_sva_bind(struct device *dev,
-						 struct mm_struct *mm)
-{
-	int ret;
-	struct arm_smmu_bond *bond;
-	struct arm_smmu_master *master = dev_iommu_priv_get(dev);
-	struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
-	struct arm_smmu_domain *smmu_domain;
-
-	if (!(domain->type & __IOMMU_DOMAIN_PAGING))
-		return ERR_PTR(-ENODEV);
-	smmu_domain = to_smmu_domain(domain);
-	if (smmu_domain->stage != ARM_SMMU_DOMAIN_S1)
-		return ERR_PTR(-ENODEV);
-
-	if (!master || !master->sva_enabled)
-		return ERR_PTR(-ENODEV);
-
-	bond = kzalloc(sizeof(*bond), GFP_KERNEL);
-	if (!bond)
-		return ERR_PTR(-ENOMEM);
-
-	bond->mm = mm;
-
-	bond->smmu_mn = arm_smmu_mmu_notifier_get(smmu_domain, mm);
-	if (IS_ERR(bond->smmu_mn)) {
-		ret = PTR_ERR(bond->smmu_mn);
-		goto err_free_bond;
-	}
-
-	list_add(&bond->list, &master->bonds);
-	return bond;
-
-err_free_bond:
-	kfree(bond);
-	return ERR_PTR(ret);
-}
-
 bool arm_smmu_sva_supported(struct arm_smmu_device *smmu)
 {
 	unsigned long reg, fld;
@@ -573,11 +315,6 @@  int arm_smmu_master_enable_sva(struct arm_smmu_master *master)
 int arm_smmu_master_disable_sva(struct arm_smmu_master *master)
 {
 	mutex_lock(&sva_lock);
-	if (!list_empty(&master->bonds)) {
-		dev_err(master->dev, "cannot disable SVA, device is bound\n");
-		mutex_unlock(&sva_lock);
-		return -EBUSY;
-	}
 	arm_smmu_master_sva_disable_iopf(master);
 	master->sva_enabled = false;
 	mutex_unlock(&sva_lock);
@@ -594,66 +331,51 @@  void arm_smmu_sva_notifier_synchronize(void)
 	mmu_notifier_synchronize();
 }
 
-void arm_smmu_sva_remove_dev_pasid(struct iommu_domain *domain,
-				   struct device *dev, ioasid_t id)
-{
-	struct mm_struct *mm = domain->mm;
-	struct arm_smmu_bond *bond = NULL, *t;
-	struct arm_smmu_master *master = dev_iommu_priv_get(dev);
-
-	arm_smmu_remove_pasid(master, to_smmu_domain(domain), id);
-
-	mutex_lock(&sva_lock);
-	list_for_each_entry(t, &master->bonds, list) {
-		if (t->mm == mm) {
-			bond = t;
-			break;
-		}
-	}
-
-	if (!WARN_ON(!bond)) {
-		list_del(&bond->list);
-		arm_smmu_mmu_notifier_put(bond->smmu_mn);
-		kfree(bond);
-	}
-	mutex_unlock(&sva_lock);
-}
-
 static int arm_smmu_sva_set_dev_pasid(struct iommu_domain *domain,
 				      struct device *dev, ioasid_t id)
 {
+	struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
 	struct arm_smmu_master *master = dev_iommu_priv_get(dev);
-	struct mm_struct *mm = domain->mm;
-	struct arm_smmu_bond *bond;
 	struct arm_smmu_cd target;
 	int ret;
 
-	if (mm_get_enqcmd_pasid(mm) != id)
+	/* Prevent arm_smmu_mm_release from being called while we are attaching */
+	if (!mmget_not_zero(domain->mm))
 		return -EINVAL;
 
-	mutex_lock(&sva_lock);
-	bond = __arm_smmu_sva_bind(dev, mm);
-	if (IS_ERR(bond)) {
-		mutex_unlock(&sva_lock);
-		return PTR_ERR(bond);
-	}
+	/*
+	 * This does not need the arm_smmu_asid_lock because SVA domains never
+	 * get reassigned
+	 */
+	arm_smmu_make_sva_cd(&target, master, domain->mm, smmu_domain->cd.asid);
+	ret = arm_smmu_set_pasid(master, smmu_domain, id, &target);
 
-	arm_smmu_make_sva_cd(&target, master, mm, bond->smmu_mn->cd->asid);
-	ret = arm_smmu_set_pasid(master, to_smmu_domain(domain), id, &target);
-	if (ret) {
-		list_del(&bond->list);
-		arm_smmu_mmu_notifier_put(bond->smmu_mn);
-		kfree(bond);
-		mutex_unlock(&sva_lock);
-		return ret;
-	}
-	mutex_unlock(&sva_lock);
-	return 0;
+	mmput(domain->mm);
+	return ret;
 }
 
 static void arm_smmu_sva_domain_free(struct iommu_domain *domain)
 {
-	kfree(to_smmu_domain(domain));
+	struct arm_smmu_domain *smmu_domain = to_smmu_domain(domain);
+
+	/*
+	 * Ensure the ASID is empty in the iommu cache before allowing reuse.
+	 */
+	arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_domain->cd.asid);
+
+	/*
+	 * Notice that the arm_smmu_mm_arch_invalidate_secondary_tlbs op can
+	 * still be called/running at this point. We allow the ASID to be
+	 * reused, and if there is a race then it just suffers harmless
+	 * unnecessary invalidation.
+	 */
+	xa_erase(&arm_smmu_asid_xa, smmu_domain->cd.asid);
+
+	/*
+	 * Actual free is defered to the SRCU callback
+	 * arm_smmu_mmu_notifier_free()
+	 */
+	mmu_notifier_put(&smmu_domain->mmu_notifier);
 }
 
 static const struct iommu_domain_ops arm_smmu_sva_domain_ops = {
@@ -667,6 +389,8 @@  struct iommu_domain *arm_smmu_sva_domain_alloc(struct device *dev,
 	struct arm_smmu_master *master = dev_iommu_priv_get(dev);
 	struct arm_smmu_device *smmu = master->smmu;
 	struct arm_smmu_domain *smmu_domain;
+	u32 asid;
+	int ret;
 
 	smmu_domain = arm_smmu_domain_alloc();
 	if (IS_ERR(smmu_domain))
@@ -675,5 +399,22 @@  struct iommu_domain *arm_smmu_sva_domain_alloc(struct device *dev,
 	smmu_domain->domain.ops = &arm_smmu_sva_domain_ops;
 	smmu_domain->smmu = smmu;
 
+	ret = xa_alloc(&arm_smmu_asid_xa, &asid, smmu_domain,
+		       XA_LIMIT(1, (1 << smmu->asid_bits) - 1), GFP_KERNEL);
+	if (ret)
+		goto err_free;
+
+	smmu_domain->cd.asid = asid;
+	smmu_domain->mmu_notifier.ops = &arm_smmu_mmu_notifier_ops;
+	ret = mmu_notifier_register(&smmu_domain->mmu_notifier, mm);
+	if (ret)
+		goto err_asid;
+
 	return &smmu_domain->domain;
+
+err_asid:
+	xa_erase(&arm_smmu_asid_xa, smmu_domain->cd.asid);
+err_free:
+	kfree(smmu_domain);
+	return ERR_PTR(ret);
 }
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
index 24000027253de8..2a845ab6d53b57 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
@@ -1439,22 +1439,6 @@  static void arm_smmu_free_cd_tables(struct arm_smmu_master *master)
 	cd_table->cdtab = NULL;
 }
 
-bool arm_smmu_free_asid(struct arm_smmu_ctx_desc *cd)
-{
-	bool free;
-	struct arm_smmu_ctx_desc *old_cd;
-
-	if (!cd->asid)
-		return false;
-
-	free = refcount_dec_and_test(&cd->refs);
-	if (free) {
-		old_cd = xa_erase(&arm_smmu_asid_xa, cd->asid);
-		WARN_ON(old_cd != cd);
-	}
-	return free;
-}
-
 /* Stream table manipulation functions */
 static void
 arm_smmu_write_strtab_l1_desc(__le64 *dst, struct arm_smmu_strtab_l1_desc *desc)
@@ -2023,8 +2007,8 @@  static int arm_smmu_atc_inv_master(struct arm_smmu_master *master,
 	return arm_smmu_cmdq_batch_submit(master->smmu, &cmds);
 }
 
-static int __arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
-				     ioasid_t ssid, unsigned long iova, size_t size)
+int arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
+			    unsigned long iova, size_t size)
 {
 	struct arm_smmu_master_domain *master_domain;
 	int i;
@@ -2062,15 +2046,7 @@  static int __arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
 		if (!master->ats_enabled)
 			continue;
 
-		/*
-		 * Non-zero ssid means SVA is co-opting the S1 domain to issue
-		 * invalidations for SVA PASIDs.
-		 */
-		if (ssid != IOMMU_NO_PASID)
-			arm_smmu_atc_inv_to_cmd(ssid, iova, size, &cmd);
-		else
-			arm_smmu_atc_inv_to_cmd(master_domain->ssid, iova, size,
-						&cmd);
+		arm_smmu_atc_inv_to_cmd(master_domain->ssid, iova, size, &cmd);
 
 		for (i = 0; i < master->num_streams; i++) {
 			cmd.atc.sid = master->streams[i].id;
@@ -2082,19 +2058,6 @@  static int __arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
 	return arm_smmu_cmdq_batch_submit(smmu_domain->smmu, &cmds);
 }
 
-static int arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
-				   unsigned long iova, size_t size)
-{
-	return __arm_smmu_atc_inv_domain(smmu_domain, IOMMU_NO_PASID, iova,
-					 size);
-}
-
-int arm_smmu_atc_inv_domain_sva(struct arm_smmu_domain *smmu_domain,
-				ioasid_t ssid, unsigned long iova, size_t size)
-{
-	return __arm_smmu_atc_inv_domain(smmu_domain, ssid, iova, size);
-}
-
 /* IO_PGTABLE API */
 static void arm_smmu_tlb_inv_context(void *cookie)
 {
@@ -2283,7 +2246,6 @@  struct arm_smmu_domain *arm_smmu_domain_alloc(void)
 	mutex_init(&smmu_domain->init_mutex);
 	INIT_LIST_HEAD(&smmu_domain->devices);
 	spin_lock_init(&smmu_domain->devices_lock);
-	INIT_LIST_HEAD(&smmu_domain->mmu_notifiers);
 
 	return smmu_domain;
 }
@@ -2325,7 +2287,7 @@  static void arm_smmu_domain_free_paging(struct iommu_domain *domain)
 	if (smmu_domain->stage == ARM_SMMU_DOMAIN_S1) {
 		/* Prevent SVA from touching the CD while we're freeing it */
 		mutex_lock(&arm_smmu_asid_lock);
-		arm_smmu_free_asid(&smmu_domain->cd);
+		xa_erase(&arm_smmu_asid_xa, smmu_domain->cd.asid);
 		mutex_unlock(&arm_smmu_asid_lock);
 	} else {
 		struct arm_smmu_s2_cfg *cfg = &smmu_domain->s2_cfg;
@@ -2343,11 +2305,9 @@  static int arm_smmu_domain_finalise_s1(struct arm_smmu_device *smmu,
 	u32 asid;
 	struct arm_smmu_ctx_desc *cd = &smmu_domain->cd;
 
-	refcount_set(&cd->refs, 1);
-
 	/* Prevent SVA from modifying the ASID until it is written to the CD */
 	mutex_lock(&arm_smmu_asid_lock);
-	ret = xa_alloc(&arm_smmu_asid_xa, &asid, cd,
+	ret = xa_alloc(&arm_smmu_asid_xa, &asid, smmu_domain,
 		       XA_LIMIT(1, (1 << smmu->asid_bits) - 1), GFP_KERNEL);
 	cd->asid	= (u16)asid;
 	mutex_unlock(&arm_smmu_asid_lock);
@@ -2835,6 +2795,9 @@  int arm_smmu_set_pasid(struct arm_smmu_master *master,
 
 	/* The core code validates pasid */
 
+	if (smmu_domain->smmu != master->smmu)
+		return -EINVAL;
+
 	if (!master->cd_table.in_ste)
 		return -ENODEV;
 
@@ -2856,9 +2819,14 @@  int arm_smmu_set_pasid(struct arm_smmu_master *master,
 	return ret;
 }
 
-void arm_smmu_remove_pasid(struct arm_smmu_master *master,
-			   struct arm_smmu_domain *smmu_domain, ioasid_t pasid)
+static void arm_smmu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
+				      struct iommu_domain *domain)
 {
+	struct arm_smmu_master *master = dev_iommu_priv_get(dev);
+	struct arm_smmu_domain *smmu_domain;
+
+	smmu_domain = to_smmu_domain(domain);
+
 	mutex_lock(&arm_smmu_asid_lock);
 	arm_smmu_clear_cd(master, pasid);
 	if (master->ats_enabled)
@@ -3129,7 +3097,6 @@  static struct iommu_device *arm_smmu_probe_device(struct device *dev)
 
 	master->dev = dev;
 	master->smmu = smmu;
-	INIT_LIST_HEAD(&master->bonds);
 	dev_iommu_priv_set(dev, master);
 
 	ret = arm_smmu_insert_master(smmu, master);
@@ -3311,12 +3278,6 @@  static int arm_smmu_def_domain_type(struct device *dev)
 	return 0;
 }
 
-static void arm_smmu_remove_dev_pasid(struct device *dev, ioasid_t pasid,
-				      struct iommu_domain *domain)
-{
-	arm_smmu_sva_remove_dev_pasid(domain, dev, pasid);
-}
-
 static struct iommu_ops arm_smmu_ops = {
 	.identity_domain	= &arm_smmu_identity_domain,
 	.blocked_domain		= &arm_smmu_blocked_domain,
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
index 212c18c70fa03e..d175d9eee6c61b 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
@@ -587,9 +587,6 @@  struct arm_smmu_strtab_l1_desc {
 
 struct arm_smmu_ctx_desc {
 	u16				asid;
-
-	refcount_t			refs;
-	struct mm_struct		*mm;
 };
 
 struct arm_smmu_l1_ctx_desc {
@@ -712,7 +709,6 @@  struct arm_smmu_master {
 	bool				stall_enabled;
 	bool				sva_enabled;
 	bool				iopf_enabled;
-	struct list_head		bonds;
 	unsigned int			ssid_bits;
 };
 
@@ -741,7 +737,7 @@  struct arm_smmu_domain {
 	struct list_head		devices;
 	spinlock_t			devices_lock;
 
-	struct list_head		mmu_notifiers;
+	struct mmu_notifier		mmu_notifier;
 };
 
 /* The following are exposed for testing purposes. */
@@ -805,16 +801,13 @@  void arm_smmu_write_cd_entry(struct arm_smmu_master *master, int ssid,
 int arm_smmu_set_pasid(struct arm_smmu_master *master,
 		       struct arm_smmu_domain *smmu_domain, ioasid_t pasid,
 		       const struct arm_smmu_cd *cd);
-void arm_smmu_remove_pasid(struct arm_smmu_master *master,
-			   struct arm_smmu_domain *smmu_domain, ioasid_t pasid);
 
 void arm_smmu_tlb_inv_asid(struct arm_smmu_device *smmu, u16 asid);
 void arm_smmu_tlb_inv_range_asid(unsigned long iova, size_t size, int asid,
 				 size_t granule, bool leaf,
 				 struct arm_smmu_domain *smmu_domain);
-bool arm_smmu_free_asid(struct arm_smmu_ctx_desc *cd);
-int arm_smmu_atc_inv_domain_sva(struct arm_smmu_domain *smmu_domain,
-				ioasid_t ssid, unsigned long iova, size_t size);
+int arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
+			    unsigned long iova, size_t size);
 
 #ifdef CONFIG_ARM_SMMU_V3_SVA
 bool arm_smmu_sva_supported(struct arm_smmu_device *smmu);
@@ -826,8 +819,6 @@  bool arm_smmu_master_iopf_supported(struct arm_smmu_master *master);
 void arm_smmu_sva_notifier_synchronize(void);
 struct iommu_domain *arm_smmu_sva_domain_alloc(struct device *dev,
 					       struct mm_struct *mm);
-void arm_smmu_sva_remove_dev_pasid(struct iommu_domain *domain,
-				   struct device *dev, ioasid_t id);
 #else /* CONFIG_ARM_SMMU_V3_SVA */
 static inline bool arm_smmu_sva_supported(struct arm_smmu_device *smmu)
 {