diff mbox series

[Fix,2/3] maple_tree: Change spanning store to work on larger trees

Message ID 20220615141921.417598-3-Liam.Howlett@oracle.com (mailing list archive)
State New
Headers show
Series Maple tree spanning fixes | expand

Commit Message

Liam R. Howlett June 15, 2022, 2:19 p.m. UTC
Spanning store had an issue which could lead to double free during a
large tree modification.  Fix this by being more careful about how nodes
are added to the to-be-freed and to-be-destroyed list on this operation.

Reported-by: Qian Cai <quic_qiancai@quicinc.com>
Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com>
---
 lib/maple_tree.c | 325 ++++++++++++++++++++++++++++++-----------------
 1 file changed, 206 insertions(+), 119 deletions(-)
diff mbox series

Patch

diff --git a/lib/maple_tree.c b/lib/maple_tree.c
index a1035963ae0d..f413b6f0da2b 100644
--- a/lib/maple_tree.c
+++ b/lib/maple_tree.c
@@ -1388,7 +1388,6 @@  static inline unsigned char ma_data_end(struct maple_node *node,
 /*
  * mas_data_end() - Find the end of the data (slot).
  * @mas: the maple state
- * @type: the type of maple node
  *
  * This method is optimized to check the metadata of a node if the node type
  * supports data end metadata.
@@ -2272,6 +2271,31 @@  static inline void mas_wr_node_walk(struct ma_wr_state *wr_mas)
 	wr_mas->offset_end = mas->offset = offset;
 }
 
+/*
+ * mas_topiary_range() - Add a range of slots to the topiary.
+ * @mas: The maple state
+ * @destroy: The topiary to add the slots (usually destroy)
+ * @start: The starting slot inclusively
+ * @end: The end slot inclusively
+ */
+static inline void mas_topiary_range(struct ma_state *mas,
+	struct ma_topiary *destroy, unsigned char start, unsigned char end)
+{
+	void __rcu **slots;
+	unsigned offset;
+
+	MT_BUG_ON(mas->tree, mte_is_leaf(mas->node));
+	slots = ma_slots(mas_mn(mas), mte_node_type(mas->node));
+	for (offset = start; offset <= end; offset++) {
+		struct maple_enode *enode = mas_slot_locked(mas, slots, offset);
+
+		if (mte_dead_node(enode))
+			continue;
+
+		mat_add(destroy, enode);
+	}
+}
+
 /*
  * mast_topiary() - Add the portions of the tree to the removal list; either to
  * be freed or discarded (destroy walk).
@@ -2280,48 +2304,62 @@  static inline void mas_wr_node_walk(struct ma_wr_state *wr_mas)
 static inline void mast_topiary(struct maple_subtree_state *mast)
 {
 	MA_WR_STATE(wr_mas, mast->orig_l, NULL);
-	unsigned char l_off, r_off, offset;
-	unsigned long l_index;
-	struct maple_enode *child;
-	void __rcu **slots;
+	unsigned char r_start, r_end;
+	unsigned char l_start, l_end;
+	void **l_slots, **r_slots;
 
 	wr_mas.type = mte_node_type(mast->orig_l->node);
-	/* The left node is consumed, so add to the free list. */
-	l_index = mast->orig_l->index;
 	mast->orig_l->index = mast->orig_l->last;
 	mas_wr_node_walk(&wr_mas);
-	mast->orig_l->index = l_index;
-	l_off = mast->orig_l->offset;
-	r_off = mast->orig_r->offset;
-	if (mast->orig_l->node == mast->orig_r->node) {
-		slots = ma_slots(mte_to_node(mast->orig_l->node), wr_mas.type);
-		for (offset = l_off + 1; offset < r_off; offset++)
-			mat_add(mast->destroy, mas_slot_locked(mast->orig_l,
-							slots, offset));
+	l_start = mast->orig_l->offset + 1;
+	l_end = mas_data_end(mast->orig_l);
+	r_start = 0;
+	r_end = mast->orig_r->offset;
+
+	if (r_end)
+		r_end--;
+
+	l_slots = ma_slots(mas_mn(mast->orig_l),
+			   mte_node_type(mast->orig_l->node));
+
+	r_slots = ma_slots(mas_mn(mast->orig_r),
+			   mte_node_type(mast->orig_r->node));
 
+	if ((l_start < l_end) &&
+	    mte_dead_node(mas_slot_locked(mast->orig_l, l_slots, l_start))) {
+		l_start++;
+	}
+
+	if (mte_dead_node(mas_slot_locked(mast->orig_r, r_slots, r_end))) {
+		if (r_end)
+			r_end--;
+	}
+
+	if ((l_start > r_end) && (mast->orig_l->node == mast->orig_r->node))
 		return;
+
+	/* At the node where left and right sides meet, add the parts between */
+	if (mast->orig_l->node == mast->orig_r->node) {
+		return mas_topiary_range(mast->orig_l, mast->destroy,
+					     l_start, r_end);
 	}
 
 	/* mast->orig_r is different and consumed. */
 	if (mte_is_leaf(mast->orig_r->node))
 		return;
 
-	/* Now destroy l_off + 1 -> end and 0 -> r_off - 1 */
-	offset = l_off + 1;
-	slots = ma_slots(mte_to_node(mast->orig_l->node), wr_mas.type);
-	while (offset < mt_slots[wr_mas.type]) {
-		child = mas_slot_locked(mast->orig_l, slots, offset++);
-		if (!child)
-			break;
+	if (mte_dead_node(mas_slot_locked(mast->orig_l, l_slots, l_end)))
+		l_end--;
 
-		mat_add(mast->destroy, child);
-	}
 
-	slots = ma_slots(mte_to_node(mast->orig_r->node),
-			     mte_node_type(mast->orig_r->node));
-	for (offset = 0; offset < r_off; offset++)
-		mat_add(mast->destroy,
-				mas_slot_locked(mast->orig_l, slots, offset));
+	if (l_start <= l_end)
+		mas_topiary_range(mast->orig_l, mast->destroy, l_start, l_end);
+
+	if (mte_dead_node(mas_slot_locked(mast->orig_r, r_slots, r_start)))
+		r_start++;
+
+	if (r_start <= r_end)
+		mas_topiary_range(mast->orig_r, mast->destroy, 0, r_end);
 }
 
 /*
@@ -2329,19 +2367,13 @@  static inline void mast_topiary(struct maple_subtree_state *mast)
  * @mast: The maple subtree state
  * @old_r: The encoded maple node to the right (next node).
  */
-static inline void mast_rebalance_next(struct maple_subtree_state *mast,
-				       struct maple_enode *old_r, bool free)
+static inline void mast_rebalance_next(struct maple_subtree_state *mast)
 {
 	unsigned char b_end = mast->bn->b_end;
 
 	mas_mab_cp(mast->orig_r, 0, mt_slot_count(mast->orig_r->node),
 		   mast->bn, b_end);
-	if (free)
-		mat_add(mast->free, old_r);
-
 	mast->orig_r->last = mast->orig_r->max;
-	if (old_r == mast->orig_l->node)
-		mast->orig_l->node = mast->orig_r->node;
 }
 
 /*
@@ -2349,17 +2381,13 @@  static inline void mast_rebalance_next(struct maple_subtree_state *mast,
  * @mast: The maple subtree state
  * @old_l: The encoded maple node to the left (previous node)
  */
-static inline void mast_rebalance_prev(struct maple_subtree_state *mast,
-				       struct maple_enode *old_l)
+static inline void mast_rebalance_prev(struct maple_subtree_state *mast)
 {
 	unsigned char end = mas_data_end(mast->orig_l) + 1;
 	unsigned char b_end = mast->bn->b_end;
 
 	mab_shift_right(mast->bn, end);
 	mas_mab_cp(mast->orig_l, 0, end - 1, mast->bn, 0);
-	mat_add(mast->free, old_l);
-	if (mast->orig_r->node == old_l)
-		mast->orig_r->node = mast->orig_l->node;
 	mast->l->min = mast->orig_l->min;
 	mast->orig_l->index = mast->orig_l->min;
 	mast->bn->b_end = end + b_end;
@@ -2367,68 +2395,116 @@  static inline void mast_rebalance_prev(struct maple_subtree_state *mast,
 }
 
 /*
- * mast_sibling_rebalance_right() - Rebalance from nodes with the same parents.
- * Check the right side, then the left.  Data is copied into the @mast->bn.
+ * mast_spanning_rebalance() - Rebalance nodes with nearest neighbour favouring
+ * the node to the right.  Checking the nodes to the right then the left at each
+ * level upwards until root is reached.  Free and destroy as needed.
+ * Data is copied into the @mast->bn.
  * @mast: The maple_subtree_state.
  */
 static inline
-bool mast_sibling_rebalance_right(struct maple_subtree_state *mast, bool free)
+bool mast_spanning_rebalance(struct maple_subtree_state *mast)
 {
-	struct maple_enode *old_r;
-	struct maple_enode *old_l;
+	struct ma_state r_tmp = *mast->orig_r;
+	struct ma_state l_tmp = *mast->orig_l;
+	struct maple_enode *ancestor = NULL;
+	unsigned char start, end;
+	unsigned char depth = 0;
 
-	old_r = mast->orig_r->node;
-	if (mas_next_sibling(mast->orig_r)) {
-		mast_rebalance_next(mast, old_r, free);
-		return true;
-	}
+	r_tmp = *mast->orig_r;
+	l_tmp = *mast->orig_l;
+	do {
+		mas_ascend(mast->orig_r);
+		mas_ascend(mast->orig_l);
+		depth++;
+		if (!ancestor &&
+		    (mast->orig_r->node == mast->orig_l->node)) {
+			ancestor = mast->orig_r->node;
+			end = mast->orig_r->offset - 1;
+			start = mast->orig_l->offset + 1;
+		}
 
-	old_l = mast->orig_l->node;
-	if (mas_prev_sibling(mast->orig_l)) {
-		mast->bn->type = mte_node_type(mast->orig_l->node);
-		mast_rebalance_prev(mast, old_l);
-		return true;
-	}
+		if (mast->orig_r->offset < mas_data_end(mast->orig_r)) {
+			if (!ancestor) {
+				ancestor = mast->orig_r->node;
+				start = 0;
+			}
 
-	return false;
-}
+			mast->orig_r->offset++;
+			do {
+				mas_descend(mast->orig_r);
+				mast->orig_r->offset = 0;
+				depth--;
+			} while (depth);
 
-static inline int mas_prev_node(struct ma_state *mas, unsigned long min);
-static inline int mas_next_node(struct ma_state *mas, struct maple_node *node,
-				unsigned long max);
-/*
- * mast_cousin_rebalance_right() - Rebalance from nodes with different parents.
- * Check the right side, then the left.  Data is copied into the @mast->bn.
- * @mast: The maple_subtree_state.
- */
-static inline
-bool mast_cousin_rebalance_right(struct maple_subtree_state *mast, bool free)
-{
-	struct maple_enode *old_l = mast->orig_l->node;
-	struct maple_enode *old_r = mast->orig_r->node;
+			mast_rebalance_next(mast);
+			do {
+				unsigned char l_off = 0;
+				struct maple_enode *child = r_tmp.node;
 
-	MA_STATE(tmp, mast->orig_r->tree, mast->orig_r->index, mast->orig_r->last);
+				mas_ascend(&r_tmp);
+				if (ancestor == r_tmp.node)
+					l_off = start;
 
-	tmp = *mast->orig_r;
-	mas_next_node(mast->orig_r, mas_mn(mast->orig_r), ULONG_MAX);
-	if (!mas_is_none(mast->orig_r)) {
-		mast_rebalance_next(mast, old_r, free);
-		return true;
-	}
+				if (r_tmp.offset)
+					r_tmp.offset--;
 
-	*mast->orig_r = *mast->orig_l;
-	*mast->r = *mast->l;
-	mas_prev_node(mast->orig_l, 0);
-	if (mas_is_none(mast->orig_l)) {
-		/* Making a new root with the contents of mast->bn */
-		*mast->orig_l = *mast->orig_r;
-		*mast->orig_r = tmp;
-		return false;
-	}
+				if (l_off < r_tmp.offset)
+					mas_topiary_range(&r_tmp, mast->destroy,
+							  l_off, r_tmp.offset);
 
-	mast->orig_l->offset = 0;
-	mast_rebalance_prev(mast, old_l);
-	return true;
+				if (l_tmp.node != child)
+					mat_add(mast->free, child);
+
+			} while (r_tmp.node != ancestor);
+
+			*mast->orig_l = l_tmp;
+			return true;
+
+		} else if (mast->orig_l->offset != 0) {
+			if (!ancestor) {
+				ancestor = mast->orig_l->node;
+				end = mas_data_end(mast->orig_l);
+			}
+
+			mast->orig_l->offset--;
+			do {
+				mas_descend(mast->orig_l);
+				mast->orig_l->offset =
+					mas_data_end(mast->orig_l);
+				depth--;
+			} while (depth);
+
+			mast_rebalance_prev(mast);
+			do {
+				unsigned char r_off;
+				struct maple_enode *child = l_tmp.node;
+
+				mas_ascend(&l_tmp);
+				if (ancestor == l_tmp.node)
+					r_off = end;
+				else
+					r_off = mas_data_end(&l_tmp);
+
+				if (l_tmp.offset < r_off)
+					l_tmp.offset++;
+
+				if (l_tmp.offset < r_off)
+					mas_topiary_range(&l_tmp, mast->destroy,
+							  l_tmp.offset, r_off);
+
+				if (r_tmp.node != child)
+					mat_add(mast->free, child);
+
+			} while (l_tmp.node != ancestor);
+
+			*mast->orig_r = r_tmp;
+			return true;
+		}
+	} while (!mte_is_root(mast->orig_r->node));
+
+	*mast->orig_r = r_tmp;
+	*mast->orig_l = l_tmp;
+	return false;
 }
 
 /*
@@ -2462,18 +2538,16 @@  mast_ascend_free(struct maple_subtree_state *mast)
 	 * The node may not contain the value so set slot to ensure all
 	 * of the nodes contents are freed or destroyed.
 	 */
-	if (mast->orig_r->max < mast->orig_r->last)
-		mast->orig_r->offset = mas_data_end(mast->orig_r) + 1;
-	else {
-		wr_mas.type = mte_node_type(mast->orig_r->node);
-		mas_wr_node_walk(&wr_mas);
-	}
+	wr_mas.type = mte_node_type(mast->orig_r->node);
+	mas_wr_node_walk(&wr_mas);
 	/* Set up the left side of things */
 	mast->orig_l->offset = 0;
 	mast->orig_l->index = mast->l->min;
 	wr_mas.mas = mast->orig_l;
 	wr_mas.type = mte_node_type(mast->orig_l->node);
 	mas_wr_node_walk(&wr_mas);
+
+	mast->bn->type = wr_mas.type;
 }
 
 /*
@@ -2881,7 +2955,7 @@  static int mas_spanning_rebalance(struct ma_state *mas,
 	struct maple_enode *left = NULL, *middle = NULL, *right = NULL;
 
 	MA_STATE(l_mas, mas->tree, mas->index, mas->index);
-	MA_STATE(r_mas, mas->tree, mas->index, mas->index);
+	MA_STATE(r_mas, mas->tree, mas->index, mas->last);
 	MA_STATE(m_mas, mas->tree, mas->index, mas->index);
 	MA_TOPIARY(free, mas->tree);
 	MA_TOPIARY(destroy, mas->tree);
@@ -2897,14 +2971,9 @@  static int mas_spanning_rebalance(struct ma_state *mas,
 	mast->destroy = &destroy;
 	l_mas.node = r_mas.node = m_mas.node = MAS_NONE;
 	if (!mas_is_root_limits(mast->orig_l) &&
-	    unlikely(mast->bn->b_end <= mt_min_slots[mast->bn->type])) {
-		/*
-		 * Do not free the current node as it may be freed in a bulk
-		 * free.
-		 */
-		if (!mast_sibling_rebalance_right(mast, false))
-			mast_cousin_rebalance_right(mast, false);
-	}
+	    unlikely(mast->bn->b_end <= mt_min_slots[mast->bn->type]))
+		mast_spanning_rebalance(mast);
+
 	mast->orig_l->depth = 0;
 
 	/*
@@ -2948,6 +3017,15 @@  static int mas_spanning_rebalance(struct ma_state *mas,
 
 		/* Copy anything necessary out of the right node. */
 		mast_combine_cp_right(mast);
+		if (mte_dead_node(mast->orig_l->node) ||
+		    mte_dead_node(mast->orig_r->node)) {
+			printk("FUCKED.  l %p is %s and r %p is %s\n",
+			       mas_mn(mast->orig_l),
+			       mte_dead_node(mast->orig_l->node) ? "dead" : "alive",
+			       mas_mn(mast->orig_r),
+			       mte_dead_node(mast->orig_r->node) ? "dead" : "alive");
+			printk("Writing %lu-%lu\n", mas->index, mas->last);
+		}
 		mast_topiary(mast);
 		mast->orig_l->last = mast->orig_l->max;
 
@@ -2961,15 +3039,14 @@  static int mas_spanning_rebalance(struct ma_state *mas,
 		if (mas_is_root_limits(mast->orig_l))
 			break;
 
-		/* Try to get enough data for the next iteration. */
-		if (!mast_sibling_rebalance_right(mast, true))
-			if (!mast_cousin_rebalance_right(mast, true))
-				break;
+		if (!mast_spanning_rebalance(mast))
+			break;
 
 		/* rebalancing from other nodes may require another loop. */
 		if (!count)
 			count++;
 	}
+
 	l_mas.node = mt_mk_node(ma_mnode_ptr(mas_pop_node(mas)),
 				mte_node_type(mast->orig_l->node));
 	mast->orig_l->depth++;
@@ -3042,6 +3119,7 @@  static inline int mas_rebalance(struct ma_state *mas,
 	mast.orig_l = &l_mas;
 	mast.orig_r = &r_mas;
 	mast.bn = b_node;
+	mast.bn->type = mte_node_type(mas->node);
 
 	l_mas = r_mas = *mas;
 
@@ -3855,7 +3933,7 @@  static inline int mas_new_root(struct ma_state *mas, void *entry)
 	return 1;
 }
 /*
- * mas_spanning_store() - Create a subtree with the store operation completed
+ * mas_wr_spanning_store() - Create a subtree with the store operation completed
  * and new nodes where necessary, then place the sub-tree in the actual tree.
  * Note that mas is expected to point to the node which caused the store to
  * span.
@@ -3941,6 +4019,13 @@  static inline int mas_wr_spanning_store(struct ma_wr_state *wr_mas)
 	mast.bn = &b_node;
 	mast.orig_l = &l_mas;
 	mast.orig_r = &r_mas;
+		if (mte_dead_node(mast.orig_l->node) ||
+		    mte_dead_node(mast.orig_r->node)) {
+			printk("FUCKED.  l is %s and r is %s\n",
+		       mte_dead_node(mast.orig_l->node) ? "dead" : "alive",
+		       mte_dead_node(mast.orig_r->node) ? "dead" : "alive");
+			printk("Writing %lu-%lu\n", mas->index, mas->last);
+		}
 	/* Combine l_mas and r_mas and split them up evenly again. */
 	return mas_spanning_rebalance(mas, &mast, height + 1);
 }
@@ -5387,6 +5472,9 @@  static inline void __rcu **mas_destroy_descend(struct ma_state *mas,
 		node = mas_mn(mas);
 		slots = ma_slots(node, mte_node_type(mas->node));
 		next = mas_slot_locked(mas, slots, 0);
+		if ((mte_dead_node(next)))
+			next = mas_slot_locked(mas, slots, 1);
+
 		mte_set_node_dead(mas->node);
 		node->type = mte_node_type(mas->node);
 		node->piv_parent = prev;
@@ -5394,6 +5482,7 @@  static inline void __rcu **mas_destroy_descend(struct ma_state *mas,
 		offset = 0;
 		prev = mas->node;
 	} while (!mte_is_leaf(next));
+
 	return slots;
 }
 
@@ -5427,17 +5516,15 @@  static void mt_destroy_walk(struct maple_enode *enode, unsigned char ma_flags,
 		mas.node = node->piv_parent;
 		if (mas_mn(&mas) == node)
 			goto start_slots_free;
+
 		type = mte_node_type(mas.node);
 		slots = ma_slots(mte_to_node(mas.node), type);
-		if ((offset < mt_slots[type])) {
-			struct maple_enode *next = slots[offset];
+		if ((offset < mt_slots[type]) && mte_node_type(slots[offset]) &&
+		    mte_to_node(slots[offset])) {
+			struct maple_enode *parent = mas.node;
 
-			if (mte_node_type(next) && mte_to_node(next)) {
-				struct maple_enode *parent = mas.node;
-
-				mas.node = mas_slot_locked(&mas, slots, offset);
-				slots = mas_destroy_descend(&mas, parent, offset);
-			}
+			mas.node = mas_slot_locked(&mas, slots, offset);
+			slots = mas_destroy_descend(&mas, parent, offset);
 		}
 		node = mas_mn(&mas);
 	} while (start != mas.node);