diff mbox series

[v3,net-next,11/15] socket: Remove kernel socket conversion.

Message ID 20241213092152.14057-12-kuniyu@amazon.com (mailing list archive)
State Changes Requested
Delegated to: Netdev Maintainers
Headers show
Series treewide: socket: Clean up sock_create() and friends. | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for net-next, async
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 0 this patch: 0
netdev/build_tools success No tools touched, skip
netdev/cc_maintainers warning 22 maintainers not CCed: martineau@kernel.org samba-technical@lists.samba.org bharathsm@microsoft.com trondmy@kernel.org guwen@linux.alibaba.com tom@talpey.com linux-rdma@vger.kernel.org mptcp@lists.linux.dev linux-cifs@vger.kernel.org linux-s390@vger.kernel.org ronniesahlberg@gmail.com pc@manguebit.com linux-nfs@vger.kernel.org anna@kernel.org alibuda@linux.alibaba.com sprasad@microsoft.com geliang@kernel.org tonylu@linux.alibaba.com okorniev@redhat.com rds-devel@oss.oracle.com neilb@suse.de Dai.Ngo@oracle.com
netdev/build_clang success Errors and warnings before: 0 this patch: 0
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api warning Found: 'put_net(' was: 0 now: 3
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 3 this patch: 3
netdev/checkpatch warning WARNING: line length of 85 exceeds 80 columns WARNING: line length of 86 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 19 this patch: 19
netdev/source_inline success Was 0 now: 0
netdev/contest success net-next-2024-12-15--09-00 (tests: 795)

Commit Message

Kuniyuki Iwashima Dec. 13, 2024, 9:21 a.m. UTC
Since commit 26abe14379f8 ("net: Modify sk_alloc to not reference count
the netns of kernel sockets."), TCP kernel socket has caused many UAF.

We have converted such sockets to hold netns refcnt, and we have the
same pattern in cifs, mptcp, rds, smc, and sunrpc.

Let's drop the conversion and use sock_create_net() instead.

The changes for cifs, mptcp, and smc are straightforward.

For rds, we need to move maybe_get_net() before sock_create_net() and
sock->ops->accept().

For sunrpc, we call sock_create_net() for IPPROTO_TCP only and still
call sock_create_kern() for others.

Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.com>
Acked-by: Matthieu Baerts (NGI0) <matttbe@kernel.org>
Acked-by: Allison Henderson <allison.henderson@oracle.com>
---
v3: Add missing mutex_unlock in rds_tcp_conn_path_connect().
v2: Collect Acked-by from MPTCP and RDS maintainers

Cc: Steve French <sfrench@samba.org>
Cc: Wenjia Zhang <wenjia@linux.ibm.com>
Cc: Jan Karcher <jaka@linux.ibm.com>
Cc: Chuck Lever <chuck.lever@oracle.com>
Cc: Jeff Layton <jlayton@kernel.org>
---
 fs/smb/client/connect.c | 13 ++-----------
 net/mptcp/subflow.c     | 10 +---------
 net/rds/tcp.c           | 14 --------------
 net/rds/tcp_connect.c   | 21 +++++++++++++++------
 net/rds/tcp_listen.c    | 14 ++++++++++++--
 net/smc/af_smc.c        | 21 ++-------------------
 net/sunrpc/svcsock.c    | 12 ++++++------
 net/sunrpc/xprtsock.c   | 12 ++++--------
 8 files changed, 42 insertions(+), 75 deletions(-)

Comments

Wenjia Zhang Dec. 13, 2024, 1:45 p.m. UTC | #1
On 13.12.24 10:21, Kuniyuki Iwashima wrote:
> Since commit 26abe14379f8 ("net: Modify sk_alloc to not reference count
> the netns of kernel sockets."), TCP kernel socket has caused many UAF.
> 
> We have converted such sockets to hold netns refcnt, and we have the
> same pattern in cifs, mptcp, rds, smc, and sunrpc.
> 
> Let's drop the conversion and use sock_create_net() instead.
> 
> The changes for cifs, mptcp, and smc are straightforward.
> 
> For rds, we need to move maybe_get_net() before sock_create_net() and
> sock->ops->accept().
> 
> For sunrpc, we call sock_create_net() for IPPROTO_TCP only and still
> call sock_create_kern() for others.
> 
> Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.com>
> Acked-by: Matthieu Baerts (NGI0) <matttbe@kernel.org>
> Acked-by: Allison Henderson <allison.henderson@oracle.com>
> ---
> v3: Add missing mutex_unlock in rds_tcp_conn_path_connect().
> v2: Collect Acked-by from MPTCP and RDS maintainers
> 
> Cc: Steve French <sfrench@samba.org>
> Cc: Wenjia Zhang <wenjia@linux.ibm.com>
> Cc: Jan Karcher <jaka@linux.ibm.com>
> Cc: Chuck Lever <chuck.lever@oracle.com>
> Cc: Jeff Layton <jlayton@kernel.org>
> ---
>   fs/smb/client/connect.c | 13 ++-----------
>   net/mptcp/subflow.c     | 10 +---------
>   net/rds/tcp.c           | 14 --------------
>   net/rds/tcp_connect.c   | 21 +++++++++++++++------
>   net/rds/tcp_listen.c    | 14 ++++++++++++--
>   net/smc/af_smc.c        | 21 ++-------------------
>   net/sunrpc/svcsock.c    | 12 ++++++------
>   net/sunrpc/xprtsock.c   | 12 ++++--------
>   8 files changed, 42 insertions(+), 75 deletions(-)
> 
> diff --git a/fs/smb/client/connect.c b/fs/smb/client/connect.c
> index c36c1b4ffe6e..7a67b86c0423 100644
> --- a/fs/smb/client/connect.c
> +++ b/fs/smb/client/connect.c
> @@ -3130,22 +3130,13 @@ generic_ip_connect(struct TCP_Server_Info *server)
>   	if (server->ssocket) {
>   		socket = server->ssocket;
>   	} else {
> -		struct net *net = cifs_net_ns(server);
> -		struct sock *sk;
> -
> -		rc = sock_create_kern(net, sfamily, SOCK_STREAM,
> -				      IPPROTO_TCP, &server->ssocket);
> +		rc = sock_create_net(cifs_net_ns(server), sfamily, SOCK_STREAM,
> +				     IPPROTO_TCP, &server->ssocket);
>   		if (rc < 0) {
>   			cifs_server_dbg(VFS, "Error %d creating socket\n", rc);
>   			return rc;
>   		}
>   
> -		sk = server->ssocket->sk;
> -		__netns_tracker_free(net, &sk->ns_tracker, false);
> -		sk->sk_net_refcnt = 1;
> -		get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(net, 1);
> -
>   		/* BB other socket options to set KEEPALIVE, NODELAY? */
>   		cifs_dbg(FYI, "Socket created\n");
>   		socket = server->ssocket;
> diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c
> index fd021cf8286e..e7e8972bdfca 100644
> --- a/net/mptcp/subflow.c
> +++ b/net/mptcp/subflow.c
> @@ -1755,7 +1755,7 @@ int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
>   	if (unlikely(!sk->sk_socket))
>   		return -EINVAL;
>   
> -	err = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
> +	err = sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
>   	if (err)
>   		return err;
>   
> @@ -1768,14 +1768,6 @@ int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
>   	/* the newly created socket has to be in the same cgroup as its parent */
>   	mptcp_attach_cgroup(sk, sf->sk);
>   
> -	/* kernel sockets do not by default acquire net ref, but TCP timer
> -	 * needs it.
> -	 * Update ns_tracker to current stack trace and refcounted tracker.
> -	 */
> -	__netns_tracker_free(net, &sf->sk->ns_tracker, false);
> -	sf->sk->sk_net_refcnt = 1;
> -	get_net_track(net, &sf->sk->ns_tracker, GFP_KERNEL);
> -	sock_inuse_add(net, 1);
>   	err = tcp_set_ulp(sf->sk, "mptcp");
>   	if (err)
>   		goto err_free;
> diff --git a/net/rds/tcp.c b/net/rds/tcp.c
> index 351ac1747224..4509900476f7 100644
> --- a/net/rds/tcp.c
> +++ b/net/rds/tcp.c
> @@ -494,21 +494,7 @@ bool rds_tcp_tune(struct socket *sock)
>   
>   	tcp_sock_set_nodelay(sock->sk);
>   	lock_sock(sk);
> -	/* TCP timer functions might access net namespace even after
> -	 * a process which created this net namespace terminated.
> -	 */
> -	if (!sk->sk_net_refcnt) {
> -		if (!maybe_get_net(net)) {
> -			release_sock(sk);
> -			return false;
> -		}
> -		/* Update ns_tracker to current stack trace and refcounted tracker */
> -		__netns_tracker_free(net, &sk->ns_tracker, false);
>   
> -		sk->sk_net_refcnt = 1;
> -		netns_tracker_alloc(net, &sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(net, 1);
> -	}
>   	rtn = net_generic(net, rds_tcp_netid);
>   	if (rtn->sndbuf_size > 0) {
>   		sk->sk_sndbuf = rtn->sndbuf_size;
> diff --git a/net/rds/tcp_connect.c b/net/rds/tcp_connect.c
> index a0046e99d6df..c9449780f952 100644
> --- a/net/rds/tcp_connect.c
> +++ b/net/rds/tcp_connect.c
> @@ -93,6 +93,7 @@ int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
>   	struct sockaddr_in6 sin6;
>   	struct sockaddr_in sin;
>   	struct sockaddr *addr;
> +	struct net *net;
>   	int addrlen;
>   	bool isv6;
>   	int ret;
> @@ -107,20 +108,28 @@ int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
>   
>   	mutex_lock(&tc->t_conn_path_lock);
>   
> +	net = rds_conn_net(conn);
> +
>   	if (rds_conn_path_up(cp)) {
> -		mutex_unlock(&tc->t_conn_path_lock);
> -		return 0;
> +		ret = 0;
> +		goto out;
>   	}
> +
> +	if (!maybe_get_net(net)) {
> +		ret = -EINVAL;
> +		goto out;
> +	}
> +
>   	if (ipv6_addr_v4mapped(&conn->c_laddr)) {
> -		ret = sock_create_kern(rds_conn_net(conn), PF_INET,
> -				       SOCK_STREAM, IPPROTO_TCP, &sock);
> +		ret = sock_create_net(net, PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
>   		isv6 = false;
>   	} else {
> -		ret = sock_create_kern(rds_conn_net(conn), PF_INET6,
> -				       SOCK_STREAM, IPPROTO_TCP, &sock);
> +		ret = sock_create_net(net, PF_INET6, SOCK_STREAM, IPPROTO_TCP, &sock);
>   		isv6 = true;
>   	}
>   
> +	put_net(net);
> +
>   	if (ret < 0)
>   		goto out;
>   
> diff --git a/net/rds/tcp_listen.c b/net/rds/tcp_listen.c
> index 69aaf03ab93e..440ac9057148 100644
> --- a/net/rds/tcp_listen.c
> +++ b/net/rds/tcp_listen.c
> @@ -101,6 +101,7 @@ int rds_tcp_accept_one(struct socket *sock)
>   	struct rds_connection *conn;
>   	int ret;
>   	struct inet_sock *inet;
> +	struct net *net;
>   	struct rds_tcp_connection *rs_tcp = NULL;
>   	int conn_state;
>   	struct rds_conn_path *cp;
> @@ -108,7 +109,7 @@ int rds_tcp_accept_one(struct socket *sock)
>   	struct proto_accept_arg arg = {
>   		.flags = O_NONBLOCK,
>   		.kern = true,
> -		.hold_net = false,
> +		.hold_net = true,
>   	};
>   #if !IS_ENABLED(CONFIG_IPV6)
>   	struct in6_addr saddr, daddr;
> @@ -118,13 +119,22 @@ int rds_tcp_accept_one(struct socket *sock)
>   	if (!sock) /* module unload or netns delete in progress */
>   		return -ENETUNREACH;
>   
> +	net = sock_net(sock->sk);
> +
> +	if (!maybe_get_net(net))
> +		return -EINVAL;
> +
>   	ret = sock_create_lite(sock->sk->sk_family,
>   			       sock->sk->sk_type, sock->sk->sk_protocol,
>   			       &new_sock);
> -	if (ret)
> +	if (ret) {
> +		put_net(net);
>   		goto out;
> +	}
>   
>   	ret = sock->ops->accept(sock, new_sock, &arg);
> +	put_net(net);
> +
>   	if (ret < 0)
>   		goto out;
>   
> diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c
> index 6e93f188a908..7b0de80b3aca 100644
> --- a/net/smc/af_smc.c
> +++ b/net/smc/af_smc.c
> @@ -3310,25 +3310,8 @@ static const struct proto_ops smc_sock_ops = {
>   
>   int smc_create_clcsk(struct net *net, struct sock *sk, int family)
>   {
> -	struct smc_sock *smc = smc_sk(sk);
> -	int rc;
> -
> -	rc = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP,
> -			      &smc->clcsock);
> -	if (rc)
> -		return rc;
> -
> -	/* smc_clcsock_release() does not wait smc->clcsock->sk's
> -	 * destruction;  its sk_state might not be TCP_CLOSE after
> -	 * smc->sk is close()d, and TCP timers can be fired later,
> -	 * which need net ref.
> -	 */
> -	sk = smc->clcsock->sk;
> -	__netns_tracker_free(net, &sk->ns_tracker, false);
> -	sk->sk_net_refcnt = 1;
> -	get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
> -	sock_inuse_add(net, 1);
I don't think this line shoud be removed. Otherwise, the popurse here to 
manage the per namespace statistics in the case of network namespace 
isolation would be lost.
@D. Wythe, could you please check it again? Maybe you have some good 
testing on this case.

> -	return 0;
> +	return sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP,
> +			       &smc_sk(sk)->clcsock);
>   }
>   
>   static int __smc_create(struct net *net, struct socket *sock, int protocol,
> diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c
> index 9583bad3d150..cde5765f6f81 100644
> --- a/net/sunrpc/svcsock.c
> +++ b/net/sunrpc/svcsock.c
> @@ -1526,7 +1526,10 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
>   		return ERR_PTR(-EINVAL);
>   	}
>   
> -	error = sock_create_kern(net, family, type, protocol, &sock);
> +	if (protocol == IPPROTO_TCP)
> +		error = sock_create_net(net, family, type, protocol, &sock);
> +	else
> +		error = sock_create_kern(net, family, type, protocol, &sock);
>   	if (error < 0)
>   		return ERR_PTR(error);
>   
> @@ -1551,11 +1554,8 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
>   	newlen = error;
>   
>   	if (protocol == IPPROTO_TCP) {
> -		__netns_tracker_free(net, &sock->sk->ns_tracker, false);
> -		sock->sk->sk_net_refcnt = 1;
> -		get_net_track(net, &sock->sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(net, 1);
> -		if ((error = kernel_listen(sock, 64)) < 0)
> +		error = kernel_listen(sock, 64);
> +		if (error < 0)
>   			goto bummer;
>   	}
>   
> diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
> index feb1768e8a57..f3e139c30442 100644
> --- a/net/sunrpc/xprtsock.c
> +++ b/net/sunrpc/xprtsock.c
> @@ -1924,7 +1924,10 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
>   	struct socket *sock;
>   	int err;
>   
> -	err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
> +	if (protocol == IPPROTO_TCP)
> +		err = sock_create_net(xprt->xprt_net, family, type, protocol, &sock);
> +	else
> +		err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
>   	if (err < 0) {
>   		dprintk("RPC:       can't create %d transport socket (%d).\n",
>   				protocol, -err);
> @@ -1941,13 +1944,6 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
>   		goto out;
>   	}
>   
> -	if (protocol == IPPROTO_TCP) {
> -		__netns_tracker_free(xprt->xprt_net, &sock->sk->ns_tracker, false);
> -		sock->sk->sk_net_refcnt = 1;
> -		get_net_track(xprt->xprt_net, &sock->sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(xprt->xprt_net, 1);
> -	}
> -
>   	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
>   	if (IS_ERR(filp))
>   		return ERR_CAST(filp);
Kuniyuki Iwashima Dec. 13, 2024, 1:54 p.m. UTC | #2
From: Wenjia Zhang <wenjia@linux.ibm.com>
Date: Fri, 13 Dec 2024 14:45:20 +0100
> > diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c
> > index 6e93f188a908..7b0de80b3aca 100644
> > --- a/net/smc/af_smc.c
> > +++ b/net/smc/af_smc.c
> > @@ -3310,25 +3310,8 @@ static const struct proto_ops smc_sock_ops = {
> >   
> >   int smc_create_clcsk(struct net *net, struct sock *sk, int family)
> >   {
> > -	struct smc_sock *smc = smc_sk(sk);
> > -	int rc;
> > -
> > -	rc = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP,
> > -			      &smc->clcsock);
> > -	if (rc)
> > -		return rc;
> > -
> > -	/* smc_clcsock_release() does not wait smc->clcsock->sk's
> > -	 * destruction;  its sk_state might not be TCP_CLOSE after
> > -	 * smc->sk is close()d, and TCP timers can be fired later,
> > -	 * which need net ref.
> > -	 */
> > -	sk = smc->clcsock->sk;
> > -	__netns_tracker_free(net, &sk->ns_tracker, false);
> > -	sk->sk_net_refcnt = 1;
> > -	get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
> > -	sock_inuse_add(net, 1);
> I don't think this line shoud be removed. Otherwise, the popurse here to 
> manage the per namespace statistics in the case of network namespace 
> isolation would be lost.

Now it's counted in sk_alloc().

sock_create_net() below passes hold_net=true to sk_alloc() and if
sk->sk_netns_refcnt (== hold_net) is true, sock_inuse_add() is
called there.

See patch 9 and 10:
https://lore.kernel.org/netdev/20241213092152.14057-10-kuniyu@amazon.com/
https://lore.kernel.org/netdev/20241213092152.14057-11-kuniyu@amazon.com/


> @D. Wythe, could you please check it again? Maybe you have some good 
> testing on this case.
> 
> > -	return 0;
> > +	return sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP,
> > +			       &smc_sk(sk)->clcsock);
> >   }
Chuck Lever Dec. 13, 2024, 2:15 p.m. UTC | #3
On 12/13/24 4:21 AM, Kuniyuki Iwashima wrote:
> Since commit 26abe14379f8 ("net: Modify sk_alloc to not reference count
> the netns of kernel sockets."), TCP kernel socket has caused many UAF.
> 
> We have converted such sockets to hold netns refcnt, and we have the
> same pattern in cifs, mptcp, rds, smc, and sunrpc.
> 
> Let's drop the conversion and use sock_create_net() instead.
> 
> The changes for cifs, mptcp, and smc are straightforward.
> 
> For rds, we need to move maybe_get_net() before sock_create_net() and
> sock->ops->accept().
> 
> For sunrpc, we call sock_create_net() for IPPROTO_TCP only and still
> call sock_create_kern() for others.
> 
> Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.com>
> Acked-by: Matthieu Baerts (NGI0) <matttbe@kernel.org>
> Acked-by: Allison Henderson <allison.henderson@oracle.com>
> ---
> v3: Add missing mutex_unlock in rds_tcp_conn_path_connect().
> v2: Collect Acked-by from MPTCP and RDS maintainers
> 
> Cc: Steve French <sfrench@samba.org>
> Cc: Wenjia Zhang <wenjia@linux.ibm.com>
> Cc: Jan Karcher <jaka@linux.ibm.com>
> Cc: Chuck Lever <chuck.lever@oracle.com>
> Cc: Jeff Layton <jlayton@kernel.org>
> ---
>   fs/smb/client/connect.c | 13 ++-----------
>   net/mptcp/subflow.c     | 10 +---------
>   net/rds/tcp.c           | 14 --------------
>   net/rds/tcp_connect.c   | 21 +++++++++++++++------
>   net/rds/tcp_listen.c    | 14 ++++++++++++--
>   net/smc/af_smc.c        | 21 ++-------------------
>   net/sunrpc/svcsock.c    | 12 ++++++------
>   net/sunrpc/xprtsock.c   | 12 ++++--------
>   8 files changed, 42 insertions(+), 75 deletions(-)
> 
> diff --git a/fs/smb/client/connect.c b/fs/smb/client/connect.c
> index c36c1b4ffe6e..7a67b86c0423 100644
> --- a/fs/smb/client/connect.c
> +++ b/fs/smb/client/connect.c
> @@ -3130,22 +3130,13 @@ generic_ip_connect(struct TCP_Server_Info *server)
>   	if (server->ssocket) {
>   		socket = server->ssocket;
>   	} else {
> -		struct net *net = cifs_net_ns(server);
> -		struct sock *sk;
> -
> -		rc = sock_create_kern(net, sfamily, SOCK_STREAM,
> -				      IPPROTO_TCP, &server->ssocket);
> +		rc = sock_create_net(cifs_net_ns(server), sfamily, SOCK_STREAM,
> +				     IPPROTO_TCP, &server->ssocket);
>   		if (rc < 0) {
>   			cifs_server_dbg(VFS, "Error %d creating socket\n", rc);
>   			return rc;
>   		}
>   
> -		sk = server->ssocket->sk;
> -		__netns_tracker_free(net, &sk->ns_tracker, false);
> -		sk->sk_net_refcnt = 1;
> -		get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(net, 1);
> -
>   		/* BB other socket options to set KEEPALIVE, NODELAY? */
>   		cifs_dbg(FYI, "Socket created\n");
>   		socket = server->ssocket;
> diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c
> index fd021cf8286e..e7e8972bdfca 100644
> --- a/net/mptcp/subflow.c
> +++ b/net/mptcp/subflow.c
> @@ -1755,7 +1755,7 @@ int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
>   	if (unlikely(!sk->sk_socket))
>   		return -EINVAL;
>   
> -	err = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
> +	err = sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
>   	if (err)
>   		return err;
>   
> @@ -1768,14 +1768,6 @@ int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
>   	/* the newly created socket has to be in the same cgroup as its parent */
>   	mptcp_attach_cgroup(sk, sf->sk);
>   
> -	/* kernel sockets do not by default acquire net ref, but TCP timer
> -	 * needs it.
> -	 * Update ns_tracker to current stack trace and refcounted tracker.
> -	 */
> -	__netns_tracker_free(net, &sf->sk->ns_tracker, false);
> -	sf->sk->sk_net_refcnt = 1;
> -	get_net_track(net, &sf->sk->ns_tracker, GFP_KERNEL);
> -	sock_inuse_add(net, 1);
>   	err = tcp_set_ulp(sf->sk, "mptcp");
>   	if (err)
>   		goto err_free;
> diff --git a/net/rds/tcp.c b/net/rds/tcp.c
> index 351ac1747224..4509900476f7 100644
> --- a/net/rds/tcp.c
> +++ b/net/rds/tcp.c
> @@ -494,21 +494,7 @@ bool rds_tcp_tune(struct socket *sock)
>   
>   	tcp_sock_set_nodelay(sock->sk);
>   	lock_sock(sk);
> -	/* TCP timer functions might access net namespace even after
> -	 * a process which created this net namespace terminated.
> -	 */
> -	if (!sk->sk_net_refcnt) {
> -		if (!maybe_get_net(net)) {
> -			release_sock(sk);
> -			return false;
> -		}
> -		/* Update ns_tracker to current stack trace and refcounted tracker */
> -		__netns_tracker_free(net, &sk->ns_tracker, false);
>   
> -		sk->sk_net_refcnt = 1;
> -		netns_tracker_alloc(net, &sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(net, 1);
> -	}
>   	rtn = net_generic(net, rds_tcp_netid);
>   	if (rtn->sndbuf_size > 0) {
>   		sk->sk_sndbuf = rtn->sndbuf_size;
> diff --git a/net/rds/tcp_connect.c b/net/rds/tcp_connect.c
> index a0046e99d6df..c9449780f952 100644
> --- a/net/rds/tcp_connect.c
> +++ b/net/rds/tcp_connect.c
> @@ -93,6 +93,7 @@ int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
>   	struct sockaddr_in6 sin6;
>   	struct sockaddr_in sin;
>   	struct sockaddr *addr;
> +	struct net *net;
>   	int addrlen;
>   	bool isv6;
>   	int ret;
> @@ -107,20 +108,28 @@ int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
>   
>   	mutex_lock(&tc->t_conn_path_lock);
>   
> +	net = rds_conn_net(conn);
> +
>   	if (rds_conn_path_up(cp)) {
> -		mutex_unlock(&tc->t_conn_path_lock);
> -		return 0;
> +		ret = 0;
> +		goto out;
>   	}
> +
> +	if (!maybe_get_net(net)) {
> +		ret = -EINVAL;
> +		goto out;
> +	}
> +
>   	if (ipv6_addr_v4mapped(&conn->c_laddr)) {
> -		ret = sock_create_kern(rds_conn_net(conn), PF_INET,
> -				       SOCK_STREAM, IPPROTO_TCP, &sock);
> +		ret = sock_create_net(net, PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
>   		isv6 = false;
>   	} else {
> -		ret = sock_create_kern(rds_conn_net(conn), PF_INET6,
> -				       SOCK_STREAM, IPPROTO_TCP, &sock);
> +		ret = sock_create_net(net, PF_INET6, SOCK_STREAM, IPPROTO_TCP, &sock);
>   		isv6 = true;
>   	}
>   
> +	put_net(net);
> +
>   	if (ret < 0)
>   		goto out;
>   
> diff --git a/net/rds/tcp_listen.c b/net/rds/tcp_listen.c
> index 69aaf03ab93e..440ac9057148 100644
> --- a/net/rds/tcp_listen.c
> +++ b/net/rds/tcp_listen.c
> @@ -101,6 +101,7 @@ int rds_tcp_accept_one(struct socket *sock)
>   	struct rds_connection *conn;
>   	int ret;
>   	struct inet_sock *inet;
> +	struct net *net;
>   	struct rds_tcp_connection *rs_tcp = NULL;
>   	int conn_state;
>   	struct rds_conn_path *cp;
> @@ -108,7 +109,7 @@ int rds_tcp_accept_one(struct socket *sock)
>   	struct proto_accept_arg arg = {
>   		.flags = O_NONBLOCK,
>   		.kern = true,
> -		.hold_net = false,
> +		.hold_net = true,
>   	};
>   #if !IS_ENABLED(CONFIG_IPV6)
>   	struct in6_addr saddr, daddr;
> @@ -118,13 +119,22 @@ int rds_tcp_accept_one(struct socket *sock)
>   	if (!sock) /* module unload or netns delete in progress */
>   		return -ENETUNREACH;
>   
> +	net = sock_net(sock->sk);
> +
> +	if (!maybe_get_net(net))
> +		return -EINVAL;
> +
>   	ret = sock_create_lite(sock->sk->sk_family,
>   			       sock->sk->sk_type, sock->sk->sk_protocol,
>   			       &new_sock);
> -	if (ret)
> +	if (ret) {
> +		put_net(net);
>   		goto out;
> +	}
>   
>   	ret = sock->ops->accept(sock, new_sock, &arg);
> +	put_net(net);
> +
>   	if (ret < 0)
>   		goto out;
>   
> diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c
> index 6e93f188a908..7b0de80b3aca 100644
> --- a/net/smc/af_smc.c
> +++ b/net/smc/af_smc.c
> @@ -3310,25 +3310,8 @@ static const struct proto_ops smc_sock_ops = {
>   
>   int smc_create_clcsk(struct net *net, struct sock *sk, int family)
>   {
> -	struct smc_sock *smc = smc_sk(sk);
> -	int rc;
> -
> -	rc = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP,
> -			      &smc->clcsock);
> -	if (rc)
> -		return rc;
> -
> -	/* smc_clcsock_release() does not wait smc->clcsock->sk's
> -	 * destruction;  its sk_state might not be TCP_CLOSE after
> -	 * smc->sk is close()d, and TCP timers can be fired later,
> -	 * which need net ref.
> -	 */
> -	sk = smc->clcsock->sk;
> -	__netns_tracker_free(net, &sk->ns_tracker, false);
> -	sk->sk_net_refcnt = 1;
> -	get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
> -	sock_inuse_add(net, 1);
> -	return 0;
> +	return sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP,
> +			       &smc_sk(sk)->clcsock);
>   }
>   
>   static int __smc_create(struct net *net, struct socket *sock, int protocol,
> diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c
> index 9583bad3d150..cde5765f6f81 100644
> --- a/net/sunrpc/svcsock.c
> +++ b/net/sunrpc/svcsock.c
> @@ -1526,7 +1526,10 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
>   		return ERR_PTR(-EINVAL);
>   	}
>   
> -	error = sock_create_kern(net, family, type, protocol, &sock);
> +	if (protocol == IPPROTO_TCP)
> +		error = sock_create_net(net, family, type, protocol, &sock);
> +	else
> +		error = sock_create_kern(net, family, type, protocol, &sock);
>   	if (error < 0)
>   		return ERR_PTR(error);
>   
> @@ -1551,11 +1554,8 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
>   	newlen = error;
>   
>   	if (protocol == IPPROTO_TCP) {
> -		__netns_tracker_free(net, &sock->sk->ns_tracker, false);
> -		sock->sk->sk_net_refcnt = 1;
> -		get_net_track(net, &sock->sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(net, 1);
> -		if ((error = kernel_listen(sock, 64)) < 0)
> +		error = kernel_listen(sock, 64);
> +		if (error < 0)
>   			goto bummer;
>   	}
>   
> diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
> index feb1768e8a57..f3e139c30442 100644
> --- a/net/sunrpc/xprtsock.c
> +++ b/net/sunrpc/xprtsock.c
> @@ -1924,7 +1924,10 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
>   	struct socket *sock;
>   	int err;
>   
> -	err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
> +	if (protocol == IPPROTO_TCP)
> +		err = sock_create_net(xprt->xprt_net, family, type, protocol, &sock);
> +	else
> +		err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
>   	if (err < 0) {
>   		dprintk("RPC:       can't create %d transport socket (%d).\n",
>   				protocol, -err);
> @@ -1941,13 +1944,6 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
>   		goto out;
>   	}
>   
> -	if (protocol == IPPROTO_TCP) {
> -		__netns_tracker_free(xprt->xprt_net, &sock->sk->ns_tracker, false);
> -		sock->sk->sk_net_refcnt = 1;
> -		get_net_track(xprt->xprt_net, &sock->sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(xprt->xprt_net, 1);
> -	}
> -
>   	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
>   	if (IS_ERR(filp))
>   		return ERR_CAST(filp);

For the svcsock.c hunks:

Acked-by: Chuck Lever <chuck.lever@oracle.com>
Wenjia Zhang Dec. 13, 2024, 3:15 p.m. UTC | #4
On 13.12.24 14:54, Kuniyuki Iwashima wrote:
> From: Wenjia Zhang <wenjia@linux.ibm.com>
> Date: Fri, 13 Dec 2024 14:45:20 +0100
>>> diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c
>>> index 6e93f188a908..7b0de80b3aca 100644
>>> --- a/net/smc/af_smc.c
>>> +++ b/net/smc/af_smc.c
>>> @@ -3310,25 +3310,8 @@ static const struct proto_ops smc_sock_ops = {
>>>    
>>>    int smc_create_clcsk(struct net *net, struct sock *sk, int family)
>>>    {
>>> -	struct smc_sock *smc = smc_sk(sk);
>>> -	int rc;
>>> -
>>> -	rc = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP,
>>> -			      &smc->clcsock);
>>> -	if (rc)
>>> -		return rc;
>>> -
>>> -	/* smc_clcsock_release() does not wait smc->clcsock->sk's
>>> -	 * destruction;  its sk_state might not be TCP_CLOSE after
>>> -	 * smc->sk is close()d, and TCP timers can be fired later,
>>> -	 * which need net ref.
>>> -	 */
>>> -	sk = smc->clcsock->sk;
>>> -	__netns_tracker_free(net, &sk->ns_tracker, false);
>>> -	sk->sk_net_refcnt = 1;
>>> -	get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
>>> -	sock_inuse_add(net, 1);
>> I don't think this line shoud be removed. Otherwise, the popurse here to
>> manage the per namespace statistics in the case of network namespace
>> isolation would be lost.
> 
> Now it's counted in sk_alloc().
> 
> sock_create_net() below passes hold_net=true to sk_alloc() and if
> sk->sk_netns_refcnt (== hold_net) is true, sock_inuse_add() is
> called there.
> 
> See patch 9 and 10:
> https://lore.kernel.org/netdev/20241213092152.14057-10-kuniyu@amazon.com/
> https://lore.kernel.org/netdev/20241213092152.14057-11-kuniyu@amazon.com/
> 
> 
ok, I see. Thank you for pointing it out!

Reviewed-by: Wenjia Zhang <wenjia@linux.ibm.com>

>> @D. Wythe, could you please check it again? Maybe you have some good
>> testing on this case.
>>
>>> -	return 0;
>>> +	return sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP,
>>> +			       &smc_sk(sk)->clcsock);
>>>    }
Allison Henderson Dec. 13, 2024, 11:29 p.m. UTC | #5
On Fri, 2024-12-13 at 18:21 +0900, Kuniyuki Iwashima wrote:
> Since commit 26abe14379f8 ("net: Modify sk_alloc to not reference count
> the netns of kernel sockets."), TCP kernel socket has caused many UAF.
> 
> We have converted such sockets to hold netns refcnt, and we have the
> same pattern in cifs, mptcp, rds, smc, and sunrpc.
> 
> Let's drop the conversion and use sock_create_net() instead.
> 
> The changes for cifs, mptcp, and smc are straightforward.
> 
> For rds, we need to move maybe_get_net() before sock_create_net() and
> sock->ops->accept().
> 
> For sunrpc, we call sock_create_net() for IPPROTO_TCP only and still
> call sock_create_kern() for others.
> 
> Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.com>
> Acked-by: Matthieu Baerts (NGI0) <matttbe@kernel.org>
> Acked-by: Allison Henderson <allison.henderson@oracle.com>
> ---
> v3: Add missing mutex_unlock in rds_tcp_conn_path_connect().
> v2: Collect Acked-by from MPTCP and RDS maintainers
> 
> Cc: Steve French <sfrench@samba.org>
> Cc: Wenjia Zhang <wenjia@linux.ibm.com>
> Cc: Jan Karcher <jaka@linux.ibm.com>
> Cc: Chuck Lever <chuck.lever@oracle.com>
> Cc: Jeff Layton <jlayton@kernel.org>
> ---
>  fs/smb/client/connect.c | 13 ++-----------
>  net/mptcp/subflow.c     | 10 +---------
>  net/rds/tcp.c           | 14 --------------
>  net/rds/tcp_connect.c   | 21 +++++++++++++++------
>  net/rds/tcp_listen.c    | 14 ++++++++++++--
>  net/smc/af_smc.c        | 21 ++-------------------
>  net/sunrpc/svcsock.c    | 12 ++++++------
>  net/sunrpc/xprtsock.c   | 12 ++++--------
>  8 files changed, 42 insertions(+), 75 deletions(-)
> 
> diff --git a/fs/smb/client/connect.c b/fs/smb/client/connect.c
> index c36c1b4ffe6e..7a67b86c0423 100644
> --- a/fs/smb/client/connect.c
> +++ b/fs/smb/client/connect.c
> @@ -3130,22 +3130,13 @@ generic_ip_connect(struct TCP_Server_Info *server)
>  	if (server->ssocket) {
>  		socket = server->ssocket;
>  	} else {
> -		struct net *net = cifs_net_ns(server);
> -		struct sock *sk;
> -
> -		rc = sock_create_kern(net, sfamily, SOCK_STREAM,
> -				      IPPROTO_TCP, &server->ssocket);
> +		rc = sock_create_net(cifs_net_ns(server), sfamily, SOCK_STREAM,
> +				     IPPROTO_TCP, &server->ssocket);
>  		if (rc < 0) {
>  			cifs_server_dbg(VFS, "Error %d creating socket\n", rc);
>  			return rc;
>  		}
>  
> -		sk = server->ssocket->sk;
> -		__netns_tracker_free(net, &sk->ns_tracker, false);
> -		sk->sk_net_refcnt = 1;
> -		get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(net, 1);
> -
>  		/* BB other socket options to set KEEPALIVE, NODELAY? */
>  		cifs_dbg(FYI, "Socket created\n");
>  		socket = server->ssocket;
> diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c
> index fd021cf8286e..e7e8972bdfca 100644
> --- a/net/mptcp/subflow.c
> +++ b/net/mptcp/subflow.c
> @@ -1755,7 +1755,7 @@ int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
>  	if (unlikely(!sk->sk_socket))
>  		return -EINVAL;
>  
> -	err = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
> +	err = sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
>  	if (err)
>  		return err;
>  
> @@ -1768,14 +1768,6 @@ int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
>  	/* the newly created socket has to be in the same cgroup as its parent */
>  	mptcp_attach_cgroup(sk, sf->sk);
>  
> -	/* kernel sockets do not by default acquire net ref, but TCP timer
> -	 * needs it.
> -	 * Update ns_tracker to current stack trace and refcounted tracker.
> -	 */
> -	__netns_tracker_free(net, &sf->sk->ns_tracker, false);
> -	sf->sk->sk_net_refcnt = 1;
> -	get_net_track(net, &sf->sk->ns_tracker, GFP_KERNEL);
> -	sock_inuse_add(net, 1);
>  	err = tcp_set_ulp(sf->sk, "mptcp");
>  	if (err)
>  		goto err_free;
> diff --git a/net/rds/tcp.c b/net/rds/tcp.c
> index 351ac1747224..4509900476f7 100644
> --- a/net/rds/tcp.c
> +++ b/net/rds/tcp.c
> @@ -494,21 +494,7 @@ bool rds_tcp_tune(struct socket *sock)
>  
>  	tcp_sock_set_nodelay(sock->sk);
>  	lock_sock(sk);
> -	/* TCP timer functions might access net namespace even after
> -	 * a process which created this net namespace terminated.
> -	 */
> -	if (!sk->sk_net_refcnt) {
> -		if (!maybe_get_net(net)) {
> -			release_sock(sk);
> -			return false;
> -		}
> -		/* Update ns_tracker to current stack trace and refcounted tracker */
> -		__netns_tracker_free(net, &sk->ns_tracker, false);
>  
> -		sk->sk_net_refcnt = 1;
> -		netns_tracker_alloc(net, &sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(net, 1);
> -	}
>  	rtn = net_generic(net, rds_tcp_netid);
>  	if (rtn->sndbuf_size > 0) {
>  		sk->sk_sndbuf = rtn->sndbuf_size;
> diff --git a/net/rds/tcp_connect.c b/net/rds/tcp_connect.c
> index a0046e99d6df..c9449780f952 100644
> --- a/net/rds/tcp_connect.c
> +++ b/net/rds/tcp_connect.c
> @@ -93,6 +93,7 @@ int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
>  	struct sockaddr_in6 sin6;
>  	struct sockaddr_in sin;
>  	struct sockaddr *addr;
> +	struct net *net;
>  	int addrlen;
>  	bool isv6;
>  	int ret;
> @@ -107,20 +108,28 @@ int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
>  
>  	mutex_lock(&tc->t_conn_path_lock);
>  
> +	net = rds_conn_net(conn);
> +
>  	if (rds_conn_path_up(cp)) {
> -		mutex_unlock(&tc->t_conn_path_lock);
> -		return 0;
> +		ret = 0;
> +		goto out;
>  	}
> +
> +	if (!maybe_get_net(net)) {
> +		ret = -EINVAL;
> +		goto out;
> +	}

Ok, this looks much better.  Thank you!

Allison

> +
>  	if (ipv6_addr_v4mapped(&conn->c_laddr)) {
> -		ret = sock_create_kern(rds_conn_net(conn), PF_INET,
> -				       SOCK_STREAM, IPPROTO_TCP, &sock);
> +		ret = sock_create_net(net, PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
>  		isv6 = false;
>  	} else {
> -		ret = sock_create_kern(rds_conn_net(conn), PF_INET6,
> -				       SOCK_STREAM, IPPROTO_TCP, &sock);
> +		ret = sock_create_net(net, PF_INET6, SOCK_STREAM, IPPROTO_TCP, &sock);
>  		isv6 = true;
>  	}
>  
> +	put_net(net);
> +
>  	if (ret < 0)
>  		goto out;
>  
> diff --git a/net/rds/tcp_listen.c b/net/rds/tcp_listen.c
> index 69aaf03ab93e..440ac9057148 100644
> --- a/net/rds/tcp_listen.c
> +++ b/net/rds/tcp_listen.c
> @@ -101,6 +101,7 @@ int rds_tcp_accept_one(struct socket *sock)
>  	struct rds_connection *conn;
>  	int ret;
>  	struct inet_sock *inet;
> +	struct net *net;
>  	struct rds_tcp_connection *rs_tcp = NULL;
>  	int conn_state;
>  	struct rds_conn_path *cp;
> @@ -108,7 +109,7 @@ int rds_tcp_accept_one(struct socket *sock)
>  	struct proto_accept_arg arg = {
>  		.flags = O_NONBLOCK,
>  		.kern = true,
> -		.hold_net = false,
> +		.hold_net = true,
>  	};
>  #if !IS_ENABLED(CONFIG_IPV6)
>  	struct in6_addr saddr, daddr;
> @@ -118,13 +119,22 @@ int rds_tcp_accept_one(struct socket *sock)
>  	if (!sock) /* module unload or netns delete in progress */
>  		return -ENETUNREACH;
>  
> +	net = sock_net(sock->sk);
> +
> +	if (!maybe_get_net(net))
> +		return -EINVAL;
> +
>  	ret = sock_create_lite(sock->sk->sk_family,
>  			       sock->sk->sk_type, sock->sk->sk_protocol,
>  			       &new_sock);
> -	if (ret)
> +	if (ret) {
> +		put_net(net);
>  		goto out;
> +	}
>  
>  	ret = sock->ops->accept(sock, new_sock, &arg);
> +	put_net(net);
> +
>  	if (ret < 0)
>  		goto out;
>  
> diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c
> index 6e93f188a908..7b0de80b3aca 100644
> --- a/net/smc/af_smc.c
> +++ b/net/smc/af_smc.c
> @@ -3310,25 +3310,8 @@ static const struct proto_ops smc_sock_ops = {
>  
>  int smc_create_clcsk(struct net *net, struct sock *sk, int family)
>  {
> -	struct smc_sock *smc = smc_sk(sk);
> -	int rc;
> -
> -	rc = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP,
> -			      &smc->clcsock);
> -	if (rc)
> -		return rc;
> -
> -	/* smc_clcsock_release() does not wait smc->clcsock->sk's
> -	 * destruction;  its sk_state might not be TCP_CLOSE after
> -	 * smc->sk is close()d, and TCP timers can be fired later,
> -	 * which need net ref.
> -	 */
> -	sk = smc->clcsock->sk;
> -	__netns_tracker_free(net, &sk->ns_tracker, false);
> -	sk->sk_net_refcnt = 1;
> -	get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
> -	sock_inuse_add(net, 1);
> -	return 0;
> +	return sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP,
> +			       &smc_sk(sk)->clcsock);
>  }
>  
>  static int __smc_create(struct net *net, struct socket *sock, int protocol,
> diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c
> index 9583bad3d150..cde5765f6f81 100644
> --- a/net/sunrpc/svcsock.c
> +++ b/net/sunrpc/svcsock.c
> @@ -1526,7 +1526,10 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
>  		return ERR_PTR(-EINVAL);
>  	}
>  
> -	error = sock_create_kern(net, family, type, protocol, &sock);
> +	if (protocol == IPPROTO_TCP)
> +		error = sock_create_net(net, family, type, protocol, &sock);
> +	else
> +		error = sock_create_kern(net, family, type, protocol, &sock);
>  	if (error < 0)
>  		return ERR_PTR(error);
>  
> @@ -1551,11 +1554,8 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
>  	newlen = error;
>  
>  	if (protocol == IPPROTO_TCP) {
> -		__netns_tracker_free(net, &sock->sk->ns_tracker, false);
> -		sock->sk->sk_net_refcnt = 1;
> -		get_net_track(net, &sock->sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(net, 1);
> -		if ((error = kernel_listen(sock, 64)) < 0)
> +		error = kernel_listen(sock, 64);
> +		if (error < 0)
>  			goto bummer;
>  	}
>  
> diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
> index feb1768e8a57..f3e139c30442 100644
> --- a/net/sunrpc/xprtsock.c
> +++ b/net/sunrpc/xprtsock.c
> @@ -1924,7 +1924,10 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
>  	struct socket *sock;
>  	int err;
>  
> -	err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
> +	if (protocol == IPPROTO_TCP)
> +		err = sock_create_net(xprt->xprt_net, family, type, protocol, &sock);
> +	else
> +		err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
>  	if (err < 0) {
>  		dprintk("RPC:       can't create %d transport socket (%d).\n",
>  				protocol, -err);
> @@ -1941,13 +1944,6 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
>  		goto out;
>  	}
>  
> -	if (protocol == IPPROTO_TCP) {
> -		__netns_tracker_free(xprt->xprt_net, &sock->sk->ns_tracker, false);
> -		sock->sk->sk_net_refcnt = 1;
> -		get_net_track(xprt->xprt_net, &sock->sk->ns_tracker, GFP_KERNEL);
> -		sock_inuse_add(xprt->xprt_net, 1);
> -	}
> -
>  	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
>  	if (IS_ERR(filp))
>  		return ERR_CAST(filp);
diff mbox series

Patch

diff --git a/fs/smb/client/connect.c b/fs/smb/client/connect.c
index c36c1b4ffe6e..7a67b86c0423 100644
--- a/fs/smb/client/connect.c
+++ b/fs/smb/client/connect.c
@@ -3130,22 +3130,13 @@  generic_ip_connect(struct TCP_Server_Info *server)
 	if (server->ssocket) {
 		socket = server->ssocket;
 	} else {
-		struct net *net = cifs_net_ns(server);
-		struct sock *sk;
-
-		rc = sock_create_kern(net, sfamily, SOCK_STREAM,
-				      IPPROTO_TCP, &server->ssocket);
+		rc = sock_create_net(cifs_net_ns(server), sfamily, SOCK_STREAM,
+				     IPPROTO_TCP, &server->ssocket);
 		if (rc < 0) {
 			cifs_server_dbg(VFS, "Error %d creating socket\n", rc);
 			return rc;
 		}
 
-		sk = server->ssocket->sk;
-		__netns_tracker_free(net, &sk->ns_tracker, false);
-		sk->sk_net_refcnt = 1;
-		get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
-		sock_inuse_add(net, 1);
-
 		/* BB other socket options to set KEEPALIVE, NODELAY? */
 		cifs_dbg(FYI, "Socket created\n");
 		socket = server->ssocket;
diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c
index fd021cf8286e..e7e8972bdfca 100644
--- a/net/mptcp/subflow.c
+++ b/net/mptcp/subflow.c
@@ -1755,7 +1755,7 @@  int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
 	if (unlikely(!sk->sk_socket))
 		return -EINVAL;
 
-	err = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
+	err = sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
 	if (err)
 		return err;
 
@@ -1768,14 +1768,6 @@  int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
 	/* the newly created socket has to be in the same cgroup as its parent */
 	mptcp_attach_cgroup(sk, sf->sk);
 
-	/* kernel sockets do not by default acquire net ref, but TCP timer
-	 * needs it.
-	 * Update ns_tracker to current stack trace and refcounted tracker.
-	 */
-	__netns_tracker_free(net, &sf->sk->ns_tracker, false);
-	sf->sk->sk_net_refcnt = 1;
-	get_net_track(net, &sf->sk->ns_tracker, GFP_KERNEL);
-	sock_inuse_add(net, 1);
 	err = tcp_set_ulp(sf->sk, "mptcp");
 	if (err)
 		goto err_free;
diff --git a/net/rds/tcp.c b/net/rds/tcp.c
index 351ac1747224..4509900476f7 100644
--- a/net/rds/tcp.c
+++ b/net/rds/tcp.c
@@ -494,21 +494,7 @@  bool rds_tcp_tune(struct socket *sock)
 
 	tcp_sock_set_nodelay(sock->sk);
 	lock_sock(sk);
-	/* TCP timer functions might access net namespace even after
-	 * a process which created this net namespace terminated.
-	 */
-	if (!sk->sk_net_refcnt) {
-		if (!maybe_get_net(net)) {
-			release_sock(sk);
-			return false;
-		}
-		/* Update ns_tracker to current stack trace and refcounted tracker */
-		__netns_tracker_free(net, &sk->ns_tracker, false);
 
-		sk->sk_net_refcnt = 1;
-		netns_tracker_alloc(net, &sk->ns_tracker, GFP_KERNEL);
-		sock_inuse_add(net, 1);
-	}
 	rtn = net_generic(net, rds_tcp_netid);
 	if (rtn->sndbuf_size > 0) {
 		sk->sk_sndbuf = rtn->sndbuf_size;
diff --git a/net/rds/tcp_connect.c b/net/rds/tcp_connect.c
index a0046e99d6df..c9449780f952 100644
--- a/net/rds/tcp_connect.c
+++ b/net/rds/tcp_connect.c
@@ -93,6 +93,7 @@  int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
 	struct sockaddr_in6 sin6;
 	struct sockaddr_in sin;
 	struct sockaddr *addr;
+	struct net *net;
 	int addrlen;
 	bool isv6;
 	int ret;
@@ -107,20 +108,28 @@  int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
 
 	mutex_lock(&tc->t_conn_path_lock);
 
+	net = rds_conn_net(conn);
+
 	if (rds_conn_path_up(cp)) {
-		mutex_unlock(&tc->t_conn_path_lock);
-		return 0;
+		ret = 0;
+		goto out;
 	}
+
+	if (!maybe_get_net(net)) {
+		ret = -EINVAL;
+		goto out;
+	}
+
 	if (ipv6_addr_v4mapped(&conn->c_laddr)) {
-		ret = sock_create_kern(rds_conn_net(conn), PF_INET,
-				       SOCK_STREAM, IPPROTO_TCP, &sock);
+		ret = sock_create_net(net, PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
 		isv6 = false;
 	} else {
-		ret = sock_create_kern(rds_conn_net(conn), PF_INET6,
-				       SOCK_STREAM, IPPROTO_TCP, &sock);
+		ret = sock_create_net(net, PF_INET6, SOCK_STREAM, IPPROTO_TCP, &sock);
 		isv6 = true;
 	}
 
+	put_net(net);
+
 	if (ret < 0)
 		goto out;
 
diff --git a/net/rds/tcp_listen.c b/net/rds/tcp_listen.c
index 69aaf03ab93e..440ac9057148 100644
--- a/net/rds/tcp_listen.c
+++ b/net/rds/tcp_listen.c
@@ -101,6 +101,7 @@  int rds_tcp_accept_one(struct socket *sock)
 	struct rds_connection *conn;
 	int ret;
 	struct inet_sock *inet;
+	struct net *net;
 	struct rds_tcp_connection *rs_tcp = NULL;
 	int conn_state;
 	struct rds_conn_path *cp;
@@ -108,7 +109,7 @@  int rds_tcp_accept_one(struct socket *sock)
 	struct proto_accept_arg arg = {
 		.flags = O_NONBLOCK,
 		.kern = true,
-		.hold_net = false,
+		.hold_net = true,
 	};
 #if !IS_ENABLED(CONFIG_IPV6)
 	struct in6_addr saddr, daddr;
@@ -118,13 +119,22 @@  int rds_tcp_accept_one(struct socket *sock)
 	if (!sock) /* module unload or netns delete in progress */
 		return -ENETUNREACH;
 
+	net = sock_net(sock->sk);
+
+	if (!maybe_get_net(net))
+		return -EINVAL;
+
 	ret = sock_create_lite(sock->sk->sk_family,
 			       sock->sk->sk_type, sock->sk->sk_protocol,
 			       &new_sock);
-	if (ret)
+	if (ret) {
+		put_net(net);
 		goto out;
+	}
 
 	ret = sock->ops->accept(sock, new_sock, &arg);
+	put_net(net);
+
 	if (ret < 0)
 		goto out;
 
diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c
index 6e93f188a908..7b0de80b3aca 100644
--- a/net/smc/af_smc.c
+++ b/net/smc/af_smc.c
@@ -3310,25 +3310,8 @@  static const struct proto_ops smc_sock_ops = {
 
 int smc_create_clcsk(struct net *net, struct sock *sk, int family)
 {
-	struct smc_sock *smc = smc_sk(sk);
-	int rc;
-
-	rc = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP,
-			      &smc->clcsock);
-	if (rc)
-		return rc;
-
-	/* smc_clcsock_release() does not wait smc->clcsock->sk's
-	 * destruction;  its sk_state might not be TCP_CLOSE after
-	 * smc->sk is close()d, and TCP timers can be fired later,
-	 * which need net ref.
-	 */
-	sk = smc->clcsock->sk;
-	__netns_tracker_free(net, &sk->ns_tracker, false);
-	sk->sk_net_refcnt = 1;
-	get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
-	sock_inuse_add(net, 1);
-	return 0;
+	return sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP,
+			       &smc_sk(sk)->clcsock);
 }
 
 static int __smc_create(struct net *net, struct socket *sock, int protocol,
diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c
index 9583bad3d150..cde5765f6f81 100644
--- a/net/sunrpc/svcsock.c
+++ b/net/sunrpc/svcsock.c
@@ -1526,7 +1526,10 @@  static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
 		return ERR_PTR(-EINVAL);
 	}
 
-	error = sock_create_kern(net, family, type, protocol, &sock);
+	if (protocol == IPPROTO_TCP)
+		error = sock_create_net(net, family, type, protocol, &sock);
+	else
+		error = sock_create_kern(net, family, type, protocol, &sock);
 	if (error < 0)
 		return ERR_PTR(error);
 
@@ -1551,11 +1554,8 @@  static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
 	newlen = error;
 
 	if (protocol == IPPROTO_TCP) {
-		__netns_tracker_free(net, &sock->sk->ns_tracker, false);
-		sock->sk->sk_net_refcnt = 1;
-		get_net_track(net, &sock->sk->ns_tracker, GFP_KERNEL);
-		sock_inuse_add(net, 1);
-		if ((error = kernel_listen(sock, 64)) < 0)
+		error = kernel_listen(sock, 64);
+		if (error < 0)
 			goto bummer;
 	}
 
diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
index feb1768e8a57..f3e139c30442 100644
--- a/net/sunrpc/xprtsock.c
+++ b/net/sunrpc/xprtsock.c
@@ -1924,7 +1924,10 @@  static struct socket *xs_create_sock(struct rpc_xprt *xprt,
 	struct socket *sock;
 	int err;
 
-	err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
+	if (protocol == IPPROTO_TCP)
+		err = sock_create_net(xprt->xprt_net, family, type, protocol, &sock);
+	else
+		err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
 	if (err < 0) {
 		dprintk("RPC:       can't create %d transport socket (%d).\n",
 				protocol, -err);
@@ -1941,13 +1944,6 @@  static struct socket *xs_create_sock(struct rpc_xprt *xprt,
 		goto out;
 	}
 
-	if (protocol == IPPROTO_TCP) {
-		__netns_tracker_free(xprt->xprt_net, &sock->sk->ns_tracker, false);
-		sock->sk->sk_net_refcnt = 1;
-		get_net_track(xprt->xprt_net, &sock->sk->ns_tracker, GFP_KERNEL);
-		sock_inuse_add(xprt->xprt_net, 1);
-	}
-
 	filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
 	if (IS_ERR(filp))
 		return ERR_CAST(filp);