diff mbox series

[RFC,06/24] MM locking API: implement fine grained range locks

Message ID 20200224203057.162467-7-walken@google.com (mailing list archive)
State New, archived
Headers show
Series Fine grained MM locking | expand

Commit Message

Michel Lespinasse Feb. 24, 2020, 8:30 p.m. UTC
This change implements fine grained reader-writer range locks.

Existing locked ranges are represented as an augmented rbtree
protected by a mutex. The locked ranges hold information about
two overlapping interval trees, representing the reader and writer
locks respectively. This data structure allows quickly searching for
existing readers, writers, or both, intersecting a given address range.

When locking a range, a count of all existing conflicting
ranges (either already locked, or queued) is added to mm_lock_range
struct. If the count is non-zero, the locking task is put to sleep
until all conflicting lock ranges are released.

When unlocking a range, the conflict count for all existing (queued)
conflicting ranges is decremented. If the count reaches zero, the
locker task is woken up - it now has a lock on its desired address range.

The general approach for this range locking implementation was first
proposed by Jan Kara back in 2013, and later worked on by at least
Laurent Dufour and Davidlohr Bueso. I have extended on the approach
by using separate indexes for the reader and writer range locks.

Signed-off-by: Michel Lespinasse <walken@google.com>
---
 arch/x86/kernel/tboot.c       |   2 +-
 drivers/firmware/efi/efi.c    |   2 +-
 include/linux/mm_lock.h       |  96 ++++-
 include/linux/mm_types.h      |  20 +
 include/linux/mm_types_task.h |  15 +
 mm/Kconfig                    |   9 +-
 mm/Makefile                   |   1 +
 mm/init-mm.c                  |   3 +-
 mm/mm_lock_range.c            | 691 ++++++++++++++++++++++++++++++++++
 9 files changed, 827 insertions(+), 12 deletions(-)
 create mode 100644 mm/mm_lock_range.c
diff mbox series

Patch

diff --git arch/x86/kernel/tboot.c arch/x86/kernel/tboot.c
index 4c61f0713832..68bb5e9b0324 100644
--- arch/x86/kernel/tboot.c
+++ arch/x86/kernel/tboot.c
@@ -90,7 +90,7 @@  static struct mm_struct tboot_mm = {
 	.pgd            = swapper_pg_dir,
 	.mm_users       = ATOMIC_INIT(2),
 	.mm_count       = ATOMIC_INIT(1),
-	.mmap_sem       = __RWSEM_INITIALIZER(init_mm.mmap_sem),
+	.mmap_sem       = MM_LOCK_INITIALIZER(init_mm.mmap_sem),
 	.page_table_lock =  __SPIN_LOCK_UNLOCKED(init_mm.page_table_lock),
 	.mmlist         = LIST_HEAD_INIT(init_mm.mmlist),
 };
diff --git drivers/firmware/efi/efi.c drivers/firmware/efi/efi.c
index 2b02cb165f16..fb5c9d53ceb2 100644
--- drivers/firmware/efi/efi.c
+++ drivers/firmware/efi/efi.c
@@ -60,7 +60,7 @@  struct mm_struct efi_mm = {
 	.mm_rb			= RB_ROOT,
 	.mm_users		= ATOMIC_INIT(2),
 	.mm_count		= ATOMIC_INIT(1),
-	.mmap_sem		= __RWSEM_INITIALIZER(efi_mm.mmap_sem),
+	.mmap_sem		= MM_LOCK_INITIALIZER(efi_mm.mmap_sem),
 	.page_table_lock	= __SPIN_LOCK_UNLOCKED(efi_mm.page_table_lock),
 	.mmlist			= LIST_HEAD_INIT(efi_mm.mmlist),
 	.cpu_bitmap		= { [BITS_TO_LONGS(NR_CPUS)] = 0},
diff --git include/linux/mm_lock.h include/linux/mm_lock.h
index 8ed92ebe58a1..a4d60bd56899 100644
--- include/linux/mm_lock.h
+++ include/linux/mm_lock.h
@@ -2,17 +2,26 @@ 
 #define _LINUX_MM_LOCK_H
 
 #include <linux/sched.h>
-
-static inline void mm_init_lock(struct mm_struct *mm)
-{
-	init_rwsem(&mm->mmap_sem);
-}
+#include <linux/lockdep.h>
 
 #ifdef CONFIG_MM_LOCK_RWSEM_INLINE
 
+#define MM_LOCK_INITIALIZER __RWSEM_INITIALIZER
 #define MM_COARSE_LOCK_RANGE_INITIALIZER {}
 
+static inline void mm_init_lock(struct mm_struct *mm)
+{
+       init_rwsem(&mm->mmap_sem);
+}
+
 static inline void mm_init_coarse_lock_range(struct mm_lock_range *range) {}
+static inline void mm_init_lock_range(struct mm_lock_range *range,
+		unsigned long start, unsigned long end) {}
+
+static inline bool mm_range_is_coarse(struct mm_lock_range *range)
+{
+	return true;
+}
 
 static inline void mm_write_range_lock(struct mm_struct *mm,
 				       struct mm_lock_range *range)
@@ -86,15 +95,80 @@  static inline struct mm_lock_range *mm_coarse_lock_range(void)
 	return NULL;
 }
 
-#else /* CONFIG_MM_LOCK_RWSEM_CHECKED */
+#else	/* !CONFIG_MM_LOCK_RWSEM_INLINE */
+
+#ifdef CONFIG_MM_LOCK_RWSEM_CHECKED
 
+#define MM_LOCK_INITIALIZER __RWSEM_INITIALIZER
 #define MM_COARSE_LOCK_RANGE_INITIALIZER { .mm = NULL }
 
+static inline void mm_init_lock(struct mm_struct *mm)
+{
+       init_rwsem(&mm->mmap_sem);
+}
+
 static inline void mm_init_coarse_lock_range(struct mm_lock_range *range)
 {
 	range->mm = NULL;
 }
 
+static inline void mm_init_lock_range(struct mm_lock_range *range,
+		unsigned long start, unsigned long end) {
+	mm_init_coarse_lock_range(range);
+}
+
+static inline bool mm_range_is_coarse(struct mm_lock_range *range)
+{
+	return true;
+}
+
+#else	/* CONFIG_MM_LOCK_RANGE */
+
+#ifdef CONFIG_DEBUG_LOCK_ALLOC
+#define __DEP_MAP_MM_LOCK_INITIALIZER(lockname)		\
+	.dep_map = { .name = #lockname },
+#else
+#define __DEP_MAP_MM_LOCK_INITIALIZER(lockname)
+#endif
+
+#define MM_LOCK_INITIALIZER(name) {			\
+	.mutex = __MUTEX_INITIALIZER(name.mutex),	\
+	.rb_root = RB_ROOT,				\
+	__DEP_MAP_MM_LOCK_INITIALIZER(name)		\
+}
+
+#define MM_COARSE_LOCK_RANGE_INITIALIZER {		\
+	.start = 0,					\
+	.end = ~0UL,					\
+}
+
+static inline void mm_init_lock(struct mm_struct *mm)
+{
+	static struct lock_class_key __key;
+
+	mutex_init(&mm->mmap_sem.mutex);
+	mm->mmap_sem.rb_root = RB_ROOT;
+	lockdep_init_map(&mm->mmap_sem.dep_map, "&mm->mmap_sem", &__key, 0);
+}
+
+static inline void mm_init_lock_range(struct mm_lock_range *range,
+		unsigned long start, unsigned long end) {
+	range->start = start;
+	range->end = end;
+}
+
+static inline void mm_init_coarse_lock_range(struct mm_lock_range *range)
+{
+	mm_init_lock_range(range, 0, ~0UL);
+}
+
+static inline bool mm_range_is_coarse(struct mm_lock_range *range)
+{
+	return range->start == 0 && range->end == ~0UL;
+}
+
+#endif	/* CONFIG_MM_LOCK_RANGE */
+
 extern void mm_write_range_lock(struct mm_struct *mm,
 				struct mm_lock_range *range);
 #ifdef CONFIG_LOCKDEP
@@ -129,11 +203,11 @@  static inline struct mm_lock_range *mm_coarse_lock_range(void)
 	return &current->mm_coarse_lock_range;
 }
 
-#endif
+#endif	/* !CONFIG_MM_LOCK_RWSEM_INLINE */
 
 static inline void mm_read_release(struct mm_struct *mm, unsigned long ip)
 {
-	rwsem_release(&mm->mmap_sem.dep_map, ip);
+	lock_release(&mm->mmap_sem.dep_map, ip);
 }
 
 static inline void mm_write_lock(struct mm_struct *mm)
@@ -183,7 +257,13 @@  static inline void mm_read_unlock(struct mm_struct *mm)
 
 static inline bool mm_is_locked(struct mm_struct *mm)
 {
+#ifndef CONFIG_MM_LOCK_RANGE
 	return rwsem_is_locked(&mm->mmap_sem) != 0;
+#elseif defined(CONFIG_LOCKDEP)
+	return lockdep_is_held(&mm->mmap_sem);	/* Close enough for asserts */
+#else
+	return true;
+#endif
 }
 
 #endif /* _LINUX_MM_LOCK_H */
diff --git include/linux/mm_types.h include/linux/mm_types.h
index 270aa8fd2800..941610c906b3 100644
--- include/linux/mm_types.h
+++ include/linux/mm_types.h
@@ -283,6 +283,21 @@  struct vm_userfaultfd_ctx {
 struct vm_userfaultfd_ctx {};
 #endif /* CONFIG_USERFAULTFD */
 
+/*
+ * struct mm_lock stores locked address ranges for a given mm,
+ * implementing a fine-grained replacement for the mmap_sem rwsem.
+ */
+#ifdef CONFIG_MM_LOCK_RANGE
+struct mm_lock {
+	struct mutex mutex;
+	struct rb_root rb_root;
+	unsigned long seq;
+#ifdef CONFIG_DEBUG_LOCK_ALLOC
+	struct lockdep_map dep_map;
+#endif
+};
+#endif
+
 /*
  * This struct defines a memory VMM memory area. There is one of these
  * per VM-area/task.  A VM area is any part of the process virtual memory
@@ -426,7 +441,12 @@  struct mm_struct {
 		spinlock_t page_table_lock; /* Protects page tables and some
 					     * counters
 					     */
+
+#ifndef CONFIG_MM_LOCK_RANGE
 		struct rw_semaphore mmap_sem;
+#else
+		struct mm_lock mmap_sem;
+#endif
 
 		struct list_head mmlist; /* List of maybe swapped mm's.	These
 					  * are globally strung together off
diff --git include/linux/mm_types_task.h include/linux/mm_types_task.h
index d98c2a2293c1..e5652fe6a53c 100644
--- include/linux/mm_types_task.h
+++ include/linux/mm_types_task.h
@@ -12,6 +12,7 @@ 
 #include <linux/threads.h>
 #include <linux/atomic.h>
 #include <linux/cpumask.h>
+#include <linux/rbtree.h>
 
 #include <asm/page.h>
 
@@ -100,6 +101,20 @@  struct mm_lock_range {
 #ifdef CONFIG_MM_LOCK_RWSEM_CHECKED
 	struct mm_struct *mm;
 #endif
+#ifdef CONFIG_MM_LOCK_RANGE
+	/* First cache line - used in insert / remove / iter */
+	struct rb_node rb;
+	long flags_count;
+	unsigned long start;		/* First address of the range. */
+	unsigned long end;		/* First address after the range. */
+	struct {
+		unsigned long read_end;	  /* Largest end in reader nodes. */
+		unsigned long write_end;  /* Largest end in writer nodes. */
+	} __subtree;			/* Subtree augmented information. */
+	/* Second cache line - used in wait and wake. */
+	unsigned long seq;		/* Killable wait sequence number. */
+	struct task_struct *task;	/* Task trying to lock this range. */
+#endif
 };
 
 #endif /* _LINUX_MM_TYPES_TASK_H */
diff --git mm/Kconfig mm/Kconfig
index 574fb51789a5..3273ddb5839f 100644
--- mm/Kconfig
+++ mm/Kconfig
@@ -741,7 +741,7 @@  config MAPPING_DIRTY_HELPERS
 
 choice
 	prompt "MM lock implementation (mmap_sem)"
-	default MM_LOCK_RWSEM_CHECKED
+	default MM_LOCK_RANGE
 
 config MM_LOCK_RWSEM_INLINE
 	bool "rwsem, inline"
@@ -755,6 +755,13 @@  config MM_LOCK_RWSEM_CHECKED
 	  This option implements the MM lock using a read-write semaphore,
 	  ignoring the passed address range but checking its validity.
 
+config MM_LOCK_RANGE
+	bool "range lock"
+	help
+	  This option implements the MM lock as a read-write range lock,
+	  thus avoiding false conflicts between operations that operate
+	  on non-overlapping address ranges.
+
 endchoice
 
 endmenu
diff --git mm/Makefile mm/Makefile
index 9f46376c6407..71197fc20eda 100644
--- mm/Makefile
+++ mm/Makefile
@@ -109,3 +109,4 @@  obj-$(CONFIG_HMM_MIRROR) += hmm.o
 obj-$(CONFIG_MEMFD_CREATE) += memfd.o
 obj-$(CONFIG_MAPPING_DIRTY_HELPERS) += mapping_dirty_helpers.o
 obj-$(CONFIG_MM_LOCK_RWSEM_CHECKED) += mm_lock_rwsem_checked.o
+obj-$(CONFIG_MM_LOCK_RANGE) += mm_lock_range.o
diff --git mm/init-mm.c mm/init-mm.c
index 19603302a77f..0ba8ba5c07f4 100644
--- mm/init-mm.c
+++ mm/init-mm.c
@@ -1,5 +1,6 @@ 
 // SPDX-License-Identifier: GPL-2.0
 #include <linux/mm_types.h>
+#include <linux/mm_lock.h>
 #include <linux/rbtree.h>
 #include <linux/rwsem.h>
 #include <linux/spinlock.h>
@@ -31,7 +32,7 @@  struct mm_struct init_mm = {
 	.pgd		= swapper_pg_dir,
 	.mm_users	= ATOMIC_INIT(2),
 	.mm_count	= ATOMIC_INIT(1),
-	.mmap_sem	= __RWSEM_INITIALIZER(init_mm.mmap_sem),
+	.mmap_sem	= MM_LOCK_INITIALIZER(init_mm.mmap_sem),
 	.page_table_lock =  __SPIN_LOCK_UNLOCKED(init_mm.page_table_lock),
 	.arg_lock	=  __SPIN_LOCK_UNLOCKED(init_mm.arg_lock),
 	.mmlist		= LIST_HEAD_INIT(init_mm.mmlist),
diff --git mm/mm_lock_range.c mm/mm_lock_range.c
new file mode 100644
index 000000000000..da3c70e0809a
--- /dev/null
+++ mm/mm_lock_range.c
@@ -0,0 +1,691 @@ 
+#include <linux/mm_lock.h>
+#include <linux/rbtree_augmented.h>
+#include <linux/mutex.h>
+#include <linux/lockdep.h>
+#include <linux/sched.h>
+#include <linux/sched/signal.h>
+#include <linux/sched/wake_q.h>
+
+/* range->flags_count definitions */
+#define MM_LOCK_RANGE_WRITE 1
+#define MM_LOCK_RANGE_COUNT_ONE 2
+
+static inline bool rbcompute(struct mm_lock_range *range, bool exit)
+{
+	struct mm_lock_range *child;
+	unsigned long subtree_read_end = range->end, subtree_write_end = 0;
+	if (range->flags_count & MM_LOCK_RANGE_WRITE) {
+		subtree_read_end = 0;
+		subtree_write_end = range->end;
+	}
+	if (range->rb.rb_left) {
+		child = rb_entry(range->rb.rb_left, struct mm_lock_range, rb);
+		if (child->__subtree.read_end > subtree_read_end)
+			subtree_read_end = child->__subtree.read_end;
+		if (child->__subtree.write_end > subtree_write_end)
+			subtree_write_end = child->__subtree.write_end;
+	}
+	if (range->rb.rb_right) {
+		child = rb_entry(range->rb.rb_right, struct mm_lock_range, rb);
+		if (child->__subtree.read_end > subtree_read_end)
+			subtree_read_end = child->__subtree.read_end;
+		if (child->__subtree.write_end > subtree_write_end)
+			subtree_write_end = child->__subtree.write_end;
+	}
+	if (exit && range->__subtree.read_end == subtree_read_end &&
+		range->__subtree.write_end == subtree_write_end)
+		return true;
+	range->__subtree.read_end = subtree_read_end;
+	range->__subtree.write_end = subtree_write_end;
+	return false;
+}
+
+RB_DECLARE_CALLBACKS(static, augment, struct mm_lock_range, rb,
+		     __subtree, rbcompute);
+
+static void insert_read(struct mm_lock_range *range, struct rb_root *root)
+{
+	struct rb_node **link = &root->rb_node, *rb_parent = NULL;
+	unsigned long start = range->start, end = range->end;
+	struct mm_lock_range *parent;
+
+	while (*link) {
+		rb_parent = *link;
+		parent = rb_entry(rb_parent, struct mm_lock_range, rb);
+		if (parent->__subtree.read_end < end)
+			parent->__subtree.read_end = end;
+		if (start < parent->start)
+			link = &parent->rb.rb_left;
+		else
+			link = &parent->rb.rb_right;
+	}
+
+	range->__subtree.read_end = end;
+	range->__subtree.write_end = 0;
+	rb_link_node(&range->rb, rb_parent, link);
+	rb_insert_augmented(&range->rb, root, &augment);
+}
+
+static void insert_write(struct mm_lock_range *range, struct rb_root *root)
+{
+	struct rb_node **link = &root->rb_node, *rb_parent = NULL;
+	unsigned long start = range->start, end = range->end;
+	struct mm_lock_range *parent;
+
+	while (*link) {
+		rb_parent = *link;
+		parent = rb_entry(rb_parent, struct mm_lock_range, rb);
+		if (parent->__subtree.write_end < end)
+			parent->__subtree.write_end = end;
+		if (start < parent->start)
+			link = &parent->rb.rb_left;
+		else
+			link = &parent->rb.rb_right;
+	}
+
+	range->__subtree.read_end = 0;
+	range->__subtree.write_end = end;
+	rb_link_node(&range->rb, rb_parent, link);
+	rb_insert_augmented(&range->rb, root, &augment);
+}
+
+static void remove(struct mm_lock_range *range, struct rb_root *root)
+{
+	rb_erase_augmented(&range->rb, root, &augment);
+}
+
+/*
+ * Iterate over ranges intersecting [start;end)
+ *
+ * Note that a range intersects [start;end) iff:
+ *   Cond1: range->start < end
+ * and
+ *   Cond2: start < range->end
+ */
+
+static struct mm_lock_range *
+subtree_search(struct mm_lock_range *range,
+	       unsigned long start, unsigned long end)
+{
+	while (true) {
+		/*
+		 * Loop invariant: start < range->__subtree.read_end
+		 *              or start < range->__subtree.write_end
+		 * (Cond2 is satisfied by one of the subtree ranges)
+		 */
+		if (range->rb.rb_left) {
+			struct mm_lock_range *left = rb_entry(
+				range->rb.rb_left, struct mm_lock_range, rb);
+			if (start < left->__subtree.read_end ||
+			    start < left->__subtree.write_end) {
+				/*
+				 * Some ranges in left subtree satisfy Cond2.
+				 * Iterate to find the leftmost such range R.
+				 * If it also satisfies Cond1, that's the
+				 * match we are looking for. Otherwise, there
+				 * is no matching interval as ranges to the
+				 * right of R can't satisfy Cond1 either.
+				 */
+				range = left;
+				continue;
+			}
+		}
+		if (range->start < end) {		/* Cond1 */
+			if (start < range->end)		/* Cond2 */
+				return range;	/* range is leftmost match */
+			if (range->rb.rb_right) {
+				range = rb_entry(range->rb.rb_right,
+						 struct mm_lock_range, rb);
+				if (start < range->__subtree.read_end ||
+				    start < range->__subtree.write_end)
+					continue;
+			}
+		}
+		return NULL;	/* No match */
+	}
+}
+
+static struct mm_lock_range *
+iter_first(struct rb_root *root, unsigned long start, unsigned long end)
+{
+	struct mm_lock_range *range;
+
+	if (!root->rb_node)
+		return NULL;
+	range = rb_entry(root->rb_node, struct mm_lock_range, rb);
+	if (range->__subtree.read_end <= start &&
+	    range->__subtree.write_end <= start)
+		return NULL;
+	return subtree_search(range, start, end);
+}
+
+static struct mm_lock_range *
+iter_next(struct mm_lock_range *range, unsigned long start, unsigned long end)
+{
+	struct rb_node *rb = range->rb.rb_right, *prev;
+
+	while (true) {
+		/*
+		 * Loop invariants:
+		 *   Cond1: range->start < end
+		 *   rb == range->rb.rb_right
+		 *
+		 * First, search right subtree if suitable
+		 */
+		if (rb) {
+			struct mm_lock_range *right = rb_entry(
+				rb, struct mm_lock_range, rb);
+			if (start < right->__subtree.read_end ||
+			    start < right->__subtree.write_end)
+				return subtree_search(right, start, end);
+		}
+
+		/* Move up the tree until we come from a range's left child */
+		do {
+			rb = rb_parent(&range->rb);
+			if (!rb)
+				return NULL;
+			prev = &range->rb;
+			range = rb_entry(rb, struct mm_lock_range, rb);
+			rb = range->rb.rb_right;
+		} while (prev == rb);
+
+		/* Check if the range intersects [start;end) */
+		if (end <= range->start)		/* !Cond1 */
+			return NULL;
+		else if (start < range->end)		/* Cond2 */
+			return range;
+	}
+}
+
+#define FOR_EACH_RANGE(mm, start, end, tmp)				\
+for (tmp = iter_first(&mm->mmap_sem.rb_root, start, end); tmp;		\
+     tmp = iter_next(tmp, start, end))
+
+static struct mm_lock_range *
+subtree_search_read(struct mm_lock_range *range,
+		    unsigned long start, unsigned long end)
+{
+	while (true) {
+		/*
+		 * Loop invariant: start < range->__subtree.read_end
+		 * (Cond2 is satisfied by one of the subtree ranges)
+		 */
+		if (range->rb.rb_left) {
+			struct mm_lock_range *left = rb_entry(
+				range->rb.rb_left, struct mm_lock_range, rb);
+			if (start < left->__subtree.read_end) {
+				/*
+				 * Some ranges in left subtree satisfy Cond2.
+				 * Iterate to find the leftmost such range R.
+				 * If it also satisfies Cond1, that's the
+				 * match we are looking for. Otherwise, there
+				 * is no matching interval as ranges to the
+				 * right of R can't satisfy Cond1 either.
+				 */
+				range = left;
+				continue;
+			}
+		}
+		if (range->start < end) {		/* Cond1 */
+			if (start < range->end &&	/* Cond2 */
+			    !(range->flags_count & MM_LOCK_RANGE_WRITE))
+				return range;	/* range is leftmost match */
+			if (range->rb.rb_right) {
+				range = rb_entry(range->rb.rb_right,
+						 struct mm_lock_range, rb);
+				if (start < range->__subtree.read_end)
+					continue;
+			}
+		}
+		return NULL;	/* No match */
+	}
+}
+
+static struct mm_lock_range *
+iter_first_read(struct rb_root *root, unsigned long start, unsigned long end)
+{
+	struct mm_lock_range *range;
+
+	if (!root->rb_node)
+		return NULL;
+	range = rb_entry(root->rb_node, struct mm_lock_range, rb);
+	if (range->__subtree.read_end <= start)
+		return NULL;
+	return subtree_search_read(range, start, end);
+}
+
+static struct mm_lock_range *
+iter_next_read(struct mm_lock_range *range,
+	       unsigned long start, unsigned long end)
+{
+	struct rb_node *rb = range->rb.rb_right, *prev;
+
+	while (true) {
+		/*
+		 * Loop invariants:
+		 *   Cond1: range->start < end
+		 *   rb == range->rb.rb_right
+		 *
+		 * First, search right subtree if suitable
+		 */
+		if (rb) {
+			struct mm_lock_range *right = rb_entry(
+				rb, struct mm_lock_range, rb);
+			if (start < right->__subtree.read_end)
+				return subtree_search_read(right, start, end);
+		}
+
+		/* Move up the tree until we come from a range's left child */
+		do {
+			rb = rb_parent(&range->rb);
+			if (!rb)
+				return NULL;
+			prev = &range->rb;
+			range = rb_entry(rb, struct mm_lock_range, rb);
+			rb = range->rb.rb_right;
+		} while (prev == rb);
+
+		/* Check if the range intersects [start;end) */
+		if (end <= range->start)		/* !Cond1 */
+			return NULL;
+		else if (start < range->end &&		/* Cond2 */
+			 !(range->flags_count & MM_LOCK_RANGE_WRITE))
+			return range;
+	}
+}
+
+#define FOR_EACH_RANGE_READ(mm, start, end, tmp)			\
+for (tmp = iter_first_read(&mm->mmap_sem.rb_root, start, end); tmp;	\
+     tmp = iter_next_read(tmp, start, end))
+
+static struct mm_lock_range *
+subtree_search_write(struct mm_lock_range *range,
+		     unsigned long start, unsigned long end)
+{
+	while (true) {
+		/*
+		 * Loop invariant: start < range->__subtree.write_end
+		 * (Cond2 is satisfied by one of the subtree ranges)
+		 */
+		if (range->rb.rb_left) {
+			struct mm_lock_range *left = rb_entry(
+				range->rb.rb_left, struct mm_lock_range, rb);
+			if (start < left->__subtree.write_end) {
+				/*
+				 * Some ranges in left subtree satisfy Cond2.
+				 * Iterate to find the leftmost such range R.
+				 * If it also satisfies Cond1, that's the
+				 * match we are looking for. Otherwise, there
+				 * is no matching interval as ranges to the
+				 * right of R can't satisfy Cond1 either.
+				 */
+				range = left;
+				continue;
+			}
+		}
+		if (range->start < end) {		/* Cond1 */
+			if (start < range->end &&	/* Cond2 */
+			    range->flags_count & MM_LOCK_RANGE_WRITE)
+				return range;	/* range is leftmost match */
+			if (range->rb.rb_right) {
+				range = rb_entry(range->rb.rb_right,
+						 struct mm_lock_range, rb);
+				if (start < range->__subtree.write_end)
+					continue;
+			}
+		}
+		return NULL;	/* No match */
+	}
+}
+
+static struct mm_lock_range *
+iter_first_write(struct rb_root *root, unsigned long start, unsigned long end)
+{
+	struct mm_lock_range *range;
+
+	if (!root->rb_node)
+		return NULL;
+	range = rb_entry(root->rb_node, struct mm_lock_range, rb);
+	if (range->__subtree.write_end <= start)
+		return NULL;
+	return subtree_search_write(range, start, end);
+}
+
+static struct mm_lock_range *
+iter_next_write(struct mm_lock_range *range,
+		unsigned long start, unsigned long end)
+{
+	struct rb_node *rb = range->rb.rb_right, *prev;
+
+	while (true) {
+		/*
+		 * Loop invariants:
+		 *   Cond1: range->start < end
+		 *   rb == range->rb.rb_right
+		 *
+		 * First, search right subtree if suitable
+		 */
+		if (rb) {
+			struct mm_lock_range *right = rb_entry(
+				rb, struct mm_lock_range, rb);
+			if (start < right->__subtree.write_end)
+				return subtree_search_write(right, start, end);
+		}
+
+		/* Move up the tree until we come from a range's left child */
+		do {
+			rb = rb_parent(&range->rb);
+			if (!rb)
+				return NULL;
+			prev = &range->rb;
+			range = rb_entry(rb, struct mm_lock_range, rb);
+			rb = range->rb.rb_right;
+		} while (prev == rb);
+
+		/* Check if the range intersects [start;end) */
+		if (end <= range->start)		/* !Cond1 */
+			return NULL;
+		else if (start < range->end &&		/* Cond2 */
+			 range->flags_count & MM_LOCK_RANGE_WRITE)
+			return range;
+	}
+}
+
+#define FOR_EACH_RANGE_WRITE(mm, start, end, tmp)			\
+for (tmp = iter_first_write(&mm->mmap_sem.rb_root, start, end); tmp;	\
+     tmp = iter_next_write(tmp, start, end))
+
+static bool queue_read(struct mm_struct *mm, struct mm_lock_range *range)
+{
+	struct mm_lock_range *conflict;
+	long flags_count = 0;
+
+	FOR_EACH_RANGE_WRITE(mm, range->start, range->end, conflict)
+		flags_count -= MM_LOCK_RANGE_COUNT_ONE;
+	range->flags_count = flags_count;
+	insert_read(range, &mm->mmap_sem.rb_root);
+	return flags_count < 0;
+}
+
+static bool queue_write(struct mm_struct *mm, struct mm_lock_range *range)
+{
+	struct mm_lock_range *conflict;
+	long flags_count = MM_LOCK_RANGE_WRITE;
+
+	FOR_EACH_RANGE(mm, range->start, range->end, conflict)
+		flags_count -= MM_LOCK_RANGE_COUNT_ONE;
+	range->flags_count = flags_count;
+	insert_write(range, &mm->mmap_sem.rb_root);
+	return flags_count < 0;
+}
+
+static inline void prepare_wait(struct mm_lock_range *range, unsigned long seq)
+{
+	range->seq = seq;
+	range->task = current;
+}
+
+static void wait(struct mm_lock_range *range)
+{
+	while (true) {
+		set_current_state(TASK_UNINTERRUPTIBLE);
+		if (range->flags_count >= 0)
+			break;
+		schedule();
+	}
+	__set_current_state(TASK_RUNNING);
+}
+
+static bool wait_killable(struct mm_lock_range *range)
+{
+	while (true) {
+		set_current_state(TASK_INTERRUPTIBLE);
+		if (range->flags_count >= 0) {
+			__set_current_state(TASK_RUNNING);
+			return true;
+		}
+		if (signal_pending(current)) {
+			__set_current_state(TASK_RUNNING);
+			return false;
+		}
+		schedule();
+	}
+}
+
+static inline void unlock_conflict(struct mm_lock_range *range,
+				   struct wake_q_head *wake_q)
+{
+	if ((range->flags_count += MM_LOCK_RANGE_COUNT_ONE) >= 0)
+		wake_q_add(wake_q, range->task);
+}
+
+void mm_write_range_lock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+	bool contended;
+
+	lock_acquire_exclusive(&mm->mmap_sem.dep_map, 0, 0, NULL, _RET_IP_);
+
+	mutex_lock(&mm->mmap_sem.mutex);
+	if ((contended = queue_write(mm, range)))
+		prepare_wait(range, mm->mmap_sem.seq);
+	mutex_unlock(&mm->mmap_sem.mutex);
+
+	if (contended) {
+		lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+		wait(range);
+	}
+	lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_write_range_lock);
+
+#ifdef CONFIG_LOCKDEP
+void mm_write_range_lock_nested(struct mm_struct *mm,
+				struct mm_lock_range *range, int subclass)
+{
+	bool contended;
+
+	lock_acquire_exclusive(&mm->mmap_sem.dep_map, subclass, 0, NULL,
+			       _RET_IP_);
+
+	mutex_lock(&mm->mmap_sem.mutex);
+	if ((contended = queue_write(mm, range)))
+		prepare_wait(range, mm->mmap_sem.seq);
+	mutex_unlock(&mm->mmap_sem.mutex);
+
+	if (contended) {
+		lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+		wait(range);
+	}
+	lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_write_range_lock_nested);
+#endif
+
+int mm_write_range_lock_killable(struct mm_struct *mm,
+				 struct mm_lock_range *range)
+{
+	bool contended;
+
+	lock_acquire_exclusive(&mm->mmap_sem.dep_map, 0, 0, NULL, _RET_IP_);
+
+	mutex_lock(&mm->mmap_sem.mutex);
+	if ((contended = queue_write(mm, range)))
+		prepare_wait(range, ++(mm->mmap_sem.seq));
+	mutex_unlock(&mm->mmap_sem.mutex);
+
+	if (contended) {
+		lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+		if (!wait_killable(range)) {
+			struct mm_lock_range *conflict;
+			DEFINE_WAKE_Q(wake_q);
+
+			mutex_lock(&mm->mmap_sem.mutex);
+			remove(range, &mm->mmap_sem.rb_root);
+			FOR_EACH_RANGE(mm, range->start, range->end, conflict)
+				if (conflict->flags_count < 0 &&
+				    conflict->seq - range->seq <= (~0UL >> 1))
+					unlock_conflict(conflict, &wake_q);
+			mutex_unlock(&mm->mmap_sem.mutex);
+
+			wake_up_q(&wake_q);
+			lock_release(&mm->mmap_sem.dep_map, _RET_IP_);
+			return -EINTR;
+		}
+	}
+	lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+	return 0;
+}
+EXPORT_SYMBOL(mm_write_range_lock_killable);
+
+bool mm_write_range_trylock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+	bool locked = false;
+
+	if (!mutex_trylock(&mm->mmap_sem.mutex))
+		goto exit;
+	if (iter_first(&mm->mmap_sem.rb_root, range->start, range->end))
+		goto unlock;
+	lock_acquire_exclusive(&mm->mmap_sem.dep_map, 0, 1, NULL,
+			       _RET_IP_);
+	range->flags_count = MM_LOCK_RANGE_WRITE;
+	insert_write(range, &mm->mmap_sem.rb_root);
+	locked = true;
+unlock:
+	mutex_unlock(&mm->mmap_sem.mutex);
+exit:
+	return locked;
+}
+EXPORT_SYMBOL(mm_write_range_trylock);
+
+void mm_write_range_unlock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+	struct mm_lock_range *conflict;
+	DEFINE_WAKE_Q(wake_q);
+
+	mutex_lock(&mm->mmap_sem.mutex);
+	remove(range, &mm->mmap_sem.rb_root);
+        FOR_EACH_RANGE(mm, range->start, range->end, conflict)
+		unlock_conflict(conflict, &wake_q);
+	mutex_unlock(&mm->mmap_sem.mutex);
+
+	wake_up_q(&wake_q);
+	lock_release(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_write_range_unlock);
+
+void mm_downgrade_write_range_lock(struct mm_struct *mm,
+				   struct mm_lock_range *range)
+{
+	struct mm_lock_range *conflict;
+	DEFINE_WAKE_Q(wake_q);
+
+	mutex_lock(&mm->mmap_sem.mutex);
+        FOR_EACH_RANGE_READ(mm, range->start, range->end, conflict)
+		unlock_conflict(conflict, &wake_q);
+	range->flags_count -= MM_LOCK_RANGE_WRITE;
+	augment_propagate(&range->rb, NULL);
+	mutex_unlock(&mm->mmap_sem.mutex);
+
+	wake_up_q(&wake_q);
+	lock_downgrade(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_downgrade_write_range_lock);
+
+void mm_read_range_lock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+	bool contended;
+
+	lock_acquire_shared(&mm->mmap_sem.dep_map, 0, 0, NULL, _RET_IP_);
+
+	mutex_lock(&mm->mmap_sem.mutex);
+	if ((contended = queue_read(mm, range)))
+		prepare_wait(range, mm->mmap_sem.seq);
+	mutex_unlock(&mm->mmap_sem.mutex);
+
+	if (contended) {
+		lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+		wait(range);
+	}
+	lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_read_range_lock);
+
+int mm_read_range_lock_killable(struct mm_struct *mm,
+				struct mm_lock_range *range)
+{
+	bool contended;
+
+	lock_acquire_shared(&mm->mmap_sem.dep_map, 0, 0, NULL, _RET_IP_);
+
+	mutex_lock(&mm->mmap_sem.mutex);
+	if ((contended = queue_read(mm, range)))
+		prepare_wait(range, ++(mm->mmap_sem.seq));
+	mutex_unlock(&mm->mmap_sem.mutex);
+
+	if (contended) {
+		lock_contended(&mm->mmap_sem.dep_map, _RET_IP_);
+		if (!wait_killable(range)) {
+			struct mm_lock_range *conflict;
+			DEFINE_WAKE_Q(wake_q);
+
+			mutex_lock(&mm->mmap_sem.mutex);
+			remove(range, &mm->mmap_sem.rb_root);
+			FOR_EACH_RANGE_WRITE(mm, range->start, range->end,
+					     conflict)
+				if (conflict->flags_count < 0 &&
+				    conflict->seq - range->seq <= (~0UL >> 1))
+					unlock_conflict(conflict, &wake_q);
+			mutex_unlock(&mm->mmap_sem.mutex);
+
+			wake_up_q(&wake_q);
+			lock_release(&mm->mmap_sem.dep_map, _RET_IP_);
+			return -EINTR;
+		}
+	}
+	lock_acquired(&mm->mmap_sem.dep_map, _RET_IP_);
+	return 0;
+}
+EXPORT_SYMBOL(mm_read_range_lock_killable);
+
+bool mm_read_range_trylock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+	bool locked = false;
+
+	if (!mutex_trylock(&mm->mmap_sem.mutex))
+		goto exit;
+	if (iter_first_write(&mm->mmap_sem.rb_root, range->start, range->end))
+		goto unlock;
+	lock_acquire_shared(&mm->mmap_sem.dep_map, 0, 1, NULL, _RET_IP_);
+	range->flags_count = 0;
+	insert_read(range, &mm->mmap_sem.rb_root);
+	locked = true;
+unlock:
+	mutex_unlock(&mm->mmap_sem.mutex);
+exit:
+	return locked;
+}
+EXPORT_SYMBOL(mm_read_range_trylock);
+
+void mm_read_range_unlock_non_owner(struct mm_struct *mm,
+				    struct mm_lock_range *range)
+{
+	struct mm_lock_range *conflict;
+	DEFINE_WAKE_Q(wake_q);
+
+	mutex_lock(&mm->mmap_sem.mutex);
+	remove(range, &mm->mmap_sem.rb_root);
+        FOR_EACH_RANGE_WRITE(mm, range->start, range->end, conflict)
+		unlock_conflict(conflict, &wake_q);
+	mutex_unlock(&mm->mmap_sem.mutex);
+
+	wake_up_q(&wake_q);
+}
+EXPORT_SYMBOL(mm_read_range_unlock_non_owner);
+
+void mm_read_range_unlock(struct mm_struct *mm, struct mm_lock_range *range)
+{
+	mm_read_range_unlock_non_owner(mm, range);
+	lock_release(&mm->mmap_sem.dep_map, _RET_IP_);
+}
+EXPORT_SYMBOL(mm_read_range_unlock);