diff mbox series

[RFC,10/15] l2tp: refactor ppp socket/session relationship

Message ID 50d03c0ac3a67542fa876d53812b3d96b62e3594.1721733730.git.jchapman@katalix.com (mailing list archive)
State Superseded
Delegated to: Netdev Maintainers
Headers show
Series l2tp: simplify tunnel and session cleanup | expand

Checks

Context Check Description
netdev/tree_selection success Guessing tree name failed - patch did not apply

Commit Message

James Chapman July 23, 2024, 1:51 p.m. UTC
Each l2tp ppp session has an associated pppox socket. l2tp_ppp uses
the session's pppox socket refcount to manage session lifetimes; the
pppox socket holds a ref on the session which is dropped by the socket
destructor. This complicates session cleanup.

Given l2tp sessions are refcounted, it makes more sense to reverse
this relationship such that the session keeps the socket alive, not
the other way around. So refactor l2tp_ppp to have the session hold a
ref on its socket while it references it. When the session is closed,
it drops its socket ref when it detaches from its socket. If the
socket is closed first, it initiates the closing of its session, if
one is attached. The socket/session can then be freed asynchronously
when their refcounts drop to 0.

Use the session's session_close callback to detach the pppox socket
since this will be done on the work queue together with the rest of
the session cleanup via l2tp_session_delete.

Also, since l2tp_ppp uses the pppox socket's sk_user_data, use the rcu
sk_user_data access helpers when accessing it and set the socket's
SOCK_RCU_FREE flag to have pppox sockets freed by rcu.
---
 net/l2tp/l2tp_ppp.c | 94 +++++++++++++++++++--------------------------
 1 file changed, 39 insertions(+), 55 deletions(-)
diff mbox series

Patch

diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c
index 0844b86cd0a6..12a0a7162870 100644
--- a/net/l2tp/l2tp_ppp.c
+++ b/net/l2tp/l2tp_ppp.c
@@ -119,7 +119,6 @@  struct pppol2tp_session {
 	struct mutex		sk_lock;	/* Protects .sk */
 	struct sock __rcu	*sk;		/* Pointer to the session PPPoX socket */
 	struct sock		*__sk;		/* Copy of .sk, for cleanup */
-	struct rcu_head		rcu;		/* For asynchronous release */
 };
 
 static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb);
@@ -157,20 +156,16 @@  static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk)
 	if (!sk)
 		return NULL;
 
-	sock_hold(sk);
-	session = (struct l2tp_session *)(sk->sk_user_data);
-	if (!session) {
-		sock_put(sk);
-		goto out;
-	}
-	if (WARN_ON(session->magic != L2TP_SESSION_MAGIC)) {
-		session = NULL;
-		sock_put(sk);
-		goto out;
+	rcu_read_lock();
+	session = rcu_dereference_sk_user_data(sk);
+	if (session && refcount_inc_not_zero(&session->ref_count)) {
+		rcu_read_unlock();
+		WARN_ON_ONCE(session->magic != L2TP_SESSION_MAGIC);
+		return session;
 	}
+	rcu_read_unlock();
 
-out:
-	return session;
+	return NULL;
 }
 
 /*****************************************************************************
@@ -318,12 +313,12 @@  static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m,
 	l2tp_xmit_skb(session, skb);
 	local_bh_enable();
 
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 
 	return total_len;
 
 error_put_sess:
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 error:
 	return error;
 }
@@ -377,12 +372,12 @@  static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
 	l2tp_xmit_skb(session, skb);
 	local_bh_enable();
 
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 
 	return 1;
 
 abort_put_sess:
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 abort:
 	/* Free the original skb */
 	kfree_skb(skb);
@@ -393,28 +388,31 @@  static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
  * Session (and tunnel control) socket create/destroy.
  *****************************************************************************/
 
-static void pppol2tp_put_sk(struct rcu_head *head)
-{
-	struct pppol2tp_session *ps;
-
-	ps = container_of(head, typeof(*ps), rcu);
-	sock_put(ps->__sk);
-}
-
 /* Really kill the session socket. (Called from sock_put() if
  * refcnt == 0.)
  */
 static void pppol2tp_session_destruct(struct sock *sk)
 {
-	struct l2tp_session *session = sk->sk_user_data;
-
 	skb_queue_purge(&sk->sk_receive_queue);
 	skb_queue_purge(&sk->sk_write_queue);
+}
 
-	if (session) {
-		sk->sk_user_data = NULL;
-		if (WARN_ON(session->magic != L2TP_SESSION_MAGIC))
-			return;
+static void pppol2tp_session_close(struct l2tp_session *session)
+{
+	struct pppol2tp_session *ps;
+
+	ps = l2tp_session_priv(session);
+	mutex_lock(&ps->sk_lock);
+	ps->__sk = rcu_dereference_protected(ps->sk,
+					     lockdep_is_held(&ps->sk_lock));
+	RCU_INIT_POINTER(ps->sk, NULL);
+	mutex_unlock(&ps->sk_lock);
+	if (ps->__sk) {
+		/* detach socket */
+		rcu_assign_sk_user_data(ps->__sk, NULL);
+		sock_put(ps->__sk);
+
+		/* drop ref taken when we referenced socket via sk_user_data */
 		l2tp_session_dec_refcount(session);
 	}
 }
@@ -444,30 +442,13 @@  static int pppol2tp_release(struct socket *sock)
 
 	session = pppol2tp_sock_to_session(sk);
 	if (session) {
-		struct pppol2tp_session *ps;
-
 		l2tp_session_delete(session);
-
-		ps = l2tp_session_priv(session);
-		mutex_lock(&ps->sk_lock);
-		ps->__sk = rcu_dereference_protected(ps->sk,
-						     lockdep_is_held(&ps->sk_lock));
-		RCU_INIT_POINTER(ps->sk, NULL);
-		mutex_unlock(&ps->sk_lock);
-		call_rcu(&ps->rcu, pppol2tp_put_sk);
-
-		/* Rely on the sock_put() call at the end of the function for
-		 * dropping the reference held by pppol2tp_sock_to_session().
-		 * The last reference will be dropped by pppol2tp_put_sk().
-		 */
+		/* drop ref taken by pppol2tp_sock_to_session */
+		l2tp_session_dec_refcount(session);
 	}
 
 	release_sock(sk);
 
-	/* This will delete the session context via
-	 * pppol2tp_session_destruct() if the socket's refcnt drops to
-	 * zero.
-	 */
 	sock_put(sk);
 
 	return 0;
@@ -506,6 +487,7 @@  static int pppol2tp_create(struct net *net, struct socket *sock, int kern)
 		goto out;
 
 	sock_init_data(sock, sk);
+	sock_set_flag(sk, SOCK_RCU_FREE);
 
 	sock->state  = SS_UNCONNECTED;
 	sock->ops    = &pppol2tp_ops;
@@ -542,6 +524,7 @@  static void pppol2tp_session_init(struct l2tp_session *session)
 	struct pppol2tp_session *ps;
 
 	session->recv_skb = pppol2tp_recv;
+	session->session_close = pppol2tp_session_close;
 	if (IS_ENABLED(CONFIG_L2TP_DEBUGFS))
 		session->show = pppol2tp_show;
 
@@ -830,12 +813,13 @@  static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr,
 
 out_no_ppp:
 	/* This is how we get the session context from the socket. */
-	sk->sk_user_data = session;
+	sock_hold(sk);
+	rcu_assign_sk_user_data(sk, session);
 	rcu_assign_pointer(ps->sk, sk);
 	mutex_unlock(&ps->sk_lock);
 
 	/* Keep the reference we've grabbed on the session: sk doesn't expect
-	 * the session to disappear. pppol2tp_session_destruct() is responsible
+	 * the session to disappear. pppol2tp_session_close() is responsible
 	 * for dropping it.
 	 */
 	drop_refcnt = false;
@@ -1002,7 +986,7 @@  static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr,
 
 	error = len;
 
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 end:
 	return error;
 }
@@ -1274,7 +1258,7 @@  static int pppol2tp_setsockopt(struct socket *sock, int level, int optname,
 		err = pppol2tp_session_setsockopt(sk, session, optname, val);
 	}
 
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 end:
 	return err;
 }
@@ -1395,7 +1379,7 @@  static int pppol2tp_getsockopt(struct socket *sock, int level, int optname,
 	err = 0;
 
 end_put_sess:
-	sock_put(sk);
+	l2tp_session_dec_refcount(session);
 end:
 	return err;
 }