diff mbox series

[1/2] net: add __sys_socket_file()

Message ID 20220412202240.234207-2-axboe@kernel.dk (mailing list archive)
State Not Applicable
Delegated to: Netdev Maintainers
Headers show
Series Add io_uring socket(2) support | expand

Checks

Context Check Description
netdev/tree_selection success Guessing tree name failed - patch did not apply, async

Commit Message

Jens Axboe April 12, 2022, 8:22 p.m. UTC
This works like __sys_socket(), except instead of allocating and
returning a socket fd, it just returns the file associated with the
socket. No fd is installed into the process file table.

This is similar to do_accept(), and allows io_uring to use this without
instantiating a file descriptor in the process file table.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
---
 include/linux/socket.h |  1 +
 net/socket.c           | 52 ++++++++++++++++++++++++++++++++++--------
 2 files changed, 43 insertions(+), 10 deletions(-)

Comments

David Miller April 13, 2022, 10:59 a.m. UTC | #1
From: Jens Axboe <axboe@kernel.dk>
Date: Tue, 12 Apr 2022 14:22:39 -0600

> This works like __sys_socket(), except instead of allocating and
> returning a socket fd, it just returns the file associated with the
> socket. No fd is installed into the process file table.
> 
> This is similar to do_accept(), and allows io_uring to use this without
> instantiating a file descriptor in the process file table.
> 
> Signed-off-by: Jens Axboe <axboe@kernel.dk>

Acked-by: David S. Miller <davem@davemloft.net>
diff mbox series

Patch

diff --git a/include/linux/socket.h b/include/linux/socket.h
index 6f85f5d957ef..a1882e1e71d2 100644
--- a/include/linux/socket.h
+++ b/include/linux/socket.h
@@ -434,6 +434,7 @@  extern struct file *do_accept(struct file *file, unsigned file_flags,
 extern int __sys_accept4(int fd, struct sockaddr __user *upeer_sockaddr,
 			 int __user *upeer_addrlen, int flags);
 extern int __sys_socket(int family, int type, int protocol);
+extern struct file *__sys_socket_file(int family, int type, int protocol);
 extern int __sys_bind(int fd, struct sockaddr __user *umyaddr, int addrlen);
 extern int __sys_connect_file(struct file *file, struct sockaddr_storage *addr,
 			      int addrlen, int file_flags);
diff --git a/net/socket.c b/net/socket.c
index 6887840682bb..bb6a1a12fbde 100644
--- a/net/socket.c
+++ b/net/socket.c
@@ -504,7 +504,7 @@  static int sock_map_fd(struct socket *sock, int flags)
 struct socket *sock_from_file(struct file *file)
 {
 	if (file->f_op == &socket_file_ops)
-		return file->private_data;	/* set in sock_map_fd */
+		return file->private_data;	/* set in sock_alloc_file */
 
 	return NULL;
 }
@@ -1538,11 +1538,10 @@  int sock_create_kern(struct net *net, int family, int type, int protocol, struct
 }
 EXPORT_SYMBOL(sock_create_kern);
 
-int __sys_socket(int family, int type, int protocol)
+static struct socket *__sys_socket_create(int family, int type, int protocol)
 {
-	int retval;
 	struct socket *sock;
-	int flags;
+	int retval;
 
 	/* Check the SOCK_* constants for consistency.  */
 	BUILD_BUG_ON(SOCK_CLOEXEC != O_CLOEXEC);
@@ -1550,17 +1549,50 @@  int __sys_socket(int family, int type, int protocol)
 	BUILD_BUG_ON(SOCK_CLOEXEC & SOCK_TYPE_MASK);
 	BUILD_BUG_ON(SOCK_NONBLOCK & SOCK_TYPE_MASK);
 
-	flags = type & ~SOCK_TYPE_MASK;
-	if (flags & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
-		return -EINVAL;
+	if ((type & ~SOCK_TYPE_MASK) & ~(SOCK_CLOEXEC | SOCK_NONBLOCK))
+		return ERR_PTR(-EINVAL);
 	type &= SOCK_TYPE_MASK;
 
+	retval = sock_create(family, type, protocol, &sock);
+	if (retval < 0)
+		return ERR_PTR(retval);
+
+	return sock;
+}
+
+struct file *__sys_socket_file(int family, int type, int protocol)
+{
+	struct socket *sock;
+	struct file *file;
+	int flags;
+
+	sock = __sys_socket_create(family, type, protocol);
+	if (IS_ERR(sock))
+		return ERR_CAST(sock);
+
+	flags = type & ~SOCK_TYPE_MASK;
 	if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK))
 		flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
 
-	retval = sock_create(family, type, protocol, &sock);
-	if (retval < 0)
-		return retval;
+	file = sock_alloc_file(sock, flags, NULL);
+	if (IS_ERR(file))
+		sock_release(sock);
+
+	return file;
+}
+
+int __sys_socket(int family, int type, int protocol)
+{
+	struct socket *sock;
+	int flags;
+
+	sock = __sys_socket_create(family, type, protocol);
+	if (IS_ERR(sock))
+		return PTR_ERR(sock);
+
+	flags = type & ~SOCK_TYPE_MASK;
+	if (SOCK_NONBLOCK != O_NONBLOCK && (flags & SOCK_NONBLOCK))
+		flags = (flags & ~SOCK_NONBLOCK) | O_NONBLOCK;
 
 	return sock_map_fd(sock, flags & (O_CLOEXEC | O_NONBLOCK));
 }