diff mbox series

[v2,24/25] KVM: x86/mmu: initialize constant-value fields just once

Message ID 20220221162243.683208-25-pbonzini@redhat.com (mailing list archive)
State New, archived
Headers show
Series KVM MMU refactoring part 2: role changes | expand

Commit Message

Paolo Bonzini Feb. 21, 2022, 4:22 p.m. UTC
The get_guest_pgd, get_pdptr and inject_page_fault pointers are constant
for all three of root_mmu, guest_mmu and nested_mmu.  In fact, the guest_mmu
function pointers depend on the processor vendor and need to be retrieved
from three new nested_ops, but the others are absolutely the same.

Opportunistically stop initializing get_pdptr for nested EPT, since it does
not have PDPTRs.

Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
---
 arch/x86/include/asm/kvm_host.h |  5 +++
 arch/x86/kvm/mmu/mmu.c          | 65 +++++++++++++++++----------------
 arch/x86/kvm/svm/nested.c       |  9 +++--
 arch/x86/kvm/vmx/nested.c       |  5 +--
 4 files changed, 46 insertions(+), 38 deletions(-)

Comments

Sean Christopherson March 8, 2022, 8:58 p.m. UTC | #1
On Mon, Feb 21, 2022, Paolo Bonzini wrote:
>  
> +	vcpu->arch.root_mmu.get_guest_pgd = kvm_get_guest_cr3;
> +	vcpu->arch.root_mmu.get_pdptr = kvm_pdptr_read;
> +
> +	if (tdp_enabled) {

Putting all this code is in a separate helper reduces line-lengths via early
returns.  And it'll allow us to do the same for the nested specific MMUs if we
ever get smart and move "nested" to x86.c (preferably as enable_nested or
nested_enabled).

> +		vcpu->arch.root_mmu.inject_page_fault = kvm_inject_page_fault;
> +		vcpu->arch.root_mmu.page_fault = kvm_tdp_page_fault;
> +		vcpu->arch.root_mmu.sync_page = nonpaging_sync_page;
> +		vcpu->arch.root_mmu.invlpg = NULL;
> +		reset_tdp_shadow_zero_bits_mask(&vcpu->arch.root_mmu);
> +
> +		vcpu->arch.guest_mmu.get_guest_pgd = kvm_x86_ops.nested_ops->get_nested_pgd;
> +		vcpu->arch.guest_mmu.get_pdptr = kvm_x86_ops.nested_ops->get_nested_pdptr;
> +		vcpu->arch.guest_mmu.inject_page_fault = kvm_x86_ops.nested_ops->inject_nested_tdp_vmexit;

Using nested_ops is clever, but IMO unnecessary, especially since we can go even
further by adding a nEPT specific hook to initialize its constant shadow paging
stuff.

Here's what I had written spliced in with your code.  Compile tested only for
this version.


From: Paolo Bonzini <pbonzini@redhat.com>
Date: Mon, 21 Feb 2022 11:22:42 -0500
Subject: [PATCH] KVM: x86/mmu: initialize constant-value fields just once

The get_guest_pgd, get_pdptr and inject_page_fault pointers are constant
for all three of root_mmu, guest_mmu and nested_mmu.  The guest_mmu
function pointers depend on the processor vendor, but are otherwise
constant.

Opportunistically stop initializing get_pdptr for nested EPT, since it
does not have PDPTRs.

Opportunistically change kvm_mmu_create() to return '0' unconditionally
in its happy path to make it obvious that it's a happy path.

Co-developed-by: Sean Christopherson <seanjc@google.com>
Signed-off-by: Sean Christopherson <seanjc@google.com>
Signed-off-by: Paolo Bonzini <pbonzini@redhat.com>
---
 arch/x86/kvm/mmu.h        |  1 +
 arch/x86/kvm/mmu/mmu.c    | 85 ++++++++++++++++++++++++---------------
 arch/x86/kvm/svm/nested.c | 15 +++++--
 arch/x86/kvm/svm/svm.c    |  3 ++
 arch/x86/kvm/svm/svm.h    |  1 +
 arch/x86/kvm/vmx/nested.c | 13 ++++--
 arch/x86/kvm/vmx/nested.h |  1 +
 arch/x86/kvm/vmx/vmx.c    |  3 ++
 8 files changed, 82 insertions(+), 40 deletions(-)

diff --git a/arch/x86/kvm/mmu.h b/arch/x86/kvm/mmu.h
index 9517e56a0da1..bd2a6e20307c 100644
--- a/arch/x86/kvm/mmu.h
+++ b/arch/x86/kvm/mmu.h
@@ -71,6 +71,7 @@ void kvm_mmu_set_ept_masks(bool has_ad_bits, bool has_exec_only);
 void kvm_init_mmu(struct kvm_vcpu *vcpu);
 void kvm_init_shadow_npt_mmu(struct kvm_vcpu *vcpu, unsigned long cr0,
 			     unsigned long cr4, u64 efer, gpa_t nested_cr3);
+void kvm_init_shadow_ept_mmu_constants(struct kvm_vcpu *vcpu);
 void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
 			     int huge_page_level, bool accessed_dirty,
 			     gpa_t new_eptp);
diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index 8c388add95cb..db2d88c59198 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -4778,12 +4778,6 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu,

 	context->cpu_mode.as_u64 = cpu_mode.as_u64;
 	context->root_role.word = root_role.word;
-	context->page_fault = kvm_tdp_page_fault;
-	context->sync_page = nonpaging_sync_page;
-	context->invlpg = NULL;
-	context->get_guest_pgd = kvm_get_guest_cr3;
-	context->get_pdptr = kvm_pdptr_read;
-	context->inject_page_fault = kvm_inject_page_fault;

 	if (!is_cr0_pg(context))
 		context->gva_to_gpa = nonpaging_gva_to_gpa;
@@ -4793,7 +4787,6 @@ static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu,
 		context->gva_to_gpa = paging32_gva_to_gpa;

 	reset_guest_paging_metadata(vcpu, context);
-	reset_tdp_shadow_zero_bits_mask(context);
 }

 static void shadow_mmu_init_context(struct kvm_vcpu *vcpu, struct kvm_mmu *context,
@@ -4818,8 +4811,8 @@ static void shadow_mmu_init_context(struct kvm_vcpu *vcpu, struct kvm_mmu *conte
 	reset_shadow_zero_bits_mask(vcpu, context);
 }

-static void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu,
-				union kvm_mmu_paging_mode cpu_mode)
+static void init_kvm_softmmu(struct kvm_vcpu *vcpu,
+			     union kvm_mmu_paging_mode cpu_mode)
 {
 	struct kvm_mmu *context = &vcpu->arch.root_mmu;
 	union kvm_mmu_page_role root_role;
@@ -4891,6 +4884,17 @@ kvm_calc_shadow_ept_root_page_role(struct kvm_vcpu *vcpu, bool accessed_dirty,
 	return role;
 }

+void kvm_init_shadow_ept_mmu_constants(struct kvm_vcpu *vcpu)
+{
+	struct kvm_mmu *guest_mmu = &vcpu->arch.guest_mmu;
+
+	guest_mmu->page_fault = ept_page_fault;
+	guest_mmu->gva_to_gpa = ept_gva_to_gpa;
+	guest_mmu->sync_page  = ept_sync_page;
+	guest_mmu->invlpg     = ept_invlpg;
+}
+EXPORT_SYMBOL_GPL(kvm_init_shadow_ept_mmu_constants);
+
 void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
 			     int huge_page_level, bool accessed_dirty,
 			     gpa_t new_eptp)
@@ -4912,7 +4916,6 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
 		context->invlpg = ept_invlpg;

 		update_permission_bitmask(context, true);
-		context->pkru_mask = 0;
 		reset_rsvds_bits_mask_ept(vcpu, context, execonly, huge_page_level);
 		reset_ept_shadow_zero_bits_mask(context, execonly);
 	}
@@ -4921,18 +4924,6 @@ void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_ept_mmu);

-static void init_kvm_softmmu(struct kvm_vcpu *vcpu,
-			     union kvm_mmu_paging_mode cpu_mode)
-{
-	struct kvm_mmu *context = &vcpu->arch.root_mmu;
-
-	kvm_init_shadow_mmu(vcpu, cpu_mode);
-
-	context->get_guest_pgd	   = kvm_get_guest_cr3;
-	context->get_pdptr         = kvm_pdptr_read;
-	context->inject_page_fault = kvm_inject_page_fault;
-}
-
 static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu,
 				union kvm_mmu_paging_mode new_mode)
 {
@@ -4941,16 +4932,7 @@ static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu,
 	if (new_mode.as_u64 == g_context->cpu_mode.as_u64)
 		return;

-	g_context->cpu_mode.as_u64   = new_mode.as_u64;
-	g_context->get_guest_pgd     = kvm_get_guest_cr3;
-	g_context->get_pdptr         = kvm_pdptr_read;
-	g_context->inject_page_fault = kvm_inject_page_fault;
-
-	/*
-	 * L2 page tables are never shadowed, so there is no need to sync
-	 * SPTEs.
-	 */
-	g_context->invlpg            = NULL;
+	g_context->cpu_mode.as_u64 = new_mode.as_u64;

 	/*
 	 * Note that arch.mmu->gva_to_gpa translates l2_gpa to l1_gpa using
@@ -5499,6 +5481,40 @@ static void free_mmu_pages(struct kvm_mmu *mmu)
 	free_page((unsigned long)mmu->pml5_root);
 }

+static void kvm_init_mmu_constants(struct kvm_vcpu *vcpu)
+{
+	struct kvm_mmu *nested_mmu = &vcpu->arch.nested_mmu;
+	struct kvm_mmu *root_mmu = &vcpu->arch.root_mmu;
+
+	root_mmu->get_guest_pgd	    = kvm_get_guest_cr3;
+	root_mmu->get_pdptr	    = kvm_pdptr_read;
+	root_mmu->inject_page_fault = kvm_inject_page_fault;
+
+	/*
+	 * When shadowing IA32 page tables, all other callbacks various based
+	 * on paging mode, and the guest+nested MMUs are unused.
+	 */
+	if (!tdp_enabled)
+		return;
+
+	root_mmu->page_fault = kvm_tdp_page_fault;
+	root_mmu->sync_page  = nonpaging_sync_page;
+	root_mmu->invlpg     = NULL;
+	reset_tdp_shadow_zero_bits_mask(&vcpu->arch.root_mmu);
+
+	/*
+	 * Nested TDP MMU callbacks that are constant are vendor specific due
+	 * to the vast differences between EPT and NPT.  NPT in particular is
+	 * nasty because L1 may use 32-bit and/or 64-bit paging.
+	 */
+	nested_mmu->get_guest_pgd     = kvm_get_guest_cr3;
+	nested_mmu->get_pdptr         = kvm_pdptr_read;
+	nested_mmu->inject_page_fault = kvm_inject_page_fault;
+
+	/* L2 page tables are never shadowed, there's no need to sync SPTEs. */
+	nested_mmu->invlpg            = NULL;
+}
+
 static int __kvm_mmu_create(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu)
 {
 	struct page *page;
@@ -5575,7 +5591,10 @@ int kvm_mmu_create(struct kvm_vcpu *vcpu)
 	if (ret)
 		goto fail_allocate_root;

-	return ret;
+	kvm_init_mmu_constants(vcpu);
+
+	return 0;
+
  fail_allocate_root:
 	free_mmu_pages(&vcpu->arch.guest_mmu);
 	return ret;
diff --git a/arch/x86/kvm/svm/nested.c b/arch/x86/kvm/svm/nested.c
index dd942c719cf6..c58c9d876a6c 100644
--- a/arch/x86/kvm/svm/nested.c
+++ b/arch/x86/kvm/svm/nested.c
@@ -96,6 +96,15 @@ static unsigned long nested_svm_get_tdp_cr3(struct kvm_vcpu *vcpu)
 	return svm->nested.ctl.nested_cr3;
 }

+void nested_svm_init_mmu_constants(struct kvm_vcpu *vcpu)
+{
+	struct kvm_mmu *guest_mmu = &vcpu->arch.guest_mmu;
+
+	guest_mmu->get_guest_pgd     = nested_svm_get_tdp_cr3;
+	guest_mmu->get_pdptr         = nested_svm_get_tdp_pdptr;
+	guest_mmu->inject_page_fault = nested_svm_inject_npf_exit;
+}
+
 static void nested_svm_init_mmu_context(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
@@ -112,10 +121,8 @@ static void nested_svm_init_mmu_context(struct kvm_vcpu *vcpu)
 	kvm_init_shadow_npt_mmu(vcpu, X86_CR0_PG, svm->vmcb01.ptr->save.cr4,
 				svm->vmcb01.ptr->save.efer,
 				svm->nested.ctl.nested_cr3);
-	vcpu->arch.mmu->get_guest_pgd     = nested_svm_get_tdp_cr3;
-	vcpu->arch.mmu->get_pdptr         = nested_svm_get_tdp_pdptr;
-	vcpu->arch.mmu->inject_page_fault = nested_svm_inject_npf_exit;
-	vcpu->arch.walk_mmu              = &vcpu->arch.nested_mmu;
+
+	vcpu->arch.walk_mmu = &vcpu->arch.nested_mmu;
 }

 static void nested_svm_uninit_mmu_context(struct kvm_vcpu *vcpu)
diff --git a/arch/x86/kvm/svm/svm.c b/arch/x86/kvm/svm/svm.c
index a8ee949b2403..db62b3e88317 100644
--- a/arch/x86/kvm/svm/svm.c
+++ b/arch/x86/kvm/svm/svm.c
@@ -1228,6 +1228,9 @@ static int svm_vcpu_create(struct kvm_vcpu *vcpu)

 	svm->guest_state_loaded = false;

+	if (npt_enabled && nested)
+		nested_svm_init_mmu_constants(vcpu);
+
 	return 0;

 error_free_vmsa_page:
diff --git a/arch/x86/kvm/svm/svm.h b/arch/x86/kvm/svm/svm.h
index e45b5645d5e0..99c5a57ab5dd 100644
--- a/arch/x86/kvm/svm/svm.h
+++ b/arch/x86/kvm/svm/svm.h
@@ -564,6 +564,7 @@ void nested_copy_vmcb_save_to_cache(struct vcpu_svm *svm,
 void nested_sync_control_from_vmcb02(struct vcpu_svm *svm);
 void nested_vmcb02_compute_g_pat(struct vcpu_svm *svm);
 void svm_switch_vmcb(struct vcpu_svm *svm, struct kvm_vmcb_info *target_vmcb);
+void nested_svm_init_mmu_constants(struct kvm_vcpu *vcpu);

 extern struct kvm_x86_nested_ops svm_nested_ops;

diff --git a/arch/x86/kvm/vmx/nested.c b/arch/x86/kvm/vmx/nested.c
index cc4c74339d35..385f60305555 100644
--- a/arch/x86/kvm/vmx/nested.c
+++ b/arch/x86/kvm/vmx/nested.c
@@ -407,15 +407,22 @@ static void nested_ept_new_eptp(struct kvm_vcpu *vcpu)
 				nested_ept_get_eptp(vcpu));
 }

+void nested_ept_init_mmu_constants(struct kvm_vcpu *vcpu)
+{
+	struct kvm_mmu *mmu = &vcpu->arch.guest_mmu;
+
+	mmu->get_guest_pgd	= nested_ept_get_eptp;
+	mmu->inject_page_fault	= nested_ept_inject_page_fault;
+
+	kvm_init_shadow_ept_mmu_constants(vcpu);
+}
+
 static void nested_ept_init_mmu_context(struct kvm_vcpu *vcpu)
 {
 	WARN_ON(mmu_is_nested(vcpu));

 	vcpu->arch.mmu = &vcpu->arch.guest_mmu;
 	nested_ept_new_eptp(vcpu);
-	vcpu->arch.mmu->get_guest_pgd     = nested_ept_get_eptp;
-	vcpu->arch.mmu->inject_page_fault = nested_ept_inject_page_fault;
-	vcpu->arch.mmu->get_pdptr         = kvm_pdptr_read;

 	vcpu->arch.walk_mmu              = &vcpu->arch.nested_mmu;
 }
diff --git a/arch/x86/kvm/vmx/nested.h b/arch/x86/kvm/vmx/nested.h
index c92cea0b8ccc..78e6d9ba5839 100644
--- a/arch/x86/kvm/vmx/nested.h
+++ b/arch/x86/kvm/vmx/nested.h
@@ -37,6 +37,7 @@ void nested_vmx_pmu_refresh(struct kvm_vcpu *vcpu,
 void nested_mark_vmcs12_pages_dirty(struct kvm_vcpu *vcpu);
 bool nested_vmx_check_io_bitmaps(struct kvm_vcpu *vcpu, unsigned int port,
 				 int size);
+void nested_ept_init_mmu_constants(struct kvm_vcpu *vcpu);

 static inline struct vmcs12 *get_vmcs12(struct kvm_vcpu *vcpu)
 {
diff --git a/arch/x86/kvm/vmx/vmx.c b/arch/x86/kvm/vmx/vmx.c
index 40e015e9b260..04edb8a761a8 100644
--- a/arch/x86/kvm/vmx/vmx.c
+++ b/arch/x86/kvm/vmx/vmx.c
@@ -7081,6 +7081,9 @@ static int vmx_vcpu_create(struct kvm_vcpu *vcpu)
 			goto free_vmcs;
 	}

+	if (enable_ept && nested)
+		nested_ept_init_mmu_constants(vcpu);
+
 	return 0;

 free_vmcs:

base-commit: 94fd8078bd4f838cf9ced265e6ac4237cbcba7a1
--
Paolo Bonzini March 9, 2022, 10:34 a.m. UTC | #2
On 3/8/22 21:58, Sean Christopherson wrote:
> Using nested_ops is clever, but IMO unnecessary, especially since we can go even
> further by adding a nEPT specific hook to initialize its constant shadow paging
> stuff.
> 
> Here's what I had written spliced in with your code.  Compile tested only for
> this version.

I'll do something in between, keeping the nested_ops but with three 
functions to initialize the various kvm_mmu structs.

Paolo
diff mbox series

Patch

diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index af90d0653139..b70965235c31 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -1503,6 +1503,11 @@  struct kvm_x86_nested_ops {
 	uint16_t (*get_evmcs_version)(struct kvm_vcpu *vcpu);
 	void (*inject_page_fault)(struct kvm_vcpu *vcpu,
 				  struct x86_exception *fault);
+	void (*inject_nested_tdp_vmexit)(struct kvm_vcpu *vcpu,
+					 struct x86_exception *fault);
+
+	unsigned long (*get_nested_pgd)(struct kvm_vcpu *vcpu);
+	u64 (*get_nested_pdptr)(struct kvm_vcpu *vcpu, int index);
 };
 
 struct kvm_x86_init_ops {
diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index 8eb2c0373309..27cb6ba5a3b0 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -4743,12 +4743,6 @@  static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu, union kvm_mmu_paging_mode cp
 
 	context->cpu_mode.as_u64 = cpu_mode.as_u64;
 	context->root_role.word = root_role.word;
-	context->page_fault = kvm_tdp_page_fault;
-	context->sync_page = nonpaging_sync_page;
-	context->invlpg = NULL;
-	context->get_guest_pgd = kvm_get_guest_cr3;
-	context->get_pdptr = kvm_pdptr_read;
-	context->inject_page_fault = kvm_inject_page_fault;
 
 	if (!is_cr0_pg(context))
 		context->gva_to_gpa = nonpaging_gva_to_gpa;
@@ -4758,7 +4752,6 @@  static void init_kvm_tdp_mmu(struct kvm_vcpu *vcpu, union kvm_mmu_paging_mode cp
 		context->gva_to_gpa = paging32_gva_to_gpa;
 
 	reset_guest_paging_metadata(vcpu, context);
-	reset_tdp_shadow_zero_bits_mask(context);
 }
 
 static void shadow_mmu_init_context(struct kvm_vcpu *vcpu, struct kvm_mmu *context,
@@ -4783,8 +4776,8 @@  static void shadow_mmu_init_context(struct kvm_vcpu *vcpu, struct kvm_mmu *conte
 	reset_shadow_zero_bits_mask(vcpu, context);
 }
 
-static void kvm_init_shadow_mmu(struct kvm_vcpu *vcpu,
-				union kvm_mmu_paging_mode cpu_mode)
+static void init_kvm_softmmu(struct kvm_vcpu *vcpu,
+			     union kvm_mmu_paging_mode cpu_mode)
 {
 	struct kvm_mmu *context = &vcpu->arch.root_mmu;
 	union kvm_mmu_page_role root_role;
@@ -4880,18 +4873,6 @@  void kvm_init_shadow_ept_mmu(struct kvm_vcpu *vcpu, bool execonly,
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_ept_mmu);
 
-static void init_kvm_softmmu(struct kvm_vcpu *vcpu,
-			     union kvm_mmu_paging_mode cpu_mode)
-{
-	struct kvm_mmu *context = &vcpu->arch.root_mmu;
-
-	kvm_init_shadow_mmu(vcpu, cpu_mode);
-
-	context->get_guest_pgd	   = kvm_get_guest_cr3;
-	context->get_pdptr         = kvm_pdptr_read;
-	context->inject_page_fault = kvm_inject_page_fault_shadow;
-}
-
 static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu, union kvm_mmu_paging_mode new_mode)
 {
 	struct kvm_mmu *g_context = &vcpu->arch.nested_mmu;
@@ -4899,16 +4880,7 @@  static void init_kvm_nested_mmu(struct kvm_vcpu *vcpu, union kvm_mmu_paging_mode
 	if (new_mode.as_u64 == g_context->cpu_mode.as_u64)
 		return;
 
-	g_context->cpu_mode.as_u64   = new_mode.as_u64;
-	g_context->get_guest_pgd     = kvm_get_guest_cr3;
-	g_context->get_pdptr         = kvm_pdptr_read;
-	g_context->inject_page_fault = kvm_inject_page_fault;
-
-	/*
-	 * L2 page tables are never shadowed, so there is no need to sync
-	 * SPTEs.
-	 */
-	g_context->invlpg            = NULL;
+	g_context->cpu_mode.as_u64 = new_mode.as_u64;
 
 	/*
 	 * Note that arch.mmu->gva_to_gpa translates l2_gpa to l1_gpa using
@@ -5477,6 +5449,37 @@  int kvm_mmu_create(struct kvm_vcpu *vcpu)
 
 	vcpu->arch.mmu_shadow_page_cache.gfp_zero = __GFP_ZERO;
 
+	vcpu->arch.root_mmu.get_guest_pgd = kvm_get_guest_cr3;
+	vcpu->arch.root_mmu.get_pdptr = kvm_pdptr_read;
+
+	if (tdp_enabled) {
+		vcpu->arch.root_mmu.inject_page_fault = kvm_inject_page_fault;
+		vcpu->arch.root_mmu.page_fault = kvm_tdp_page_fault;
+		vcpu->arch.root_mmu.sync_page = nonpaging_sync_page;
+		vcpu->arch.root_mmu.invlpg = NULL;
+		reset_tdp_shadow_zero_bits_mask(&vcpu->arch.root_mmu);
+
+		vcpu->arch.guest_mmu.get_guest_pgd = kvm_x86_ops.nested_ops->get_nested_pgd;
+		vcpu->arch.guest_mmu.get_pdptr = kvm_x86_ops.nested_ops->get_nested_pdptr;
+		vcpu->arch.guest_mmu.inject_page_fault = kvm_x86_ops.nested_ops->inject_nested_tdp_vmexit;
+	} else {
+		vcpu->arch.root_mmu.inject_page_fault = kvm_inject_page_fault_shadow;
+		/*
+		 * page_fault, sync_page, invlpg are set at runtime depending
+		 * on the guest paging mode.
+		 */
+	}
+
+	vcpu->arch.nested_mmu.get_guest_pgd     = kvm_get_guest_cr3;
+	vcpu->arch.nested_mmu.get_pdptr         = kvm_pdptr_read;
+	vcpu->arch.nested_mmu.inject_page_fault = kvm_inject_page_fault;
+
+	/*
+	 * L2 page tables are never shadowed, so there is no need to sync
+	 * SPTEs.
+	 */
+	vcpu->arch.nested_mmu.invlpg = NULL;
+
 	vcpu->arch.mmu = &vcpu->arch.root_mmu;
 	vcpu->arch.walk_mmu = &vcpu->arch.root_mmu;
 
diff --git a/arch/x86/kvm/svm/nested.c b/arch/x86/kvm/svm/nested.c
index ff58c9ebc552..713c7531de99 100644
--- a/arch/x86/kvm/svm/nested.c
+++ b/arch/x86/kvm/svm/nested.c
@@ -109,10 +109,8 @@  static void nested_svm_init_mmu_context(struct kvm_vcpu *vcpu)
 	kvm_init_shadow_npt_mmu(vcpu, X86_CR0_PG, svm->vmcb01.ptr->save.cr4,
 				svm->vmcb01.ptr->save.efer,
 				svm->nested.ctl.nested_cr3);
-	vcpu->arch.mmu->get_guest_pgd     = nested_svm_get_tdp_cr3;
-	vcpu->arch.mmu->get_pdptr         = nested_svm_get_tdp_pdptr;
-	vcpu->arch.mmu->inject_page_fault = nested_svm_inject_npf_exit;
-	vcpu->arch.walk_mmu              = &vcpu->arch.nested_mmu;
+
+	vcpu->arch.walk_mmu = &vcpu->arch.nested_mmu;
 }
 
 static void nested_svm_uninit_mmu_context(struct kvm_vcpu *vcpu)
@@ -1569,4 +1567,7 @@  struct kvm_x86_nested_ops svm_nested_ops = {
 	.get_state = svm_get_nested_state,
 	.set_state = svm_set_nested_state,
 	.inject_page_fault = svm_inject_page_fault_nested,
+	.inject_nested_tdp_vmexit = nested_svm_inject_npf_exit,
+	.get_nested_pgd = nested_svm_get_tdp_cr3,
+	.get_nested_pdptr = nested_svm_get_tdp_pdptr,
 };
diff --git a/arch/x86/kvm/vmx/nested.c b/arch/x86/kvm/vmx/nested.c
index 564c60566da7..02df0f4fccef 100644
--- a/arch/x86/kvm/vmx/nested.c
+++ b/arch/x86/kvm/vmx/nested.c
@@ -414,9 +414,6 @@  static void nested_ept_init_mmu_context(struct kvm_vcpu *vcpu)
 
 	vcpu->arch.mmu = &vcpu->arch.guest_mmu;
 	nested_ept_new_eptp(vcpu);
-	vcpu->arch.mmu->get_guest_pgd     = nested_ept_get_eptp;
-	vcpu->arch.mmu->inject_page_fault = nested_ept_inject_page_fault;
-	vcpu->arch.mmu->get_pdptr         = kvm_pdptr_read;
 
 	vcpu->arch.walk_mmu              = &vcpu->arch.nested_mmu;
 }
@@ -6805,4 +6802,6 @@  struct kvm_x86_nested_ops vmx_nested_ops = {
 	.enable_evmcs = nested_enable_evmcs,
 	.get_evmcs_version = nested_get_evmcs_version,
 	.inject_page_fault = vmx_inject_page_fault_nested,
+	.inject_nested_tdp_vmexit = nested_ept_inject_page_fault,
+	.get_nested_pgd = nested_ept_get_eptp,
 };