diff mbox series

[net-next,v2,5/5] mctp: Add SIOCMCTP{ALLOC,DROP}TAG ioctls for tag control

Message ID 20220209040557.391197-6-jk@codeconstruct.com.au (mailing list archive)
State Accepted
Commit 63ed1aab3d40aa61aaa66819bdce9377ac7f40fa
Delegated to: Netdev Maintainers
Headers show
Series MCTP tag control interface | expand

Checks

Context Check Description
netdev/tree_selection success Clearly marked for net-next
netdev/fixes_present success Fixes tag not required for -next series
netdev/subject_prefix success Link
netdev/cover_letter success Series has a cover letter
netdev/patch_count success Link
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 4 this patch: 4
netdev/cc_maintainers success CCed 9 of 9 maintainers
netdev/build_clang success Errors and warnings before: 18 this patch: 18
netdev/module_param success Was 0 now: 0
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 9 this patch: 9
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 581 lines checked
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Jeremy Kerr Feb. 9, 2022, 4:05 a.m. UTC
From: Matt Johnston <matt@codeconstruct.com.au>

This change adds a couple of new ioctls for mctp sockets:
SIOCMCTPALLOCTAG and SIOCMCTPDROPTAG.  These ioctls provide facilities
for explicit allocation / release of tags, overriding the automatic
allocate-on-send/release-on-reply and timeout behaviours. This allows
userspace more control over messages that may not fit a simple
request/response model.

In order to indicate a pre-allocated tag to the sendmsg() syscall, we
introduce a new flag to the struct sockaddr_mctp.smctp_tag value:
MCTP_TAG_PREALLOC.

Additional changes from Jeremy Kerr <jk@codeconstruct.com.au>.

Contains a fix that was:
Reported-by: kernel test robot <lkp@intel.com>

Signed-off-by: Matt Johnston <matt@codeconstruct.com.au>
Signed-off-by: Jeremy Kerr <jk@codeconstruct.com.au>

--
v2:
 - staticify mctp_lookup_prealloc_tag
 - limit line length
---
 Documentation/networking/mctp.rst |  48 ++++++++
 include/net/mctp.h                |  11 +-
 include/trace/events/mctp.h       |   5 +-
 include/uapi/linux/mctp.h         |  18 +++
 net/mctp/af_mctp.c                | 189 ++++++++++++++++++++++++++----
 net/mctp/route.c                  | 114 +++++++++++++-----
 6 files changed, 329 insertions(+), 56 deletions(-)
diff mbox series

Patch

diff --git a/Documentation/networking/mctp.rst b/Documentation/networking/mctp.rst
index 46f74bffce0f..c628cb5406d2 100644
--- a/Documentation/networking/mctp.rst
+++ b/Documentation/networking/mctp.rst
@@ -212,6 +212,54 @@  remote address is already known, or the message does not require a reply.
 Like the send calls, sockets will only receive responses to requests they have
 sent (TO=1) and may only respond (TO=0) to requests they have received.
 
+``ioctl(SIOCMCTPALLOCTAG)`` and ``ioctl(SIOCMCTPDROPTAG)``
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+These tags give applications more control over MCTP message tags, by allocating
+(and dropping) tag values explicitly, rather than the kernel automatically
+allocating a per-message tag at ``sendmsg()`` time.
+
+In general, you will only need to use these ioctls if your MCTP protocol does
+not fit the usual request/response model. For example, if you need to persist
+tags across multiple requests, or a request may generate more than one response.
+In these cases, the ioctls allow you to decouple the tag allocation (and
+release) from individual message send and receive operations.
+
+Both ioctls are passed a pointer to a ``struct mctp_ioc_tag_ctl``:
+
+.. code-block:: C
+
+    struct mctp_ioc_tag_ctl {
+        mctp_eid_t      peer_addr;
+        __u8		tag;
+        __u16   	flags;
+    };
+
+``SIOCMCTPALLOCTAG`` allocates a tag for a specific peer, which an application
+can use in future ``sendmsg()`` calls. The application populates the
+``peer_addr`` member with the remote EID. Other fields must be zero.
+
+On return, the ``tag`` member will be populated with the allocated tag value.
+The allocated tag will have the following tag bits set:
+
+ - ``MCTP_TAG_OWNER``: it only makes sense to allocate tags if you're the tag
+   owner
+
+ - ``MCTP_TAG_PREALLOC``: to indicate to ``sendmsg()`` that this is a
+   preallocated tag.
+
+ - ... and the actual tag value, within the least-significant three bits
+   (``MCTP_TAG_MASK``). Note that zero is a valid tag value.
+
+The tag value should be used as-is for the ``smctp_tag`` member of ``struct
+sockaddr_mctp``.
+
+``SIOCMCTPDROPTAG`` releases a tag that has been previously allocated by a
+``SIOCMCTPALLOCTAG`` ioctl. The ``peer_addr`` must be the same as used for the
+allocation, and the ``tag`` value must match exactly the tag returned from the
+allocation (including the ``MCTP_TAG_OWNER`` and ``MCTP_TAG_PREALLOC`` bits).
+The ``flags`` field must be zero.
+
 Kernel internals
 ================
 
diff --git a/include/net/mctp.h b/include/net/mctp.h
index 706d329dd8e8..e80a4baf8379 100644
--- a/include/net/mctp.h
+++ b/include/net/mctp.h
@@ -126,7 +126,7 @@  struct mctp_sock {
  */
 struct mctp_sk_key {
 	mctp_eid_t	peer_addr;
-	mctp_eid_t	local_addr;
+	mctp_eid_t	local_addr; /* MCTP_ADDR_ANY for local owned tags */
 	__u8		tag; /* incoming tag match; invert TO for local */
 
 	/* we hold a ref to sk when set */
@@ -163,6 +163,12 @@  struct mctp_sk_key {
 	 */
 	unsigned long	dev_flow_state;
 	struct mctp_dev	*dev;
+
+	/* a tag allocated with SIOCMCTPALLOCTAG ioctl will not expire
+	 * automatically on timeout or response, instead SIOCMCTPDROPTAG
+	 * is used.
+	 */
+	bool		manual_alloc;
 };
 
 struct mctp_skb_cb {
@@ -239,6 +245,9 @@  int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 		      struct sk_buff *skb, mctp_eid_t daddr, u8 req_tag);
 
 void mctp_key_unref(struct mctp_sk_key *key);
+struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
+					 mctp_eid_t daddr, mctp_eid_t saddr,
+					 bool manual, u8 *tagp);
 
 /* routing <--> device interface */
 unsigned int mctp_default_net(struct net *net);
diff --git a/include/trace/events/mctp.h b/include/trace/events/mctp.h
index 175b057c507f..165cf25f77a7 100644
--- a/include/trace/events/mctp.h
+++ b/include/trace/events/mctp.h
@@ -15,6 +15,7 @@  enum {
 	MCTP_TRACE_KEY_REPLIED,
 	MCTP_TRACE_KEY_INVALIDATED,
 	MCTP_TRACE_KEY_CLOSED,
+	MCTP_TRACE_KEY_DROPPED,
 };
 #endif /* __TRACE_MCTP_ENUMS */
 
@@ -22,6 +23,7 @@  TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_TIMEOUT);
 TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_REPLIED);
 TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_INVALIDATED);
 TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_CLOSED);
+TRACE_DEFINE_ENUM(MCTP_TRACE_KEY_DROPPED);
 
 TRACE_EVENT(mctp_key_acquire,
 	TP_PROTO(const struct mctp_sk_key *key),
@@ -66,7 +68,8 @@  TRACE_EVENT(mctp_key_release,
 				 { MCTP_TRACE_KEY_TIMEOUT, "timeout" },
 				 { MCTP_TRACE_KEY_REPLIED, "replied" },
 				 { MCTP_TRACE_KEY_INVALIDATED, "invalidated" },
-				 { MCTP_TRACE_KEY_CLOSED, "closed" })
+				 { MCTP_TRACE_KEY_CLOSED, "closed" },
+				 { MCTP_TRACE_KEY_DROPPED, "dropped" })
 	)
 );
 
diff --git a/include/uapi/linux/mctp.h b/include/uapi/linux/mctp.h
index 07b0318716fc..154ab56651f1 100644
--- a/include/uapi/linux/mctp.h
+++ b/include/uapi/linux/mctp.h
@@ -44,7 +44,25 @@  struct sockaddr_mctp_ext {
 
 #define MCTP_TAG_MASK		0x07
 #define MCTP_TAG_OWNER		0x08
+#define MCTP_TAG_PREALLOC	0x10
 
 #define MCTP_OPT_ADDR_EXT	1
 
+#define SIOCMCTPALLOCTAG	(SIOCPROTOPRIVATE + 0)
+#define SIOCMCTPDROPTAG		(SIOCPROTOPRIVATE + 1)
+
+struct mctp_ioc_tag_ctl {
+	mctp_eid_t	peer_addr;
+
+	/* For SIOCMCTPALLOCTAG: must be passed as zero, kernel will
+	 * populate with the allocated tag value. Returned tag value will
+	 * always have TO and PREALLOC set.
+	 *
+	 * For SIOCMCTPDROPTAG: userspace provides tag value to drop, from
+	 * a prior SIOCMCTPALLOCTAG call (and so must have TO and PREALLOC set).
+	 */
+	__u8		tag;
+	__u16		flags;
+};
+
 #endif /* __UAPI_MCTP_H */
diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c
index c921de63b494..f0702d920d8d 100644
--- a/net/mctp/af_mctp.c
+++ b/net/mctp/af_mctp.c
@@ -6,6 +6,7 @@ 
  * Copyright (c) 2021 Google
  */
 
+#include <linux/compat.h>
 #include <linux/if_arp.h>
 #include <linux/net.h>
 #include <linux/mctp.h>
@@ -21,6 +22,8 @@ 
 
 /* socket implementation */
 
+static void mctp_sk_expire_keys(struct timer_list *timer);
+
 static int mctp_release(struct socket *sock)
 {
 	struct sock *sk = sock->sk;
@@ -99,13 +102,20 @@  static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
 	struct sk_buff *skb;
 
 	if (addr) {
+		const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER |
+			MCTP_TAG_PREALLOC;
+
 		if (addrlen < sizeof(struct sockaddr_mctp))
 			return -EINVAL;
 		if (addr->smctp_family != AF_MCTP)
 			return -EINVAL;
 		if (!mctp_sockaddr_is_ok(addr))
 			return -EINVAL;
-		if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
+		if (addr->smctp_tag & ~tagbits)
+			return -EINVAL;
+		/* can't preallocate a non-owned tag */
+		if (addr->smctp_tag & MCTP_TAG_PREALLOC &&
+		    !(addr->smctp_tag & MCTP_TAG_OWNER))
 			return -EINVAL;
 
 	} else {
@@ -248,6 +258,32 @@  static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 	return rc;
 }
 
+/* We're done with the key; invalidate, stop reassembly, and remove from lists.
+ */
+static void __mctp_key_remove(struct mctp_sk_key *key, struct net *net,
+			      unsigned long flags, unsigned long reason)
+__releases(&key->lock)
+__must_hold(&net->mctp.keys_lock)
+{
+	struct sk_buff *skb;
+
+	trace_mctp_key_release(key, reason);
+	skb = key->reasm_head;
+	key->reasm_head = NULL;
+	key->reasm_dead = true;
+	key->valid = false;
+	mctp_dev_release_key(key->dev, key);
+	spin_unlock_irqrestore(&key->lock, flags);
+
+	hlist_del(&key->hlist);
+	hlist_del(&key->sklist);
+
+	/* unref for the lists */
+	mctp_key_unref(key);
+
+	kfree_skb(skb);
+}
+
 static int mctp_setsockopt(struct socket *sock, int level, int optname,
 			   sockptr_t optval, unsigned int optlen)
 {
@@ -293,6 +329,115 @@  static int mctp_getsockopt(struct socket *sock, int level, int optname,
 	return -EINVAL;
 }
 
+static int mctp_ioctl_alloctag(struct mctp_sock *msk, unsigned long arg)
+{
+	struct net *net = sock_net(&msk->sk);
+	struct mctp_sk_key *key = NULL;
+	struct mctp_ioc_tag_ctl ctl;
+	unsigned long flags;
+	u8 tag;
+
+	if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
+		return -EFAULT;
+
+	if (ctl.tag)
+		return -EINVAL;
+
+	if (ctl.flags)
+		return -EINVAL;
+
+	key = mctp_alloc_local_tag(msk, ctl.peer_addr, MCTP_ADDR_ANY,
+				   true, &tag);
+	if (IS_ERR(key))
+		return PTR_ERR(key);
+
+	ctl.tag = tag | MCTP_TAG_OWNER | MCTP_TAG_PREALLOC;
+	if (copy_to_user((void __user *)arg, &ctl, sizeof(ctl))) {
+		spin_lock_irqsave(&key->lock, flags);
+		__mctp_key_remove(key, net, flags, MCTP_TRACE_KEY_DROPPED);
+		mctp_key_unref(key);
+		return -EFAULT;
+	}
+
+	mctp_key_unref(key);
+	return 0;
+}
+
+static int mctp_ioctl_droptag(struct mctp_sock *msk, unsigned long arg)
+{
+	struct net *net = sock_net(&msk->sk);
+	struct mctp_ioc_tag_ctl ctl;
+	unsigned long flags, fl2;
+	struct mctp_sk_key *key;
+	struct hlist_node *tmp;
+	int rc;
+	u8 tag;
+
+	if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
+		return -EFAULT;
+
+	if (ctl.flags)
+		return -EINVAL;
+
+	/* Must be a local tag, TO set, preallocated */
+	if ((ctl.tag & ~MCTP_TAG_MASK) != (MCTP_TAG_OWNER | MCTP_TAG_PREALLOC))
+		return -EINVAL;
+
+	tag = ctl.tag & MCTP_TAG_MASK;
+	rc = -EINVAL;
+
+	spin_lock_irqsave(&net->mctp.keys_lock, flags);
+	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
+		/* we do an irqsave here, even though we know the irq state,
+		 * so we have the flags to pass to __mctp_key_remove
+		 */
+		spin_lock_irqsave(&key->lock, fl2);
+		if (key->manual_alloc &&
+		    ctl.peer_addr == key->peer_addr &&
+		    tag == key->tag) {
+			__mctp_key_remove(key, net, fl2,
+					  MCTP_TRACE_KEY_DROPPED);
+			rc = 0;
+		} else {
+			spin_unlock_irqrestore(&key->lock, fl2);
+		}
+	}
+	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+
+	return rc;
+}
+
+static int mctp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
+{
+	struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
+
+	switch (cmd) {
+	case SIOCMCTPALLOCTAG:
+		return mctp_ioctl_alloctag(msk, arg);
+	case SIOCMCTPDROPTAG:
+		return mctp_ioctl_droptag(msk, arg);
+	}
+
+	return -EINVAL;
+}
+
+#ifdef CONFIG_COMPAT
+static int mctp_compat_ioctl(struct socket *sock, unsigned int cmd,
+			     unsigned long arg)
+{
+	void __user *argp = compat_ptr(arg);
+
+	switch (cmd) {
+	/* These have compatible ptr layouts */
+	case SIOCMCTPALLOCTAG:
+	case SIOCMCTPDROPTAG:
+		return mctp_ioctl(sock, cmd, (unsigned long)argp);
+	}
+
+	return -ENOIOCTLCMD;
+}
+#endif
+
 static const struct proto_ops mctp_dgram_ops = {
 	.family		= PF_MCTP,
 	.release	= mctp_release,
@@ -302,7 +447,7 @@  static const struct proto_ops mctp_dgram_ops = {
 	.accept		= sock_no_accept,
 	.getname	= sock_no_getname,
 	.poll		= datagram_poll,
-	.ioctl		= sock_no_ioctl,
+	.ioctl		= mctp_ioctl,
 	.gettstamp	= sock_gettstamp,
 	.listen		= sock_no_listen,
 	.shutdown	= sock_no_shutdown,
@@ -312,6 +457,9 @@  static const struct proto_ops mctp_dgram_ops = {
 	.recvmsg	= mctp_recvmsg,
 	.mmap		= sock_no_mmap,
 	.sendpage	= sock_no_sendpage,
+#ifdef CONFIG_COMPAT
+	.compat_ioctl	= mctp_compat_ioctl,
+#endif
 };
 
 static void mctp_sk_expire_keys(struct timer_list *timer)
@@ -319,7 +467,7 @@  static void mctp_sk_expire_keys(struct timer_list *timer)
 	struct mctp_sock *msk = container_of(timer, struct mctp_sock,
 					     key_expiry);
 	struct net *net = sock_net(&msk->sk);
-	unsigned long next_expiry, flags;
+	unsigned long next_expiry, flags, fl2;
 	struct mctp_sk_key *key;
 	struct hlist_node *tmp;
 	bool next_expiry_valid = false;
@@ -327,15 +475,16 @@  static void mctp_sk_expire_keys(struct timer_list *timer)
 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
 
 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
-		spin_lock(&key->lock);
+		/* don't expire. manual_alloc is immutable, no locking
+		 * required.
+		 */
+		if (key->manual_alloc)
+			continue;
 
+		spin_lock_irqsave(&key->lock, fl2);
 		if (!time_after_eq(key->expiry, jiffies)) {
-			trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT);
-			key->valid = false;
-			hlist_del_rcu(&key->hlist);
-			hlist_del_rcu(&key->sklist);
-			spin_unlock(&key->lock);
-			mctp_key_unref(key);
+			__mctp_key_remove(key, net, fl2,
+					  MCTP_TRACE_KEY_TIMEOUT);
 			continue;
 		}
 
@@ -346,7 +495,7 @@  static void mctp_sk_expire_keys(struct timer_list *timer)
 			next_expiry = key->expiry;
 			next_expiry_valid = true;
 		}
-		spin_unlock(&key->lock);
+		spin_unlock_irqrestore(&key->lock, fl2);
 	}
 
 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
@@ -387,9 +536,9 @@  static void mctp_sk_unhash(struct sock *sk)
 {
 	struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
 	struct net *net = sock_net(sk);
+	unsigned long flags, fl2;
 	struct mctp_sk_key *key;
 	struct hlist_node *tmp;
-	unsigned long flags;
 
 	/* remove from any type-based binds */
 	mutex_lock(&net->mctp.bind_lock);
@@ -399,20 +548,8 @@  static void mctp_sk_unhash(struct sock *sk)
 	/* remove tag allocations */
 	spin_lock_irqsave(&net->mctp.keys_lock, flags);
 	hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
-		hlist_del(&key->sklist);
-		hlist_del(&key->hlist);
-
-		trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED);
-
-		spin_lock(&key->lock);
-		kfree_skb(key->reasm_head);
-		key->reasm_head = NULL;
-		key->reasm_dead = true;
-		key->valid = false;
-		spin_unlock(&key->lock);
-
-		/* key is no longer on the lookup lists, unref */
-		mctp_key_unref(key);
+		spin_lock_irqsave(&key->lock, fl2);
+		__mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED);
 	}
 	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
 }
diff --git a/net/mctp/route.c b/net/mctp/route.c
index 35f72e99e188..17e3482aa770 100644
--- a/net/mctp/route.c
+++ b/net/mctp/route.c
@@ -203,29 +203,38 @@  static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
 	return rc;
 }
 
-/* We're done with the key; unset valid and remove from lists. There may still
- * be outstanding refs on the key though...
+/* Helper for mctp_route_input().
+ * We're done with the key; unlock and unref the key.
+ * For the usual case of automatic expiry we remove the key from lists.
+ * In the case that manual allocation is set on a key we release the lock
+ * and local ref, reset reassembly, but don't remove from lists.
  */
-static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
-				   unsigned long flags)
-	__releases(&key->lock)
+static void __mctp_key_done_in(struct mctp_sk_key *key, struct net *net,
+			       unsigned long flags, unsigned long reason)
+__releases(&key->lock)
 {
 	struct sk_buff *skb;
 
+	trace_mctp_key_release(key, reason);
 	skb = key->reasm_head;
 	key->reasm_head = NULL;
-	key->reasm_dead = true;
-	key->valid = false;
-	mctp_dev_release_key(key->dev, key);
+
+	if (!key->manual_alloc) {
+		key->reasm_dead = true;
+		key->valid = false;
+		mctp_dev_release_key(key->dev, key);
+	}
 	spin_unlock_irqrestore(&key->lock, flags);
 
-	spin_lock_irqsave(&net->mctp.keys_lock, flags);
-	hlist_del(&key->hlist);
-	hlist_del(&key->sklist);
-	spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+	if (!key->manual_alloc) {
+		spin_lock_irqsave(&net->mctp.keys_lock, flags);
+		hlist_del(&key->hlist);
+		hlist_del(&key->sklist);
+		spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
 
-	/* one unref for the lists */
-	mctp_key_unref(key);
+		/* unref for the lists */
+		mctp_key_unref(key);
+	}
 
 	/* and one for the local reference */
 	mctp_key_unref(key);
@@ -379,9 +388,8 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 				/* we've hit a pending reassembly; not much we
 				 * can do but drop it
 				 */
-				trace_mctp_key_release(key,
-						       MCTP_TRACE_KEY_REPLIED);
-				__mctp_key_unlock_drop(key, net, f);
+				__mctp_key_done_in(key, net, f,
+						   MCTP_TRACE_KEY_REPLIED);
 				key = NULL;
 			}
 			rc = 0;
@@ -423,9 +431,8 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 		} else {
 			if (key->reasm_head || key->reasm_dead) {
 				/* duplicate start? drop everything */
-				trace_mctp_key_release(key,
-						       MCTP_TRACE_KEY_INVALIDATED);
-				__mctp_key_unlock_drop(key, net, f);
+				__mctp_key_done_in(key, net, f,
+						   MCTP_TRACE_KEY_INVALIDATED);
 				rc = -EEXIST;
 				key = NULL;
 			} else {
@@ -448,10 +455,10 @@  static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
 		 * the reassembly/response key
 		 */
 		if (!rc && flags & MCTP_HDR_FLAG_EOM) {
+			msk = container_of(key->sk, struct mctp_sock, sk);
 			sock_queue_rcv_skb(key->sk, key->reasm_head);
 			key->reasm_head = NULL;
-			trace_mctp_key_release(key, MCTP_TRACE_KEY_REPLIED);
-			__mctp_key_unlock_drop(key, net, f);
+			__mctp_key_done_in(key, net, f, MCTP_TRACE_KEY_REPLIED);
 			key = NULL;
 		}
 
@@ -579,9 +586,9 @@  static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
 /* Allocate a locally-owned tag value for (saddr, daddr), and reserve
  * it for the socket msk
  */
-static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
-						mctp_eid_t saddr,
-						mctp_eid_t daddr, u8 *tagp)
+struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
+					 mctp_eid_t daddr, mctp_eid_t saddr,
+					 bool manual, u8 *tagp)
 {
 	struct net *net = sock_net(&msk->sk);
 	struct netns_mctp *mns = &net->mctp;
@@ -636,6 +643,7 @@  static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
 		mctp_reserve_tag(net, key, msk);
 		trace_mctp_key_acquire(key);
 
+		key->manual_alloc = manual;
 		*tagp = key->tag;
 	}
 
@@ -649,6 +657,50 @@  static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
 	return key;
 }
 
+static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk,
+						    mctp_eid_t daddr,
+						    u8 req_tag, u8 *tagp)
+{
+	struct net *net = sock_net(&msk->sk);
+	struct netns_mctp *mns = &net->mctp;
+	struct mctp_sk_key *key, *tmp;
+	unsigned long flags;
+
+	req_tag &= ~(MCTP_TAG_PREALLOC | MCTP_TAG_OWNER);
+	key = NULL;
+
+	spin_lock_irqsave(&mns->keys_lock, flags);
+
+	hlist_for_each_entry(tmp, &mns->keys, hlist) {
+		if (tmp->tag != req_tag)
+			continue;
+
+		if (!mctp_address_matches(tmp->peer_addr, daddr))
+			continue;
+
+		if (!tmp->manual_alloc)
+			continue;
+
+		spin_lock(&tmp->lock);
+		if (tmp->valid) {
+			key = tmp;
+			refcount_inc(&key->refs);
+			spin_unlock(&tmp->lock);
+			break;
+		}
+		spin_unlock(&tmp->lock);
+	}
+	spin_unlock_irqrestore(&mns->keys_lock, flags);
+
+	if (!key)
+		return ERR_PTR(-ENOENT);
+
+	if (tagp)
+		*tagp = key->tag;
+
+	return key;
+}
+
 /* routing lookups */
 static bool mctp_rt_match_eid(struct mctp_route *rt,
 			      unsigned int net, mctp_eid_t eid)
@@ -843,8 +895,14 @@  int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 	if (rc)
 		goto out_release;
 
-	if (req_tag & MCTP_HDR_FLAG_TO) {
-		key = mctp_alloc_local_tag(msk, saddr, daddr, &tag);
+	if (req_tag & MCTP_TAG_OWNER) {
+		if (req_tag & MCTP_TAG_PREALLOC)
+			key = mctp_lookup_prealloc_tag(msk, daddr,
+						       req_tag, &tag);
+		else
+			key = mctp_alloc_local_tag(msk, daddr, saddr,
+						   false, &tag);
+
 		if (IS_ERR(key)) {
 			rc = PTR_ERR(key);
 			goto out_release;
@@ -855,7 +913,7 @@  int mctp_local_output(struct sock *sk, struct mctp_route *rt,
 		tag |= MCTP_HDR_FLAG_TO;
 	} else {
 		key = NULL;
-		tag = req_tag;
+		tag = req_tag & MCTP_TAG_MASK;
 	}
 
 	skb->protocol = htons(ETH_P_MCTP);