diff mbox series

[RFC,03/11] mm/mempolicy: refactor set_mempolicy stack to take a task argument

Message ID 20231122211200.31620-4-gregory.price@memverge.com (mailing list archive)
State New
Headers show
Series mm/mempolicy: Make task->mempolicy externally modifiable via syscall and procfs | expand

Commit Message

Gregory Price Nov. 22, 2023, 9:11 p.m. UTC
To make mempolicy modifiable by external tasks, we must refactor
the callstack to take a task as an argument.

Modify the following functions to require a task argument:
	mpol_set_nodemask
	replace_mempolicy
	do_set_mempolicy

Since replace_mempolicy already acquired the task lock, there
is no need to change any locking behaviors.

All other callers (as of this patch) to mpol_set_nodemask
call either in the context of current with the task or mmap
lock held, so no other changes are required.

Signed-off-by: Gregory Price <gregory.price@memverge.com>
---
 mm/mempolicy.c | 51 +++++++++++++++++++++++++++-----------------------
 1 file changed, 28 insertions(+), 23 deletions(-)
diff mbox series

Patch

diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 37da712259d7..9ea3e1bfc002 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -226,8 +226,10 @@  static int mpol_new_preferred(struct mempolicy *pol, const nodemask_t *nodes)
  * Must be called holding task's alloc_lock to protect task's mems_allowed
  * and mempolicy.  May also be called holding the mmap_lock for write.
  */
-static int mpol_set_nodemask(struct mempolicy *pol,
-		     const nodemask_t *nodes, struct nodemask_scratch *nsc)
+static int mpol_set_nodemask(struct task_struct *tsk,
+			     struct mempolicy *pol,
+			     const nodemask_t *nodes,
+			     struct nodemask_scratch *nsc)
 {
 	int ret;
 
@@ -240,8 +242,7 @@  static int mpol_set_nodemask(struct mempolicy *pol,
 		return 0;
 
 	/* Check N_MEMORY */
-	nodes_and(nsc->mask1,
-		  cpuset_current_mems_allowed, node_states[N_MEMORY]);
+	nodes_and(nsc->mask1, tsk->mems_allowed, node_states[N_MEMORY]);
 
 	VM_BUG_ON(!nodes);
 
@@ -253,7 +254,7 @@  static int mpol_set_nodemask(struct mempolicy *pol,
 	if (mpol_store_user_nodemask(pol))
 		pol->w.user_nodemask = *nodes;
 	else
-		pol->w.cpuset_mems_allowed = cpuset_current_mems_allowed;
+		pol->w.cpuset_mems_allowed = tsk->mems_allowed;
 
 	ret = mpol_ops[pol->mode].create(pol, &nsc->mask2);
 	return ret;
@@ -810,7 +811,9 @@  static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma,
 }
 
 /* Attempt to replace mempolicy, release the old one if successful */
-static long replace_mempolicy(struct mempolicy *new, nodemask_t *nodes)
+static long replace_mempolicy(struct task_struct *task,
+			      struct mempolicy *new,
+			      nodemask_t *nodes)
 {
 	struct mempolicy *old = NULL;
 	NODEMASK_SCRATCH(scratch);
@@ -819,19 +822,19 @@  static long replace_mempolicy(struct mempolicy *new, nodemask_t *nodes)
 	if (!scratch)
 		return -ENOMEM;
 
-	task_lock(current);
-	ret = mpol_set_nodemask(new, nodes, scratch);
+	task_lock(task);
+	ret = mpol_set_nodemask(task, new, nodes, scratch);
 	if (ret) {
-		task_unlock(current);
+		task_unlock(task);
 		goto out;
 	}
 
-	old = current->mempolicy;
-	current->mempolicy = new;
+	old = task->mempolicy;
+	task->mempolicy = new;
 	if (new && new->mode == MPOL_INTERLEAVE)
-		current->il_prev = MAX_NUMNODES-1;
+		task->il_prev = MAX_NUMNODES-1;
 out:
-	task_unlock(current);
+	task_unlock(task);
 	mpol_put(old);
 
 	NODEMASK_SCRATCH_FREE(scratch);
@@ -839,8 +842,8 @@  static long replace_mempolicy(struct mempolicy *new, nodemask_t *nodes)
 }
 
 /* Set the process memory policy */
-static long do_set_mempolicy(unsigned short mode, unsigned short flags,
-			     nodemask_t *nodes)
+static long do_set_mempolicy(struct task_struct *task, unsigned short mode,
+			     unsigned short flags, nodemask_t *nodes)
 {
 	struct mempolicy *new;
 	int ret;
@@ -849,7 +852,7 @@  static long do_set_mempolicy(unsigned short mode, unsigned short flags,
 	if (IS_ERR(new))
 		return PTR_ERR(new);
 
-	ret = replace_mempolicy(new, nodes);
+	ret = replace_mempolicy(task, new, nodes);
 	if (ret)
 		mpol_put(new);
 
@@ -1284,7 +1287,7 @@  static long do_mbind(unsigned long start, unsigned long len,
 		NODEMASK_SCRATCH(scratch);
 		if (scratch) {
 			mmap_write_lock(mm);
-			err = mpol_set_nodemask(new, nmask, scratch);
+			err = mpol_set_nodemask(current, new, nmask, scratch);
 			if (err)
 				mmap_write_unlock(mm);
 		} else
@@ -1580,7 +1583,8 @@  SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len,
 }
 
 /* Set the process memory policy */
-static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
+static long kernel_set_mempolicy(struct task_struct *task, int mode,
+				 const unsigned long __user *nmask,
 				 unsigned long maxnode)
 {
 	unsigned short mode_flags;
@@ -1596,13 +1600,13 @@  static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
 	if (err)
 		return err;
 
-	return do_set_mempolicy(lmode, mode_flags, &nodes);
+	return do_set_mempolicy(task, lmode, mode_flags, &nodes);
 }
 
 SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
 		unsigned long, maxnode)
 {
-	return kernel_set_mempolicy(mode, nmask, maxnode);
+	return kernel_set_mempolicy(current, mode, nmask, maxnode);
 }
 
 static int kernel_migrate_pages(pid_t pid, unsigned long maxnode,
@@ -2722,7 +2726,8 @@  void mpol_shared_policy_init(struct shared_policy *sp, struct mempolicy *mpol)
 			goto free_scratch; /* no valid nodemask intersection */
 
 		task_lock(current);
-		ret = mpol_set_nodemask(npol, &mpol->w.user_nodemask, scratch);
+		ret = mpol_set_nodemask(current, npol, &mpol->w.user_nodemask,
+					scratch);
 		task_unlock(current);
 		if (ret)
 			goto put_npol;
@@ -2870,7 +2875,7 @@  void __init numa_policy_init(void)
 	if (unlikely(nodes_empty(interleave_nodes)))
 		node_set(prefer, interleave_nodes);
 
-	if (do_set_mempolicy(MPOL_INTERLEAVE, 0, &interleave_nodes))
+	if (do_set_mempolicy(current, MPOL_INTERLEAVE, 0, &interleave_nodes))
 		pr_err("%s: interleaving failed\n", __func__);
 
 	check_numabalancing_enable();
@@ -2879,7 +2884,7 @@  void __init numa_policy_init(void)
 /* Reset policy of current process to default */
 void numa_default_policy(void)
 {
-	do_set_mempolicy(MPOL_DEFAULT, 0, NULL);
+	do_set_mempolicy(current, MPOL_DEFAULT, 0, NULL);
 }
 
 /*