diff mbox series

[mm-unstable,v7,08/18] mm/khugepaged: record SCAN_PMD_MAPPED when scan_pmd() finds hugepage

Message ID 20220706235936.2197195-9-zokeefe@google.com (mailing list archive)
State New
Headers show
Series mm: userspace hugepage collapse | expand

Commit Message

Zach O'Keefe July 6, 2022, 11:59 p.m. UTC
When scanning an anon pmd to see if it's eligible for collapse, return
SCAN_PMD_MAPPED if the pmd already maps a hugepage.  Note that
SCAN_PMD_MAPPED is different from SCAN_PAGE_COMPOUND used in the
file-collapse path, since the latter might identify pte-mapped compound
pages.  This is required by MADV_COLLAPSE which necessarily needs to
know what hugepage-aligned/sized regions are already pmd-mapped.

In order to determine if a pmd already maps a hugepage, refactor
mm_find_pmd():

Return mm_find_pmd() to it's pre-commit f72e7dcdd252 ("mm: let mm_find_pmd
fix buggy race with THP fault") behavior.  ksm was the only caller that
explicitly wanted a pte-mapping pmd, so open code the pte-mapping logic
there (pmd_present() and pmd_trans_huge() checks).

Undo revert change in commit f72e7dcdd252 ("mm: let mm_find_pmd fix buggy race
with THP fault") that open-coded split_huge_pmd_address() pmd lookup and
use mm_find_pmd() instead.

Signed-off-by: Zach O'Keefe <zokeefe@google.com>
---
 include/trace/events/huge_memory.h |  1 +
 mm/huge_memory.c                   | 18 +--------
 mm/internal.h                      |  2 +-
 mm/khugepaged.c                    | 60 ++++++++++++++++++++++++------
 mm/ksm.c                           | 10 +++++
 mm/rmap.c                          | 15 +++-----
 6 files changed, 67 insertions(+), 39 deletions(-)

Comments

Yang Shi July 11, 2022, 9:03 p.m. UTC | #1
On Wed, Jul 6, 2022 at 5:06 PM Zach O'Keefe <zokeefe@google.com> wrote:
>
> When scanning an anon pmd to see if it's eligible for collapse, return
> SCAN_PMD_MAPPED if the pmd already maps a hugepage.  Note that
> SCAN_PMD_MAPPED is different from SCAN_PAGE_COMPOUND used in the
> file-collapse path, since the latter might identify pte-mapped compound
> pages.  This is required by MADV_COLLAPSE which necessarily needs to
> know what hugepage-aligned/sized regions are already pmd-mapped.
>
> In order to determine if a pmd already maps a hugepage, refactor
> mm_find_pmd():
>
> Return mm_find_pmd() to it's pre-commit f72e7dcdd252 ("mm: let mm_find_pmd
> fix buggy race with THP fault") behavior.  ksm was the only caller that
> explicitly wanted a pte-mapping pmd, so open code the pte-mapping logic
> there (pmd_present() and pmd_trans_huge() checks).
>
> Undo revert change in commit f72e7dcdd252 ("mm: let mm_find_pmd fix buggy race
> with THP fault") that open-coded split_huge_pmd_address() pmd lookup and
> use mm_find_pmd() instead.
>
> Signed-off-by: Zach O'Keefe <zokeefe@google.com>

Reviewed-by: Yang Shi <shy828301@gmail.com>

> ---
>  include/trace/events/huge_memory.h |  1 +
>  mm/huge_memory.c                   | 18 +--------
>  mm/internal.h                      |  2 +-
>  mm/khugepaged.c                    | 60 ++++++++++++++++++++++++------
>  mm/ksm.c                           | 10 +++++
>  mm/rmap.c                          | 15 +++-----
>  6 files changed, 67 insertions(+), 39 deletions(-)
>
> diff --git a/include/trace/events/huge_memory.h b/include/trace/events/huge_memory.h
> index d651f3437367..55392bf30a03 100644
> --- a/include/trace/events/huge_memory.h
> +++ b/include/trace/events/huge_memory.h
> @@ -11,6 +11,7 @@
>         EM( SCAN_FAIL,                  "failed")                       \
>         EM( SCAN_SUCCEED,               "succeeded")                    \
>         EM( SCAN_PMD_NULL,              "pmd_null")                     \
> +       EM( SCAN_PMD_MAPPED,            "page_pmd_mapped")              \
>         EM( SCAN_EXCEED_NONE_PTE,       "exceed_none_pte")              \
>         EM( SCAN_EXCEED_SWAP_PTE,       "exceed_swap_pte")              \
>         EM( SCAN_EXCEED_SHARED_PTE,     "exceed_shared_pte")            \
> diff --git a/mm/huge_memory.c b/mm/huge_memory.c
> index 4fbe43dc1568..fb76db6c703e 100644
> --- a/mm/huge_memory.c
> +++ b/mm/huge_memory.c
> @@ -2363,25 +2363,11 @@ void __split_huge_pmd(struct vm_area_struct *vma, pmd_t *pmd,
>  void split_huge_pmd_address(struct vm_area_struct *vma, unsigned long address,
>                 bool freeze, struct folio *folio)
>  {
> -       pgd_t *pgd;
> -       p4d_t *p4d;
> -       pud_t *pud;
> -       pmd_t *pmd;
> +       pmd_t *pmd = mm_find_pmd(vma->vm_mm, address);
>
> -       pgd = pgd_offset(vma->vm_mm, address);
> -       if (!pgd_present(*pgd))
> +       if (!pmd)
>                 return;
>
> -       p4d = p4d_offset(pgd, address);
> -       if (!p4d_present(*p4d))
> -               return;
> -
> -       pud = pud_offset(p4d, address);
> -       if (!pud_present(*pud))
> -               return;
> -
> -       pmd = pmd_offset(pud, address);
> -
>         __split_huge_pmd(vma, pmd, address, freeze, folio);
>  }
>
> diff --git a/mm/internal.h b/mm/internal.h
> index 6e14749ad1e5..ef8c23fb678f 100644
> --- a/mm/internal.h
> +++ b/mm/internal.h
> @@ -188,7 +188,7 @@ extern void reclaim_throttle(pg_data_t *pgdat, enum vmscan_throttle_state reason
>  /*
>   * in mm/rmap.c:
>   */
> -extern pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address);
> +pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address);
>
>  /*
>   * in mm/page_alloc.c
> diff --git a/mm/khugepaged.c b/mm/khugepaged.c
> index b0e20db3f805..c7a09cc9a0e8 100644
> --- a/mm/khugepaged.c
> +++ b/mm/khugepaged.c
> @@ -28,6 +28,7 @@ enum scan_result {
>         SCAN_FAIL,
>         SCAN_SUCCEED,
>         SCAN_PMD_NULL,
> +       SCAN_PMD_MAPPED,
>         SCAN_EXCEED_NONE_PTE,
>         SCAN_EXCEED_SWAP_PTE,
>         SCAN_EXCEED_SHARED_PTE,
> @@ -871,6 +872,45 @@ static int hugepage_vma_revalidate(struct mm_struct *mm, unsigned long address,
>         return SCAN_SUCCEED;
>  }
>
> +static int find_pmd_or_thp_or_none(struct mm_struct *mm,
> +                                  unsigned long address,
> +                                  pmd_t **pmd)
> +{
> +       pmd_t pmde;
> +
> +       *pmd = mm_find_pmd(mm, address);
> +       if (!*pmd)
> +               return SCAN_PMD_NULL;
> +
> +       pmde = pmd_read_atomic(*pmd);
> +
> +#ifdef CONFIG_TRANSPARENT_HUGEPAGE
> +       /* See comments in pmd_none_or_trans_huge_or_clear_bad() */
> +       barrier();
> +#endif
> +       if (!pmd_present(pmde))
> +               return SCAN_PMD_NULL;
> +       if (pmd_trans_huge(pmde))
> +               return SCAN_PMD_MAPPED;
> +       if (pmd_bad(pmde))
> +               return SCAN_PMD_NULL;
> +       return SCAN_SUCCEED;
> +}
> +
> +static int check_pmd_still_valid(struct mm_struct *mm,
> +                                unsigned long address,
> +                                pmd_t *pmd)
> +{
> +       pmd_t *new_pmd;
> +       int result = find_pmd_or_thp_or_none(mm, address, &new_pmd);
> +
> +       if (result != SCAN_SUCCEED)
> +               return result;
> +       if (new_pmd != pmd)
> +               return SCAN_FAIL;
> +       return SCAN_SUCCEED;
> +}
> +
>  /*
>   * Bring missing pages in from swap, to complete THP collapse.
>   * Only done if khugepaged_scan_pmd believes it is worthwhile.
> @@ -982,9 +1022,8 @@ static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
>                 goto out_nolock;
>         }
>
> -       pmd = mm_find_pmd(mm, address);
> -       if (!pmd) {
> -               result = SCAN_PMD_NULL;
> +       result = find_pmd_or_thp_or_none(mm, address, &pmd);
> +       if (result != SCAN_SUCCEED) {
>                 mmap_read_unlock(mm);
>                 goto out_nolock;
>         }
> @@ -1012,7 +1051,8 @@ static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
>         if (result != SCAN_SUCCEED)
>                 goto out_up_write;
>         /* check if the pmd is still valid */
> -       if (mm_find_pmd(mm, address) != pmd)
> +       result = check_pmd_still_valid(mm, address, pmd);
> +       if (result != SCAN_SUCCEED)
>                 goto out_up_write;
>
>         anon_vma_lock_write(vma->anon_vma);
> @@ -1115,11 +1155,9 @@ static int khugepaged_scan_pmd(struct mm_struct *mm, struct vm_area_struct *vma,
>
>         VM_BUG_ON(address & ~HPAGE_PMD_MASK);
>
> -       pmd = mm_find_pmd(mm, address);
> -       if (!pmd) {
> -               result = SCAN_PMD_NULL;
> +       result = find_pmd_or_thp_or_none(mm, address, &pmd);
> +       if (result != SCAN_SUCCEED)
>                 goto out;
> -       }
>
>         memset(cc->node_load, 0, sizeof(cc->node_load));
>         pte = pte_offset_map_lock(mm, pmd, address, &ptl);
> @@ -1373,8 +1411,7 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
>         if (!PageHead(hpage))
>                 goto drop_hpage;
>
> -       pmd = mm_find_pmd(mm, haddr);
> -       if (!pmd)
> +       if (find_pmd_or_thp_or_none(mm, haddr, &pmd) != SCAN_SUCCEED)
>                 goto drop_hpage;
>
>         start_pte = pte_offset_map_lock(mm, pmd, haddr, &ptl);
> @@ -1492,8 +1529,7 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
>                 if (vma->vm_end < addr + HPAGE_PMD_SIZE)
>                         continue;
>                 mm = vma->vm_mm;
> -               pmd = mm_find_pmd(mm, addr);
> -               if (!pmd)
> +               if (find_pmd_or_thp_or_none(mm, addr, &pmd) != SCAN_SUCCEED)
>                         continue;
>                 /*
>                  * We need exclusive mmap_lock to retract page table.
> diff --git a/mm/ksm.c b/mm/ksm.c
> index 075123602bd0..3e0a0a42fa1f 100644
> --- a/mm/ksm.c
> +++ b/mm/ksm.c
> @@ -1136,6 +1136,7 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
>  {
>         struct mm_struct *mm = vma->vm_mm;
>         pmd_t *pmd;
> +       pmd_t pmde;
>         pte_t *ptep;
>         pte_t newpte;
>         spinlock_t *ptl;
> @@ -1150,6 +1151,15 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
>         pmd = mm_find_pmd(mm, addr);
>         if (!pmd)
>                 goto out;
> +       /*
> +        * Some THP functions use the sequence pmdp_huge_clear_flush(), set_pmd_at()
> +        * without holding anon_vma lock for write.  So when looking for a
> +        * genuine pmde (in which to find pte), test present and !THP together.
> +        */
> +       pmde = *pmd;
> +       barrier();
> +       if (!pmd_present(pmde) || pmd_trans_huge(pmde))
> +               goto out;
>
>         mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, vma, mm, addr,
>                                 addr + PAGE_SIZE);
> diff --git a/mm/rmap.c b/mm/rmap.c
> index edc06c52bc82..af775855e58f 100644
> --- a/mm/rmap.c
> +++ b/mm/rmap.c
> @@ -767,13 +767,17 @@ unsigned long page_address_in_vma(struct page *page, struct vm_area_struct *vma)
>         return vma_address(page, vma);
>  }
>
> +/*
> + * Returns the actual pmd_t* where we expect 'address' to be mapped from, or
> + * NULL if it doesn't exist.  No guarantees / checks on what the pmd_t*
> + * represents.
> + */
>  pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address)
>  {
>         pgd_t *pgd;
>         p4d_t *p4d;
>         pud_t *pud;
>         pmd_t *pmd = NULL;
> -       pmd_t pmde;
>
>         pgd = pgd_offset(mm, address);
>         if (!pgd_present(*pgd))
> @@ -788,15 +792,6 @@ pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address)
>                 goto out;
>
>         pmd = pmd_offset(pud, address);
> -       /*
> -        * Some THP functions use the sequence pmdp_huge_clear_flush(), set_pmd_at()
> -        * without holding anon_vma lock for write.  So when looking for a
> -        * genuine pmde (in which to find pte), test present and !THP together.
> -        */
> -       pmde = *pmd;
> -       barrier();
> -       if (!pmd_present(pmde) || pmd_trans_huge(pmde))
> -               pmd = NULL;
>  out:
>         return pmd;
>  }
> --
> 2.37.0.rc0.161.g10f37bed90-goog
>
Zach O'Keefe July 12, 2022, 4:50 p.m. UTC | #2
On Jul 11 14:03, Yang Shi wrote:
> On Wed, Jul 6, 2022 at 5:06 PM Zach O'Keefe <zokeefe@google.com> wrote:
> >
> > When scanning an anon pmd to see if it's eligible for collapse, return
> > SCAN_PMD_MAPPED if the pmd already maps a hugepage.  Note that
> > SCAN_PMD_MAPPED is different from SCAN_PAGE_COMPOUND used in the
> > file-collapse path, since the latter might identify pte-mapped compound
> > pages.  This is required by MADV_COLLAPSE which necessarily needs to
> > know what hugepage-aligned/sized regions are already pmd-mapped.
> >
> > In order to determine if a pmd already maps a hugepage, refactor
> > mm_find_pmd():
> >
> > Return mm_find_pmd() to it's pre-commit f72e7dcdd252 ("mm: let mm_find_pmd
> > fix buggy race with THP fault") behavior.  ksm was the only caller that
> > explicitly wanted a pte-mapping pmd, so open code the pte-mapping logic
> > there (pmd_present() and pmd_trans_huge() checks).
> >
> > Undo revert change in commit f72e7dcdd252 ("mm: let mm_find_pmd fix buggy race
> > with THP fault") that open-coded split_huge_pmd_address() pmd lookup and
> > use mm_find_pmd() instead.
> >
> > Signed-off-by: Zach O'Keefe <zokeefe@google.com>
> 
> Reviewed-by: Yang Shi <shy828301@gmail.com>
> 

Thanks for taking the time to review!

Zach

> > ---
> >  include/trace/events/huge_memory.h |  1 +
> >  mm/huge_memory.c                   | 18 +--------
> >  mm/internal.h                      |  2 +-
> >  mm/khugepaged.c                    | 60 ++++++++++++++++++++++++------
> >  mm/ksm.c                           | 10 +++++
> >  mm/rmap.c                          | 15 +++-----
> >  6 files changed, 67 insertions(+), 39 deletions(-)
> >
> > diff --git a/include/trace/events/huge_memory.h b/include/trace/events/huge_memory.h
> > index d651f3437367..55392bf30a03 100644
> > --- a/include/trace/events/huge_memory.h
> > +++ b/include/trace/events/huge_memory.h
> > @@ -11,6 +11,7 @@
> >         EM( SCAN_FAIL,                  "failed")                       \
> >         EM( SCAN_SUCCEED,               "succeeded")                    \
> >         EM( SCAN_PMD_NULL,              "pmd_null")                     \
> > +       EM( SCAN_PMD_MAPPED,            "page_pmd_mapped")              \
> >         EM( SCAN_EXCEED_NONE_PTE,       "exceed_none_pte")              \
> >         EM( SCAN_EXCEED_SWAP_PTE,       "exceed_swap_pte")              \
> >         EM( SCAN_EXCEED_SHARED_PTE,     "exceed_shared_pte")            \
> > diff --git a/mm/huge_memory.c b/mm/huge_memory.c
> > index 4fbe43dc1568..fb76db6c703e 100644
> > --- a/mm/huge_memory.c
> > +++ b/mm/huge_memory.c
> > @@ -2363,25 +2363,11 @@ void __split_huge_pmd(struct vm_area_struct *vma, pmd_t *pmd,
> >  void split_huge_pmd_address(struct vm_area_struct *vma, unsigned long address,
> >                 bool freeze, struct folio *folio)
> >  {
> > -       pgd_t *pgd;
> > -       p4d_t *p4d;
> > -       pud_t *pud;
> > -       pmd_t *pmd;
> > +       pmd_t *pmd = mm_find_pmd(vma->vm_mm, address);
> >
> > -       pgd = pgd_offset(vma->vm_mm, address);
> > -       if (!pgd_present(*pgd))
> > +       if (!pmd)
> >                 return;
> >
> > -       p4d = p4d_offset(pgd, address);
> > -       if (!p4d_present(*p4d))
> > -               return;
> > -
> > -       pud = pud_offset(p4d, address);
> > -       if (!pud_present(*pud))
> > -               return;
> > -
> > -       pmd = pmd_offset(pud, address);
> > -
> >         __split_huge_pmd(vma, pmd, address, freeze, folio);
> >  }
> >
> > diff --git a/mm/internal.h b/mm/internal.h
> > index 6e14749ad1e5..ef8c23fb678f 100644
> > --- a/mm/internal.h
> > +++ b/mm/internal.h
> > @@ -188,7 +188,7 @@ extern void reclaim_throttle(pg_data_t *pgdat, enum vmscan_throttle_state reason
> >  /*
> >   * in mm/rmap.c:
> >   */
> > -extern pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address);
> > +pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address);
> >
> >  /*
> >   * in mm/page_alloc.c
> > diff --git a/mm/khugepaged.c b/mm/khugepaged.c
> > index b0e20db3f805..c7a09cc9a0e8 100644
> > --- a/mm/khugepaged.c
> > +++ b/mm/khugepaged.c
> > @@ -28,6 +28,7 @@ enum scan_result {
> >         SCAN_FAIL,
> >         SCAN_SUCCEED,
> >         SCAN_PMD_NULL,
> > +       SCAN_PMD_MAPPED,
> >         SCAN_EXCEED_NONE_PTE,
> >         SCAN_EXCEED_SWAP_PTE,
> >         SCAN_EXCEED_SHARED_PTE,
> > @@ -871,6 +872,45 @@ static int hugepage_vma_revalidate(struct mm_struct *mm, unsigned long address,
> >         return SCAN_SUCCEED;
> >  }
> >
> > +static int find_pmd_or_thp_or_none(struct mm_struct *mm,
> > +                                  unsigned long address,
> > +                                  pmd_t **pmd)
> > +{
> > +       pmd_t pmde;
> > +
> > +       *pmd = mm_find_pmd(mm, address);
> > +       if (!*pmd)
> > +               return SCAN_PMD_NULL;
> > +
> > +       pmde = pmd_read_atomic(*pmd);
> > +
> > +#ifdef CONFIG_TRANSPARENT_HUGEPAGE
> > +       /* See comments in pmd_none_or_trans_huge_or_clear_bad() */
> > +       barrier();
> > +#endif
> > +       if (!pmd_present(pmde))
> > +               return SCAN_PMD_NULL;
> > +       if (pmd_trans_huge(pmde))
> > +               return SCAN_PMD_MAPPED;
> > +       if (pmd_bad(pmde))
> > +               return SCAN_PMD_NULL;
> > +       return SCAN_SUCCEED;
> > +}
> > +
> > +static int check_pmd_still_valid(struct mm_struct *mm,
> > +                                unsigned long address,
> > +                                pmd_t *pmd)
> > +{
> > +       pmd_t *new_pmd;
> > +       int result = find_pmd_or_thp_or_none(mm, address, &new_pmd);
> > +
> > +       if (result != SCAN_SUCCEED)
> > +               return result;
> > +       if (new_pmd != pmd)
> > +               return SCAN_FAIL;
> > +       return SCAN_SUCCEED;
> > +}
> > +
> >  /*
> >   * Bring missing pages in from swap, to complete THP collapse.
> >   * Only done if khugepaged_scan_pmd believes it is worthwhile.
> > @@ -982,9 +1022,8 @@ static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
> >                 goto out_nolock;
> >         }
> >
> > -       pmd = mm_find_pmd(mm, address);
> > -       if (!pmd) {
> > -               result = SCAN_PMD_NULL;
> > +       result = find_pmd_or_thp_or_none(mm, address, &pmd);
> > +       if (result != SCAN_SUCCEED) {
> >                 mmap_read_unlock(mm);
> >                 goto out_nolock;
> >         }
> > @@ -1012,7 +1051,8 @@ static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
> >         if (result != SCAN_SUCCEED)
> >                 goto out_up_write;
> >         /* check if the pmd is still valid */
> > -       if (mm_find_pmd(mm, address) != pmd)
> > +       result = check_pmd_still_valid(mm, address, pmd);
> > +       if (result != SCAN_SUCCEED)
> >                 goto out_up_write;
> >
> >         anon_vma_lock_write(vma->anon_vma);
> > @@ -1115,11 +1155,9 @@ static int khugepaged_scan_pmd(struct mm_struct *mm, struct vm_area_struct *vma,
> >
> >         VM_BUG_ON(address & ~HPAGE_PMD_MASK);
> >
> > -       pmd = mm_find_pmd(mm, address);
> > -       if (!pmd) {
> > -               result = SCAN_PMD_NULL;
> > +       result = find_pmd_or_thp_or_none(mm, address, &pmd);
> > +       if (result != SCAN_SUCCEED)
> >                 goto out;
> > -       }
> >
> >         memset(cc->node_load, 0, sizeof(cc->node_load));
> >         pte = pte_offset_map_lock(mm, pmd, address, &ptl);
> > @@ -1373,8 +1411,7 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
> >         if (!PageHead(hpage))
> >                 goto drop_hpage;
> >
> > -       pmd = mm_find_pmd(mm, haddr);
> > -       if (!pmd)
> > +       if (find_pmd_or_thp_or_none(mm, haddr, &pmd) != SCAN_SUCCEED)
> >                 goto drop_hpage;
> >
> >         start_pte = pte_offset_map_lock(mm, pmd, haddr, &ptl);
> > @@ -1492,8 +1529,7 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
> >                 if (vma->vm_end < addr + HPAGE_PMD_SIZE)
> >                         continue;
> >                 mm = vma->vm_mm;
> > -               pmd = mm_find_pmd(mm, addr);
> > -               if (!pmd)
> > +               if (find_pmd_or_thp_or_none(mm, addr, &pmd) != SCAN_SUCCEED)
> >                         continue;
> >                 /*
> >                  * We need exclusive mmap_lock to retract page table.
> > diff --git a/mm/ksm.c b/mm/ksm.c
> > index 075123602bd0..3e0a0a42fa1f 100644
> > --- a/mm/ksm.c
> > +++ b/mm/ksm.c
> > @@ -1136,6 +1136,7 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
> >  {
> >         struct mm_struct *mm = vma->vm_mm;
> >         pmd_t *pmd;
> > +       pmd_t pmde;
> >         pte_t *ptep;
> >         pte_t newpte;
> >         spinlock_t *ptl;
> > @@ -1150,6 +1151,15 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
> >         pmd = mm_find_pmd(mm, addr);
> >         if (!pmd)
> >                 goto out;
> > +       /*
> > +        * Some THP functions use the sequence pmdp_huge_clear_flush(), set_pmd_at()
> > +        * without holding anon_vma lock for write.  So when looking for a
> > +        * genuine pmde (in which to find pte), test present and !THP together.
> > +        */
> > +       pmde = *pmd;
> > +       barrier();
> > +       if (!pmd_present(pmde) || pmd_trans_huge(pmde))
> > +               goto out;
> >
> >         mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, vma, mm, addr,
> >                                 addr + PAGE_SIZE);
> > diff --git a/mm/rmap.c b/mm/rmap.c
> > index edc06c52bc82..af775855e58f 100644
> > --- a/mm/rmap.c
> > +++ b/mm/rmap.c
> > @@ -767,13 +767,17 @@ unsigned long page_address_in_vma(struct page *page, struct vm_area_struct *vma)
> >         return vma_address(page, vma);
> >  }
> >
> > +/*
> > + * Returns the actual pmd_t* where we expect 'address' to be mapped from, or
> > + * NULL if it doesn't exist.  No guarantees / checks on what the pmd_t*
> > + * represents.
> > + */
> >  pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address)
> >  {
> >         pgd_t *pgd;
> >         p4d_t *p4d;
> >         pud_t *pud;
> >         pmd_t *pmd = NULL;
> > -       pmd_t pmde;
> >
> >         pgd = pgd_offset(mm, address);
> >         if (!pgd_present(*pgd))
> > @@ -788,15 +792,6 @@ pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address)
> >                 goto out;
> >
> >         pmd = pmd_offset(pud, address);
> > -       /*
> > -        * Some THP functions use the sequence pmdp_huge_clear_flush(), set_pmd_at()
> > -        * without holding anon_vma lock for write.  So when looking for a
> > -        * genuine pmde (in which to find pte), test present and !THP together.
> > -        */
> > -       pmde = *pmd;
> > -       barrier();
> > -       if (!pmd_present(pmde) || pmd_trans_huge(pmde))
> > -               pmd = NULL;
> >  out:
> >         return pmd;
> >  }
> > --
> > 2.37.0.rc0.161.g10f37bed90-goog
> >
diff mbox series

Patch

diff --git a/include/trace/events/huge_memory.h b/include/trace/events/huge_memory.h
index d651f3437367..55392bf30a03 100644
--- a/include/trace/events/huge_memory.h
+++ b/include/trace/events/huge_memory.h
@@ -11,6 +11,7 @@ 
 	EM( SCAN_FAIL,			"failed")			\
 	EM( SCAN_SUCCEED,		"succeeded")			\
 	EM( SCAN_PMD_NULL,		"pmd_null")			\
+	EM( SCAN_PMD_MAPPED,		"page_pmd_mapped")		\
 	EM( SCAN_EXCEED_NONE_PTE,	"exceed_none_pte")		\
 	EM( SCAN_EXCEED_SWAP_PTE,	"exceed_swap_pte")		\
 	EM( SCAN_EXCEED_SHARED_PTE,	"exceed_shared_pte")		\
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 4fbe43dc1568..fb76db6c703e 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -2363,25 +2363,11 @@  void __split_huge_pmd(struct vm_area_struct *vma, pmd_t *pmd,
 void split_huge_pmd_address(struct vm_area_struct *vma, unsigned long address,
 		bool freeze, struct folio *folio)
 {
-	pgd_t *pgd;
-	p4d_t *p4d;
-	pud_t *pud;
-	pmd_t *pmd;
+	pmd_t *pmd = mm_find_pmd(vma->vm_mm, address);
 
-	pgd = pgd_offset(vma->vm_mm, address);
-	if (!pgd_present(*pgd))
+	if (!pmd)
 		return;
 
-	p4d = p4d_offset(pgd, address);
-	if (!p4d_present(*p4d))
-		return;
-
-	pud = pud_offset(p4d, address);
-	if (!pud_present(*pud))
-		return;
-
-	pmd = pmd_offset(pud, address);
-
 	__split_huge_pmd(vma, pmd, address, freeze, folio);
 }
 
diff --git a/mm/internal.h b/mm/internal.h
index 6e14749ad1e5..ef8c23fb678f 100644
--- a/mm/internal.h
+++ b/mm/internal.h
@@ -188,7 +188,7 @@  extern void reclaim_throttle(pg_data_t *pgdat, enum vmscan_throttle_state reason
 /*
  * in mm/rmap.c:
  */
-extern pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address);
+pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address);
 
 /*
  * in mm/page_alloc.c
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index b0e20db3f805..c7a09cc9a0e8 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -28,6 +28,7 @@  enum scan_result {
 	SCAN_FAIL,
 	SCAN_SUCCEED,
 	SCAN_PMD_NULL,
+	SCAN_PMD_MAPPED,
 	SCAN_EXCEED_NONE_PTE,
 	SCAN_EXCEED_SWAP_PTE,
 	SCAN_EXCEED_SHARED_PTE,
@@ -871,6 +872,45 @@  static int hugepage_vma_revalidate(struct mm_struct *mm, unsigned long address,
 	return SCAN_SUCCEED;
 }
 
+static int find_pmd_or_thp_or_none(struct mm_struct *mm,
+				   unsigned long address,
+				   pmd_t **pmd)
+{
+	pmd_t pmde;
+
+	*pmd = mm_find_pmd(mm, address);
+	if (!*pmd)
+		return SCAN_PMD_NULL;
+
+	pmde = pmd_read_atomic(*pmd);
+
+#ifdef CONFIG_TRANSPARENT_HUGEPAGE
+	/* See comments in pmd_none_or_trans_huge_or_clear_bad() */
+	barrier();
+#endif
+	if (!pmd_present(pmde))
+		return SCAN_PMD_NULL;
+	if (pmd_trans_huge(pmde))
+		return SCAN_PMD_MAPPED;
+	if (pmd_bad(pmde))
+		return SCAN_PMD_NULL;
+	return SCAN_SUCCEED;
+}
+
+static int check_pmd_still_valid(struct mm_struct *mm,
+				 unsigned long address,
+				 pmd_t *pmd)
+{
+	pmd_t *new_pmd;
+	int result = find_pmd_or_thp_or_none(mm, address, &new_pmd);
+
+	if (result != SCAN_SUCCEED)
+		return result;
+	if (new_pmd != pmd)
+		return SCAN_FAIL;
+	return SCAN_SUCCEED;
+}
+
 /*
  * Bring missing pages in from swap, to complete THP collapse.
  * Only done if khugepaged_scan_pmd believes it is worthwhile.
@@ -982,9 +1022,8 @@  static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
 		goto out_nolock;
 	}
 
-	pmd = mm_find_pmd(mm, address);
-	if (!pmd) {
-		result = SCAN_PMD_NULL;
+	result = find_pmd_or_thp_or_none(mm, address, &pmd);
+	if (result != SCAN_SUCCEED) {
 		mmap_read_unlock(mm);
 		goto out_nolock;
 	}
@@ -1012,7 +1051,8 @@  static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
 	if (result != SCAN_SUCCEED)
 		goto out_up_write;
 	/* check if the pmd is still valid */
-	if (mm_find_pmd(mm, address) != pmd)
+	result = check_pmd_still_valid(mm, address, pmd);
+	if (result != SCAN_SUCCEED)
 		goto out_up_write;
 
 	anon_vma_lock_write(vma->anon_vma);
@@ -1115,11 +1155,9 @@  static int khugepaged_scan_pmd(struct mm_struct *mm, struct vm_area_struct *vma,
 
 	VM_BUG_ON(address & ~HPAGE_PMD_MASK);
 
-	pmd = mm_find_pmd(mm, address);
-	if (!pmd) {
-		result = SCAN_PMD_NULL;
+	result = find_pmd_or_thp_or_none(mm, address, &pmd);
+	if (result != SCAN_SUCCEED)
 		goto out;
-	}
 
 	memset(cc->node_load, 0, sizeof(cc->node_load));
 	pte = pte_offset_map_lock(mm, pmd, address, &ptl);
@@ -1373,8 +1411,7 @@  void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
 	if (!PageHead(hpage))
 		goto drop_hpage;
 
-	pmd = mm_find_pmd(mm, haddr);
-	if (!pmd)
+	if (find_pmd_or_thp_or_none(mm, haddr, &pmd) != SCAN_SUCCEED)
 		goto drop_hpage;
 
 	start_pte = pte_offset_map_lock(mm, pmd, haddr, &ptl);
@@ -1492,8 +1529,7 @@  static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 		if (vma->vm_end < addr + HPAGE_PMD_SIZE)
 			continue;
 		mm = vma->vm_mm;
-		pmd = mm_find_pmd(mm, addr);
-		if (!pmd)
+		if (find_pmd_or_thp_or_none(mm, addr, &pmd) != SCAN_SUCCEED)
 			continue;
 		/*
 		 * We need exclusive mmap_lock to retract page table.
diff --git a/mm/ksm.c b/mm/ksm.c
index 075123602bd0..3e0a0a42fa1f 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -1136,6 +1136,7 @@  static int replace_page(struct vm_area_struct *vma, struct page *page,
 {
 	struct mm_struct *mm = vma->vm_mm;
 	pmd_t *pmd;
+	pmd_t pmde;
 	pte_t *ptep;
 	pte_t newpte;
 	spinlock_t *ptl;
@@ -1150,6 +1151,15 @@  static int replace_page(struct vm_area_struct *vma, struct page *page,
 	pmd = mm_find_pmd(mm, addr);
 	if (!pmd)
 		goto out;
+	/*
+	 * Some THP functions use the sequence pmdp_huge_clear_flush(), set_pmd_at()
+	 * without holding anon_vma lock for write.  So when looking for a
+	 * genuine pmde (in which to find pte), test present and !THP together.
+	 */
+	pmde = *pmd;
+	barrier();
+	if (!pmd_present(pmde) || pmd_trans_huge(pmde))
+		goto out;
 
 	mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, vma, mm, addr,
 				addr + PAGE_SIZE);
diff --git a/mm/rmap.c b/mm/rmap.c
index edc06c52bc82..af775855e58f 100644
--- a/mm/rmap.c
+++ b/mm/rmap.c
@@ -767,13 +767,17 @@  unsigned long page_address_in_vma(struct page *page, struct vm_area_struct *vma)
 	return vma_address(page, vma);
 }
 
+/*
+ * Returns the actual pmd_t* where we expect 'address' to be mapped from, or
+ * NULL if it doesn't exist.  No guarantees / checks on what the pmd_t*
+ * represents.
+ */
 pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address)
 {
 	pgd_t *pgd;
 	p4d_t *p4d;
 	pud_t *pud;
 	pmd_t *pmd = NULL;
-	pmd_t pmde;
 
 	pgd = pgd_offset(mm, address);
 	if (!pgd_present(*pgd))
@@ -788,15 +792,6 @@  pmd_t *mm_find_pmd(struct mm_struct *mm, unsigned long address)
 		goto out;
 
 	pmd = pmd_offset(pud, address);
-	/*
-	 * Some THP functions use the sequence pmdp_huge_clear_flush(), set_pmd_at()
-	 * without holding anon_vma lock for write.  So when looking for a
-	 * genuine pmde (in which to find pte), test present and !THP together.
-	 */
-	pmde = *pmd;
-	barrier();
-	if (!pmd_present(pmde) || pmd_trans_huge(pmde))
-		pmd = NULL;
 out:
 	return pmd;
 }