diff mbox series

[17/22] kvm: mmu: Support dirty logging for the TDP MMU

Message ID 20200925212302.3979661-18-bgardon@google.com (mailing list archive)
State New, archived
Headers show
Series Introduce the TDP MMU | expand

Commit Message

Ben Gardon Sept. 25, 2020, 9:22 p.m. UTC
Dirty logging is a key feature of the KVM MMU and must be supported by
the TDP MMU. Add support for both the write protection and PML dirty
logging modes.

Tested by running kvm-unit-tests and KVM selftests on an Intel Haswell
machine. This series introduced no new failures.

This series can be viewed in Gerrit at:
	https://linux-review.googlesource.com/c/virt/kvm/kvm/+/2538

Signed-off-by: Ben Gardon <bgardon@google.com>
---
 arch/x86/kvm/mmu/mmu.c          |  19 +-
 arch/x86/kvm/mmu/mmu_internal.h |   2 +
 arch/x86/kvm/mmu/tdp_iter.c     |  18 ++
 arch/x86/kvm/mmu/tdp_iter.h     |   1 +
 arch/x86/kvm/mmu/tdp_mmu.c      | 295 ++++++++++++++++++++++++++++++++
 arch/x86/kvm/mmu/tdp_mmu.h      |  10 ++
 6 files changed, 343 insertions(+), 2 deletions(-)

Comments

Paolo Bonzini Sept. 26, 2020, 1:04 a.m. UTC | #1
On 25/09/20 23:22, Ben Gardon wrote:
>  				start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
> +	if (kvm->arch.tdp_mmu_enabled)
> +		flush = kvm_tdp_mmu_wrprot_slot(kvm, memslot, false) || flush;
>  	spin_unlock(&kvm->mmu_lock);
>  

In fact you can just pass down the end-level KVM_MAX_HUGEPAGE_LEVEL or
PGLEVEL_4K here to kvm_tdp_mmu_wrprot_slot and from there to
wrprot_gfn_range.

> 
> +		/*
> +		 * Take a reference on the root so that it cannot be freed if
> +		 * this thread releases the MMU lock and yields in this loop.
> +		 */
> +		get_tdp_mmu_root(kvm, root);
> +
> +		spte_set = wrprot_gfn_range(kvm, root, slot->base_gfn,
> +				slot->base_gfn + slot->npages, skip_4k) ||
> +			   spte_set;
> +
> +		put_tdp_mmu_root(kvm, root);


Generalyl using "|=" is the more common idiom in mmu.c.

> +static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
> +			   gfn_t start, gfn_t end)
> ...
> +		__handle_changed_spte(kvm, as_id, iter.gfn, iter.old_spte,
> +				      new_spte, iter.level);
> +		handle_changed_spte_acc_track(iter.old_spte, new_spte,
> +					      iter.level);

Is it worth not calling handle_changed_spte?  handle_changed_spte_dlog
obviously will never fire but duplicating the code is a bit ugly.

I guess this patch is the first one that really gives the "feeling" of
what the data structures look like.  The main difference with the shadow
MMU is that you have the tdp_iter instead of the callback-based code of
slot_handle_level_range, but otherwise it's not hard to follow one if
you know the other.  Reorganizing the code so that mmu.c is little more
than a wrapper around the two will help as well in this respect.

Paolo
Paolo Bonzini Sept. 29, 2020, 3:07 p.m. UTC | #2
On 25/09/20 23:22, Ben Gardon wrote:
> +	for_each_tdp_pte_root(iter, root, start, end) {
> +iteration_start:
> +		if (!is_shadow_present_pte(iter.old_spte))
> +			continue;
> +
> +		/*
> +		 * If this entry points to a page of 4K entries, and 4k entries
> +		 * should be skipped, skip the whole page. If the non-leaf
> +		 * entry is at a higher level, move on to the next,
> +		 * (lower level) entry.
> +		 */
> +		if (!is_last_spte(iter.old_spte, iter.level)) {
> +			if (skip_4k && iter.level == PG_LEVEL_2M) {
> +				tdp_iter_next_no_step_down(&iter);
> +				if (iter.valid && iter.gfn >= end)
> +					goto iteration_start;
> +				else
> +					break;

The iteration_start label confuses me mightily. :)  That would be a case
where iter.gfn >= end (so for_each_tdp_pte_root would exit) but you want
to proceed anyway with the gfn that was found by
tdp_iter_next_no_step_down.  Are you sure you didn't mean

	if (iter.valid && iter.gfn < end)
		goto iteration_start;
	else
		break;

because that would make much more sense: basically a "continue" that
skips the tdp_iter_next.  With the min_level change I suggested no
Friday, it would become something like this:

        for_each_tdp_pte_root_level(iter, root, start, end, min_level) {
                if (!is_shadow_present_pte(iter.old_spte) ||
                    !is_last_spte(iter.old_spte, iter.level))
                        continue;

                new_spte = iter.old_spte & ~PT_WRITABLE_MASK;

		*iter.sptep = new_spte;
                handle_change_spte(kvm, as_id, iter.gfn, iter.old_spte,
				   new_spte, iter.level);

                spte_set = true;
                tdp_mmu_iter_cond_resched(kvm, &iter);
        }

which is all nice and understandable.

Also, related to this function, why ignore the return value of
tdp_mmu_iter_cond_resched?  It does makes sense to assign spte_set =
true since, just like in kvm_mmu_slot_largepage_remove_write_access's
instance of slot_handle_large_level, you don't even need to flush on
cond_resched.  However, in order to do that you would have to add some
kind of "bool flush_on_resched" argument to tdp_mmu_iter_cond_resched,
or have two separate functions tdp_mmu_iter_cond_{flush_and_,}resched.

The same is true of clear_dirty_gfn_range and set_dirty_gfn_range.

Paolo
Sean Christopherson Sept. 30, 2020, 6:04 p.m. UTC | #3
On Fri, Sep 25, 2020 at 02:22:57PM -0700, Ben Gardon wrote:
> +/*
> + * Remove write access from all the SPTEs mapping GFNs in the memslot. If
> + * skip_4k is set, SPTEs that map 4k pages, will not be write-protected.
> + * Returns true if an SPTE has been changed and the TLBs need to be flushed.
> + */
> +bool kvm_tdp_mmu_wrprot_slot(struct kvm *kvm, struct kvm_memory_slot *slot,
> +			     bool skip_4k)
> +{
> +	struct kvm_mmu_page *root;
> +	int root_as_id;
> +	bool spte_set = false;
> +
> +	for_each_tdp_mmu_root(kvm, root) {
> +		root_as_id = kvm_mmu_page_as_id(root);
> +		if (root_as_id != slot->as_id)
> +			continue;

This pattern pops up quite a few times, probably worth adding

#define for_each_tdp_mmu_root_using_memslot(...)	\
	for_each_tdp_mmu_root(...)			\
		if (kvm_mmu_page_as_id(root) != slot->as_id) {
		} else

> +
> +		/*
> +		 * Take a reference on the root so that it cannot be freed if
> +		 * this thread releases the MMU lock and yields in this loop.
> +		 */
> +		get_tdp_mmu_root(kvm, root);
> +
> +		spte_set = wrprot_gfn_range(kvm, root, slot->base_gfn,
> +				slot->base_gfn + slot->npages, skip_4k) ||
> +			   spte_set;
> +
> +		put_tdp_mmu_root(kvm, root);
> +	}
> +
> +	return spte_set;
> +}
> +
> +/*
> + * Clear the dirty status of all the SPTEs mapping GFNs in the memslot. If
> + * AD bits are enabled, this will involve clearing the dirty bit on each SPTE.
> + * If AD bits are not enabled, this will require clearing the writable bit on
> + * each SPTE. Returns true if an SPTE has been changed and the TLBs need to
> + * be flushed.
> + */
> +static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
> +			   gfn_t start, gfn_t end)
> +{
> +	struct tdp_iter iter;
> +	u64 new_spte;
> +	bool spte_set = false;
> +	int as_id = kvm_mmu_page_as_id(root);
> +
> +	for_each_tdp_pte_root(iter, root, start, end) {
> +		if (!is_shadow_present_pte(iter.old_spte) ||
> +		    !is_last_spte(iter.old_spte, iter.level))
> +			continue;

Same thing here, extra wrappers would probably be helpful.  At least add one
for the present case, e.g.

  #define for_each_present_tdp_pte_using_root()

and maybe even

  #define for_each_leaf_tdp_pte_using_root()

since the "!present || !last" pops up 4 or 5 times.

> +
> +		if (spte_ad_need_write_protect(iter.old_spte)) {
> +			if (is_writable_pte(iter.old_spte))
> +				new_spte = iter.old_spte & ~PT_WRITABLE_MASK;
> +			else
> +				continue;
Paolo Bonzini Sept. 30, 2020, 6:08 p.m. UTC | #4
On 30/09/20 20:04, Sean Christopherson wrote:
>> +	for_each_tdp_mmu_root(kvm, root) {
>> +		root_as_id = kvm_mmu_page_as_id(root);
>> +		if (root_as_id != slot->as_id)
>> +			continue;
> This pattern pops up quite a few times, probably worth adding
> 
> #define for_each_tdp_mmu_root_using_memslot(...)	\
> 	for_each_tdp_mmu_root(...)			\
> 		if (kvm_mmu_page_as_id(root) != slot->as_id) {
> 		} else
> 

It's not really relevant that it's a memslot, but

	for_each_tdp_mmu_root_using_as_id

makes sense too.

Paolo
Ben Gardon Oct. 8, 2020, 6:27 p.m. UTC | #5
On Fri, Sep 25, 2020 at 6:04 PM Paolo Bonzini <pbonzini@redhat.com> wrote:
>
> On 25/09/20 23:22, Ben Gardon wrote:
> >                               start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
> > +     if (kvm->arch.tdp_mmu_enabled)
> > +             flush = kvm_tdp_mmu_wrprot_slot(kvm, memslot, false) || flush;
> >       spin_unlock(&kvm->mmu_lock);
> >
>
> In fact you can just pass down the end-level KVM_MAX_HUGEPAGE_LEVEL or
> PGLEVEL_4K here to kvm_tdp_mmu_wrprot_slot and from there to
> wrprot_gfn_range.

That makes sense. My only worry there is the added complexity of error
handling values besides PG_LEVEL_2M and PG_LEVEL_4K. Since there are
only two callers, I don't think that will be too much of a problem
though. I don't think KVM_MAX_HUGEPAGE_LEVEL would actually be a good
value to pass in as I don't think that would write protect 2M
mappings. KVM_MAX_HUGEPAGE_LEVEL is defined as PG_LEVEL_1G, or 3.

>
> >
> > +             /*
> > +              * Take a reference on the root so that it cannot be freed if
> > +              * this thread releases the MMU lock and yields in this loop.
> > +              */
> > +             get_tdp_mmu_root(kvm, root);
> > +
> > +             spte_set = wrprot_gfn_range(kvm, root, slot->base_gfn,
> > +                             slot->base_gfn + slot->npages, skip_4k) ||
> > +                        spte_set;
> > +
> > +             put_tdp_mmu_root(kvm, root);
>
>
> Generalyl using "|=" is the more common idiom in mmu.c.

I changed to this in response to some feedback on the RFC, about
mixing bitwise ops and bools, but I like the |= syntax more as well.

>
> > +static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
> > +                        gfn_t start, gfn_t end)
> > ...
> > +             __handle_changed_spte(kvm, as_id, iter.gfn, iter.old_spte,
> > +                                   new_spte, iter.level);
> > +             handle_changed_spte_acc_track(iter.old_spte, new_spte,
> > +                                           iter.level);
>
> Is it worth not calling handle_changed_spte?  handle_changed_spte_dlog
> obviously will never fire but duplicating the code is a bit ugly.
>
> I guess this patch is the first one that really gives the "feeling" of
> what the data structures look like.  The main difference with the shadow
> MMU is that you have the tdp_iter instead of the callback-based code of
> slot_handle_level_range, but otherwise it's not hard to follow one if
> you know the other.  Reorganizing the code so that mmu.c is little more
> than a wrapper around the two will help as well in this respect.
>
> Paolo
>
diff mbox series

Patch

diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index 0d80abe82ca93..b9074603f9df1 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -201,7 +201,7 @@  static u64 __read_mostly shadow_nx_mask;
 static u64 __read_mostly shadow_x_mask;	/* mutual exclusive with nx_mask */
 u64 __read_mostly shadow_user_mask;
 u64 __read_mostly shadow_accessed_mask;
-static u64 __read_mostly shadow_dirty_mask;
+u64 __read_mostly shadow_dirty_mask;
 static u64 __read_mostly shadow_mmio_value;
 static u64 __read_mostly shadow_mmio_access_mask;
 u64 __read_mostly shadow_present_mask;
@@ -324,7 +324,7 @@  inline bool spte_ad_enabled(u64 spte)
 	return (spte & SPTE_SPECIAL_MASK) != SPTE_AD_DISABLED_MASK;
 }
 
-static inline bool spte_ad_need_write_protect(u64 spte)
+inline bool spte_ad_need_write_protect(u64 spte)
 {
 	MMU_WARN_ON(is_mmio_spte(spte));
 	return (spte & SPTE_SPECIAL_MASK) != SPTE_AD_ENABLED_MASK;
@@ -1591,6 +1591,9 @@  static void kvm_mmu_write_protect_pt_masked(struct kvm *kvm,
 {
 	struct kvm_rmap_head *rmap_head;
 
+	if (kvm->arch.tdp_mmu_enabled)
+		kvm_tdp_mmu_clear_dirty_pt_masked(kvm, slot,
+				slot->base_gfn + gfn_offset, mask, true);
 	while (mask) {
 		rmap_head = __gfn_to_rmap(slot->base_gfn + gfn_offset + __ffs(mask),
 					  PG_LEVEL_4K, slot);
@@ -1617,6 +1620,9 @@  void kvm_mmu_clear_dirty_pt_masked(struct kvm *kvm,
 {
 	struct kvm_rmap_head *rmap_head;
 
+	if (kvm->arch.tdp_mmu_enabled)
+		kvm_tdp_mmu_clear_dirty_pt_masked(kvm, slot,
+				slot->base_gfn + gfn_offset, mask, false);
 	while (mask) {
 		rmap_head = __gfn_to_rmap(slot->base_gfn + gfn_offset + __ffs(mask),
 					  PG_LEVEL_4K, slot);
@@ -5954,6 +5960,8 @@  void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
 	spin_lock(&kvm->mmu_lock);
 	flush = slot_handle_level(kvm, memslot, slot_rmap_write_protect,
 				start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
+	if (kvm->arch.tdp_mmu_enabled)
+		flush = kvm_tdp_mmu_wrprot_slot(kvm, memslot, false) || flush;
 	spin_unlock(&kvm->mmu_lock);
 
 	/*
@@ -6034,6 +6042,7 @@  void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
 	kvm_flush_remote_tlbs_with_address(kvm, memslot->base_gfn,
 					   memslot->npages);
 }
+EXPORT_SYMBOL_GPL(kvm_arch_flush_remote_tlbs_memslot);
 
 void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
 				   struct kvm_memory_slot *memslot)
@@ -6042,6 +6051,8 @@  void kvm_mmu_slot_leaf_clear_dirty(struct kvm *kvm,
 
 	spin_lock(&kvm->mmu_lock);
 	flush = slot_handle_leaf(kvm, memslot, __rmap_clear_dirty, false);
+	if (kvm->arch.tdp_mmu_enabled)
+		flush = kvm_tdp_mmu_clear_dirty_slot(kvm, memslot) || flush;
 	spin_unlock(&kvm->mmu_lock);
 
 	/*
@@ -6063,6 +6074,8 @@  void kvm_mmu_slot_largepage_remove_write_access(struct kvm *kvm,
 	spin_lock(&kvm->mmu_lock);
 	flush = slot_handle_large_level(kvm, memslot, slot_rmap_write_protect,
 					false);
+	if (kvm->arch.tdp_mmu_enabled)
+		flush = kvm_tdp_mmu_wrprot_slot(kvm, memslot, true) || flush;
 	spin_unlock(&kvm->mmu_lock);
 
 	if (flush)
@@ -6077,6 +6090,8 @@  void kvm_mmu_slot_set_dirty(struct kvm *kvm,
 
 	spin_lock(&kvm->mmu_lock);
 	flush = slot_handle_all_level(kvm, memslot, __rmap_set_dirty, false);
+	if (kvm->arch.tdp_mmu_enabled)
+		flush = kvm_tdp_mmu_slot_set_dirty(kvm, memslot) || flush;
 	spin_unlock(&kvm->mmu_lock);
 
 	if (flush)
diff --git a/arch/x86/kvm/mmu/mmu_internal.h b/arch/x86/kvm/mmu/mmu_internal.h
index 8eaa6e4764bce..1a777ccfde44e 100644
--- a/arch/x86/kvm/mmu/mmu_internal.h
+++ b/arch/x86/kvm/mmu/mmu_internal.h
@@ -89,6 +89,7 @@  bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
 extern u64 shadow_user_mask;
 extern u64 shadow_accessed_mask;
 extern u64 shadow_present_mask;
+extern u64 shadow_dirty_mask;
 
 #define ACC_EXEC_MASK    1
 #define ACC_WRITE_MASK   PT_WRITABLE_MASK
@@ -112,6 +113,7 @@  bool is_access_track_spte(u64 spte);
 bool is_accessed_spte(u64 spte);
 bool spte_ad_enabled(u64 spte);
 bool is_executable_pte(u64 spte);
+bool spte_ad_need_write_protect(u64 spte);
 
 void kvm_flush_remote_tlbs_with_address(struct kvm *kvm, u64 start_gfn,
 					u64 pages);
diff --git a/arch/x86/kvm/mmu/tdp_iter.c b/arch/x86/kvm/mmu/tdp_iter.c
index 6c1a38429c81a..132e286150856 100644
--- a/arch/x86/kvm/mmu/tdp_iter.c
+++ b/arch/x86/kvm/mmu/tdp_iter.c
@@ -178,3 +178,21 @@  void tdp_iter_refresh_walk(struct tdp_iter *iter)
 	tdp_iter_start(iter, iter->pt_path[iter->root_level - 1],
 		       iter->root_level, goal_gfn);
 }
+
+/*
+ * Move on to the next SPTE, but do not move down into a child page table even
+ * if the current SPTE leads to one.
+ */
+void tdp_iter_next_no_step_down(struct tdp_iter *iter)
+{
+	bool done;
+
+	done = try_step_side(iter);
+	while (!done) {
+		if (!try_step_up(iter)) {
+			iter->valid = false;
+			break;
+		}
+		done = try_step_side(iter);
+	}
+}
diff --git a/arch/x86/kvm/mmu/tdp_iter.h b/arch/x86/kvm/mmu/tdp_iter.h
index 34da3bdada436..d0e65a62ea7d9 100644
--- a/arch/x86/kvm/mmu/tdp_iter.h
+++ b/arch/x86/kvm/mmu/tdp_iter.h
@@ -50,5 +50,6 @@  void tdp_iter_start(struct tdp_iter *iter, u64 *root_pt, int root_level,
 		    gfn_t goal_gfn);
 void tdp_iter_next(struct tdp_iter *iter);
 void tdp_iter_refresh_walk(struct tdp_iter *iter);
+void tdp_iter_next_no_step_down(struct tdp_iter *iter);
 
 #endif /* __KVM_X86_MMU_TDP_ITER_H */
diff --git a/arch/x86/kvm/mmu/tdp_mmu.c b/arch/x86/kvm/mmu/tdp_mmu.c
index bbe973d3f8084..e5cb7f0ec23e8 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.c
+++ b/arch/x86/kvm/mmu/tdp_mmu.c
@@ -700,6 +700,7 @@  static int age_gfn_range(struct kvm *kvm, struct kvm_memory_slot *slot,
 
 			new_spte = mark_spte_for_access_track(new_spte);
 		}
+		new_spte &= ~shadow_dirty_mask;
 
 		*iter.sptep = new_spte;
 		__handle_changed_spte(kvm, as_id, iter.gfn, iter.old_spte,
@@ -804,3 +805,297 @@  int kvm_tdp_mmu_set_spte_hva(struct kvm *kvm, unsigned long address,
 					    set_tdp_spte);
 }
 
+/*
+ * Remove write access from all the SPTEs mapping GFNs [start, end). If
+ * skip_4k is set, SPTEs that map 4k pages, will not be write-protected.
+ * Returns true if an SPTE has been changed and the TLBs need to be flushed.
+ */
+static bool wrprot_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
+			     gfn_t start, gfn_t end, bool skip_4k)
+{
+	struct tdp_iter iter;
+	u64 new_spte;
+	bool spte_set = false;
+	int as_id = kvm_mmu_page_as_id(root);
+
+	for_each_tdp_pte_root(iter, root, start, end) {
+iteration_start:
+		if (!is_shadow_present_pte(iter.old_spte))
+			continue;
+
+		/*
+		 * If this entry points to a page of 4K entries, and 4k entries
+		 * should be skipped, skip the whole page. If the non-leaf
+		 * entry is at a higher level, move on to the next,
+		 * (lower level) entry.
+		 */
+		if (!is_last_spte(iter.old_spte, iter.level)) {
+			if (skip_4k && iter.level == PG_LEVEL_2M) {
+				tdp_iter_next_no_step_down(&iter);
+				if (iter.valid && iter.gfn >= end)
+					goto iteration_start;
+				else
+					break;
+			} else {
+				continue;
+			}
+		}
+
+		WARN_ON(skip_4k && iter.level == PG_LEVEL_4K);
+
+		new_spte = iter.old_spte & ~PT_WRITABLE_MASK;
+
+		*iter.sptep = new_spte;
+		__handle_changed_spte(kvm, as_id, iter.gfn, iter.old_spte,
+				      new_spte, iter.level);
+		handle_changed_spte_acc_track(iter.old_spte, new_spte,
+					      iter.level);
+		spte_set = true;
+
+		tdp_mmu_iter_cond_resched(kvm, &iter);
+	}
+	return spte_set;
+}
+
+/*
+ * Remove write access from all the SPTEs mapping GFNs in the memslot. If
+ * skip_4k is set, SPTEs that map 4k pages, will not be write-protected.
+ * Returns true if an SPTE has been changed and the TLBs need to be flushed.
+ */
+bool kvm_tdp_mmu_wrprot_slot(struct kvm *kvm, struct kvm_memory_slot *slot,
+			     bool skip_4k)
+{
+	struct kvm_mmu_page *root;
+	int root_as_id;
+	bool spte_set = false;
+
+	for_each_tdp_mmu_root(kvm, root) {
+		root_as_id = kvm_mmu_page_as_id(root);
+		if (root_as_id != slot->as_id)
+			continue;
+
+		/*
+		 * Take a reference on the root so that it cannot be freed if
+		 * this thread releases the MMU lock and yields in this loop.
+		 */
+		get_tdp_mmu_root(kvm, root);
+
+		spte_set = wrprot_gfn_range(kvm, root, slot->base_gfn,
+				slot->base_gfn + slot->npages, skip_4k) ||
+			   spte_set;
+
+		put_tdp_mmu_root(kvm, root);
+	}
+
+	return spte_set;
+}
+
+/*
+ * Clear the dirty status of all the SPTEs mapping GFNs in the memslot. If
+ * AD bits are enabled, this will involve clearing the dirty bit on each SPTE.
+ * If AD bits are not enabled, this will require clearing the writable bit on
+ * each SPTE. Returns true if an SPTE has been changed and the TLBs need to
+ * be flushed.
+ */
+static bool clear_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
+			   gfn_t start, gfn_t end)
+{
+	struct tdp_iter iter;
+	u64 new_spte;
+	bool spte_set = false;
+	int as_id = kvm_mmu_page_as_id(root);
+
+	for_each_tdp_pte_root(iter, root, start, end) {
+		if (!is_shadow_present_pte(iter.old_spte) ||
+		    !is_last_spte(iter.old_spte, iter.level))
+			continue;
+
+		if (spte_ad_need_write_protect(iter.old_spte)) {
+			if (is_writable_pte(iter.old_spte))
+				new_spte = iter.old_spte & ~PT_WRITABLE_MASK;
+			else
+				continue;
+		} else {
+			if (iter.old_spte & shadow_dirty_mask)
+				new_spte = iter.old_spte & ~shadow_dirty_mask;
+			else
+				continue;
+		}
+
+		*iter.sptep = new_spte;
+		__handle_changed_spte(kvm, as_id, iter.gfn, iter.old_spte,
+				      new_spte, iter.level);
+		handle_changed_spte_acc_track(iter.old_spte, new_spte,
+					      iter.level);
+		spte_set = true;
+
+		tdp_mmu_iter_cond_resched(kvm, &iter);
+	}
+	return spte_set;
+}
+
+/*
+ * Clear the dirty status of all the SPTEs mapping GFNs in the memslot. If
+ * AD bits are enabled, this will involve clearing the dirty bit on each SPTE.
+ * If AD bits are not enabled, this will require clearing the writable bit on
+ * each SPTE. Returns true if an SPTE has been changed and the TLBs need to
+ * be flushed.
+ */
+bool kvm_tdp_mmu_clear_dirty_slot(struct kvm *kvm, struct kvm_memory_slot *slot)
+{
+	struct kvm_mmu_page *root;
+	int root_as_id;
+	bool spte_set = false;
+
+	for_each_tdp_mmu_root(kvm, root) {
+		root_as_id = kvm_mmu_page_as_id(root);
+		if (root_as_id != slot->as_id)
+			continue;
+
+		/*
+		 * Take a reference on the root so that it cannot be freed if
+		 * this thread releases the MMU lock and yields in this loop.
+		 */
+		get_tdp_mmu_root(kvm, root);
+
+		spte_set = clear_dirty_gfn_range(kvm, root, slot->base_gfn,
+				slot->base_gfn + slot->npages) || spte_set;
+
+		put_tdp_mmu_root(kvm, root);
+	}
+
+	return spte_set;
+}
+
+/*
+ * Clears the dirty status of all the 4k SPTEs mapping GFNs for which a bit is
+ * set in mask, starting at gfn. The given memslot is expected to contain all
+ * the GFNs represented by set bits in the mask. If AD bits are enabled,
+ * clearing the dirty status will involve clearing the dirty bit on each SPTE
+ * or, if AD bits are not enabled, clearing the writable bit on each SPTE.
+ */
+static void clear_dirty_pt_masked(struct kvm *kvm, struct kvm_mmu_page *root,
+				  gfn_t gfn, unsigned long mask, bool wrprot)
+{
+	struct tdp_iter iter;
+	u64 new_spte;
+	int as_id = kvm_mmu_page_as_id(root);
+
+	for_each_tdp_pte_root(iter, root, gfn + __ffs(mask),
+			      gfn + BITS_PER_LONG) {
+		if (!mask)
+			break;
+
+		if (!is_shadow_present_pte(iter.old_spte) ||
+		    !is_last_spte(iter.old_spte, iter.level) ||
+		    iter.level > PG_LEVEL_4K ||
+		    !(mask & (1UL << (iter.gfn - gfn))))
+			continue;
+
+		if (wrprot || spte_ad_need_write_protect(iter.old_spte)) {
+			if (is_writable_pte(iter.old_spte))
+				new_spte = iter.old_spte & ~PT_WRITABLE_MASK;
+			else
+				continue;
+		} else {
+			if (iter.old_spte & shadow_dirty_mask)
+				new_spte = iter.old_spte & ~shadow_dirty_mask;
+			else
+				continue;
+		}
+
+		*iter.sptep = new_spte;
+		__handle_changed_spte(kvm, as_id, iter.gfn, iter.old_spte,
+				      new_spte, iter.level);
+		handle_changed_spte_acc_track(iter.old_spte, new_spte,
+					      iter.level);
+
+		mask &= ~(1UL << (iter.gfn - gfn));
+	}
+}
+
+/*
+ * Clears the dirty status of all the 4k SPTEs mapping GFNs for which a bit is
+ * set in mask, starting at gfn. The given memslot is expected to contain all
+ * the GFNs represented by set bits in the mask. If AD bits are enabled,
+ * clearing the dirty status will involve clearing the dirty bit on each SPTE
+ * or, if AD bits are not enabled, clearing the writable bit on each SPTE.
+ */
+void kvm_tdp_mmu_clear_dirty_pt_masked(struct kvm *kvm,
+				       struct kvm_memory_slot *slot,
+				       gfn_t gfn, unsigned long mask,
+				       bool wrprot)
+{
+	struct kvm_mmu_page *root;
+	int root_as_id;
+
+	lockdep_assert_held(&kvm->mmu_lock);
+	for_each_tdp_mmu_root(kvm, root) {
+		root_as_id = kvm_mmu_page_as_id(root);
+		if (root_as_id != slot->as_id)
+			continue;
+
+		clear_dirty_pt_masked(kvm, root, gfn, mask, wrprot);
+	}
+}
+
+/*
+ * Set the dirty status of all the SPTEs mapping GFNs in the memslot. This is
+ * only used for PML, and so will involve setting the dirty bit on each SPTE.
+ * Returns true if an SPTE has been changed and the TLBs need to be flushed.
+ */
+static bool set_dirty_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
+				gfn_t start, gfn_t end)
+{
+	struct tdp_iter iter;
+	u64 new_spte;
+	bool spte_set = false;
+	int as_id = kvm_mmu_page_as_id(root);
+
+	for_each_tdp_pte_root(iter, root, start, end) {
+		if (!is_shadow_present_pte(iter.old_spte))
+			continue;
+
+		new_spte = iter.old_spte | shadow_dirty_mask;
+
+		*iter.sptep = new_spte;
+		handle_changed_spte(kvm, as_id, iter.gfn, iter.old_spte,
+				    new_spte, iter.level);
+		spte_set = true;
+
+		tdp_mmu_iter_cond_resched(kvm, &iter);
+	}
+
+	return spte_set;
+}
+
+/*
+ * Set the dirty status of all the SPTEs mapping GFNs in the memslot. This is
+ * only used for PML, and so will involve setting the dirty bit on each SPTE.
+ * Returns true if an SPTE has been changed and the TLBs need to be flushed.
+ */
+bool kvm_tdp_mmu_slot_set_dirty(struct kvm *kvm, struct kvm_memory_slot *slot)
+{
+	struct kvm_mmu_page *root;
+	int root_as_id;
+	bool spte_set = false;
+
+	for_each_tdp_mmu_root(kvm, root) {
+		root_as_id = kvm_mmu_page_as_id(root);
+		if (root_as_id != slot->as_id)
+			continue;
+
+		/*
+		 * Take a reference on the root so that it cannot be freed if
+		 * this thread releases the MMU lock and yields in this loop.
+		 */
+		get_tdp_mmu_root(kvm, root);
+
+		spte_set = set_dirty_gfn_range(kvm, root, slot->base_gfn,
+				slot->base_gfn + slot->npages) || spte_set;
+
+		put_tdp_mmu_root(kvm, root);
+	}
+	return spte_set;
+}
+
diff --git a/arch/x86/kvm/mmu/tdp_mmu.h b/arch/x86/kvm/mmu/tdp_mmu.h
index 5a399aa60b8d8..2c9322ba3462b 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.h
+++ b/arch/x86/kvm/mmu/tdp_mmu.h
@@ -28,4 +28,14 @@  int kvm_tdp_mmu_test_age_hva(struct kvm *kvm, unsigned long hva);
 
 int kvm_tdp_mmu_set_spte_hva(struct kvm *kvm, unsigned long address,
 			     pte_t *host_ptep);
+
+bool kvm_tdp_mmu_wrprot_slot(struct kvm *kvm, struct kvm_memory_slot *slot,
+			     bool skip_4k);
+bool kvm_tdp_mmu_clear_dirty_slot(struct kvm *kvm,
+				  struct kvm_memory_slot *slot);
+void kvm_tdp_mmu_clear_dirty_pt_masked(struct kvm *kvm,
+				       struct kvm_memory_slot *slot,
+				       gfn_t gfn, unsigned long mask,
+				       bool wrprot);
+bool kvm_tdp_mmu_slot_set_dirty(struct kvm *kvm, struct kvm_memory_slot *slot);
 #endif /* __KVM_X86_MMU_TDP_MMU_H */