diff mbox series

[v4,2/6] maple_tree: use height and depth consistently

Message ID 20250407184102.2155415-3-sidhartha.kumar@oracle.com (mailing list archive)
State New
Headers show
Series Track node vacancy to reduce worst case allocation counts | expand

Commit Message

Sidhartha Kumar April 7, 2025, 6:40 p.m. UTC
For the maple tree, the root node is defined to have a depth of 0 with a
height of 1. Each level down from the node, these values are incremented
by 1. Various code paths define a root with depth 1 which is inconsisent
with the definition. Modify the code to be consistent with this
definition.

Signed-off-by: Sidhartha Kumar <sidhartha.kumar@oracle.com>
---
 lib/maple_tree.c                 | 82 +++++++++++++++++---------------
 tools/testing/radix-tree/maple.c | 19 ++++++++
 2 files changed, 63 insertions(+), 38 deletions(-)

Comments

Liam R. Howlett April 8, 2025, 6:02 p.m. UTC | #1
* Sidhartha Kumar <sidhartha.kumar@oracle.com> [250407 14:41]:
> For the maple tree, the root node is defined to have a depth of 0 with a
> height of 1. Each level down from the node, these values are incremented
> by 1. Various code paths define a root with depth 1 which is inconsisent
> with the definition. Modify the code to be consistent with this
> definition.

Small nit below about adding more detail to this log, but otherwise it
looks good.

> 
> Signed-off-by: Sidhartha Kumar <sidhartha.kumar@oracle.com>
> ---
>  lib/maple_tree.c                 | 82 +++++++++++++++++---------------
>  tools/testing/radix-tree/maple.c | 19 ++++++++
>  2 files changed, 63 insertions(+), 38 deletions(-)
> 
> diff --git a/lib/maple_tree.c b/lib/maple_tree.c
> index f25ee210d495..236f0579ca53 100644
> --- a/lib/maple_tree.c
> +++ b/lib/maple_tree.c
> @@ -211,14 +211,14 @@ static void ma_free_rcu(struct maple_node *node)
>  	call_rcu(&node->rcu, mt_free_rcu);
>  }
>  
> -static void mas_set_height(struct ma_state *mas)
> +static void mt_set_height(struct maple_tree *mt, unsigned char height)
>  {
> -	unsigned int new_flags = mas->tree->ma_flags;
> +	unsigned int new_flags = mt->ma_flags;
>  
>  	new_flags &= ~MT_FLAGS_HEIGHT_MASK;
> -	MAS_BUG_ON(mas, mas->depth > MAPLE_HEIGHT_MAX);
> -	new_flags |= mas->depth << MT_FLAGS_HEIGHT_OFFSET;
> -	mas->tree->ma_flags = new_flags;
> +	MT_BUG_ON(mt, height > MAPLE_HEIGHT_MAX);
> +	new_flags |= height << MT_FLAGS_HEIGHT_OFFSET;
> +	mt->ma_flags = new_flags;
>  }
>  
>  static unsigned int mas_mt_height(struct ma_state *mas)
> @@ -1371,7 +1371,7 @@ static inline struct maple_enode *mas_start(struct ma_state *mas)
>  		root = mas_root(mas);
>  		/* Tree with nodes */
>  		if (likely(xa_is_node(root))) {
> -			mas->depth = 1;
> +			mas->depth = 0;
>  			mas->status = ma_active;
>  			mas->node = mte_safe_root(root);
>  			mas->offset = 0;
> @@ -1712,9 +1712,10 @@ static inline void mas_adopt_children(struct ma_state *mas,
>   * node as dead.
>   * @mas: the maple state with the new node
>   * @old_enode: The old maple encoded node to replace.
> + * @new_height: if we are inserting a root node, update the height of the tree
>   */
>  static inline void mas_put_in_tree(struct ma_state *mas,
> -		struct maple_enode *old_enode)
> +		struct maple_enode *old_enode, char new_height)
>  	__must_hold(mas->tree->ma_lock)
>  {
>  	unsigned char offset;
> @@ -1723,7 +1724,7 @@ static inline void mas_put_in_tree(struct ma_state *mas,
>  	if (mte_is_root(mas->node)) {
>  		mas_mn(mas)->parent = ma_parent_ptr(mas_tree_parent(mas));
>  		rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
> -		mas_set_height(mas);
> +		mt_set_height(mas->tree, new_height);
>  	} else {
>  
>  		offset = mte_parent_slot(mas->node);
> @@ -1741,12 +1742,13 @@ static inline void mas_put_in_tree(struct ma_state *mas,
>   * the parent encoding to locate the maple node in the tree.
>   * @mas: the ma_state with @mas->node pointing to the new node.
>   * @old_enode: The old maple encoded node.
> + * @new_height: The new height of the tree as a result of the operation
>   */
>  static inline void mas_replace_node(struct ma_state *mas,
> -		struct maple_enode *old_enode)
> +		struct maple_enode *old_enode, unsigned char new_height)
>  	__must_hold(mas->tree->ma_lock)
>  {
> -	mas_put_in_tree(mas, old_enode);
> +	mas_put_in_tree(mas, old_enode, new_height);
>  	mas_free(mas, old_enode);
>  }
>  
> @@ -2536,10 +2538,11 @@ static inline void mas_topiary_node(struct ma_state *mas,
>   *
>   * @mas: The maple state pointing at the new data
>   * @old_enode: The maple encoded node being replaced
> + * @new_height: The new height of the tree as a result of the operation
>   *
>   */
>  static inline void mas_topiary_replace(struct ma_state *mas,
> -		struct maple_enode *old_enode)
> +		struct maple_enode *old_enode, unsigned char new_height)
>  {
>  	struct ma_state tmp[3], tmp_next[3];
>  	MA_TOPIARY(subtrees, mas->tree);
> @@ -2547,7 +2550,7 @@ static inline void mas_topiary_replace(struct ma_state *mas,
>  	int i, n;
>  
>  	/* Place data in tree & then mark node as old */
> -	mas_put_in_tree(mas, old_enode);
> +	mas_put_in_tree(mas, old_enode, new_height);
>  
>  	/* Update the parent pointers in the tree */
>  	tmp[0] = *mas;
> @@ -2631,14 +2634,15 @@ static inline void mas_topiary_replace(struct ma_state *mas,
>   * mas_wmb_replace() - Write memory barrier and replace
>   * @mas: The maple state
>   * @old_enode: The old maple encoded node that is being replaced.
> + * @new_height: The new height of the tree as a result of the operation
>   *
>   * Updates gap as necessary.
>   */
>  static inline void mas_wmb_replace(struct ma_state *mas,
> -		struct maple_enode *old_enode)
> +		struct maple_enode *old_enode, unsigned char new_height)
>  {
>  	/* Insert the new data in the tree */
> -	mas_topiary_replace(mas, old_enode);
> +	mas_topiary_replace(mas, old_enode, new_height);
>  
>  	if (mte_is_leaf(mas->node))
>  		return;
> @@ -2824,6 +2828,7 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>  {
>  	unsigned char split, mid_split;
>  	unsigned char slot = 0;
> +	unsigned char new_height = 0; /* used if node is a new root */
>  	struct maple_enode *left = NULL, *middle = NULL, *right = NULL;
>  	struct maple_enode *old_enode;
>  
> @@ -2866,6 +2871,7 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>  		mast_set_split_parents(mast, left, middle, right, split,
>  				       mid_split);
>  		mast_cp_to_nodes(mast, left, middle, right, split, mid_split);
> +		new_height++;
>  
>  		/*
>  		 * Copy data from next level in the tree to mast->bn from next
> @@ -2873,7 +2879,6 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>  		 */
>  		memset(mast->bn, 0, sizeof(struct maple_big_node));
>  		mast->bn->type = mte_node_type(left);
> -		l_mas.depth++;
>  
>  		/* Root already stored in l->node. */
>  		if (mas_is_root_limits(mast->l))
> @@ -2909,8 +2914,9 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>  
>  	l_mas.node = mt_mk_node(ma_mnode_ptr(mas_pop_node(mas)),
>  				mte_node_type(mast->orig_l->node));
> -	l_mas.depth++;
> +

Please add a comment in the git log as to why this moved so we can find
out later.

>  	mab_mas_cp(mast->bn, 0, mt_slots[mast->bn->type] - 1, &l_mas, true);
> +	new_height++;
>  	mas_set_parent(mas, left, l_mas.node, slot);
>  	if (middle)
>  		mas_set_parent(mas, middle, l_mas.node, ++slot);
> @@ -2933,7 +2939,7 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>  	mas->min = l_mas.min;
>  	mas->max = l_mas.max;
>  	mas->offset = l_mas.offset;
> -	mas_wmb_replace(mas, old_enode);
> +	mas_wmb_replace(mas, old_enode, new_height);
>  	mtree_range_walk(mas);
>  	return;
>  }
> @@ -3009,6 +3015,7 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end
>  	void __rcu **l_slots, **slots;
>  	unsigned long *l_pivs, *pivs, gap;
>  	bool in_rcu = mt_in_rcu(mas->tree);
> +	unsigned char new_height = mas_mt_height(mas);
>  
>  	MA_STATE(l_mas, mas->tree, mas->index, mas->last);
>  
> @@ -3103,7 +3110,7 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end
>  	mas_ascend(mas);
>  
>  	if (in_rcu) {
> -		mas_replace_node(mas, old_eparent);
> +		mas_replace_node(mas, old_eparent, new_height);
>  		mas_adopt_children(mas, mas->node);
>  	}
>  
> @@ -3114,10 +3121,9 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end
>   * mas_split_final_node() - Split the final node in a subtree operation.
>   * @mast: the maple subtree state
>   * @mas: The maple state
> - * @height: The height of the tree in case it's a new root.
>   */
>  static inline void mas_split_final_node(struct maple_subtree_state *mast,
> -					struct ma_state *mas, int height)
> +					struct ma_state *mas)
>  {
>  	struct maple_enode *ancestor;
>  
> @@ -3126,7 +3132,6 @@ static inline void mas_split_final_node(struct maple_subtree_state *mast,
>  			mast->bn->type = maple_arange_64;
>  		else
>  			mast->bn->type = maple_range_64;
> -		mas->depth = height;
>  	}
>  	/*
>  	 * Only a single node is used here, could be root.
> @@ -3214,7 +3219,6 @@ static inline void mast_split_data(struct maple_subtree_state *mast,
>   * mas_push_data() - Instead of splitting a node, it is beneficial to push the
>   * data to the right or left node if there is room.
>   * @mas: The maple state
> - * @height: The current height of the maple state
>   * @mast: The maple subtree state
>   * @left: Push left or not.
>   *
> @@ -3222,8 +3226,8 @@ static inline void mast_split_data(struct maple_subtree_state *mast,
>   *
>   * Return: True if pushed, false otherwise.
>   */
> -static inline bool mas_push_data(struct ma_state *mas, int height,
> -				 struct maple_subtree_state *mast, bool left)
> +static inline bool mas_push_data(struct ma_state *mas,
> +				struct maple_subtree_state *mast, bool left)
>  {
>  	unsigned char slot_total = mast->bn->b_end;
>  	unsigned char end, space, split;
> @@ -3280,7 +3284,7 @@ static inline bool mas_push_data(struct ma_state *mas, int height,
>  
>  	mast_split_data(mast, mas, split);
>  	mast_fill_bnode(mast, mas, 2);
> -	mas_split_final_node(mast, mas, height + 1);
> +	mas_split_final_node(mast, mas);
>  	return true;
>  }
>  
> @@ -3293,6 +3297,7 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>  {
>  	struct maple_subtree_state mast;
>  	int height = 0;
> +	unsigned int orig_height = mas_mt_height(mas);
>  	unsigned char mid_split, split = 0;
>  	struct maple_enode *old;
>  
> @@ -3319,7 +3324,6 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>  	MA_STATE(prev_r_mas, mas->tree, mas->index, mas->last);
>  
>  	trace_ma_op(__func__, mas);
> -	mas->depth = mas_mt_height(mas);
>  
>  	mast.l = &l_mas;
>  	mast.r = &r_mas;
> @@ -3327,9 +3331,9 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>  	mast.orig_r = &prev_r_mas;
>  	mast.bn = b_node;
>  
> -	while (height++ <= mas->depth) {
> +	while (height++ <= orig_height) {
>  		if (mt_slots[b_node->type] > b_node->b_end) {
> -			mas_split_final_node(&mast, mas, height);
> +			mas_split_final_node(&mast, mas);
>  			break;
>  		}
>  
> @@ -3344,11 +3348,15 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>  		 * is a significant savings.
>  		 */
>  		/* Try to push left. */
> -		if (mas_push_data(mas, height, &mast, true))
> +		if (mas_push_data(mas, &mast, true)) {
> +			height++;
>  			break;
> +		}
>  		/* Try to push right. */
> -		if (mas_push_data(mas, height, &mast, false))
> +		if (mas_push_data(mas, &mast, false)) {
> +			height++;
>  			break;
> +		}
>  
>  		split = mab_calc_split(mas, b_node, &mid_split);
>  		mast_split_data(&mast, mas, split);
> @@ -3365,7 +3373,7 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>  	/* Set the original node as dead */
>  	old = mas->node;
>  	mas->node = l_mas.node;
> -	mas_wmb_replace(mas, old);
> +	mas_wmb_replace(mas, old, height);
>  	mtree_range_walk(mas);
>  	return;
>  }
> @@ -3424,8 +3432,7 @@ static inline void mas_root_expand(struct ma_state *mas, void *entry)
>  	if (mas->last != ULONG_MAX)
>  		pivots[++slot] = ULONG_MAX;
>  
> -	mas->depth = 1;
> -	mas_set_height(mas);
> +	mt_set_height(mas->tree, 1);
>  	ma_set_meta(node, maple_leaf_64, 0, slot);
>  	/* swap the new root into the tree */
>  	rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
> @@ -3669,8 +3676,7 @@ static inline void mas_new_root(struct ma_state *mas, void *entry)
>  	WARN_ON_ONCE(mas->index || mas->last != ULONG_MAX);
>  
>  	if (!entry) {
> -		mas->depth = 0;
> -		mas_set_height(mas);
> +		mt_set_height(mas->tree, 0);
>  		rcu_assign_pointer(mas->tree->ma_root, entry);
>  		mas->status = ma_start;
>  		goto done;
> @@ -3684,8 +3690,7 @@ static inline void mas_new_root(struct ma_state *mas, void *entry)
>  	mas->status = ma_active;
>  	rcu_assign_pointer(slots[0], entry);
>  	pivots[0] = mas->last;
> -	mas->depth = 1;
> -	mas_set_height(mas);
> +	mt_set_height(mas->tree, 1);
>  	rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
>  
>  done:
> @@ -3804,6 +3809,7 @@ static inline void mas_wr_node_store(struct ma_wr_state *wr_mas,
>  	struct maple_node reuse, *newnode;
>  	unsigned char copy_size, node_pivots = mt_pivots[wr_mas->type];
>  	bool in_rcu = mt_in_rcu(mas->tree);
> +	unsigned char height = mas_mt_height(mas);
>  
>  	if (mas->last == wr_mas->end_piv)
>  		offset_end++; /* don't copy this offset */
> @@ -3860,7 +3866,7 @@ static inline void mas_wr_node_store(struct ma_wr_state *wr_mas,
>  		struct maple_enode *old_enode = mas->node;
>  
>  		mas->node = mt_mk_node(newnode, wr_mas->type);
> -		mas_replace_node(mas, old_enode);
> +		mas_replace_node(mas, old_enode, height);
>  	} else {
>  		memcpy(wr_mas->node, newnode, sizeof(struct maple_node));
>  	}
> diff --git a/tools/testing/radix-tree/maple.c b/tools/testing/radix-tree/maple.c
> index bc30050227fd..e0f8fabe8821 100644
> --- a/tools/testing/radix-tree/maple.c
> +++ b/tools/testing/radix-tree/maple.c
> @@ -36248,6 +36248,21 @@ static noinline void __init check_mtree_dup(struct maple_tree *mt)
>  
>  extern void test_kmem_cache_bulk(void);
>  
> +static inline void check_spanning_store_height(struct maple_tree *mt)
> +{
> +	int index = 0;
> +	MA_STATE(mas, mt, 0, 0);
> +	mas_lock(&mas);
> +	while (mt_height(mt) != 3) {
> +		mas_store_gfp(&mas, xa_mk_value(index), GFP_KERNEL);
> +		mas_set(&mas, ++index);
> +	}
> +	mas_set_range(&mas, 90, 140);
> +	mas_store_gfp(&mas, xa_mk_value(index), GFP_KERNEL);
> +	MT_BUG_ON(mt, mas_mt_height(&mas) != 2);
> +	mas_unlock(&mas);
> +}
> +
>  /* callback function used for check_nomem_writer_race() */
>  static void writer2(void *maple_tree)
>  {
> @@ -36414,6 +36429,10 @@ void farmer_tests(void)
>  	check_spanning_write(&tree);
>  	mtree_destroy(&tree);
>  
> +	mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE);
> +	check_spanning_store_height(&tree);
> +	mtree_destroy(&tree);
> +
>  	mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE);
>  	check_null_expand(&tree);
>  	mtree_destroy(&tree);
> -- 
> 2.43.0
>
Sidhartha Kumar April 9, 2025, 5:37 p.m. UTC | #2
On 4/7/25 2:40 PM, Sidhartha Kumar wrote:
> For the maple tree, the root node is defined to have a depth of 0 with a
> height of 1. Each level down from the node, these values are incremented
> by 1. Various code paths define a root with depth 1 which is inconsisent
> with the definition. Modify the code to be consistent with this
> definition.
> 

Hi Andrew,
could you add the following to the commit message:

In mas_spanning_rebalance(), l_mas.depth was being used to track the
height based on the number of iterations done in the main loop. This
information was then used in mas_put_in_tree() to set the height. Rather 
than overload the l_mas.depth field to track height, simply keep track 
of height in the local variable new_height and directly pass this to 
mas_wmb_replace() which will be passed into mas_put_in_tree(). This
allows us to remove references to l_mas.depth.


> Signed-off-by: Sidhartha Kumar <sidhartha.kumar@oracle.com>
> ---
>   lib/maple_tree.c                 | 82 +++++++++++++++++---------------
>   tools/testing/radix-tree/maple.c | 19 ++++++++
>   2 files changed, 63 insertions(+), 38 deletions(-)
> 
> diff --git a/lib/maple_tree.c b/lib/maple_tree.c
> index f25ee210d495..236f0579ca53 100644
> --- a/lib/maple_tree.c
> +++ b/lib/maple_tree.c
> @@ -211,14 +211,14 @@ static void ma_free_rcu(struct maple_node *node)
>   	call_rcu(&node->rcu, mt_free_rcu);
>   }
>   
> -static void mas_set_height(struct ma_state *mas)
> +static void mt_set_height(struct maple_tree *mt, unsigned char height)
>   {
> -	unsigned int new_flags = mas->tree->ma_flags;
> +	unsigned int new_flags = mt->ma_flags;
>   
>   	new_flags &= ~MT_FLAGS_HEIGHT_MASK;
> -	MAS_BUG_ON(mas, mas->depth > MAPLE_HEIGHT_MAX);
> -	new_flags |= mas->depth << MT_FLAGS_HEIGHT_OFFSET;
> -	mas->tree->ma_flags = new_flags;
> +	MT_BUG_ON(mt, height > MAPLE_HEIGHT_MAX);
> +	new_flags |= height << MT_FLAGS_HEIGHT_OFFSET;
> +	mt->ma_flags = new_flags;
>   }
>   
>   static unsigned int mas_mt_height(struct ma_state *mas)
> @@ -1371,7 +1371,7 @@ static inline struct maple_enode *mas_start(struct ma_state *mas)
>   		root = mas_root(mas);
>   		/* Tree with nodes */
>   		if (likely(xa_is_node(root))) {
> -			mas->depth = 1;
> +			mas->depth = 0;
>   			mas->status = ma_active;
>   			mas->node = mte_safe_root(root);
>   			mas->offset = 0;
> @@ -1712,9 +1712,10 @@ static inline void mas_adopt_children(struct ma_state *mas,
>    * node as dead.
>    * @mas: the maple state with the new node
>    * @old_enode: The old maple encoded node to replace.
> + * @new_height: if we are inserting a root node, update the height of the tree
>    */
>   static inline void mas_put_in_tree(struct ma_state *mas,
> -		struct maple_enode *old_enode)
> +		struct maple_enode *old_enode, char new_height)
>   	__must_hold(mas->tree->ma_lock)
>   {
>   	unsigned char offset;
> @@ -1723,7 +1724,7 @@ static inline void mas_put_in_tree(struct ma_state *mas,
>   	if (mte_is_root(mas->node)) {
>   		mas_mn(mas)->parent = ma_parent_ptr(mas_tree_parent(mas));
>   		rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
> -		mas_set_height(mas);
> +		mt_set_height(mas->tree, new_height);
>   	} else {
>   
>   		offset = mte_parent_slot(mas->node);
> @@ -1741,12 +1742,13 @@ static inline void mas_put_in_tree(struct ma_state *mas,
>    * the parent encoding to locate the maple node in the tree.
>    * @mas: the ma_state with @mas->node pointing to the new node.
>    * @old_enode: The old maple encoded node.
> + * @new_height: The new height of the tree as a result of the operation
>    */
>   static inline void mas_replace_node(struct ma_state *mas,
> -		struct maple_enode *old_enode)
> +		struct maple_enode *old_enode, unsigned char new_height)
>   	__must_hold(mas->tree->ma_lock)
>   {
> -	mas_put_in_tree(mas, old_enode);
> +	mas_put_in_tree(mas, old_enode, new_height);
>   	mas_free(mas, old_enode);
>   }
>   
> @@ -2536,10 +2538,11 @@ static inline void mas_topiary_node(struct ma_state *mas,
>    *
>    * @mas: The maple state pointing at the new data
>    * @old_enode: The maple encoded node being replaced
> + * @new_height: The new height of the tree as a result of the operation
>    *
>    */
>   static inline void mas_topiary_replace(struct ma_state *mas,
> -		struct maple_enode *old_enode)
> +		struct maple_enode *old_enode, unsigned char new_height)
>   {
>   	struct ma_state tmp[3], tmp_next[3];
>   	MA_TOPIARY(subtrees, mas->tree);
> @@ -2547,7 +2550,7 @@ static inline void mas_topiary_replace(struct ma_state *mas,
>   	int i, n;
>   
>   	/* Place data in tree & then mark node as old */
> -	mas_put_in_tree(mas, old_enode);
> +	mas_put_in_tree(mas, old_enode, new_height);
>   
>   	/* Update the parent pointers in the tree */
>   	tmp[0] = *mas;
> @@ -2631,14 +2634,15 @@ static inline void mas_topiary_replace(struct ma_state *mas,
>    * mas_wmb_replace() - Write memory barrier and replace
>    * @mas: The maple state
>    * @old_enode: The old maple encoded node that is being replaced.
> + * @new_height: The new height of the tree as a result of the operation
>    *
>    * Updates gap as necessary.
>    */
>   static inline void mas_wmb_replace(struct ma_state *mas,
> -		struct maple_enode *old_enode)
> +		struct maple_enode *old_enode, unsigned char new_height)
>   {
>   	/* Insert the new data in the tree */
> -	mas_topiary_replace(mas, old_enode);
> +	mas_topiary_replace(mas, old_enode, new_height);
>   
>   	if (mte_is_leaf(mas->node))
>   		return;
> @@ -2824,6 +2828,7 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>   {
>   	unsigned char split, mid_split;
>   	unsigned char slot = 0;
> +	unsigned char new_height = 0; /* used if node is a new root */
>   	struct maple_enode *left = NULL, *middle = NULL, *right = NULL;
>   	struct maple_enode *old_enode;
>   
> @@ -2866,6 +2871,7 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>   		mast_set_split_parents(mast, left, middle, right, split,
>   				       mid_split);
>   		mast_cp_to_nodes(mast, left, middle, right, split, mid_split);
> +		new_height++;
>   
>   		/*
>   		 * Copy data from next level in the tree to mast->bn from next
> @@ -2873,7 +2879,6 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>   		 */
>   		memset(mast->bn, 0, sizeof(struct maple_big_node));
>   		mast->bn->type = mte_node_type(left);
> -		l_mas.depth++;
>   
>   		/* Root already stored in l->node. */
>   		if (mas_is_root_limits(mast->l))
> @@ -2909,8 +2914,9 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>   
>   	l_mas.node = mt_mk_node(ma_mnode_ptr(mas_pop_node(mas)),
>   				mte_node_type(mast->orig_l->node));
> -	l_mas.depth++;
> +
>   	mab_mas_cp(mast->bn, 0, mt_slots[mast->bn->type] - 1, &l_mas, true);
> +	new_height++;
>   	mas_set_parent(mas, left, l_mas.node, slot);
>   	if (middle)
>   		mas_set_parent(mas, middle, l_mas.node, ++slot);
> @@ -2933,7 +2939,7 @@ static void mas_spanning_rebalance(struct ma_state *mas,
>   	mas->min = l_mas.min;
>   	mas->max = l_mas.max;
>   	mas->offset = l_mas.offset;
> -	mas_wmb_replace(mas, old_enode);
> +	mas_wmb_replace(mas, old_enode, new_height);
>   	mtree_range_walk(mas);
>   	return;
>   }
> @@ -3009,6 +3015,7 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end
>   	void __rcu **l_slots, **slots;
>   	unsigned long *l_pivs, *pivs, gap;
>   	bool in_rcu = mt_in_rcu(mas->tree);
> +	unsigned char new_height = mas_mt_height(mas);
>   
>   	MA_STATE(l_mas, mas->tree, mas->index, mas->last);
>   
> @@ -3103,7 +3110,7 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end
>   	mas_ascend(mas);
>   
>   	if (in_rcu) {
> -		mas_replace_node(mas, old_eparent);
> +		mas_replace_node(mas, old_eparent, new_height);
>   		mas_adopt_children(mas, mas->node);
>   	}
>   
> @@ -3114,10 +3121,9 @@ static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end
>    * mas_split_final_node() - Split the final node in a subtree operation.
>    * @mast: the maple subtree state
>    * @mas: The maple state
> - * @height: The height of the tree in case it's a new root.
>    */
>   static inline void mas_split_final_node(struct maple_subtree_state *mast,
> -					struct ma_state *mas, int height)
> +					struct ma_state *mas)
>   {
>   	struct maple_enode *ancestor;
>   
> @@ -3126,7 +3132,6 @@ static inline void mas_split_final_node(struct maple_subtree_state *mast,
>   			mast->bn->type = maple_arange_64;
>   		else
>   			mast->bn->type = maple_range_64;
> -		mas->depth = height;
>   	}
>   	/*
>   	 * Only a single node is used here, could be root.
> @@ -3214,7 +3219,6 @@ static inline void mast_split_data(struct maple_subtree_state *mast,
>    * mas_push_data() - Instead of splitting a node, it is beneficial to push the
>    * data to the right or left node if there is room.
>    * @mas: The maple state
> - * @height: The current height of the maple state
>    * @mast: The maple subtree state
>    * @left: Push left or not.
>    *
> @@ -3222,8 +3226,8 @@ static inline void mast_split_data(struct maple_subtree_state *mast,
>    *
>    * Return: True if pushed, false otherwise.
>    */
> -static inline bool mas_push_data(struct ma_state *mas, int height,
> -				 struct maple_subtree_state *mast, bool left)
> +static inline bool mas_push_data(struct ma_state *mas,
> +				struct maple_subtree_state *mast, bool left)
>   {
>   	unsigned char slot_total = mast->bn->b_end;
>   	unsigned char end, space, split;
> @@ -3280,7 +3284,7 @@ static inline bool mas_push_data(struct ma_state *mas, int height,
>   
>   	mast_split_data(mast, mas, split);
>   	mast_fill_bnode(mast, mas, 2);
> -	mas_split_final_node(mast, mas, height + 1);
> +	mas_split_final_node(mast, mas);
>   	return true;
>   }
>   
> @@ -3293,6 +3297,7 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>   {
>   	struct maple_subtree_state mast;
>   	int height = 0;
> +	unsigned int orig_height = mas_mt_height(mas);
>   	unsigned char mid_split, split = 0;
>   	struct maple_enode *old;
>   
> @@ -3319,7 +3324,6 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>   	MA_STATE(prev_r_mas, mas->tree, mas->index, mas->last);
>   
>   	trace_ma_op(__func__, mas);
> -	mas->depth = mas_mt_height(mas);
>   
>   	mast.l = &l_mas;
>   	mast.r = &r_mas;
> @@ -3327,9 +3331,9 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>   	mast.orig_r = &prev_r_mas;
>   	mast.bn = b_node;
>   
> -	while (height++ <= mas->depth) {
> +	while (height++ <= orig_height) {
>   		if (mt_slots[b_node->type] > b_node->b_end) {
> -			mas_split_final_node(&mast, mas, height);
> +			mas_split_final_node(&mast, mas);
>   			break;
>   		}
>   
> @@ -3344,11 +3348,15 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>   		 * is a significant savings.
>   		 */
>   		/* Try to push left. */
> -		if (mas_push_data(mas, height, &mast, true))
> +		if (mas_push_data(mas, &mast, true)) {
> +			height++;
>   			break;
> +		}
>   		/* Try to push right. */
> -		if (mas_push_data(mas, height, &mast, false))
> +		if (mas_push_data(mas, &mast, false)) {
> +			height++;
>   			break;
> +		}
>   
>   		split = mab_calc_split(mas, b_node, &mid_split);
>   		mast_split_data(&mast, mas, split);
> @@ -3365,7 +3373,7 @@ static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
>   	/* Set the original node as dead */
>   	old = mas->node;
>   	mas->node = l_mas.node;
> -	mas_wmb_replace(mas, old);
> +	mas_wmb_replace(mas, old, height);
>   	mtree_range_walk(mas);
>   	return;
>   }
> @@ -3424,8 +3432,7 @@ static inline void mas_root_expand(struct ma_state *mas, void *entry)
>   	if (mas->last != ULONG_MAX)
>   		pivots[++slot] = ULONG_MAX;
>   
> -	mas->depth = 1;
> -	mas_set_height(mas);
> +	mt_set_height(mas->tree, 1);
>   	ma_set_meta(node, maple_leaf_64, 0, slot);
>   	/* swap the new root into the tree */
>   	rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
> @@ -3669,8 +3676,7 @@ static inline void mas_new_root(struct ma_state *mas, void *entry)
>   	WARN_ON_ONCE(mas->index || mas->last != ULONG_MAX);
>   
>   	if (!entry) {
> -		mas->depth = 0;
> -		mas_set_height(mas);
> +		mt_set_height(mas->tree, 0);
>   		rcu_assign_pointer(mas->tree->ma_root, entry);
>   		mas->status = ma_start;
>   		goto done;
> @@ -3684,8 +3690,7 @@ static inline void mas_new_root(struct ma_state *mas, void *entry)
>   	mas->status = ma_active;
>   	rcu_assign_pointer(slots[0], entry);
>   	pivots[0] = mas->last;
> -	mas->depth = 1;
> -	mas_set_height(mas);
> +	mt_set_height(mas->tree, 1);
>   	rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
>   
>   done:
> @@ -3804,6 +3809,7 @@ static inline void mas_wr_node_store(struct ma_wr_state *wr_mas,
>   	struct maple_node reuse, *newnode;
>   	unsigned char copy_size, node_pivots = mt_pivots[wr_mas->type];
>   	bool in_rcu = mt_in_rcu(mas->tree);
> +	unsigned char height = mas_mt_height(mas);
>   
>   	if (mas->last == wr_mas->end_piv)
>   		offset_end++; /* don't copy this offset */
> @@ -3860,7 +3866,7 @@ static inline void mas_wr_node_store(struct ma_wr_state *wr_mas,
>   		struct maple_enode *old_enode = mas->node;
>   
>   		mas->node = mt_mk_node(newnode, wr_mas->type);
> -		mas_replace_node(mas, old_enode);
> +		mas_replace_node(mas, old_enode, height);
>   	} else {
>   		memcpy(wr_mas->node, newnode, sizeof(struct maple_node));
>   	}
> diff --git a/tools/testing/radix-tree/maple.c b/tools/testing/radix-tree/maple.c
> index bc30050227fd..e0f8fabe8821 100644
> --- a/tools/testing/radix-tree/maple.c
> +++ b/tools/testing/radix-tree/maple.c
> @@ -36248,6 +36248,21 @@ static noinline void __init check_mtree_dup(struct maple_tree *mt)
>   
>   extern void test_kmem_cache_bulk(void);
>   
> +static inline void check_spanning_store_height(struct maple_tree *mt)
> +{
> +	int index = 0;
> +	MA_STATE(mas, mt, 0, 0);
> +	mas_lock(&mas);
> +	while (mt_height(mt) != 3) {
> +		mas_store_gfp(&mas, xa_mk_value(index), GFP_KERNEL);
> +		mas_set(&mas, ++index);
> +	}
> +	mas_set_range(&mas, 90, 140);
> +	mas_store_gfp(&mas, xa_mk_value(index), GFP_KERNEL);
> +	MT_BUG_ON(mt, mas_mt_height(&mas) != 2);
> +	mas_unlock(&mas);
> +}
> +
>   /* callback function used for check_nomem_writer_race() */
>   static void writer2(void *maple_tree)
>   {
> @@ -36414,6 +36429,10 @@ void farmer_tests(void)
>   	check_spanning_write(&tree);
>   	mtree_destroy(&tree);
>   
> +	mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE);
> +	check_spanning_store_height(&tree);
> +	mtree_destroy(&tree);
> +
>   	mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE);
>   	check_null_expand(&tree);
>   	mtree_destroy(&tree);
diff mbox series

Patch

diff --git a/lib/maple_tree.c b/lib/maple_tree.c
index f25ee210d495..236f0579ca53 100644
--- a/lib/maple_tree.c
+++ b/lib/maple_tree.c
@@ -211,14 +211,14 @@  static void ma_free_rcu(struct maple_node *node)
 	call_rcu(&node->rcu, mt_free_rcu);
 }
 
-static void mas_set_height(struct ma_state *mas)
+static void mt_set_height(struct maple_tree *mt, unsigned char height)
 {
-	unsigned int new_flags = mas->tree->ma_flags;
+	unsigned int new_flags = mt->ma_flags;
 
 	new_flags &= ~MT_FLAGS_HEIGHT_MASK;
-	MAS_BUG_ON(mas, mas->depth > MAPLE_HEIGHT_MAX);
-	new_flags |= mas->depth << MT_FLAGS_HEIGHT_OFFSET;
-	mas->tree->ma_flags = new_flags;
+	MT_BUG_ON(mt, height > MAPLE_HEIGHT_MAX);
+	new_flags |= height << MT_FLAGS_HEIGHT_OFFSET;
+	mt->ma_flags = new_flags;
 }
 
 static unsigned int mas_mt_height(struct ma_state *mas)
@@ -1371,7 +1371,7 @@  static inline struct maple_enode *mas_start(struct ma_state *mas)
 		root = mas_root(mas);
 		/* Tree with nodes */
 		if (likely(xa_is_node(root))) {
-			mas->depth = 1;
+			mas->depth = 0;
 			mas->status = ma_active;
 			mas->node = mte_safe_root(root);
 			mas->offset = 0;
@@ -1712,9 +1712,10 @@  static inline void mas_adopt_children(struct ma_state *mas,
  * node as dead.
  * @mas: the maple state with the new node
  * @old_enode: The old maple encoded node to replace.
+ * @new_height: if we are inserting a root node, update the height of the tree
  */
 static inline void mas_put_in_tree(struct ma_state *mas,
-		struct maple_enode *old_enode)
+		struct maple_enode *old_enode, char new_height)
 	__must_hold(mas->tree->ma_lock)
 {
 	unsigned char offset;
@@ -1723,7 +1724,7 @@  static inline void mas_put_in_tree(struct ma_state *mas,
 	if (mte_is_root(mas->node)) {
 		mas_mn(mas)->parent = ma_parent_ptr(mas_tree_parent(mas));
 		rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
-		mas_set_height(mas);
+		mt_set_height(mas->tree, new_height);
 	} else {
 
 		offset = mte_parent_slot(mas->node);
@@ -1741,12 +1742,13 @@  static inline void mas_put_in_tree(struct ma_state *mas,
  * the parent encoding to locate the maple node in the tree.
  * @mas: the ma_state with @mas->node pointing to the new node.
  * @old_enode: The old maple encoded node.
+ * @new_height: The new height of the tree as a result of the operation
  */
 static inline void mas_replace_node(struct ma_state *mas,
-		struct maple_enode *old_enode)
+		struct maple_enode *old_enode, unsigned char new_height)
 	__must_hold(mas->tree->ma_lock)
 {
-	mas_put_in_tree(mas, old_enode);
+	mas_put_in_tree(mas, old_enode, new_height);
 	mas_free(mas, old_enode);
 }
 
@@ -2536,10 +2538,11 @@  static inline void mas_topiary_node(struct ma_state *mas,
  *
  * @mas: The maple state pointing at the new data
  * @old_enode: The maple encoded node being replaced
+ * @new_height: The new height of the tree as a result of the operation
  *
  */
 static inline void mas_topiary_replace(struct ma_state *mas,
-		struct maple_enode *old_enode)
+		struct maple_enode *old_enode, unsigned char new_height)
 {
 	struct ma_state tmp[3], tmp_next[3];
 	MA_TOPIARY(subtrees, mas->tree);
@@ -2547,7 +2550,7 @@  static inline void mas_topiary_replace(struct ma_state *mas,
 	int i, n;
 
 	/* Place data in tree & then mark node as old */
-	mas_put_in_tree(mas, old_enode);
+	mas_put_in_tree(mas, old_enode, new_height);
 
 	/* Update the parent pointers in the tree */
 	tmp[0] = *mas;
@@ -2631,14 +2634,15 @@  static inline void mas_topiary_replace(struct ma_state *mas,
  * mas_wmb_replace() - Write memory barrier and replace
  * @mas: The maple state
  * @old_enode: The old maple encoded node that is being replaced.
+ * @new_height: The new height of the tree as a result of the operation
  *
  * Updates gap as necessary.
  */
 static inline void mas_wmb_replace(struct ma_state *mas,
-		struct maple_enode *old_enode)
+		struct maple_enode *old_enode, unsigned char new_height)
 {
 	/* Insert the new data in the tree */
-	mas_topiary_replace(mas, old_enode);
+	mas_topiary_replace(mas, old_enode, new_height);
 
 	if (mte_is_leaf(mas->node))
 		return;
@@ -2824,6 +2828,7 @@  static void mas_spanning_rebalance(struct ma_state *mas,
 {
 	unsigned char split, mid_split;
 	unsigned char slot = 0;
+	unsigned char new_height = 0; /* used if node is a new root */
 	struct maple_enode *left = NULL, *middle = NULL, *right = NULL;
 	struct maple_enode *old_enode;
 
@@ -2866,6 +2871,7 @@  static void mas_spanning_rebalance(struct ma_state *mas,
 		mast_set_split_parents(mast, left, middle, right, split,
 				       mid_split);
 		mast_cp_to_nodes(mast, left, middle, right, split, mid_split);
+		new_height++;
 
 		/*
 		 * Copy data from next level in the tree to mast->bn from next
@@ -2873,7 +2879,6 @@  static void mas_spanning_rebalance(struct ma_state *mas,
 		 */
 		memset(mast->bn, 0, sizeof(struct maple_big_node));
 		mast->bn->type = mte_node_type(left);
-		l_mas.depth++;
 
 		/* Root already stored in l->node. */
 		if (mas_is_root_limits(mast->l))
@@ -2909,8 +2914,9 @@  static void mas_spanning_rebalance(struct ma_state *mas,
 
 	l_mas.node = mt_mk_node(ma_mnode_ptr(mas_pop_node(mas)),
 				mte_node_type(mast->orig_l->node));
-	l_mas.depth++;
+
 	mab_mas_cp(mast->bn, 0, mt_slots[mast->bn->type] - 1, &l_mas, true);
+	new_height++;
 	mas_set_parent(mas, left, l_mas.node, slot);
 	if (middle)
 		mas_set_parent(mas, middle, l_mas.node, ++slot);
@@ -2933,7 +2939,7 @@  static void mas_spanning_rebalance(struct ma_state *mas,
 	mas->min = l_mas.min;
 	mas->max = l_mas.max;
 	mas->offset = l_mas.offset;
-	mas_wmb_replace(mas, old_enode);
+	mas_wmb_replace(mas, old_enode, new_height);
 	mtree_range_walk(mas);
 	return;
 }
@@ -3009,6 +3015,7 @@  static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end
 	void __rcu **l_slots, **slots;
 	unsigned long *l_pivs, *pivs, gap;
 	bool in_rcu = mt_in_rcu(mas->tree);
+	unsigned char new_height = mas_mt_height(mas);
 
 	MA_STATE(l_mas, mas->tree, mas->index, mas->last);
 
@@ -3103,7 +3110,7 @@  static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end
 	mas_ascend(mas);
 
 	if (in_rcu) {
-		mas_replace_node(mas, old_eparent);
+		mas_replace_node(mas, old_eparent, new_height);
 		mas_adopt_children(mas, mas->node);
 	}
 
@@ -3114,10 +3121,9 @@  static inline void mas_destroy_rebalance(struct ma_state *mas, unsigned char end
  * mas_split_final_node() - Split the final node in a subtree operation.
  * @mast: the maple subtree state
  * @mas: The maple state
- * @height: The height of the tree in case it's a new root.
  */
 static inline void mas_split_final_node(struct maple_subtree_state *mast,
-					struct ma_state *mas, int height)
+					struct ma_state *mas)
 {
 	struct maple_enode *ancestor;
 
@@ -3126,7 +3132,6 @@  static inline void mas_split_final_node(struct maple_subtree_state *mast,
 			mast->bn->type = maple_arange_64;
 		else
 			mast->bn->type = maple_range_64;
-		mas->depth = height;
 	}
 	/*
 	 * Only a single node is used here, could be root.
@@ -3214,7 +3219,6 @@  static inline void mast_split_data(struct maple_subtree_state *mast,
  * mas_push_data() - Instead of splitting a node, it is beneficial to push the
  * data to the right or left node if there is room.
  * @mas: The maple state
- * @height: The current height of the maple state
  * @mast: The maple subtree state
  * @left: Push left or not.
  *
@@ -3222,8 +3226,8 @@  static inline void mast_split_data(struct maple_subtree_state *mast,
  *
  * Return: True if pushed, false otherwise.
  */
-static inline bool mas_push_data(struct ma_state *mas, int height,
-				 struct maple_subtree_state *mast, bool left)
+static inline bool mas_push_data(struct ma_state *mas,
+				struct maple_subtree_state *mast, bool left)
 {
 	unsigned char slot_total = mast->bn->b_end;
 	unsigned char end, space, split;
@@ -3280,7 +3284,7 @@  static inline bool mas_push_data(struct ma_state *mas, int height,
 
 	mast_split_data(mast, mas, split);
 	mast_fill_bnode(mast, mas, 2);
-	mas_split_final_node(mast, mas, height + 1);
+	mas_split_final_node(mast, mas);
 	return true;
 }
 
@@ -3293,6 +3297,7 @@  static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
 {
 	struct maple_subtree_state mast;
 	int height = 0;
+	unsigned int orig_height = mas_mt_height(mas);
 	unsigned char mid_split, split = 0;
 	struct maple_enode *old;
 
@@ -3319,7 +3324,6 @@  static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
 	MA_STATE(prev_r_mas, mas->tree, mas->index, mas->last);
 
 	trace_ma_op(__func__, mas);
-	mas->depth = mas_mt_height(mas);
 
 	mast.l = &l_mas;
 	mast.r = &r_mas;
@@ -3327,9 +3331,9 @@  static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
 	mast.orig_r = &prev_r_mas;
 	mast.bn = b_node;
 
-	while (height++ <= mas->depth) {
+	while (height++ <= orig_height) {
 		if (mt_slots[b_node->type] > b_node->b_end) {
-			mas_split_final_node(&mast, mas, height);
+			mas_split_final_node(&mast, mas);
 			break;
 		}
 
@@ -3344,11 +3348,15 @@  static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
 		 * is a significant savings.
 		 */
 		/* Try to push left. */
-		if (mas_push_data(mas, height, &mast, true))
+		if (mas_push_data(mas, &mast, true)) {
+			height++;
 			break;
+		}
 		/* Try to push right. */
-		if (mas_push_data(mas, height, &mast, false))
+		if (mas_push_data(mas, &mast, false)) {
+			height++;
 			break;
+		}
 
 		split = mab_calc_split(mas, b_node, &mid_split);
 		mast_split_data(&mast, mas, split);
@@ -3365,7 +3373,7 @@  static void mas_split(struct ma_state *mas, struct maple_big_node *b_node)
 	/* Set the original node as dead */
 	old = mas->node;
 	mas->node = l_mas.node;
-	mas_wmb_replace(mas, old);
+	mas_wmb_replace(mas, old, height);
 	mtree_range_walk(mas);
 	return;
 }
@@ -3424,8 +3432,7 @@  static inline void mas_root_expand(struct ma_state *mas, void *entry)
 	if (mas->last != ULONG_MAX)
 		pivots[++slot] = ULONG_MAX;
 
-	mas->depth = 1;
-	mas_set_height(mas);
+	mt_set_height(mas->tree, 1);
 	ma_set_meta(node, maple_leaf_64, 0, slot);
 	/* swap the new root into the tree */
 	rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
@@ -3669,8 +3676,7 @@  static inline void mas_new_root(struct ma_state *mas, void *entry)
 	WARN_ON_ONCE(mas->index || mas->last != ULONG_MAX);
 
 	if (!entry) {
-		mas->depth = 0;
-		mas_set_height(mas);
+		mt_set_height(mas->tree, 0);
 		rcu_assign_pointer(mas->tree->ma_root, entry);
 		mas->status = ma_start;
 		goto done;
@@ -3684,8 +3690,7 @@  static inline void mas_new_root(struct ma_state *mas, void *entry)
 	mas->status = ma_active;
 	rcu_assign_pointer(slots[0], entry);
 	pivots[0] = mas->last;
-	mas->depth = 1;
-	mas_set_height(mas);
+	mt_set_height(mas->tree, 1);
 	rcu_assign_pointer(mas->tree->ma_root, mte_mk_root(mas->node));
 
 done:
@@ -3804,6 +3809,7 @@  static inline void mas_wr_node_store(struct ma_wr_state *wr_mas,
 	struct maple_node reuse, *newnode;
 	unsigned char copy_size, node_pivots = mt_pivots[wr_mas->type];
 	bool in_rcu = mt_in_rcu(mas->tree);
+	unsigned char height = mas_mt_height(mas);
 
 	if (mas->last == wr_mas->end_piv)
 		offset_end++; /* don't copy this offset */
@@ -3860,7 +3866,7 @@  static inline void mas_wr_node_store(struct ma_wr_state *wr_mas,
 		struct maple_enode *old_enode = mas->node;
 
 		mas->node = mt_mk_node(newnode, wr_mas->type);
-		mas_replace_node(mas, old_enode);
+		mas_replace_node(mas, old_enode, height);
 	} else {
 		memcpy(wr_mas->node, newnode, sizeof(struct maple_node));
 	}
diff --git a/tools/testing/radix-tree/maple.c b/tools/testing/radix-tree/maple.c
index bc30050227fd..e0f8fabe8821 100644
--- a/tools/testing/radix-tree/maple.c
+++ b/tools/testing/radix-tree/maple.c
@@ -36248,6 +36248,21 @@  static noinline void __init check_mtree_dup(struct maple_tree *mt)
 
 extern void test_kmem_cache_bulk(void);
 
+static inline void check_spanning_store_height(struct maple_tree *mt)
+{
+	int index = 0;
+	MA_STATE(mas, mt, 0, 0);
+	mas_lock(&mas);
+	while (mt_height(mt) != 3) {
+		mas_store_gfp(&mas, xa_mk_value(index), GFP_KERNEL);
+		mas_set(&mas, ++index);
+	}
+	mas_set_range(&mas, 90, 140);
+	mas_store_gfp(&mas, xa_mk_value(index), GFP_KERNEL);
+	MT_BUG_ON(mt, mas_mt_height(&mas) != 2);
+	mas_unlock(&mas);
+}
+
 /* callback function used for check_nomem_writer_race() */
 static void writer2(void *maple_tree)
 {
@@ -36414,6 +36429,10 @@  void farmer_tests(void)
 	check_spanning_write(&tree);
 	mtree_destroy(&tree);
 
+	mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE);
+	check_spanning_store_height(&tree);
+	mtree_destroy(&tree);
+
 	mt_init_flags(&tree, MT_FLAGS_ALLOC_RANGE);
 	check_null_expand(&tree);
 	mtree_destroy(&tree);