diff mbox series

[RESEND,v2,1/9] riscv: Restore the pfn in a NAPOT pte when manipulated by core mm code

Message ID 20240508113419.18620-2-alexghiti@rivosinc.com (mailing list archive)
State New
Headers show
Series Merge arm64/riscv hugetlbfs contpte support | expand

Commit Message

Alexandre Ghiti May 8, 2024, 11:34 a.m. UTC
The core mm code expects to be able to extract the pfn from a pte. NAPOT
mappings work differently since its ptes actually point to the first pfn
of the mapping, the other bits being used to encode the size of the
mapping.

So modify ptep_get() so that it returns a pte value that contains the
*real* pfn (which is then different from what the HW expects) and right
before storing the ptes to the page table, reset the pfn LSBs to the
size of the mapping.

And make sure that all NAPOT mappings are set using set_ptes().

Signed-off-by: Alexandre Ghiti <alexghiti@rivosinc.com>
---
 arch/riscv/include/asm/pgtable-64.h |  11 +++
 arch/riscv/include/asm/pgtable.h    | 105 ++++++++++++++++++++++++++--
 arch/riscv/mm/hugetlbpage.c         |  38 +++++-----
 3 files changed, 128 insertions(+), 26 deletions(-)

Comments

Ryan Roberts May 10, 2024, 12:20 p.m. UTC | #1
On 08/05/2024 12:34, Alexandre Ghiti wrote:
> The core mm code expects to be able to extract the pfn from a pte. NAPOT
> mappings work differently since its ptes actually point to the first pfn
> of the mapping, the other bits being used to encode the size of the
> mapping.
> 
> So modify ptep_get() so that it returns a pte value that contains the
> *real* pfn (which is then different from what the HW expects) and right
> before storing the ptes to the page table, reset the pfn LSBs to the
> size of the mapping.

Did you consider leaving the pte as is and instead modifying your pte_pfn()
implementation?

For arm64 at least, it is beneficial to keep the pte marked as contiguous when
passing it up to core-mm because there are other helpers which need to parse the
contiguous bit (e.g. pte_leaf_size()). If we were to clear the cont bit in
ptep_get() that info would be lost and perf_get_pgtable_size() would always
conclude the leaf size is 4K even when it is actually 64K.

> 
> And make sure that all NAPOT mappings are set using set_ptes().
> 
> Signed-off-by: Alexandre Ghiti <alexghiti@rivosinc.com>
> ---
>  arch/riscv/include/asm/pgtable-64.h |  11 +++
>  arch/riscv/include/asm/pgtable.h    | 105 ++++++++++++++++++++++++++--
>  arch/riscv/mm/hugetlbpage.c         |  38 +++++-----
>  3 files changed, 128 insertions(+), 26 deletions(-)
> 
> diff --git a/arch/riscv/include/asm/pgtable-64.h b/arch/riscv/include/asm/pgtable-64.h
> index 221a5c1ee287..9fe076fc503e 100644
> --- a/arch/riscv/include/asm/pgtable-64.h
> +++ b/arch/riscv/include/asm/pgtable-64.h
> @@ -106,6 +106,17 @@ enum napot_cont_order {
>  #define napot_cont_mask(order)	(~(napot_cont_size(order) - 1UL))
>  #define napot_pte_num(order)	BIT(order)
>  
> +static inline bool is_napot_order(unsigned int order)
> +{
> +	unsigned int napot_order;
> +
> +	for_each_napot_order(napot_order)
> +		if (order == napot_order)
> +			return true;
> +
> +	return false;
> +}
> +
>  #ifdef CONFIG_RISCV_ISA_SVNAPOT
>  #define HUGE_MAX_HSTATE		(2 + (NAPOT_ORDER_MAX - NAPOT_CONT_ORDER_BASE))
>  #else
> diff --git a/arch/riscv/include/asm/pgtable.h b/arch/riscv/include/asm/pgtable.h
> index 9f8ea0e33eb1..268c828f5152 100644
> --- a/arch/riscv/include/asm/pgtable.h
> +++ b/arch/riscv/include/asm/pgtable.h
> @@ -297,6 +297,8 @@ static inline unsigned long pte_napot(pte_t pte)
>  	return pte_val(pte) & _PAGE_NAPOT;
>  }
>  
> +#define pte_valid_napot(pte)	(pte_present(pte) && pte_napot(pte))
> +
>  static inline pte_t pte_mknapot(pte_t pte, unsigned int order)
>  {
>  	int pos = order - 1 + _PAGE_PFN_SHIFT;
> @@ -306,6 +308,12 @@ static inline pte_t pte_mknapot(pte_t pte, unsigned int order)
>  	return __pte((pte_val(pte) & napot_mask) | napot_bit | _PAGE_NAPOT);
>  }
>  
> +/* pte at entry must *not* encode the mapping size in the pfn LSBs. */
> +static inline pte_t pte_clear_napot(pte_t pte)
> +{
> +	return __pte(pte_val(pte) & ~_PAGE_NAPOT);
> +}
> +
>  #else
>  
>  static __always_inline bool has_svnapot(void) { return false; }
> @@ -315,17 +323,14 @@ static inline unsigned long pte_napot(pte_t pte)
>  	return 0;
>  }
>  
> +#define pte_valid_napot(pte)	false
> +
>  #endif /* CONFIG_RISCV_ISA_SVNAPOT */
>  
>  /* Yields the page frame number (PFN) of a page table entry */
>  static inline unsigned long pte_pfn(pte_t pte)
>  {
> -	unsigned long res  = __page_val_to_pfn(pte_val(pte));
> -
> -	if (has_svnapot() && pte_napot(pte))
> -		res = res & (res - 1UL);
> -
> -	return res;
> +	return __page_val_to_pfn(pte_val(pte));
>  }
>  
>  #define pte_page(x)     pfn_to_page(pte_pfn(x))
> @@ -525,9 +530,91 @@ static inline void __set_pte_at(struct mm_struct *mm, pte_t *ptep, pte_t pteval)
>  
>  #define PFN_PTE_SHIFT		_PAGE_PFN_SHIFT
>  
> +#ifdef CONFIG_RISCV_ISA_SVNAPOT
> +static inline int arch_contpte_get_num_contig(pte_t *ptep, unsigned long size,
> +					      size_t *pgsize)
> +{
> +	pte_t __pte;
> +
> +	/* We must read the raw value of the pte to get the size of the mapping */
> +	__pte = READ_ONCE(*ptep);
> +
> +	if (pgsize) {
> +		if (size >= PGDIR_SIZE)
> +			*pgsize = PGDIR_SIZE;
> +		else if (size >= P4D_SIZE)
> +			*pgsize = P4D_SIZE;
> +		else if (size >= PUD_SIZE)
> +			*pgsize = PUD_SIZE;
> +		else if (size >= PMD_SIZE)
> +			*pgsize = PMD_SIZE;
> +		else
> +			*pgsize = PAGE_SIZE;
> +	}
> +
> +	/* Make sure __pte is not a swap entry */
> +	if (pte_valid_napot(__pte))
> +		return napot_pte_num(napot_cont_order(__pte));
> +
> +	return 1;
> +}
> +#endif
> +
> +static inline pte_t ptep_get(pte_t *ptep)
> +{
> +	pte_t pte = READ_ONCE(*ptep);
> +
> +#ifdef CONFIG_RISCV_ISA_SVNAPOT
> +	/*
> +	 * The pte we load has the N bit set and the size of the mapping in
> +	 * the pfn LSBs: keep the N bit and replace the mapping size with
> +	 * the *real* pfn since the core mm code expects to find it there.
> +	 * The mapping size will be reset just before being written to the
> +	 * page table in set_ptes().
> +	 */
> +	if (unlikely(pte_valid_napot(pte))) {
> +		unsigned int order = napot_cont_order(pte);
> +		int pos = order - 1 + _PAGE_PFN_SHIFT;
> +		unsigned long napot_mask = ~GENMASK(pos, _PAGE_PFN_SHIFT);
> +		pte_t *orig_ptep = PTR_ALIGN_DOWN(ptep, sizeof(*ptep) * napot_pte_num(order));
> +
> +		pte = __pte((pte_val(pte) & napot_mask) + ((ptep - orig_ptep) << _PAGE_PFN_SHIFT));
> +	}
> +#endif
> +
> +	return pte;
> +}
> +#define ptep_get	ptep_get
> +
>  static inline void set_ptes(struct mm_struct *mm, unsigned long addr,
>  		pte_t *ptep, pte_t pteval, unsigned int nr)
>  {
> +#ifdef CONFIG_RISCV_ISA_SVNAPOT
> +	if (unlikely(pte_valid_napot(pteval))) {
> +		unsigned int order = ilog2(nr);
> +
> +		if (!is_napot_order(order)) {
> +			/*
> +			 * Something's weird, we are given a NAPOT pte but the
> +			 * size of the mapping is not a known NAPOT mapping
> +			 * size, so clear the NAPOT bit and map this without
> +			 * NAPOT support: core mm only manipulates pte with the
> +			 * real pfn so we know the pte is valid without the N
> +			 * bit.
> +			 */
> +			pr_err("Incorrect NAPOT mapping, resetting.\n");
> +			pteval = pte_clear_napot(pteval);
> +		} else {
> +			/*
> +			 * NAPOT ptes that arrive here only have the N bit set
> +			 * and their pfn does not contain the mapping size, so
> +			 * set that here.
> +			 */
> +			pteval = pte_mknapot(pteval, order);
> +		}
> +	}
> +#endif

I think all this complexity comes along due to using this function both as a
public interface that the core-mm uses (which never sets napot) and also using
it as an internal interface that riscv-hugetlb uses (which does set napot)? It
might be more understandable if you layer it into a lower level/internal API and
a higher level/public API (similar to arm64)?

> +
>  	page_table_check_ptes_set(mm, ptep, pteval, nr);
>  
>  	for (;;) {
> @@ -535,6 +622,12 @@ static inline void set_ptes(struct mm_struct *mm, unsigned long addr,
>  		if (--nr == 0)
>  			break;
>  		ptep++;
> +
> +#ifdef CONFIG_RISCV_ISA_SVNAPOT
> +		if (unlikely(pte_valid_napot(pteval)))
> +			continue;
> +#endif
> +
>  		pte_val(pteval) += 1 << _PAGE_PFN_SHIFT;
>  	}
>  }
> diff --git a/arch/riscv/mm/hugetlbpage.c b/arch/riscv/mm/hugetlbpage.c
> index 5ef2a6891158..fe8067ee71b4 100644
> --- a/arch/riscv/mm/hugetlbpage.c
> +++ b/arch/riscv/mm/hugetlbpage.c
> @@ -256,8 +256,7 @@ void set_huge_pte_at(struct mm_struct *mm,
>  
>  	clear_flush(mm, addr, ptep, pgsize, pte_num);
>  
> -	for (i = 0; i < pte_num; i++, ptep++, addr += pgsize)
> -		set_pte_at(mm, addr, ptep, pte);
> +	set_ptes(mm, addr, ptep, pte, pte_num);
>  }
>  
>  int huge_ptep_set_access_flags(struct vm_area_struct *vma,
> @@ -267,16 +266,16 @@ int huge_ptep_set_access_flags(struct vm_area_struct *vma,
>  			       int dirty)
>  {
>  	struct mm_struct *mm = vma->vm_mm;
> -	unsigned long order;
> +	size_t pgsize;
>  	pte_t orig_pte;
> -	int i, pte_num;
> +	int pte_num;
>  
>  	if (!pte_napot(pte))
>  		return ptep_set_access_flags(vma, addr, ptep, pte, dirty);
>  
> -	order = napot_cont_order(pte);
> -	pte_num = napot_pte_num(order);
> -	ptep = huge_pte_offset(mm, addr, napot_cont_size(order));
> +	pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
> +	ptep = huge_pte_offset(mm, addr, pte_num * pgsize);
> +
>  	orig_pte = get_clear_contig_flush(mm, addr, ptep, pte_num);
>  
>  	if (pte_dirty(orig_pte))
> @@ -285,8 +284,7 @@ int huge_ptep_set_access_flags(struct vm_area_struct *vma,
>  	if (pte_young(orig_pte))
>  		pte = pte_mkyoung(pte);
>  
> -	for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
> -		set_pte_at(mm, addr, ptep, pte);
> +	set_ptes(mm, addr, ptep, pte, pte_num);
>  
>  	return true;
>  }
> @@ -301,7 +299,7 @@ pte_t huge_ptep_get_and_clear(struct mm_struct *mm,
>  	if (!pte_napot(orig_pte))
>  		return ptep_get_and_clear(mm, addr, ptep);
>  
> -	pte_num = napot_pte_num(napot_cont_order(orig_pte));
> +	pte_num = arch_contpte_get_num_contig(ptep, 0, NULL);
>  
>  	return get_clear_contig(mm, addr, ptep, pte_num);
>  }
> @@ -311,24 +309,23 @@ void huge_ptep_set_wrprotect(struct mm_struct *mm,
>  			     pte_t *ptep)
>  {
>  	pte_t pte = ptep_get(ptep);
> -	unsigned long order;
> +	size_t pgsize;
>  	pte_t orig_pte;
> -	int i, pte_num;
> +	int pte_num;
>  
>  	if (!pte_napot(pte)) {
>  		ptep_set_wrprotect(mm, addr, ptep);
>  		return;
>  	}
>  
> -	order = napot_cont_order(pte);
> -	pte_num = napot_pte_num(order);
> -	ptep = huge_pte_offset(mm, addr, napot_cont_size(order));
> +	pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
> +	ptep = huge_pte_offset(mm, addr, pte_num * pgsize);
> +
>  	orig_pte = get_clear_contig_flush(mm, addr, ptep, pte_num);
>  
>  	orig_pte = pte_wrprotect(orig_pte);
>  
> -	for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
> -		set_pte_at(mm, addr, ptep, orig_pte);
> +	set_ptes(mm, addr, ptep, orig_pte, pte_num);
>  }
>  
>  pte_t huge_ptep_clear_flush(struct vm_area_struct *vma,
> @@ -341,7 +338,7 @@ pte_t huge_ptep_clear_flush(struct vm_area_struct *vma,
>  	if (!pte_napot(pte))
>  		return ptep_clear_flush(vma, addr, ptep);
>  
> -	pte_num = napot_pte_num(napot_cont_order(pte));
> +	pte_num = arch_contpte_get_num_contig(ptep, 0, NULL);
>  
>  	return get_clear_contig_flush(vma->vm_mm, addr, ptep, pte_num);
>  }
> @@ -351,6 +348,7 @@ void huge_pte_clear(struct mm_struct *mm,
>  		    pte_t *ptep,
>  		    unsigned long sz)
>  {
> +	size_t pgsize;
>  	pte_t pte = ptep_get(ptep);
>  	int i, pte_num;
>  
> @@ -359,8 +357,8 @@ void huge_pte_clear(struct mm_struct *mm,
>  		return;
>  	}
>  
> -	pte_num = napot_pte_num(napot_cont_order(pte));
> -	for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
> +	pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
> +	for (i = 0; i < pte_num; i++, addr += pgsize, ptep++)
>  		pte_clear(mm, addr, ptep);
>  }
>
Alexandre Ghiti May 13, 2024, 1:06 p.m. UTC | #2
Hi Ryan,

On Fri, May 10, 2024 at 2:20 PM Ryan Roberts <ryan.roberts@arm.com> wrote:
>
> On 08/05/2024 12:34, Alexandre Ghiti wrote:
> > The core mm code expects to be able to extract the pfn from a pte. NAPOT
> > mappings work differently since its ptes actually point to the first pfn
> > of the mapping, the other bits being used to encode the size of the
> > mapping.
> >
> > So modify ptep_get() so that it returns a pte value that contains the
> > *real* pfn (which is then different from what the HW expects) and right
> > before storing the ptes to the page table, reset the pfn LSBs to the
> > size of the mapping.
>
> Did you consider leaving the pte as is and instead modifying your pte_pfn()
> implementation?
>
> For arm64 at least, it is beneficial to keep the pte marked as contiguous when
> passing it up to core-mm because there are other helpers which need to parse the
> contiguous bit (e.g. pte_leaf_size()). If we were to clear the cont bit in
> ptep_get() that info would be lost and perf_get_pgtable_size() would always
> conclude the leaf size is 4K even when it is actually 64K.

I don't clear the contpte bit here (ie the napot bit), I'm just
setting the right pfn so that the core-mm code knows exactly which
page is targeted by each pte of a contpte region (remember riscv napot
extension uses the lsb of the pte pfn to encode the size the mapping,
so all ptes of a contpte region will return the same pfn).

And from pte_pfn(), we have no way of knowing from the pte value alone
which page is targeted, we need to know its position in the page table
to "guess" the right pfn.

>
> >
> > And make sure that all NAPOT mappings are set using set_ptes().
> >
> > Signed-off-by: Alexandre Ghiti <alexghiti@rivosinc.com>
> > ---
> >  arch/riscv/include/asm/pgtable-64.h |  11 +++
> >  arch/riscv/include/asm/pgtable.h    | 105 ++++++++++++++++++++++++++--
> >  arch/riscv/mm/hugetlbpage.c         |  38 +++++-----
> >  3 files changed, 128 insertions(+), 26 deletions(-)
> >
> > diff --git a/arch/riscv/include/asm/pgtable-64.h b/arch/riscv/include/asm/pgtable-64.h
> > index 221a5c1ee287..9fe076fc503e 100644
> > --- a/arch/riscv/include/asm/pgtable-64.h
> > +++ b/arch/riscv/include/asm/pgtable-64.h
> > @@ -106,6 +106,17 @@ enum napot_cont_order {
> >  #define napot_cont_mask(order)       (~(napot_cont_size(order) - 1UL))
> >  #define napot_pte_num(order) BIT(order)
> >
> > +static inline bool is_napot_order(unsigned int order)
> > +{
> > +     unsigned int napot_order;
> > +
> > +     for_each_napot_order(napot_order)
> > +             if (order == napot_order)
> > +                     return true;
> > +
> > +     return false;
> > +}
> > +
> >  #ifdef CONFIG_RISCV_ISA_SVNAPOT
> >  #define HUGE_MAX_HSTATE              (2 + (NAPOT_ORDER_MAX - NAPOT_CONT_ORDER_BASE))
> >  #else
> > diff --git a/arch/riscv/include/asm/pgtable.h b/arch/riscv/include/asm/pgtable.h
> > index 9f8ea0e33eb1..268c828f5152 100644
> > --- a/arch/riscv/include/asm/pgtable.h
> > +++ b/arch/riscv/include/asm/pgtable.h
> > @@ -297,6 +297,8 @@ static inline unsigned long pte_napot(pte_t pte)
> >       return pte_val(pte) & _PAGE_NAPOT;
> >  }
> >
> > +#define pte_valid_napot(pte) (pte_present(pte) && pte_napot(pte))
> > +
> >  static inline pte_t pte_mknapot(pte_t pte, unsigned int order)
> >  {
> >       int pos = order - 1 + _PAGE_PFN_SHIFT;
> > @@ -306,6 +308,12 @@ static inline pte_t pte_mknapot(pte_t pte, unsigned int order)
> >       return __pte((pte_val(pte) & napot_mask) | napot_bit | _PAGE_NAPOT);
> >  }
> >
> > +/* pte at entry must *not* encode the mapping size in the pfn LSBs. */
> > +static inline pte_t pte_clear_napot(pte_t pte)
> > +{
> > +     return __pte(pte_val(pte) & ~_PAGE_NAPOT);
> > +}
> > +
> >  #else
> >
> >  static __always_inline bool has_svnapot(void) { return false; }
> > @@ -315,17 +323,14 @@ static inline unsigned long pte_napot(pte_t pte)
> >       return 0;
> >  }
> >
> > +#define pte_valid_napot(pte) false
> > +
> >  #endif /* CONFIG_RISCV_ISA_SVNAPOT */
> >
> >  /* Yields the page frame number (PFN) of a page table entry */
> >  static inline unsigned long pte_pfn(pte_t pte)
> >  {
> > -     unsigned long res  = __page_val_to_pfn(pte_val(pte));
> > -
> > -     if (has_svnapot() && pte_napot(pte))
> > -             res = res & (res - 1UL);
> > -
> > -     return res;
> > +     return __page_val_to_pfn(pte_val(pte));
> >  }
> >
> >  #define pte_page(x)     pfn_to_page(pte_pfn(x))
> > @@ -525,9 +530,91 @@ static inline void __set_pte_at(struct mm_struct *mm, pte_t *ptep, pte_t pteval)
> >
> >  #define PFN_PTE_SHIFT                _PAGE_PFN_SHIFT
> >
> > +#ifdef CONFIG_RISCV_ISA_SVNAPOT
> > +static inline int arch_contpte_get_num_contig(pte_t *ptep, unsigned long size,
> > +                                           size_t *pgsize)
> > +{
> > +     pte_t __pte;
> > +
> > +     /* We must read the raw value of the pte to get the size of the mapping */
> > +     __pte = READ_ONCE(*ptep);
> > +
> > +     if (pgsize) {
> > +             if (size >= PGDIR_SIZE)
> > +                     *pgsize = PGDIR_SIZE;
> > +             else if (size >= P4D_SIZE)
> > +                     *pgsize = P4D_SIZE;
> > +             else if (size >= PUD_SIZE)
> > +                     *pgsize = PUD_SIZE;
> > +             else if (size >= PMD_SIZE)
> > +                     *pgsize = PMD_SIZE;
> > +             else
> > +                     *pgsize = PAGE_SIZE;
> > +     }
> > +
> > +     /* Make sure __pte is not a swap entry */
> > +     if (pte_valid_napot(__pte))
> > +             return napot_pte_num(napot_cont_order(__pte));
> > +
> > +     return 1;
> > +}
> > +#endif
> > +
> > +static inline pte_t ptep_get(pte_t *ptep)
> > +{
> > +     pte_t pte = READ_ONCE(*ptep);
> > +
> > +#ifdef CONFIG_RISCV_ISA_SVNAPOT
> > +     /*
> > +      * The pte we load has the N bit set and the size of the mapping in
> > +      * the pfn LSBs: keep the N bit and replace the mapping size with
> > +      * the *real* pfn since the core mm code expects to find it there.
> > +      * The mapping size will be reset just before being written to the
> > +      * page table in set_ptes().
> > +      */
> > +     if (unlikely(pte_valid_napot(pte))) {
> > +             unsigned int order = napot_cont_order(pte);
> > +             int pos = order - 1 + _PAGE_PFN_SHIFT;
> > +             unsigned long napot_mask = ~GENMASK(pos, _PAGE_PFN_SHIFT);
> > +             pte_t *orig_ptep = PTR_ALIGN_DOWN(ptep, sizeof(*ptep) * napot_pte_num(order));
> > +
> > +             pte = __pte((pte_val(pte) & napot_mask) + ((ptep - orig_ptep) << _PAGE_PFN_SHIFT));
> > +     }
> > +#endif
> > +
> > +     return pte;
> > +}
> > +#define ptep_get     ptep_get
> > +
> >  static inline void set_ptes(struct mm_struct *mm, unsigned long addr,
> >               pte_t *ptep, pte_t pteval, unsigned int nr)
> >  {
> > +#ifdef CONFIG_RISCV_ISA_SVNAPOT
> > +     if (unlikely(pte_valid_napot(pteval))) {
> > +             unsigned int order = ilog2(nr);
> > +
> > +             if (!is_napot_order(order)) {
> > +                     /*
> > +                      * Something's weird, we are given a NAPOT pte but the
> > +                      * size of the mapping is not a known NAPOT mapping
> > +                      * size, so clear the NAPOT bit and map this without
> > +                      * NAPOT support: core mm only manipulates pte with the
> > +                      * real pfn so we know the pte is valid without the N
> > +                      * bit.
> > +                      */
> > +                     pr_err("Incorrect NAPOT mapping, resetting.\n");
> > +                     pteval = pte_clear_napot(pteval);
> > +             } else {
> > +                     /*
> > +                      * NAPOT ptes that arrive here only have the N bit set
> > +                      * and their pfn does not contain the mapping size, so
> > +                      * set that here.
> > +                      */
> > +                     pteval = pte_mknapot(pteval, order);
> > +             }
> > +     }
> > +#endif
>
> I think all this complexity comes along due to using this function both as a
> public interface that the core-mm uses (which never sets napot)
> and also using
> it as an internal interface that riscv-hugetlb uses (which does set napot)? It
> might be more understandable if you layer it into a lower level/internal API and
> a higher level/public API (similar to arm64)?

I think you're right here, I'll try to do that too.

Thanks for your comments,

Alex

>
> > +
> >       page_table_check_ptes_set(mm, ptep, pteval, nr);
> >
> >       for (;;) {
> > @@ -535,6 +622,12 @@ static inline void set_ptes(struct mm_struct *mm, unsigned long addr,
> >               if (--nr == 0)
> >                       break;
> >               ptep++;
> > +
> > +#ifdef CONFIG_RISCV_ISA_SVNAPOT
> > +             if (unlikely(pte_valid_napot(pteval)))
> > +                     continue;
> > +#endif
> > +
> >               pte_val(pteval) += 1 << _PAGE_PFN_SHIFT;
> >       }
> >  }
> > diff --git a/arch/riscv/mm/hugetlbpage.c b/arch/riscv/mm/hugetlbpage.c
> > index 5ef2a6891158..fe8067ee71b4 100644
> > --- a/arch/riscv/mm/hugetlbpage.c
> > +++ b/arch/riscv/mm/hugetlbpage.c
> > @@ -256,8 +256,7 @@ void set_huge_pte_at(struct mm_struct *mm,
> >
> >       clear_flush(mm, addr, ptep, pgsize, pte_num);
> >
> > -     for (i = 0; i < pte_num; i++, ptep++, addr += pgsize)
> > -             set_pte_at(mm, addr, ptep, pte);
> > +     set_ptes(mm, addr, ptep, pte, pte_num);
> >  }
> >
> >  int huge_ptep_set_access_flags(struct vm_area_struct *vma,
> > @@ -267,16 +266,16 @@ int huge_ptep_set_access_flags(struct vm_area_struct *vma,
> >                              int dirty)
> >  {
> >       struct mm_struct *mm = vma->vm_mm;
> > -     unsigned long order;
> > +     size_t pgsize;
> >       pte_t orig_pte;
> > -     int i, pte_num;
> > +     int pte_num;
> >
> >       if (!pte_napot(pte))
> >               return ptep_set_access_flags(vma, addr, ptep, pte, dirty);
> >
> > -     order = napot_cont_order(pte);
> > -     pte_num = napot_pte_num(order);
> > -     ptep = huge_pte_offset(mm, addr, napot_cont_size(order));
> > +     pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
> > +     ptep = huge_pte_offset(mm, addr, pte_num * pgsize);
> > +
> >       orig_pte = get_clear_contig_flush(mm, addr, ptep, pte_num);
> >
> >       if (pte_dirty(orig_pte))
> > @@ -285,8 +284,7 @@ int huge_ptep_set_access_flags(struct vm_area_struct *vma,
> >       if (pte_young(orig_pte))
> >               pte = pte_mkyoung(pte);
> >
> > -     for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
> > -             set_pte_at(mm, addr, ptep, pte);
> > +     set_ptes(mm, addr, ptep, pte, pte_num);
> >
> >       return true;
> >  }
> > @@ -301,7 +299,7 @@ pte_t huge_ptep_get_and_clear(struct mm_struct *mm,
> >       if (!pte_napot(orig_pte))
> >               return ptep_get_and_clear(mm, addr, ptep);
> >
> > -     pte_num = napot_pte_num(napot_cont_order(orig_pte));
> > +     pte_num = arch_contpte_get_num_contig(ptep, 0, NULL);
> >
> >       return get_clear_contig(mm, addr, ptep, pte_num);
> >  }
> > @@ -311,24 +309,23 @@ void huge_ptep_set_wrprotect(struct mm_struct *mm,
> >                            pte_t *ptep)
> >  {
> >       pte_t pte = ptep_get(ptep);
> > -     unsigned long order;
> > +     size_t pgsize;
> >       pte_t orig_pte;
> > -     int i, pte_num;
> > +     int pte_num;
> >
> >       if (!pte_napot(pte)) {
> >               ptep_set_wrprotect(mm, addr, ptep);
> >               return;
> >       }
> >
> > -     order = napot_cont_order(pte);
> > -     pte_num = napot_pte_num(order);
> > -     ptep = huge_pte_offset(mm, addr, napot_cont_size(order));
> > +     pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
> > +     ptep = huge_pte_offset(mm, addr, pte_num * pgsize);
> > +
> >       orig_pte = get_clear_contig_flush(mm, addr, ptep, pte_num);
> >
> >       orig_pte = pte_wrprotect(orig_pte);
> >
> > -     for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
> > -             set_pte_at(mm, addr, ptep, orig_pte);
> > +     set_ptes(mm, addr, ptep, orig_pte, pte_num);
> >  }
> >
> >  pte_t huge_ptep_clear_flush(struct vm_area_struct *vma,
> > @@ -341,7 +338,7 @@ pte_t huge_ptep_clear_flush(struct vm_area_struct *vma,
> >       if (!pte_napot(pte))
> >               return ptep_clear_flush(vma, addr, ptep);
> >
> > -     pte_num = napot_pte_num(napot_cont_order(pte));
> > +     pte_num = arch_contpte_get_num_contig(ptep, 0, NULL);
> >
> >       return get_clear_contig_flush(vma->vm_mm, addr, ptep, pte_num);
> >  }
> > @@ -351,6 +348,7 @@ void huge_pte_clear(struct mm_struct *mm,
> >                   pte_t *ptep,
> >                   unsigned long sz)
> >  {
> > +     size_t pgsize;
> >       pte_t pte = ptep_get(ptep);
> >       int i, pte_num;
> >
> > @@ -359,8 +357,8 @@ void huge_pte_clear(struct mm_struct *mm,
> >               return;
> >       }
> >
> > -     pte_num = napot_pte_num(napot_cont_order(pte));
> > -     for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
> > +     pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
> > +     for (i = 0; i < pte_num; i++, addr += pgsize, ptep++)
> >               pte_clear(mm, addr, ptep);
> >  }
> >
>
Ryan Roberts May 13, 2024, 2:52 p.m. UTC | #3
On 13/05/2024 14:06, Alexandre Ghiti wrote:
> Hi Ryan,
> 
> On Fri, May 10, 2024 at 2:20 PM Ryan Roberts <ryan.roberts@arm.com> wrote:
>>
>> On 08/05/2024 12:34, Alexandre Ghiti wrote:
>>> The core mm code expects to be able to extract the pfn from a pte. NAPOT
>>> mappings work differently since its ptes actually point to the first pfn
>>> of the mapping, the other bits being used to encode the size of the
>>> mapping.
>>>
>>> So modify ptep_get() so that it returns a pte value that contains the
>>> *real* pfn (which is then different from what the HW expects) and right
>>> before storing the ptes to the page table, reset the pfn LSBs to the
>>> size of the mapping.
>>
>> Did you consider leaving the pte as is and instead modifying your pte_pfn()
>> implementation?
>>
>> For arm64 at least, it is beneficial to keep the pte marked as contiguous when
>> passing it up to core-mm because there are other helpers which need to parse the
>> contiguous bit (e.g. pte_leaf_size()). If we were to clear the cont bit in
>> ptep_get() that info would be lost and perf_get_pgtable_size() would always
>> conclude the leaf size is 4K even when it is actually 64K.
> 
> I don't clear the contpte bit here (ie the napot bit), I'm just
> setting the right pfn so that the core-mm code knows exactly which
> page is targeted by each pte of a contpte region (remember riscv napot
> extension uses the lsb of the pte pfn to encode the size the mapping,
> so all ptes of a contpte region will return the same pfn).
> 
> And from pte_pfn(), we have no way of knowing from the pte value alone
> which page is targeted, we need to know its position in the page table
> to "guess" the right pfn.

Ahh yes - good point!

> 
>>
>>>
>>> And make sure that all NAPOT mappings are set using set_ptes().
>>>
>>> Signed-off-by: Alexandre Ghiti <alexghiti@rivosinc.com>
>>> ---
>>>  arch/riscv/include/asm/pgtable-64.h |  11 +++
>>>  arch/riscv/include/asm/pgtable.h    | 105 ++++++++++++++++++++++++++--
>>>  arch/riscv/mm/hugetlbpage.c         |  38 +++++-----
>>>  3 files changed, 128 insertions(+), 26 deletions(-)
>>>
>>> diff --git a/arch/riscv/include/asm/pgtable-64.h b/arch/riscv/include/asm/pgtable-64.h
>>> index 221a5c1ee287..9fe076fc503e 100644
>>> --- a/arch/riscv/include/asm/pgtable-64.h
>>> +++ b/arch/riscv/include/asm/pgtable-64.h
>>> @@ -106,6 +106,17 @@ enum napot_cont_order {
>>>  #define napot_cont_mask(order)       (~(napot_cont_size(order) - 1UL))
>>>  #define napot_pte_num(order) BIT(order)
>>>
>>> +static inline bool is_napot_order(unsigned int order)
>>> +{
>>> +     unsigned int napot_order;
>>> +
>>> +     for_each_napot_order(napot_order)
>>> +             if (order == napot_order)
>>> +                     return true;
>>> +
>>> +     return false;
>>> +}
>>> +
>>>  #ifdef CONFIG_RISCV_ISA_SVNAPOT
>>>  #define HUGE_MAX_HSTATE              (2 + (NAPOT_ORDER_MAX - NAPOT_CONT_ORDER_BASE))
>>>  #else
>>> diff --git a/arch/riscv/include/asm/pgtable.h b/arch/riscv/include/asm/pgtable.h
>>> index 9f8ea0e33eb1..268c828f5152 100644
>>> --- a/arch/riscv/include/asm/pgtable.h
>>> +++ b/arch/riscv/include/asm/pgtable.h
>>> @@ -297,6 +297,8 @@ static inline unsigned long pte_napot(pte_t pte)
>>>       return pte_val(pte) & _PAGE_NAPOT;
>>>  }
>>>
>>> +#define pte_valid_napot(pte) (pte_present(pte) && pte_napot(pte))
>>> +
>>>  static inline pte_t pte_mknapot(pte_t pte, unsigned int order)
>>>  {
>>>       int pos = order - 1 + _PAGE_PFN_SHIFT;
>>> @@ -306,6 +308,12 @@ static inline pte_t pte_mknapot(pte_t pte, unsigned int order)
>>>       return __pte((pte_val(pte) & napot_mask) | napot_bit | _PAGE_NAPOT);
>>>  }
>>>
>>> +/* pte at entry must *not* encode the mapping size in the pfn LSBs. */
>>> +static inline pte_t pte_clear_napot(pte_t pte)
>>> +{
>>> +     return __pte(pte_val(pte) & ~_PAGE_NAPOT);
>>> +}
>>> +
>>>  #else
>>>
>>>  static __always_inline bool has_svnapot(void) { return false; }
>>> @@ -315,17 +323,14 @@ static inline unsigned long pte_napot(pte_t pte)
>>>       return 0;
>>>  }
>>>
>>> +#define pte_valid_napot(pte) false
>>> +
>>>  #endif /* CONFIG_RISCV_ISA_SVNAPOT */
>>>
>>>  /* Yields the page frame number (PFN) of a page table entry */
>>>  static inline unsigned long pte_pfn(pte_t pte)
>>>  {
>>> -     unsigned long res  = __page_val_to_pfn(pte_val(pte));
>>> -
>>> -     if (has_svnapot() && pte_napot(pte))
>>> -             res = res & (res - 1UL);
>>> -
>>> -     return res;
>>> +     return __page_val_to_pfn(pte_val(pte));
>>>  }
>>>
>>>  #define pte_page(x)     pfn_to_page(pte_pfn(x))
>>> @@ -525,9 +530,91 @@ static inline void __set_pte_at(struct mm_struct *mm, pte_t *ptep, pte_t pteval)
>>>
>>>  #define PFN_PTE_SHIFT                _PAGE_PFN_SHIFT
>>>
>>> +#ifdef CONFIG_RISCV_ISA_SVNAPOT
>>> +static inline int arch_contpte_get_num_contig(pte_t *ptep, unsigned long size,
>>> +                                           size_t *pgsize)
>>> +{
>>> +     pte_t __pte;
>>> +
>>> +     /* We must read the raw value of the pte to get the size of the mapping */
>>> +     __pte = READ_ONCE(*ptep);
>>> +
>>> +     if (pgsize) {
>>> +             if (size >= PGDIR_SIZE)
>>> +                     *pgsize = PGDIR_SIZE;
>>> +             else if (size >= P4D_SIZE)
>>> +                     *pgsize = P4D_SIZE;
>>> +             else if (size >= PUD_SIZE)
>>> +                     *pgsize = PUD_SIZE;
>>> +             else if (size >= PMD_SIZE)
>>> +                     *pgsize = PMD_SIZE;
>>> +             else
>>> +                     *pgsize = PAGE_SIZE;
>>> +     }
>>> +
>>> +     /* Make sure __pte is not a swap entry */
>>> +     if (pte_valid_napot(__pte))
>>> +             return napot_pte_num(napot_cont_order(__pte));
>>> +
>>> +     return 1;
>>> +}
>>> +#endif
>>> +
>>> +static inline pte_t ptep_get(pte_t *ptep)
>>> +{
>>> +     pte_t pte = READ_ONCE(*ptep);
>>> +
>>> +#ifdef CONFIG_RISCV_ISA_SVNAPOT
>>> +     /*
>>> +      * The pte we load has the N bit set and the size of the mapping in
>>> +      * the pfn LSBs: keep the N bit and replace the mapping size with
>>> +      * the *real* pfn since the core mm code expects to find it there.
>>> +      * The mapping size will be reset just before being written to the
>>> +      * page table in set_ptes().
>>> +      */
>>> +     if (unlikely(pte_valid_napot(pte))) {
>>> +             unsigned int order = napot_cont_order(pte);
>>> +             int pos = order - 1 + _PAGE_PFN_SHIFT;
>>> +             unsigned long napot_mask = ~GENMASK(pos, _PAGE_PFN_SHIFT);
>>> +             pte_t *orig_ptep = PTR_ALIGN_DOWN(ptep, sizeof(*ptep) * napot_pte_num(order));
>>> +
>>> +             pte = __pte((pte_val(pte) & napot_mask) + ((ptep - orig_ptep) << _PAGE_PFN_SHIFT));
>>> +     }
>>> +#endif
>>> +
>>> +     return pte;
>>> +}
>>> +#define ptep_get     ptep_get
>>> +
>>>  static inline void set_ptes(struct mm_struct *mm, unsigned long addr,
>>>               pte_t *ptep, pte_t pteval, unsigned int nr)
>>>  {
>>> +#ifdef CONFIG_RISCV_ISA_SVNAPOT
>>> +     if (unlikely(pte_valid_napot(pteval))) {
>>> +             unsigned int order = ilog2(nr);
>>> +
>>> +             if (!is_napot_order(order)) {
>>> +                     /*
>>> +                      * Something's weird, we are given a NAPOT pte but the
>>> +                      * size of the mapping is not a known NAPOT mapping
>>> +                      * size, so clear the NAPOT bit and map this without
>>> +                      * NAPOT support: core mm only manipulates pte with the
>>> +                      * real pfn so we know the pte is valid without the N
>>> +                      * bit.
>>> +                      */
>>> +                     pr_err("Incorrect NAPOT mapping, resetting.\n");
>>> +                     pteval = pte_clear_napot(pteval);
>>> +             } else {
>>> +                     /*
>>> +                      * NAPOT ptes that arrive here only have the N bit set
>>> +                      * and their pfn does not contain the mapping size, so
>>> +                      * set that here.
>>> +                      */
>>> +                     pteval = pte_mknapot(pteval, order);
>>> +             }
>>> +     }
>>> +#endif
>>
>> I think all this complexity comes along due to using this function both as a
>> public interface that the core-mm uses (which never sets napot)
>> and also using
>> it as an internal interface that riscv-hugetlb uses (which does set napot)? It
>> might be more understandable if you layer it into a lower level/internal API and
>> a higher level/public API (similar to arm64)?
> 
> I think you're right here, I'll try to do that too.
> 
> Thanks for your comments,
> 
> Alex
> 
>>
>>> +
>>>       page_table_check_ptes_set(mm, ptep, pteval, nr);
>>>
>>>       for (;;) {
>>> @@ -535,6 +622,12 @@ static inline void set_ptes(struct mm_struct *mm, unsigned long addr,
>>>               if (--nr == 0)
>>>                       break;
>>>               ptep++;
>>> +
>>> +#ifdef CONFIG_RISCV_ISA_SVNAPOT
>>> +             if (unlikely(pte_valid_napot(pteval)))
>>> +                     continue;
>>> +#endif
>>> +
>>>               pte_val(pteval) += 1 << _PAGE_PFN_SHIFT;
>>>       }
>>>  }
>>> diff --git a/arch/riscv/mm/hugetlbpage.c b/arch/riscv/mm/hugetlbpage.c
>>> index 5ef2a6891158..fe8067ee71b4 100644
>>> --- a/arch/riscv/mm/hugetlbpage.c
>>> +++ b/arch/riscv/mm/hugetlbpage.c
>>> @@ -256,8 +256,7 @@ void set_huge_pte_at(struct mm_struct *mm,
>>>
>>>       clear_flush(mm, addr, ptep, pgsize, pte_num);
>>>
>>> -     for (i = 0; i < pte_num; i++, ptep++, addr += pgsize)
>>> -             set_pte_at(mm, addr, ptep, pte);
>>> +     set_ptes(mm, addr, ptep, pte, pte_num);
>>>  }
>>>
>>>  int huge_ptep_set_access_flags(struct vm_area_struct *vma,
>>> @@ -267,16 +266,16 @@ int huge_ptep_set_access_flags(struct vm_area_struct *vma,
>>>                              int dirty)
>>>  {
>>>       struct mm_struct *mm = vma->vm_mm;
>>> -     unsigned long order;
>>> +     size_t pgsize;
>>>       pte_t orig_pte;
>>> -     int i, pte_num;
>>> +     int pte_num;
>>>
>>>       if (!pte_napot(pte))
>>>               return ptep_set_access_flags(vma, addr, ptep, pte, dirty);
>>>
>>> -     order = napot_cont_order(pte);
>>> -     pte_num = napot_pte_num(order);
>>> -     ptep = huge_pte_offset(mm, addr, napot_cont_size(order));
>>> +     pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
>>> +     ptep = huge_pte_offset(mm, addr, pte_num * pgsize);
>>> +
>>>       orig_pte = get_clear_contig_flush(mm, addr, ptep, pte_num);
>>>
>>>       if (pte_dirty(orig_pte))
>>> @@ -285,8 +284,7 @@ int huge_ptep_set_access_flags(struct vm_area_struct *vma,
>>>       if (pte_young(orig_pte))
>>>               pte = pte_mkyoung(pte);
>>>
>>> -     for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
>>> -             set_pte_at(mm, addr, ptep, pte);
>>> +     set_ptes(mm, addr, ptep, pte, pte_num);
>>>
>>>       return true;
>>>  }
>>> @@ -301,7 +299,7 @@ pte_t huge_ptep_get_and_clear(struct mm_struct *mm,
>>>       if (!pte_napot(orig_pte))
>>>               return ptep_get_and_clear(mm, addr, ptep);
>>>
>>> -     pte_num = napot_pte_num(napot_cont_order(orig_pte));
>>> +     pte_num = arch_contpte_get_num_contig(ptep, 0, NULL);
>>>
>>>       return get_clear_contig(mm, addr, ptep, pte_num);
>>>  }
>>> @@ -311,24 +309,23 @@ void huge_ptep_set_wrprotect(struct mm_struct *mm,
>>>                            pte_t *ptep)
>>>  {
>>>       pte_t pte = ptep_get(ptep);
>>> -     unsigned long order;
>>> +     size_t pgsize;
>>>       pte_t orig_pte;
>>> -     int i, pte_num;
>>> +     int pte_num;
>>>
>>>       if (!pte_napot(pte)) {
>>>               ptep_set_wrprotect(mm, addr, ptep);
>>>               return;
>>>       }
>>>
>>> -     order = napot_cont_order(pte);
>>> -     pte_num = napot_pte_num(order);
>>> -     ptep = huge_pte_offset(mm, addr, napot_cont_size(order));
>>> +     pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
>>> +     ptep = huge_pte_offset(mm, addr, pte_num * pgsize);
>>> +
>>>       orig_pte = get_clear_contig_flush(mm, addr, ptep, pte_num);
>>>
>>>       orig_pte = pte_wrprotect(orig_pte);
>>>
>>> -     for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
>>> -             set_pte_at(mm, addr, ptep, orig_pte);
>>> +     set_ptes(mm, addr, ptep, orig_pte, pte_num);
>>>  }
>>>
>>>  pte_t huge_ptep_clear_flush(struct vm_area_struct *vma,
>>> @@ -341,7 +338,7 @@ pte_t huge_ptep_clear_flush(struct vm_area_struct *vma,
>>>       if (!pte_napot(pte))
>>>               return ptep_clear_flush(vma, addr, ptep);
>>>
>>> -     pte_num = napot_pte_num(napot_cont_order(pte));
>>> +     pte_num = arch_contpte_get_num_contig(ptep, 0, NULL);
>>>
>>>       return get_clear_contig_flush(vma->vm_mm, addr, ptep, pte_num);
>>>  }
>>> @@ -351,6 +348,7 @@ void huge_pte_clear(struct mm_struct *mm,
>>>                   pte_t *ptep,
>>>                   unsigned long sz)
>>>  {
>>> +     size_t pgsize;
>>>       pte_t pte = ptep_get(ptep);
>>>       int i, pte_num;
>>>
>>> @@ -359,8 +357,8 @@ void huge_pte_clear(struct mm_struct *mm,
>>>               return;
>>>       }
>>>
>>> -     pte_num = napot_pte_num(napot_cont_order(pte));
>>> -     for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
>>> +     pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
>>> +     for (i = 0; i < pte_num; i++, addr += pgsize, ptep++)
>>>               pte_clear(mm, addr, ptep);
>>>  }
>>>
>>
diff mbox series

Patch

diff --git a/arch/riscv/include/asm/pgtable-64.h b/arch/riscv/include/asm/pgtable-64.h
index 221a5c1ee287..9fe076fc503e 100644
--- a/arch/riscv/include/asm/pgtable-64.h
+++ b/arch/riscv/include/asm/pgtable-64.h
@@ -106,6 +106,17 @@  enum napot_cont_order {
 #define napot_cont_mask(order)	(~(napot_cont_size(order) - 1UL))
 #define napot_pte_num(order)	BIT(order)
 
+static inline bool is_napot_order(unsigned int order)
+{
+	unsigned int napot_order;
+
+	for_each_napot_order(napot_order)
+		if (order == napot_order)
+			return true;
+
+	return false;
+}
+
 #ifdef CONFIG_RISCV_ISA_SVNAPOT
 #define HUGE_MAX_HSTATE		(2 + (NAPOT_ORDER_MAX - NAPOT_CONT_ORDER_BASE))
 #else
diff --git a/arch/riscv/include/asm/pgtable.h b/arch/riscv/include/asm/pgtable.h
index 9f8ea0e33eb1..268c828f5152 100644
--- a/arch/riscv/include/asm/pgtable.h
+++ b/arch/riscv/include/asm/pgtable.h
@@ -297,6 +297,8 @@  static inline unsigned long pte_napot(pte_t pte)
 	return pte_val(pte) & _PAGE_NAPOT;
 }
 
+#define pte_valid_napot(pte)	(pte_present(pte) && pte_napot(pte))
+
 static inline pte_t pte_mknapot(pte_t pte, unsigned int order)
 {
 	int pos = order - 1 + _PAGE_PFN_SHIFT;
@@ -306,6 +308,12 @@  static inline pte_t pte_mknapot(pte_t pte, unsigned int order)
 	return __pte((pte_val(pte) & napot_mask) | napot_bit | _PAGE_NAPOT);
 }
 
+/* pte at entry must *not* encode the mapping size in the pfn LSBs. */
+static inline pte_t pte_clear_napot(pte_t pte)
+{
+	return __pte(pte_val(pte) & ~_PAGE_NAPOT);
+}
+
 #else
 
 static __always_inline bool has_svnapot(void) { return false; }
@@ -315,17 +323,14 @@  static inline unsigned long pte_napot(pte_t pte)
 	return 0;
 }
 
+#define pte_valid_napot(pte)	false
+
 #endif /* CONFIG_RISCV_ISA_SVNAPOT */
 
 /* Yields the page frame number (PFN) of a page table entry */
 static inline unsigned long pte_pfn(pte_t pte)
 {
-	unsigned long res  = __page_val_to_pfn(pte_val(pte));
-
-	if (has_svnapot() && pte_napot(pte))
-		res = res & (res - 1UL);
-
-	return res;
+	return __page_val_to_pfn(pte_val(pte));
 }
 
 #define pte_page(x)     pfn_to_page(pte_pfn(x))
@@ -525,9 +530,91 @@  static inline void __set_pte_at(struct mm_struct *mm, pte_t *ptep, pte_t pteval)
 
 #define PFN_PTE_SHIFT		_PAGE_PFN_SHIFT
 
+#ifdef CONFIG_RISCV_ISA_SVNAPOT
+static inline int arch_contpte_get_num_contig(pte_t *ptep, unsigned long size,
+					      size_t *pgsize)
+{
+	pte_t __pte;
+
+	/* We must read the raw value of the pte to get the size of the mapping */
+	__pte = READ_ONCE(*ptep);
+
+	if (pgsize) {
+		if (size >= PGDIR_SIZE)
+			*pgsize = PGDIR_SIZE;
+		else if (size >= P4D_SIZE)
+			*pgsize = P4D_SIZE;
+		else if (size >= PUD_SIZE)
+			*pgsize = PUD_SIZE;
+		else if (size >= PMD_SIZE)
+			*pgsize = PMD_SIZE;
+		else
+			*pgsize = PAGE_SIZE;
+	}
+
+	/* Make sure __pte is not a swap entry */
+	if (pte_valid_napot(__pte))
+		return napot_pte_num(napot_cont_order(__pte));
+
+	return 1;
+}
+#endif
+
+static inline pte_t ptep_get(pte_t *ptep)
+{
+	pte_t pte = READ_ONCE(*ptep);
+
+#ifdef CONFIG_RISCV_ISA_SVNAPOT
+	/*
+	 * The pte we load has the N bit set and the size of the mapping in
+	 * the pfn LSBs: keep the N bit and replace the mapping size with
+	 * the *real* pfn since the core mm code expects to find it there.
+	 * The mapping size will be reset just before being written to the
+	 * page table in set_ptes().
+	 */
+	if (unlikely(pte_valid_napot(pte))) {
+		unsigned int order = napot_cont_order(pte);
+		int pos = order - 1 + _PAGE_PFN_SHIFT;
+		unsigned long napot_mask = ~GENMASK(pos, _PAGE_PFN_SHIFT);
+		pte_t *orig_ptep = PTR_ALIGN_DOWN(ptep, sizeof(*ptep) * napot_pte_num(order));
+
+		pte = __pte((pte_val(pte) & napot_mask) + ((ptep - orig_ptep) << _PAGE_PFN_SHIFT));
+	}
+#endif
+
+	return pte;
+}
+#define ptep_get	ptep_get
+
 static inline void set_ptes(struct mm_struct *mm, unsigned long addr,
 		pte_t *ptep, pte_t pteval, unsigned int nr)
 {
+#ifdef CONFIG_RISCV_ISA_SVNAPOT
+	if (unlikely(pte_valid_napot(pteval))) {
+		unsigned int order = ilog2(nr);
+
+		if (!is_napot_order(order)) {
+			/*
+			 * Something's weird, we are given a NAPOT pte but the
+			 * size of the mapping is not a known NAPOT mapping
+			 * size, so clear the NAPOT bit and map this without
+			 * NAPOT support: core mm only manipulates pte with the
+			 * real pfn so we know the pte is valid without the N
+			 * bit.
+			 */
+			pr_err("Incorrect NAPOT mapping, resetting.\n");
+			pteval = pte_clear_napot(pteval);
+		} else {
+			/*
+			 * NAPOT ptes that arrive here only have the N bit set
+			 * and their pfn does not contain the mapping size, so
+			 * set that here.
+			 */
+			pteval = pte_mknapot(pteval, order);
+		}
+	}
+#endif
+
 	page_table_check_ptes_set(mm, ptep, pteval, nr);
 
 	for (;;) {
@@ -535,6 +622,12 @@  static inline void set_ptes(struct mm_struct *mm, unsigned long addr,
 		if (--nr == 0)
 			break;
 		ptep++;
+
+#ifdef CONFIG_RISCV_ISA_SVNAPOT
+		if (unlikely(pte_valid_napot(pteval)))
+			continue;
+#endif
+
 		pte_val(pteval) += 1 << _PAGE_PFN_SHIFT;
 	}
 }
diff --git a/arch/riscv/mm/hugetlbpage.c b/arch/riscv/mm/hugetlbpage.c
index 5ef2a6891158..fe8067ee71b4 100644
--- a/arch/riscv/mm/hugetlbpage.c
+++ b/arch/riscv/mm/hugetlbpage.c
@@ -256,8 +256,7 @@  void set_huge_pte_at(struct mm_struct *mm,
 
 	clear_flush(mm, addr, ptep, pgsize, pte_num);
 
-	for (i = 0; i < pte_num; i++, ptep++, addr += pgsize)
-		set_pte_at(mm, addr, ptep, pte);
+	set_ptes(mm, addr, ptep, pte, pte_num);
 }
 
 int huge_ptep_set_access_flags(struct vm_area_struct *vma,
@@ -267,16 +266,16 @@  int huge_ptep_set_access_flags(struct vm_area_struct *vma,
 			       int dirty)
 {
 	struct mm_struct *mm = vma->vm_mm;
-	unsigned long order;
+	size_t pgsize;
 	pte_t orig_pte;
-	int i, pte_num;
+	int pte_num;
 
 	if (!pte_napot(pte))
 		return ptep_set_access_flags(vma, addr, ptep, pte, dirty);
 
-	order = napot_cont_order(pte);
-	pte_num = napot_pte_num(order);
-	ptep = huge_pte_offset(mm, addr, napot_cont_size(order));
+	pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
+	ptep = huge_pte_offset(mm, addr, pte_num * pgsize);
+
 	orig_pte = get_clear_contig_flush(mm, addr, ptep, pte_num);
 
 	if (pte_dirty(orig_pte))
@@ -285,8 +284,7 @@  int huge_ptep_set_access_flags(struct vm_area_struct *vma,
 	if (pte_young(orig_pte))
 		pte = pte_mkyoung(pte);
 
-	for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
-		set_pte_at(mm, addr, ptep, pte);
+	set_ptes(mm, addr, ptep, pte, pte_num);
 
 	return true;
 }
@@ -301,7 +299,7 @@  pte_t huge_ptep_get_and_clear(struct mm_struct *mm,
 	if (!pte_napot(orig_pte))
 		return ptep_get_and_clear(mm, addr, ptep);
 
-	pte_num = napot_pte_num(napot_cont_order(orig_pte));
+	pte_num = arch_contpte_get_num_contig(ptep, 0, NULL);
 
 	return get_clear_contig(mm, addr, ptep, pte_num);
 }
@@ -311,24 +309,23 @@  void huge_ptep_set_wrprotect(struct mm_struct *mm,
 			     pte_t *ptep)
 {
 	pte_t pte = ptep_get(ptep);
-	unsigned long order;
+	size_t pgsize;
 	pte_t orig_pte;
-	int i, pte_num;
+	int pte_num;
 
 	if (!pte_napot(pte)) {
 		ptep_set_wrprotect(mm, addr, ptep);
 		return;
 	}
 
-	order = napot_cont_order(pte);
-	pte_num = napot_pte_num(order);
-	ptep = huge_pte_offset(mm, addr, napot_cont_size(order));
+	pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
+	ptep = huge_pte_offset(mm, addr, pte_num * pgsize);
+
 	orig_pte = get_clear_contig_flush(mm, addr, ptep, pte_num);
 
 	orig_pte = pte_wrprotect(orig_pte);
 
-	for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
-		set_pte_at(mm, addr, ptep, orig_pte);
+	set_ptes(mm, addr, ptep, orig_pte, pte_num);
 }
 
 pte_t huge_ptep_clear_flush(struct vm_area_struct *vma,
@@ -341,7 +338,7 @@  pte_t huge_ptep_clear_flush(struct vm_area_struct *vma,
 	if (!pte_napot(pte))
 		return ptep_clear_flush(vma, addr, ptep);
 
-	pte_num = napot_pte_num(napot_cont_order(pte));
+	pte_num = arch_contpte_get_num_contig(ptep, 0, NULL);
 
 	return get_clear_contig_flush(vma->vm_mm, addr, ptep, pte_num);
 }
@@ -351,6 +348,7 @@  void huge_pte_clear(struct mm_struct *mm,
 		    pte_t *ptep,
 		    unsigned long sz)
 {
+	size_t pgsize;
 	pte_t pte = ptep_get(ptep);
 	int i, pte_num;
 
@@ -359,8 +357,8 @@  void huge_pte_clear(struct mm_struct *mm,
 		return;
 	}
 
-	pte_num = napot_pte_num(napot_cont_order(pte));
-	for (i = 0; i < pte_num; i++, addr += PAGE_SIZE, ptep++)
+	pte_num = arch_contpte_get_num_contig(ptep, 0, &pgsize);
+	for (i = 0; i < pte_num; i++, addr += pgsize, ptep++)
 		pte_clear(mm, addr, ptep);
 }