diff mbox series

[v2,3/3] vhost/vsock: use netns of process that opens the vhost-vsock-netns device

Message ID 20250312-vsock-netns-v2-3-84bffa1aa97a@gmail.com (mailing list archive)
State New
Delegated to: Netdev Maintainers
Headers show
Series vsock: add namespace support to vhost-vsock | expand

Checks

Context Check Description
netdev/series_format warning Target tree name not specified in the subject
netdev/tree_selection success Guessed tree name to be net-next
netdev/ynl success Generated files up to date; no warnings/errors; no diff in generated;
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 27 this patch: 27
netdev/build_tools success Errors and warnings before: 26 (+0) this patch: 26 (+0)
netdev/cc_maintainers warning 5 maintainers not CCed: edumazet@google.com pabeni@redhat.com horms@kernel.org gregkh@linuxfoundation.org arnd@arndb.de
netdev/build_clang success Errors and warnings before: 31 this patch: 31
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api warning Found: 'put_net(' was: 0 now: 1
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 250 this patch: 250
netdev/checkpatch warning WARNING: line length of 82 exceeds 80 columns WARNING: line length of 83 exceeds 80 columns WARNING: line length of 84 exceeds 80 columns WARNING: line length of 85 exceeds 80 columns WARNING: line length of 90 exceeds 80 columns WARNING: line length of 96 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0
netdev/contest success net-next-2025-03-13--00-00 (tests: 894)

Commit Message

Bobby Eshleman March 12, 2025, 8:59 p.m. UTC
From: Stefano Garzarella <sgarzare@redhat.com>

This patch assigns the network namespace of the process that opened
vhost-vsock-netns device (e.g. VMM) to the packets coming from the
guest, allowing only host sockets in the same network namespace to
communicate with the guest.

This patch also allows having different VMs, running in different
network namespace, with the same CID/port.

This patch brings namespace support only to vhost-vsock, but not
to virtio-vsock, hyperv, or vmci.

Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
Signed-off-by: Bobby Eshleman <bobbyeshleman@gmail.com>
---
v1 -> v2
 * removed 'netns' module param
 * renamed vsock_default_net() to vsock_global_net() to reflect new
   semantics
 * reserved NULL net to indicate "global" vsock namespace

RFC -> v1
* added 'netns' module param
* added 'vsock_net_eq()' to check the "net" assigned to a socket
  only when 'netns' support is enabled
---
 drivers/vhost/vsock.c            | 97 +++++++++++++++++++++++++++++++++-------
 include/linux/miscdevice.h       |  1 +
 include/net/af_vsock.h           |  3 +-
 net/vmw_vsock/af_vsock.c         | 30 ++++++++++++-
 net/vmw_vsock/virtio_transport.c |  4 +-
 net/vmw_vsock/vsock_loopback.c   |  4 +-
 6 files changed, 117 insertions(+), 22 deletions(-)
diff mbox series

Patch

diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 02e2a3551205a4398a74a167a82802d950c962f6..8702beb8238aa290b6d901e53c36481637840017 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -46,6 +46,7 @@  static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
 struct vhost_vsock {
 	struct vhost_dev dev;
 	struct vhost_virtqueue vqs[2];
+	struct net *net;
 
 	/* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
 	struct hlist_node hash;
@@ -67,8 +68,9 @@  static u32 vhost_transport_get_local_cid(void)
 /* Callers that dereference the return value must hold vhost_vsock_mutex or the
  * RCU read lock.
  */
-static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
+static struct vhost_vsock *vhost_vsock_get(u32 guest_cid, struct net *net, bool global_fallback)
 {
+	struct vhost_vsock *fallback = NULL;
 	struct vhost_vsock *vsock;
 
 	hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) {
@@ -78,11 +80,18 @@  static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
 		if (other_cid == 0)
 			continue;
 
-		if (other_cid == guest_cid)
-			return vsock;
+		if (other_cid == guest_cid) {
+			if (net_eq(net, vsock->net))
+				return vsock;
 
+			if (net_eq(vsock->net, vsock_global_net()))
+				fallback = vsock;
+		}
 	}
 
+	if (global_fallback)
+		return fallback;
+
 	return NULL;
 }
 
@@ -272,13 +281,14 @@  static int
 vhost_transport_send_pkt(struct sk_buff *skb)
 {
 	struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
+	struct net *net = VIRTIO_VSOCK_SKB_CB(skb)->net;
 	struct vhost_vsock *vsock;
 	int len = skb->len;
 
 	rcu_read_lock();
 
 	/* Find the vhost_vsock according to guest context id  */
-	vsock = vhost_vsock_get(le64_to_cpu(hdr->dst_cid));
+	vsock = vhost_vsock_get(le64_to_cpu(hdr->dst_cid), net, true);
 	if (!vsock) {
 		rcu_read_unlock();
 		kfree_skb(skb);
@@ -305,7 +315,8 @@  vhost_transport_cancel_pkt(struct vsock_sock *vsk)
 	rcu_read_lock();
 
 	/* Find the vhost_vsock according to guest context id  */
-	vsock = vhost_vsock_get(vsk->remote_addr.svm_cid);
+	vsock = vhost_vsock_get(vsk->remote_addr.svm_cid,
+				sock_net(sk_vsock(vsk)), true);
 	if (!vsock)
 		goto out;
 
@@ -403,7 +414,7 @@  static bool vhost_transport_msgzerocopy_allow(void)
 	return true;
 }
 
-static bool vhost_transport_seqpacket_allow(u32 remote_cid);
+static bool vhost_transport_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid);
 
 static struct virtio_transport vhost_transport = {
 	.transport = {
@@ -459,13 +470,14 @@  static struct virtio_transport vhost_transport = {
 	.send_pkt = vhost_transport_send_pkt,
 };
 
-static bool vhost_transport_seqpacket_allow(u32 remote_cid)
+static bool vhost_transport_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid)
 {
+	struct net *net = sock_net(sk_vsock(vsk));
 	struct vhost_vsock *vsock;
 	bool seqpacket_allow = false;
 
 	rcu_read_lock();
-	vsock = vhost_vsock_get(remote_cid);
+	vsock = vhost_vsock_get(remote_cid, net, true);
 
 	if (vsock)
 		seqpacket_allow = vsock->seqpacket_allow;
@@ -525,7 +537,7 @@  static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
 			continue;
 		}
 
-		VIRTIO_VSOCK_SKB_CB(skb)->net = vsock_global_net();
+		VIRTIO_VSOCK_SKB_CB(skb)->net = vsock->net;
 		total_len += sizeof(*hdr) + skb->len;
 
 		/* Deliver to monitoring devices all received packets */
@@ -650,7 +662,7 @@  static void vhost_vsock_free(struct vhost_vsock *vsock)
 	kvfree(vsock);
 }
 
-static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
+static int __vhost_vsock_dev_open(struct inode *inode, struct file *file, struct net *net)
 {
 	struct vhost_virtqueue **vqs;
 	struct vhost_vsock *vsock;
@@ -669,6 +681,8 @@  static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
 		goto out;
 	}
 
+	vsock->net = net;
+
 	vsock->guest_cid = 0; /* no CID assigned yet */
 	vsock->seqpacket_allow = false;
 
@@ -693,6 +707,22 @@  static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
 	return ret;
 }
 
+static int vhost_vsock_dev_open(struct inode *inode, struct file *file)
+{
+	return __vhost_vsock_dev_open(inode, file, vsock_global_net());
+}
+
+static int vhost_vsock_netns_dev_open(struct inode *inode, struct file *file)
+{
+	struct net *net;
+
+	net = get_net_ns_by_pid(current->pid);
+	if (IS_ERR(net))
+		return PTR_ERR(net);
+
+	return __vhost_vsock_dev_open(inode, file, net);
+}
+
 static void vhost_vsock_flush(struct vhost_vsock *vsock)
 {
 	vhost_dev_flush(&vsock->dev);
@@ -708,7 +738,7 @@  static void vhost_vsock_reset_orphans(struct sock *sk)
 	 */
 
 	/* If the peer is still valid, no need to reset connection */
-	if (vhost_vsock_get(vsk->remote_addr.svm_cid))
+	if (vhost_vsock_get(vsk->remote_addr.svm_cid, sock_net(sk), true))
 		return;
 
 	/* If the close timeout is pending, let it expire.  This avoids races
@@ -753,6 +783,8 @@  static int vhost_vsock_dev_release(struct inode *inode, struct file *file)
 	virtio_vsock_skb_queue_purge(&vsock->send_pkt_queue);
 
 	vhost_dev_cleanup(&vsock->dev);
+	if (vsock->net)
+		put_net(vsock->net);
 	kfree(vsock->dev.vqs);
 	vhost_vsock_free(vsock);
 	return 0;
@@ -777,9 +809,15 @@  static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
 	if (vsock_find_cid(guest_cid))
 		return -EADDRINUSE;
 
+	/* If this namespace has already connected to this CID, then report
+	 * that this address is already in use.
+	 */
+	if (vsock->net && vsock_net_has_connected(vsock->net, guest_cid))
+		return -EADDRINUSE;
+
 	/* Refuse if CID is already in use */
 	mutex_lock(&vhost_vsock_mutex);
-	other = vhost_vsock_get(guest_cid);
+	other = vhost_vsock_get(guest_cid, vsock->net, false);
 	if (other && other != vsock) {
 		mutex_unlock(&vhost_vsock_mutex);
 		return -EADDRINUSE;
@@ -931,6 +969,24 @@  static struct miscdevice vhost_vsock_misc = {
 	.fops = &vhost_vsock_fops,
 };
 
+static const struct file_operations vhost_vsock_netns_fops = {
+	.owner          = THIS_MODULE,
+	.open           = vhost_vsock_netns_dev_open,
+	.release        = vhost_vsock_dev_release,
+	.llseek		= noop_llseek,
+	.unlocked_ioctl = vhost_vsock_dev_ioctl,
+	.compat_ioctl   = compat_ptr_ioctl,
+	.read_iter      = vhost_vsock_chr_read_iter,
+	.write_iter     = vhost_vsock_chr_write_iter,
+	.poll           = vhost_vsock_chr_poll,
+};
+
+static struct miscdevice vhost_vsock_netns_misc = {
+	.minor = VHOST_VSOCK_NETNS_MINOR,
+	.name = "vhost-vsock-netns",
+	.fops = &vhost_vsock_netns_fops,
+};
+
 static int __init vhost_vsock_init(void)
 {
 	int ret;
@@ -941,17 +997,26 @@  static int __init vhost_vsock_init(void)
 		return ret;
 
 	ret = misc_register(&vhost_vsock_misc);
-	if (ret) {
-		vsock_core_unregister(&vhost_transport.transport);
-		return ret;
-	}
+	if (ret)
+		goto out_vt;
+
+	ret = misc_register(&vhost_vsock_netns_misc);
+	if (ret)
+		goto out_vvm;
 
 	return 0;
+
+out_vvm:
+	misc_deregister(&vhost_vsock_misc);
+out_vt:
+	vsock_core_unregister(&vhost_transport.transport);
+	return ret;
 };
 
 static void __exit vhost_vsock_exit(void)
 {
 	misc_deregister(&vhost_vsock_misc);
+	misc_deregister(&vhost_vsock_netns_misc);
 	vsock_core_unregister(&vhost_transport.transport);
 };
 
diff --git a/include/linux/miscdevice.h b/include/linux/miscdevice.h
index 69e110c2b86a9b16c1637778a25e1eebb3fd0111..a7e11b70a5398a225c4d63d50ac460e6388e022c 100644
--- a/include/linux/miscdevice.h
+++ b/include/linux/miscdevice.h
@@ -71,6 +71,7 @@ 
 #define USERIO_MINOR		240
 #define VHOST_VSOCK_MINOR	241
 #define RFKILL_MINOR		242
+#define VHOST_VSOCK_NETNS_MINOR	243
 #define MISC_DYNAMIC_MINOR	255
 
 struct device;
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index 41afbc18648c953da27a93571d408de968aa7668..1e37737689a741d91e64b8c0ed0a931fc7376194 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -143,7 +143,7 @@  struct vsock_transport {
 				     int flags);
 	int (*seqpacket_enqueue)(struct vsock_sock *vsk, struct msghdr *msg,
 				 size_t len);
-	bool (*seqpacket_allow)(u32 remote_cid);
+	bool (*seqpacket_allow)(struct vsock_sock *vsk, u32 remote_cid);
 	u32 (*seqpacket_has_data)(struct vsock_sock *vsk);
 
 	/* Notification. */
@@ -258,4 +258,5 @@  static inline bool vsock_msgzerocopy_allow(const struct vsock_transport *t)
 }
 
 struct net *vsock_global_net(void);
+bool vsock_net_has_connected(struct net *net, u64 guest_cid);
 #endif /* __AF_VSOCK_H__ */
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index d206489bf0a81cf989387c7c8063be91a7c21a7d..58fa415555d6aae5043b5ca2bfc4783af566cf28 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -391,6 +391,34 @@  void vsock_for_each_connected_socket(struct vsock_transport *transport,
 }
 EXPORT_SYMBOL_GPL(vsock_for_each_connected_socket);
 
+bool vsock_net_has_connected(struct net *net, u64 guest_cid)
+{
+	bool ret = false;
+	int i;
+
+	spin_lock_bh(&vsock_table_lock);
+
+	for (i = 0; i < ARRAY_SIZE(vsock_connected_table); i++) {
+		struct vsock_sock *vsk;
+
+		list_for_each_entry(vsk, &vsock_connected_table[i],
+				    connected_table) {
+			if (sock_net(sk_vsock(vsk)) != net)
+				continue;
+
+			if (vsk->remote_addr.svm_cid == guest_cid) {
+				ret = true;
+				goto out;
+			}
+		}
+	}
+
+out:
+	spin_unlock_bh(&vsock_table_lock);
+	return ret;
+}
+EXPORT_SYMBOL_GPL(vsock_net_has_connected);
+
 void vsock_add_pending(struct sock *listener, struct sock *pending)
 {
 	struct vsock_sock *vlistener;
@@ -537,7 +565,7 @@  int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
 
 	if (sk->sk_type == SOCK_SEQPACKET) {
 		if (!new_transport->seqpacket_allow ||
-		    !new_transport->seqpacket_allow(remote_cid)) {
+		    !new_transport->seqpacket_allow(vsk, remote_cid)) {
 			module_put(new_transport->module);
 			return -ESOCKTNOSUPPORT;
 		}
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 163ddfc0808529ad6dda7992f9ec48837dd7337c..60bf3f1f39c51d44768fd2f04df3abee9f966252 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -536,7 +536,7 @@  static bool virtio_transport_msgzerocopy_allow(void)
 	return true;
 }
 
-static bool virtio_transport_seqpacket_allow(u32 remote_cid);
+static bool virtio_transport_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid);
 
 static struct virtio_transport virtio_transport = {
 	.transport = {
@@ -593,7 +593,7 @@  static struct virtio_transport virtio_transport = {
 	.can_msgzerocopy = virtio_transport_can_msgzerocopy,
 };
 
-static bool virtio_transport_seqpacket_allow(u32 remote_cid)
+static bool virtio_transport_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid)
 {
 	struct virtio_vsock *vsock;
 	bool seqpacket_allow;
diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c
index 6e78927a598e07cf77386a420b9d05d3f491dc7c..1b2fab73e0d0a6c63ed60d29fc837da58f6fb121 100644
--- a/net/vmw_vsock/vsock_loopback.c
+++ b/net/vmw_vsock/vsock_loopback.c
@@ -46,7 +46,7 @@  static int vsock_loopback_cancel_pkt(struct vsock_sock *vsk)
 	return 0;
 }
 
-static bool vsock_loopback_seqpacket_allow(u32 remote_cid);
+static bool vsock_loopback_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid);
 static bool vsock_loopback_msgzerocopy_allow(void)
 {
 	return true;
@@ -106,7 +106,7 @@  static struct virtio_transport loopback_transport = {
 	.send_pkt = vsock_loopback_send_pkt,
 };
 
-static bool vsock_loopback_seqpacket_allow(u32 remote_cid)
+static bool vsock_loopback_seqpacket_allow(struct vsock_sock *vsk, u32 remote_cid)
 {
 	return true;
 }