diff mbox series

[v5,2/2] arm64: mm: install KPTI nG mappings with MMU enabled

Message ID 20220609174320.4035379-3-ardb@kernel.org (mailing list archive)
State New, archived
Headers show
Series arm64: apply G-to-nG conversion for KPTI with MMU enabled | expand

Commit Message

Ard Biesheuvel June 9, 2022, 5:43 p.m. UTC
In cases where we unmap the kernel while running in user space, we rely
on ASIDs to distinguish the minimal trampoline from the full kernel
mapping, and this means we must use non-global attributes for those
mappings, to ensure they are scoped by ASID and will not hit in the TLB
inadvertently.

We only do this when needed, as this is generally more costly in terms
of TLB pressure, and so we boot without these non-global attributes, and
apply them to all existing kernel mappings once all CPUs are up and we
know whether or not the non-global attributes are needed. At this point,
we cannot simply unmap and remap the entire address space, so we have to
update all existing block and page descriptors in place.

Currently, we go through a lot of trouble to perform these updates with
the MMU and caches off, to avoid violating break before make (BBM) rules
imposed by the architecture. Since we make changes to page tables that
are not covered by the ID map, we gain access to those descriptors by
disabling translations altogether. This means that the stores to memory
are issued with device attributes, and require extra care in terms of
coherency, which is costly. We also rely on the ID map to access a
shared flag, which requires the ID map to be executable and writable at
the same time, which is another thing we'd prefer to avoid.

So let's switch to an approach where we replace the kernel mapping with
a minimal mapping of a few pages that can be used for a minimal, ad-hoc
fixmap that we can use to map each page table in turn as we traverse the
hierarchy.

Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
---
 arch/arm64/kernel/cpufeature.c | 54 ++++++++++++-
 arch/arm64/mm/mmu.c            |  7 ++
 arch/arm64/mm/proc.S           | 81 +++++++++++++-------
 3 files changed, 113 insertions(+), 29 deletions(-)

Comments

Mark Rutland June 14, 2022, 9:01 a.m. UTC | #1
On Thu, Jun 09, 2022 at 07:43:20PM +0200, Ard Biesheuvel wrote:
> In cases where we unmap the kernel while running in user space, we rely
> on ASIDs to distinguish the minimal trampoline from the full kernel
> mapping, and this means we must use non-global attributes for those
> mappings, to ensure they are scoped by ASID and will not hit in the TLB
> inadvertently.
> 
> We only do this when needed, as this is generally more costly in terms
> of TLB pressure, and so we boot without these non-global attributes, and
> apply them to all existing kernel mappings once all CPUs are up and we
> know whether or not the non-global attributes are needed. At this point,
> we cannot simply unmap and remap the entire address space, so we have to
> update all existing block and page descriptors in place.
> 
> Currently, we go through a lot of trouble to perform these updates with
> the MMU and caches off, to avoid violating break before make (BBM) rules
> imposed by the architecture. Since we make changes to page tables that
> are not covered by the ID map, we gain access to those descriptors by
> disabling translations altogether. This means that the stores to memory
> are issued with device attributes, and require extra care in terms of
> coherency, which is costly. We also rely on the ID map to access a
> shared flag, which requires the ID map to be executable and writable at
> the same time, which is another thing we'd prefer to avoid.
> 
> So let's switch to an approach where we replace the kernel mapping with
> a minimal mapping of a few pages that can be used for a minimal, ad-hoc
> fixmap that we can use to map each page table in turn as we traverse the
> hierarchy.
> 
> Signed-off-by: Ard Biesheuvel <ardb@kernel.org>

This addresses all my concerns, so FWIW:

  Reviewed-by: Mark Rutland <mark.rutland@arm.com>

Thanks for respinning this, and sorry I didn't post the version I promised!

Mark.

> ---
>  arch/arm64/kernel/cpufeature.c | 54 ++++++++++++-
>  arch/arm64/mm/mmu.c            |  7 ++
>  arch/arm64/mm/proc.S           | 81 +++++++++++++-------
>  3 files changed, 113 insertions(+), 29 deletions(-)
> 
> diff --git a/arch/arm64/kernel/cpufeature.c b/arch/arm64/kernel/cpufeature.c
> index 42ea2bd856c6..c2a64c9e451e 100644
> --- a/arch/arm64/kernel/cpufeature.c
> +++ b/arch/arm64/kernel/cpufeature.c
> @@ -1645,14 +1645,34 @@ static bool unmap_kernel_at_el0(const struct arm64_cpu_capabilities *entry,
>  }
>  
>  #ifdef CONFIG_UNMAP_KERNEL_AT_EL0
> +#define KPTI_NG_TEMP_VA		(-(1UL << PMD_SHIFT))
> +
> +extern
> +void create_kpti_ng_temp_pgd(pgd_t *pgdir, phys_addr_t phys, unsigned long virt,
> +			     phys_addr_t size, pgprot_t prot,
> +			     phys_addr_t (*pgtable_alloc)(int), int flags);
> +
> +static phys_addr_t kpti_ng_temp_alloc;
> +
> +static phys_addr_t kpti_ng_pgd_alloc(int shift)
> +{
> +	kpti_ng_temp_alloc -= PAGE_SIZE;
> +	return kpti_ng_temp_alloc;
> +}
> +
>  static void __nocfi
>  kpti_install_ng_mappings(const struct arm64_cpu_capabilities *__unused)
>  {
> -	typedef void (kpti_remap_fn)(int, int, phys_addr_t);
> +	typedef void (kpti_remap_fn)(int, int, phys_addr_t, unsigned long);
>  	extern kpti_remap_fn idmap_kpti_install_ng_mappings;
>  	kpti_remap_fn *remap_fn;
>  
>  	int cpu = smp_processor_id();
> +	int levels = CONFIG_PGTABLE_LEVELS;
> +	int order = order_base_2(levels);
> +	u64 kpti_ng_temp_pgd_pa = 0;
> +	pgd_t *kpti_ng_temp_pgd;
> +	u64 alloc = 0;
>  
>  	if (__this_cpu_read(this_cpu_vector) == vectors) {
>  		const char *v = arm64_get_bp_hardening_vector(EL1_VECTOR_KPTI);
> @@ -1670,12 +1690,40 @@ kpti_install_ng_mappings(const struct arm64_cpu_capabilities *__unused)
>  
>  	remap_fn = (void *)__pa_symbol(function_nocfi(idmap_kpti_install_ng_mappings));
>  
> +	if (!cpu) {
> +		alloc = __get_free_pages(GFP_ATOMIC | __GFP_ZERO, order);
> +		kpti_ng_temp_pgd = (pgd_t *)(alloc + (levels - 1) * PAGE_SIZE);
> +		kpti_ng_temp_alloc = kpti_ng_temp_pgd_pa = __pa(kpti_ng_temp_pgd);
> +
> +		//
> +		// Create a minimal page table hierarchy that permits us to map
> +		// the swapper page tables temporarily as we traverse them.
> +		//
> +		// The physical pages are laid out as follows:
> +		//
> +		// +--------+-/-------+-/------ +-\\--------+
> +		// :  PTE[] : | PMD[] : | PUD[] : || PGD[]  :
> +		// +--------+-\-------+-\------ +-//--------+
> +		//      ^
> +		// The first page is mapped into this hierarchy at a PMD_SHIFT
> +		// aligned virtual address, so that we can manipulate the PTE
> +		// level entries while the mapping is active. The first entry
> +		// covers the PTE[] page itself, the remaining entries are free
> +		// to be used as a ad-hoc fixmap.
> +		//
> +		create_kpti_ng_temp_pgd(kpti_ng_temp_pgd, __pa(alloc),
> +					KPTI_NG_TEMP_VA, PAGE_SIZE, PAGE_KERNEL,
> +					kpti_ng_pgd_alloc, 0);
> +	}
> +
>  	cpu_install_idmap();
> -	remap_fn(cpu, num_online_cpus(), __pa_symbol(swapper_pg_dir));
> +	remap_fn(cpu, num_online_cpus(), kpti_ng_temp_pgd_pa, KPTI_NG_TEMP_VA);
>  	cpu_uninstall_idmap();
>  
> -	if (!cpu)
> +	if (!cpu) {
> +		free_pages(alloc, order);
>  		arm64_use_ng_mappings = true;
> +	}
>  }
>  #else
>  static void
> diff --git a/arch/arm64/mm/mmu.c b/arch/arm64/mm/mmu.c
> index be4d6c3f5692..c5563ff990da 100644
> --- a/arch/arm64/mm/mmu.c
> +++ b/arch/arm64/mm/mmu.c
> @@ -388,6 +388,13 @@ static void __create_pgd_mapping(pgd_t *pgdir, phys_addr_t phys,
>  	} while (pgdp++, addr = next, addr != end);
>  }
>  
> +#ifdef CONFIG_UNMAP_KERNEL_AT_EL0
> +extern __alias(__create_pgd_mapping)
> +void create_kpti_ng_temp_pgd(pgd_t *pgdir, phys_addr_t phys, unsigned long virt,
> +			     phys_addr_t size, pgprot_t prot,
> +			     phys_addr_t (*pgtable_alloc)(int), int flags);
> +#endif
> +
>  static phys_addr_t __pgd_pgtable_alloc(int shift)
>  {
>  	void *ptr = (void *)__get_free_page(GFP_PGTABLE_KERNEL);
> diff --git a/arch/arm64/mm/proc.S b/arch/arm64/mm/proc.S
> index 660887152dba..972ce8d7f2c5 100644
> --- a/arch/arm64/mm/proc.S
> +++ b/arch/arm64/mm/proc.S
> @@ -14,6 +14,7 @@
>  #include <asm/asm-offsets.h>
>  #include <asm/asm_pointer_auth.h>
>  #include <asm/hwcap.h>
> +#include <asm/kernel-pgtable.h>
>  #include <asm/pgtable-hwdef.h>
>  #include <asm/cpufeature.h>
>  #include <asm/alternative.h>
> @@ -200,20 +201,19 @@ SYM_FUNC_END(idmap_cpu_replace_ttbr1)
>  	.popsection
>  
>  #ifdef CONFIG_UNMAP_KERNEL_AT_EL0
> +
> +#define KPTI_NG_PTE_FLAGS	(PTE_ATTRINDX(MT_NORMAL) | SWAPPER_PTE_FLAGS)
> +
>  	.pushsection ".idmap.text", "awx"
>  
>  	.macro	kpti_mk_tbl_ng, type, num_entries
>  	add	end_\type\()p, cur_\type\()p, #\num_entries * 8
>  .Ldo_\type:
> -	dc	cvac, cur_\type\()p		// Ensure any existing dirty
> -	dmb	sy				// lines are written back before
> -	ldr	\type, [cur_\type\()p]		// loading the entry
> +	ldr	\type, [cur_\type\()p]		// Load the entry
>  	tbz	\type, #0, .Lnext_\type		// Skip invalid and
>  	tbnz	\type, #11, .Lnext_\type	// non-global entries
>  	orr	\type, \type, #PTE_NG		// Same bit for blocks and pages
> -	str	\type, [cur_\type\()p]		// Update the entry and ensure
> -	dmb	sy				// that it is visible to all
> -	dc	civac, cur_\()\type\()p		// CPUs.
> +	str	\type, [cur_\type\()p]		// Update the entry
>  	.ifnc	\type, pte
>  	tbnz	\type, #1, .Lderef_\type
>  	.endif
> @@ -223,8 +223,29 @@ SYM_FUNC_END(idmap_cpu_replace_ttbr1)
>  	b.ne	.Ldo_\type
>  	.endm
>  
> +	/*
> +	 * Dereference the current table entry and map it into the temporary
> +	 * fixmap slot associated with the current level.
> +	 */
> +	.macro	kpti_map_pgtbl, type, level
> +	str	xzr, [temp_pte, #8 * (\level + 1)]	// break before make
> +	dsb	nshst
> +	add	pte, temp_pte, #PAGE_SIZE * (\level + 1)
> +	lsr	pte, pte, #12
> +	tlbi	vaae1, pte
> +	dsb	nsh
> +	isb
> +
> +	phys_to_pte pte, cur_\type\()p
> +	add	cur_\type\()p, temp_pte, #PAGE_SIZE * (\level + 1)
> +	orr	pte, pte, pte_flags
> +	str	pte, [temp_pte, #8 * (\level + 1)]
> +	dsb	nshst
> +	.endm
> +
>  /*
> - * void __kpti_install_ng_mappings(int cpu, int num_cpus, phys_addr_t swapper)
> + * void __kpti_install_ng_mappings(int cpu, int num_secondaries, phys_addr_t temp_pgd,
> + *				   unsigned long temp_pte_va)
>   *
>   * Called exactly once from stop_machine context by each CPU found during boot.
>   */
> @@ -232,8 +253,10 @@ __idmap_kpti_flag:
>  	.long	1
>  SYM_FUNC_START(idmap_kpti_install_ng_mappings)
>  	cpu		.req	w0
> +	temp_pte	.req	x0
>  	num_cpus	.req	w1
> -	swapper_pa	.req	x2
> +	pte_flags	.req	x1
> +	temp_pgd_phys	.req	x2
>  	swapper_ttb	.req	x3
>  	flag_ptr	.req	x4
>  	cur_pgdp	.req	x5
> @@ -246,9 +269,10 @@ SYM_FUNC_START(idmap_kpti_install_ng_mappings)
>  	cur_ptep	.req	x14
>  	end_ptep	.req	x15
>  	pte		.req	x16
> +	valid		.req	x17
>  
> +	mov	x5, x3				// preserve temp_pte arg
>  	mrs	swapper_ttb, ttbr1_el1
> -	restore_ttbr1	swapper_ttb
>  	adr	flag_ptr, __idmap_kpti_flag
>  
>  	cbnz	cpu, __idmap_kpti_secondary
> @@ -260,28 +284,28 @@ SYM_FUNC_START(idmap_kpti_install_ng_mappings)
>  	eor	w17, w17, num_cpus
>  	cbnz	w17, 1b
>  
> -	/* We need to walk swapper, so turn off the MMU. */
> -	pre_disable_mmu_workaround
> -	mrs	x17, sctlr_el1
> -	bic	x17, x17, #SCTLR_ELx_M
> -	msr	sctlr_el1, x17
> +	/* Switch to the temporary page tables on this CPU only */
> +	__idmap_cpu_set_reserved_ttbr1 x8, x9
> +	offset_ttbr1 temp_pgd_phys, x8
> +	msr	ttbr1_el1, temp_pgd_phys
>  	isb
>  
> +	mov	temp_pte, x5
> +	mov	pte_flags, #KPTI_NG_PTE_FLAGS
> +
>  	/* Everybody is enjoying the idmap, so we can rewrite swapper. */
>  	/* PGD */
> -	mov		cur_pgdp, swapper_pa
> +	adrp		cur_pgdp, swapper_pg_dir
> +	kpti_map_pgtbl	pgd, 0
>  	kpti_mk_tbl_ng	pgd, PTRS_PER_PGD
>  
> -	/* Publish the updated tables and nuke all the TLBs */
> -	dsb	sy
> -	tlbi	vmalle1is
> -	dsb	ish
> -	isb
> +	/* Ensure all the updated entries are visible to secondary CPUs */
> +	dsb	ishst
>  
> -	/* We're done: fire up the MMU again */
> -	mrs	x17, sctlr_el1
> -	orr	x17, x17, #SCTLR_ELx_M
> -	set_sctlr_el1	x17
> +	/* We're done: fire up swapper_pg_dir again */
> +	__idmap_cpu_set_reserved_ttbr1 x8, x9
> +	msr	ttbr1_el1, swapper_ttb
> +	isb
>  
>  	/* Set the flag to zero to indicate that we're all done */
>  	str	wzr, [flag_ptr]
> @@ -292,6 +316,7 @@ SYM_FUNC_START(idmap_kpti_install_ng_mappings)
>  	.if		CONFIG_PGTABLE_LEVELS > 3
>  	pud		.req	x10
>  	pte_to_phys	cur_pudp, pgd
> +	kpti_map_pgtbl	pud, 1
>  	kpti_mk_tbl_ng	pud, PTRS_PER_PUD
>  	b		.Lnext_pgd
>  	.else		/* CONFIG_PGTABLE_LEVELS <= 3 */
> @@ -304,6 +329,7 @@ SYM_FUNC_START(idmap_kpti_install_ng_mappings)
>  	.if		CONFIG_PGTABLE_LEVELS > 2
>  	pmd		.req	x13
>  	pte_to_phys	cur_pmdp, pud
> +	kpti_map_pgtbl	pmd, 2
>  	kpti_mk_tbl_ng	pmd, PTRS_PER_PMD
>  	b		.Lnext_pud
>  	.else		/* CONFIG_PGTABLE_LEVELS <= 2 */
> @@ -314,12 +340,15 @@ SYM_FUNC_START(idmap_kpti_install_ng_mappings)
>  .Lderef_pmd:
>  	/* PTE */
>  	pte_to_phys	cur_ptep, pmd
> +	kpti_map_pgtbl	pte, 3
>  	kpti_mk_tbl_ng	pte, PTRS_PER_PTE
>  	b		.Lnext_pmd
>  
>  	.unreq	cpu
> +	.unreq	temp_pte
>  	.unreq	num_cpus
> -	.unreq	swapper_pa
> +	.unreq	pte_flags
> +	.unreq	temp_pgd_phys
>  	.unreq	cur_pgdp
>  	.unreq	end_pgdp
>  	.unreq	pgd
> @@ -332,6 +361,7 @@ SYM_FUNC_START(idmap_kpti_install_ng_mappings)
>  	.unreq	cur_ptep
>  	.unreq	end_ptep
>  	.unreq	pte
> +	.unreq	valid
>  
>  	/* Secondary CPUs end up here */
>  __idmap_kpti_secondary:
> @@ -351,7 +381,6 @@ __idmap_kpti_secondary:
>  	cbnz	w16, 1b
>  
>  	/* All done, act like nothing happened */
> -	offset_ttbr1 swapper_ttb, x16
>  	msr	ttbr1_el1, swapper_ttb
>  	isb
>  	ret
> -- 
> 2.30.2
>
Ard Biesheuvel June 14, 2022, 9:09 a.m. UTC | #2
On Tue, 14 Jun 2022 at 11:01, Mark Rutland <mark.rutland@arm.com> wrote:
>
> On Thu, Jun 09, 2022 at 07:43:20PM +0200, Ard Biesheuvel wrote:
> > In cases where we unmap the kernel while running in user space, we rely
> > on ASIDs to distinguish the minimal trampoline from the full kernel
> > mapping, and this means we must use non-global attributes for those
> > mappings, to ensure they are scoped by ASID and will not hit in the TLB
> > inadvertently.
> >
> > We only do this when needed, as this is generally more costly in terms
> > of TLB pressure, and so we boot without these non-global attributes, and
> > apply them to all existing kernel mappings once all CPUs are up and we
> > know whether or not the non-global attributes are needed. At this point,
> > we cannot simply unmap and remap the entire address space, so we have to
> > update all existing block and page descriptors in place.
> >
> > Currently, we go through a lot of trouble to perform these updates with
> > the MMU and caches off, to avoid violating break before make (BBM) rules
> > imposed by the architecture. Since we make changes to page tables that
> > are not covered by the ID map, we gain access to those descriptors by
> > disabling translations altogether. This means that the stores to memory
> > are issued with device attributes, and require extra care in terms of
> > coherency, which is costly. We also rely on the ID map to access a
> > shared flag, which requires the ID map to be executable and writable at
> > the same time, which is another thing we'd prefer to avoid.
> >
> > So let's switch to an approach where we replace the kernel mapping with
> > a minimal mapping of a few pages that can be used for a minimal, ad-hoc
> > fixmap that we can use to map each page table in turn as we traverse the
> > hierarchy.
> >
> > Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
>
> This addresses all my concerns, so FWIW:
>
>   Reviewed-by: Mark Rutland <mark.rutland@arm.com>
>
> Thanks for respinning this, and sorry I didn't post the version I promised!
>

No worries - it seemed to me that we've both spent more time on this
than we should have, so I just went back to a more incremental
approach.
diff mbox series

Patch

diff --git a/arch/arm64/kernel/cpufeature.c b/arch/arm64/kernel/cpufeature.c
index 42ea2bd856c6..c2a64c9e451e 100644
--- a/arch/arm64/kernel/cpufeature.c
+++ b/arch/arm64/kernel/cpufeature.c
@@ -1645,14 +1645,34 @@  static bool unmap_kernel_at_el0(const struct arm64_cpu_capabilities *entry,
 }
 
 #ifdef CONFIG_UNMAP_KERNEL_AT_EL0
+#define KPTI_NG_TEMP_VA		(-(1UL << PMD_SHIFT))
+
+extern
+void create_kpti_ng_temp_pgd(pgd_t *pgdir, phys_addr_t phys, unsigned long virt,
+			     phys_addr_t size, pgprot_t prot,
+			     phys_addr_t (*pgtable_alloc)(int), int flags);
+
+static phys_addr_t kpti_ng_temp_alloc;
+
+static phys_addr_t kpti_ng_pgd_alloc(int shift)
+{
+	kpti_ng_temp_alloc -= PAGE_SIZE;
+	return kpti_ng_temp_alloc;
+}
+
 static void __nocfi
 kpti_install_ng_mappings(const struct arm64_cpu_capabilities *__unused)
 {
-	typedef void (kpti_remap_fn)(int, int, phys_addr_t);
+	typedef void (kpti_remap_fn)(int, int, phys_addr_t, unsigned long);
 	extern kpti_remap_fn idmap_kpti_install_ng_mappings;
 	kpti_remap_fn *remap_fn;
 
 	int cpu = smp_processor_id();
+	int levels = CONFIG_PGTABLE_LEVELS;
+	int order = order_base_2(levels);
+	u64 kpti_ng_temp_pgd_pa = 0;
+	pgd_t *kpti_ng_temp_pgd;
+	u64 alloc = 0;
 
 	if (__this_cpu_read(this_cpu_vector) == vectors) {
 		const char *v = arm64_get_bp_hardening_vector(EL1_VECTOR_KPTI);
@@ -1670,12 +1690,40 @@  kpti_install_ng_mappings(const struct arm64_cpu_capabilities *__unused)
 
 	remap_fn = (void *)__pa_symbol(function_nocfi(idmap_kpti_install_ng_mappings));
 
+	if (!cpu) {
+		alloc = __get_free_pages(GFP_ATOMIC | __GFP_ZERO, order);
+		kpti_ng_temp_pgd = (pgd_t *)(alloc + (levels - 1) * PAGE_SIZE);
+		kpti_ng_temp_alloc = kpti_ng_temp_pgd_pa = __pa(kpti_ng_temp_pgd);
+
+		//
+		// Create a minimal page table hierarchy that permits us to map
+		// the swapper page tables temporarily as we traverse them.
+		//
+		// The physical pages are laid out as follows:
+		//
+		// +--------+-/-------+-/------ +-\\--------+
+		// :  PTE[] : | PMD[] : | PUD[] : || PGD[]  :
+		// +--------+-\-------+-\------ +-//--------+
+		//      ^
+		// The first page is mapped into this hierarchy at a PMD_SHIFT
+		// aligned virtual address, so that we can manipulate the PTE
+		// level entries while the mapping is active. The first entry
+		// covers the PTE[] page itself, the remaining entries are free
+		// to be used as a ad-hoc fixmap.
+		//
+		create_kpti_ng_temp_pgd(kpti_ng_temp_pgd, __pa(alloc),
+					KPTI_NG_TEMP_VA, PAGE_SIZE, PAGE_KERNEL,
+					kpti_ng_pgd_alloc, 0);
+	}
+
 	cpu_install_idmap();
-	remap_fn(cpu, num_online_cpus(), __pa_symbol(swapper_pg_dir));
+	remap_fn(cpu, num_online_cpus(), kpti_ng_temp_pgd_pa, KPTI_NG_TEMP_VA);
 	cpu_uninstall_idmap();
 
-	if (!cpu)
+	if (!cpu) {
+		free_pages(alloc, order);
 		arm64_use_ng_mappings = true;
+	}
 }
 #else
 static void
diff --git a/arch/arm64/mm/mmu.c b/arch/arm64/mm/mmu.c
index be4d6c3f5692..c5563ff990da 100644
--- a/arch/arm64/mm/mmu.c
+++ b/arch/arm64/mm/mmu.c
@@ -388,6 +388,13 @@  static void __create_pgd_mapping(pgd_t *pgdir, phys_addr_t phys,
 	} while (pgdp++, addr = next, addr != end);
 }
 
+#ifdef CONFIG_UNMAP_KERNEL_AT_EL0
+extern __alias(__create_pgd_mapping)
+void create_kpti_ng_temp_pgd(pgd_t *pgdir, phys_addr_t phys, unsigned long virt,
+			     phys_addr_t size, pgprot_t prot,
+			     phys_addr_t (*pgtable_alloc)(int), int flags);
+#endif
+
 static phys_addr_t __pgd_pgtable_alloc(int shift)
 {
 	void *ptr = (void *)__get_free_page(GFP_PGTABLE_KERNEL);
diff --git a/arch/arm64/mm/proc.S b/arch/arm64/mm/proc.S
index 660887152dba..972ce8d7f2c5 100644
--- a/arch/arm64/mm/proc.S
+++ b/arch/arm64/mm/proc.S
@@ -14,6 +14,7 @@ 
 #include <asm/asm-offsets.h>
 #include <asm/asm_pointer_auth.h>
 #include <asm/hwcap.h>
+#include <asm/kernel-pgtable.h>
 #include <asm/pgtable-hwdef.h>
 #include <asm/cpufeature.h>
 #include <asm/alternative.h>
@@ -200,20 +201,19 @@  SYM_FUNC_END(idmap_cpu_replace_ttbr1)
 	.popsection
 
 #ifdef CONFIG_UNMAP_KERNEL_AT_EL0
+
+#define KPTI_NG_PTE_FLAGS	(PTE_ATTRINDX(MT_NORMAL) | SWAPPER_PTE_FLAGS)
+
 	.pushsection ".idmap.text", "awx"
 
 	.macro	kpti_mk_tbl_ng, type, num_entries
 	add	end_\type\()p, cur_\type\()p, #\num_entries * 8
 .Ldo_\type:
-	dc	cvac, cur_\type\()p		// Ensure any existing dirty
-	dmb	sy				// lines are written back before
-	ldr	\type, [cur_\type\()p]		// loading the entry
+	ldr	\type, [cur_\type\()p]		// Load the entry
 	tbz	\type, #0, .Lnext_\type		// Skip invalid and
 	tbnz	\type, #11, .Lnext_\type	// non-global entries
 	orr	\type, \type, #PTE_NG		// Same bit for blocks and pages
-	str	\type, [cur_\type\()p]		// Update the entry and ensure
-	dmb	sy				// that it is visible to all
-	dc	civac, cur_\()\type\()p		// CPUs.
+	str	\type, [cur_\type\()p]		// Update the entry
 	.ifnc	\type, pte
 	tbnz	\type, #1, .Lderef_\type
 	.endif
@@ -223,8 +223,29 @@  SYM_FUNC_END(idmap_cpu_replace_ttbr1)
 	b.ne	.Ldo_\type
 	.endm
 
+	/*
+	 * Dereference the current table entry and map it into the temporary
+	 * fixmap slot associated with the current level.
+	 */
+	.macro	kpti_map_pgtbl, type, level
+	str	xzr, [temp_pte, #8 * (\level + 1)]	// break before make
+	dsb	nshst
+	add	pte, temp_pte, #PAGE_SIZE * (\level + 1)
+	lsr	pte, pte, #12
+	tlbi	vaae1, pte
+	dsb	nsh
+	isb
+
+	phys_to_pte pte, cur_\type\()p
+	add	cur_\type\()p, temp_pte, #PAGE_SIZE * (\level + 1)
+	orr	pte, pte, pte_flags
+	str	pte, [temp_pte, #8 * (\level + 1)]
+	dsb	nshst
+	.endm
+
 /*
- * void __kpti_install_ng_mappings(int cpu, int num_cpus, phys_addr_t swapper)
+ * void __kpti_install_ng_mappings(int cpu, int num_secondaries, phys_addr_t temp_pgd,
+ *				   unsigned long temp_pte_va)
  *
  * Called exactly once from stop_machine context by each CPU found during boot.
  */
@@ -232,8 +253,10 @@  __idmap_kpti_flag:
 	.long	1
 SYM_FUNC_START(idmap_kpti_install_ng_mappings)
 	cpu		.req	w0
+	temp_pte	.req	x0
 	num_cpus	.req	w1
-	swapper_pa	.req	x2
+	pte_flags	.req	x1
+	temp_pgd_phys	.req	x2
 	swapper_ttb	.req	x3
 	flag_ptr	.req	x4
 	cur_pgdp	.req	x5
@@ -246,9 +269,10 @@  SYM_FUNC_START(idmap_kpti_install_ng_mappings)
 	cur_ptep	.req	x14
 	end_ptep	.req	x15
 	pte		.req	x16
+	valid		.req	x17
 
+	mov	x5, x3				// preserve temp_pte arg
 	mrs	swapper_ttb, ttbr1_el1
-	restore_ttbr1	swapper_ttb
 	adr	flag_ptr, __idmap_kpti_flag
 
 	cbnz	cpu, __idmap_kpti_secondary
@@ -260,28 +284,28 @@  SYM_FUNC_START(idmap_kpti_install_ng_mappings)
 	eor	w17, w17, num_cpus
 	cbnz	w17, 1b
 
-	/* We need to walk swapper, so turn off the MMU. */
-	pre_disable_mmu_workaround
-	mrs	x17, sctlr_el1
-	bic	x17, x17, #SCTLR_ELx_M
-	msr	sctlr_el1, x17
+	/* Switch to the temporary page tables on this CPU only */
+	__idmap_cpu_set_reserved_ttbr1 x8, x9
+	offset_ttbr1 temp_pgd_phys, x8
+	msr	ttbr1_el1, temp_pgd_phys
 	isb
 
+	mov	temp_pte, x5
+	mov	pte_flags, #KPTI_NG_PTE_FLAGS
+
 	/* Everybody is enjoying the idmap, so we can rewrite swapper. */
 	/* PGD */
-	mov		cur_pgdp, swapper_pa
+	adrp		cur_pgdp, swapper_pg_dir
+	kpti_map_pgtbl	pgd, 0
 	kpti_mk_tbl_ng	pgd, PTRS_PER_PGD
 
-	/* Publish the updated tables and nuke all the TLBs */
-	dsb	sy
-	tlbi	vmalle1is
-	dsb	ish
-	isb
+	/* Ensure all the updated entries are visible to secondary CPUs */
+	dsb	ishst
 
-	/* We're done: fire up the MMU again */
-	mrs	x17, sctlr_el1
-	orr	x17, x17, #SCTLR_ELx_M
-	set_sctlr_el1	x17
+	/* We're done: fire up swapper_pg_dir again */
+	__idmap_cpu_set_reserved_ttbr1 x8, x9
+	msr	ttbr1_el1, swapper_ttb
+	isb
 
 	/* Set the flag to zero to indicate that we're all done */
 	str	wzr, [flag_ptr]
@@ -292,6 +316,7 @@  SYM_FUNC_START(idmap_kpti_install_ng_mappings)
 	.if		CONFIG_PGTABLE_LEVELS > 3
 	pud		.req	x10
 	pte_to_phys	cur_pudp, pgd
+	kpti_map_pgtbl	pud, 1
 	kpti_mk_tbl_ng	pud, PTRS_PER_PUD
 	b		.Lnext_pgd
 	.else		/* CONFIG_PGTABLE_LEVELS <= 3 */
@@ -304,6 +329,7 @@  SYM_FUNC_START(idmap_kpti_install_ng_mappings)
 	.if		CONFIG_PGTABLE_LEVELS > 2
 	pmd		.req	x13
 	pte_to_phys	cur_pmdp, pud
+	kpti_map_pgtbl	pmd, 2
 	kpti_mk_tbl_ng	pmd, PTRS_PER_PMD
 	b		.Lnext_pud
 	.else		/* CONFIG_PGTABLE_LEVELS <= 2 */
@@ -314,12 +340,15 @@  SYM_FUNC_START(idmap_kpti_install_ng_mappings)
 .Lderef_pmd:
 	/* PTE */
 	pte_to_phys	cur_ptep, pmd
+	kpti_map_pgtbl	pte, 3
 	kpti_mk_tbl_ng	pte, PTRS_PER_PTE
 	b		.Lnext_pmd
 
 	.unreq	cpu
+	.unreq	temp_pte
 	.unreq	num_cpus
-	.unreq	swapper_pa
+	.unreq	pte_flags
+	.unreq	temp_pgd_phys
 	.unreq	cur_pgdp
 	.unreq	end_pgdp
 	.unreq	pgd
@@ -332,6 +361,7 @@  SYM_FUNC_START(idmap_kpti_install_ng_mappings)
 	.unreq	cur_ptep
 	.unreq	end_ptep
 	.unreq	pte
+	.unreq	valid
 
 	/* Secondary CPUs end up here */
 __idmap_kpti_secondary:
@@ -351,7 +381,6 @@  __idmap_kpti_secondary:
 	cbnz	w16, 1b
 
 	/* All done, act like nothing happened */
-	offset_ttbr1 swapper_ttb, x16
 	msr	ttbr1_el1, swapper_ttb
 	isb
 	ret