diff mbox series

[v4,07/11] mm: multigenerational lru: aging

Message ID 20210818063107.2696454-8-yuzhao@google.com (mailing list archive)
State New
Headers show
Series Multigenerational LRU Framework | expand

Commit Message

Yu Zhao Aug. 18, 2021, 6:31 a.m. UTC
The aging produces young generations. Given an lruvec, the aging
traverses lruvec_memcg()->mm_list and calls walk_page_range() to scan
PTEs for accessed pages. Upon finding one, the aging updates its
generation number to max_seq (modulo MAX_NR_GENS). After each round of
traversal, the aging increments max_seq. The aging is due when both
min_seq[2] have caught up with max_seq-1.

The aging uses the following optimizations when walking page tables:
  1) It skips page tables of processes that have been sleeping since
  the last walk.
  2) It skips non-leaf PMD entries that have the accessed bit cleared
  when CONFIG_HAVE_ARCH_PARENT_PMD_YOUNG=y.
  3) It does not zigzag between a PGD table and the same PMD or PTE
  table spanning multiple VMAs. In other words, it finishes all the
  VMAs within the range of the same PMD or PTE table before it returns
  to this PGD table.

Signed-off-by: Yu Zhao <yuzhao@google.com>
Tested-by: Konstantin Kharlamov <Hi-Angel@yandex.ru>
---
 include/linux/memcontrol.h |   3 +
 include/linux/mmzone.h     |  11 +
 include/linux/oom.h        |  16 +
 include/linux/swap.h       |   1 +
 mm/oom_kill.c              |   4 +-
 mm/rmap.c                  |   7 +
 mm/swap.c                  |   4 +-
 mm/vmscan.c                | 903 +++++++++++++++++++++++++++++++++++++
 8 files changed, 945 insertions(+), 4 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 5e223cecb5c2..657d94344dfc 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -1346,10 +1346,13 @@  mem_cgroup_print_oom_meminfo(struct mem_cgroup *memcg)
 
 static inline void lock_page_memcg(struct page *page)
 {
+	/* to match page_memcg_rcu() */
+	rcu_read_lock();
 }
 
 static inline void unlock_page_memcg(struct page *page)
 {
+	rcu_read_unlock();
 }
 
 static inline void mem_cgroup_handle_over_high(void)
diff --git a/include/linux/mmzone.h b/include/linux/mmzone.h
index d6c2c3a4ba43..b6005e881862 100644
--- a/include/linux/mmzone.h
+++ b/include/linux/mmzone.h
@@ -295,6 +295,7 @@  enum lruvec_flags {
 };
 
 struct lruvec;
+struct page_vma_mapped_walk;
 
 #define LRU_GEN_MASK		((BIT(LRU_GEN_WIDTH) - 1) << LRU_GEN_PGOFF)
 #define LRU_USAGE_MASK		((BIT(LRU_USAGE_WIDTH) - 1) << LRU_USAGE_PGOFF)
@@ -369,6 +370,7 @@  struct lrugen {
 
 void lru_gen_init_lrugen(struct lruvec *lruvec);
 void lru_gen_set_state(bool enable, bool main, bool swap);
+void lru_gen_scan_around(struct page_vma_mapped_walk *pvmw);
 
 #else /* CONFIG_LRU_GEN */
 
@@ -380,6 +382,10 @@  static inline void lru_gen_set_state(bool enable, bool main, bool swap)
 {
 }
 
+static inline void lru_gen_scan_around(struct page_vma_mapped_walk *pvmw)
+{
+}
+
 #endif /* CONFIG_LRU_GEN */
 
 struct lruvec {
@@ -874,6 +880,8 @@  struct deferred_split {
 };
 #endif
 
+struct mm_walk_args;
+
 /*
  * On NUMA machines, each NUMA node would have a pg_data_t to describe
  * it's memory layout. On UMA machines there is a single pglist_data which
@@ -979,6 +987,9 @@  typedef struct pglist_data {
 
 	unsigned long		flags;
 
+#ifdef CONFIG_LRU_GEN
+	struct mm_walk_args	*mm_walk_args;
+#endif
 	ZONE_PADDING(_pad2_)
 
 	/* Per-node vmstats */
diff --git a/include/linux/oom.h b/include/linux/oom.h
index 2db9a1432511..c4c8c7e71099 100644
--- a/include/linux/oom.h
+++ b/include/linux/oom.h
@@ -57,6 +57,22 @@  struct oom_control {
 extern struct mutex oom_lock;
 extern struct mutex oom_adj_mutex;
 
+#ifdef CONFIG_MMU
+extern struct task_struct *oom_reaper_list;
+extern struct wait_queue_head oom_reaper_wait;
+
+static inline bool oom_reaping_in_progress(void)
+{
+	/* racy check to see if oom reaping could be in progress */
+	return READ_ONCE(oom_reaper_list) || !waitqueue_active(&oom_reaper_wait);
+}
+#else
+static inline bool oom_reaping_in_progress(void)
+{
+	return false;
+}
+#endif
+
 static inline void set_current_oom_origin(void)
 {
 	current->signal->oom_flag_origin = true;
diff --git a/include/linux/swap.h b/include/linux/swap.h
index 6f5a43251593..c838e67dfa3a 100644
--- a/include/linux/swap.h
+++ b/include/linux/swap.h
@@ -368,6 +368,7 @@  extern void lru_add_drain_all(void);
 extern void rotate_reclaimable_page(struct page *page);
 extern void deactivate_file_page(struct page *page);
 extern void deactivate_page(struct page *page);
+extern void activate_page(struct page *page);
 extern void mark_page_lazyfree(struct page *page);
 extern void swap_setup(void);
 
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index c729a4c4a1ac..eca484ee3a3d 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -507,8 +507,8 @@  bool process_shares_mm(struct task_struct *p, struct mm_struct *mm)
  * victim (if that is possible) to help the OOM killer to move on.
  */
 static struct task_struct *oom_reaper_th;
-static DECLARE_WAIT_QUEUE_HEAD(oom_reaper_wait);
-static struct task_struct *oom_reaper_list;
+DECLARE_WAIT_QUEUE_HEAD(oom_reaper_wait);
+struct task_struct *oom_reaper_list;
 static DEFINE_SPINLOCK(oom_reaper_lock);
 
 bool __oom_reap_task_mm(struct mm_struct *mm)
diff --git a/mm/rmap.c b/mm/rmap.c
index b9eb5c12f3fe..f4963d60ff68 100644
--- a/mm/rmap.c
+++ b/mm/rmap.c
@@ -72,6 +72,7 @@ 
 #include <linux/page_idle.h>
 #include <linux/memremap.h>
 #include <linux/userfaultfd_k.h>
+#include <linux/mm_inline.h>
 
 #include <asm/tlbflush.h>
 
@@ -789,6 +790,12 @@  static bool page_referenced_one(struct page *page, struct vm_area_struct *vma,
 		}
 
 		if (pvmw.pte) {
+			/* the multigenerational lru exploits the spatial locality */
+			if (lru_gen_enabled() && pte_young(*pvmw.pte) &&
+			    !(vma->vm_flags & VM_SEQ_READ)) {
+				lru_gen_scan_around(&pvmw);
+				referenced++;
+			}
 			if (ptep_clear_flush_young_notify(vma, address,
 						pvmw.pte)) {
 				/*
diff --git a/mm/swap.c b/mm/swap.c
index 0d3fb2ee3fd6..0315cfa9fa41 100644
--- a/mm/swap.c
+++ b/mm/swap.c
@@ -347,7 +347,7 @@  static bool need_activate_page_drain(int cpu)
 	return pagevec_count(&per_cpu(lru_pvecs.activate_page, cpu)) != 0;
 }
 
-static void activate_page(struct page *page)
+void activate_page(struct page *page)
 {
 	page = compound_head(page);
 	if (PageLRU(page) && !PageActive(page) && !PageUnevictable(page)) {
@@ -367,7 +367,7 @@  static inline void activate_page_drain(int cpu)
 {
 }
 
-static void activate_page(struct page *page)
+void activate_page(struct page *page)
 {
 	struct lruvec *lruvec;
 
diff --git a/mm/vmscan.c b/mm/vmscan.c
index 15eadf2a135e..757ba4f415cc 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -50,6 +50,8 @@ 
 #include <linux/dax.h>
 #include <linux/psi.h>
 #include <linux/memory.h>
+#include <linux/pagewalk.h>
+#include <linux/shmem_fs.h>
 
 #include <asm/tlbflush.h>
 #include <asm/div64.h>
@@ -3208,6 +3210,883 @@  static bool get_next_mm(struct mm_walk_args *args, struct mm_struct **iter)
 	return last;
 }
 
+/******************************************************************************
+ *                          the aging
+ ******************************************************************************/
+
+static int page_update_gen(struct page *page, int gen)
+{
+	unsigned long old_flags, new_flags;
+
+	VM_BUG_ON(gen >= MAX_NR_GENS);
+
+	do {
+		new_flags = old_flags = READ_ONCE(page->flags);
+
+		if (!(new_flags & LRU_GEN_MASK)) {
+			new_flags |= BIT(PG_referenced);
+			continue;
+		}
+
+		new_flags &= ~(LRU_GEN_MASK | LRU_USAGE_MASK | LRU_TIER_FLAGS);
+		new_flags |= (gen + 1UL) << LRU_GEN_PGOFF;
+	} while (new_flags != old_flags &&
+		 cmpxchg(&page->flags, old_flags, new_flags) != old_flags);
+
+	return ((old_flags & LRU_GEN_MASK) >> LRU_GEN_PGOFF) - 1;
+}
+
+static void page_inc_gen(struct page *page, struct lruvec *lruvec, bool reclaiming)
+{
+	int old_gen, new_gen;
+	unsigned long old_flags, new_flags;
+	int type = page_is_file_lru(page);
+	int zone = page_zonenum(page);
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	old_gen = lru_gen_from_seq(lrugen->min_seq[type]);
+
+	do {
+		new_flags = old_flags = READ_ONCE(page->flags);
+		VM_BUG_ON_PAGE(!(new_flags & LRU_GEN_MASK), page);
+
+		new_gen = ((new_flags & LRU_GEN_MASK) >> LRU_GEN_PGOFF) - 1;
+		/* page_update_gen() has updated this page? */
+		if (new_gen >= 0 && new_gen != old_gen) {
+			list_move(&page->lru, &lrugen->lists[new_gen][type][zone]);
+			return;
+		}
+
+		new_gen = (old_gen + 1) % MAX_NR_GENS;
+
+		new_flags &= ~(LRU_GEN_MASK | LRU_USAGE_MASK | LRU_TIER_FLAGS);
+		new_flags |= (new_gen + 1UL) << LRU_GEN_PGOFF;
+		/* for rotate_reclaimable_page() */
+		if (reclaiming)
+			new_flags |= BIT(PG_reclaim);
+	} while (cmpxchg(&page->flags, old_flags, new_flags) != old_flags);
+
+	lru_gen_update_size(page, lruvec, old_gen, new_gen);
+	if (reclaiming)
+		list_move(&page->lru, &lrugen->lists[new_gen][type][zone]);
+	else
+		list_move_tail(&page->lru, &lrugen->lists[new_gen][type][zone]);
+}
+
+static void update_batch_size(struct page *page, int old_gen, int new_gen,
+			      struct mm_walk_args *args)
+{
+	int type = page_is_file_lru(page);
+	int zone = page_zonenum(page);
+	int delta = thp_nr_pages(page);
+
+	VM_BUG_ON(old_gen >= MAX_NR_GENS);
+	VM_BUG_ON(new_gen >= MAX_NR_GENS);
+
+	args->batch_size++;
+
+	args->nr_pages[old_gen][type][zone] -= delta;
+	args->nr_pages[new_gen][type][zone] += delta;
+}
+
+static void reset_batch_size(struct lruvec *lruvec, struct mm_walk_args *args)
+{
+	int gen, type, zone;
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	if (!args->batch_size)
+		return;
+
+	args->batch_size = 0;
+
+	spin_lock_irq(&lruvec->lru_lock);
+
+	for_each_gen_type_zone(gen, type, zone) {
+		enum lru_list lru = type * LRU_FILE;
+		int total = args->nr_pages[gen][type][zone];
+
+		if (!total)
+			continue;
+
+		args->nr_pages[gen][type][zone] = 0;
+		WRITE_ONCE(lrugen->sizes[gen][type][zone],
+			   lrugen->sizes[gen][type][zone] + total);
+
+		if (lru_gen_is_active(lruvec, gen))
+			lru += LRU_ACTIVE;
+		update_lru_size(lruvec, lru, zone, total);
+	}
+
+	spin_unlock_irq(&lruvec->lru_lock);
+}
+
+static int should_skip_vma(unsigned long start, unsigned long end, struct mm_walk *walk)
+{
+	struct address_space *mapping;
+	struct vm_area_struct *vma = walk->vma;
+	struct mm_walk_args *args = walk->private;
+
+	if (!vma_is_accessible(vma) || is_vm_hugetlb_page(vma) ||
+	    (vma->vm_flags & (VM_LOCKED | VM_SPECIAL | VM_SEQ_READ)))
+		return true;
+
+	if (vma_is_anonymous(vma))
+		return !args->swappiness;
+
+	if (WARN_ON_ONCE(!vma->vm_file || !vma->vm_file->f_mapping))
+		return true;
+
+	mapping = vma->vm_file->f_mapping;
+	if (!mapping->a_ops->writepage)
+		return true;
+
+	return (shmem_mapping(mapping) && !args->swappiness) || mapping_unevictable(mapping);
+}
+
+/*
+ * Some userspace memory allocators create many single-page VMAs. So instead of
+ * returning back to the PGD table for each of such VMAs, we finish at least an
+ * entire PMD table and therefore avoid many zigzags.
+ */
+static bool get_next_vma(struct mm_walk *walk, unsigned long mask, unsigned long size,
+			 unsigned long *start, unsigned long *end)
+{
+	unsigned long next = round_up(*end, size);
+	struct mm_walk_args *args = walk->private;
+
+	VM_BUG_ON(mask & size);
+	VM_BUG_ON(*start >= *end);
+	VM_BUG_ON((next & mask) != (*start & mask));
+
+	while (walk->vma) {
+		if (next >= walk->vma->vm_end) {
+			walk->vma = walk->vma->vm_next;
+			continue;
+		}
+
+		if ((next & mask) != (walk->vma->vm_start & mask))
+			return false;
+
+		if (should_skip_vma(walk->vma->vm_start, walk->vma->vm_end, walk)) {
+			walk->vma = walk->vma->vm_next;
+			continue;
+		}
+
+		*start = max(next, walk->vma->vm_start);
+		next = (next | ~mask) + 1;
+		/* rounded-up boundaries can wrap to 0 */
+		*end = next && next < walk->vma->vm_end ? next : walk->vma->vm_end;
+
+		args->mm_stats[MM_VMA_INTERVAL]++;
+
+		return true;
+	}
+
+	return false;
+}
+
+static bool walk_pte_range(pmd_t *pmd, unsigned long start, unsigned long end,
+			   struct mm_walk *walk)
+{
+	int i;
+	pte_t *pte;
+	spinlock_t *ptl;
+	unsigned long addr;
+	int remote = 0;
+	struct mm_walk_args *args = walk->private;
+	int old_gen, new_gen = lru_gen_from_seq(args->max_seq);
+
+	VM_BUG_ON(pmd_leaf(*pmd));
+
+	pte = pte_offset_map_lock(walk->mm, pmd, start & PMD_MASK, &ptl);
+	arch_enter_lazy_mmu_mode();
+restart:
+	for (i = pte_index(start), addr = start; addr != end; i++, addr += PAGE_SIZE) {
+		struct page *page;
+		unsigned long pfn = pte_pfn(pte[i]);
+
+		if (!pte_present(pte[i]) || is_zero_pfn(pfn)) {
+			args->mm_stats[MM_LEAF_HOLE]++;
+			continue;
+		}
+
+		if (WARN_ON_ONCE(pte_devmap(pte[i]) || pte_special(pte[i])))
+			continue;
+
+		if (!pte_young(pte[i])) {
+			args->mm_stats[MM_LEAF_OLD]++;
+			continue;
+		}
+
+		VM_BUG_ON(!pfn_valid(pfn));
+		if (pfn < args->start_pfn || pfn >= args->end_pfn) {
+			args->mm_stats[MM_LEAF_OTHER_NODE]++;
+			remote++;
+			continue;
+		}
+
+		page = compound_head(pfn_to_page(pfn));
+		if (page_to_nid(page) != args->node_id) {
+			args->mm_stats[MM_LEAF_OTHER_NODE]++;
+			remote++;
+			continue;
+		}
+
+		if (page_memcg_rcu(page) != args->memcg) {
+			args->mm_stats[MM_LEAF_OTHER_MEMCG]++;
+			continue;
+		}
+
+		VM_BUG_ON(addr < walk->vma->vm_start || addr >= walk->vma->vm_end);
+		if (!ptep_test_and_clear_young(walk->vma, addr, pte + i))
+			continue;
+
+		if (pte_dirty(pte[i]) && !PageDirty(page) &&
+		    !(PageAnon(page) && PageSwapBacked(page) && !PageSwapCache(page))) {
+			set_page_dirty(page);
+			args->mm_stats[MM_LEAF_DIRTY]++;
+		}
+
+		old_gen = page_update_gen(page, new_gen);
+		if (old_gen >= 0 && old_gen != new_gen)
+			update_batch_size(page, old_gen, new_gen, args);
+		args->mm_stats[MM_LEAF_YOUNG]++;
+	}
+
+	if (i < PTRS_PER_PTE && get_next_vma(walk, PMD_MASK, PAGE_SIZE, &start, &end))
+		goto restart;
+
+	arch_leave_lazy_mmu_mode();
+	pte_unmap_unlock(pte, ptl);
+
+	return IS_ENABLED(CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG) && !remote;
+}
+
+/*
+ * We scan PMD entries in two passes. The first pass reaches to PTE tables and
+ * doesn't take the PMD lock. The second pass clears the accessed bit on PMD
+ * entries and needs to take the PMD lock.
+ */
+#if defined(CONFIG_TRANSPARENT_HUGEPAGE) || defined(CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG)
+static void walk_pmd_range_locked(pud_t *pud, unsigned long start,
+				  struct vm_area_struct *vma, struct mm_walk *walk)
+{
+	int i;
+	pmd_t *pmd;
+	spinlock_t *ptl;
+	struct mm_walk_args *args = walk->private;
+	int old_gen, new_gen = lru_gen_from_seq(args->max_seq);
+
+	VM_BUG_ON(pud_leaf(*pud));
+
+	start &= PUD_MASK;
+	pmd = pmd_offset(pud, start);
+	ptl = pmd_lock(walk->mm, pmd);
+	arch_enter_lazy_mmu_mode();
+
+	for_each_set_bit(i, args->bitmap, PTRS_PER_PMD) {
+		struct page *page;
+		unsigned long pfn = pmd_pfn(pmd[i]);
+		unsigned long addr = start + i * PMD_SIZE;
+
+		if (!pmd_present(pmd[i]) || is_huge_zero_pmd(pmd[i])) {
+			args->mm_stats[MM_LEAF_HOLE]++;
+			continue;
+		}
+
+		if (WARN_ON_ONCE(pmd_devmap(pmd[i])))
+			continue;
+
+		if (!pmd_young(pmd[i])) {
+			args->mm_stats[MM_LEAF_OLD]++;
+			continue;
+		}
+
+		if (!pmd_trans_huge(pmd[i])) {
+			if (IS_ENABLED(CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG) &&
+			    pmdp_test_and_clear_young(vma, addr, pmd + i))
+				args->mm_stats[MM_NONLEAF_YOUNG]++;
+			continue;
+		}
+
+		VM_BUG_ON(!pfn_valid(pfn));
+		if (pfn < args->start_pfn || pfn >= args->end_pfn) {
+			args->mm_stats[MM_LEAF_OTHER_NODE]++;
+			continue;
+		}
+
+		page = pfn_to_page(pfn);
+		VM_BUG_ON_PAGE(PageTail(page), page);
+		if (page_to_nid(page) != args->node_id) {
+			args->mm_stats[MM_LEAF_OTHER_NODE]++;
+			continue;
+		}
+
+		if (page_memcg_rcu(page) != args->memcg) {
+			args->mm_stats[MM_LEAF_OTHER_MEMCG]++;
+			continue;
+		}
+
+		VM_BUG_ON(addr < vma->vm_start || addr >= vma->vm_end);
+		if (!pmdp_test_and_clear_young(vma, addr, pmd + i))
+			continue;
+
+		if (pmd_dirty(pmd[i]) && !PageDirty(page) &&
+		    !(PageAnon(page) && PageSwapBacked(page) && !PageSwapCache(page))) {
+			set_page_dirty(page);
+			args->mm_stats[MM_LEAF_DIRTY]++;
+		}
+
+		old_gen = page_update_gen(page, new_gen);
+		if (old_gen >= 0 && old_gen != new_gen)
+			update_batch_size(page, old_gen, new_gen, args);
+		args->mm_stats[MM_LEAF_YOUNG]++;
+	}
+
+	arch_leave_lazy_mmu_mode();
+	spin_unlock(ptl);
+
+	bitmap_zero(args->bitmap, PTRS_PER_PMD);
+}
+#else
+static void walk_pmd_range_locked(pud_t *pud, unsigned long start,
+				  struct vm_area_struct *vma, struct mm_walk *walk)
+{
+}
+#endif
+
+static void walk_pmd_range(pud_t *pud, unsigned long start, unsigned long end,
+			   struct mm_walk *walk)
+{
+	int i;
+	pmd_t *pmd;
+	unsigned long next;
+	unsigned long addr;
+	struct vm_area_struct *vma;
+	int leaf = 0;
+	int nonleaf = 0;
+	struct mm_walk_args *args = walk->private;
+
+	VM_BUG_ON(pud_leaf(*pud));
+
+	pmd = pmd_offset(pud, start & PUD_MASK);
+restart:
+	vma = walk->vma;
+	for (i = pmd_index(start), addr = start; addr != end; i++, addr = next) {
+		pmd_t val = pmd_read_atomic(pmd + i);
+
+		/* for pmd_read_atomic() */
+		barrier();
+
+		next = pmd_addr_end(addr, end);
+
+		if (!pmd_present(val)) {
+			args->mm_stats[MM_LEAF_HOLE]++;
+			continue;
+		}
+
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
+		if (pmd_trans_huge(val)) {
+			unsigned long pfn = pmd_pfn(val);
+
+			if (is_huge_zero_pmd(val)) {
+				args->mm_stats[MM_LEAF_HOLE]++;
+				continue;
+			}
+
+			if (!pmd_young(val)) {
+				args->mm_stats[MM_LEAF_OLD]++;
+				continue;
+			}
+
+			if (pfn < args->start_pfn || pfn >= args->end_pfn) {
+				args->mm_stats[MM_LEAF_OTHER_NODE]++;
+				continue;
+			}
+
+			__set_bit(i, args->bitmap);
+			leaf++;
+			continue;
+		}
+#endif
+
+#ifdef CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG
+		if (!pmd_young(val)) {
+			args->mm_stats[MM_NONLEAF_OLD]++;
+			continue;
+		}
+#endif
+		if (walk_pte_range(&val, addr, next, walk)) {
+			__set_bit(i, args->bitmap);
+			nonleaf++;
+		}
+	}
+
+	if (leaf) {
+		walk_pmd_range_locked(pud, start, vma, walk);
+		leaf = nonleaf = 0;
+	}
+
+	if (i < PTRS_PER_PMD && get_next_vma(walk, PUD_MASK, PMD_SIZE, &start, &end))
+		goto restart;
+
+	if (nonleaf)
+		walk_pmd_range_locked(pud, start, vma, walk);
+}
+
+static int walk_pud_range(p4d_t *p4d, unsigned long start, unsigned long end,
+			  struct mm_walk *walk)
+{
+	int i;
+	pud_t *pud;
+	unsigned long addr;
+	unsigned long next;
+	struct mm_walk_args *args = walk->private;
+
+	VM_BUG_ON(p4d_leaf(*p4d));
+
+	pud = pud_offset(p4d, start & P4D_MASK);
+restart:
+	for (i = pud_index(start), addr = start; addr != end; i++, addr = next) {
+		pud_t val = READ_ONCE(pud[i]);
+
+		next = pud_addr_end(addr, end);
+
+		if (!pud_present(val) || WARN_ON_ONCE(pud_leaf(val)))
+			continue;
+
+		walk_pmd_range(&val, addr, next, walk);
+
+		if (args->batch_size >= MAX_BATCH_SIZE) {
+			end = (addr | ~PUD_MASK) + 1;
+			goto done;
+		}
+	}
+
+	if (i < PTRS_PER_PUD && get_next_vma(walk, P4D_MASK, PUD_SIZE, &start, &end))
+		goto restart;
+
+	end = round_up(end, P4D_SIZE);
+done:
+	/* rounded-up boundaries can wrap to 0 */
+	args->next_addr = end && walk->vma ? max(end, walk->vma->vm_start) : 0;
+
+	return -EAGAIN;
+}
+
+static void walk_mm(struct mm_walk_args *args, struct mm_struct *mm)
+{
+	static const struct mm_walk_ops mm_walk_ops = {
+		.test_walk = should_skip_vma,
+		.p4d_entry = walk_pud_range,
+	};
+
+	int err;
+	struct mem_cgroup *memcg = args->memcg;
+	struct lruvec *lruvec = mem_cgroup_lruvec(memcg, NODE_DATA(args->node_id));
+
+	args->next_addr = FIRST_USER_ADDRESS;
+
+	do {
+		unsigned long start = args->next_addr;
+		unsigned long end = mm->highest_vm_end;
+
+		err = -EBUSY;
+
+		rcu_read_lock();
+#ifdef CONFIG_MEMCG
+		if (memcg && atomic_read(&memcg->moving_account)) {
+			args->mm_stats[MM_LOCK_CONTENTION]++;
+			goto contended;
+		}
+#endif
+		if (!mmap_read_trylock(mm)) {
+			args->mm_stats[MM_LOCK_CONTENTION]++;
+			goto contended;
+		}
+
+		err = walk_page_range(mm, start, end, &mm_walk_ops, args);
+
+		mmap_read_unlock(mm);
+
+		reset_batch_size(lruvec, args);
+contended:
+		rcu_read_unlock();
+
+		cond_resched();
+	} while (err == -EAGAIN && args->next_addr &&
+		 !mm_is_oom_victim(mm) && !mm_is_migrated(mm, memcg));
+}
+
+static struct mm_walk_args *alloc_mm_walk_args(int nid)
+{
+	struct pglist_data *pgdat;
+	int size = sizeof(struct mm_walk_args);
+
+	if (IS_ENABLED(CONFIG_TRANSPARENT_HUGEPAGE) ||
+	    IS_ENABLED(CONFIG_ARCH_HAS_NONLEAF_PMD_YOUNG))
+		size += sizeof(unsigned long) * BITS_TO_LONGS(PTRS_PER_PMD);
+
+	if (!current_is_kswapd())
+		return kvzalloc_node(size, GFP_KERNEL, nid);
+
+	VM_BUG_ON(nid == NUMA_NO_NODE);
+
+	pgdat = NODE_DATA(nid);
+	if (!pgdat->mm_walk_args)
+		pgdat->mm_walk_args = kvzalloc_node(size, GFP_KERNEL, nid);
+
+	return pgdat->mm_walk_args;
+}
+
+static void free_mm_walk_args(struct mm_walk_args *args)
+{
+	if (!current_is_kswapd())
+		kvfree(args);
+}
+
+static bool inc_min_seq(struct lruvec *lruvec, int type)
+{
+	int gen, zone;
+	int remaining = MAX_BATCH_SIZE;
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	VM_BUG_ON(!seq_is_valid(lruvec));
+
+	if (get_nr_gens(lruvec, type) != MAX_NR_GENS)
+		return true;
+
+	gen = lru_gen_from_seq(lrugen->min_seq[type]);
+
+	for (zone = 0; zone < MAX_NR_ZONES; zone++) {
+		struct list_head *head = &lrugen->lists[gen][type][zone];
+
+		while (!list_empty(head)) {
+			struct page *page = lru_to_page(head);
+
+			VM_BUG_ON_PAGE(PageTail(page), page);
+			VM_BUG_ON_PAGE(PageUnevictable(page), page);
+			VM_BUG_ON_PAGE(PageActive(page), page);
+			VM_BUG_ON_PAGE(page_is_file_lru(page) != type, page);
+			VM_BUG_ON_PAGE(page_zonenum(page) != zone, page);
+
+			prefetchw_prev_lru_page(page, head, flags);
+
+			page_inc_gen(page, lruvec, false);
+
+			if (!--remaining)
+				return false;
+		}
+
+		VM_BUG_ON(lrugen->sizes[gen][type][zone]);
+	}
+
+	reset_controller_pos(lruvec, gen, type);
+	WRITE_ONCE(lrugen->min_seq[type], lrugen->min_seq[type] + 1);
+
+	return true;
+}
+
+static bool try_to_inc_min_seq(struct lruvec *lruvec, int type)
+{
+	int gen, zone;
+	bool success = false;
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	VM_BUG_ON(!seq_is_valid(lruvec));
+
+	while (get_nr_gens(lruvec, type) > MIN_NR_GENS) {
+		gen = lru_gen_from_seq(lrugen->min_seq[type]);
+
+		for (zone = 0; zone < MAX_NR_ZONES; zone++) {
+			if (!list_empty(&lrugen->lists[gen][type][zone]))
+				return success;
+		}
+
+		reset_controller_pos(lruvec, gen, type);
+		WRITE_ONCE(lrugen->min_seq[type], lrugen->min_seq[type] + 1);
+
+		success = true;
+	}
+
+	return success;
+}
+
+static void inc_max_seq(struct lruvec *lruvec, unsigned long max_seq)
+{
+	int gen, type, zone;
+	struct lrugen *lrugen = &lruvec->evictable;
+
+	spin_lock_irq(&lruvec->lru_lock);
+
+	VM_BUG_ON(!seq_is_valid(lruvec));
+
+	if (max_seq != lrugen->max_seq)
+		goto unlock;
+
+	for (type = 0; type < ANON_AND_FILE; type++) {
+		if (try_to_inc_min_seq(lruvec, type))
+			continue;
+
+		while (!inc_min_seq(lruvec, type)) {
+			spin_unlock_irq(&lruvec->lru_lock);
+			cond_resched();
+			spin_lock_irq(&lruvec->lru_lock);
+		}
+	}
+
+	gen = lru_gen_from_seq(lrugen->max_seq - 1);
+	for_each_type_zone(type, zone) {
+		enum lru_list lru = type * LRU_FILE;
+		long total = lrugen->sizes[gen][type][zone];
+
+		if (!total)
+			continue;
+
+		WARN_ON_ONCE(total != (int)total);
+
+		update_lru_size(lruvec, lru, zone, total);
+		update_lru_size(lruvec, lru + LRU_ACTIVE, zone, -total);
+	}
+
+	gen = lru_gen_from_seq(lrugen->max_seq + 1);
+	for_each_type_zone(type, zone) {
+		VM_BUG_ON(lrugen->sizes[gen][type][zone]);
+		VM_BUG_ON(!list_empty(&lrugen->lists[gen][type][zone]));
+	}
+
+	for (type = 0; type < ANON_AND_FILE; type++)
+		reset_controller_pos(lruvec, gen, type);
+
+	WRITE_ONCE(lrugen->timestamps[gen], jiffies);
+	/* make sure all preceding modifications appear first */
+	smp_store_release(&lrugen->max_seq, lrugen->max_seq + 1);
+unlock:
+	spin_unlock_irq(&lruvec->lru_lock);
+}
+
+/* Main function used by the foreground, the background and the user-triggered aging. */
+static bool try_to_inc_max_seq(struct lruvec *lruvec, unsigned long max_seq,
+			       struct scan_control *sc, int swappiness)
+{
+	bool last;
+	struct mm_walk_args *args;
+	struct mm_struct *mm = NULL;
+	struct lrugen *lrugen = &lruvec->evictable;
+	struct mem_cgroup *memcg = lruvec_memcg(lruvec);
+	struct lru_gen_mm_list *mm_list = get_mm_list(memcg);
+	struct pglist_data *pgdat = lruvec_pgdat(lruvec);
+	int nid = pgdat->node_id;
+
+	VM_BUG_ON(max_seq > READ_ONCE(lrugen->max_seq));
+
+	/*
+	 * If we are not from run_aging() and clearing the accessed bit may
+	 * trigger page faults, then don't proceed to clearing all accessed
+	 * PTEs. Instead, fallback to lru_gen_scan_around(), which only clears a
+	 * handful of accessed PTEs. This is less efficient but causes fewer
+	 * page faults on CPUs that don't have the capability.
+	 */
+	if ((current->flags & PF_MEMALLOC) && !arch_has_hw_pte_young()) {
+		inc_max_seq(lruvec, max_seq);
+		return true;
+	}
+
+	args = alloc_mm_walk_args(nid);
+	if (!args)
+		return false;
+
+	args->memcg = memcg;
+	args->max_seq = max_seq;
+	args->start_pfn = pgdat->node_start_pfn;
+	args->end_pfn = pgdat_end_pfn(pgdat);
+	args->node_id = nid;
+	args->swappiness = swappiness;
+
+	do {
+		last = get_next_mm(args, &mm);
+		if (mm)
+			walk_mm(args, mm);
+
+		cond_resched();
+	} while (mm);
+
+	free_mm_walk_args(args);
+
+	if (!last) {
+		/* don't wait unless we may have trouble reclaiming */
+		if (!current_is_kswapd() && sc->priority < DEF_PRIORITY - 2)
+			wait_event_killable(mm_list->nodes[nid].wait,
+					    max_seq < READ_ONCE(lrugen->max_seq));
+
+		return max_seq < READ_ONCE(lrugen->max_seq);
+	}
+
+	VM_BUG_ON(max_seq != READ_ONCE(lrugen->max_seq));
+
+	inc_max_seq(lruvec, max_seq);
+	/* either we see any waiters or they will see updated max_seq */
+	if (wq_has_sleeper(&mm_list->nodes[nid].wait))
+		wake_up_all(&mm_list->nodes[nid].wait);
+
+	wakeup_flusher_threads(WB_REASON_VMSCAN);
+
+	return true;
+}
+
+/* Protect the working set accessed within the last N milliseconds. */
+static unsigned long lru_gen_min_ttl;
+
+static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc)
+{
+	struct mem_cgroup *memcg;
+
+	VM_BUG_ON(!current_is_kswapd());
+
+	if (sc->file_is_tiny && mutex_trylock(&oom_lock)) {
+		struct oom_control oc = {
+			.gfp_mask = sc->gfp_mask,
+			.order = sc->order,
+		};
+
+		/* to avoid overkilling */
+		if (!oom_reaping_in_progress())
+			out_of_memory(&oc);
+
+		mutex_unlock(&oom_lock);
+	}
+
+	if (READ_ONCE(lru_gen_min_ttl))
+		sc->file_is_tiny = 1;
+
+	if (!mem_cgroup_disabled() && !sc->force_deactivate) {
+		sc->force_deactivate = 1;
+		return;
+	}
+
+	sc->force_deactivate = 0;
+
+	memcg = mem_cgroup_iter(NULL, NULL, NULL);
+	do {
+		struct lruvec *lruvec = mem_cgroup_lruvec(memcg, pgdat);
+		int swappiness = get_swappiness(memcg);
+		DEFINE_MAX_SEQ(lruvec);
+		DEFINE_MIN_SEQ(lruvec);
+
+		if (get_lo_wmark(max_seq, min_seq, swappiness) == MIN_NR_GENS)
+			try_to_inc_max_seq(lruvec, max_seq, sc, swappiness);
+
+		cond_resched();
+	} while ((memcg = mem_cgroup_iter(NULL, memcg, NULL)));
+}
+
+#define NR_TO_SCAN	(SWAP_CLUSTER_MAX * 2)
+#define SIZE_TO_SCAN	(NR_TO_SCAN * PAGE_SIZE)
+
+/* Scan the vicinity of an accessed PTE when shrink_page_list() uses the rmap. */
+void lru_gen_scan_around(struct page_vma_mapped_walk *pvmw)
+{
+	int i;
+	pte_t *pte;
+	struct page *page;
+	int old_gen, new_gen;
+	unsigned long start;
+	unsigned long end;
+	unsigned long addr;
+	struct mem_cgroup *memcg = page_memcg(pvmw->page);
+	struct pglist_data *pgdat = page_pgdat(pvmw->page);
+	struct lruvec *lruvec = mem_cgroup_lruvec(memcg, pgdat);
+	unsigned long bitmap[BITS_TO_LONGS(NR_TO_SCAN)] = {};
+
+	lockdep_assert_held(pvmw->ptl);
+	VM_BUG_ON_PAGE(PageLRU(pvmw->page), pvmw->page);
+
+	start = max(pvmw->address & PMD_MASK, pvmw->vma->vm_start);
+	end = pmd_addr_end(pvmw->address, pvmw->vma->vm_end);
+
+	if (end - start > SIZE_TO_SCAN) {
+		if (pvmw->address - start < SIZE_TO_SCAN / 2)
+			end = start + SIZE_TO_SCAN;
+		else if (end - pvmw->address < SIZE_TO_SCAN / 2)
+			start = end - SIZE_TO_SCAN;
+		else {
+			start = pvmw->address - SIZE_TO_SCAN / 2;
+			end = pvmw->address + SIZE_TO_SCAN / 2;
+		}
+	}
+
+	pte = pvmw->pte - (pvmw->address - start) / PAGE_SIZE;
+	new_gen = lru_gen_from_seq(READ_ONCE(lruvec->evictable.max_seq));
+
+	rcu_read_lock();
+	arch_enter_lazy_mmu_mode();
+
+	for (i = 0, addr = start; addr != end; i++, addr += PAGE_SIZE) {
+		unsigned long pfn = pte_pfn(pte[i]);
+
+		if (!pte_present(pte[i]) || is_zero_pfn(pfn))
+			continue;
+
+		if (WARN_ON_ONCE(pte_devmap(pte[i]) || pte_special(pte[i])))
+			continue;
+
+		if (!pte_young(pte[i]))
+			continue;
+
+		VM_BUG_ON(!pfn_valid(pfn));
+		if (pfn < pgdat->node_start_pfn || pfn >= pgdat_end_pfn(pgdat))
+			continue;
+
+		page = compound_head(pfn_to_page(pfn));
+		if (page_to_nid(page) != pgdat->node_id)
+			continue;
+
+		if (page_memcg_rcu(page) != memcg)
+			continue;
+
+		VM_BUG_ON(addr < pvmw->vma->vm_start || addr >= pvmw->vma->vm_end);
+		if (!ptep_test_and_clear_young(pvmw->vma, addr, pte + i))
+			continue;
+
+		old_gen = page_lru_gen(page);
+		if (old_gen < 0)
+			SetPageReferenced(page);
+		else if (old_gen != new_gen)
+			__set_bit(i, bitmap);
+
+		if (pte_dirty(pte[i]) && !PageDirty(page) &&
+		    !(PageAnon(page) && PageSwapBacked(page) && !PageSwapCache(page)))
+			set_page_dirty(page);
+	}
+
+	arch_leave_lazy_mmu_mode();
+	rcu_read_unlock();
+
+	if (bitmap_weight(bitmap, NR_TO_SCAN) < PAGEVEC_SIZE) {
+		for_each_set_bit(i, bitmap, NR_TO_SCAN)
+			activate_page(pte_page(pte[i]));
+		return;
+	}
+
+	lock_page_memcg(pvmw->page);
+	spin_lock_irq(&lruvec->lru_lock);
+
+	new_gen = lru_gen_from_seq(lruvec->evictable.max_seq);
+
+	for_each_set_bit(i, bitmap, NR_TO_SCAN) {
+		page = compound_head(pte_page(pte[i]));
+		if (page_memcg_rcu(page) != memcg)
+			continue;
+
+		old_gen = page_update_gen(page, new_gen);
+		if (old_gen >= 0 && old_gen != new_gen)
+			lru_gen_update_size(page, lruvec, old_gen, new_gen);
+	}
+
+	spin_unlock_irq(&lruvec->lru_lock);
+	unlock_page_memcg(pvmw->page);
+}
+
 /******************************************************************************
  *                          state change
  ******************************************************************************/
@@ -3392,9 +4271,18 @@  static int __meminit __maybe_unused mem_notifier(struct notifier_block *self,
 
 	pgdat = NODE_DATA(nid);
 
+	if (action == MEM_CANCEL_ONLINE || action == MEM_OFFLINE) {
+		free_mm_walk_args(pgdat->mm_walk_args);
+		pgdat->mm_walk_args = NULL;
+		return NOTIFY_DONE;
+	}
+
 	if (action != MEM_GOING_ONLINE)
 		return NOTIFY_DONE;
 
+	if (!WARN_ON_ONCE(pgdat->mm_walk_args))
+		pgdat->mm_walk_args = alloc_mm_walk_args(NUMA_NO_NODE);
+
 	mutex_lock(&lru_gen_state_mutex);
 	cgroup_lock();
 
@@ -3443,6 +4331,10 @@  static int __init init_lru_gen(void)
 	BUILD_BUG_ON(BIT(LRU_GEN_WIDTH) <= MAX_NR_GENS);
 	BUILD_BUG_ON(sizeof(MM_STAT_CODES) != NR_MM_STATS + 1);
 
+	VM_BUG_ON(PMD_SIZE / PAGE_SIZE != PTRS_PER_PTE);
+	VM_BUG_ON(PUD_SIZE / PMD_SIZE != PTRS_PER_PMD);
+	VM_BUG_ON(P4D_SIZE / PUD_SIZE != PTRS_PER_PUD);
+
 	if (mem_cgroup_disabled()) {
 		global_mm_list = alloc_mm_list();
 		if (!global_mm_list)
@@ -3460,6 +4352,12 @@  static int __init init_lru_gen(void)
  */
 arch_initcall(init_lru_gen);
 
+#else /* CONFIG_LRU_GEN */
+
+static void lru_gen_age_node(struct pglist_data *pgdat, struct scan_control *sc)
+{
+}
+
 #endif /* CONFIG_LRU_GEN */
 
 static void shrink_lruvec(struct lruvec *lruvec, struct scan_control *sc)
@@ -4313,6 +5211,11 @@  static void age_active_anon(struct pglist_data *pgdat,
 	struct mem_cgroup *memcg;
 	struct lruvec *lruvec;
 
+	if (lru_gen_enabled()) {
+		lru_gen_age_node(pgdat, sc);
+		return;
+	}
+
 	if (!total_swap_pages)
 		return;