@@ -14,6 +14,8 @@
#define MM_CONTEXT_HAS_VSYSCALL 1
/* Do not allow changing LAM mode */
#define MM_CONTEXT_LOCK_LAM 2
+/* Allow LAM and SVA coexisting */
+#define MM_CONTEXT_FORCE_TAGGED_SVA 3
/*
* x86 has arch-specific MMU state beyond what lives in mm_struct.
@@ -114,6 +114,12 @@ static inline void mm_reset_untag_mask(struct mm_struct *mm)
mm->context.untag_mask = -1UL;
}
+#define arch_pgtable_dma_compat arch_pgtable_dma_compat
+static inline bool arch_pgtable_dma_compat(struct mm_struct *mm)
+{
+ return !mm_lam_cr3_mask(mm) ||
+ test_bit(MM_CONTEXT_FORCE_TAGGED_SVA, &mm->context.flags);
+}
#else
static inline unsigned long mm_lam_cr3_mask(struct mm_struct *mm)
@@ -23,5 +23,6 @@
#define ARCH_GET_UNTAG_MASK 0x4001
#define ARCH_ENABLE_TAGGED_ADDR 0x4002
#define ARCH_GET_MAX_TAG_BITS 0x4003
+#define ARCH_FORCE_TAGGED_SVA 0x4004
#endif /* _ASM_X86_PRCTL_H */
@@ -759,6 +759,10 @@ static int prctl_enable_tagged_addr(struct mm_struct *mm, unsigned long nr_bits)
if (current->mm != mm)
return -EINVAL;
+ if (mm_valid_pasid(mm) &&
+ !test_bit(MM_CONTEXT_FORCE_TAGGED_SVA, &mm->context.flags))
+ return -EINTR;
+
if (mmap_write_lock_killable(mm))
return -EINTR;
@@ -882,6 +886,9 @@ long do_arch_prctl_64(struct task_struct *task, int option, unsigned long arg2)
(unsigned long __user *)arg2);
case ARCH_ENABLE_TAGGED_ADDR:
return prctl_enable_tagged_addr(task->mm, arg2);
+ case ARCH_FORCE_TAGGED_SVA:
+ set_bit(MM_CONTEXT_FORCE_TAGGED_SVA, &task->mm->context.flags);
+ return 0;
case ARCH_GET_MAX_TAG_BITS:
if (!cpu_feature_enabled(X86_FEATURE_LAM))
return put_user(0, (unsigned long __user *)arg2);
@@ -2,6 +2,7 @@
/*
* Helpers for IOMMU drivers implementing SVA
*/
+#include <linux/mmu_context.h>
#include <linux/mutex.h>
#include <linux/sched/mm.h>
#include <linux/iommu.h>
@@ -32,6 +33,9 @@ int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
min == 0 || max < min)
return -EINVAL;
+ if (!arch_pgtable_dma_compat(mm))
+ return -EBUSY;
+
mutex_lock(&iommu_sva_lock);
/* Is a PASID already associated with this mm? */
if (mm_valid_pasid(mm)) {
@@ -35,4 +35,11 @@ static inline unsigned long mm_untag_mask(struct mm_struct *mm)
}
#endif
+#ifndef arch_pgtable_dma_compat
+static inline bool arch_pgtable_dma_compat(struct mm_struct *mm)
+{
+ return true;
+}
+#endif
+
#endif