diff mbox series

[RFC,v2,4/6] af_vsock/virtio/vsock: add 'seqpacket_drop()' callback

Message ID 20210704081040.89567-1-arseny.krasnov@kaspersky.com (mailing list archive)
State New, archived
Headers show
Series [RFC,v2,1/6] af_vsock/virtio/vsock: change seqpacket receive logic | expand

Commit Message

Arseny Krasnov July 4, 2021, 8:10 a.m. UTC
Add special callback for SEQPACKET socket which is called when
we need to drop current in-progress record: part of record was
copied successfully, reader wait rest of record, but signal
interrupts it and reader leaves it's loop, leaving packets of
current record still in queue. So to avoid copy of "orphaned"
record, we tell transport to drop every packet until EOR will
be found.

Signed-off-by: Arseny Krasnov <arseny.krasnov@kaspersky.com>
---
 drivers/vhost/vsock.c                   |  1 +
 include/linux/virtio_vsock.h            |  2 ++
 include/net/af_vsock.h                  |  1 +
 net/vmw_vsock/af_vsock.c                |  1 +
 net/vmw_vsock/virtio_transport.c        |  1 +
 net/vmw_vsock/virtio_transport_common.c | 23 +++++++++++++++++++----
 net/vmw_vsock/vsock_loopback.c          |  1 +
 7 files changed, 26 insertions(+), 4 deletions(-)
diff mbox series

Patch

diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index c9713d8db0f4..731b9fe07cd3 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -447,6 +447,7 @@  static struct virtio_transport vhost_transport = {
 		.stream_rcvhiwat          = virtio_transport_stream_rcvhiwat,
 		.stream_is_active         = virtio_transport_stream_is_active,
 		.stream_allow             = virtio_transport_stream_allow,
+		.seqpacket_drop           = virtio_transport_seqpacket_drop,
 
 		.seqpacket_dequeue        = virtio_transport_seqpacket_dequeue,
 		.seqpacket_enqueue        = virtio_transport_seqpacket_enqueue,
diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index 7360ab7ea0af..18a50f64bf54 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -36,6 +36,7 @@  struct virtio_vsock_sock {
 	u32 rx_bytes;
 	u32 buf_alloc;
 	struct list_head rx_queue;
+	bool drop_until_eor;
 };
 
 struct virtio_vsock_pkt {
@@ -89,6 +90,7 @@  virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
 				   struct msghdr *msg,
 				   int flags,
 				   bool *msg_ready);
+void virtio_transport_seqpacket_drop(struct vsock_sock *vsk);
 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk);
 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk);
 
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index 1747c0b564ef..356878aabbd4 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -141,6 +141,7 @@  struct vsock_transport {
 	int (*seqpacket_enqueue)(struct vsock_sock *vsk, struct msghdr *msg,
 				 size_t len);
 	bool (*seqpacket_allow)(u32 remote_cid);
+	void (*seqpacket_drop)(struct vsock_sock *vsk);
 
 	/* Notification. */
 	int (*notify_poll_in)(struct vsock_sock *, size_t, bool *);
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index 87955f9ff065..380a90c758c4 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -2024,6 +2024,7 @@  static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
 		intr_err = vsock_connectible_wait_data(sk, &wait, timeout, NULL, 0);
 		if (intr_err <= 0) {
 			err = intr_err;
+			transport->seqpacket_drop(vsk);
 			break;
 		}
 
diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 2a7c56fcb062..2f7d54071ee2 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -475,6 +475,7 @@  static struct virtio_transport virtio_transport = {
 		.seqpacket_dequeue        = virtio_transport_seqpacket_dequeue,
 		.seqpacket_enqueue        = virtio_transport_seqpacket_enqueue,
 		.seqpacket_allow          = virtio_transport_seqpacket_allow,
+		.seqpacket_drop           = virtio_transport_seqpacket_drop,
 
 		.notify_poll_in           = virtio_transport_notify_poll_in,
 		.notify_poll_out          = virtio_transport_notify_poll_out,
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index ce67cf449ef8..52765754edcd 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -425,7 +425,7 @@  static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
 		pkt = list_first_entry(&vvs->rx_queue, struct virtio_vsock_pkt, list);
 		pkt_len = (size_t)le32_to_cpu(pkt->hdr.len);
 
-		bytes_to_copy = min(user_buf_len, pkt_len);
+		bytes_to_copy = vvs->drop_until_eor ? 0 : min(user_buf_len, pkt_len);
 
 		if (bytes_to_copy) {
 			int err;
@@ -438,17 +438,22 @@  static int virtio_transport_seqpacket_do_dequeue(struct vsock_sock *vsk,
 
 			spin_lock_bh(&vvs->rx_lock);
 
-			if (err)
+			if (err) {
 				dequeued_len = err;
-			else
+				vvs->drop_until_eor = true;
+			} else {
 				user_buf_len -= bytes_to_copy;
+			}
 		}
 
 		if (dequeued_len >= 0)
 			dequeued_len += pkt_len;
 
 		if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SEQ_EOR) {
-			*msg_ready = true;
+			if (vvs->drop_until_eor)
+				vvs->drop_until_eor = false;
+			else
+				*msg_ready = true;
 		}
 
 		virtio_transport_dec_rx_pkt(vvs, pkt);
@@ -487,6 +492,16 @@  virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk,
 }
 EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_dequeue);
 
+void virtio_transport_seqpacket_drop(struct vsock_sock *vsk)
+{
+	struct virtio_vsock_sock *vvs = vsk->trans;
+
+	spin_lock_bh(&vvs->rx_lock);
+	vvs->drop_until_eor = true;
+	spin_unlock_bh(&vvs->rx_lock);
+}
+EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_drop);
+
 int
 virtio_transport_seqpacket_enqueue(struct vsock_sock *vsk,
 				   struct msghdr *msg,
diff --git a/net/vmw_vsock/vsock_loopback.c b/net/vmw_vsock/vsock_loopback.c
index 809f807d0710..d9030a46e4b9 100644
--- a/net/vmw_vsock/vsock_loopback.c
+++ b/net/vmw_vsock/vsock_loopback.c
@@ -94,6 +94,7 @@  static struct virtio_transport loopback_transport = {
 		.seqpacket_dequeue        = virtio_transport_seqpacket_dequeue,
 		.seqpacket_enqueue        = virtio_transport_seqpacket_enqueue,
 		.seqpacket_allow          = vsock_loopback_seqpacket_allow,
+		.seqpacket_drop           = virtio_transport_seqpacket_drop,
 
 		.notify_poll_in           = virtio_transport_notify_poll_in,
 		.notify_poll_out          = virtio_transport_notify_poll_out,