diff mbox series

[net-next,03/10] mctp: locking, lifetime and validity changes for sk_keys

Message ID 20210929072614.854015-4-matt@codeconstruct.com.au (mailing list archive)
State Accepted
Commit 73c618456dc5cf2acb597256d633060cf75de8d6
Delegated to: Netdev Maintainers
Headers show
Series Updates to MCTP core | expand

Checks

Context Check Description
netdev/cover_letter success Link
netdev/fixes_present success Link
netdev/patch_count success Link
netdev/tree_selection success Clearly marked for net-next
netdev/subject_prefix success Link
netdev/cc_maintainers success CCed 5 of 5 maintainers
netdev/source_inline success Was 0 now: 0
netdev/verify_signedoff success Link
netdev/module_param success Was 0 now: 0
netdev/build_32bit fail Errors and warnings before: 0 this patch: 2
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/verify_fixes success Link
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 367 lines checked
netdev/build_allmodconfig_warn fail Errors and warnings before: 0 this patch: 2
netdev/header_inline success Link

Commit Message

Matt Johnston Sept. 29, 2021, 7:26 a.m. UTC
From: Jeremy Kerr <jk@codeconstruct.com.au>

We will want to invalidate sk_keys in a future change, which will
require a boolean flag to mark invalidated items in the socket & net
namespace lists. We'll also need to take a reference to keys, held over
non-atomic contexts, so we need a refcount on keys also.

This change adds a validity flag (currently always true) and refcount to
struct mctp_sk_key.  With a refcount on the keys, using RCU no longer
makes much sense; we have exact indications on the lifetime of keys. So,
we also change the RCU list traversal to a locked implementation.

Signed-off-by: Jeremy Kerr <jk@codeconstruct.com.au>
---
 include/net/mctp.h |  46 ++++++++++++------
 net/mctp/af_mctp.c |  14 +++---
 net/mctp/route.c   | 118 +++++++++++++++++++++++++++++++++------------
 3 files changed, 125 insertions(+), 53 deletions(-)
diff mbox series

Patch

diff --git a/include/net/mctp.h b/include/net/mctp.h
index a824d47c3c6d..bf783dc3ea45 100644
--- a/include/net/mctp.h
+++ b/include/net/mctp.h
@@ -67,30 +67,36 @@  struct mctp_sock {
 /* Key for matching incoming packets to sockets or reassembly contexts.
  * Packets are matched on (src,dest,tag).
  *
- * Lifetime requirements:
+ * Lifetime / locking requirements:
  *
- *  - keys are free()ed via RCU
+ *  - individual key data (ie, the struct itself) is protected by key->lock;
+ *    changes must be made with that lock held.
+ *
+ *  - the lookup fields: peer_addr, local_addr and tag are set before the
+ *    key is added to lookup lists, and never updated.
+ *
+ *  - A ref to the key must be held (throuh key->refs) if a pointer to the
+ *    key is to be accessed after key->lock is released.
  *
  *  - a mctp_sk_key contains a reference to a struct sock; this is valid
  *    for the life of the key. On sock destruction (through unhash), the key is
- *    removed from lists (see below), and will not be observable after a RCU
- *    grace period.
- *
- *    any RX occurring within that grace period may still queue to the socket,
- *    but will hit the SOCK_DEAD case before the socket is freed.
+ *    removed from lists (see below), and marked invalid.
  *
  * - these mctp_sk_keys appear on two lists:
  *     1) the struct mctp_sock->keys list
  *     2) the struct netns_mctp->keys list
  *
- *        updates to either list are performed under the netns_mctp->keys
- *        lock.
+ *   presences on these lists requires a (single) refcount to be held; both
+ *   lists are updated as a single operation.
+ *
+ *   Updates and lookups in either list are performed under the
+ *   netns_mctp->keys lock. Lookup functions will need to lock the key and
+ *   take a reference before unlocking the keys_lock. Consequently, the list's
+ *   keys_lock *cannot* be acquired with the individual key->lock held.
  *
  * - a key may have a sk_buff attached as part of an in-progress message
- *   reassembly (->reasm_head). The reassembly context is protected by
- *   reasm_lock, which may be acquired with the keys lock (above) held, if
- *   necessary. Consequently, keys lock *cannot* be acquired with the
- *   reasm_lock held.
+ *   reassembly (->reasm_head). The reasm data is protected by the individual
+ *   key->lock.
  *
  * - there are two destruction paths for a mctp_sk_key:
  *
@@ -116,14 +122,22 @@  struct mctp_sk_key {
 	/* per-socket list */
 	struct hlist_node sklist;
 
+	/* lock protects against concurrent updates to the reassembly and
+	 * expiry data below.
+	 */
+	spinlock_t	lock;
+
+	/* Keys are referenced during the output path, which may sleep */
+	refcount_t	refs;
+
 	/* incoming fragment reassembly context */
-	spinlock_t	reasm_lock;
 	struct sk_buff	*reasm_head;
 	struct sk_buff	**reasm_tailp;
 	bool		reasm_dead;
 	u8		last_seq;
 
-	struct rcu_head	rcu;
+	/* key validity */
+	bool		valid;
 };
 
 struct mctp_skb_cb {
@@ -191,6 +205,8 @@  int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb);
 int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 		      struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag);
 
+void mctp_key_unref(struct mctp_sk_key *key);
+
 /* routing <--> device interface */
 unsigned int mctp_default_net(struct net *net);
 int mctp_default_net_set(struct net *net, unsigned int index);
diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c
index a9526ac29dff..2767d548736b 100644
--- a/net/mctp/af_mctp.c
+++ b/net/mctp/af_mctp.c
@@ -263,21 +263,21 @@  static void mctp_sk_unhash(struct sock *sk)
 	/* remove tag allocations */
 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
-		hlist_del_rcu(&key->sklist);
-		hlist_del_rcu(&key->hlist);
+		hlist_del(&key->sklist);
+		hlist_del(&key->hlist);
 
-		spin_lock(&key->reasm_lock);
+		spin_lock(&key->lock);
 		if (key->reasm_head)
 			kfree_skb(key->reasm_head);
 		key->reasm_head = NULL;
 		key->reasm_dead = true;
-		spin_unlock(&key->reasm_lock);
+		key->valid = false;
+		spin_unlock(&key->lock);
 
-		kfree_rcu(key, rcu);
+		/* key is no longer on the lookup lists, unref */
+		mctp_key_unref(key);
 	}
 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
-
-	synchronize_rcu();
 }
 
 static struct proto mctp_proto = {
diff --git a/net/mctp/route.c b/net/mctp/route.c
index 224fd25b3678..b2243b150e71 100644
--- a/net/mctp/route.c
+++ b/net/mctp/route.c
@@ -83,25 +83,43 @@  static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
 	return true;
 }
 
+/* returns a key (with key->lock held, and refcounted), or NULL if no such
+ * key exists.
+ */
 static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
-					   mctp_eid_t peer)
+					   mctp_eid_t peer,
+					   unsigned long *irqflags)
+	__acquires(&key->lock)
 {
 	struct mctp_sk_key *key, *ret;
+	unsigned long flags;
 	struct mctp_hdr *mh;
 	u8 tag;
 
-	WARN_ON(!rcu_read_lock_held());
-
 	mh = mctp_hdr(skb);
 	tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
 
 	ret = NULL;
+	spin_lock_irqsave(&net->mctp.keys_lock, flags);
 
-	hlist_for_each_entry_rcu(key, &net->mctp.keys, hlist) {
-		if (mctp_key_match(key, mh->dest, peer, tag)) {
+	hlist_for_each_entry(key, &net->mctp.keys, hlist) {
+		if (!mctp_key_match(key, mh->dest, peer, tag))
+			continue;
+
+		spin_lock(&key->lock);
+		if (key->valid) {
+			refcount_inc(&key->refs);
 			ret = key;
 			break;
 		}
+		spin_unlock(&key->lock);
+	}
+
+	if (ret) {
+		spin_unlock(&net->mctp.keys_lock);
+		*irqflags = flags;
+	} else {
+		spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
 	}
 
 	return ret;
@@ -121,11 +139,19 @@  static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
 	key->local_addr = local;
 	key->tag = tag;
 	key->sk = &msk->sk;
-	spin_lock_init(&key->reasm_lock);
+	key->valid = true;
+	spin_lock_init(&key->lock);
+	refcount_set(&key->refs, 1);
 
 	return key;
 }
 
+void mctp_key_unref(struct mctp_sk_key *key)
+{
+	if (refcount_dec_and_test(&key->refs))
+		kfree(key);
+}
+
 static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
 {
 	struct net *net = sock_net(&msk->sk);
@@ -138,12 +164,17 @@  static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
 	hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
 		if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
 				   key->tag)) {
-			rc = -EEXIST;
-			break;
+			spin_lock(&tmp->lock);
+			if (tmp->valid)
+				rc = -EEXIST;
+			spin_unlock(&tmp->lock);
+			if (rc)
+				break;
 		}
 	}
 
 	if (!rc) {
+		refcount_inc(&key->refs);
 		hlist_add_head(&key->hlist, &net->mctp.keys);
 		hlist_add_head(&key->sklist, &msk->keys);
 	}
@@ -153,28 +184,35 @@  static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
 	return rc;
 }
 
-/* Must be called with key->reasm_lock, which it will release. Will schedule
- * the key for an RCU free.
+/* We're done with the key; unset valid and remove from lists. There may still
+ * be outstanding refs on the key though...
  */
 static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
 				   unsigned long flags)
-	__releases(&key->reasm_lock)
+	__releases(&key->lock)
 {
 	struct sk_buff *skb;
 
 	skb = key->reasm_head;
 	key->reasm_head = NULL;
 	key->reasm_dead = true;
-	spin_unlock_irqrestore(&key->reasm_lock, flags);
+	key->valid = false;
+	spin_unlock_irqrestore(&key->lock, flags);
 
 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
-	hlist_del_rcu(&key->hlist);
-	hlist_del_rcu(&key->sklist);
+	hlist_del(&key->hlist);
+	hlist_del(&key->sklist);
 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
-	kfree_rcu(key, rcu);
+
+	/* one unref for the lists */
+	mctp_key_unref(key);
+
+	/* and one for the local reference */
+	mctp_key_unref(key);
 
 	if (skb)
 		kfree_skb(skb);
+
 }
 
 static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb)
@@ -248,8 +286,10 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 
 	rcu_read_lock();
 
-	/* lookup socket / reasm context, exactly matching (src,dest,tag) */
-	key = mctp_lookup_key(net, skb, mh->src);
+	/* lookup socket / reasm context, exactly matching (src,dest,tag).
+	 * we hold a ref on the key, and key->lock held.
+	 */
+	key = mctp_lookup_key(net, skb, mh->src, &f);
 
 	if (flags & MCTP_HDR_FLAG_SOM) {
 		if (key) {
@@ -260,10 +300,12 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 			 * key for reassembly - we'll create a more specific
 			 * one for future packets if required (ie, !EOM).
 			 */
-			key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY);
+			key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY, &f);
 			if (key) {
 				msk = container_of(key->sk,
 						   struct mctp_sock, sk);
+				spin_unlock_irqrestore(&key->lock, f);
+				mctp_key_unref(key);
 				key = NULL;
 			}
 		}
@@ -282,11 +324,11 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 		if (flags & MCTP_HDR_FLAG_EOM) {
 			sock_queue_rcv_skb(&msk->sk, skb);
 			if (key) {
-				spin_lock_irqsave(&key->reasm_lock, f);
 				/* we've hit a pending reassembly; not much we
 				 * can do but drop it
 				 */
 				__mctp_key_unlock_drop(key, net, f);
+				key = NULL;
 			}
 			rc = 0;
 			goto out_unlock;
@@ -303,7 +345,7 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 				goto out_unlock;
 			}
 
-			/* we can queue without the reasm lock here, as the
+			/* we can queue without the key lock here, as the
 			 * key isn't observable yet
 			 */
 			mctp_frag_queue(key, skb);
@@ -318,17 +360,17 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 			if (rc)
 				kfree(key);
 
-		} else {
-			/* existing key: start reassembly */
-			spin_lock_irqsave(&key->reasm_lock, f);
+			/* we don't need to release key->lock on exit */
+			key = NULL;
 
+		} else {
 			if (key->reasm_head || key->reasm_dead) {
 				/* duplicate start? drop everything */
 				__mctp_key_unlock_drop(key, net, f);
 				rc = -EEXIST;
+				key = NULL;
 			} else {
 				rc = mctp_frag_queue(key, skb);
-				spin_unlock_irqrestore(&key->reasm_lock, f);
 			}
 		}
 
@@ -337,8 +379,6 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 		 * using the message-specific key
 		 */
 
-		spin_lock_irqsave(&key->reasm_lock, f);
-
 		/* we need to be continuing an existing reassembly... */
 		if (!key->reasm_head)
 			rc = -EINVAL;
@@ -352,8 +392,7 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 			sock_queue_rcv_skb(key->sk, key->reasm_head);
 			key->reasm_head = NULL;
 			__mctp_key_unlock_drop(key, net, f);
-		} else {
-			spin_unlock_irqrestore(&key->reasm_lock, f);
+			key = NULL;
 		}
 
 	} else {
@@ -363,6 +402,10 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 
 out_unlock:
 	rcu_read_unlock();
+	if (key) {
+		spin_unlock_irqrestore(&key->lock, f);
+		mctp_key_unref(key);
+	}
 out:
 	if (rc)
 		kfree_skb(skb);
@@ -459,6 +502,7 @@  static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
 	 */
 	hlist_add_head_rcu(&key->hlist, &mns->keys);
 	hlist_add_head_rcu(&key->sklist, &msk->keys);
+	refcount_inc(&key->refs);
 }
 
 /* Allocate a locally-owned tag value for (saddr, daddr), and reserve
@@ -492,14 +536,26 @@  static int mctp_alloc_local_tag(struct mctp_sock *msk,
 	 * tags. If we find a conflict, clear that bit from tagbits
 	 */
 	hlist_for_each_entry(tmp, &mns->keys, hlist) {
+		/* We can check the lookup fields (*_addr, tag) without the
+		 * lock held, they don't change over the lifetime of the key.
+		 */
+
 		/* if we don't own the tag, it can't conflict */
 		if (tmp->tag & MCTP_HDR_FLAG_TO)
 			continue;
 
-		if ((tmp->peer_addr == daddr ||
-		     tmp->peer_addr == MCTP_ADDR_ANY) &&
-		    tmp->local_addr == saddr)
+		if (!((tmp->peer_addr == daddr ||
+		       tmp->peer_addr == MCTP_ADDR_ANY) &&
+		       tmp->local_addr == saddr))
+			continue;
+
+		spin_lock(&tmp->lock);
+		/* key must still be valid. If we find a match, clear the
+		 * potential tag value
+		 */
+		if (tmp->valid)
 			tagbits &= ~(1 << tmp->tag);
+		spin_unlock(&tmp->lock);
 
 		if (!tagbits)
 			break;