diff mbox series

[v1,2/4] mm/mempolicy: unify the preprocessing for mbind and set_mempolicy

Message ID 1622005302-23027-3-git-send-email-feng.tang@intel.com (mailing list archive)
State New, archived
Headers show
Series mm/mempolicy: some fix and semantics cleanup | expand

Commit Message

Feng Tang May 26, 2021, 5:01 a.m. UTC
Currently the kernel_mbind() and kernel_set_mempolicy() do almost
the same operation for parameter sanity check and preprocessing.

Add a helper function to unify the code to reduce the redundancy,
and make it easier for changing the pre-processing code in future.

[thanks to David Rientjes for suggesting using helper function
instead of macro]

Signed-off-by: Feng Tang <feng.tang@intel.com>
---
 mm/mempolicy.c | 43 ++++++++++++++++++++++++-------------------
 1 file changed, 24 insertions(+), 19 deletions(-)

Comments

Michal Hocko May 27, 2021, 7:39 a.m. UTC | #1
On Wed 26-05-21 13:01:40, Feng Tang wrote:
> Currently the kernel_mbind() and kernel_set_mempolicy() do almost
> the same operation for parameter sanity check and preprocessing.
> 
> Add a helper function to unify the code to reduce the redundancy,
> and make it easier for changing the pre-processing code in future.
> 
> [thanks to David Rientjes for suggesting using helper function
> instead of macro]

I appreciate removing the code duplication but I am not really convinced
this is an improvement. You are conflating two things. One is the mpol
flags checking and node mask copying. While abstracting the first one
makes sense to me the later is already a single line of code that makes
your helper unnecessarily complex. So I would go with sanitize_mpol_flags
and put a flags handling there and leave get_nodes alone.
 
> Signed-off-by: Feng Tang <feng.tang@intel.com>
> ---
>  mm/mempolicy.c | 43 ++++++++++++++++++++++++-------------------
>  1 file changed, 24 insertions(+), 19 deletions(-)

Funny how removing code duplication adds more code than it removes ;)

> 
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index 1964cca..2830bb8 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -1460,6 +1460,20 @@ static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
>  	return copy_to_user(mask, nodes_addr(*nodes), copy) ? -EFAULT : 0;
>  }
>  
> +static inline int mpol_pre_process(int *mode, const unsigned long __user *nmask, unsigned long maxnode, nodemask_t *nodes, unsigned short *flags)
> +{
> +	int ret;
> +
> +	*flags = *mode & MPOL_MODE_FLAGS;
> +	*mode &= ~MPOL_MODE_FLAGS;
> +	if ((unsigned int)(*mode) >= MPOL_MAX)
> +		return -EINVAL;
> +	if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
> +		return -EINVAL;
> +	ret = get_nodes(nodes, nmask, maxnode);
> +	return ret;
> +}
> +
>  static long kernel_mbind(unsigned long start, unsigned long len,
>  			 unsigned long mode, const unsigned long __user *nmask,
>  			 unsigned long maxnode, unsigned int flags)
> @@ -1467,19 +1481,14 @@ static long kernel_mbind(unsigned long start, unsigned long len,
>  	nodemask_t nodes;
>  	int err;
>  	unsigned short mode_flags;
> +	int lmode = mode;
>  
> -	start = untagged_addr(start);
> -	mode_flags = mode & MPOL_MODE_FLAGS;
> -	mode &= ~MPOL_MODE_FLAGS;
> -	if (mode >= MPOL_MAX)
> -		return -EINVAL;
> -	if ((mode_flags & MPOL_F_STATIC_NODES) &&
> -	    (mode_flags & MPOL_F_RELATIVE_NODES))
> -		return -EINVAL;
> -	err = get_nodes(&nodes, nmask, maxnode);
> +	err = mpol_pre_process(&lmode, nmask, maxnode, &nodes, &mode_flags);
>  	if (err)
>  		return err;
> -	return do_mbind(start, len, mode, mode_flags, &nodes, flags);
> +
> +	start = untagged_addr(start);
> +	return do_mbind(start, len, lmode, mode_flags, &nodes, flags);
>  }
>  
>  SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len,
> @@ -1495,18 +1504,14 @@ static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
>  {
>  	int err;
>  	nodemask_t nodes;
> -	unsigned short flags;
> +	unsigned short mode_flags;
> +	int lmode = mode;
>  
> -	flags = mode & MPOL_MODE_FLAGS;
> -	mode &= ~MPOL_MODE_FLAGS;
> -	if ((unsigned int)mode >= MPOL_MAX)
> -		return -EINVAL;
> -	if ((flags & MPOL_F_STATIC_NODES) && (flags & MPOL_F_RELATIVE_NODES))
> -		return -EINVAL;
> -	err = get_nodes(&nodes, nmask, maxnode);
> +	err = mpol_pre_process(&lmode, nmask, maxnode, &nodes, &mode_flags);
>  	if (err)
>  		return err;
> -	return do_set_mempolicy(mode, flags, &nodes);
> +
> +	return do_set_mempolicy(lmode, mode_flags, &nodes);
>  }
>  
>  SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
> -- 
> 2.7.4
Feng Tang May 27, 2021, 12:31 p.m. UTC | #2
On Thu, May 27, 2021 at 09:39:49AM +0200, Michal Hocko wrote:
> On Wed 26-05-21 13:01:40, Feng Tang wrote:
> > Currently the kernel_mbind() and kernel_set_mempolicy() do almost
> > the same operation for parameter sanity check and preprocessing.
> > 
> > Add a helper function to unify the code to reduce the redundancy,
> > and make it easier for changing the pre-processing code in future.
> > 
> > [thanks to David Rientjes for suggesting using helper function
> > instead of macro]
> 
> I appreciate removing the code duplication but I am not really convinced
> this is an improvement. You are conflating two things. One is the mpol
> flags checking and node mask copying. While abstracting the first one
> makes sense to me the later is already a single line of code that makes
> your helper unnecessarily complex. So I would go with sanitize_mpol_flags
> and put a flags handling there and leave get_nodes alone.
>  
> > Signed-off-by: Feng Tang <feng.tang@intel.com>
> > ---
> >  mm/mempolicy.c | 43 ++++++++++++++++++++++++-------------------
> >  1 file changed, 24 insertions(+), 19 deletions(-)
> 
> Funny how removing code duplication adds more code than it removes ;)

Yes.

And in last verion which uses macro to unify the code: 
https://lore.kernel.org/lkml/1621499404-67756-3-git-send-email-feng.tang@intel.com/
it does save some lines :)

 mm/mempolicy.c | 46 +++++++++++++++++++---------------------------
 1 file changed, 19 insertions(+), 27 deletions(-)

Thanks,
Feng

> > 
> > diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> > index 1964cca..2830bb8 100644
> > --- a/mm/mempolicy.c
> > +++ b/mm/mempolicy.c
> > @@ -1460,6 +1460,20 @@ static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
> >  	return copy_to_user(mask, nodes_addr(*nodes), copy) ? -EFAULT : 0;
> >  }
> >  
> > +static inline int mpol_pre_process(int *mode, const unsigned long __user *nmask, unsigned long maxnode, nodemask_t *nodes, unsigned short *flags)
> > +{
> > +	int ret;
> > +
> > +	*flags = *mode & MPOL_MODE_FLAGS;
> > +	*mode &= ~MPOL_MODE_FLAGS;
> > +	if ((unsigned int)(*mode) >= MPOL_MAX)
> > +		return -EINVAL;
> > +	if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
> > +		return -EINVAL;
> > +	ret = get_nodes(nodes, nmask, maxnode);
> > +	return ret;
> > +}
> > +
> >  static long kernel_mbind(unsigned long start, unsigned long len,
> >  			 unsigned long mode, const unsigned long __user *nmask,
> >  			 unsigned long maxnode, unsigned int flags)
> > @@ -1467,19 +1481,14 @@ static long kernel_mbind(unsigned long start, unsigned long len,
> >  	nodemask_t nodes;
> >  	int err;
> >  	unsigned short mode_flags;
> > +	int lmode = mode;
> >  
> > -	start = untagged_addr(start);
> > -	mode_flags = mode & MPOL_MODE_FLAGS;
> > -	mode &= ~MPOL_MODE_FLAGS;
> > -	if (mode >= MPOL_MAX)
> > -		return -EINVAL;
> > -	if ((mode_flags & MPOL_F_STATIC_NODES) &&
> > -	    (mode_flags & MPOL_F_RELATIVE_NODES))
> > -		return -EINVAL;
> > -	err = get_nodes(&nodes, nmask, maxnode);
> > +	err = mpol_pre_process(&lmode, nmask, maxnode, &nodes, &mode_flags);
> >  	if (err)
> >  		return err;
> > -	return do_mbind(start, len, mode, mode_flags, &nodes, flags);
> > +
> > +	start = untagged_addr(start);
> > +	return do_mbind(start, len, lmode, mode_flags, &nodes, flags);
> >  }
> >  
> >  SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len,
> > @@ -1495,18 +1504,14 @@ static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
> >  {
> >  	int err;
> >  	nodemask_t nodes;
> > -	unsigned short flags;
> > +	unsigned short mode_flags;
> > +	int lmode = mode;
> >  
> > -	flags = mode & MPOL_MODE_FLAGS;
> > -	mode &= ~MPOL_MODE_FLAGS;
> > -	if ((unsigned int)mode >= MPOL_MAX)
> > -		return -EINVAL;
> > -	if ((flags & MPOL_F_STATIC_NODES) && (flags & MPOL_F_RELATIVE_NODES))
> > -		return -EINVAL;
> > -	err = get_nodes(&nodes, nmask, maxnode);
> > +	err = mpol_pre_process(&lmode, nmask, maxnode, &nodes, &mode_flags);
> >  	if (err)
> >  		return err;
> > -	return do_set_mempolicy(mode, flags, &nodes);
> > +
> > +	return do_set_mempolicy(lmode, mode_flags, &nodes);
> >  }
> >  
> >  SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,
> > -- 
> > 2.7.4
> 
> -- 
> Michal Hocko
> SUSE Labs
diff mbox series

Patch

diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 1964cca..2830bb8 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -1460,6 +1460,20 @@  static int copy_nodes_to_user(unsigned long __user *mask, unsigned long maxnode,
 	return copy_to_user(mask, nodes_addr(*nodes), copy) ? -EFAULT : 0;
 }
 
+static inline int mpol_pre_process(int *mode, const unsigned long __user *nmask, unsigned long maxnode, nodemask_t *nodes, unsigned short *flags)
+{
+	int ret;
+
+	*flags = *mode & MPOL_MODE_FLAGS;
+	*mode &= ~MPOL_MODE_FLAGS;
+	if ((unsigned int)(*mode) >= MPOL_MAX)
+		return -EINVAL;
+	if ((*flags & MPOL_F_STATIC_NODES) && (*flags & MPOL_F_RELATIVE_NODES))
+		return -EINVAL;
+	ret = get_nodes(nodes, nmask, maxnode);
+	return ret;
+}
+
 static long kernel_mbind(unsigned long start, unsigned long len,
 			 unsigned long mode, const unsigned long __user *nmask,
 			 unsigned long maxnode, unsigned int flags)
@@ -1467,19 +1481,14 @@  static long kernel_mbind(unsigned long start, unsigned long len,
 	nodemask_t nodes;
 	int err;
 	unsigned short mode_flags;
+	int lmode = mode;
 
-	start = untagged_addr(start);
-	mode_flags = mode & MPOL_MODE_FLAGS;
-	mode &= ~MPOL_MODE_FLAGS;
-	if (mode >= MPOL_MAX)
-		return -EINVAL;
-	if ((mode_flags & MPOL_F_STATIC_NODES) &&
-	    (mode_flags & MPOL_F_RELATIVE_NODES))
-		return -EINVAL;
-	err = get_nodes(&nodes, nmask, maxnode);
+	err = mpol_pre_process(&lmode, nmask, maxnode, &nodes, &mode_flags);
 	if (err)
 		return err;
-	return do_mbind(start, len, mode, mode_flags, &nodes, flags);
+
+	start = untagged_addr(start);
+	return do_mbind(start, len, lmode, mode_flags, &nodes, flags);
 }
 
 SYSCALL_DEFINE6(mbind, unsigned long, start, unsigned long, len,
@@ -1495,18 +1504,14 @@  static long kernel_set_mempolicy(int mode, const unsigned long __user *nmask,
 {
 	int err;
 	nodemask_t nodes;
-	unsigned short flags;
+	unsigned short mode_flags;
+	int lmode = mode;
 
-	flags = mode & MPOL_MODE_FLAGS;
-	mode &= ~MPOL_MODE_FLAGS;
-	if ((unsigned int)mode >= MPOL_MAX)
-		return -EINVAL;
-	if ((flags & MPOL_F_STATIC_NODES) && (flags & MPOL_F_RELATIVE_NODES))
-		return -EINVAL;
-	err = get_nodes(&nodes, nmask, maxnode);
+	err = mpol_pre_process(&lmode, nmask, maxnode, &nodes, &mode_flags);
 	if (err)
 		return err;
-	return do_set_mempolicy(mode, flags, &nodes);
+
+	return do_set_mempolicy(lmode, mode_flags, &nodes);
 }
 
 SYSCALL_DEFINE3(set_mempolicy, int, mode, const unsigned long __user *, nmask,