diff mbox series

[20/20] mm/mshare: associate a mem cgroup with an mshare file

Message ID 20250124235454.84587-21-anthony.yznaga@oracle.com (mailing list archive)
State New
Headers show
Series Add support for shared PTEs across processes | expand

Commit Message

Anthony Yznaga Jan. 24, 2025, 11:54 p.m. UTC
This patch shows one approach to associating a specific mem cgroup to
an mshare file and was inspired by code in mem_cgroup_sk_alloc().
Essentially when a process creates an mshare region, a reference is
taken on the mem cgroup that the process belongs to and a pointer to
the memcg is saved. At fault time set_active_memcg() is used to
temporarily enable charging of __GFP_ACCOUNT allocations to the saved
memcg. This does consolidate pagetable charges to a single memcg, but
there are issues to address such as how to handle the case where the
memcg is deleted but becomes a hidden, zombie memcg because the mshare
file has a reference to it.

Signed-off-by: Anthony Yznaga <anthony.yznaga@oracle.com>
---
 arch/x86/mm/fault.c | 11 +++++++++++
 include/linux/mm.h  |  5 +++++
 mm/mshare.c         | 33 +++++++++++++++++++++++++++++++++
 3 files changed, 49 insertions(+)
diff mbox series

Patch

diff --git a/arch/x86/mm/fault.c b/arch/x86/mm/fault.c
index 4b55ade61a01..1b50417f68ad 100644
--- a/arch/x86/mm/fault.c
+++ b/arch/x86/mm/fault.c
@@ -21,6 +21,7 @@ 
 #include <linux/mm_types.h>
 #include <linux/mm.h>			/* find_and_lock_vma() */
 #include <linux/vmalloc.h>
+#include <linux/memcontrol.h>
 
 #include <asm/cpufeature.h>		/* boot_cpu_has, ...		*/
 #include <asm/traps.h>			/* dotraplinkage, ...		*/
@@ -1219,6 +1220,8 @@  void do_user_addr_fault(struct pt_regs *regs,
 	unsigned int flags = FAULT_FLAG_DEFAULT;
 	bool is_shared_vma;
 	unsigned long addr;
+	struct mem_cgroup *mshare_memcg;
+	struct mem_cgroup *memcg;
 
 	tsk = current;
 	mm = tsk->mm;
@@ -1375,6 +1378,8 @@  void do_user_addr_fault(struct pt_regs *regs,
 	}
 
 	if (unlikely(vma_is_mshare(vma))) {
+		mshare_memcg = get_mshare_memcg(vma);
+
 		fault = find_shared_vma(&vma, &addr);
 
 		if (fault) {
@@ -1402,6 +1407,9 @@  void do_user_addr_fault(struct pt_regs *regs,
 		return;
 	}
 
+	if (is_shared_vma && mshare_memcg)
+		memcg = set_active_memcg(mshare_memcg);
+
 	/*
 	 * If for any reason at all we couldn't handle the fault,
 	 * make sure we exit gracefully rather than endlessly redo
@@ -1417,6 +1425,9 @@  void do_user_addr_fault(struct pt_regs *regs,
 	 */
 	fault = handle_mm_fault(vma, addr, flags, regs);
 
+	if (is_shared_vma && mshare_memcg)
+		set_active_memcg(memcg);
+
 	if (unlikely(is_shared_vma) && ((fault & VM_FAULT_COMPLETED) ||
 	    (fault & VM_FAULT_RETRY) || fault_signal_pending(fault, regs)))
 		mmap_read_unlock(mm);
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 80429d1a6ae4..eaa304d22a9d 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1110,12 +1110,17 @@  static inline bool vma_is_anon_shmem(struct vm_area_struct *vma) { return false;
 int vma_is_stack_for_current(struct vm_area_struct *vma);
 
 #ifdef CONFIG_MSHARE
+struct mem_cgroup *get_mshare_memcg(struct vm_area_struct *vma);
 vm_fault_t find_shared_vma(struct vm_area_struct **vma, unsigned long *addrp);
 static inline bool vma_is_mshare(const struct vm_area_struct *vma)
 {
 	return vma->vm_flags & VM_MSHARE;
 }
 #else
+static inline struct mem_cgroup *get_mshare_memcg(struct vm_area_struct *vma)
+{
+	return NULL;
+}
 static inline vm_fault_t find_shared_vma(struct vm_area_struct **vma, unsigned long *addrp)
 {
 	WARN_ON_ONCE(1);
diff --git a/mm/mshare.c b/mm/mshare.c
index 5cc416cfd78c..a56e56c90aaa 100644
--- a/mm/mshare.c
+++ b/mm/mshare.c
@@ -16,6 +16,7 @@ 
 
 #include <linux/fs.h>
 #include <linux/fs_context.h>
+#include <linux/memcontrol.h>
 #include <linux/mman.h>
 #include <linux/mmu_notifier.h>
 #include <linux/spinlock_types.h>
@@ -30,8 +31,22 @@  struct mshare_data {
 	spinlock_t m_lock;
 	struct mshare_info minfo;
 	struct mmu_notifier mn;
+#ifdef CONFIG_MEMCG
+	struct mem_cgroup *memcg;
+#endif
 };
 
+struct mem_cgroup *get_mshare_memcg(struct vm_area_struct *vma)
+{
+	struct mshare_data *m_data = vma->vm_private_data;
+
+#ifdef CONFIG_MEMCG
+	return m_data->memcg;
+#else
+	return NULL;
+#endif
+}
+
 static void mshare_invalidate_tlbs(struct mmu_notifier *mn, struct mm_struct *mm,
 				   unsigned long start, unsigned long end)
 {
@@ -358,6 +373,9 @@  msharefs_fill_mm(struct inode *inode)
 	struct mm_struct *mm;
 	struct mshare_data *m_data = NULL;
 	int ret = 0;
+#ifdef CONFIG_MEMCG
+	struct mem_cgroup *memcg;
+#endif
 
 	mm = mm_alloc();
 	if (!mm) {
@@ -383,6 +401,17 @@  msharefs_fill_mm(struct inode *inode)
 
 #ifdef CONFIG_MEMCG
 	mm->owner = NULL;
+
+	rcu_read_lock();
+	memcg = mem_cgroup_from_task(current);
+	if (mem_cgroup_is_root(memcg))
+		goto out;
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys))
+		goto out;
+	if (css_tryget(&memcg->css))
+		m_data->memcg = memcg;
+out:
+	rcu_read_unlock();
 #endif
 	return 0;
 
@@ -396,6 +425,10 @@  msharefs_fill_mm(struct inode *inode)
 static void
 msharefs_delmm(struct mshare_data *m_data)
 {
+#ifdef CONFIG_MEMCG
+	if (m_data->memcg)
+		css_put(&m_data->memcg->css);
+#endif
 	mmput(m_data->mm);
 	kfree(m_data);
 }