diff mbox series

[12/20] mm/mshare: prepare for page table sharing support

Message ID 20250124235454.84587-13-anthony.yznaga@oracle.com (mailing list archive)
State New
Headers show
Series Add support for shared PTEs across processes | expand

Commit Message

Anthony Yznaga Jan. 24, 2025, 11:54 p.m. UTC
From: Khalid Aziz <khalid@kernel.org>

In preparation for enabling the handling of page faults in an mshare
region provide a way to link an mshare shared page table to a process
page table and otherwise find the actual vma in order to handle a page
fault. Modify the unmap path to ensure that page tables in mshare regions
are unlinked and kept intact when a process exits or an mshare region
is explicitly unmapped.

Signed-off-by: Khalid Aziz <khalid@kernel.org>
Signed-off-by: Matthew Wilcox (Oracle) <willy@infradead.org>
Signed-off-by: Anthony Yznaga <anthony.yznaga@oracle.com>
---
 include/linux/mm.h |  6 +++++
 mm/memory.c        | 38 ++++++++++++++++++++++------
 mm/mshare.c        | 62 ++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 98 insertions(+), 8 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/mm.h b/include/linux/mm.h
index 1314af11596d..9889c4757f45 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1110,11 +1110,17 @@  static inline bool vma_is_anon_shmem(struct vm_area_struct *vma) { return false;
 int vma_is_stack_for_current(struct vm_area_struct *vma);
 
 #ifdef CONFIG_MSHARE
+vm_fault_t find_shared_vma(struct vm_area_struct **vma, unsigned long *addrp);
 static inline bool vma_is_mshare(const struct vm_area_struct *vma)
 {
 	return vma->vm_flags & VM_MSHARE;
 }
 #else
+static inline vm_fault_t find_shared_vma(struct vm_area_struct **vma, unsigned long *addrp)
+{
+	WARN_ON_ONCE(1);
+	return VM_FAULT_SIGBUS;
+}
 static inline bool vma_is_mshare(const struct vm_area_struct *vma)
 {
 	return false;
diff --git a/mm/memory.c b/mm/memory.c
index 20bafbb10ea7..9374bb184a5f 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -263,7 +263,8 @@  static inline void free_pud_range(struct mmu_gather *tlb, p4d_t *p4d,
 
 static inline void free_p4d_range(struct mmu_gather *tlb, pgd_t *pgd,
 				unsigned long addr, unsigned long end,
-				unsigned long floor, unsigned long ceiling)
+				unsigned long floor, unsigned long ceiling,
+				bool shared_pud)
 {
 	p4d_t *p4d;
 	unsigned long next;
@@ -275,7 +276,10 @@  static inline void free_p4d_range(struct mmu_gather *tlb, pgd_t *pgd,
 		next = p4d_addr_end(addr, end);
 		if (p4d_none_or_clear_bad(p4d))
 			continue;
-		free_pud_range(tlb, p4d, addr, next, floor, ceiling);
+		if (unlikely(shared_pud))
+			p4d_clear(p4d);
+		else
+			free_pud_range(tlb, p4d, addr, next, floor, ceiling);
 	} while (p4d++, addr = next, addr != end);
 
 	start &= PGDIR_MASK;
@@ -297,9 +301,10 @@  static inline void free_p4d_range(struct mmu_gather *tlb, pgd_t *pgd,
 /*
  * This function frees user-level page tables of a process.
  */
-void free_pgd_range(struct mmu_gather *tlb,
+static void __free_pgd_range(struct mmu_gather *tlb,
 			unsigned long addr, unsigned long end,
-			unsigned long floor, unsigned long ceiling)
+			unsigned long floor, unsigned long ceiling,
+			bool shared_pud)
 {
 	pgd_t *pgd;
 	unsigned long next;
@@ -355,10 +360,17 @@  void free_pgd_range(struct mmu_gather *tlb,
 		next = pgd_addr_end(addr, end);
 		if (pgd_none_or_clear_bad(pgd))
 			continue;
-		free_p4d_range(tlb, pgd, addr, next, floor, ceiling);
+		free_p4d_range(tlb, pgd, addr, next, floor, ceiling, shared_pud);
 	} while (pgd++, addr = next, addr != end);
 }
 
+void free_pgd_range(struct mmu_gather *tlb,
+			unsigned long addr, unsigned long end,
+			unsigned long floor, unsigned long ceiling)
+{
+	__free_pgd_range(tlb, addr, end, floor, ceiling, false);
+}
+
 void free_pgtables(struct mmu_gather *tlb, struct ma_state *mas,
 		   struct vm_area_struct *vma, unsigned long floor,
 		   unsigned long ceiling, bool mm_wr_locked)
@@ -395,9 +407,12 @@  void free_pgtables(struct mmu_gather *tlb, struct ma_state *mas,
 
 			/*
 			 * Optimization: gather nearby vmas into one call down
+			 *
+			 * Do not free the shared page tables of an mshare region.
 			 */
 			while (next && next->vm_start <= vma->vm_end + PMD_SIZE
-			       && !is_vm_hugetlb_page(next)) {
+			       && !is_vm_hugetlb_page(next)
+			       && !vma_is_mshare(next)) {
 				vma = next;
 				next = mas_find(mas, ceiling - 1);
 				if (unlikely(xa_is_zero(next)))
@@ -408,9 +423,11 @@  void free_pgtables(struct mmu_gather *tlb, struct ma_state *mas,
 				unlink_file_vma_batch_add(&vb, vma);
 			}
 			unlink_file_vma_batch_final(&vb);
-			free_pgd_range(tlb, addr, vma->vm_end,
-				floor, next ? next->vm_start : ceiling);
+			__free_pgd_range(tlb, addr, vma->vm_end,
+				floor, next ? next->vm_start : ceiling,
+				vma_is_mshare(vma));
 		}
+
 		vma = next;
 	} while (vma);
 }
@@ -6148,6 +6165,11 @@  vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 	if (ret)
 		goto out;
 
+	if (unlikely(vma_is_mshare(vma))) {
+		WARN_ON_ONCE(1);
+		return VM_FAULT_SIGBUS;
+	}
+
 	if (!arch_vma_access_permitted(vma, flags & FAULT_FLAG_WRITE,
 					    flags & FAULT_FLAG_INSTRUCTION,
 					    flags & FAULT_FLAG_REMOTE)) {
diff --git a/mm/mshare.c b/mm/mshare.c
index 8dca4199dd01..9ada1544aeb1 100644
--- a/mm/mshare.c
+++ b/mm/mshare.c
@@ -42,6 +42,56 @@  static const struct mmu_notifier_ops mshare_mmu_ops = {
 	.arch_invalidate_secondary_tlbs = mshare_invalidate_tlbs,
 };
 
+static p4d_t *walk_to_p4d(struct mm_struct *mm, unsigned long addr)
+{
+	pgd_t *pgd;
+	p4d_t *p4d;
+
+	pgd = pgd_offset(mm, addr);
+	p4d = p4d_alloc(mm, pgd, addr);
+	if (!p4d)
+		return NULL;
+
+	return p4d;
+}
+
+/* Returns holding the host mm's lock for read.  Caller must release. */
+vm_fault_t
+find_shared_vma(struct vm_area_struct **vmap, unsigned long *addrp)
+{
+	struct vm_area_struct *vma, *guest = *vmap;
+	struct mshare_data *m_data = guest->vm_private_data;
+	struct mm_struct *host_mm = m_data->mm;
+	unsigned long host_addr;
+	p4d_t *p4d, *guest_p4d;
+
+	mmap_read_lock_nested(host_mm, SINGLE_DEPTH_NESTING);
+	host_addr = *addrp - guest->vm_start + host_mm->mmap_base;
+	p4d = walk_to_p4d(host_mm, host_addr);
+	guest_p4d = walk_to_p4d(guest->vm_mm, *addrp);
+	if (!p4d_same(*guest_p4d, *p4d)) {
+		set_p4d(guest_p4d, *p4d);
+		mmap_read_unlock(host_mm);
+		return VM_FAULT_NOPAGE;
+	}
+
+	*addrp = host_addr;
+	vma = find_vma(host_mm, host_addr);
+
+	/* XXX: expand stack? */
+	if (vma && vma->vm_start > host_addr)
+		vma = NULL;
+
+	*vmap = vma;
+
+	/*
+	 * release host mm lock unless a matching vma is found
+	 */
+	if (!vma)
+		mmap_read_unlock(host_mm);
+	return 0;
+}
+
 static int mshare_vm_op_split(struct vm_area_struct *vma, unsigned long addr)
 {
 	return -EINVAL;
@@ -53,9 +103,21 @@  static int mshare_vm_op_mprotect(struct vm_area_struct *vma, unsigned long start
 	return -EINVAL;
 }
 
+static void mshare_vm_op_unmap_page_range(struct mmu_gather *tlb,
+				struct vm_area_struct *vma,
+				unsigned long addr, unsigned long end,
+				struct zap_details *details)
+{
+	/*
+	 * The msharefs vma is being unmapped. Do not unmap pages in the
+	 * mshare region itself.
+	 */
+}
+
 static const struct vm_operations_struct msharefs_vm_ops = {
 	.may_split = mshare_vm_op_split,
 	.mprotect = mshare_vm_op_mprotect,
+	.unmap_page_range = mshare_vm_op_unmap_page_range,
 };
 
 /*