@@ -213,9 +213,10 @@ void vsock_enqueue_accept(struct sock *listener, struct sock *connected);
void vsock_insert_connected(struct vsock_sock *vsk);
void vsock_remove_bound(struct vsock_sock *vsk);
void vsock_remove_connected(struct vsock_sock *vsk);
-struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr);
+struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr, struct net *net);
struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
- struct sockaddr_vm *dst);
+ struct sockaddr_vm *dst,
+ struct net *net);
void vsock_remove_sock(struct vsock_sock *vsk);
void vsock_for_each_connected_socket(struct vsock_transport *transport,
void (*fn)(struct sock *sk));
@@ -255,4 +256,6 @@ static inline bool vsock_msgzerocopy_allow(const struct vsock_transport *t)
{
return t->msgzerocopy_allow && t->msgzerocopy_allow();
}
+
+struct net *vsock_global_net(void);
#endif /* __AF_VSOCK_H__ */
@@ -235,37 +235,60 @@ static void __vsock_remove_connected(struct vsock_sock *vsk)
sock_put(&vsk->sk);
}
-static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr)
+struct net *vsock_global_net(void)
{
+ return NULL;
+}
+EXPORT_SYMBOL_GPL(vsock_global_net);
+
+static struct sock *__vsock_find_bound_socket(struct sockaddr_vm *addr,
+ struct net *net)
+{
+ struct sock *fallback = NULL;
struct vsock_sock *vsk;
list_for_each_entry(vsk, vsock_bound_sockets(addr), bound_table) {
- if (vsock_addr_equals_addr(addr, &vsk->local_addr))
- return sk_vsock(vsk);
+ if (vsock_addr_equals_addr(addr, &vsk->local_addr)) {
+ if (net_eq(net, sock_net(sk_vsock(vsk))))
+ return sk_vsock(vsk);
+ if (net_eq(net, vsock_global_net()))
+ fallback = sk_vsock(vsk);
+ }
if (addr->svm_port == vsk->local_addr.svm_port &&
(vsk->local_addr.svm_cid == VMADDR_CID_ANY ||
- addr->svm_cid == VMADDR_CID_ANY))
- return sk_vsock(vsk);
+ addr->svm_cid == VMADDR_CID_ANY)) {
+ if (net_eq(net, sock_net(sk_vsock(vsk))))
+ return sk_vsock(vsk);
+
+ if (net_eq(net, vsock_global_net()))
+ fallback = sk_vsock(vsk);
+ }
}
- return NULL;
+ return fallback;
}
static struct sock *__vsock_find_connected_socket(struct sockaddr_vm *src,
- struct sockaddr_vm *dst)
+ struct sockaddr_vm *dst,
+ struct net *net)
{
+ struct sock *fallback = NULL;
struct vsock_sock *vsk;
list_for_each_entry(vsk, vsock_connected_sockets(src, dst),
connected_table) {
if (vsock_addr_equals_addr(src, &vsk->remote_addr) &&
dst->svm_port == vsk->local_addr.svm_port) {
- return sk_vsock(vsk);
+ if (net_eq(net, sock_net(sk_vsock(vsk))))
+ return sk_vsock(vsk);
+
+ if (net_eq(net, vsock_global_net()))
+ fallback = sk_vsock(vsk);
}
}
- return NULL;
+ return fallback;
}
static void vsock_insert_unbound(struct vsock_sock *vsk)
@@ -304,12 +327,12 @@ void vsock_remove_connected(struct vsock_sock *vsk)
}
EXPORT_SYMBOL_GPL(vsock_remove_connected);
-struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr)
+struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr, struct net *net)
{
struct sock *sk;
spin_lock_bh(&vsock_table_lock);
- sk = __vsock_find_bound_socket(addr);
+ sk = __vsock_find_bound_socket(addr, net);
if (sk)
sock_hold(sk);
@@ -320,12 +343,13 @@ struct sock *vsock_find_bound_socket(struct sockaddr_vm *addr)
EXPORT_SYMBOL_GPL(vsock_find_bound_socket);
struct sock *vsock_find_connected_socket(struct sockaddr_vm *src,
- struct sockaddr_vm *dst)
+ struct sockaddr_vm *dst,
+ struct net *net)
{
struct sock *sk;
spin_lock_bh(&vsock_table_lock);
- sk = __vsock_find_connected_socket(src, dst);
+ sk = __vsock_find_connected_socket(src, dst, net);
if (sk)
sock_hold(sk);
@@ -644,6 +668,7 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
{
static u32 port;
struct sockaddr_vm new_addr;
+ struct net *net = sock_net(sk_vsock(vsk));
if (!port)
port = get_random_u32_above(LAST_RESERVED_PORT);
@@ -660,7 +685,7 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
new_addr.svm_port = port++;
- if (!__vsock_find_bound_socket(&new_addr)) {
+ if (!__vsock_find_bound_socket(&new_addr, net)) {
found = true;
break;
}
@@ -677,7 +702,7 @@ static int __vsock_bind_connectible(struct vsock_sock *vsk,
return -EACCES;
}
- if (__vsock_find_bound_socket(&new_addr))
+ if (__vsock_find_bound_socket(&new_addr, net))
return -EADDRINUSE;
}
@@ -313,7 +313,7 @@ static void hvs_open_connection(struct vmbus_channel *chan)
return;
hvs_addr_init(&addr, conn_from_host ? if_type : if_instance);
- sk = vsock_find_bound_socket(&addr);
+ sk = vsock_find_bound_socket(&addr, NULL);
if (!sk)
return;
@@ -1590,6 +1590,7 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
struct sk_buff *skb)
{
struct virtio_vsock_hdr *hdr = virtio_vsock_hdr(skb);
+ struct net *net = vsock_global_net();
struct sockaddr_vm src, dst;
struct vsock_sock *vsk;
struct sock *sk;
@@ -1617,9 +1618,9 @@ void virtio_transport_recv_pkt(struct virtio_transport *t,
/* The socket must be in connected or bound table
* otherwise send reset back
*/
- sk = vsock_find_connected_socket(&src, &dst);
+ sk = vsock_find_connected_socket(&src, &dst, net);
if (!sk) {
- sk = vsock_find_bound_socket(&dst);
+ sk = vsock_find_bound_socket(&dst, net);
if (!sk) {
(void)virtio_transport_reset_no_sock(t, skb);
goto free_pkt;
@@ -703,9 +703,9 @@ static int vmci_transport_recv_stream_cb(void *data, struct vmci_datagram *dg)
vsock_addr_init(&src, pkt->dg.src.context, pkt->src_port);
vsock_addr_init(&dst, pkt->dg.dst.context, pkt->dst_port);
- sk = vsock_find_connected_socket(&src, &dst);
+ sk = vsock_find_connected_socket(&src, &dst, NULL);
if (!sk) {
- sk = vsock_find_bound_socket(&dst);
+ sk = vsock_find_bound_socket(&dst, NULL);
if (!sk) {
/* We could not find a socket for this specified
* address. If this packet is a RST, we just drop it.