diff mbox series

[net-next,v3,12/16] mctp: Implement message fragmentation & reassembly

Message ID 20210723082932.3570396-13-jk@codeconstruct.com.au (mailing list archive)
State Superseded
Delegated to: Netdev Maintainers
Headers show
Series Add Management Component Transport Protocol support | expand

Checks

Context Check Description
netdev/cover_letter success Link
netdev/fixes_present success Link
netdev/patch_count fail Series longer than 15 patches
netdev/tree_selection success Clearly marked for net-next
netdev/subject_prefix success Link
netdev/cc_maintainers warning 2 maintainers not CCed: davem@davemloft.net kuba@kernel.org
netdev/source_inline success Was 0 now: 0
netdev/verify_signedoff success Link
netdev/module_param success Was 0 now: 0
netdev/build_32bit fail Errors and warnings before: 4 this patch: 6
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/verify_fixes success Link
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 508 lines checked
netdev/build_allmodconfig_warn fail Errors and warnings before: 4 this patch: 6
netdev/header_inline success Link

Commit Message

Jeremy Kerr July 23, 2021, 8:29 a.m. UTC
This change implements MCTP fragmentation (based on route & device MTU),
and corresponding reassembly.

The MCTP specification only allows for fragmentation on the originating
message endpoint, and reassembly on the destination endpoint -
intermediate nodes do not need to reassemble/refragment.  Consequently,
we only fragment in the local transmit path, and reassemble
locally-bound packets. Messages are required to be in-order, so we
simply cancel reassembly on out-of-order or missing packets.

In the fragmentation path, we just break up the message into MTU-sized
fragments; the skb structure is a simple copy for now, which we can later
improve with a shared data implementation.

For reassembly, we keep track of incoming message fragments using the
existing tag infrastructure, allocating a key on the (src,dest,tag)
tuple, and reassembles matching fragments into a skb->frag_list.

Signed-off-by: Jeremy Kerr <jk@codeconstruct.com.au>

---
v2:
 - limit max reassembly size
v3:
 - fix comment typos
---
 include/net/mctp.h |  25 ++-
 net/mctp/af_mctp.c |   8 +
 net/mctp/route.c   | 371 ++++++++++++++++++++++++++++++++++++++++-----
 3 files changed, 360 insertions(+), 44 deletions(-)
diff mbox series

Patch

diff --git a/include/net/mctp.h b/include/net/mctp.h
index 381d71983b78..350facde2ceb 100644
--- a/include/net/mctp.h
+++ b/include/net/mctp.h
@@ -84,9 +84,21 @@  struct mctp_sock {
  *        updates to either list are performed under the netns_mctp->keys
  *        lock.
  *
- * - there is a single destruction path for a mctp_sk_key - through socket
- *   unhash (see mctp_sk_unhash). This performs the list removal under
- *   keys_lock.
+ * - a key may have a sk_buff attached as part of an in-progress message
+ *   reassembly (->reasm_head). The reassembly context is protected by
+ *   reasm_lock, which may be acquired with the keys lock (above) held, if
+ *   necessary. Consequently, keys lock *cannot* be acquired with the
+ *   reasm_lock held.
+ *
+ * - there are two destruction paths for a mctp_sk_key:
+ *
+ *    - through socket unhash (see mctp_sk_unhash). This performs the list
+ *      removal under keys_lock.
+ *
+ *    - where a key is established to receive a reply message: after receiving
+ *      the (complete) reply, or during reassembly errors. Here, we clean up
+ *      the reassembly context (marking reasm_dead, to prevent another from
+ *      starting), and remove the socket from the netns & socket lists.
  */
 struct mctp_sk_key {
 	mctp_eid_t	peer_addr;
@@ -102,6 +114,13 @@  struct mctp_sk_key {
 	/* per-socket list */
 	struct hlist_node sklist;
 
+	/* incoming fragment reassembly context */
+	spinlock_t	reasm_lock;
+	struct sk_buff	*reasm_head;
+	struct sk_buff	**reasm_tailp;
+	bool		reasm_dead;
+	u8		last_seq;
+
 	struct rcu_head	rcu;
 };
 
diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c
index 52bd7f2b78db..9ca836df19d0 100644
--- a/net/mctp/af_mctp.c
+++ b/net/mctp/af_mctp.c
@@ -263,6 +263,14 @@  static void mctp_sk_unhash(struct sock *sk)
 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
 		hlist_del_rcu(&key->sklist);
 		hlist_del_rcu(&key->hlist);
+
+		spin_lock(&key->reasm_lock);
+		if (key->reasm_head)
+			kfree_skb(key->reasm_head);
+		key->reasm_head = NULL;
+		key->reasm_dead = true;
+		spin_unlock(&key->reasm_lock);
+
 		kfree_rcu(key, rcu);
 	}
 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
diff --git a/net/mctp/route.c b/net/mctp/route.c
index 9f371c1914c4..b961f43b5fbd 100644
--- a/net/mctp/route.c
+++ b/net/mctp/route.c
@@ -23,6 +23,8 @@ 
 #include <net/netlink.h>
 #include <net/sock.h>
 
+static const unsigned int mctp_message_maxlen = 64 * 1024;
+
 /* route output callbacks */
 static int mctp_route_discard(struct mctp_route *route, struct sk_buff *skb)
 {
@@ -105,14 +107,125 @@  static struct mctp_sk_key *mctp_lookup_key(struct net *net, struct sk_buff *skb,
 	return ret;
 }
 
+static struct mctp_sk_key *mctp_key_alloc(struct mctp_sock *msk,
+					  mctp_eid_t local, mctp_eid_t peer,
+					  u8 tag, gfp_t gfp)
+{
+	struct mctp_sk_key *key;
+
+	key = kzalloc(sizeof(*key), gfp);
+	if (!key)
+		return NULL;
+
+	key->peer_addr = peer;
+	key->local_addr = local;
+	key->tag = tag;
+	key->sk = &msk->sk;
+	spin_lock_init(&key->reasm_lock);
+
+	return key;
+}
+
+static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
+{
+	struct net *net = sock_net(&msk->sk);
+	struct mctp_sk_key *tmp;
+	unsigned long flags;
+	int rc = 0;
+
+	spin_lock_irqsave(&net->mctp.keys_lock, flags);
+
+	hlist_for_each_entry(tmp, &net->mctp.keys, hlist) {
+		if (mctp_key_match(tmp, key->local_addr, key->peer_addr,
+				   key->tag)) {
+			rc = -EEXIST;
+			break;
+		}
+	}
+
+	if (!rc) {
+		hlist_add_head(&key->hlist, &net->mctp.keys);
+		hlist_add_head(&key->sklist, &msk->keys);
+	}
+
+	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+
+	return rc;
+}
+
+/* Must be called with key->reasm_lock, which it will release. Will schedule
+ * the key for an RCU free.
+ */
+static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
+				   unsigned long flags)
+	__releases(&key->reasm_lock)
+{
+	struct sk_buff *skb;
+
+	skb = key->reasm_head;
+	key->reasm_head = NULL;
+	key->reasm_dead = true;
+	spin_unlock_irqrestore(&key->reasm_lock, flags);
+
+	spin_lock_irqsave(&net->mctp.keys_lock, flags);
+	hlist_del_rcu(&key->hlist);
+	hlist_del_rcu(&key->sklist);
+	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+	kfree_rcu(key, rcu);
+
+	if (skb)
+		kfree_skb(skb);
+}
+
+static int mctp_frag_queue(struct mctp_sk_key *key, struct sk_buff *skb)
+{
+	struct mctp_hdr *hdr = mctp_hdr(skb);
+	u8 exp_seq, this_seq;
+
+	this_seq = (hdr->flags_seq_tag >> MCTP_HDR_SEQ_SHIFT)
+		& MCTP_HDR_SEQ_MASK;
+
+	if (!key->reasm_head) {
+		key->reasm_head = skb;
+		key->reasm_tailp = &(skb_shinfo(skb)->frag_list);
+		key->last_seq = this_seq;
+		return 0;
+	}
+
+	exp_seq = (key->last_seq + 1) & MCTP_HDR_SEQ_MASK;
+
+	if (this_seq != exp_seq)
+		return -EINVAL;
+
+	if (key->reasm_head->len + skb->len > mctp_message_maxlen)
+		return -EINVAL;
+
+	skb->next = NULL;
+	skb->sk = NULL;
+	*key->reasm_tailp = skb;
+	key->reasm_tailp = &skb->next;
+
+	key->last_seq = this_seq;
+
+	key->reasm_head->data_len += skb->len;
+	key->reasm_head->len += skb->len;
+	key->reasm_head->truesize += skb->truesize;
+
+	return 0;
+}
+
 static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 {
 	struct net *net = dev_net(skb->dev);
 	struct mctp_sk_key *key;
 	struct mctp_sock *msk;
 	struct mctp_hdr *mh;
+	unsigned long f;
+	u8 tag, flags;
+	int rc;
 
 	msk = NULL;
+	rc = -EINVAL;
 
 	/* we may be receiving a locally-routed packet; drop source sk
 	 * accounting
@@ -121,50 +234,143 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 
 	/* ensure we have enough data for a header and a type */
 	if (skb->len < sizeof(struct mctp_hdr) + 1)
-		goto drop;
+		goto out;
 
 	/* grab header, advance data ptr */
 	mh = mctp_hdr(skb);
 	skb_pull(skb, sizeof(struct mctp_hdr));
 
 	if (mh->ver != 1)
-		goto drop;
+		goto out;
 
-	/* TODO: reassembly */
-	if ((mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM))
-				!= (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM))
-		goto drop;
+	flags = mh->flags_seq_tag & (MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM);
+	tag = mh->flags_seq_tag & (MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
 
 	rcu_read_lock();
-	/* 1. lookup socket matching (src,dest,tag) */
+
+	/* lookup socket / reasm context, exactly matching (src,dest,tag) */
 	key = mctp_lookup_key(net, skb, mh->src);
 
-	/* 2. lookup socket macthing (BCAST,dest,tag) */
-	if (!key)
-		key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY);
+	if (flags & MCTP_HDR_FLAG_SOM) {
+		if (key) {
+			msk = container_of(key->sk, struct mctp_sock, sk);
+		} else {
+			/* first response to a broadcast? do a more general
+			 * key lookup to find the socket, but don't use this
+			 * key for reassembly - we'll create a more specific
+			 * one for future packets if required (ie, !EOM).
+			 */
+			key = mctp_lookup_key(net, skb, MCTP_ADDR_ANY);
+			if (key) {
+				msk = container_of(key->sk,
+						   struct mctp_sock, sk);
+				key = NULL;
+			}
+		}
 
-	/* 3. SOM? -> lookup bound socket, conditionally (!EOM) create
-	 * mapping for future (1)/(2).
-	 */
-	if (key)
-		msk = container_of(key->sk, struct mctp_sock, sk);
-	else if (!msk && (mh->flags_seq_tag & MCTP_HDR_FLAG_SOM))
-		msk = mctp_lookup_bind(net, skb);
+		if (!key && !msk && (tag & MCTP_HDR_FLAG_TO))
+			msk = mctp_lookup_bind(net, skb);
 
-	if (!msk)
-		goto unlock_drop;
+		if (!msk) {
+			rc = -ENOENT;
+			goto out;
+		}
 
-	sock_queue_rcv_skb(&msk->sk, skb);
+		/* single-packet message? deliver to socket, clean up any
+		 * pending key.
+		 */
+		if (flags & MCTP_HDR_FLAG_EOM) {
+			sock_queue_rcv_skb(&msk->sk, skb);
+			if (key) {
+				spin_lock_irqsave(&key->reasm_lock, f);
+				/* we've hit a pending reassembly; not much we
+				 * can do but drop it
+				 */
+				__mctp_key_unlock_drop(key, net, f);
+			}
+			rc = 0;
+			goto out;
+		}
 
-	rcu_read_unlock();
+		/* broadcast response or a bind() - create a key for further
+		 * packets for this message
+		 */
+		if (!key) {
+			key = mctp_key_alloc(msk, mh->dest, mh->src,
+					     tag, GFP_ATOMIC);
+			if (!key) {
+				rc = -ENOMEM;
+				goto out;
+			}
 
-	return 0;
+			/* we can queue without the reasm lock here, as the
+			 * key isn't observable yet
+			 */
+			mctp_frag_queue(key, skb);
+
+			/* if the key_add fails, we've raced with another
+			 * SOM packet with the same src, dest and tag. There's
+			 * no way to distinguish future packets, so all we
+			 * can do is drop; we'll free the skb on exit from
+			 * this function.
+			 */
+			rc = mctp_key_add(key, msk);
+			if (rc)
+				kfree(key);
+
+		} else {
+			/* existing key: start reassembly */
+			spin_lock_irqsave(&key->reasm_lock, f);
+
+			if (key->reasm_head || key->reasm_dead) {
+				/* duplicate start? drop everything */
+				__mctp_key_unlock_drop(key, net, f);
+				rc = -EEXIST;
+			} else {
+				rc = mctp_frag_queue(key, skb);
+				spin_unlock_irqrestore(&key->reasm_lock, f);
+			}
+		}
+
+	} else if (key) {
+		/* this packet continues a previous message; reassemble
+		 * using the message-specific key
+		 */
+
+		spin_lock_irqsave(&key->reasm_lock, f);
+
+		/* we need to be continuing an existing reassembly... */
+		if (!key->reasm_head)
+			rc = -EINVAL;
+		else
+			rc = mctp_frag_queue(key, skb);
+
+		/* end of message? deliver to socket, and we're done with
+		 * the reassembly/response key
+		 */
+		if (!rc && flags & MCTP_HDR_FLAG_EOM) {
+			sock_queue_rcv_skb(key->sk, key->reasm_head);
+			key->reasm_head = NULL;
+			__mctp_key_unlock_drop(key, net, f);
+		} else {
+			spin_unlock_irqrestore(&key->reasm_lock, f);
+		}
+
+	} else {
+		/* not a start, no matching key */
+		rc = -ENOENT;
+	}
 
-unlock_drop:
+out:
 	rcu_read_unlock();
-drop:
-	kfree_skb(skb);
-	return 0;
+	if (rc)
+		kfree_skb(skb);
+	return rc;
+}
+
+static unsigned int mctp_route_mtu(struct mctp_route *rt)
+{
+	return rt->mtu ?: READ_ONCE(rt->dev->dev->mtu);
 }
 
 static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
@@ -228,8 +434,6 @@  static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
 
 	lockdep_assert_held(&mns->keys_lock);
 
-	key->sk = &msk->sk;
-
 	/* we hold the net->key_lock here, allowing updates to both
 	 * then net and sk
 	 */
@@ -251,11 +455,9 @@  static int mctp_alloc_local_tag(struct mctp_sock *msk,
 	u8 tagbits;
 
 	/* be optimistic, alloc now */
-	key = kzalloc(sizeof(*key), GFP_KERNEL);
+	key = mctp_key_alloc(msk, saddr, daddr, 0, GFP_KERNEL);
 	if (!key)
 		return -ENOMEM;
-	key->local_addr = saddr;
-	key->peer_addr = daddr;
 
 	/* 8 possible tag values */
 	tagbits = 0xff;
@@ -340,6 +542,86 @@  int mctp_do_route(struct mctp_route *rt, struct sk_buff *skb)
 	return rc;
 }
 
+static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
+				  unsigned int mtu, u8 tag)
+{
+	const unsigned int hlen = sizeof(struct mctp_hdr);
+	struct mctp_hdr *hdr, *hdr2;
+	unsigned int pos, size;
+	struct sk_buff *skb2;
+	int rc;
+	u8 seq;
+
+	hdr = mctp_hdr(skb);
+	seq = 0;
+	rc = 0;
+
+	if (mtu < hlen + 1) {
+		kfree_skb(skb);
+		return -EMSGSIZE;
+	}
+
+	/* we've got the header */
+	skb_pull(skb, hlen);
+
+	for (pos = 0; pos < skb->len;) {
+		/* size of message payload */
+		size = min(mtu - hlen, skb->len - pos);
+
+		skb2 = alloc_skb(MCTP_HEADER_MAXLEN + hlen + size, GFP_KERNEL);
+		if (!skb2) {
+			rc = -ENOMEM;
+			break;
+		}
+
+		/* generic skb copy */
+		skb2->protocol = skb->protocol;
+		skb2->priority = skb->priority;
+		skb2->dev = skb->dev;
+		memcpy(skb2->cb, skb->cb, sizeof(skb2->cb));
+
+		if (skb->sk)
+			skb_set_owner_w(skb2, skb->sk);
+
+		/* establish packet */
+		skb_reserve(skb2, MCTP_HEADER_MAXLEN);
+		skb_reset_network_header(skb2);
+		skb_put(skb2, hlen + size);
+		skb2->transport_header = skb2->network_header + hlen;
+
+		/* copy header fields, calculate SOM/EOM flags & seq */
+		hdr2 = mctp_hdr(skb2);
+		hdr2->ver = hdr->ver;
+		hdr2->dest = hdr->dest;
+		hdr2->src = hdr->src;
+		hdr2->flags_seq_tag = tag &
+			(MCTP_HDR_TAG_MASK | MCTP_HDR_FLAG_TO);
+
+		if (pos == 0)
+			hdr2->flags_seq_tag |= MCTP_HDR_FLAG_SOM;
+
+		if (pos + size == skb->len)
+			hdr2->flags_seq_tag |= MCTP_HDR_FLAG_EOM;
+
+		hdr2->flags_seq_tag |= seq << MCTP_HDR_SEQ_SHIFT;
+
+		/* copy message payload */
+		skb_copy_bits(skb, pos, skb_transport_header(skb2), size);
+
+		/* do route, but don't drop the rt reference */
+		rc = rt->output(rt, skb2);
+		if (rc)
+			break;
+
+		seq = (seq + 1) & MCTP_HDR_SEQ_MASK;
+		pos += size;
+	}
+
+	mctp_route_release(rt);
+	consume_skb(skb);
+	return rc;
+}
+
 int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 		      struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag)
 {
@@ -347,6 +629,7 @@  int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 	struct mctp_skb_cb *cb = mctp_cb(skb);
 	struct mctp_hdr *hdr;
 	unsigned long flags;
+	unsigned int mtu;
 	mctp_eid_t saddr;
 	int rc;
 	u8 tag;
@@ -376,26 +659,32 @@  int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 		tag = req_tag;
 	}
 
-	/* TODO: we have the route MTU here; packetise */
 
+	skb->protocol = htons(ETH_P_MCTP);
+	skb->priority = 0;
 	skb_reset_transport_header(skb);
 	skb_push(skb, sizeof(struct mctp_hdr));
 	skb_reset_network_header(skb);
+	skb->dev = rt->dev->dev;
+
+	/* cb->net will have been set on initial ingress */
+	cb->src = saddr;
+
+	/* set up common header fields */
 	hdr = mctp_hdr(skb);
 	hdr->ver = 1;
 	hdr->dest = daddr;
 	hdr->src = saddr;
-	hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM | /* TODO */
-		tag;
 
-	skb->dev = rt->dev->dev;
-	skb->protocol = htons(ETH_P_MCTP);
-	skb->priority = 0;
+	mtu = mctp_route_mtu(rt);
 
-	/* cb->net will have been set on initial ingress */
-	cb->src = saddr;
-
-	return mctp_do_route(rt, skb);
+	if (skb->len + sizeof(struct mctp_hdr) <= mtu) {
+		hdr->flags_seq_tag = MCTP_HDR_FLAG_SOM | MCTP_HDR_FLAG_EOM |
+			tag;
+		return mctp_do_route(rt, skb);
+	} else {
+		return mctp_do_fragment_route(rt, skb, mtu, tag);
+	}
 }
 
 /* route management */