@@ -1579,7 +1579,15 @@ static int vsock_connect(struct socket *sock, struct sockaddr *addr,
if (signal_pending(current)) {
err = sock_intr_errno(timeout);
- sk->sk_state = sk->sk_state == TCP_ESTABLISHED ? TCP_CLOSING : TCP_CLOSE;
+ if (sk->sk_state == TCP_ESTABLISHED) {
+ /* Might have raced with a sockmap update. */
+ if (sk->sk_prot->unhash)
+ sk->sk_prot->unhash(sk);
+
+ sk->sk_state = TCP_CLOSING;
+ } else {
+ sk->sk_state = TCP_CLOSE;
+ }
sock->state = SS_UNCONNECTED;
vsock_transport_cancel_pkt(vsk);
vsock_remove_connected(vsk);
@@ -127,6 +127,7 @@ static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *bas
{
*prot = *base;
prot->close = sock_map_close;
+ prot->unhash = sock_map_unhash;
prot->recvmsg = vsock_bpf_recvmsg;
prot->sock_is_readable = sk_msg_is_readable;
}