diff mbox series

[RFC,13/28] kvm: mmu: Add an iterator for concurrent paging structure walks

Message ID 20190926231824.149014-14-bgardon@google.com (mailing list archive)
State New, archived
Headers show
Series kvm: mmu: Rework the x86 TDP direct mapped case | expand

Commit Message

Ben Gardon Sept. 26, 2019, 11:18 p.m. UTC
Add a utility for concurrent paging structure traversals. This iterator
uses several mechanisms to ensure that its accesses to paging structure
memory are safe, and that memory can be freed safely in the face of
lockless access. The purpose of the iterator is to create a unified
pattern for concurrent paging structure traversals and simplify the
implementation of other MMU functions.

This iterator implements a pre-order traversal of PTEs for a given GFN
range within a given address space. The iterator abstracts away
bookkeeping on successful changes to PTEs, retrying on failed PTE
modifications, TLB flushing, and yielding during long operations.

Signed-off-by: Ben Gardon <bgardon@google.com>
---
 arch/x86/kvm/mmu.c      | 455 ++++++++++++++++++++++++++++++++++++++++
 arch/x86/kvm/mmutrace.h |  50 +++++
 2 files changed, 505 insertions(+)

Comments

Sean Christopherson Dec. 3, 2019, 2:15 a.m. UTC | #1
On Thu, Sep 26, 2019 at 04:18:09PM -0700, Ben Gardon wrote:
> Add a utility for concurrent paging structure traversals. This iterator
> uses several mechanisms to ensure that its accesses to paging structure
> memory are safe, and that memory can be freed safely in the face of
> lockless access. The purpose of the iterator is to create a unified
> pattern for concurrent paging structure traversals and simplify the
> implementation of other MMU functions.
> 
> This iterator implements a pre-order traversal of PTEs for a given GFN
> range within a given address space. The iterator abstracts away
> bookkeeping on successful changes to PTEs, retrying on failed PTE
> modifications, TLB flushing, and yielding during long operations.
> 
> Signed-off-by: Ben Gardon <bgardon@google.com>
> ---
>  arch/x86/kvm/mmu.c      | 455 ++++++++++++++++++++++++++++++++++++++++
>  arch/x86/kvm/mmutrace.h |  50 +++++
>  2 files changed, 505 insertions(+)

...

> +/*
> + * Sets a direct walk iterator to seek the gfn range [start, end).
> + * If end is greater than the maximum possible GFN, it will be changed to the
> + * maximum possible gfn + 1. (Note that start/end is and inclusive/exclusive
> + * range, so the last gfn to be interated over would be the largest possible
> + * GFN, in this scenario.)
> + */
> +__attribute__((unused))
> +static void direct_walk_iterator_setup_walk(struct direct_walk_iterator *iter,
> +	struct kvm *kvm, int as_id, gfn_t start, gfn_t end,
> +	enum mmu_lock_mode lock_mode)

Echoing earlier patches, please introduce variables/flags/functions along
with their users.  I have a feeling you're adding some of the unused
functions so that all flags/variables in struct direct_walk_iterator can
be in place from the get-go, but that actually makes everything much harder
to review.

> +{
> +	BUG_ON(!kvm->arch.direct_mmu_enabled);
> +	BUG_ON((lock_mode & MMU_WRITE_LOCK) && (lock_mode & MMU_READ_LOCK));
> +	BUG_ON(as_id < 0);
> +	BUG_ON(as_id >= KVM_ADDRESS_SPACE_NUM);
> +	BUG_ON(!VALID_PAGE(kvm->arch.direct_root_hpa[as_id]));
> +
> +	/* End cannot be greater than the maximum possible gfn. */
> +	end = min(end, 1ULL << (PT64_ROOT_4LEVEL * PT64_PT_BITS));
> +
> +	iter->as_id = as_id;
> +	iter->pt_path[PT64_ROOT_4LEVEL - 1] =
> +			(u64 *)__va(kvm->arch.direct_root_hpa[as_id]);
> +
> +	iter->walk_start = start;
> +	iter->walk_end = end;
> +	iter->target_gfn = start;
> +
> +	iter->lock_mode = lock_mode;
> +	iter->kvm = kvm;
> +	iter->tlbs_dirty = 0;
> +
> +	direct_walk_iterator_start_traversal(iter);
> +}

...

> +static void direct_walk_iterator_cond_resched(struct direct_walk_iterator *iter)
> +{
> +	if (!(iter->lock_mode & MMU_LOCK_MAY_RESCHED) || !need_resched())
> +		return;
> +
> +	direct_walk_iterator_prepare_cond_resched(iter);
> +	cond_resched();
> +	direct_walk_iterator_finish_cond_resched(iter);
> +}
> +
> +static bool direct_walk_iterator_next_pte(struct direct_walk_iterator *iter)
> +{
> +	/*
> +	 * This iterator could be iterating over a large number of PTEs, such
> +	 * that if this thread did not yield, it would cause scheduler\
> +	 * problems. To avoid this, yield if needed. Note the check on
> +	 * MMU_LOCK_MAY_RESCHED in direct_walk_iterator_cond_resched. This
> +	 * iterator will not yield unless that flag is set in its lock_mode.
> +	 */
> +	direct_walk_iterator_cond_resched(iter);

This looks very fragile, e.g. one of the future patches even has to avoid
problems with this code by limiting the number of PTEs it processes.

> +
> +	while (true) {
> +		if (!direct_walk_iterator_next_pte_raw(iter))

Implicitly initializing the iterator during next_pte_raw() is asking for
problems, e.g. @walk_in_progress should not exist.  The standard kernel
pattern for fancy iterators is to wrap the initialization, deref, and
advancement operators in a macro, e.g. something like:

	for_each_direct_pte(...) {

	}

That might require additional control flow logic in the users of the
iterator, but if so that's probably a good thing in terms of readability
and robustness.  E.g. verifying that rcu_read_unlock() is guaranteed to
be called is extremely difficult as rcu_read_lock() is buried in this
low level helper but the iterator relies on the top-level caller to
terminate traversal.

See mem_cgroup_iter_break() for one example of handling an iter walk
where an action needs to taken when the walk terminates early.

> +			return false;
> +
> +		direct_walk_iterator_recalculate_output_fields(iter);
> +		if (iter->old_pte != DISCONNECTED_PTE)
> +			break;
> +
> +		/*
> +		 * The iterator has encountered a disconnected pte, so it is in
> +		 * a page that has been disconnected from the root. Restart the
> +		 * traversal from the root in this case.
> +		 */
> +		direct_walk_iterator_reset_traversal(iter);

I understand wanting to hide details to eliminate copy-paste, but this
goes too far and makes it too difficult to understand the flow of the
top-level walks.  Ditto for burying retry_pte() in set_pte().  I'd say it
also applies to skip_step_down(), but AFAICT that's dead code.

Off-topic for a second, the super long direct_walk_iterator_... names
make me want to simply call this new MMU the "tdp MMU" and just live with
the discrepancy until the old shadow-based TDP MMU can be nuked.  Then we
could have tdp_iter_blah_blah_blah(), for_each_tdp_present_pte(), etc...

Back to the iterator, I think it can be massaged into a standard for loop
approach without polluting the top level walkers much.  The below code is
the basic idea, e.g. the macros won't compile, probably doesn't terminate
the walk correct, rescheduling is missing, etc...

Note, open coding the down/sideways/up helpers is 50% personal preference,
50% because gfn_start and gfn_end are now local variables, and 50% because
it was the easiest way to learn the code.  I wouldn't argue too much about
having one or more of the helpers.


static void tdp_iter_break(struct tdp_iter *iter)
{
	/* TLB flush, RCU unlock, etc...)
}

static void tdp_iter_next(struct tdp_iter *iter, bool *retry)
{
	gfn_t gfn_start, gfn_end;
	u64 *child_pt;

	if (*retry) {
		*retry = false;
		return;
	}

	/*
	 * Reread the pte before stepping down to avoid traversing into page
	 * tables that are no longer linked from this entry. This is not
	 * needed for correctness - just a small optimization.
	 */
	iter->old_pte = READ_ONCE(*iter->ptep);

	/* Try to step down. */
	child_pt = pte_to_child_pt(iter->old_pte, iter->level);
	if (child_pt) {
		child_pt = rcu_dereference(child_pt);
		iter->level--;
		iter->pt_path[iter->level - 1] = child_pt;
		return;
	}

step_sideways:
	/* Try to step sideways. */
	gfn_start = ALIGN_DOWN(iter->target_gfn,
			       KVM_PAGES_PER_HPAGE(iter->level));
	gfn_end = gfn_start + KVM_PAGES_PER_HPAGE(iter->level)

	/*
	 * If the current gfn maps past the target gfn range, the next entry in
	 * the current page table will be outside the target range.
	 */
	if (gfn_end >= iter->walk_end ||
	    !(gfn_end % KVM_PAGES_PER_HPAGE(iter->level + 1))) {
		/* Try to step up. */
		iter->level++;

		if (iter->level > PT64_ROOT_4LEVEL; {
			/* This is ugly, there's probably a better solution. */
			tdp_iter_break(iter);
			return;
		}
		goto step_sideways;
	}

	iter->target_gfn = gfn_end;
	iter->ptep = iter->pt_path[iter->level - 1] +
			PT64_INDEX(iter->target_gfn << PAGE_SHIFT, iter->level);
	iter->old_pte = READ_ONCE(*iter->ptep);
}

#define for_each_tdp_pte(iter, start, end, retry)
	for (tdp_iter_start(&iter, start, end);
	     iter->level <= PT64_ROOT_4LEVEL;
	     tdp_iter_next(&iter, &retry))

#define for_each_tdp_present_pte(iter, start, end, retry)
	for_each_tdp_pte(iter, start, end, retry)
		if (!is_present_direct_pte(iter->old_pte)) {

		} else

#define for_each_tdp_present_leaf_pte(iter, start, end, retry)
	for_each_tdp_pte(iter, start, end, retry)
		if (!is_present_direct_pte(iter->old_pte) ||
		    !is_last_spte(iter->old_pte, iter->level))
		{

		} else

/*
 * Marks the range of gfns, [start, end), non-present.
 */
static bool zap_direct_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
				 gfn_t end, enum mmu_lock_mode lock_mode)
{
	struct direct_walk_iterator iter;
	bool retry;

	tdp_iter_init(&iter, kvm, as_id, lock_mode);

restart:
	retry = false;
	for_each_tdp_present_pte(iter, start, end, retry) {
		if (tdp_iter_set_pte(&iter, 0))
			retry = true;

		if (tdp_iter_disconnected(&iter)) {
			tdp_iter_break(&iter);
			goto restart;
		}
	}
}
Ben Gardon Dec. 18, 2019, 6:25 p.m. UTC | #2
On Mon, Dec 2, 2019 at 6:15 PM Sean Christopherson
<sean.j.christopherson@intel.com> wrote:
>
> On Thu, Sep 26, 2019 at 04:18:09PM -0700, Ben Gardon wrote:
> > Add a utility for concurrent paging structure traversals. This iterator
> > uses several mechanisms to ensure that its accesses to paging structure
> > memory are safe, and that memory can be freed safely in the face of
> > lockless access. The purpose of the iterator is to create a unified
> > pattern for concurrent paging structure traversals and simplify the
> > implementation of other MMU functions.
> >
> > This iterator implements a pre-order traversal of PTEs for a given GFN
> > range within a given address space. The iterator abstracts away
> > bookkeeping on successful changes to PTEs, retrying on failed PTE
> > modifications, TLB flushing, and yielding during long operations.
> >
> > Signed-off-by: Ben Gardon <bgardon@google.com>
> > ---
> >  arch/x86/kvm/mmu.c      | 455 ++++++++++++++++++++++++++++++++++++++++
> >  arch/x86/kvm/mmutrace.h |  50 +++++
> >  2 files changed, 505 insertions(+)
>
> ...
>
> > +/*
> > + * Sets a direct walk iterator to seek the gfn range [start, end).
> > + * If end is greater than the maximum possible GFN, it will be changed to the
> > + * maximum possible gfn + 1. (Note that start/end is and inclusive/exclusive
> > + * range, so the last gfn to be interated over would be the largest possible
> > + * GFN, in this scenario.)
> > + */
> > +__attribute__((unused))
> > +static void direct_walk_iterator_setup_walk(struct direct_walk_iterator *iter,
> > +     struct kvm *kvm, int as_id, gfn_t start, gfn_t end,
> > +     enum mmu_lock_mode lock_mode)
>
> Echoing earlier patches, please introduce variables/flags/functions along
> with their users.  I have a feeling you're adding some of the unused
> functions so that all flags/variables in struct direct_walk_iterator can
> be in place from the get-go, but that actually makes everything much harder
> to review.
>
> > +{
> > +     BUG_ON(!kvm->arch.direct_mmu_enabled);
> > +     BUG_ON((lock_mode & MMU_WRITE_LOCK) && (lock_mode & MMU_READ_LOCK));
> > +     BUG_ON(as_id < 0);
> > +     BUG_ON(as_id >= KVM_ADDRESS_SPACE_NUM);
> > +     BUG_ON(!VALID_PAGE(kvm->arch.direct_root_hpa[as_id]));
> > +
> > +     /* End cannot be greater than the maximum possible gfn. */
> > +     end = min(end, 1ULL << (PT64_ROOT_4LEVEL * PT64_PT_BITS));
> > +
> > +     iter->as_id = as_id;
> > +     iter->pt_path[PT64_ROOT_4LEVEL - 1] =
> > +                     (u64 *)__va(kvm->arch.direct_root_hpa[as_id]);
> > +
> > +     iter->walk_start = start;
> > +     iter->walk_end = end;
> > +     iter->target_gfn = start;
> > +
> > +     iter->lock_mode = lock_mode;
> > +     iter->kvm = kvm;
> > +     iter->tlbs_dirty = 0;
> > +
> > +     direct_walk_iterator_start_traversal(iter);
> > +}
>
> ...
>
> > +static void direct_walk_iterator_cond_resched(struct direct_walk_iterator *iter)
> > +{
> > +     if (!(iter->lock_mode & MMU_LOCK_MAY_RESCHED) || !need_resched())
> > +             return;
> > +
> > +     direct_walk_iterator_prepare_cond_resched(iter);
> > +     cond_resched();
> > +     direct_walk_iterator_finish_cond_resched(iter);
> > +}
> > +
> > +static bool direct_walk_iterator_next_pte(struct direct_walk_iterator *iter)
> > +{
> > +     /*
> > +      * This iterator could be iterating over a large number of PTEs, such
> > +      * that if this thread did not yield, it would cause scheduler\
> > +      * problems. To avoid this, yield if needed. Note the check on
> > +      * MMU_LOCK_MAY_RESCHED in direct_walk_iterator_cond_resched. This
> > +      * iterator will not yield unless that flag is set in its lock_mode.
> > +      */
> > +     direct_walk_iterator_cond_resched(iter);
>
> This looks very fragile, e.g. one of the future patches even has to avoid
> problems with this code by limiting the number of PTEs it processes.
With this, functions either need to limit the number of PTEs they
process or pass the MMU_LOCK_MAY_RESCHED to the iterator. It would
probably be safer to invert the flag and make it
MMU_LOCK_MAY_NOT_RESCHED for functions that can self-regulate the
number of PTEs they process or have weird synchronization
requirements. For example, the page fault handler can't reschedule and
we know it won't process many entries, so we could pass
MMU_LOCK_MAY_NOT_RESCHED in there.


>
> > +
> > +     while (true) {
> > +             if (!direct_walk_iterator_next_pte_raw(iter))
>
> Implicitly initializing the iterator during next_pte_raw() is asking for
> problems, e.g. @walk_in_progress should not exist.  The standard kernel
> pattern for fancy iterators is to wrap the initialization, deref, and
> advancement operators in a macro, e.g. something like:
>
>         for_each_direct_pte(...) {
>
>         }
>
> That might require additional control flow logic in the users of the
> iterator, but if so that's probably a good thing in terms of readability
> and robustness.  E.g. verifying that rcu_read_unlock() is guaranteed to
> be called is extremely difficult as rcu_read_lock() is buried in this
> low level helper but the iterator relies on the top-level caller to
> terminate traversal.
>
> See mem_cgroup_iter_break() for one example of handling an iter walk
> where an action needs to taken when the walk terminates early.
>
> > +                     return false;
> > +
> > +             direct_walk_iterator_recalculate_output_fields(iter);
> > +             if (iter->old_pte != DISCONNECTED_PTE)
> > +                     break;
> > +
> > +             /*
> > +              * The iterator has encountered a disconnected pte, so it is in
> > +              * a page that has been disconnected from the root. Restart the
> > +              * traversal from the root in this case.
> > +              */
> > +             direct_walk_iterator_reset_traversal(iter);
>
> I understand wanting to hide details to eliminate copy-paste, but this
> goes too far and makes it too difficult to understand the flow of the
> top-level walks.  Ditto for burying retry_pte() in set_pte().  I'd say it
> also applies to skip_step_down(), but AFAICT that's dead code.
>
> Off-topic for a second, the super long direct_walk_iterator_... names
> make me want to simply call this new MMU the "tdp MMU" and just live with
> the discrepancy until the old shadow-based TDP MMU can be nuked.  Then we
> could have tdp_iter_blah_blah_blah(), for_each_tdp_present_pte(), etc...
>
> Back to the iterator, I think it can be massaged into a standard for loop
> approach without polluting the top level walkers much.  The below code is
> the basic idea, e.g. the macros won't compile, probably doesn't terminate
> the walk correct, rescheduling is missing, etc...
>
> Note, open coding the down/sideways/up helpers is 50% personal preference,
> 50% because gfn_start and gfn_end are now local variables, and 50% because
> it was the easiest way to learn the code.  I wouldn't argue too much about
> having one or more of the helpers.
>
>
> static void tdp_iter_break(struct tdp_iter *iter)
> {
>         /* TLB flush, RCU unlock, etc...)
> }
>
> static void tdp_iter_next(struct tdp_iter *iter, bool *retry)
> {
>         gfn_t gfn_start, gfn_end;
>         u64 *child_pt;
>
>         if (*retry) {
>                 *retry = false;
>                 return;
>         }
>
>         /*
>          * Reread the pte before stepping down to avoid traversing into page
>          * tables that are no longer linked from this entry. This is not
>          * needed for correctness - just a small optimization.
>          */
>         iter->old_pte = READ_ONCE(*iter->ptep);
>
>         /* Try to step down. */
>         child_pt = pte_to_child_pt(iter->old_pte, iter->level);
>         if (child_pt) {
>                 child_pt = rcu_dereference(child_pt);
>                 iter->level--;
>                 iter->pt_path[iter->level - 1] = child_pt;
>                 return;
>         }
>
> step_sideways:
>         /* Try to step sideways. */
>         gfn_start = ALIGN_DOWN(iter->target_gfn,
>                                KVM_PAGES_PER_HPAGE(iter->level));
>         gfn_end = gfn_start + KVM_PAGES_PER_HPAGE(iter->level)
>
>         /*
>          * If the current gfn maps past the target gfn range, the next entry in
>          * the current page table will be outside the target range.
>          */
>         if (gfn_end >= iter->walk_end ||
>             !(gfn_end % KVM_PAGES_PER_HPAGE(iter->level + 1))) {
>                 /* Try to step up. */
>                 iter->level++;
>
>                 if (iter->level > PT64_ROOT_4LEVEL; {
>                         /* This is ugly, there's probably a better solution. */
>                         tdp_iter_break(iter);
>                         return;
>                 }
>                 goto step_sideways;
>         }
>
>         iter->target_gfn = gfn_end;
>         iter->ptep = iter->pt_path[iter->level - 1] +
>                         PT64_INDEX(iter->target_gfn << PAGE_SHIFT, iter->level);
>         iter->old_pte = READ_ONCE(*iter->ptep);
> }
>
> #define for_each_tdp_pte(iter, start, end, retry)
>         for (tdp_iter_start(&iter, start, end);
>              iter->level <= PT64_ROOT_4LEVEL;
>              tdp_iter_next(&iter, &retry))
>
> #define for_each_tdp_present_pte(iter, start, end, retry)
>         for_each_tdp_pte(iter, start, end, retry)
>                 if (!is_present_direct_pte(iter->old_pte)) {
>
>                 } else
>
> #define for_each_tdp_present_leaf_pte(iter, start, end, retry)
>         for_each_tdp_pte(iter, start, end, retry)
>                 if (!is_present_direct_pte(iter->old_pte) ||
>                     !is_last_spte(iter->old_pte, iter->level))
>                 {
>
>                 } else
>
> /*
>  * Marks the range of gfns, [start, end), non-present.
>  */
> static bool zap_direct_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
>                                  gfn_t end, enum mmu_lock_mode lock_mode)
> {
>         struct direct_walk_iterator iter;
>         bool retry;
>
>         tdp_iter_init(&iter, kvm, as_id, lock_mode);
>
> restart:
>         retry = false;
>         for_each_tdp_present_pte(iter, start, end, retry) {
>                 if (tdp_iter_set_pte(&iter, 0))
>                         retry = true;
>
>                 if (tdp_iter_disconnected(&iter)) {
>                         tdp_iter_break(&iter);
>                         goto restart;
>                 }
>         }
> }
>
Sean Christopherson Dec. 18, 2019, 7:14 p.m. UTC | #3
On Wed, Dec 18, 2019 at 10:25:45AM -0800, Ben Gardon wrote:
> On Mon, Dec 2, 2019 at 6:15 PM Sean Christopherson
> <sean.j.christopherson@intel.com> wrote:
> >
> > > +static bool direct_walk_iterator_next_pte(struct direct_walk_iterator *iter)
> > > +{
> > > +     /*
> > > +      * This iterator could be iterating over a large number of PTEs, such
> > > +      * that if this thread did not yield, it would cause scheduler\
> > > +      * problems. To avoid this, yield if needed. Note the check on
> > > +      * MMU_LOCK_MAY_RESCHED in direct_walk_iterator_cond_resched. This
> > > +      * iterator will not yield unless that flag is set in its lock_mode.
> > > +      */
> > > +     direct_walk_iterator_cond_resched(iter);
> >
> > This looks very fragile, e.g. one of the future patches even has to avoid
> > problems with this code by limiting the number of PTEs it processes.
>
> With this, functions either need to limit the number of PTEs they
> process or pass the MMU_LOCK_MAY_RESCHED to the iterator. It would
> probably be safer to invert the flag and make it
> MMU_LOCK_MAY_NOT_RESCHED for functions that can self-regulate the
> number of PTEs they process or have weird synchronization
> requirements. For example, the page fault handler can't reschedule and
> we know it won't process many entries, so we could pass
> MMU_LOCK_MAY_NOT_RESCHED in there.

That doesn't address the underlying fragility of the iterator, i.e. relying
on callers to self-regulate.  Especially since the threshold is completely
arbitrary, e.g. in zap_direct_gfn_range(), what's to say PDPE and lower is
always safe, e.g. if should_resched() becomes true at the very start of the
walk?

The direct comparison to zap_direct_gfn_range() is slot_handle_level_range(),
which supports rescheduling regardless of what function is being invoked.
What prevents the TDP iterator from doing the same?  E.g. what's the worst
case scenario if a reschedule pops up at an inopportune time?
diff mbox series

Patch

diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index 263718d49f730..59d1866398c42 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -1948,6 +1948,461 @@  static void handle_changed_pte(struct kvm *kvm, int as_id, gfn_t gfn,
 	}
 }
 
+/*
+ * Given a host page table entry and its level, returns a pointer containing
+ * the host virtual address of the child page table referenced by the page table
+ * entry. Returns null if there is no such entry.
+ */
+static u64 *pte_to_child_pt(u64 pte, int level)
+{
+	u64 *pt;
+	/* There's no child entry if this entry isn't present */
+	if (!is_present_direct_pte(pte))
+		return NULL;
+
+	/* There is no child page table if this is a leaf entry. */
+	if (is_last_spte(pte, level))
+		return NULL;
+
+	pt = (u64 *)__va(pte & PT64_BASE_ADDR_MASK);
+	return pt;
+}
+
+enum mmu_lock_mode {
+	MMU_NO_LOCK = 0,
+	MMU_READ_LOCK = 1,
+	MMU_WRITE_LOCK = 2,
+	MMU_LOCK_MAY_RESCHED = 4
+};
+
+/*
+ * A direct walk iterator encapsulates a walk through a direct paging structure.
+ * It handles ensuring that the walk uses RCU to safely access page table
+ * memory.
+ */
+struct direct_walk_iterator {
+	/* Internal */
+	gfn_t walk_start;
+	gfn_t walk_end;
+	gfn_t target_gfn;
+	long tlbs_dirty;
+
+	/* the address space id. */
+	int as_id;
+	u64 *pt_path[PT64_ROOT_4LEVEL];
+	bool walk_in_progress;
+
+	/*
+	 * If set, the next call to direct_walk_iterator_next_pte_raw will
+	 * simply reread the current pte and return. This is useful in cases
+	 * where a thread misses a race to set a pte and wants to retry. This
+	 * should be set with a call to direct_walk_iterator_retry_pte.
+	 */
+	bool retry_pte;
+
+	/*
+	 * If set, the next call to direct_walk_iterator_next_pte_raw will not
+	 * step down to a lower level on its next step, even if it is at a
+	 * present, non-leaf pte. This is useful when, for example, splitting
+	 * pages, since we know that the entries below the now split page don't
+	 * need to be handled again.
+	 */
+	bool skip_step_down;
+
+	enum mmu_lock_mode lock_mode;
+	struct kvm *kvm;
+
+	/* Output */
+
+	/* The iterator's current level within the paging structure */
+	int level;
+	/* A pointer to the current PTE */
+	u64 *ptep;
+	/* The a snapshot of the PTE pointed to by ptep */
+	u64 old_pte;
+	/* The lowest GFN mapped by the current PTE */
+	gfn_t pte_gfn_start;
+	/* The highest GFN mapped by the current PTE, + 1 */
+	gfn_t pte_gfn_end;
+};
+
+static void direct_walk_iterator_start_traversal(
+		struct direct_walk_iterator *iter)
+{
+	int level;
+
+	/*
+	 * Only clear the levels below the root. The root level page table is
+	 * allocated at VM creation time and will never change for the life of
+	 * the VM.
+	 */
+	for (level = PT_PAGE_TABLE_LEVEL; level < PT64_ROOT_4LEVEL; level++)
+		iter->pt_path[level - 1] = NULL;
+	iter->level = 0;
+	iter->ptep = NULL;
+	iter->old_pte = 0;
+	iter->pte_gfn_start = 0;
+	iter->pte_gfn_end = 0;
+	iter->walk_in_progress = false;
+	iter->retry_pte = false;
+	iter->skip_step_down = false;
+}
+
+static bool direct_walk_iterator_flush_needed(struct direct_walk_iterator *iter)
+{
+	long tlbs_dirty;
+
+	if (iter->tlbs_dirty) {
+		tlbs_dirty = xadd(&iter->kvm->tlbs_dirty, iter->tlbs_dirty) +
+				iter->tlbs_dirty;
+		iter->tlbs_dirty = 0;
+	} else {
+		tlbs_dirty = READ_ONCE(iter->kvm->tlbs_dirty);
+	}
+
+	return (iter->lock_mode & MMU_WRITE_LOCK) && tlbs_dirty;
+}
+
+static bool direct_walk_iterator_end_traversal(
+		struct direct_walk_iterator *iter)
+{
+	if (iter->walk_in_progress)
+		rcu_read_unlock();
+	return direct_walk_iterator_flush_needed(iter);
+}
+
+/*
+ * Resets a direct walk iterator to the root of the paging structure and RCU
+ * unlocks. After calling this function, the traversal can be reattempted.
+ */
+static void direct_walk_iterator_reset_traversal(
+		struct direct_walk_iterator *iter)
+{
+	/*
+	 * It's okay it ignore the return value, indicating whether a TLB flush
+	 * is needed here because we are ending and then restarting the
+	 * traversal without releasing the MMU lock. At this point the
+	 * iterator tlbs_dirty will have been flushed to the kvm tlbs_dirty, so
+	 * the next end_traversal will return that a flush is needed, if there's
+	 * not an intervening flush for some other reason.
+	 */
+	direct_walk_iterator_end_traversal(iter);
+	direct_walk_iterator_start_traversal(iter);
+}
+
+/*
+ * Sets a direct walk iterator to seek the gfn range [start, end).
+ * If end is greater than the maximum possible GFN, it will be changed to the
+ * maximum possible gfn + 1. (Note that start/end is and inclusive/exclusive
+ * range, so the last gfn to be interated over would be the largest possible
+ * GFN, in this scenario.)
+ */
+__attribute__((unused))
+static void direct_walk_iterator_setup_walk(struct direct_walk_iterator *iter,
+	struct kvm *kvm, int as_id, gfn_t start, gfn_t end,
+	enum mmu_lock_mode lock_mode)
+{
+	BUG_ON(!kvm->arch.direct_mmu_enabled);
+	BUG_ON((lock_mode & MMU_WRITE_LOCK) && (lock_mode & MMU_READ_LOCK));
+	BUG_ON(as_id < 0);
+	BUG_ON(as_id >= KVM_ADDRESS_SPACE_NUM);
+	BUG_ON(!VALID_PAGE(kvm->arch.direct_root_hpa[as_id]));
+
+	/* End cannot be greater than the maximum possible gfn. */
+	end = min(end, 1ULL << (PT64_ROOT_4LEVEL * PT64_PT_BITS));
+
+	iter->as_id = as_id;
+	iter->pt_path[PT64_ROOT_4LEVEL - 1] =
+			(u64 *)__va(kvm->arch.direct_root_hpa[as_id]);
+
+	iter->walk_start = start;
+	iter->walk_end = end;
+	iter->target_gfn = start;
+
+	iter->lock_mode = lock_mode;
+	iter->kvm = kvm;
+	iter->tlbs_dirty = 0;
+
+	direct_walk_iterator_start_traversal(iter);
+}
+
+__attribute__((unused))
+static void direct_walk_iterator_retry_pte(struct direct_walk_iterator *iter)
+{
+	BUG_ON(!iter->walk_in_progress);
+	iter->retry_pte = true;
+}
+
+__attribute__((unused))
+static void direct_walk_iterator_skip_step_down(
+		struct direct_walk_iterator *iter)
+{
+	BUG_ON(!iter->walk_in_progress);
+	iter->skip_step_down = true;
+}
+
+/*
+ * Steps down one level in the paging structure towards the previously set
+ * target gfn. Returns true if the iterator was able to step down a level,
+ * false otherwise.
+ */
+static bool direct_walk_iterator_try_step_down(
+		struct direct_walk_iterator *iter)
+{
+	u64 *child_pt;
+
+	/*
+	 * Reread the pte before stepping down to avoid traversing into page
+	 * tables that are no longer linked from this entry. This is not
+	 * needed for correctness - just a small optimization.
+	 */
+	iter->old_pte = READ_ONCE(*iter->ptep);
+
+	child_pt = pte_to_child_pt(iter->old_pte, iter->level);
+	if (child_pt == NULL)
+		return false;
+	child_pt = rcu_dereference(child_pt);
+
+	iter->level--;
+	iter->pt_path[iter->level - 1] = child_pt;
+	return true;
+}
+
+/*
+ * Steps to the next entry in the current page table, at the current page table
+ * level. The next entry could map a page of guest memory, another page table,
+ * or it could be non-present or invalid. Returns true if the iterator was able
+ * to step to the next entry in the page table, false otherwise.
+ */
+static bool direct_walk_iterator_try_step_side(
+		struct direct_walk_iterator *iter)
+{
+	/*
+	 * If the current gfn maps past the target gfn range, the next entry in
+	 * the current page table will be outside the target range.
+	 */
+	if (iter->pte_gfn_end >= iter->walk_end)
+		return false;
+
+	/*
+	 * Check if the iterator is already at the end of the current page
+	 * table.
+	 */
+	if (!(iter->pte_gfn_end % KVM_PAGES_PER_HPAGE(iter->level + 1)))
+		return false;
+
+	iter->target_gfn = iter->pte_gfn_end;
+	return true;
+}
+
+/*
+ * Tries to back up a level in the paging structure so that the walk can
+ * continue from the next entry in the parent page table. Returns true on a
+ * successful step up, false otherwise.
+ */
+static bool direct_walk_iterator_try_step_up(struct direct_walk_iterator *iter)
+{
+	if (iter->level == PT64_ROOT_4LEVEL)
+		return false;
+
+	iter->level++;
+	return true;
+}
+
+/*
+ * Step to the next pte in a pre-order traversal of the target gfn range.
+ * To get to the next pte, the iterator either steps down towards the current
+ * target gfn, if at a present, non-leaf pte, or over to a pte mapping a
+ * highter gfn, if there's room in the gfn range. If there is no step within
+ * the target gfn range, returns false.
+ */
+static bool direct_walk_iterator_next_pte_raw(struct direct_walk_iterator *iter)
+{
+	bool retry_pte = iter->retry_pte;
+	bool skip_step_down = iter->skip_step_down;
+
+	iter->retry_pte = false;
+	iter->skip_step_down = false;
+
+	if (iter->target_gfn >= iter->walk_end)
+		return false;
+
+	/* If the walk is just starting, set up initial values. */
+	if (!iter->walk_in_progress) {
+		rcu_read_lock();
+
+		iter->level = PT64_ROOT_4LEVEL;
+		iter->walk_in_progress = true;
+		return true;
+	}
+
+	if (retry_pte)
+		return true;
+
+	if (!skip_step_down && direct_walk_iterator_try_step_down(iter))
+		return true;
+
+	while (!direct_walk_iterator_try_step_side(iter))
+		if (!direct_walk_iterator_try_step_up(iter))
+			return false;
+	return true;
+}
+
+static void direct_walk_iterator_recalculate_output_fields(
+		struct direct_walk_iterator *iter)
+{
+	iter->ptep = iter->pt_path[iter->level - 1] +
+			PT64_INDEX(iter->target_gfn << PAGE_SHIFT, iter->level);
+	iter->old_pte = READ_ONCE(*iter->ptep);
+	iter->pte_gfn_start = ALIGN_DOWN(iter->target_gfn,
+			KVM_PAGES_PER_HPAGE(iter->level));
+	iter->pte_gfn_end = iter->pte_gfn_start +
+			KVM_PAGES_PER_HPAGE(iter->level);
+}
+
+static void direct_walk_iterator_prepare_cond_resched(
+		struct direct_walk_iterator *iter)
+{
+	if (direct_walk_iterator_end_traversal(iter))
+		kvm_flush_remote_tlbs(iter->kvm);
+
+	if (iter->lock_mode & MMU_WRITE_LOCK)
+		write_unlock(&iter->kvm->mmu_lock);
+	else if (iter->lock_mode & MMU_READ_LOCK)
+		read_unlock(&iter->kvm->mmu_lock);
+
+}
+
+static void direct_walk_iterator_finish_cond_resched(
+		struct direct_walk_iterator *iter)
+{
+	if (iter->lock_mode & MMU_WRITE_LOCK)
+		write_lock(&iter->kvm->mmu_lock);
+	else if (iter->lock_mode & MMU_READ_LOCK)
+		read_lock(&iter->kvm->mmu_lock);
+
+	direct_walk_iterator_start_traversal(iter);
+}
+
+static void direct_walk_iterator_cond_resched(struct direct_walk_iterator *iter)
+{
+	if (!(iter->lock_mode & MMU_LOCK_MAY_RESCHED) || !need_resched())
+		return;
+
+	direct_walk_iterator_prepare_cond_resched(iter);
+	cond_resched();
+	direct_walk_iterator_finish_cond_resched(iter);
+}
+
+static bool direct_walk_iterator_next_pte(struct direct_walk_iterator *iter)
+{
+	/*
+	 * This iterator could be iterating over a large number of PTEs, such
+	 * that if this thread did not yield, it would cause scheduler\
+	 * problems. To avoid this, yield if needed. Note the check on
+	 * MMU_LOCK_MAY_RESCHED in direct_walk_iterator_cond_resched. This
+	 * iterator will not yield unless that flag is set in its lock_mode.
+	 */
+	direct_walk_iterator_cond_resched(iter);
+
+	while (true) {
+		if (!direct_walk_iterator_next_pte_raw(iter))
+			return false;
+
+		direct_walk_iterator_recalculate_output_fields(iter);
+		if (iter->old_pte != DISCONNECTED_PTE)
+			break;
+
+		/*
+		 * The iterator has encountered a disconnected pte, so it is in
+		 * a page that has been disconnected from the root. Restart the
+		 * traversal from the root in this case.
+		 */
+		direct_walk_iterator_reset_traversal(iter);
+	}
+
+	trace_kvm_mmu_direct_walk_iterator_step(iter->walk_start,
+			iter->walk_end, iter->pte_gfn_start,
+			iter->level, iter->old_pte);
+
+	return true;
+}
+
+/*
+ * As direct_walk_iterator_next_pte but skips over non-present ptes.
+ * (i.e. ptes that are 0 or invalidated.)
+ */
+static bool direct_walk_iterator_next_present_pte(
+		struct direct_walk_iterator *iter)
+{
+	while (direct_walk_iterator_next_pte(iter))
+		if (is_present_direct_pte(iter->old_pte))
+			return true;
+
+	return false;
+}
+
+/*
+ * As direct_walk_iterator_next_present_pte but skips over non-leaf ptes.
+ */
+__attribute__((unused))
+static bool direct_walk_iterator_next_present_leaf_pte(
+		struct direct_walk_iterator *iter)
+{
+	while (direct_walk_iterator_next_present_pte(iter))
+		if (is_last_spte(iter->old_pte, iter->level))
+			return true;
+
+	return false;
+}
+
+/*
+ * Performs an atomic compare / exchange of ptes.
+ * Returns true if the pte was successfully set to the new value, false if the
+ * there was a race and the compare exchange needs to be retried.
+ */
+static bool cmpxchg_pte(u64 *ptep, u64 old_pte, u64 new_pte, int level, u64 gfn)
+{
+	u64 r;
+
+	r = cmpxchg64(ptep, old_pte, new_pte);
+	if (r == old_pte)
+		trace_kvm_mmu_set_pte_atomic(gfn, level, old_pte, new_pte);
+
+	return r == old_pte;
+}
+
+__attribute__((unused))
+static bool direct_walk_iterator_set_pte(struct direct_walk_iterator *iter,
+					 u64 new_pte)
+{
+	bool r;
+
+	if (!(iter->lock_mode & (MMU_READ_LOCK | MMU_WRITE_LOCK))) {
+		BUG_ON(is_present_direct_pte(iter->old_pte) !=
+				is_present_direct_pte(new_pte));
+		BUG_ON(spte_to_pfn(iter->old_pte) != spte_to_pfn(new_pte));
+		BUG_ON(is_last_spte(iter->old_pte, iter->level) !=
+				is_last_spte(new_pte, iter->level));
+	}
+
+	if (iter->old_pte == new_pte)
+		return true;
+
+	r = cmpxchg_pte(iter->ptep, iter->old_pte, new_pte, iter->level,
+			iter->pte_gfn_start);
+	if (r) {
+		handle_changed_pte(iter->kvm, iter->as_id, iter->pte_gfn_start,
+				   iter->old_pte, new_pte, iter->level, false);
+
+		if (iter->lock_mode & (MMU_WRITE_LOCK | MMU_READ_LOCK))
+			iter->tlbs_dirty++;
+	} else
+		direct_walk_iterator_retry_pte(iter);
+
+	return r;
+}
+
 /**
  * kvm_mmu_write_protect_pt_masked - write protect selected PT level pages
  * @kvm: kvm instance
diff --git a/arch/x86/kvm/mmutrace.h b/arch/x86/kvm/mmutrace.h
index 7ca8831c7d1a2..530723038296a 100644
--- a/arch/x86/kvm/mmutrace.h
+++ b/arch/x86/kvm/mmutrace.h
@@ -166,6 +166,56 @@  TRACE_EVENT(
 		  __entry->created ? "new" : "existing")
 );
 
+TRACE_EVENT(
+	kvm_mmu_direct_walk_iterator_step,
+	TP_PROTO(u64 walk_start, u64 walk_end, u64 base_gfn, int level,
+		u64 pte),
+	TP_ARGS(walk_start, walk_end, base_gfn, level, pte),
+
+	TP_STRUCT__entry(
+		__field(u64, walk_start)
+		__field(u64, walk_end)
+		__field(u64, base_gfn)
+		__field(int, level)
+		__field(u64, pte)
+		),
+
+	TP_fast_assign(
+		__entry->walk_start = walk_start;
+		__entry->walk_end = walk_end;
+		__entry->base_gfn = base_gfn;
+		__entry->level = level;
+		__entry->pte = pte;
+		),
+
+	TP_printk("walk_start=%llx walk_end=%llx base_gfn=%llx lvl=%d pte=%llx",
+		__entry->walk_start, __entry->walk_end, __entry->base_gfn,
+		__entry->level, __entry->pte)
+);
+
+TRACE_EVENT(
+	kvm_mmu_set_pte_atomic,
+	TP_PROTO(u64 gfn, int level, u64 old_pte, u64 new_pte),
+	TP_ARGS(gfn, level, old_pte, new_pte),
+
+	TP_STRUCT__entry(
+		__field(u64, gfn)
+		__field(int, level)
+		__field(u64, old_pte)
+		__field(u64, new_pte)
+		),
+
+	TP_fast_assign(
+		__entry->gfn = gfn;
+		__entry->level = level;
+		__entry->old_pte = old_pte;
+		__entry->new_pte = new_pte;
+		),
+
+	TP_printk("gfn=%llx level=%d old_pte=%llx new_pte=%llx", __entry->gfn,
+		  __entry->level, __entry->old_pte, __entry->new_pte)
+);
+
 DECLARE_EVENT_CLASS(kvm_mmu_page_class,
 
 	TP_PROTO(struct kvm_mmu_page *sp),