diff mbox series

[v4,for-next,06/12] RDMA/core: Use refcount_t instead of atomic_t on refcount of mcast_group

Message ID 1622194663-2383-7-git-send-email-liweihang@huawei.com (mailing list archive)
State Changes Requested
Headers show
Series RDMA: Use refcount_t for reference counting | expand

Commit Message

Weihang Li May 28, 2021, 9:37 a.m. UTC
The refcount_t API will WARN on underflow and overflow of a reference
counter, and avoid use-after-free risks. Increase refcount_t from 0 to 1 is
regarded as there is a risk about use-after-free. So it should be set to 1
directly during initialization.

Signed-off-by: Weihang Li <liweihang@huawei.com>
---
 drivers/infiniband/core/multicast.c | 17 +++++++++++------
 1 file changed, 11 insertions(+), 6 deletions(-)

Comments

Jason Gunthorpe June 8, 2021, 5:55 p.m. UTC | #1
On Fri, May 28, 2021 at 05:37:37PM +0800, Weihang Li wrote:
> @@ -565,8 +565,11 @@ static struct mcast_group *acquire_group(struct mcast_port *port,
>  	if (!is_mgid0) {
>  		spin_lock_irqsave(&port->lock, flags);
>  		group = mcast_find(port, mgid);
> -		if (group)
> +		if (group) {
> +			refcount_inc(&group->refcount);
>  			goto found;
> +		}
> +
>  		spin_unlock_irqrestore(&port->lock, flags);
>  	}
>  
> @@ -590,8 +593,10 @@ static struct mcast_group *acquire_group(struct mcast_port *port,
>  		group = cur_group;
>  	} else
>  		refcount_inc(&port->refcount);
> +
> +	refcount_set(&group->refcount, 1);
> +

This isn't right, when mcast_insert() returns an existing group we
need to incr not set the refcount. Change it like this:

diff --git a/drivers/infiniband/core/multicast.c b/drivers/infiniband/core/multicast.c
index 17abc212b87d05..cf99e17b81ce79 100644
--- a/drivers/infiniband/core/multicast.c
+++ b/drivers/infiniband/core/multicast.c
@@ -585,17 +585,17 @@ static struct mcast_group *acquire_group(struct mcast_port *port,
 	INIT_LIST_HEAD(&group->active_list);
 	INIT_WORK(&group->work, mcast_work_handler);
 	spin_lock_init(&group->lock);
+	refcount_set(&group->refcount, 1);
 
 	spin_lock_irqsave(&port->lock, flags);
 	cur_group = mcast_insert(port, group, is_mgid0);
 	if (cur_group) {
 		kfree(group);
 		group = cur_group;
+		refcount_inc(&group->refcount);
 	} else
 		refcount_inc(&port->refcount);
 
-	refcount_set(&group->refcount, 1);
-
 found:
 	spin_unlock_irqrestore(&port->lock, flags);
 	return group;
Weihang Li June 9, 2021, 3:45 a.m. UTC | #2
On 2021/6/9 1:55, Jason Gunthorpe wrote:
> On Fri, May 28, 2021 at 05:37:37PM +0800, Weihang Li wrote:
>> @@ -565,8 +565,11 @@ static struct mcast_group *acquire_group(struct mcast_port *port,
>>  	if (!is_mgid0) {
>>  		spin_lock_irqsave(&port->lock, flags);
>>  		group = mcast_find(port, mgid);
>> -		if (group)
>> +		if (group) {
>> +			refcount_inc(&group->refcount);
>>  			goto found;
>> +		}
>> +
>>  		spin_unlock_irqrestore(&port->lock, flags);
>>  	}
>>  
>> @@ -590,8 +593,10 @@ static struct mcast_group *acquire_group(struct mcast_port *port,
>>  		group = cur_group;
>>  	} else
>>  		refcount_inc(&port->refcount);
>> +
>> +	refcount_set(&group->refcount, 1);
>> +
> 
> This isn't right, when mcast_insert() returns an existing group we
> need to incr not set the refcount. Change it like this:
> 

Thanks, I will modify it.

Weihang

> diff --git a/drivers/infiniband/core/multicast.c b/drivers/infiniband/core/multicast.c
> index 17abc212b87d05..cf99e17b81ce79 100644
> --- a/drivers/infiniband/core/multicast.c
> +++ b/drivers/infiniband/core/multicast.c
> @@ -585,17 +585,17 @@ static struct mcast_group *acquire_group(struct mcast_port *port,
>  	INIT_LIST_HEAD(&group->active_list);
>  	INIT_WORK(&group->work, mcast_work_handler);
>  	spin_lock_init(&group->lock);
> +	refcount_set(&group->refcount, 1);
>  
>  	spin_lock_irqsave(&port->lock, flags);
>  	cur_group = mcast_insert(port, group, is_mgid0);
>  	if (cur_group) {
>  		kfree(group);
>  		group = cur_group;
> +		refcount_inc(&group->refcount);
>  	} else
>  		refcount_inc(&port->refcount);
>  
> -	refcount_set(&group->refcount, 1);
> -
>  found:
>  	spin_unlock_irqrestore(&port->lock, flags);
>  	return group;
>
diff mbox series

Patch

diff --git a/drivers/infiniband/core/multicast.c b/drivers/infiniband/core/multicast.c
index a236532..17abc21 100644
--- a/drivers/infiniband/core/multicast.c
+++ b/drivers/infiniband/core/multicast.c
@@ -103,7 +103,7 @@  struct mcast_group {
 	struct list_head	active_list;
 	struct mcast_member	*last_join;
 	int			members[NUM_JOIN_MEMBERSHIP_TYPES];
-	atomic_t		refcount;
+	refcount_t		refcount;
 	enum mcast_group_state	state;
 	struct ib_sa_query	*query;
 	u16			pkey_index;
@@ -188,7 +188,7 @@  static void release_group(struct mcast_group *group)
 	unsigned long flags;
 
 	spin_lock_irqsave(&port->lock, flags);
-	if (atomic_dec_and_test(&group->refcount)) {
+	if (refcount_dec_and_test(&group->refcount)) {
 		rb_erase(&group->node, &port->table);
 		spin_unlock_irqrestore(&port->lock, flags);
 		kfree(group);
@@ -212,7 +212,7 @@  static void queue_join(struct mcast_member *member)
 	list_add_tail(&member->list, &group->pending_list);
 	if (group->state == MCAST_IDLE) {
 		group->state = MCAST_BUSY;
-		atomic_inc(&group->refcount);
+		refcount_inc(&group->refcount);
 		queue_work(mcast_wq, &group->work);
 	}
 	spin_unlock_irqrestore(&group->lock, flags);
@@ -565,8 +565,11 @@  static struct mcast_group *acquire_group(struct mcast_port *port,
 	if (!is_mgid0) {
 		spin_lock_irqsave(&port->lock, flags);
 		group = mcast_find(port, mgid);
-		if (group)
+		if (group) {
+			refcount_inc(&group->refcount);
 			goto found;
+		}
+
 		spin_unlock_irqrestore(&port->lock, flags);
 	}
 
@@ -590,8 +593,10 @@  static struct mcast_group *acquire_group(struct mcast_port *port,
 		group = cur_group;
 	} else
 		refcount_inc(&port->refcount);
+
+	refcount_set(&group->refcount, 1);
+
 found:
-	atomic_inc(&group->refcount);
 	spin_unlock_irqrestore(&port->lock, flags);
 	return group;
 }
@@ -780,7 +785,7 @@  static void mcast_groups_event(struct mcast_port *port,
 		group = rb_entry(node, struct mcast_group, node);
 		spin_lock(&group->lock);
 		if (group->state == MCAST_IDLE) {
-			atomic_inc(&group->refcount);
+			refcount_inc(&group->refcount);
 			queue_work(mcast_wq, &group->work);
 		}
 		if (group->state != MCAST_GROUP_ERROR)