diff mbox series

[v2,1/2] mm/memfd: Refactor and cleanup the logic in memfd_create()

Message ID 20250107184804.4074147-2-isaacmanjarres@google.com (mailing list archive)
State New
Headers show
Series Cleanup for memfd_create() | expand

Commit Message

Isaac Manjarres Jan. 7, 2025, 6:48 p.m. UTC
memfd_create() is a pretty busy function that could be easier to read
if some of the logic was split out into helper functions.

Therefore, split the flags check, name creation, and file creation into
their own helper functions, and create the file structure before
creating the memfd. This allows for simplifying the error handling path
in memfd_create().

No functional change.

Signed-off-by: Isaac J. Manjarres <isaacmanjarres@google.com>
---
 mm/memfd.c | 87 +++++++++++++++++++++++++++++++++++-------------------
 1 file changed, 56 insertions(+), 31 deletions(-)

Comments

Alice Ryhl Jan. 8, 2025, 1:31 p.m. UTC | #1
On Tue, Jan 7, 2025 at 7:48 PM Isaac J. Manjarres
<isaacmanjarres@google.com> wrote:
> +SYSCALL_DEFINE2(memfd_create,
> +               const char __user *, uname,
> +               unsigned int, flags)
> +{
> +       struct file *file;
> +       int fd;
> +       char *name;
> +
> +       name = memfd_create_name(uname);
> +       if (IS_ERR(name))
> +               return PTR_ERR(name);
> +
> +       file = memfd_file_create(name, flags);
> +       /* name is not needed beyond this point. */
>         kfree(name);
> -       return error;
> +       if (IS_ERR(file))
> +               return PTR_ERR(file);
> +
> +       fd = get_unused_fd_flags((flags & MFD_CLOEXEC) ? O_CLOEXEC : 0);
> +       if (fd >= 0)
> +               fd_install(fd, file);
> +       else
> +               fput(file);

You changed the order so that get_unused_fd_flags() happens after
creating the file, so the error path now does fput(file) instead of
put_unused_fd(fd). Is there a reason for this? I would generally
assume that calling get_unused_fd_flags() first is better.

Otherwise this LGTM.


Alice
Lorenzo Stoakes Jan. 8, 2025, 6:30 p.m. UTC | #2
On Tue, Jan 07, 2025 at 10:48:01AM -0800, Isaac J. Manjarres wrote:
> memfd_create() is a pretty busy function that could be easier to read
> if some of the logic was split out into helper functions.
>
> Therefore, split the flags check, name creation, and file creation into
> their own helper functions, and create the file structure before
> creating the memfd. This allows for simplifying the error handling path
> in memfd_create().

I do like the intent of this change, but I think this needs some tweaking.

I wish the diff algorithm would do a little better here, because it's quite
hard to follow. In no way your fault that :) difftastic hopefully can help
me here...

>
> No functional change.
>
> Signed-off-by: Isaac J. Manjarres <isaacmanjarres@google.com>
> ---
>  mm/memfd.c | 87 +++++++++++++++++++++++++++++++++++-------------------
>  1 file changed, 56 insertions(+), 31 deletions(-)
>
> diff --git a/mm/memfd.c b/mm/memfd.c
> index 5f5a23c9051d..a9430090bb20 100644
> --- a/mm/memfd.c
> +++ b/mm/memfd.c
> @@ -369,16 +369,8 @@ int memfd_check_seals_mmap(struct file *file, unsigned long *vm_flags_ptr)
>  	return err;
>  }
>
> -SYSCALL_DEFINE2(memfd_create,
> -		const char __user *, uname,
> -		unsigned int, flags)
> +static int memfd_validate_flags(unsigned int flags)

For static functions the memfd_ prefix is redundant, please strip them. We
know we're in mm/memfd.c which is context enough for these internal
helpers!

>  {
> -	unsigned int *file_seals;
> -	struct file *file;
> -	int fd, error;
> -	char *name;
> -	long len;
> -
>  	if (!(flags & MFD_HUGETLB)) {
>  		if (flags & ~(unsigned int)MFD_ALL_FLAGS)
>  			return -EINVAL;
> @@ -393,20 +385,25 @@ SYSCALL_DEFINE2(memfd_create,
>  	if ((flags & MFD_EXEC) && (flags & MFD_NOEXEC_SEAL))
>  		return -EINVAL;
>
> -	error = check_sysctl_memfd_noexec(&flags);
> -	if (error < 0)
> -		return error;
> +	return check_sysctl_memfd_noexec(&flags);

More importantly - this is broken...

The check_sysctl_memfd_noexec() function is _changing_ flags, which you now
discard.

This also renders 'validate' in the name a little inaccurate (hey naming is
hard :), perhaps 'sanitise_flags()'?

Anyway you should pass flags as a pointer (even if that's yuck) and rename.

> +}
> +
> +static char *memfd_create_name(const char __user *uname)

Again, strip memfd_ prefix please.

Also I don't know what 'create' means here. Given the function you're
interacting with is memfd_create() it's rendered a little vague.

I'd say 'alloc_name()' would be better.

> +{
> +	int error;
> +	char *name;
> +	long len;
>
>  	/* length includes terminating zero */
>  	len = strnlen_user(uname, MFD_NAME_MAX_LEN + 1);
>  	if (len <= 0)
> -		return -EFAULT;
> +		return ERR_PTR(-EFAULT);

Not sure if this is necessary, I mean I guess technically... but feels like
it's adding a bunch of noise.

I know you refactor this whole thing in the next commit so maybe to reduce
the size of this commit you could drop these changes here and keep the bare
minimum before you change the function again?


>  	if (len > MFD_NAME_MAX_LEN + 1)
> -		return -EINVAL;
> +		return ERR_PTR(-EINVAL);
>
>  	name = kmalloc(len + MFD_NAME_PREFIX_LEN, GFP_KERNEL);
>  	if (!name)
> -		return -ENOMEM;
> +		return ERR_PTR(-ENOMEM);
>
>  	strcpy(name, MFD_NAME_PREFIX);
>  	if (copy_from_user(&name[MFD_NAME_PREFIX_LEN], uname, len)) {
> @@ -420,11 +417,22 @@ SYSCALL_DEFINE2(memfd_create,
>  		goto err_name;
>  	}

>
> -	fd = get_unused_fd_flags((flags & MFD_CLOEXEC) ? O_CLOEXEC : 0);
> -	if (fd < 0) {
> -		error = fd;
> -		goto err_name;
> -	}
> +	return name;
> +
> +err_name:
> +	kfree(name);
> +	return ERR_PTR(error);
> +}
> +
> +static struct file *memfd_file_create(const char *name, unsigned int flags)

I really am not a great fan of this name, memfd_ prefix obviously has to
go, but 'file_create' when the actual system call is 'memfd_create".

Again, naming eh? It is hard :)

alloc_file() probably works best as you are in fact allocating memory for
this.

Then, as mentioned below, restore the original ordering of fd assignment in
the syscall, do so before file allocation, and install the file into the fd
afterwards.

> +{
> +	unsigned int *file_seals;
> +	struct file *file;
> +	int error;
> +
> +	error = memfd_validate_flags(flags);
> +	if (error < 0)
> +		return ERR_PTR(error);

I'm not actually sure why you put this here, it seems quite
arbitrary. Let's invoke this from the syscall so we neatly divide the logic
into each part rather than dividing into different parts one of which is
invoked by another.

And obviously make sure the original ordering is restored.

>
>  	if (flags & MFD_HUGETLB) {
>  		file = hugetlb_file_setup(name, 0, VM_NORESERVE,
> @@ -433,10 +441,8 @@ SYSCALL_DEFINE2(memfd_create,
>  					MFD_HUGE_MASK);
>  	} else
>  		file = shmem_file_setup(name, 0, VM_NORESERVE);

I know not to do with you, and strictly this is probably in line with
kernel code style, but this dangling else _kills_ me.

Could we put { } around it? Risking invoking the ire of the strict
adherents of the coding style perhaps here :P

> -	if (IS_ERR(file)) {
> -		error = PTR_ERR(file);
> -		goto err_fd;
> -	}
> +	if (IS_ERR(file))
> +		return file;
>  	file->f_mode |= FMODE_LSEEK | FMODE_PREAD | FMODE_PWRITE;
>  	file->f_flags |= O_LARGEFILE;
>
> @@ -456,13 +462,32 @@ SYSCALL_DEFINE2(memfd_create,
>  			*file_seals &= ~F_SEAL_SEAL;
>  	}
>
> -	fd_install(fd, file);
> -	kfree(name);
> -	return fd;
> +	return file;
> +}
>
> -err_fd:
> -	put_unused_fd(fd);
> -err_name:
> +SYSCALL_DEFINE2(memfd_create,
> +		const char __user *, uname,
> +		unsigned int, flags)
> +{
> +	struct file *file;
> +	int fd;
> +	char *name;
> +
> +	name = memfd_create_name(uname);
> +	if (IS_ERR(name))
> +		return PTR_ERR(name);

You're changing the ordering of checks which is user-visible. Previously
the flags would be validated first, now you're 'creating' the name (not
sure this is a great name - naming is hard obviously but this doesn't
really tell me what you intend here, I'll highlight in that bit of code I
guess).

Please try to keep the order of checks the same and validate the flags first.

> +
> +	file = memfd_file_create(name, flags);
> +	/* name is not needed beyond this point. */

This is nice to highlight! Though it absolutely SUCKS to kmalloc() some
memory then instantly discard it like this. But I suppose we have no choice
here.

>  	kfree(name);
> -	return error;
> +	if (IS_ERR(file))
> +		return PTR_ERR(file);
> +
> +	fd = get_unused_fd_flags((flags & MFD_CLOEXEC) ? O_CLOEXEC : 0);
> +	if (fd >= 0)
> +		fd_install(fd, file);
> +	else
> +		fput(file);

You've changed the ordering of this again, and you're not doing anything to
free file if something goes wrong with fd allocation? That's a leak no?

Please reinstate the original ordering.

It's strange to open code this when we don't open code other things... but
perhaps for a few lines it's ok.

> +
> +	return fd;
>  }
> --
> 2.47.1.613.gc27f4b7a9f-goog
>
Isaac Manjarres Jan. 8, 2025, 6:40 p.m. UTC | #3
On Wed, Jan 08, 2025 at 02:31:58PM +0100, Alice Ryhl wrote:
> On Tue, Jan 7, 2025 at 7:48 PM Isaac J. Manjarres
> <isaacmanjarres@google.com> wrote:
> > +SYSCALL_DEFINE2(memfd_create,
> > +               const char __user *, uname,
> > +               unsigned int, flags)
> > +{
> > +       struct file *file;
> > +       int fd;
> > +       char *name;
> > +
> > +       name = memfd_create_name(uname);
> > +       if (IS_ERR(name))
> > +               return PTR_ERR(name);
> > +
> > +       file = memfd_file_create(name, flags);
> > +       /* name is not needed beyond this point. */
> >         kfree(name);
> > -       return error;
> > +       if (IS_ERR(file))
> > +               return PTR_ERR(file);
> > +
> > +       fd = get_unused_fd_flags((flags & MFD_CLOEXEC) ? O_CLOEXEC : 0);
> > +       if (fd >= 0)
> > +               fd_install(fd, file);
> > +       else
> > +               fput(file);
> 
> You changed the order so that get_unused_fd_flags() happens after
> creating the file, so the error path now does fput(file) instead of
> put_unused_fd(fd). Is there a reason for this? I would generally
> assume that calling get_unused_fd_flags() first is better.

Thanks for taking a look at this, Alice!

I changed the order so that the code had a more logical structure
where we create objects and use them right away, as opposed to getting
an fd, then creating the file to associate with that file descriptor,
and then actually associating it.

I also structured the code to get rid of the gotos in this function to
make it easier to follow. It also made sense to me to fold the flags
validation into memfd_file_create() since that's where the flags are
used the most anyway, and it also makes sense to validate the flags
first, so reordering the file creation and fd creation allowed me to
do that.

I'm open to restoring the order back to how it was though. Is there a
reason for why get_unused_fd_flags() first is better?

Thanks,
Isaac
diff mbox series

Patch

diff --git a/mm/memfd.c b/mm/memfd.c
index 5f5a23c9051d..a9430090bb20 100644
--- a/mm/memfd.c
+++ b/mm/memfd.c
@@ -369,16 +369,8 @@  int memfd_check_seals_mmap(struct file *file, unsigned long *vm_flags_ptr)
 	return err;
 }
 
-SYSCALL_DEFINE2(memfd_create,
-		const char __user *, uname,
-		unsigned int, flags)
+static int memfd_validate_flags(unsigned int flags)
 {
-	unsigned int *file_seals;
-	struct file *file;
-	int fd, error;
-	char *name;
-	long len;
-
 	if (!(flags & MFD_HUGETLB)) {
 		if (flags & ~(unsigned int)MFD_ALL_FLAGS)
 			return -EINVAL;
@@ -393,20 +385,25 @@  SYSCALL_DEFINE2(memfd_create,
 	if ((flags & MFD_EXEC) && (flags & MFD_NOEXEC_SEAL))
 		return -EINVAL;
 
-	error = check_sysctl_memfd_noexec(&flags);
-	if (error < 0)
-		return error;
+	return check_sysctl_memfd_noexec(&flags);
+}
+
+static char *memfd_create_name(const char __user *uname)
+{
+	int error;
+	char *name;
+	long len;
 
 	/* length includes terminating zero */
 	len = strnlen_user(uname, MFD_NAME_MAX_LEN + 1);
 	if (len <= 0)
-		return -EFAULT;
+		return ERR_PTR(-EFAULT);
 	if (len > MFD_NAME_MAX_LEN + 1)
-		return -EINVAL;
+		return ERR_PTR(-EINVAL);
 
 	name = kmalloc(len + MFD_NAME_PREFIX_LEN, GFP_KERNEL);
 	if (!name)
-		return -ENOMEM;
+		return ERR_PTR(-ENOMEM);
 
 	strcpy(name, MFD_NAME_PREFIX);
 	if (copy_from_user(&name[MFD_NAME_PREFIX_LEN], uname, len)) {
@@ -420,11 +417,22 @@  SYSCALL_DEFINE2(memfd_create,
 		goto err_name;
 	}
 
-	fd = get_unused_fd_flags((flags & MFD_CLOEXEC) ? O_CLOEXEC : 0);
-	if (fd < 0) {
-		error = fd;
-		goto err_name;
-	}
+	return name;
+
+err_name:
+	kfree(name);
+	return ERR_PTR(error);
+}
+
+static struct file *memfd_file_create(const char *name, unsigned int flags)
+{
+	unsigned int *file_seals;
+	struct file *file;
+	int error;
+
+	error = memfd_validate_flags(flags);
+	if (error < 0)
+		return ERR_PTR(error);
 
 	if (flags & MFD_HUGETLB) {
 		file = hugetlb_file_setup(name, 0, VM_NORESERVE,
@@ -433,10 +441,8 @@  SYSCALL_DEFINE2(memfd_create,
 					MFD_HUGE_MASK);
 	} else
 		file = shmem_file_setup(name, 0, VM_NORESERVE);
-	if (IS_ERR(file)) {
-		error = PTR_ERR(file);
-		goto err_fd;
-	}
+	if (IS_ERR(file))
+		return file;
 	file->f_mode |= FMODE_LSEEK | FMODE_PREAD | FMODE_PWRITE;
 	file->f_flags |= O_LARGEFILE;
 
@@ -456,13 +462,32 @@  SYSCALL_DEFINE2(memfd_create,
 			*file_seals &= ~F_SEAL_SEAL;
 	}
 
-	fd_install(fd, file);
-	kfree(name);
-	return fd;
+	return file;
+}
 
-err_fd:
-	put_unused_fd(fd);
-err_name:
+SYSCALL_DEFINE2(memfd_create,
+		const char __user *, uname,
+		unsigned int, flags)
+{
+	struct file *file;
+	int fd;
+	char *name;
+
+	name = memfd_create_name(uname);
+	if (IS_ERR(name))
+		return PTR_ERR(name);
+
+	file = memfd_file_create(name, flags);
+	/* name is not needed beyond this point. */
 	kfree(name);
-	return error;
+	if (IS_ERR(file))
+		return PTR_ERR(file);
+
+	fd = get_unused_fd_flags((flags & MFD_CLOEXEC) ? O_CLOEXEC : 0);
+	if (fd >= 0)
+		fd_install(fd, file);
+	else
+		fput(file);
+
+	return fd;
 }