diff mbox series

[RFC,18/47] mm: asi: Support for pre-ASI-init local non-sensitive allocations

Message ID 20220223052223.1202152-19-junaids@google.com (mailing list archive)
State New
Headers show
Series Address Space Isolation for KVM | expand

Commit Message

Junaid Shahid Feb. 23, 2022, 5:21 a.m. UTC
Local non-sensitive allocations can be made before an actual ASI
instance is initialized. To support this, a process-wide pseudo-PGD
is created, which contains mappings for all locally non-sensitive
allocations. Memory can be mapped into this pseudo-PGD by using
ASI_LOCAL_NONSENSITIVE when calling asi_map(). The mappings will be
copied to an actual ASI PGD when an ASI instance is initialized in
that process, by copying all the PGD entries in the local
non-sensitive range from the pseudo-PGD to the ASI PGD. In addition,
the page fault handler will copy any new PGD entries that get added
after the initialization of the ASI instance.

Signed-off-by: Junaid Shahid <junaids@google.com>


---
 arch/x86/include/asm/asi.h |  6 +++-
 arch/x86/mm/asi.c          | 74 +++++++++++++++++++++++++++++++++++++-
 arch/x86/mm/fault.c        |  7 ++++
 include/asm-generic/asi.h  | 12 ++++++-
 kernel/fork.c              |  8 +++--
 5 files changed, 102 insertions(+), 5 deletions(-)
diff mbox series

Patch

diff --git a/arch/x86/include/asm/asi.h b/arch/x86/include/asm/asi.h
index f69e1f2f09a4..f11010c0334b 100644
--- a/arch/x86/include/asm/asi.h
+++ b/arch/x86/include/asm/asi.h
@@ -16,6 +16,7 @@ 
 #define ASI_MAX_NUM		(1 << ASI_MAX_NUM_ORDER)
 
 #define ASI_GLOBAL_NONSENSITIVE	(&init_mm.asi[0])
+#define ASI_LOCAL_NONSENSITIVE	(&current->mm->asi[0])
 
 struct asi_state {
 	struct asi *curr_asi;
@@ -45,7 +46,8 @@  DECLARE_PER_CPU_ALIGNED(struct asi_state, asi_cpu_state);
 
 extern pgd_t asi_global_nonsensitive_pgd[];
 
-void asi_init_mm_state(struct mm_struct *mm);
+int  asi_init_mm_state(struct mm_struct *mm);
+void asi_free_mm_state(struct mm_struct *mm);
 
 int  asi_register_class(const char *name, uint flags,
 			const struct asi_hooks *ops);
@@ -61,6 +63,8 @@  int  asi_map_gfp(struct asi *asi, void *addr, size_t len, gfp_t gfp_flags);
 int  asi_map(struct asi *asi, void *addr, size_t len);
 void asi_unmap(struct asi *asi, void *addr, size_t len, bool flush_tlb);
 void asi_flush_tlb_range(struct asi *asi, void *addr, size_t len);
+void asi_sync_mapping(struct asi *asi, void *addr, size_t len);
+void asi_do_lazy_map(struct asi *asi, size_t addr);
 
 static inline void asi_init_thread_state(struct thread_struct *thread)
 {
diff --git a/arch/x86/mm/asi.c b/arch/x86/mm/asi.c
index 38eaa650bac1..3ba0971a318d 100644
--- a/arch/x86/mm/asi.c
+++ b/arch/x86/mm/asi.c
@@ -73,6 +73,17 @@  void asi_unregister_class(int index)
 }
 EXPORT_SYMBOL_GPL(asi_unregister_class);
 
+static void asi_clone_pgd(pgd_t *dst_table, pgd_t *src_table, size_t addr)
+{
+	pgd_t *src = pgd_offset_pgd(src_table, addr);
+	pgd_t *dst = pgd_offset_pgd(dst_table, addr);
+
+	if (!pgd_val(*dst))
+		set_pgd(dst, *src);
+	else
+		VM_BUG_ON(pgd_val(*dst) != pgd_val(*src));
+}
+
 #ifndef mm_inc_nr_p4ds
 #define mm_inc_nr_p4ds(mm)	do {} while (false)
 #endif
@@ -291,6 +302,11 @@  int asi_init(struct mm_struct *mm, int asi_index, struct asi **out_asi)
 		for (i = KERNEL_PGD_BOUNDARY; i < pgd_index(ASI_LOCAL_MAP); i++)
 			set_pgd(asi->pgd + i, asi_global_nonsensitive_pgd[i]);
 
+		for (i = pgd_index(ASI_LOCAL_MAP);
+		     i <= pgd_index(ASI_LOCAL_MAP + PFN_PHYS(max_possible_pfn));
+		     i++)
+			set_pgd(asi->pgd + i, mm->asi[0].pgd[i]);
+
 		for (i = pgd_index(VMALLOC_GLOBAL_NONSENSITIVE_START);
 		     i < PTRS_PER_PGD; i++)
 			set_pgd(asi->pgd + i, asi_global_nonsensitive_pgd[i]);
@@ -379,7 +395,7 @@  void asi_exit(void)
 }
 EXPORT_SYMBOL_GPL(asi_exit);
 
-void asi_init_mm_state(struct mm_struct *mm)
+int asi_init_mm_state(struct mm_struct *mm)
 {
 	struct mem_cgroup *memcg = get_mem_cgroup_from_mm(mm);
 
@@ -395,6 +411,28 @@  void asi_init_mm_state(struct mm_struct *mm)
 				  memcg->use_asi;
 		css_put(&memcg->css);
 	}
+
+	if (!mm->asi_enabled)
+		return 0;
+
+	mm->asi[0].mm = mm;
+	mm->asi[0].pgd = (pgd_t *)__get_free_page(GFP_PGTABLE_USER);
+	if (!mm->asi[0].pgd)
+		return -ENOMEM;
+
+	return 0;
+}
+
+void asi_free_mm_state(struct mm_struct *mm)
+{
+	if (!boot_cpu_has(X86_FEATURE_ASI) || !mm->asi_enabled)
+		return;
+
+	asi_free_pgd_range(&mm->asi[0], pgd_index(ASI_LOCAL_MAP),
+			   pgd_index(ASI_LOCAL_MAP +
+				     PFN_PHYS(max_possible_pfn)) + 1);
+
+	free_page((ulong)mm->asi[0].pgd);
 }
 
 static bool is_page_within_range(size_t addr, size_t page_size,
@@ -599,3 +637,37 @@  void *asi_va(unsigned long pa)
 			      ? ASI_LOCAL_MAP : PAGE_OFFSET));
 }
 EXPORT_SYMBOL(asi_va);
+
+static bool is_addr_in_local_nonsensitive_range(size_t addr)
+{
+	return addr >= ASI_LOCAL_MAP &&
+	       addr < VMALLOC_GLOBAL_NONSENSITIVE_START;
+}
+
+void asi_do_lazy_map(struct asi *asi, size_t addr)
+{
+	if (!static_cpu_has(X86_FEATURE_ASI) || !asi)
+		return;
+
+	if ((asi->class->flags & ASI_MAP_STANDARD_NONSENSITIVE) &&
+	    is_addr_in_local_nonsensitive_range(addr))
+		asi_clone_pgd(asi->pgd, asi->mm->asi[0].pgd, addr);
+}
+
+/*
+ * Should be called after asi_map(ASI_LOCAL_NONSENSITIVE,...) for any mapping
+ * that is required to exist prior to asi_enter() (e.g. thread stacks)
+ */
+void asi_sync_mapping(struct asi *asi, void *start, size_t len)
+{
+	size_t addr = (size_t)start;
+	size_t end = addr + len;
+
+	if (!static_cpu_has(X86_FEATURE_ASI) || !asi)
+		return;
+
+	if ((asi->class->flags & ASI_MAP_STANDARD_NONSENSITIVE) &&
+	    is_addr_in_local_nonsensitive_range(addr))
+		for (; addr < end; addr = pgd_addr_end(addr, end))
+			asi_clone_pgd(asi->pgd, asi->mm->asi[0].pgd, addr);
+}
diff --git a/arch/x86/mm/fault.c b/arch/x86/mm/fault.c
index 4bfed53e210e..8692eb50f4a5 100644
--- a/arch/x86/mm/fault.c
+++ b/arch/x86/mm/fault.c
@@ -1498,6 +1498,12 @@  DEFINE_IDTENTRY_RAW_ERRORCODE(exc_page_fault)
 {
 	unsigned long address = read_cr2();
 	irqentry_state_t state;
+	/*
+	 * There is a very small chance that an NMI could cause an asi_exit()
+	 * before this asi_get_current(), but that is ok, we will just do
+	 * the fixup on the next page fault.
+	 */
+	struct asi *asi = asi_get_current();
 
 	prefetchw(&current->mm->mmap_lock);
 
@@ -1539,6 +1545,7 @@  DEFINE_IDTENTRY_RAW_ERRORCODE(exc_page_fault)
 
 	instrumentation_begin();
 	handle_page_fault(regs, error_code, address);
+	asi_do_lazy_map(asi, address);
 	instrumentation_end();
 
 	irqentry_exit(regs, state);
diff --git a/include/asm-generic/asi.h b/include/asm-generic/asi.h
index 51c9c4a488e8..a1c8ebff70e8 100644
--- a/include/asm-generic/asi.h
+++ b/include/asm-generic/asi.h
@@ -13,6 +13,7 @@ 
 #define ASI_MAX_NUM			0
 
 #define ASI_GLOBAL_NONSENSITIVE		NULL
+#define ASI_LOCAL_NONSENSITIVE		NULL
 
 #define VMALLOC_GLOBAL_NONSENSITIVE_START	VMALLOC_START
 #define VMALLOC_GLOBAL_NONSENSITIVE_END		VMALLOC_END
@@ -31,7 +32,9 @@  int asi_register_class(const char *name, uint flags,
 
 static inline void asi_unregister_class(int asi_index) { }
 
-static inline void asi_init_mm_state(struct mm_struct *mm) { }
+static inline int asi_init_mm_state(struct mm_struct *mm) { return 0; }
+
+static inline void asi_free_mm_state(struct mm_struct *mm) { }
 
 static inline
 int asi_init(struct mm_struct *mm, int asi_index, struct asi **out_asi)
@@ -67,9 +70,16 @@  static inline int asi_map(struct asi *asi, void *addr, size_t len)
 	return 0;
 }
 
+static inline
+void asi_sync_mapping(struct asi *asi, void *addr, size_t len) { }
+
 static inline
 void asi_unmap(struct asi *asi, void *addr, size_t len, bool flush_tlb) { }
 
+
+static inline
+void asi_do_lazy_map(struct asi *asi, size_t addr) { }
+
 static inline
 void asi_flush_tlb_range(struct asi *asi, void *addr, size_t len) { }
 
diff --git a/kernel/fork.c b/kernel/fork.c
index 3695a32ee9bd..dd5a86e913ea 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -699,6 +699,7 @@  void __mmdrop(struct mm_struct *mm)
 	mm_free_pgd(mm);
 	destroy_context(mm);
 	mmu_notifier_subscriptions_destroy(mm);
+	asi_free_mm_state(mm);
 	check_mm(mm);
 	put_user_ns(mm->user_ns);
 	free_mm(mm);
@@ -1072,17 +1073,20 @@  static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p,
 		mm->def_flags = 0;
 	}
 
-	asi_init_mm_state(mm);
-
 	if (mm_alloc_pgd(mm))
 		goto fail_nopgd;
 
 	if (init_new_context(p, mm))
 		goto fail_nocontext;
 
+	if (asi_init_mm_state(mm))
+		goto fail_noasi;
+
 	mm->user_ns = get_user_ns(user_ns);
+
 	return mm;
 
+fail_noasi:
 fail_nocontext:
 	mm_free_pgd(mm);
 fail_nopgd: