diff mbox series

[RFC,04/11] mm/mempolicy: modify get_mempolicy call stack to take a task argument

Message ID 20231122211200.31620-5-gregory.price@memverge.com (mailing list archive)
State New
Headers show
Series mm/mempolicy: Make task->mempolicy externally modifiable via syscall and procfs | expand

Commit Message

Gregory Price Nov. 22, 2023, 9:11 p.m. UTC
To make mempolicy fetchable by external tasks, we must first change
the callstack to take a task as an argument.

Modify the following functions to require a task argument:
	do_get_mempolicy
	kernel_get_mempolicy

The way the task->mm is acquired must change slightly to enable this
change.  Originally, do_get_mempolicy would acquire the task->mm
directly via (current->mm).  This is unsafe to do in a non-current
context.  However, utilizing get_task_mm would break the original
functionality of do_get_mempolicy due to the following check
in get_task_mm:

  if (mm) {
    if (task->flags & PF_KTHREAD)
      mm = NULL;
    else
      mmget(mm);
  }

To retain the original behavior, if (task == current) we access
the task->mm directly, but if (task != current) we will utilize
get_task_mm to safely access the mm.

We simplify the get/put mechanics by always taking a reference to
the mm, even if we are in the context of (task == current).

Additionally, since the mempolicy will become externally modifiable,
we need to take the task lock to acquire task->mempolicy safely,
regardless of whether we are operating on current or not.

Signed-off-by: Gregory Price <gregory.price@memverge.com>
---
 mm/mempolicy.c | 43 +++++++++++++++++++++++++++++--------------
 1 file changed, 29 insertions(+), 14 deletions(-)

Comments

Michal Hocko Nov. 28, 2023, 2:07 p.m. UTC | #1
On Wed 22-11-23 16:11:53, Gregory Price wrote:
[...]
> @@ -928,7 +929,16 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
>  		 * vma/shared policy at addr is NULL.  We
>  		 * want to return MPOL_DEFAULT in this case.
>  		 */
> -		mm = current->mm;
> +		if (task == current) {
> +			/*
> +			 * original behavior allows a kernel task changing its
> +			 * own policy to avoid the condition in get_task_mm,
> +			 * so we'll directly access
> +			 */
> +			mm = task->mm;
> +			mmget(mm);

Do we actually have any kernel thread that would call this? Does it
actually make sense to support?

> +		} else
> +			mm = get_task_mm(task);
>  		mmap_read_lock(mm);
>  		vma = vma_lookup(mm, addr);
>  		if (!vma) {
> @@ -947,8 +957,10 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
>  		return -EINVAL;
>  	else {
>  		/* take a reference of the task policy now */
> -		pol = current->mempolicy;
> +		task_lock(task);
> +		pol = task->mempolicy;
>  		mpol_get(pol);
> +		task_unlock(task);
>  	}
>  
>  	if (!pol) {
> @@ -962,12 +974,13 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
>  			vma = NULL;
>  			mmap_read_unlock(mm);
>  			err = lookup_node(mm, addr);
> +			mmput(mm);
>  			if (err < 0)
>  				goto out;
>  			*policy = err;
> -		} else if (pol == current->mempolicy &&
> +		} else if (pol == task->mempolicy &&
>  				pol->mode == MPOL_INTERLEAVE) {
> -			*policy = next_node_in(current->il_prev, pol->nodes);
> +			*policy = next_node_in(task->il_prev, pol->nodes);

This is racy without task_lock which I do not think is helde but it also
seems this is not a big deal. pol is ref. counted so it won't go away
and if the task->mempolicy changes then the return value could be bogus
but this seems acceptable. It would be good to put a comment here that
this is actually deliberate.
Michal Hocko Nov. 28, 2023, 2:49 p.m. UTC | #2
[restoring the CC list as I believe this was not meant to be a private
response]

On Tue 28-11-23 09:12:35, Gregory Price wrote:
> On Tue, Nov 28, 2023 at 03:07:28PM +0100, Michal Hocko wrote:
> > On Wed 22-11-23 16:11:53, Gregory Price wrote:
> > [...]
> > > @@ -928,7 +929,16 @@ static long do_get_mempolicy(int *policy, nodemask_t *nmask,
> > >  		 * vma/shared policy at addr is NULL.  We
> > >  		 * want to return MPOL_DEFAULT in this case.
> > >  		 */
> > > -		mm = current->mm;
> > > +		if (task == current) {
> > > +			/*
> > > +			 * original behavior allows a kernel task changing its
> > > +			 * own policy to avoid the condition in get_task_mm,
> > > +			 * so we'll directly access
> > > +			 */
> > > +			mm = task->mm;
> > > +			mmget(mm);
> > 
> > Do we actually have any kernel thread that would call this? Does it
> > actually make sense to support?
> > 
> 
> This was changed in the upcoming v2 by using the pidfd interface for
> referencing both the task and the mm, so this code is a bit dead.

OK, that is the right thing to do IMHO. Allowing modifications on memory
policies on borrowed mms sounds rather weird and if we do not have any
actual usecases that would require that support then I would rather not
open that possibility at all.
diff mbox series

Patch

diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 9ea3e1bfc002..4519f39b1a07 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -899,8 +899,9 @@  static int lookup_node(struct mm_struct *mm, unsigned long addr)
 }
 
 /* Retrieve NUMA policy */
-static long do_get_mempolicy(int *policy, nodemask_t *nmask,
-			     unsigned long addr, unsigned long flags)
+static long do_get_mempolicy(struct task_struct *task, int *policy,
+			     nodemask_t *nmask, unsigned long addr,
+			     unsigned long flags)
 {
 	int err;
 	struct mm_struct *mm;
@@ -915,9 +916,9 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 		if (flags & (MPOL_F_NODE|MPOL_F_ADDR))
 			return -EINVAL;
 		*policy = 0;	/* just so it's initialized */
-		task_lock(current);
-		*nmask  = cpuset_current_mems_allowed;
-		task_unlock(current);
+		task_lock(task);
+		*nmask = task->mems_allowed;
+		task_unlock(task);
 		return 0;
 	}
 
@@ -928,7 +929,16 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 		 * vma/shared policy at addr is NULL.  We
 		 * want to return MPOL_DEFAULT in this case.
 		 */
-		mm = current->mm;
+		if (task == current) {
+			/*
+			 * original behavior allows a kernel task changing its
+			 * own policy to avoid the condition in get_task_mm,
+			 * so we'll directly access
+			 */
+			mm = task->mm;
+			mmget(mm);
+		} else
+			mm = get_task_mm(task);
 		mmap_read_lock(mm);
 		vma = vma_lookup(mm, addr);
 		if (!vma) {
@@ -947,8 +957,10 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 		return -EINVAL;
 	else {
 		/* take a reference of the task policy now */
-		pol = current->mempolicy;
+		task_lock(task);
+		pol = task->mempolicy;
 		mpol_get(pol);
+		task_unlock(task);
 	}
 
 	if (!pol) {
@@ -962,12 +974,13 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 			vma = NULL;
 			mmap_read_unlock(mm);
 			err = lookup_node(mm, addr);
+			mmput(mm);
 			if (err < 0)
 				goto out;
 			*policy = err;
-		} else if (pol == current->mempolicy &&
+		} else if (pol == task->mempolicy &&
 				pol->mode == MPOL_INTERLEAVE) {
-			*policy = next_node_in(current->il_prev, pol->nodes);
+			*policy = next_node_in(task->il_prev, pol->nodes);
 		} else {
 			err = -EINVAL;
 			goto out;
@@ -987,9 +1000,9 @@  static long do_get_mempolicy(int *policy, nodemask_t *nmask,
 		if (mpol_store_user_nodemask(pol)) {
 			*nmask = pol->w.user_nodemask;
 		} else {
-			task_lock(current);
+			task_lock(task);
 			get_policy_nodemask(pol, nmask);
-			task_unlock(current);
+			task_unlock(task);
 		}
 	}
 
@@ -1704,7 +1717,8 @@  SYSCALL_DEFINE4(migrate_pages, pid_t, pid, unsigned long, maxnode,
 }
 
 /* Retrieve NUMA policy */
-static int kernel_get_mempolicy(int __user *policy,
+static int kernel_get_mempolicy(struct task_struct *task,
+				int __user *policy,
 				unsigned long __user *nmask,
 				unsigned long maxnode,
 				unsigned long addr,
@@ -1719,7 +1733,7 @@  static int kernel_get_mempolicy(int __user *policy,
 
 	addr = untagged_addr(addr);
 
-	err = do_get_mempolicy(&pval, &nodes, addr, flags);
+	err = do_get_mempolicy(task, &pval, &nodes, addr, flags);
 
 	if (err)
 		return err;
@@ -1737,7 +1751,8 @@  SYSCALL_DEFINE5(get_mempolicy, int __user *, policy,
 		unsigned long __user *, nmask, unsigned long, maxnode,
 		unsigned long, addr, unsigned long, flags)
 {
-	return kernel_get_mempolicy(policy, nmask, maxnode, addr, flags);
+	return kernel_get_mempolicy(current, policy, nmask, maxnode, addr,
+				    flags);
 }
 
 bool vma_migratable(struct vm_area_struct *vma)