diff mbox series

[RFC,6/6] mm: Expand Copy-On-Write to PTE table

Message ID 20220519183127.3909598-7-shiyn.lin@gmail.com (mailing list archive)
State New
Headers show
Series Introduce Copy-On-Write to Page Table | expand

Commit Message

Chih-En Lin May 19, 2022, 6:31 p.m. UTC
This patch adds the Copy-On-Write (COW) mechanism to the PTE table.
To enable the COW page table use the clone3() system call with the
CLONE_COW_PGTABLE flag. It will set the MMF_COW_PGTABLE flag to the
processes.

It uses the MMF_COW_PGTABLE flag to distinguish the default page table
and the COW one. Moreover, it is difficult to distinguish whether the
entire page table is out of COW state. So the MMF_COW_PGTABLE flag won't
be disabled after its setup.

Since the memory space of the page table is distinctive for each process
in kernel space. It uses the address of the PMD index for the ownership
of the PTE table to identify which one of the processes needs to update
the page table state. In other words, only the owner will update COW PTE
state, like the RSS and pgtable_bytes.

It uses the reference count to control the lifetime of COW PTE table.
When someone breaks COW, it will copy the COW PTE table and decrease the
reference count. But if the reference count is equal to one before the
break COW, it will reuse the COW PTE table.

This patch modifies the part of the copy page table to do the basic COW.
For the break COW, it modifies the part of a page fault, zaps page table
, unmapping, and remapping.

Signed-off-by: Chih-En Lin <shiyn.lin@gmail.com>
---
 include/linux/pgtable.h |   3 +
 mm/memory.c             | 262 ++++++++++++++++++++++++++++++++++++----
 mm/mmap.c               |   4 +
 mm/mremap.c             |   5 +
 4 files changed, 251 insertions(+), 23 deletions(-)

Comments

Christophe Leroy May 20, 2022, 2:49 p.m. UTC | #1
Le 19/05/2022 à 20:31, Chih-En Lin a écrit :
> This patch adds the Copy-On-Write (COW) mechanism to the PTE table.
> To enable the COW page table use the clone3() system call with the
> CLONE_COW_PGTABLE flag. It will set the MMF_COW_PGTABLE flag to the
> processes.
> 
> It uses the MMF_COW_PGTABLE flag to distinguish the default page table
> and the COW one. Moreover, it is difficult to distinguish whether the
> entire page table is out of COW state. So the MMF_COW_PGTABLE flag won't
> be disabled after its setup.
> 
> Since the memory space of the page table is distinctive for each process
> in kernel space. It uses the address of the PMD index for the ownership
> of the PTE table to identify which one of the processes needs to update
> the page table state. In other words, only the owner will update COW PTE
> state, like the RSS and pgtable_bytes.
> 
> It uses the reference count to control the lifetime of COW PTE table.
> When someone breaks COW, it will copy the COW PTE table and decrease the
> reference count. But if the reference count is equal to one before the
> break COW, it will reuse the COW PTE table.
> 
> This patch modifies the part of the copy page table to do the basic COW.
> For the break COW, it modifies the part of a page fault, zaps page table
> , unmapping, and remapping.
> 
> Signed-off-by: Chih-En Lin <shiyn.lin@gmail.com>
> ---
>   include/linux/pgtable.h |   3 +
>   mm/memory.c             | 262 ++++++++++++++++++++++++++++++++++++----
>   mm/mmap.c               |   4 +
>   mm/mremap.c             |   5 +
>   4 files changed, 251 insertions(+), 23 deletions(-)
> 
> diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h
> index 33c01fec7b92..357ce3722ee8 100644
> --- a/include/linux/pgtable.h
> +++ b/include/linux/pgtable.h
> @@ -631,6 +631,9 @@ static inline int cow_pte_refcount_read(pmd_t *pmd)
>          return atomic_read(&pmd_page(*pmd)->cow_pgtable_refcount);
>   }
> 
> +extern int handle_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
> +               unsigned long addr, bool alloc);
> +
>   #ifndef pte_access_permitted
>   #define pte_access_permitted(pte, write) \
>          (pte_present(pte) && (!(write) || pte_write(pte)))
> diff --git a/mm/memory.c b/mm/memory.c
> index aa66af76e214..ff3fcbe4dfb5 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -247,6 +247,8 @@ static inline void free_pmd_range(struct mmu_gather *tlb, pud_t *pud,
>                  next = pmd_addr_end(addr, end);
>                  if (pmd_none_or_clear_bad(pmd))
>                          continue;
> +               BUG_ON(cow_pte_refcount_read(pmd) != 1);
> +               BUG_ON(!cow_pte_owner_is_same(pmd, NULL));

See comment on a previous patch of this series, there seem to be a huge 
number of new BUG_ONs.

>                  free_pte_range(tlb, pmd, addr);
>          } while (pmd++, addr = next, addr != end);
> 
> @@ -1031,7 +1033,7 @@ static inline void cow_pte_rss(struct mm_struct *mm, struct vm_area_struct *vma,
>   static int
>   copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>                 pmd_t *dst_pmd, pmd_t *src_pmd, unsigned long addr,
> -              unsigned long end)
> +              unsigned long end, bool is_src_pte_locked)
>   {
>          struct mm_struct *dst_mm = dst_vma->vm_mm;
>          struct mm_struct *src_mm = src_vma->vm_mm;
> @@ -1053,8 +1055,10 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>                  goto out;
>          }
>          src_pte = pte_offset_map(src_pmd, addr);
> -       src_ptl = pte_lockptr(src_mm, src_pmd);
> -       spin_lock_nested(src_ptl, SINGLE_DEPTH_NESTING);
> +       if (!is_src_pte_locked) {
> +               src_ptl = pte_lockptr(src_mm, src_pmd);
> +               spin_lock_nested(src_ptl, SINGLE_DEPTH_NESTING);
> +       }

Odd construct, that kind of construct often leads to messy errors.

Could you construct things differently by refactoring the code ?

>          orig_src_pte = src_pte;
>          orig_dst_pte = dst_pte;
>          arch_enter_lazy_mmu_mode();
> @@ -1067,7 +1071,8 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>                  if (progress >= 32) {
>                          progress = 0;
>                          if (need_resched() ||
> -                           spin_needbreak(src_ptl) || spin_needbreak(dst_ptl))
> +                           (!is_src_pte_locked && spin_needbreak(src_ptl)) ||
> +                           spin_needbreak(dst_ptl))
>                                  break;
>                  }
>                  if (pte_none(*src_pte)) {
> @@ -1118,7 +1123,8 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>          } while (dst_pte++, src_pte++, addr += PAGE_SIZE, addr != end);
> 
>          arch_leave_lazy_mmu_mode();
> -       spin_unlock(src_ptl);
> +       if (!is_src_pte_locked)
> +               spin_unlock(src_ptl);
>          pte_unmap(orig_src_pte);
>          add_mm_rss_vec(dst_mm, rss);
>          pte_unmap_unlock(orig_dst_pte, dst_ptl);
> @@ -1180,11 +1186,55 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>                                  continue;
>                          /* fall through */
>                  }
> -               if (pmd_none_or_clear_bad(src_pmd))
> -                       continue;
> -               if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
> -                                  addr, next))
> +
> +               if (test_bit(MMF_COW_PGTABLE, &src_mm->flags)) {
> +
> +                        if (pmd_none(*src_pmd))
> +                               continue;

Why not keep the pmd_none_or_clear_bad(src_pmd) instead ?

> +
> +                       /* XXX: Skip if the PTE already COW this time. */
> +                       if (!pmd_none(*dst_pmd) &&

Shouldn't is be a pmd_none_or_clear_bad() ?

> +                           cow_pte_refcount_read(src_pmd) > 1)
> +                               continue;
> +
> +                       /* If PTE doesn't have an owner, the parent needs to
> +                        * take this PTE.
> +                        */
> +                       if (cow_pte_owner_is_same(src_pmd, NULL)) {
> +                               set_cow_pte_owner(src_pmd, src_pmd);
> +                               /* XXX: The process may COW PTE fork two times.
> +                                * But in some situations, owner has cleared.
> +                                * Previously Child (This time is the parent)
> +                                * COW PTE forking, but previously parent, owner
> +                                * , break COW. So it needs to add back the RSS
> +                                * state and pgtable bytes.
> +                                */
> +                               if (!pmd_write(*src_pmd)) {
> +                                       unsigned long pte_start =
> +                                               addr & PMD_MASK;
> +                                       unsigned long pte_end =
> +                                               (addr + PMD_SIZE) & PMD_MASK;
> +                                       cow_pte_rss(src_mm, src_vma, src_pmd,
> +                                           pte_start, pte_end, true /* inc */);
> +                                       mm_inc_nr_ptes(src_mm);
> +                                       smp_wmb();
> +                                       pmd_populate(src_mm, src_pmd,
> +                                                       pmd_page(*src_pmd));
> +                               }
> +                       }
> +
> +                       pmdp_set_wrprotect(src_mm, addr, src_pmd);
> +
> +                       /* Child reference count */
> +                       pmd_get_pte(src_pmd);
> +
> +                       /* COW for PTE table */
> +                       set_pmd_at(dst_mm, addr, dst_pmd, *src_pmd);
> +               } else if (!pmd_none_or_clear_bad(src_pmd) &&

Can't we keep pmd_none_or_clear_bad(src_pmd) common to both cases ?


> +                           copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
> +                                   addr, next, false)) {
>                          return -ENOMEM;
> +               }
>          } while (dst_pmd++, src_pmd++, addr = next, addr != end);
>          return 0;
>   }
> @@ -1336,6 +1386,7 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma)
>   struct zap_details {
>          struct folio *single_folio;     /* Locked folio to be unmapped */
>          bool even_cows;                 /* Zap COWed private pages too? */
> +       bool cow_pte;                   /* Do not free COW PTE */
>   };
> 
>   /* Whether we should zap all COWed (private) pages too */
> @@ -1398,8 +1449,9 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>                          page = vm_normal_page(vma, addr, ptent);
>                          if (unlikely(!should_zap_page(details, page)))
>                                  continue;
> -                       ptent = ptep_get_and_clear_full(mm, addr, pte,
> -                                                       tlb->fullmm);
> +                       if (!details || !details->cow_pte)
> +                               ptent = ptep_get_and_clear_full(mm, addr, pte,
> +                                                               tlb->fullmm);
>                          tlb_remove_tlb_entry(tlb, pte, addr);
>                          if (unlikely(!page))
>                                  continue;
> @@ -1413,8 +1465,11 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>                                      likely(!(vma->vm_flags & VM_SEQ_READ)))
>                                          mark_page_accessed(page);
>                          }
> -                       rss[mm_counter(page)]--;
> -                       page_remove_rmap(page, vma, false);
> +                       if (!details || !details->cow_pte) {
> +                               rss[mm_counter(page)]--;
> +                               page_remove_rmap(page, vma, false);
> +                       } else
> +                               continue;

Can you do the reverse:

			if (details && details->cow_pte)
				continue;

			rss[mm_counter(page)]--;
			page_remove_rmap(page, vma, false);


>                          if (unlikely(page_mapcount(page) < 0))
>                                  print_bad_pte(vma, addr, ptent, page);
>                          if (unlikely(__tlb_remove_page(tlb, page))) {
> @@ -1425,6 +1480,8 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>                          continue;
>                  }
> 
> +               // TODO: Deal COW PTE with swap
> +
>                  entry = pte_to_swp_entry(ptent);
>                  if (is_device_private_entry(entry) ||
>                      is_device_exclusive_entry(entry)) {
> @@ -1513,16 +1570,34 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb,
>                          spin_unlock(ptl);
>                  }
> 
> -               /*
> -                * Here there can be other concurrent MADV_DONTNEED or
> -                * trans huge page faults running, and if the pmd is
> -                * none or trans huge it can change under us. This is
> -                * because MADV_DONTNEED holds the mmap_lock in read
> -                * mode.
> -                */
> -               if (pmd_none_or_trans_huge_or_clear_bad(pmd))
> -                       goto next;
> -               next = zap_pte_range(tlb, vma, pmd, addr, next, details);
> +
> +               if (test_bit(MMF_COW_PGTABLE, &tlb->mm->flags) &&
> +                   !pmd_none(*pmd) && !pmd_write(*pmd)) {

Can't you use pmd_none_or_trans_huge_or_clear_bad() and keep it common ? ...

> +                       struct zap_details cow_pte_details = {0};
> +                       if (details)
> +                               cow_pte_details = *details;
> +                       cow_pte_details.cow_pte = true;
> +                       /* Flush the TLB but do not free the COW PTE */
> +                       next = zap_pte_range(tlb, vma, pmd, addr,
> +                                               next, &cow_pte_details);
> +                       if (details)
> +                               *details = cow_pte_details;
> +                       handle_cow_pte(vma, pmd, addr, false);

Or add a continue; here and avoid the else below

> +               } else {
> +                       if (details)
> +                               details->cow_pte = false;
> +                       /*
> +                        * Here there can be other concurrent MADV_DONTNEED or
> +                        * trans huge page faults running, and if the pmd is
> +                        * none or trans huge it can change under us. This is
> +                        * because MADV_DONTNEED holds the mmap_lock in read
> +                        * mode.
> +                        */
> +                       if (pmd_none_or_trans_huge_or_clear_bad(pmd))
> +                               goto next;
> +                       next = zap_pte_range(tlb, vma, pmd, addr, next,
> +                                       details);
> +               }
>   next:
>                  cond_resched();
>          } while (pmd++, addr = next, addr != end);
> @@ -4621,6 +4696,134 @@ void cow_pte_fallback(struct vm_area_struct *vma, pmd_t *pmd,
>          BUG_ON(pmd_page(*pmd)->cow_pte_owner);
>   }
> 
> +/* Break COW PTE:
> + * - two state here
> + *   - After fork :   [parent, rss=1, ref=2, write=NO , owner=parent]
> + *                 to [parent, rss=1, ref=1, write=YES, owner=NULL  ]
> + *                    COW PTE become [ref=1, write=NO , owner=NULL  ]
> + *                    [child , rss=0, ref=2, write=NO , owner=parent]
> + *                 to [child , rss=1, ref=1, write=YES, owner=NULL  ]
> + *                    COW PTE become [ref=1, write=NO , owner=parent]
> + *   NOTE
> + *     - Copy the COW PTE to new PTE.
> + *     - Clear the owner of COW PTE and set PMD entry writable when it is owner.
> + *     - Increase RSS if it is not owner.
> + */
> +static int break_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
> +               unsigned long addr)
> +{
> +       struct mm_struct *mm = vma->vm_mm;
> +       unsigned long start, end;
> +       pmd_t cowed_entry = *pmd;
> +
> +       if (cow_pte_refcount_read(&cowed_entry) == 1) {
> +               cow_pte_fallback(vma, pmd, addr);
> +               return 1;
> +       }
> +
> +       BUG_ON(pmd_write(cowed_entry));
> +
> +       start = addr & PMD_MASK;
> +       end = (addr + PMD_SIZE) & PMD_MASK;
> +
> +       pmd_clear(pmd);
> +       if (copy_pte_range(vma, vma, pmd, &cowed_entry,
> +                               start, end, true))
> +               return -ENOMEM;
> +
> +       /* Here, it is the owner, so clear the ownership. To keep RSS state and
> +        * page table bytes correct, it needs to decrease them.
> +        */
> +       if (cow_pte_owner_is_same(&cowed_entry, pmd)) {
> +               set_cow_pte_owner(&cowed_entry, NULL);
> +               cow_pte_rss(mm, vma, pmd, start, end, false /* dec */);
> +               mm_dec_nr_ptes(mm);
> +       }
> +
> +       pmd_put_pte(vma, &cowed_entry, addr);
> +
> +       BUG_ON(!pmd_write(*pmd));
> +       BUG_ON(cow_pte_refcount_read(pmd) != 1);
> +
> +       return 0;
> +}
> +
> +static int zap_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
> +               unsigned long addr)
> +{
> +       struct mm_struct *mm = vma->vm_mm;
> +       unsigned long start, end;
> +
> +       if (pmd_put_pte(vma, pmd, addr)) {
> +               // fallback
> +               return 1;
> +       }

No { } for a single line if. The comment could go just before the if.

> +
> +       start = addr & PMD_MASK;
> +       end = (addr + PMD_SIZE) & PMD_MASK;
> +
> +       /* If PMD entry is owner, clear the ownership, and decrease RSS state
> +        * and pgtable_bytes.
> +        */

Please follow the standard comments style:

/*
  * Some text
  * More text
  */

> +       if (cow_pte_owner_is_same(pmd, pmd)) {
> +               set_cow_pte_owner(pmd, NULL);
> +               cow_pte_rss(mm, vma, pmd, start, end, false /* dec */);
> +               mm_dec_nr_ptes(mm);
> +       }
> +
> +       pmd_clear(pmd);
> +       return 0;
> +}
> +
> +/* If alloc set means it won't break COW. For this case, it will just decrease
> + * the reference count. The address needs to be at the beginning of the PTE page
> + * since COW PTE is copy-on-write the entire PTE.
> + * If pmd is NULL, it will get the pmd from vma and check it is cowing.
> + */
> +int handle_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
> +               unsigned long addr, bool alloc)
> +{
> +       pgd_t *pgd;
> +       p4d_t *p4d;
> +       pud_t *pud;
> +       struct mm_struct *mm = vma->vm_mm;
> +       int ret = 0;
> +       spinlock_t *ptl = NULL;
> +
> +       if (!pmd) {
> +               pgd = pgd_offset(mm, addr);
> +               if (pgd_none_or_clear_bad(pgd))
> +                       return 0;
> +               p4d = p4d_offset(pgd, addr);
> +               if (p4d_none_or_clear_bad(p4d))
> +                       return 0;
> +               pud = pud_offset(p4d, addr);
> +               if (pud_none_or_clear_bad(pud))
> +                       return 0;
> +               pmd = pmd_offset(pud, addr);
> +               if (pmd_none(*pmd) || pmd_write(*pmd))
> +                       return 0;
> +       }
> +
> +       // TODO: handle COW PTE with swap
> +       BUG_ON(is_swap_pmd(*pmd));
> +       BUG_ON(pmd_trans_huge(*pmd));
> +       BUG_ON(pmd_devmap(*pmd));
> +
> +       BUG_ON(pmd_none(*pmd));
> +       BUG_ON(pmd_write(*pmd));

So many BUG_ON ? All this has a cost during the execution.

> +
> +       ptl = pte_lockptr(mm, pmd);
> +       spin_lock(ptl);
> +       if (!alloc)
> +               ret = zap_cow_pte(vma, pmd, addr);
> +       else
> +               ret = break_cow_pte(vma, pmd, addr);

Better as

	if (alloc)
		break_cow_pte()
	else
		zap_cow_pte()

> +       spin_unlock(ptl);
> +
> +       return ret;
> +}
> +
>   /*
>    * These routines also need to handle stuff like marking pages dirty
>    * and/or accessed for architectures that don't do it in hardware (most
> @@ -4825,6 +5028,19 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
>                                  return 0;
>                          }
>                  }
> +
> +               /* When the PMD entry is set with write protection, it needs to
> +                * handle the on-demand PTE. It will allocate a new PTE and copy
> +                * the old one, then set this entry writeable and decrease the
> +                * reference count at COW PTE.
> +                */
> +               if (test_bit(MMF_COW_PGTABLE, &mm->flags) &&
> +                   !pmd_none(vmf.orig_pmd) && !pmd_write(vmf.orig_pmd)) {
> +                       if (handle_cow_pte(vmf.vma, vmf.pmd, vmf.real_address,
> +                          (cow_pte_refcount_read(&vmf.orig_pmd) > 1) ?
> +                          true : false) < 0)

(condition ? true : false) is exactly the same as (condition)


> +                               return VM_FAULT_OOM;
> +               }
>          }
> 
>          return handle_pte_fault(&vmf);
> diff --git a/mm/mmap.c b/mm/mmap.c
> index 313b57d55a63..e3a9c38e87e8 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -2709,6 +2709,10 @@ int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
>                          return err;
>          }
> 
> +       if (test_bit(MMF_COW_PGTABLE, &vma->vm_mm->flags) &&
> +           handle_cow_pte(vma, NULL, addr, true) < 0)
> +               return -ENOMEM;
> +
>          new = vm_area_dup(vma);
>          if (!new)
>                  return -ENOMEM;
> diff --git a/mm/mremap.c b/mm/mremap.c
> index 303d3290b938..01aefdfc61b7 100644
> --- a/mm/mremap.c
> +++ b/mm/mremap.c
> @@ -532,6 +532,11 @@ unsigned long move_page_tables(struct vm_area_struct *vma,
>                  old_pmd = get_old_pmd(vma->vm_mm, old_addr);
>                  if (!old_pmd)
>                          continue;
> +
> +               if (test_bit(MMF_COW_PGTABLE, &vma->vm_mm->flags) &&
> +                   !pmd_none(*old_pmd) && !pmd_write(*old_pmd))
> +                       handle_cow_pte(vma, old_pmd, old_addr, true);
> +
>                  new_pmd = alloc_new_pmd(vma->vm_mm, vma, new_addr);
>                  if (!new_pmd)
>                          break;
> --
> 2.36.1
>
Chih-En Lin May 21, 2022, 4:38 a.m. UTC | #2
On Fri, May 20, 2022 at 02:49:31PM +0000, Christophe Leroy wrote:
> 
> 
> Le 19/05/2022 à 20:31, Chih-En Lin a écrit :
> > This patch adds the Copy-On-Write (COW) mechanism to the PTE table.
> > To enable the COW page table use the clone3() system call with the
> > CLONE_COW_PGTABLE flag. It will set the MMF_COW_PGTABLE flag to the
> > processes.
> > 
> > It uses the MMF_COW_PGTABLE flag to distinguish the default page table
> > and the COW one. Moreover, it is difficult to distinguish whether the
> > entire page table is out of COW state. So the MMF_COW_PGTABLE flag won't
> > be disabled after its setup.
> > 
> > Since the memory space of the page table is distinctive for each process
> > in kernel space. It uses the address of the PMD index for the ownership
> > of the PTE table to identify which one of the processes needs to update
> > the page table state. In other words, only the owner will update COW PTE
> > state, like the RSS and pgtable_bytes.
> > 
> > It uses the reference count to control the lifetime of COW PTE table.
> > When someone breaks COW, it will copy the COW PTE table and decrease the
> > reference count. But if the reference count is equal to one before the
> > break COW, it will reuse the COW PTE table.
> > 
> > This patch modifies the part of the copy page table to do the basic COW.
> > For the break COW, it modifies the part of a page fault, zaps page table
> > , unmapping, and remapping.
> > 
> > Signed-off-by: Chih-En Lin <shiyn.lin@gmail.com>
> > ---
> >   include/linux/pgtable.h |   3 +
> >   mm/memory.c             | 262 ++++++++++++++++++++++++++++++++++++----
> >   mm/mmap.c               |   4 +
> >   mm/mremap.c             |   5 +
> >   4 files changed, 251 insertions(+), 23 deletions(-)
> > 
> > diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h
> > index 33c01fec7b92..357ce3722ee8 100644
> > --- a/include/linux/pgtable.h
> > +++ b/include/linux/pgtable.h
> > @@ -631,6 +631,9 @@ static inline int cow_pte_refcount_read(pmd_t *pmd)
> >          return atomic_read(&pmd_page(*pmd)->cow_pgtable_refcount);
> >   }
> > 
> > +extern int handle_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
> > +               unsigned long addr, bool alloc);
> > +
> >   #ifndef pte_access_permitted
> >   #define pte_access_permitted(pte, write) \
> >          (pte_present(pte) && (!(write) || pte_write(pte)))
> > diff --git a/mm/memory.c b/mm/memory.c
> > index aa66af76e214..ff3fcbe4dfb5 100644
> > --- a/mm/memory.c
> > +++ b/mm/memory.c
> > @@ -247,6 +247,8 @@ static inline void free_pmd_range(struct mmu_gather *tlb, pud_t *pud,
> >                  next = pmd_addr_end(addr, end);
> >                  if (pmd_none_or_clear_bad(pmd))
> >                          continue;
> > +               BUG_ON(cow_pte_refcount_read(pmd) != 1);
> > +               BUG_ON(!cow_pte_owner_is_same(pmd, NULL));
> 
> See comment on a previous patch of this series, there seem to be a huge 
> number of new BUG_ONs.

Got it.

> >                  free_pte_range(tlb, pmd, addr);
> >          } while (pmd++, addr = next, addr != end);
> > 
> > @@ -1031,7 +1033,7 @@ static inline void cow_pte_rss(struct mm_struct *mm, struct vm_area_struct *vma,
> >   static int
> >   copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
> >                 pmd_t *dst_pmd, pmd_t *src_pmd, unsigned long addr,
> > -              unsigned long end)
> > +              unsigned long end, bool is_src_pte_locked)
> >   {
> >          struct mm_struct *dst_mm = dst_vma->vm_mm;
> >          struct mm_struct *src_mm = src_vma->vm_mm;
> > @@ -1053,8 +1055,10 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
> >                  goto out;
> >          }
> >          src_pte = pte_offset_map(src_pmd, addr);
> > -       src_ptl = pte_lockptr(src_mm, src_pmd);
> > -       spin_lock_nested(src_ptl, SINGLE_DEPTH_NESTING);
> > +       if (!is_src_pte_locked) {
> > +               src_ptl = pte_lockptr(src_mm, src_pmd);
> > +               spin_lock_nested(src_ptl, SINGLE_DEPTH_NESTING);
> > +       }
> 
> Odd construct, that kind of construct often leads to messy errors.
> 
> Could you construct things differently by refactoring the code ?

Sure, I will try my best.
It's probably why here have the bug when doing the stress testing.

> > @@ -1180,11 +1186,55 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
> >                                  continue;
> >                          /* fall through */
> >                  }
> > -               if (pmd_none_or_clear_bad(src_pmd))
> > -                       continue;
> > -               if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
> > -                                  addr, next))
> > +
> > +               if (test_bit(MMF_COW_PGTABLE, &src_mm->flags)) {
> > +
> > +                        if (pmd_none(*src_pmd))
> > +                               continue;
> 
> Why not keep the pmd_none_or_clear_bad(src_pmd) instead ?
> 
> > +
> > +                       /* XXX: Skip if the PTE already COW this time. */
> > +                       if (!pmd_none(*dst_pmd) &&
> 
> Shouldn't is be a pmd_none_or_clear_bad() ?
> 
> > +                           cow_pte_refcount_read(src_pmd) > 1)
> > +                               continue;
> > +
> > +                       /* If PTE doesn't have an owner, the parent needs to
> > +                        * take this PTE.
> > +                        */
> > +                       if (cow_pte_owner_is_same(src_pmd, NULL)) {
> > +                               set_cow_pte_owner(src_pmd, src_pmd);
> > +                               /* XXX: The process may COW PTE fork two times.
> > +                                * But in some situations, owner has cleared.
> > +                                * Previously Child (This time is the parent)
> > +                                * COW PTE forking, but previously parent, owner
> > +                                * , break COW. So it needs to add back the RSS
> > +                                * state and pgtable bytes.
> > +                                */
> > +                               if (!pmd_write(*src_pmd)) {
> > +                                       unsigned long pte_start =
> > +                                               addr & PMD_MASK;
> > +                                       unsigned long pte_end =
> > +                                               (addr + PMD_SIZE) & PMD_MASK;
> > +                                       cow_pte_rss(src_mm, src_vma, src_pmd,
> > +                                           pte_start, pte_end, true /* inc */);
> > +                                       mm_inc_nr_ptes(src_mm);
> > +                                       smp_wmb();
> > +                                       pmd_populate(src_mm, src_pmd,
> > +                                                       pmd_page(*src_pmd));
> > +                               }
> > +                       }
> > +
> > +                       pmdp_set_wrprotect(src_mm, addr, src_pmd);
> > +
> > +                       /* Child reference count */
> > +                       pmd_get_pte(src_pmd);
> > +
> > +                       /* COW for PTE table */
> > +                       set_pmd_at(dst_mm, addr, dst_pmd, *src_pmd);
> > +               } else if (!pmd_none_or_clear_bad(src_pmd) &&
> 
> Can't we keep pmd_none_or_clear_bad(src_pmd) common to both cases ?
> 

You are right.
I will change to pmd_none_or_clear_bad().

> > +                           copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
> > +                                   addr, next, false)) {
> >                          return -ENOMEM;
> > +               }
> >          } while (dst_pmd++, src_pmd++, addr = next, addr != end);
> >          return 0;
> >   }
> > @@ -1336,6 +1386,7 @@ copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma)
> >   struct zap_details {
> >          struct folio *single_folio;     /* Locked folio to be unmapped */
> >          bool even_cows;                 /* Zap COWed private pages too? */
> > +       bool cow_pte;                   /* Do not free COW PTE */
> >   };
> > 
> >   /* Whether we should zap all COWed (private) pages too */
> > @@ -1398,8 +1449,9 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
> >                          page = vm_normal_page(vma, addr, ptent);
> >                          if (unlikely(!should_zap_page(details, page)))
> >                                  continue;
> > -                       ptent = ptep_get_and_clear_full(mm, addr, pte,
> > -                                                       tlb->fullmm);
> > +                       if (!details || !details->cow_pte)
> > +                               ptent = ptep_get_and_clear_full(mm, addr, pte,
> > +                                                               tlb->fullmm);
> >                          tlb_remove_tlb_entry(tlb, pte, addr);
> >                          if (unlikely(!page))
> >                                  continue;
> > @@ -1413,8 +1465,11 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
> >                                      likely(!(vma->vm_flags & VM_SEQ_READ)))
> >                                          mark_page_accessed(page);
> >                          }
> > -                       rss[mm_counter(page)]--;
> > -                       page_remove_rmap(page, vma, false);
> > +                       if (!details || !details->cow_pte) {
> > +                               rss[mm_counter(page)]--;
> > +                               page_remove_rmap(page, vma, false);
> > +                       } else
> > +                               continue;
> 
> Can you do the reverse:
> 
> 			if (details && details->cow_pte)
> 				continue;
> 
> 			rss[mm_counter(page)]--;
> 			page_remove_rmap(page, vma, false);

It's better than I wrote.
Thanks.

> 
> >                          if (unlikely(page_mapcount(page) < 0))
> >                                  print_bad_pte(vma, addr, ptent, page);
> >                          if (unlikely(__tlb_remove_page(tlb, page))) {
> > @@ -1425,6 +1480,8 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
> >                          continue;
> >                  }
> > 
> > +               // TODO: Deal COW PTE with swap
> > +
> >                  entry = pte_to_swp_entry(ptent);
> >                  if (is_device_private_entry(entry) ||
> >                      is_device_exclusive_entry(entry)) {
> > @@ -1513,16 +1570,34 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb,
> >                          spin_unlock(ptl);
> >                  }
> > 
> > -               /*
> > -                * Here there can be other concurrent MADV_DONTNEED or
> > -                * trans huge page faults running, and if the pmd is
> > -                * none or trans huge it can change under us. This is
> > -                * because MADV_DONTNEED holds the mmap_lock in read
> > -                * mode.
> > -                */
> > -               if (pmd_none_or_trans_huge_or_clear_bad(pmd))
> > -                       goto next;
> > -               next = zap_pte_range(tlb, vma, pmd, addr, next, details);
> > +
> > +               if (test_bit(MMF_COW_PGTABLE, &tlb->mm->flags) &&
> > +                   !pmd_none(*pmd) && !pmd_write(*pmd)) {
> 
> Can't you use pmd_none_or_trans_huge_or_clear_bad() and keep it common ? ...

Sure.

> > +static int zap_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
> > +               unsigned long addr)
> > +{
> > +       struct mm_struct *mm = vma->vm_mm;
> > +       unsigned long start, end;
> > +
> > +       if (pmd_put_pte(vma, pmd, addr)) {
> > +               // fallback
> > +               return 1;
> > +       }
> 
> No { } for a single line if. The comment could go just before the if.
> 
> > +
> > +       start = addr & PMD_MASK;
> > +       end = (addr + PMD_SIZE) & PMD_MASK;
> > +
> > +       /* If PMD entry is owner, clear the ownership, and decrease RSS state
> > +        * and pgtable_bytes.
> > +        */
> 
> Please follow the standard comments style:
> 
> /*
>   * Some text
>   * More text
>   */
> 

Got it.

> > +       if (cow_pte_owner_is_same(pmd, pmd)) {
> > +               set_cow_pte_owner(pmd, NULL);
> > +               cow_pte_rss(mm, vma, pmd, start, end, false /* dec */);
> > +               mm_dec_nr_ptes(mm);
> > +       }
> > +
> > +       pmd_clear(pmd);
> > +       return 0;
> > +}
> > +
> > +/* If alloc set means it won't break COW. For this case, it will just decrease
> > + * the reference count. The address needs to be at the beginning of the PTE page
> > + * since COW PTE is copy-on-write the entire PTE.
> > + * If pmd is NULL, it will get the pmd from vma and check it is cowing.
> > + */
> > +int handle_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
> > +               unsigned long addr, bool alloc)
> > +{
> > +       pgd_t *pgd;
> > +       p4d_t *p4d;
> > +       pud_t *pud;
> > +       struct mm_struct *mm = vma->vm_mm;
> > +       int ret = 0;
> > +       spinlock_t *ptl = NULL;
> > +
> > +       if (!pmd) {
> > +               pgd = pgd_offset(mm, addr);
> > +               if (pgd_none_or_clear_bad(pgd))
> > +                       return 0;
> > +               p4d = p4d_offset(pgd, addr);
> > +               if (p4d_none_or_clear_bad(p4d))
> > +                       return 0;
> > +               pud = pud_offset(p4d, addr);
> > +               if (pud_none_or_clear_bad(pud))
> > +                       return 0;
> > +               pmd = pmd_offset(pud, addr);
> > +               if (pmd_none(*pmd) || pmd_write(*pmd))
> > +                       return 0;
> > +       }
> > +
> > +       // TODO: handle COW PTE with swap
> > +       BUG_ON(is_swap_pmd(*pmd));
> > +       BUG_ON(pmd_trans_huge(*pmd));
> > +       BUG_ON(pmd_devmap(*pmd));
> > +
> > +       BUG_ON(pmd_none(*pmd));
> > +       BUG_ON(pmd_write(*pmd));
> 
> So many BUG_ON ? All this has a cost during the execution.

I will consider it again.

> > +
> > +       ptl = pte_lockptr(mm, pmd);
> > +       spin_lock(ptl);
> > +       if (!alloc)
> > +               ret = zap_cow_pte(vma, pmd, addr);
> > +       else
> > +               ret = break_cow_pte(vma, pmd, addr);
> 
> Better as
> 
> 	if (alloc)
> 		break_cow_pte()
> 	else
> 		zap_cow_pte()

Great!
Thanks.

> > +       spin_unlock(ptl);
> > +
> > +       return ret;
> > +}
> > +
> >   /*
> >    * These routines also need to handle stuff like marking pages dirty
> >    * and/or accessed for architectures that don't do it in hardware (most
> > @@ -4825,6 +5028,19 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
> >                                  return 0;
> >                          }
> >                  }
> > +
> > +               /* When the PMD entry is set with write protection, it needs to
> > +                * handle the on-demand PTE. It will allocate a new PTE and copy
> > +                * the old one, then set this entry writeable and decrease the
> > +                * reference count at COW PTE.
> > +                */
> > +               if (test_bit(MMF_COW_PGTABLE, &mm->flags) &&
> > +                   !pmd_none(vmf.orig_pmd) && !pmd_write(vmf.orig_pmd)) {
> > +                       if (handle_cow_pte(vmf.vma, vmf.pmd, vmf.real_address,
> > +                          (cow_pte_refcount_read(&vmf.orig_pmd) > 1) ?
> > +                          true : false) < 0)
> 
> (condition ? true : false) is exactly the same as (condition)
> 

I knew. ;-)

Again, thanks!
diff mbox series

Patch

diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h
index 33c01fec7b92..357ce3722ee8 100644
--- a/include/linux/pgtable.h
+++ b/include/linux/pgtable.h
@@ -631,6 +631,9 @@  static inline int cow_pte_refcount_read(pmd_t *pmd)
 	return atomic_read(&pmd_page(*pmd)->cow_pgtable_refcount);
 }
 
+extern int handle_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
+		unsigned long addr, bool alloc);
+
 #ifndef pte_access_permitted
 #define pte_access_permitted(pte, write) \
 	(pte_present(pte) && (!(write) || pte_write(pte)))
diff --git a/mm/memory.c b/mm/memory.c
index aa66af76e214..ff3fcbe4dfb5 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -247,6 +247,8 @@  static inline void free_pmd_range(struct mmu_gather *tlb, pud_t *pud,
 		next = pmd_addr_end(addr, end);
 		if (pmd_none_or_clear_bad(pmd))
 			continue;
+		BUG_ON(cow_pte_refcount_read(pmd) != 1);
+		BUG_ON(!cow_pte_owner_is_same(pmd, NULL));
 		free_pte_range(tlb, pmd, addr);
 	} while (pmd++, addr = next, addr != end);
 
@@ -1031,7 +1033,7 @@  static inline void cow_pte_rss(struct mm_struct *mm, struct vm_area_struct *vma,
 static int
 copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
 	       pmd_t *dst_pmd, pmd_t *src_pmd, unsigned long addr,
-	       unsigned long end)
+	       unsigned long end, bool is_src_pte_locked)
 {
 	struct mm_struct *dst_mm = dst_vma->vm_mm;
 	struct mm_struct *src_mm = src_vma->vm_mm;
@@ -1053,8 +1055,10 @@  copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
 		goto out;
 	}
 	src_pte = pte_offset_map(src_pmd, addr);
-	src_ptl = pte_lockptr(src_mm, src_pmd);
-	spin_lock_nested(src_ptl, SINGLE_DEPTH_NESTING);
+	if (!is_src_pte_locked) {
+		src_ptl = pte_lockptr(src_mm, src_pmd);
+		spin_lock_nested(src_ptl, SINGLE_DEPTH_NESTING);
+	}
 	orig_src_pte = src_pte;
 	orig_dst_pte = dst_pte;
 	arch_enter_lazy_mmu_mode();
@@ -1067,7 +1071,8 @@  copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
 		if (progress >= 32) {
 			progress = 0;
 			if (need_resched() ||
-			    spin_needbreak(src_ptl) || spin_needbreak(dst_ptl))
+			    (!is_src_pte_locked && spin_needbreak(src_ptl)) ||
+			    spin_needbreak(dst_ptl))
 				break;
 		}
 		if (pte_none(*src_pte)) {
@@ -1118,7 +1123,8 @@  copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
 	} while (dst_pte++, src_pte++, addr += PAGE_SIZE, addr != end);
 
 	arch_leave_lazy_mmu_mode();
-	spin_unlock(src_ptl);
+	if (!is_src_pte_locked)
+		spin_unlock(src_ptl);
 	pte_unmap(orig_src_pte);
 	add_mm_rss_vec(dst_mm, rss);
 	pte_unmap_unlock(orig_dst_pte, dst_ptl);
@@ -1180,11 +1186,55 @@  copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
 				continue;
 			/* fall through */
 		}
-		if (pmd_none_or_clear_bad(src_pmd))
-			continue;
-		if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
-				   addr, next))
+
+		if (test_bit(MMF_COW_PGTABLE, &src_mm->flags)) {
+
+			 if (pmd_none(*src_pmd))
+				continue;
+
+			/* XXX: Skip if the PTE already COW this time. */
+			if (!pmd_none(*dst_pmd) &&
+			    cow_pte_refcount_read(src_pmd) > 1)
+				continue;
+
+			/* If PTE doesn't have an owner, the parent needs to
+			 * take this PTE.
+			 */
+			if (cow_pte_owner_is_same(src_pmd, NULL)) {
+				set_cow_pte_owner(src_pmd, src_pmd);
+				/* XXX: The process may COW PTE fork two times.
+				 * But in some situations, owner has cleared.
+				 * Previously Child (This time is the parent)
+				 * COW PTE forking, but previously parent, owner
+				 * , break COW. So it needs to add back the RSS
+				 * state and pgtable bytes.
+				 */
+				if (!pmd_write(*src_pmd)) {
+					unsigned long pte_start =
+						addr & PMD_MASK;
+					unsigned long pte_end =
+						(addr + PMD_SIZE) & PMD_MASK;
+					cow_pte_rss(src_mm, src_vma, src_pmd,
+					    pte_start, pte_end, true /* inc */);
+					mm_inc_nr_ptes(src_mm);
+					smp_wmb();
+					pmd_populate(src_mm, src_pmd,
+							pmd_page(*src_pmd));
+				}
+			}
+
+			pmdp_set_wrprotect(src_mm, addr, src_pmd);
+
+			/* Child reference count */
+			pmd_get_pte(src_pmd);
+
+			/* COW for PTE table */
+			set_pmd_at(dst_mm, addr, dst_pmd, *src_pmd);
+		} else if (!pmd_none_or_clear_bad(src_pmd) &&
+			    copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
+				    addr, next, false)) {
 			return -ENOMEM;
+		}
 	} while (dst_pmd++, src_pmd++, addr = next, addr != end);
 	return 0;
 }
@@ -1336,6 +1386,7 @@  copy_page_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma)
 struct zap_details {
 	struct folio *single_folio;	/* Locked folio to be unmapped */
 	bool even_cows;			/* Zap COWed private pages too? */
+	bool cow_pte;			/* Do not free COW PTE */
 };
 
 /* Whether we should zap all COWed (private) pages too */
@@ -1398,8 +1449,9 @@  static unsigned long zap_pte_range(struct mmu_gather *tlb,
 			page = vm_normal_page(vma, addr, ptent);
 			if (unlikely(!should_zap_page(details, page)))
 				continue;
-			ptent = ptep_get_and_clear_full(mm, addr, pte,
-							tlb->fullmm);
+			if (!details || !details->cow_pte)
+				ptent = ptep_get_and_clear_full(mm, addr, pte,
+								tlb->fullmm);
 			tlb_remove_tlb_entry(tlb, pte, addr);
 			if (unlikely(!page))
 				continue;
@@ -1413,8 +1465,11 @@  static unsigned long zap_pte_range(struct mmu_gather *tlb,
 				    likely(!(vma->vm_flags & VM_SEQ_READ)))
 					mark_page_accessed(page);
 			}
-			rss[mm_counter(page)]--;
-			page_remove_rmap(page, vma, false);
+			if (!details || !details->cow_pte) {
+				rss[mm_counter(page)]--;
+				page_remove_rmap(page, vma, false);
+			} else
+				continue;
 			if (unlikely(page_mapcount(page) < 0))
 				print_bad_pte(vma, addr, ptent, page);
 			if (unlikely(__tlb_remove_page(tlb, page))) {
@@ -1425,6 +1480,8 @@  static unsigned long zap_pte_range(struct mmu_gather *tlb,
 			continue;
 		}
 
+		// TODO: Deal COW PTE with swap
+
 		entry = pte_to_swp_entry(ptent);
 		if (is_device_private_entry(entry) ||
 		    is_device_exclusive_entry(entry)) {
@@ -1513,16 +1570,34 @@  static inline unsigned long zap_pmd_range(struct mmu_gather *tlb,
 			spin_unlock(ptl);
 		}
 
-		/*
-		 * Here there can be other concurrent MADV_DONTNEED or
-		 * trans huge page faults running, and if the pmd is
-		 * none or trans huge it can change under us. This is
-		 * because MADV_DONTNEED holds the mmap_lock in read
-		 * mode.
-		 */
-		if (pmd_none_or_trans_huge_or_clear_bad(pmd))
-			goto next;
-		next = zap_pte_range(tlb, vma, pmd, addr, next, details);
+
+		if (test_bit(MMF_COW_PGTABLE, &tlb->mm->flags) &&
+		    !pmd_none(*pmd) && !pmd_write(*pmd)) {
+			struct zap_details cow_pte_details = {0};
+			if (details)
+				cow_pte_details = *details;
+			cow_pte_details.cow_pte = true;
+			/* Flush the TLB but do not free the COW PTE */
+			next = zap_pte_range(tlb, vma, pmd, addr,
+						next, &cow_pte_details);
+			if (details)
+				*details = cow_pte_details;
+			handle_cow_pte(vma, pmd, addr, false);
+		} else {
+			if (details)
+				details->cow_pte = false;
+			/*
+			 * Here there can be other concurrent MADV_DONTNEED or
+			 * trans huge page faults running, and if the pmd is
+			 * none or trans huge it can change under us. This is
+			 * because MADV_DONTNEED holds the mmap_lock in read
+			 * mode.
+			 */
+			if (pmd_none_or_trans_huge_or_clear_bad(pmd))
+				goto next;
+			next = zap_pte_range(tlb, vma, pmd, addr, next,
+					details);
+		}
 next:
 		cond_resched();
 	} while (pmd++, addr = next, addr != end);
@@ -4621,6 +4696,134 @@  void cow_pte_fallback(struct vm_area_struct *vma, pmd_t *pmd,
 	BUG_ON(pmd_page(*pmd)->cow_pte_owner);
 }
 
+/* Break COW PTE:
+ * - two state here
+ *   - After fork :   [parent, rss=1, ref=2, write=NO , owner=parent]
+ *                 to [parent, rss=1, ref=1, write=YES, owner=NULL  ]
+ *                    COW PTE become [ref=1, write=NO , owner=NULL  ]
+ *                    [child , rss=0, ref=2, write=NO , owner=parent]
+ *                 to [child , rss=1, ref=1, write=YES, owner=NULL  ]
+ *                    COW PTE become [ref=1, write=NO , owner=parent]
+ *   NOTE
+ *     - Copy the COW PTE to new PTE.
+ *     - Clear the owner of COW PTE and set PMD entry writable when it is owner.
+ *     - Increase RSS if it is not owner.
+ */
+static int break_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
+		unsigned long addr)
+{
+	struct mm_struct *mm = vma->vm_mm;
+	unsigned long start, end;
+	pmd_t cowed_entry = *pmd;
+
+	if (cow_pte_refcount_read(&cowed_entry) == 1) {
+		cow_pte_fallback(vma, pmd, addr);
+		return 1;
+	}
+
+	BUG_ON(pmd_write(cowed_entry));
+
+	start = addr & PMD_MASK;
+	end = (addr + PMD_SIZE) & PMD_MASK;
+
+	pmd_clear(pmd);
+	if (copy_pte_range(vma, vma, pmd, &cowed_entry,
+				start, end, true))
+		return -ENOMEM;
+
+	/* Here, it is the owner, so clear the ownership. To keep RSS state and
+	 * page table bytes correct, it needs to decrease them.
+	 */
+	if (cow_pte_owner_is_same(&cowed_entry, pmd)) {
+		set_cow_pte_owner(&cowed_entry, NULL);
+		cow_pte_rss(mm, vma, pmd, start, end, false /* dec */);
+		mm_dec_nr_ptes(mm);
+	}
+
+	pmd_put_pte(vma, &cowed_entry, addr);
+
+	BUG_ON(!pmd_write(*pmd));
+	BUG_ON(cow_pte_refcount_read(pmd) != 1);
+
+	return 0;
+}
+
+static int zap_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
+		unsigned long addr)
+{
+	struct mm_struct *mm = vma->vm_mm;
+	unsigned long start, end;
+
+	if (pmd_put_pte(vma, pmd, addr)) {
+		// fallback
+		return 1;
+	}
+
+	start = addr & PMD_MASK;
+	end = (addr + PMD_SIZE) & PMD_MASK;
+
+	/* If PMD entry is owner, clear the ownership, and decrease RSS state
+	 * and pgtable_bytes.
+	 */
+	if (cow_pte_owner_is_same(pmd, pmd)) {
+		set_cow_pte_owner(pmd, NULL);
+		cow_pte_rss(mm, vma, pmd, start, end, false /* dec */);
+		mm_dec_nr_ptes(mm);
+	}
+
+	pmd_clear(pmd);
+	return 0;
+}
+
+/* If alloc set means it won't break COW. For this case, it will just decrease
+ * the reference count. The address needs to be at the beginning of the PTE page
+ * since COW PTE is copy-on-write the entire PTE.
+ * If pmd is NULL, it will get the pmd from vma and check it is cowing.
+ */
+int handle_cow_pte(struct vm_area_struct *vma, pmd_t *pmd,
+		unsigned long addr, bool alloc)
+{
+	pgd_t *pgd;
+	p4d_t *p4d;
+	pud_t *pud;
+	struct mm_struct *mm = vma->vm_mm;
+	int ret = 0;
+	spinlock_t *ptl = NULL;
+
+	if (!pmd) {
+		pgd = pgd_offset(mm, addr);
+		if (pgd_none_or_clear_bad(pgd))
+			return 0;
+		p4d = p4d_offset(pgd, addr);
+		if (p4d_none_or_clear_bad(p4d))
+			return 0;
+		pud = pud_offset(p4d, addr);
+		if (pud_none_or_clear_bad(pud))
+			return 0;
+		pmd = pmd_offset(pud, addr);
+		if (pmd_none(*pmd) || pmd_write(*pmd))
+			return 0;
+	}
+
+	// TODO: handle COW PTE with swap
+	BUG_ON(is_swap_pmd(*pmd));
+	BUG_ON(pmd_trans_huge(*pmd));
+	BUG_ON(pmd_devmap(*pmd));
+
+	BUG_ON(pmd_none(*pmd));
+	BUG_ON(pmd_write(*pmd));
+
+	ptl = pte_lockptr(mm, pmd);
+	spin_lock(ptl);
+	if (!alloc)
+		ret = zap_cow_pte(vma, pmd, addr);
+	else
+		ret = break_cow_pte(vma, pmd, addr);
+	spin_unlock(ptl);
+
+	return ret;
+}
+
 /*
  * These routines also need to handle stuff like marking pages dirty
  * and/or accessed for architectures that don't do it in hardware (most
@@ -4825,6 +5028,19 @@  static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
 				return 0;
 			}
 		}
+
+		/* When the PMD entry is set with write protection, it needs to
+		 * handle the on-demand PTE. It will allocate a new PTE and copy
+		 * the old one, then set this entry writeable and decrease the
+		 * reference count at COW PTE.
+		 */
+		if (test_bit(MMF_COW_PGTABLE, &mm->flags) &&
+		    !pmd_none(vmf.orig_pmd) && !pmd_write(vmf.orig_pmd)) {
+			if (handle_cow_pte(vmf.vma, vmf.pmd, vmf.real_address,
+			   (cow_pte_refcount_read(&vmf.orig_pmd) > 1) ?
+			   true : false) < 0)
+				return VM_FAULT_OOM;
+		}
 	}
 
 	return handle_pte_fault(&vmf);
diff --git a/mm/mmap.c b/mm/mmap.c
index 313b57d55a63..e3a9c38e87e8 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -2709,6 +2709,10 @@  int __split_vma(struct mm_struct *mm, struct vm_area_struct *vma,
 			return err;
 	}
 
+	if (test_bit(MMF_COW_PGTABLE, &vma->vm_mm->flags) &&
+	    handle_cow_pte(vma, NULL, addr, true) < 0)
+		return -ENOMEM;
+
 	new = vm_area_dup(vma);
 	if (!new)
 		return -ENOMEM;
diff --git a/mm/mremap.c b/mm/mremap.c
index 303d3290b938..01aefdfc61b7 100644
--- a/mm/mremap.c
+++ b/mm/mremap.c
@@ -532,6 +532,11 @@  unsigned long move_page_tables(struct vm_area_struct *vma,
 		old_pmd = get_old_pmd(vma->vm_mm, old_addr);
 		if (!old_pmd)
 			continue;
+
+		if (test_bit(MMF_COW_PGTABLE, &vma->vm_mm->flags) &&
+		    !pmd_none(*old_pmd) && !pmd_write(*old_pmd))
+			handle_cow_pte(vma, old_pmd, old_addr, true);
+
 		new_pmd = alloc_new_pmd(vma->vm_mm, vma, new_addr);
 		if (!new_pmd)
 			break;