diff mbox series

[net-next,10/15] l2tp: refactor ppp socket/session relationship

Message ID 444e08a9ef4f955b6fb4893c3e59fc7240de8f68.1722265212.git.jchapman@katalix.com (mailing list archive)
State Accepted
Commit c5cbaef992d6420d8bcebea1b1fcc23302a67c57
Delegated to: Netdev Maintainers
Headers show
Series l2tp: simplify tunnel and session cleanup | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for net-next
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 42 this patch: 42
netdev/build_tools success No tools touched, skip
netdev/cc_maintainers success CCed 5 of 5 maintainers
netdev/build_clang success Errors and warnings before: 43 this patch: 43
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 43 this patch: 43
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 193 lines checked
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0
netdev/contest success net-next-2024-07-29--18-00 (tests: 703)

Commit Message

James Chapman July 29, 2024, 3:38 p.m. UTC
Each l2tp ppp session has an associated pppox socket. l2tp_ppp uses
the session's pppox socket refcount to manage session lifetimes; the
pppox socket holds a ref on the session which is dropped by the socket
destructor. This complicates session cleanup.

Given l2tp sessions are refcounted, it makes more sense to reverse
this relationship such that the session keeps the socket alive, not
the other way around. So refactor l2tp_ppp to have the session hold a
ref on its socket while it references it. When the session is closed,
it drops its socket ref when it detaches from its socket. If the
socket is closed first, it initiates the closing of its session, if
one is attached. The socket/session can then be freed asynchronously
when their refcounts drop to 0.

Use the session's session_close callback to detach the pppox socket
since this will be done on the work queue together with the rest of
the session cleanup via l2tp_session_delete.

Also, since l2tp_ppp uses the pppox socket's sk_user_data, use the rcu
sk_user_data access helpers when accessing it and set the socket's
SOCK_RCU_FREE flag to have pppox sockets freed by rcu.

Signed-off-by: James Chapman <jchapman@katalix.com>
Signed-off-by: Tom Parkin <tparkin@katalix.com>
---
 net/l2tp/l2tp_ppp.c | 94 +++++++++++++++++++--------------------------
 1 file changed, 39 insertions(+), 55 deletions(-)
diff mbox series

Patch

diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c
index 0844b86cd0a6..12a0a7162870 100644
--- a/net/l2tp/l2tp_ppp.c
+++ b/net/l2tp/l2tp_ppp.c
@@ -119,7 +119,6 @@  struct pppol2tp_session {
 	struct mutex		sk_lock;	/* Protects .sk */
 	struct sock __rcu	*sk;		/* Pointer to the session PPPoX socket */
 	struct sock		*__sk;		/* Copy of .sk, for cleanup */
-	struct rcu_head		rcu;		/* For asynchronous release */
 };
 
 static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb);
@@ -157,20 +156,16 @@  static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk)
 	if (!sk)
 		return NULL;
 
-	sock_hold(sk);
-	session = (struct l2tp_session *)(sk->sk_user_data);
-	if (!session) {
-		sock_put(sk);
-		goto out;
-	}
-	if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) {
-		session = NULL;
-		sock_put(sk);
-		goto out;
+	rcu_read_lock();
+	session = rcu_dereference_sk_user_data(sk);
+	if (session && refcount_inc_not_zero(&session->ref_count)) {
+		rcu_read_unlock();
+		WARN_ON_ONCE(session->magic != L2TP_SESSION_MAGIC);
+		return session;
 	}
+	rcu_read_unlock();
 
-out:
-	return session;
+	return NULL;
 }
 
 /*****************************************************************************
@@ -318,12 +313,12 @@  static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m,
 	l2tp_xmit_skb(session, skb);
 	local_bh_enable();
 
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 
 	return total_len;
 
 error_put_sess:
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 error:
 	return error;
 }
@@ -377,12 +372,12 @@  static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
 	l2tp_xmit_skb(session, skb);
 	local_bh_enable();
 
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 
 	return 1;
 
 abort_put_sess:
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 abort:
 	/* Free the original skb */
 	kfree_skb(skb);
@@ -393,28 +388,31 @@  static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
  * Session (and tunnel control) socket create/destroy.
  *****************************************************************************/
 
-static void pppol2tp_put_sk(struct rcu_head *head)
-{
-	struct pppol2tp_session *ps;
-
-	ps = container_of(head, typeof(*ps), rcu);
-	sock_put(ps->__sk);
-}
-
 /* Really kill the session socket. (Called from sock_put() if
  * refcnt == 0.)
  */
 static void pppol2tp_session_destruct(struct sock *sk)
 {
-	struct l2tp_session *session = sk->sk_user_data;
-
 	skb_queue_purge(&sk->sk_receive_queue);
 	skb_queue_purge(&sk->sk_write_queue);
+}
 
-	if (session) {
-		sk->sk_user_data = NULL;
-		if (WARN_ON(session->magic != L2TP_SESSION_MAGIC))
-			return;
+static void pppol2tp_session_close(struct l2tp_session *session)
+{
+	struct pppol2tp_session *ps;
+
+	ps = l2tp_session_priv(session);
+	mutex_lock(&ps->sk_lock);
+	ps->__sk = rcu_dereference_protected(ps->sk,
+					     lockdep_is_held(&ps->sk_lock));
+	RCU_INIT_POINTER(ps->sk, NULL);
+	mutex_unlock(&ps->sk_lock);
+	if (ps->__sk) {
+		/* detach socket */
+		rcu_assign_sk_user_data(ps->__sk, NULL);
+		sock_put(ps->__sk);
+
+		/* drop ref taken when we referenced socket via sk_user_data */
 		l2tp_session_dec_refcount(session);
 	}
 }
@@ -444,30 +442,13 @@  static int pppol2tp_release(struct socket *sock)
 
 	session = pppol2tp_sock_to_session(sk);
 	if (session) {
-		struct pppol2tp_session *ps;
-
 		l2tp_session_delete(session);
-
-		ps = l2tp_session_priv(session);
-		mutex_lock(&ps->sk_lock);
-		ps->__sk = rcu_dereference_protected(ps->sk,
-						     lockdep_is_held(&ps->sk_lock));
-		RCU_INIT_POINTER(ps->sk, NULL);
-		mutex_unlock(&ps->sk_lock);
-		call_rcu(&ps->rcu, pppol2tp_put_sk);
-
-		/* Rely on the sock_put() call at the end of the function for
-		 * dropping the reference held by pppol2tp_sock_to_session().
-		 * The last reference will be dropped by pppol2tp_put_sk().
-		 */
+		/* drop ref taken by pppol2tp_sock_to_session */
+		l2tp_session_dec_refcount(session);
 	}
 
 	release_sock(sk);
 
-	/* This will delete the session context via
-	 * pppol2tp_session_destruct() if the socket's refcnt drops to
-	 * zero.
-	 */
 	sock_put(sk);
 
 	return 0;
@@ -506,6 +487,7 @@  static int pppol2tp_create(struct net *net, struct socket *sock, int kern)
 		goto out;
 
 	sock_init_data(sock, sk);
+	sock_set_flag(sk, SOCK_RCU_FREE);
 
 	sock->state  = SS_UNCONNECTED;
 	sock->ops    = &pppol2tp_ops;
@@ -542,6 +524,7 @@  static void pppol2tp_session_init(struct l2tp_session *session)
 	struct pppol2tp_session *ps;
 
 	session->recv_skb = pppol2tp_recv;
+	session->session_close = pppol2tp_session_close;
 	if (IS_ENABLED(CONFIG_L2TP_DEBUGFS))
 		session->show = pppol2tp_show;
 
@@ -830,12 +813,13 @@  static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 
 out_no_ppp:
 	/* This is how we get the session context from the socket. */
-	sk->sk_user_data = session;
+	sock_hold(sk);
+	rcu_assign_sk_user_data(sk, session);
 	rcu_assign_pointer(ps->sk, sk);
 	mutex_unlock(&ps->sk_lock);
 
 	/* Keep the reference we've grabbed on the session: sk doesn't expect
-	 * the session to disappear. pppol2tp_session_destruct() is responsible
+	 * the session to disappear. pppol2tp_session_close() is responsible
 	 * for dropping it.
 	 */
 	drop_refcnt = false;
@@ -1002,7 +986,7 @@  static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr,
 
 	error = len;
 
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 end:
 	return error;
 }
@@ -1274,7 +1258,7 @@  static int pppol2tp_setsockopt(struct socket *sock, int level, int optname,
 		err = pppol2tp_session_setsockopt(sk, session, optname, val);
 	}
 
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 end:
 	return err;
 }
@@ -1395,7 +1379,7 @@  static int pppol2tp_getsockopt(struct socket *sock, int level, int optname,
 	err = 0;
 
 end_put_sess:
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 end:
 	return err;
 }