diff mbox series

[1/2] padata: add pd get/put refcnt helper

Message ID 20241123080509.2573987-2-chenridong@huaweicloud.com (mailing list archive)
State New
Headers show
Series padata: fix UAF in padata_reorder | expand

Commit Message

Chen Ridong Nov. 23, 2024, 8:05 a.m. UTC
From: Chen Ridong <chenridong@huawei.com>

Add helpers for pd to get/put refcnt to make code consice.

Signed-off-by: Chen Ridong <chenridong@huawei.com>
---
 kernel/padata.c | 28 +++++++++++++++++++++-------
 1 file changed, 21 insertions(+), 7 deletions(-)
diff mbox series

Patch

diff --git a/kernel/padata.c b/kernel/padata.c
index d51bbc76b227..5d8e18cdcb25 100644
--- a/kernel/padata.c
+++ b/kernel/padata.c
@@ -47,6 +47,23 @@  struct padata_mt_job_state {
 static void padata_free_pd(struct parallel_data *pd);
 static void __init padata_mt_helper(struct work_struct *work);
 
+static inline void padata_get_pd(struct parallel_data *pd)
+{
+	refcount_inc(&pd->refcnt);
+}
+
+static inline void padata_put_pd(struct parallel_data *pd)
+{
+	if (refcount_dec_and_test(&pd->refcnt))
+		padata_free_pd(pd);
+}
+
+static inline void padata_put_pd_cnt(struct parallel_data *pd, int cnt)
+{
+	if (refcount_sub_and_test(cnt, &pd->refcnt))
+		padata_free_pd(pd);
+}
+
 static int padata_index_to_cpu(struct parallel_data *pd, int cpu_index)
 {
 	int cpu, target_cpu;
@@ -206,7 +223,7 @@  int padata_do_parallel(struct padata_shell *ps,
 	if ((pinst->flags & PADATA_RESET))
 		goto out;
 
-	refcount_inc(&pd->refcnt);
+	padata_get_pd(pd);
 	padata->pd = pd;
 	padata->cb_cpu = *cb_cpu;
 
@@ -380,8 +397,7 @@  static void padata_serial_worker(struct work_struct *serial_work)
 	}
 	local_bh_enable();
 
-	if (refcount_sub_and_test(cnt, &pd->refcnt))
-		padata_free_pd(pd);
+	padata_put_pd_cnt(pd, cnt);
 }
 
 /**
@@ -681,8 +697,7 @@  static int padata_replace(struct padata_instance *pinst)
 	synchronize_rcu();
 
 	list_for_each_entry_continue_reverse(ps, &pinst->pslist, list)
-		if (refcount_dec_and_test(&ps->opd->refcnt))
-			padata_free_pd(ps->opd);
+		padata_put_pd(ps->opd);
 
 	pinst->flags &= ~PADATA_RESET;
 
@@ -1124,8 +1139,7 @@  void padata_free_shell(struct padata_shell *ps)
 	mutex_lock(&ps->pinst->lock);
 	list_del(&ps->list);
 	pd = rcu_dereference_protected(ps->pd, 1);
-	if (refcount_dec_and_test(&pd->refcnt))
-		padata_free_pd(pd);
+	padata_put_pd(ps->pd);
 	mutex_unlock(&ps->pinst->lock);
 
 	kfree(ps);