diff mbox series

[RFC,2/3] mm/mempolicy: Implement set_mempolicy2 and get_mempolicy2 syscalls

Message ID 20230914235457.482710-3-gregory.price@memverge.com
State New, archived
Headers show
Series mm/mempolicy: set/get_mempolicy2 | expand

Commit Message

Gregory Price Sept. 14, 2023, 11:54 p.m. UTC
sys_set_mempolicy is limited by its current argument structure
(mode, nodes, flags) to implementing policies that can be described
in that manner.

Implement set/get_mempolicy2 with a new mempolicy_args structure
which encapsulates the old behavior, and allows for new mempolicies
which may require additional information.

Signed-off-by: Gregory Price <gregory.price@memverge.com>
---
 arch/x86/entry/syscalls/syscall_32.tbl |   2 +
 arch/x86/entry/syscalls/syscall_64.tbl |   2 +
 include/linux/syscalls.h               |   2 +
 include/uapi/asm-generic/unistd.h      |  10 +-
 include/uapi/linux/mempolicy.h         |  32 ++++
 mm/mempolicy.c                         | 215 ++++++++++++++++++++++++-
 6 files changed, 261 insertions(+), 2 deletions(-)

Comments

Jonathan Cameron Oct. 2, 2023, 1:30 p.m. UTC | #1
On Thu, 14 Sep 2023 19:54:56 -0400
Gregory Price <gourry.memverge@gmail.com> wrote:

> sys_set_mempolicy is limited by its current argument structure
> (mode, nodes, flags) to implementing policies that can be described
> in that manner.
> 
> Implement set/get_mempolicy2 with a new mempolicy_args structure
> which encapsulates the old behavior, and allows for new mempolicies
> which may require additional information.
> 
> Signed-off-by: Gregory Price <gregory.price@memverge.com>
Some random comments inline.

Jonathan


> ---
>  arch/x86/entry/syscalls/syscall_32.tbl |   2 +
>  arch/x86/entry/syscalls/syscall_64.tbl |   2 +
>  include/linux/syscalls.h               |   2 +
>  include/uapi/asm-generic/unistd.h      |  10 +-
>  include/uapi/linux/mempolicy.h         |  32 ++++
>  mm/mempolicy.c                         | 215 ++++++++++++++++++++++++-
>  6 files changed, 261 insertions(+), 2 deletions(-)
> 
> diff --git a/arch/x86/entry/syscalls/syscall_32.tbl b/arch/x86/entry/syscalls/syscall_32.tbl
> index 2d0b1bd866ea..a72ef588a704 100644
> --- a/arch/x86/entry/syscalls/syscall_32.tbl
> +++ b/arch/x86/entry/syscalls/syscall_32.tbl
> @@ -457,3 +457,5 @@
>  450	i386	set_mempolicy_home_node		sys_set_mempolicy_home_node
>  451	i386	cachestat		sys_cachestat
>  452	i386	fchmodat2		sys_fchmodat2
> +454	i386	set_mempolicy2		sys_set_mempolicy2
> +455	i386	get_mempolicy2		sys_get_mempolicy2
> diff --git a/arch/x86/entry/syscalls/syscall_64.tbl b/arch/x86/entry/syscalls/syscall_64.tbl
> index 1d6eee30eceb..ec54064de8b3 100644
> --- a/arch/x86/entry/syscalls/syscall_64.tbl
> +++ b/arch/x86/entry/syscalls/syscall_64.tbl
> @@ -375,6 +375,8 @@
>  451	common	cachestat		sys_cachestat
>  452	common	fchmodat2		sys_fchmodat2
>  453	64	map_shadow_stack	sys_map_shadow_stack
> +454	common	set_mempolicy2		sys_set_mempolicy2
> +455	common	get_mempolicy2		sys_get_mempolicy2
>  
>  #
>  # Due to a historical design error, certain syscalls are numbered differently
> diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h
> index 22bc6bc147f8..d50a452954ae 100644
> --- a/include/linux/syscalls.h
> +++ b/include/linux/syscalls.h
> @@ -813,6 +813,8 @@ asmlinkage long sys_get_mempolicy(int __user *policy,
>  				unsigned long addr, unsigned long flags);
>  asmlinkage long sys_set_mempolicy(int mode, const unsigned long __user *nmask,
>  				unsigned long maxnode);
> +asmlinkage long sys_get_mempolicy2(struct mempolicy_args __user *args);
> +asmlinkage long sys_set_mempolicy2(struct mempolicy_args __user *args);
>  asmlinkage long sys_migrate_pages(pid_t pid, unsigned long maxnode,
>  				const unsigned long __user *from,
>  				const unsigned long __user *to);
> diff --git a/include/uapi/asm-generic/unistd.h b/include/uapi/asm-generic/unistd.h
> index abe087c53b4b..397dcf804941 100644
> --- a/include/uapi/asm-generic/unistd.h
> +++ b/include/uapi/asm-generic/unistd.h
> @@ -823,8 +823,16 @@ __SYSCALL(__NR_cachestat, sys_cachestat)
>  #define __NR_fchmodat2 452
>  __SYSCALL(__NR_fchmodat2, sys_fchmodat2)
>  
> +/* CONFIG_MMU only */
> +#ifndef __ARCH_NOMMU
> +#define __NR_set_mempolicy 454
> +__SYSCALL(__NR_set_mempolicy2, sys_set_mempolicy2)
> +#define __NR_set_mempolicy 455
> +__SYSCALL(__NR_get_mempolicy2, sys_get_mempolicy2)
> +#endif
> +
>  #undef __NR_syscalls
> -#define __NR_syscalls 453
> +#define __NR_syscalls 456
+3 for 2 additions?

>  
>  /*
>   * 32 bit systems traditionally used different
> diff --git a/include/uapi/linux/mempolicy.h b/include/uapi/linux/mempolicy.h
> index 046d0ccba4cd..53650f69db2b 100644
> --- a/include/uapi/linux/mempolicy.h
> +++ b/include/uapi/linux/mempolicy.h
> @@ -23,9 +23,41 @@ enum {
>  	MPOL_INTERLEAVE,
>  	MPOL_LOCAL,
>  	MPOL_PREFERRED_MANY,
> +	MPOL_LEGACY,	/* set_mempolicy limited to above modes */
>  	MPOL_MAX,	/* always last member of enum */
>  };
>  
> +struct mempolicy_args {
> +	int err;
> +	unsigned short mode;
> +	unsigned long *nodemask;
> +	unsigned long maxnode;
> +	unsigned short flags;
> +	struct {
> +		/* Memory allowed */
> +		struct {
> +			int err;
> +			unsigned long maxnode;
> +			unsigned long *nodemask;
> +		} allowed;
> +		/* Address information */
> +		struct {
> +			int err;
> +			unsigned long addr;
> +			unsigned long node;
> +			unsigned short mode;
> +			unsigned short flags;
> +		} addr;
> +		/* Interleave */
> +	} get;
> +	/* Mode specific settings */
> +	union {
> +		struct {
> +			unsigned long next_node; /* get only */
> +		} interleave;
> +	};
> +};
> +
>  /* Flags for set_mempolicy */
>  #define MPOL_F_STATIC_NODES	(1 << 15)
>  #define MPOL_F_RELATIVE_NODES	(1 << 14)
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index f49337f6f300..1cf7709400f1 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -1483,7 +1483,7 @@ static inline int sanitize_mpol_flags(int *mode, unsigned short *flags)
>  	*flags = *mode & MPOL_MODE_FLAGS;
>  	*mode &= ~MPOL_MODE_FLAGS;
>  
> -	if ((unsigned int)(*mode) >=  MPOL_MAX)
> +	if ((unsigned int)(*mode) >= MPOL_LEGACY)
>  		return -EINVAL;
>  	if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
>  		return -EINVAL;
> @@ -1614,6 +1614,219 @@ SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
>  	return kernel_set_mempolicy(mode, nmask, maxnode);
>  }
>  
> +static long do_set_mempolicy2(struct mempolicy_args *args)
> +{
> +	struct mempolicy *new = NULL;
> +	nodemask_t nodes;
> +	int err;
> +
> +	if (args->mode <= MPOL_LEGACY)
> +		return -EINVAL;
> +
> +	if (args->mode >= MPOL_MAX)
> +		return -EINVAL;
> +
> +	err = get_nodes(&nodes, args->nodemask, args->maxnode);
> +	if (err)
> +		return err;
> +
> +	new = mpol_new(args->mode, args->flags, &nodes);
> +	if (IS_ERR(new)) {
> +		err = PTR_ERR(new);
> +		goto out;

I'd expect mpol_new() to be side effect free on error,
so
		return PTR_ERR(new);
should be fine?

> +	}
> +
> +	switch (args->mode) {
> +	default:
> +		BUG();
> +	}
> +
> +	if (err)
> +		goto out;
> +
> +	err = swap_mempolicy(new, &nodes);
> +out:
> +	if (err && new)

as IS_ERR(new) is true, I think this puts the node even if mpol_new
returned an error.  That seems unwise.

I'd push this block below a return 0 anyway, so as to avoid
error handling in the good path.

> +		mpol_put(new);
> +	return err;
> +};
> +
> +static bool mempolicy2_args_valid(struct mempolicy_args *kargs)
> +{
> +	/* Legacy modes are routed through the legacy interface */
> +	if (kargs->mode <= MPOL_LEGACY)
> +		return false;
> +
> +	if (kargs->mode >= MPOL_MAX)
> +		return false;
> +
> +	return true;

This is a range check, so I think equally clear (and shorter) as..
	/* Legacy modes are routed through the legacy interface */
	return kargs->mode > MPOL_LEGACY && kargs->mode < MPOL_MAX;

> +}
> +
> +static long kernel_set_mempolicy2(const struct mempolicy_args __user *uargs,
> +				  size_t usize)
> +{
> +	struct mempolicy_args kargs;
> +	int err;
> +
> +	if (usize != sizeof(kargs))

As below, maybe allow for bigger with assumption we'll ignore what is in the
extra space.

> +		return -EINVAL;
> +
> +	err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize);
> +	if (err)
> +		return err;
> +
> +	/* If the mode is legacy, use the legacy path */
> +	if (kargs.mode < MPOL_LEGACY) {
> +		int legacy_mode = kargs.mode | kargs.flags;
> +		const unsigned long __user *lnmask = kargs.nodemask;
> +		unsigned long maxnode = kargs.maxnode;
> +
> +		return kernel_set_mempolicy(legacy_mode, lnmask, maxnode);
> +	}
> +
> +	if (!mempolicy2_args_valid(&kargs))
> +		return -EINVAL;
> +
> +	return do_set_mempolicy2(&kargs);
> +}
> +
> +SYSCALL_DEFINE2(set_mempolicy2, const struct mempolicy_args __user *, args,
> +		size_t, size)
> +{
> +	return kernel_set_mempolicy2(args, size);
> +}
> +
> +/* Gets extended mempolicy information */
> +static long do_get_mempolicy2(struct mempolicy_args *kargs)
> +{
> +	struct mempolicy *pol = current->mempolicy;
> +	nodemask_t knodes;
> +	int err = 0;
> +
> +	kargs->err = 0;
> +	kargs->mode = pol->mode;
> +	/* Mask off internal flags */
> +	kargs->flags = (pol->flags & MPOL_MODE_FLAGS);

Excessive brackets.

> +
> +	if (kargs->nodemask) {
> +		if (mpol_store_user_nodemask(pol)) {
> +			knodes = pol->w.user_nodemask;
> +		} else {
> +			task_lock(current);
> +			get_policy_nodemask(pol, &knodes);
> +			task_unlock(current);
> +		}
> +		err = copy_nodes_to_user(kargs->nodemask,
> +					 kargs->maxnode,
> +					 &knodes);
Can wrap this less.

> +		if (err)

return err ?

> +			return -EINVAL;
> +	}
> +
> +
> +	if (kargs->get.allowed.nodemask) {
> +		kargs->get.allowed.err = 0;
> +		task_lock(current);
> +		knodes = cpuset_current_mems_allowed;
> +		task_unlock(current);
> +		err = copy_nodes_to_user(kargs->get.allowed.nodemask,
> +					 kargs->get.allowed.maxnode,
> +					 &knodes);
> +		kargs->get.allowed.err = err ? err : 0;
> +		kargs->err |= err ? err : 1;
		if (err) {
			kargs->get.allowed.err = err;
			kargs->err |= err;
		} else {
			kargs->get.allowed.err = 0;
			kargs->err = 1;
	Not particularly obvious why 1 and if you get an error later it's going to be messy
        as will 1 |= err_code
		}
> +	}
> +
> +	if (kargs->get.addr.addr) {
> +		struct mempolicy *addr_pol = NULL;

Why init here - I think it's always set before use.

> +		struct vm_area_struct *vma = NULL;

Why init here?

> +		struct mm_struct *mm = current->mm;
> +		unsigned long addr = kargs->get.addr.addr;
> +
> +		kargs->get.addr.err = 0;

I'd set this only in the good path. You overwrite it
in the bad paths anyway, so just move it down below the error
checks.

> +
> +		/*
> +		 * Do NOT fall back to task policy if the
> +		 * vma/shared policy at addr is NULL.  We
> +		 * want to return MPOL_DEFAULT in this case.
> +		 */
> +		mmap_read_lock(mm);
> +		vma = vma_lookup(mm, addr);
> +		if (!vma) {
> +			mmap_read_unlock(mm);
> +			kargs->get.addr.err = -EFAULT;
> +			kargs->err |= err ? err : 2;
> +			goto mode_info;
> +		}
> +		if (vma->vm_ops && vma->vm_ops->get_policy)
> +			addr_pol = vma->vm_ops->get_policy(vma, addr);
> +		else
> +			addr_pol = vma->vm_policy;
> +
> +		kargs->get.addr.mode = addr_pol->mode;
> +		/* Mask off internal flags */
> +		kargs->get.addr.flags = (pol->flags & MPOL_MODE_FLAGS);
> +
> +		/*
> +		 * Take a refcount on the mpol, because we are about to
> +		 * drop the mmap_lock, after which only "pol" remains
> +		 * valid, "vma" is stale.
> +		 */
> +		vma = NULL;
> +		mpol_get(addr_pol);
> +		mmap_read_unlock(mm);
> +		err = lookup_node(mm, addr);
> +		mpol_put(addr_pol);
> +		if (err < 0) {
> +			kargs->get.addr.err = err;
> +			kargs->err |= err ? err : 4;
> +			goto mode_info;
> +		}
> +		kargs->get.addr.node = err;

Confusing to call something that isn't an error, err. I'd use a different
local variable for this and set err = rc in error path only.

Could set the get.addr.err = 0; down here as this is only way it remains 0
if you set it earlier.


> +	}
> +
> +mode_info:
> +	switch (kargs->mode) {
> +	case MPOL_INTERLEAVE:
> +		kargs->interleave.next_node = next_node_in(current->il_prev,
> +							   pol->nodes);
> +		break;
> +	default:
> +		break;
> +	}
> +
> +	return err;
> +}
> +
> +static long kernel_get_mempolicy2(struct mempolicy_args __user *uargs,
> +				  size_t usize)
> +{
> +	struct mempolicy_args kargs;
> +	int err;
> +
> +	if (usize != sizeof(struct mempolicy_args))

sizeof(kargs) for same reason as below.  I'm not sure on convention here
but is it wise to leave option for a newer userspace to send a larger
struct, knowing that fields in it might be ignored by an older kernel?


> +		return -EINVAL;
> +
> +	err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize);
> +	if (err)
> +		return err;
> +
> +	/* Get the extended memory policy information (kargs.ext) */
> +	err = do_get_mempolicy2(&kargs);
> +	if (err)
> +		return err;
> +
> +	err = copy_to_user(uargs, &kargs, sizeof(struct mempolicy_args));
> +
> +	return err;

return copy_to_user(uargs, &kargs, sizeof(kargs));
You are inconsistent on the sizeof.  Better to pick one style, and
given both are used, I'd go with using the sizeof(thing) rather
than sizeof(type) option + shorter lines ;)

> +}
> +
> +SYSCALL_DEFINE2(get_mempolicy2, struct mempolicy_args __user *, policy,
> +		size_t, size)
> +{
> +	return kernel_get_mempolicy2(policy, size);
> +}
> +
>  static int kernel_migrate_pages(pid_t pid, unsigned long maxnode,
>  				const unsigned long __user *old_nodes,
>  				const unsigned long __user *new_nodes)
Gregory Price Oct. 2, 2023, 3:30 p.m. UTC | #2
On Mon, Oct 02, 2023 at 02:30:08PM +0100, Jonathan Cameron wrote:
> On Thu, 14 Sep 2023 19:54:56 -0400
> Gregory Price <gourry.memverge@gmail.com> wrote:
> 
> > diff --git a/include/uapi/asm-generic/unistd.h b/include/uapi/asm-generic/unistd.h
> > index abe087c53b4b..397dcf804941 100644
> > --- a/include/uapi/asm-generic/unistd.h
> > +++ b/include/uapi/asm-generic/unistd.h
> > ...
> >  #undef __NR_syscalls
> > -#define __NR_syscalls 453
> > +#define __NR_syscalls 456
> +3 for 2 additions?
> 

When i'd originally written this, there was a partially merged syscall
colliding with 453, and this hadn't been incremented yet.  Did a quick
grep and it seems like that might have been reverted, so yeah this would
drop down to 453/454 & __NR=455.

> > +	/* Legacy modes are routed through the legacy interface */
> > +	if (kargs->mode <= MPOL_LEGACY)
> > +		return false;
> > +
> > +	if (kargs->mode >= MPOL_MAX)
> > +		return false;
> > +
> > +	return true;
> 
> This is a range check, so I think equally clear (and shorter) as..
> 	/* Legacy modes are routed through the legacy interface */
> 	return kargs->mode > MPOL_LEGACY && kargs->mode < MPOL_MAX;
>

I'll combine the range, but i left the two true/false conditions
separate because it's intended that follow on patches will add logic
before true is returned.

> > +		kargs->get.allowed.err = err ? err : 0;
> > +		kargs->err |= err ? err : 1;
> 		if (err) {
> 			kargs->get.allowed.err = err;
> 			kargs->err |= err;
> 		} else {
> 			kargs->get.allowed.err = 0;
> 			kargs->err = 1;
> 	Not particularly obvious why 1 and if you get an error later it's going to be messy
>         as will 1 |= err_code

My original intent was to just allow each section to error separately,
but honestly this seems overly complicated and somewhat against the
design of almost every other syscall, so i'm going to rip all these
error code spaces out and instead just have everything return on error.

Thanks!
Gregory
Gregory Price Oct. 2, 2023, 6:03 p.m. UTC | #3
On Mon, Oct 02, 2023 at 02:30:08PM +0100, Jonathan Cameron wrote:
> On Thu, 14 Sep 2023 19:54:56 -0400
> Gregory Price <gourry.memverge@gmail.com> wrote:
> 
> > diff --git a/arch/x86/entry/syscalls/syscall_64.tbl b/arch/x86/entry/syscalls/syscall_64.tbl
> > index 1d6eee30eceb..ec54064de8b3 100644
> > --- a/arch/x86/entry/syscalls/syscall_64.tbl
> > +++ b/arch/x86/entry/syscalls/syscall_64.tbl
> > @@ -375,6 +375,8 @@
> >  451	common	cachestat		sys_cachestat
> >  452	common	fchmodat2		sys_fchmodat2
> >  453	64	map_shadow_stack	sys_map_shadow_stack
> > +454	common	set_mempolicy2		sys_set_mempolicy2
> > +455	common	get_mempolicy2		sys_get_mempolicy2
> >  

^^ this is the discrepency.  map_shadow_stack is at 453, so NR_syscalls
should already be 454, but map_shadow_stack has not be plumbed through
the rest of the kernel.

This needs to be addressed, but not in this RFC.

> >  #undef __NR_syscalls
> > -#define __NR_syscalls 453
> > +#define __NR_syscalls 456
> +3 for 2 additions?
> 

see above
diff mbox series

Patch

diff --git a/arch/x86/entry/syscalls/syscall_32.tbl b/arch/x86/entry/syscalls/syscall_32.tbl
index 2d0b1bd866ea..a72ef588a704 100644
--- a/arch/x86/entry/syscalls/syscall_32.tbl
+++ b/arch/x86/entry/syscalls/syscall_32.tbl
@@ -457,3 +457,5 @@ 
 450	i386	set_mempolicy_home_node		sys_set_mempolicy_home_node
 451	i386	cachestat		sys_cachestat
 452	i386	fchmodat2		sys_fchmodat2
+454	i386	set_mempolicy2		sys_set_mempolicy2
+455	i386	get_mempolicy2		sys_get_mempolicy2
diff --git a/arch/x86/entry/syscalls/syscall_64.tbl b/arch/x86/entry/syscalls/syscall_64.tbl
index 1d6eee30eceb..ec54064de8b3 100644
--- a/arch/x86/entry/syscalls/syscall_64.tbl
+++ b/arch/x86/entry/syscalls/syscall_64.tbl
@@ -375,6 +375,8 @@ 
 451	common	cachestat		sys_cachestat
 452	common	fchmodat2		sys_fchmodat2
 453	64	map_shadow_stack	sys_map_shadow_stack
+454	common	set_mempolicy2		sys_set_mempolicy2
+455	common	get_mempolicy2		sys_get_mempolicy2
 
 #
 # Due to a historical design error, certain syscalls are numbered differently
diff --git a/include/linux/syscalls.h b/include/linux/syscalls.h
index 22bc6bc147f8..d50a452954ae 100644
--- a/include/linux/syscalls.h
+++ b/include/linux/syscalls.h
@@ -813,6 +813,8 @@  asmlinkage long sys_get_mempolicy(int __user *policy,
 				unsigned long addr, unsigned long flags);
 asmlinkage long sys_set_mempolicy(int mode, const unsigned long __user *nmask,
 				unsigned long maxnode);
+asmlinkage long sys_get_mempolicy2(struct mempolicy_args __user *args);
+asmlinkage long sys_set_mempolicy2(struct mempolicy_args __user *args);
 asmlinkage long sys_migrate_pages(pid_t pid, unsigned long maxnode,
 				const unsigned long __user *from,
 				const unsigned long __user *to);
diff --git a/include/uapi/asm-generic/unistd.h b/include/uapi/asm-generic/unistd.h
index abe087c53b4b..397dcf804941 100644
--- a/include/uapi/asm-generic/unistd.h
+++ b/include/uapi/asm-generic/unistd.h
@@ -823,8 +823,16 @@  __SYSCALL(__NR_cachestat, sys_cachestat)
 #define __NR_fchmodat2 452
 __SYSCALL(__NR_fchmodat2, sys_fchmodat2)
 
+/* CONFIG_MMU only */
+#ifndef __ARCH_NOMMU
+#define __NR_set_mempolicy 454
+__SYSCALL(__NR_set_mempolicy2, sys_set_mempolicy2)
+#define __NR_set_mempolicy 455
+__SYSCALL(__NR_get_mempolicy2, sys_get_mempolicy2)
+#endif
+
 #undef __NR_syscalls
-#define __NR_syscalls 453
+#define __NR_syscalls 456
 
 /*
  * 32 bit systems traditionally used different
diff --git a/include/uapi/linux/mempolicy.h b/include/uapi/linux/mempolicy.h
index 046d0ccba4cd..53650f69db2b 100644
--- a/include/uapi/linux/mempolicy.h
+++ b/include/uapi/linux/mempolicy.h
@@ -23,9 +23,41 @@  enum {
 	MPOL_INTERLEAVE,
 	MPOL_LOCAL,
 	MPOL_PREFERRED_MANY,
+	MPOL_LEGACY,	/* set_mempolicy limited to above modes */
 	MPOL_MAX,	/* always last member of enum */
 };
 
+struct mempolicy_args {
+	int err;
+	unsigned short mode;
+	unsigned long *nodemask;
+	unsigned long maxnode;
+	unsigned short flags;
+	struct {
+		/* Memory allowed */
+		struct {
+			int err;
+			unsigned long maxnode;
+			unsigned long *nodemask;
+		} allowed;
+		/* Address information */
+		struct {
+			int err;
+			unsigned long addr;
+			unsigned long node;
+			unsigned short mode;
+			unsigned short flags;
+		} addr;
+		/* Interleave */
+	} get;
+	/* Mode specific settings */
+	union {
+		struct {
+			unsigned long next_node; /* get only */
+		} interleave;
+	};
+};
+
 /* Flags for set_mempolicy */
 #define MPOL_F_STATIC_NODES	(1 << 15)
 #define MPOL_F_RELATIVE_NODES	(1 << 14)
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index f49337f6f300..1cf7709400f1 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -1483,7 +1483,7 @@  static inline int sanitize_mpol_flags(int *mode, unsigned short *flags)
 	*flags = *mode & MPOL_MODE_FLAGS;
 	*mode &= ~MPOL_MODE_FLAGS;
 
-	if ((unsigned int)(*mode) >=  MPOL_MAX)
+	if ((unsigned int)(*mode) >= MPOL_LEGACY)
 		return -EINVAL;
 	if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
 		return -EINVAL;
@@ -1614,6 +1614,219 @@  SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
 	return kernel_set_mempolicy(mode, nmask, maxnode);
 }
 
+static long do_set_mempolicy2(struct mempolicy_args *args)
+{
+	struct mempolicy *new = NULL;
+	nodemask_t nodes;
+	int err;
+
+	if (args->mode <= MPOL_LEGACY)
+		return -EINVAL;
+
+	if (args->mode >= MPOL_MAX)
+		return -EINVAL;
+
+	err = get_nodes(&nodes, args->nodemask, args->maxnode);
+	if (err)
+		return err;
+
+	new = mpol_new(args->mode, args->flags, &nodes);
+	if (IS_ERR(new)) {
+		err = PTR_ERR(new);
+		goto out;
+	}
+
+	switch (args->mode) {
+	default:
+		BUG();
+	}
+
+	if (err)
+		goto out;
+
+	err = swap_mempolicy(new, &nodes);
+out:
+	if (err && new)
+		mpol_put(new);
+	return err;
+};
+
+static bool mempolicy2_args_valid(struct mempolicy_args *kargs)
+{
+	/* Legacy modes are routed through the legacy interface */
+	if (kargs->mode <= MPOL_LEGACY)
+		return false;
+
+	if (kargs->mode >= MPOL_MAX)
+		return false;
+
+	return true;
+}
+
+static long kernel_set_mempolicy2(const struct mempolicy_args __user *uargs,
+				  size_t usize)
+{
+	struct mempolicy_args kargs;
+	int err;
+
+	if (usize != sizeof(kargs))
+		return -EINVAL;
+
+	err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize);
+	if (err)
+		return err;
+
+	/* If the mode is legacy, use the legacy path */
+	if (kargs.mode < MPOL_LEGACY) {
+		int legacy_mode = kargs.mode | kargs.flags;
+		const unsigned long __user *lnmask = kargs.nodemask;
+		unsigned long maxnode = kargs.maxnode;
+
+		return kernel_set_mempolicy(legacy_mode, lnmask, maxnode);
+	}
+
+	if (!mempolicy2_args_valid(&kargs))
+		return -EINVAL;
+
+	return do_set_mempolicy2(&kargs);
+}
+
+SYSCALL_DEFINE2(set_mempolicy2, const struct mempolicy_args __user *, args,
+		size_t, size)
+{
+	return kernel_set_mempolicy2(args, size);
+}
+
+/* Gets extended mempolicy information */
+static long do_get_mempolicy2(struct mempolicy_args *kargs)
+{
+	struct mempolicy *pol = current->mempolicy;
+	nodemask_t knodes;
+	int err = 0;
+
+	kargs->err = 0;
+	kargs->mode = pol->mode;
+	/* Mask off internal flags */
+	kargs->flags = (pol->flags & MPOL_MODE_FLAGS);
+
+	if (kargs->nodemask) {
+		if (mpol_store_user_nodemask(pol)) {
+			knodes = pol->w.user_nodemask;
+		} else {
+			task_lock(current);
+			get_policy_nodemask(pol, &knodes);
+			task_unlock(current);
+		}
+		err = copy_nodes_to_user(kargs->nodemask,
+					 kargs->maxnode,
+					 &knodes);
+		if (err)
+			return -EINVAL;
+	}
+
+
+	if (kargs->get.allowed.nodemask) {
+		kargs->get.allowed.err = 0;
+		task_lock(current);
+		knodes = cpuset_current_mems_allowed;
+		task_unlock(current);
+		err = copy_nodes_to_user(kargs->get.allowed.nodemask,
+					 kargs->get.allowed.maxnode,
+					 &knodes);
+		kargs->get.allowed.err = err ? err : 0;
+		kargs->err |= err ? err : 1;
+	}
+
+	if (kargs->get.addr.addr) {
+		struct mempolicy *addr_pol = NULL;
+		struct vm_area_struct *vma = NULL;
+		struct mm_struct *mm = current->mm;
+		unsigned long addr = kargs->get.addr.addr;
+
+		kargs->get.addr.err = 0;
+
+		/*
+		 * Do NOT fall back to task policy if the
+		 * vma/shared policy at addr is NULL.  We
+		 * want to return MPOL_DEFAULT in this case.
+		 */
+		mmap_read_lock(mm);
+		vma = vma_lookup(mm, addr);
+		if (!vma) {
+			mmap_read_unlock(mm);
+			kargs->get.addr.err = -EFAULT;
+			kargs->err |= err ? err : 2;
+			goto mode_info;
+		}
+		if (vma->vm_ops && vma->vm_ops->get_policy)
+			addr_pol = vma->vm_ops->get_policy(vma, addr);
+		else
+			addr_pol = vma->vm_policy;
+
+		kargs->get.addr.mode = addr_pol->mode;
+		/* Mask off internal flags */
+		kargs->get.addr.flags = (pol->flags & MPOL_MODE_FLAGS);
+
+		/*
+		 * Take a refcount on the mpol, because we are about to
+		 * drop the mmap_lock, after which only "pol" remains
+		 * valid, "vma" is stale.
+		 */
+		vma = NULL;
+		mpol_get(addr_pol);
+		mmap_read_unlock(mm);
+		err = lookup_node(mm, addr);
+		mpol_put(addr_pol);
+		if (err < 0) {
+			kargs->get.addr.err = err;
+			kargs->err |= err ? err : 4;
+			goto mode_info;
+		}
+		kargs->get.addr.node = err;
+	}
+
+mode_info:
+	switch (kargs->mode) {
+	case MPOL_INTERLEAVE:
+		kargs->interleave.next_node = next_node_in(current->il_prev,
+							   pol->nodes);
+		break;
+	default:
+		break;
+	}
+
+	return err;
+}
+
+static long kernel_get_mempolicy2(struct mempolicy_args __user *uargs,
+				  size_t usize)
+{
+	struct mempolicy_args kargs;
+	int err;
+
+	if (usize != sizeof(struct mempolicy_args))
+		return -EINVAL;
+
+	err = copy_struct_from_user(&kargs, sizeof(kargs), uargs, usize);
+	if (err)
+		return err;
+
+	/* Get the extended memory policy information (kargs.ext) */
+	err = do_get_mempolicy2(&kargs);
+	if (err)
+		return err;
+
+	err = copy_to_user(uargs, &kargs, sizeof(struct mempolicy_args));
+
+	return err;
+}
+
+SYSCALL_DEFINE2(get_mempolicy2, struct mempolicy_args __user *, policy,
+		size_t, size)
+{
+	return kernel_get_mempolicy2(policy, size);
+}
+
 static int kernel_migrate_pages(pid_t pid, unsigned long maxnode,
 				const unsigned long __user *old_nodes,
 				const unsigned long __user *new_nodes)