diff mbox series

[RFC,net-next,v6,07/14] virtio/vsock: add common datagram send path

Message ID 20240710212555.1617795-8-amery.hung@bytedance.com (mailing list archive)
State New, archived
Headers show
Series virtio/vsock: support datagrams | expand

Commit Message

Amery Hung July 10, 2024, 9:25 p.m. UTC
From: Bobby Eshleman <bobby.eshleman@bytedance.com>

This commit implements the common function
virtio_transport_dgram_enqueue for enqueueing datagrams. It does not add
usage in either vhost or virtio yet.

Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
Signed-off-by: Amery Hung <amery.hung@bytedance.com>
---
 include/linux/virtio_vsock.h            |  1 +
 include/net/af_vsock.h                  |  2 +
 net/vmw_vsock/af_vsock.c                |  2 +-
 net/vmw_vsock/virtio_transport_common.c | 87 ++++++++++++++++++++++++-
 4 files changed, 90 insertions(+), 2 deletions(-)

Comments

Stefano Garzarella July 23, 2024, 2:42 p.m. UTC | #1
On Wed, Jul 10, 2024 at 09:25:48PM GMT, Amery Hung wrote:
>From: Bobby Eshleman <bobby.eshleman@bytedance.com>
>
>This commit implements the common function
>virtio_transport_dgram_enqueue for enqueueing datagrams. It does not add
>usage in either vhost or virtio yet.
>
>Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
>Signed-off-by: Amery Hung <amery.hung@bytedance.com>
>---
> include/linux/virtio_vsock.h            |  1 +
> include/net/af_vsock.h                  |  2 +
> net/vmw_vsock/af_vsock.c                |  2 +-
> net/vmw_vsock/virtio_transport_common.c | 87 ++++++++++++++++++++++++-
> 4 files changed, 90 insertions(+), 2 deletions(-)
>
>diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
>index f749a066af46..4408749febd2 100644
>--- a/include/linux/virtio_vsock.h
>+++ b/include/linux/virtio_vsock.h
>@@ -152,6 +152,7 @@ struct virtio_vsock_pkt_info {
> 	u16 op;
> 	u32 flags;
> 	bool reply;
>+	u8 remote_flags;
> };
>
> struct virtio_transport {
>diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
>index 44db8f2c507d..6e97d344ac75 100644
>--- a/include/net/af_vsock.h
>+++ b/include/net/af_vsock.h
>@@ -216,6 +216,8 @@ void vsock_for_each_connected_socket(struct vsock_transport *transport,
> 				     void (*fn)(struct sock *sk));
> int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
> bool vsock_find_cid(unsigned int cid);
>+const struct vsock_transport *vsock_dgram_lookup_transport(unsigned int cid,
>+							   __u8 flags);

Why __u8 and not just u8?


>
> struct vsock_skb_cb {
> 	unsigned int src_cid;
>diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
>index ab08cd81720e..f83b655fdbe9 100644
>--- a/net/vmw_vsock/af_vsock.c
>+++ b/net/vmw_vsock/af_vsock.c
>@@ -487,7 +487,7 @@ vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
> 	return transport;
> }
>
>-static const struct vsock_transport *
>+const struct vsock_transport *
> vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
> {
> 	const struct vsock_transport *transport;
>diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
>index a1c76836d798..46cd1807f8e3 100644
>--- a/net/vmw_vsock/virtio_transport_common.c
>+++ b/net/vmw_vsock/virtio_transport_common.c
>@@ -1040,13 +1040,98 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
> }
> EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
>
>+static int virtio_transport_dgram_send_pkt_info(struct vsock_sock *vsk,
>+						struct virtio_vsock_pkt_info *info)
>+{
>+	u32 src_cid, src_port, dst_cid, dst_port;
>+	const struct vsock_transport *transport;
>+	const struct virtio_transport *t_ops;
>+	struct sock *sk = sk_vsock(vsk);
>+	struct virtio_vsock_hdr *hdr;
>+	struct sk_buff *skb;
>+	void *payload;
>+	int noblock = 0;
>+	int err;
>+
>+	info->type = virtio_transport_get_type(sk_vsock(vsk));
>+
>+	if (info->pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
>+		return -EMSGSIZE;
>+
>+	transport = vsock_dgram_lookup_transport(info->remote_cid, info->remote_flags);

Can `transport` be null?

I don't understand why we are calling vsock_dgram_lookup_transport()
again. Didn't we already do that in vsock_dgram_sendmsg()?

Also should we add a comment mentioning that we can't use
virtio_transport_get_ops()? IIUC becuase the vsk can be not assigned
to a specific transport, right?

>+	t_ops = container_of(transport, struct virtio_transport, transport);
>+	if (unlikely(!t_ops))
>+		return -EFAULT;
>+
>+	if (info->msg)
>+		noblock = info->msg->msg_flags & MSG_DONTWAIT;
>+
>+	/* Use sock_alloc_send_skb to throttle by sk_sndbuf. This helps avoid
>+	 * triggering the OOM.
>+	 */
>+	skb = sock_alloc_send_skb(sk, info->pkt_len + VIRTIO_VSOCK_SKB_HEADROOM,
>+				  noblock, &err);
>+	if (!skb)
>+		return err;
>+
>+	skb_reserve(skb, VIRTIO_VSOCK_SKB_HEADROOM);
>+
>+	src_cid = t_ops->transport.get_local_cid();
>+	src_port = vsk->local_addr.svm_port;
>+	dst_cid = info->remote_cid;
>+	dst_port = info->remote_port;
>+
>+	hdr = virtio_vsock_hdr(skb);
>+	hdr->type	= cpu_to_le16(info->type);
>+	hdr->op		= cpu_to_le16(info->op);
>+	hdr->src_cid	= cpu_to_le64(src_cid);
>+	hdr->dst_cid	= cpu_to_le64(dst_cid);
>+	hdr->src_port	= cpu_to_le32(src_port);
>+	hdr->dst_port	= cpu_to_le32(dst_port);
>+	hdr->flags	= cpu_to_le32(info->flags);
>+	hdr->len	= cpu_to_le32(info->pkt_len);
>+
>+	if (info->msg && info->pkt_len > 0) {
>+		payload = skb_put(skb, info->pkt_len);
>+		err = memcpy_from_msg(payload, info->msg, info->pkt_len);
>+		if (err)
>+			goto out;
>+	}
>+
>+	trace_virtio_transport_alloc_pkt(src_cid, src_port,
>+					 dst_cid, dst_port,
>+					 info->pkt_len,
>+					 info->type,
>+					 info->op,
>+					 info->flags,
>+					 false);
>+
>+	return t_ops->send_pkt(skb);
>+out:
>+	kfree_skb(skb);
>+	return err;
>+}
>+
> int
> virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
> 			       struct sockaddr_vm *remote_addr,
> 			       struct msghdr *msg,
> 			       size_t dgram_len)
> {
>-	return -EOPNOTSUPP;
>+	/* Here we are only using the info struct to retain style uniformity
>+	 * and to ease future refactoring and merging.
>+	 */
>+	struct virtio_vsock_pkt_info info = {
>+		.op = VIRTIO_VSOCK_OP_RW,
>+		.remote_cid = remote_addr->svm_cid,
>+		.remote_port = remote_addr->svm_port,
>+		.remote_flags = remote_addr->svm_flags,
>+		.msg = msg,
>+		.vsk = vsk,
>+		.pkt_len = dgram_len,
>+	};
>+
>+	return virtio_transport_dgram_send_pkt_info(vsk, &info);
> }
> EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
>
>-- 
>2.20.1
>
Amery Hung July 26, 2024, 11:22 p.m. UTC | #2
On Tue, Jul 23, 2024 at 7:42 AM Stefano Garzarella <sgarzare@redhat.com> wrote:
>
> On Wed, Jul 10, 2024 at 09:25:48PM GMT, Amery Hung wrote:
> >From: Bobby Eshleman <bobby.eshleman@bytedance.com>
> >
> >This commit implements the common function
> >virtio_transport_dgram_enqueue for enqueueing datagrams. It does not add
> >usage in either vhost or virtio yet.
> >
> >Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
> >Signed-off-by: Amery Hung <amery.hung@bytedance.com>
> >---
> > include/linux/virtio_vsock.h            |  1 +
> > include/net/af_vsock.h                  |  2 +
> > net/vmw_vsock/af_vsock.c                |  2 +-
> > net/vmw_vsock/virtio_transport_common.c | 87 ++++++++++++++++++++++++-
> > 4 files changed, 90 insertions(+), 2 deletions(-)
> >
> >diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
> >index f749a066af46..4408749febd2 100644
> >--- a/include/linux/virtio_vsock.h
> >+++ b/include/linux/virtio_vsock.h
> >@@ -152,6 +152,7 @@ struct virtio_vsock_pkt_info {
> >       u16 op;
> >       u32 flags;
> >       bool reply;
> >+      u8 remote_flags;
> > };
> >
> > struct virtio_transport {
> >diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
> >index 44db8f2c507d..6e97d344ac75 100644
> >--- a/include/net/af_vsock.h
> >+++ b/include/net/af_vsock.h
> >@@ -216,6 +216,8 @@ void vsock_for_each_connected_socket(struct vsock_transport *transport,
> >                                    void (*fn)(struct sock *sk));
> > int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
> > bool vsock_find_cid(unsigned int cid);
> >+const struct vsock_transport *vsock_dgram_lookup_transport(unsigned int cid,
> >+                                                         __u8 flags);
>
> Why __u8 and not just u8?
>

Will change to u8.

>
> >
> > struct vsock_skb_cb {
> >       unsigned int src_cid;
> >diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
> >index ab08cd81720e..f83b655fdbe9 100644
> >--- a/net/vmw_vsock/af_vsock.c
> >+++ b/net/vmw_vsock/af_vsock.c
> >@@ -487,7 +487,7 @@ vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
> >       return transport;
> > }
> >
> >-static const struct vsock_transport *
> >+const struct vsock_transport *
> > vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
> > {
> >       const struct vsock_transport *transport;
> >diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
> >index a1c76836d798..46cd1807f8e3 100644
> >--- a/net/vmw_vsock/virtio_transport_common.c
> >+++ b/net/vmw_vsock/virtio_transport_common.c
> >@@ -1040,13 +1040,98 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
> > }
> > EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
> >
> >+static int virtio_transport_dgram_send_pkt_info(struct vsock_sock *vsk,
> >+                                              struct virtio_vsock_pkt_info *info)
> >+{
> >+      u32 src_cid, src_port, dst_cid, dst_port;
> >+      const struct vsock_transport *transport;
> >+      const struct virtio_transport *t_ops;
> >+      struct sock *sk = sk_vsock(vsk);
> >+      struct virtio_vsock_hdr *hdr;
> >+      struct sk_buff *skb;
> >+      void *payload;
> >+      int noblock = 0;
> >+      int err;
> >+
> >+      info->type = virtio_transport_get_type(sk_vsock(vsk));
> >+
> >+      if (info->pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
> >+              return -EMSGSIZE;
> >+
> >+      transport = vsock_dgram_lookup_transport(info->remote_cid, info->remote_flags);
>
> Can `transport` be null?
>
> I don't understand why we are calling vsock_dgram_lookup_transport()
> again. Didn't we already do that in vsock_dgram_sendmsg()?
>

transport should be valid here since we null-checked it in
vsock_dgram_sendmsg(). The reason vsock_dgram_lookup_transport() is
called again here is we don't have the transport when we called into
transport->dgram_enqueue(). I can also instead add transport to the
argument of dgram_enqueue() to eliminate this redundant lookup.

> Also should we add a comment mentioning that we can't use
> virtio_transport_get_ops()? IIUC becuase the vsk can be not assigned
> to a specific transport, right?
>

Correct. For virtio dgram socket, transport is not assigned unless
vsock_dgram_connect() is called. I will add a comment here explaining
this.

> >+      t_ops = container_of(transport, struct virtio_transport, transport);
> >+      if (unlikely(!t_ops))
> >+              return -EFAULT;
> >+
> >+      if (info->msg)
> >+              noblock = info->msg->msg_flags & MSG_DONTWAIT;
> >+
> >+      /* Use sock_alloc_send_skb to throttle by sk_sndbuf. This helps avoid
> >+       * triggering the OOM.
> >+       */
> >+      skb = sock_alloc_send_skb(sk, info->pkt_len + VIRTIO_VSOCK_SKB_HEADROOM,
> >+                                noblock, &err);
> >+      if (!skb)
> >+              return err;
> >+
> >+      skb_reserve(skb, VIRTIO_VSOCK_SKB_HEADROOM);
> >+
> >+      src_cid = t_ops->transport.get_local_cid();
> >+      src_port = vsk->local_addr.svm_port;
> >+      dst_cid = info->remote_cid;
> >+      dst_port = info->remote_port;
> >+
> >+      hdr = virtio_vsock_hdr(skb);
> >+      hdr->type       = cpu_to_le16(info->type);
> >+      hdr->op         = cpu_to_le16(info->op);
> >+      hdr->src_cid    = cpu_to_le64(src_cid);
> >+      hdr->dst_cid    = cpu_to_le64(dst_cid);
> >+      hdr->src_port   = cpu_to_le32(src_port);
> >+      hdr->dst_port   = cpu_to_le32(dst_port);
> >+      hdr->flags      = cpu_to_le32(info->flags);
> >+      hdr->len        = cpu_to_le32(info->pkt_len);
> >+
> >+      if (info->msg && info->pkt_len > 0) {
> >+              payload = skb_put(skb, info->pkt_len);
> >+              err = memcpy_from_msg(payload, info->msg, info->pkt_len);
> >+              if (err)
> >+                      goto out;
> >+      }
> >+
> >+      trace_virtio_transport_alloc_pkt(src_cid, src_port,
> >+                                       dst_cid, dst_port,
> >+                                       info->pkt_len,
> >+                                       info->type,
> >+                                       info->op,
> >+                                       info->flags,
> >+                                       false);
> >+
> >+      return t_ops->send_pkt(skb);
> >+out:
> >+      kfree_skb(skb);
> >+      return err;
> >+}
> >+
> > int
> > virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
> >                              struct sockaddr_vm *remote_addr,
> >                              struct msghdr *msg,
> >                              size_t dgram_len)
> > {
> >-      return -EOPNOTSUPP;
> >+      /* Here we are only using the info struct to retain style uniformity
> >+       * and to ease future refactoring and merging.
> >+       */
> >+      struct virtio_vsock_pkt_info info = {
> >+              .op = VIRTIO_VSOCK_OP_RW,
> >+              .remote_cid = remote_addr->svm_cid,
> >+              .remote_port = remote_addr->svm_port,
> >+              .remote_flags = remote_addr->svm_flags,
> >+              .msg = msg,
> >+              .vsk = vsk,
> >+              .pkt_len = dgram_len,
> >+      };
> >+
> >+      return virtio_transport_dgram_send_pkt_info(vsk, &info);
> > }
> > EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
> >
> >--
> >2.20.1
> >
>
Arseniy Krasnov July 29, 2024, 8 p.m. UTC | #3
Hi,

> diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
> index a1c76836d798..46cd1807f8e3 100644
> --- a/net/vmw_vsock/virtio_transport_common.c
> +++ b/net/vmw_vsock/virtio_transport_common.c
> @@ -1040,13 +1040,98 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
>  }
>  EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
>  
> +static int virtio_transport_dgram_send_pkt_info(struct vsock_sock *vsk,
> +						struct virtio_vsock_pkt_info *info)
> +{
> +	u32 src_cid, src_port, dst_cid, dst_port;
> +	const struct vsock_transport *transport;
> +	const struct virtio_transport *t_ops;
> +	struct sock *sk = sk_vsock(vsk);
> +	struct virtio_vsock_hdr *hdr;
> +	struct sk_buff *skb;
> +	void *payload;
> +	int noblock = 0;
> +	int err;
> +
> +	info->type = virtio_transport_get_type(sk_vsock(vsk));
> +
> +	if (info->pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
> +		return -EMSGSIZE;

Small suggestion, i think we can check for packet length earlier ? Before
info->type = ...

> +
> +	transport = vsock_dgram_lookup_transport(info->remote_cid, info->remote_flags);
> +	t_ops = container_of(transport, struct virtio_transport, transport);
> +	if (unlikely(!t_ops))
> +		return -EFAULT;
> +
> +	if (info->msg)
> +		noblock = info->msg->msg_flags & MSG_DONTWAIT;
> +
> +	/* Use sock_alloc_send_skb to throttle by sk_sndbuf. This helps avoid
> +	 * triggering the OOM.
> +	 */
> +	skb = sock_alloc_send_skb(sk, info->pkt_len + VIRTIO_VSOCK_SKB_HEADROOM,
> +				  noblock, &err);
> +	if (!skb)
> +		return err;
> +
> +	skb_reserve(skb, VIRTIO_VSOCK_SKB_HEADROOM);
> +
> +	src_cid = t_ops->transport.get_local_cid();
> +	src_port = vsk->local_addr.svm_port;
> +	dst_cid = info->remote_cid;
> +	dst_port = info->remote_port;
> +
> +	hdr = virtio_vsock_hdr(skb);
> +	hdr->type	= cpu_to_le16(info->type);
> +	hdr->op		= cpu_to_le16(info->op);
> +	hdr->src_cid	= cpu_to_le64(src_cid);
> +	hdr->dst_cid	= cpu_to_le64(dst_cid);
> +	hdr->src_port	= cpu_to_le32(src_port);
> +	hdr->dst_port	= cpu_to_le32(dst_port);
> +	hdr->flags	= cpu_to_le32(info->flags);
> +	hdr->len	= cpu_to_le32(info->pkt_len);

There is function 'virtio_transport_init_hdr()' in this file, may be reuse it ?

> +
> +	if (info->msg && info->pkt_len > 0) {

If pkt_len is 0, do we really need to send such packets ? Because for connectible
sockets, we ignore empty OP_RW packets.

> +		payload = skb_put(skb, info->pkt_len);
> +		err = memcpy_from_msg(payload, info->msg, info->pkt_len);
> +		if (err)
> +			goto out;
> +	}
> +
> +	trace_virtio_transport_alloc_pkt(src_cid, src_port,
> +					 dst_cid, dst_port,
> +					 info->pkt_len,
> +					 info->type,
> +					 info->op,
> +					 info->flags,
> +					 false);

^^^ For SOCK_DGRAM, include/trace/events/vsock_virtio_transport_common.h also should
be updated?

> +
> +	return t_ops->send_pkt(skb);
> +out:
> +	kfree_skb(skb);
> +	return err;
> +}
> +
>  int
>  virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
>  			       struct sockaddr_vm *remote_addr,
>  			       struct msghdr *msg,
>  			       size_t dgram_len)
>  {
> -	return -EOPNOTSUPP;
> +	/* Here we are only using the info struct to retain style uniformity
> +	 * and to ease future refactoring and merging.
> +	 */
> +	struct virtio_vsock_pkt_info info = {
> +		.op = VIRTIO_VSOCK_OP_RW,
> +		.remote_cid = remote_addr->svm_cid,
> +		.remote_port = remote_addr->svm_port,
> +		.remote_flags = remote_addr->svm_flags,
> +		.msg = msg,
> +		.vsk = vsk,
> +		.pkt_len = dgram_len,
> +	};
> +
> +	return virtio_transport_dgram_send_pkt_info(vsk, &info);
>  }
>  EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
>  
> -- 
> 2.20.1

Thanks, Arseniy
Amery Hung July 29, 2024, 10:51 p.m. UTC | #4
On Mon, Jul 29, 2024 at 1:12 PM Arseniy Krasnov
<avkrasnov@salutedevices.com> wrote:
>
> Hi,
>
> > diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
> > index a1c76836d798..46cd1807f8e3 100644
> > --- a/net/vmw_vsock/virtio_transport_common.c
> > +++ b/net/vmw_vsock/virtio_transport_common.c
> > @@ -1040,13 +1040,98 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
> >  }
> >  EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
> >
> > +static int virtio_transport_dgram_send_pkt_info(struct vsock_sock *vsk,
> > +                                             struct virtio_vsock_pkt_info *info)
> > +{
> > +     u32 src_cid, src_port, dst_cid, dst_port;
> > +     const struct vsock_transport *transport;
> > +     const struct virtio_transport *t_ops;
> > +     struct sock *sk = sk_vsock(vsk);
> > +     struct virtio_vsock_hdr *hdr;
> > +     struct sk_buff *skb;
> > +     void *payload;
> > +     int noblock = 0;
> > +     int err;
> > +
> > +     info->type = virtio_transport_get_type(sk_vsock(vsk));
> > +
> > +     if (info->pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
> > +             return -EMSGSIZE;
>
> Small suggestion, i think we can check for packet length earlier ? Before
> info->type = ...

Certainly.

>
> > +
> > +     transport = vsock_dgram_lookup_transport(info->remote_cid, info->remote_flags);
> > +     t_ops = container_of(transport, struct virtio_transport, transport);
> > +     if (unlikely(!t_ops))
> > +             return -EFAULT;
> > +
> > +     if (info->msg)
> > +             noblock = info->msg->msg_flags & MSG_DONTWAIT;
> > +
> > +     /* Use sock_alloc_send_skb to throttle by sk_sndbuf. This helps avoid
> > +      * triggering the OOM.
> > +      */
> > +     skb = sock_alloc_send_skb(sk, info->pkt_len + VIRTIO_VSOCK_SKB_HEADROOM,
> > +                               noblock, &err);
> > +     if (!skb)
> > +             return err;
> > +
> > +     skb_reserve(skb, VIRTIO_VSOCK_SKB_HEADROOM);
> > +
> > +     src_cid = t_ops->transport.get_local_cid();
> > +     src_port = vsk->local_addr.svm_port;
> > +     dst_cid = info->remote_cid;
> > +     dst_port = info->remote_port;
> > +
> > +     hdr = virtio_vsock_hdr(skb);
> > +     hdr->type       = cpu_to_le16(info->type);
> > +     hdr->op         = cpu_to_le16(info->op);
> > +     hdr->src_cid    = cpu_to_le64(src_cid);
> > +     hdr->dst_cid    = cpu_to_le64(dst_cid);
> > +     hdr->src_port   = cpu_to_le32(src_port);
> > +     hdr->dst_port   = cpu_to_le32(dst_port);
> > +     hdr->flags      = cpu_to_le32(info->flags);
> > +     hdr->len        = cpu_to_le32(info->pkt_len);
>
> There is function 'virtio_transport_init_hdr()' in this file, may be reuse it ?

Will do.

>
> > +
> > +     if (info->msg && info->pkt_len > 0) {
>
> If pkt_len is 0, do we really need to send such packets ? Because for connectible
> sockets, we ignore empty OP_RW packets.

Thanks for pointing this out. I think virtio dgram should also follow that.

>
> > +             payload = skb_put(skb, info->pkt_len);
> > +             err = memcpy_from_msg(payload, info->msg, info->pkt_len);
> > +             if (err)
> > +                     goto out;
> > +     }
> > +
> > +     trace_virtio_transport_alloc_pkt(src_cid, src_port,
> > +                                      dst_cid, dst_port,
> > +                                      info->pkt_len,
> > +                                      info->type,
> > +                                      info->op,
> > +                                      info->flags,
> > +                                      false);
>
> ^^^ For SOCK_DGRAM, include/trace/events/vsock_virtio_transport_common.h also should
> be updated?

Can you elaborate what needs to be changed?

Thank you,
Amery

>
> > +
> > +     return t_ops->send_pkt(skb);
> > +out:
> > +     kfree_skb(skb);
> > +     return err;
> > +}
> > +
> >  int
> >  virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
> >                              struct sockaddr_vm *remote_addr,
> >                              struct msghdr *msg,
> >                              size_t dgram_len)
> >  {
> > -     return -EOPNOTSUPP;
> > +     /* Here we are only using the info struct to retain style uniformity
> > +      * and to ease future refactoring and merging.
> > +      */
> > +     struct virtio_vsock_pkt_info info = {
> > +             .op = VIRTIO_VSOCK_OP_RW,
> > +             .remote_cid = remote_addr->svm_cid,
> > +             .remote_port = remote_addr->svm_port,
> > +             .remote_flags = remote_addr->svm_flags,
> > +             .msg = msg,
> > +             .vsk = vsk,
> > +             .pkt_len = dgram_len,
> > +     };
> > +
> > +     return virtio_transport_dgram_send_pkt_info(vsk, &info);
> >  }
> >  EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
> >
> > --
> > 2.20.1
>
> Thanks, Arseniy
Arseniy Krasnov July 30, 2024, 5:09 a.m. UTC | #5
On 30.07.2024 01:51, Amery Hung wrote:
> On Mon, Jul 29, 2024 at 1:12 PM Arseniy Krasnov
> <avkrasnov@salutedevices.com> wrote:
>>
>> Hi,
>>
>>> diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
>>> index a1c76836d798..46cd1807f8e3 100644
>>> --- a/net/vmw_vsock/virtio_transport_common.c
>>> +++ b/net/vmw_vsock/virtio_transport_common.c
>>> @@ -1040,13 +1040,98 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
>>>  }
>>>  EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
>>>
>>> +static int virtio_transport_dgram_send_pkt_info(struct vsock_sock *vsk,
>>> +                                             struct virtio_vsock_pkt_info *info)
>>> +{
>>> +     u32 src_cid, src_port, dst_cid, dst_port;
>>> +     const struct vsock_transport *transport;
>>> +     const struct virtio_transport *t_ops;
>>> +     struct sock *sk = sk_vsock(vsk);
>>> +     struct virtio_vsock_hdr *hdr;
>>> +     struct sk_buff *skb;
>>> +     void *payload;
>>> +     int noblock = 0;
>>> +     int err;
>>> +
>>> +     info->type = virtio_transport_get_type(sk_vsock(vsk));
>>> +
>>> +     if (info->pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
>>> +             return -EMSGSIZE;
>>
>> Small suggestion, i think we can check for packet length earlier ? Before
>> info->type = ...
> 
> Certainly.
> 
>>
>>> +
>>> +     transport = vsock_dgram_lookup_transport(info->remote_cid, info->remote_flags);
>>> +     t_ops = container_of(transport, struct virtio_transport, transport);
>>> +     if (unlikely(!t_ops))
>>> +             return -EFAULT;
>>> +
>>> +     if (info->msg)
>>> +             noblock = info->msg->msg_flags & MSG_DONTWAIT;
>>> +
>>> +     /* Use sock_alloc_send_skb to throttle by sk_sndbuf. This helps avoid
>>> +      * triggering the OOM.
>>> +      */
>>> +     skb = sock_alloc_send_skb(sk, info->pkt_len + VIRTIO_VSOCK_SKB_HEADROOM,
>>> +                               noblock, &err);
>>> +     if (!skb)
>>> +             return err;
>>> +
>>> +     skb_reserve(skb, VIRTIO_VSOCK_SKB_HEADROOM);
>>> +
>>> +     src_cid = t_ops->transport.get_local_cid();
>>> +     src_port = vsk->local_addr.svm_port;
>>> +     dst_cid = info->remote_cid;
>>> +     dst_port = info->remote_port;
>>> +
>>> +     hdr = virtio_vsock_hdr(skb);
>>> +     hdr->type       = cpu_to_le16(info->type);
>>> +     hdr->op         = cpu_to_le16(info->op);
>>> +     hdr->src_cid    = cpu_to_le64(src_cid);
>>> +     hdr->dst_cid    = cpu_to_le64(dst_cid);
>>> +     hdr->src_port   = cpu_to_le32(src_port);
>>> +     hdr->dst_port   = cpu_to_le32(dst_port);
>>> +     hdr->flags      = cpu_to_le32(info->flags);
>>> +     hdr->len        = cpu_to_le32(info->pkt_len);
>>
>> There is function 'virtio_transport_init_hdr()' in this file, may be reuse it ?
> 
> Will do.
> 
>>
>>> +
>>> +     if (info->msg && info->pkt_len > 0) {
>>
>> If pkt_len is 0, do we really need to send such packets ? Because for connectible
>> sockets, we ignore empty OP_RW packets.
> 
> Thanks for pointing this out. I think virtio dgram should also follow that.
> 
>>
>>> +             payload = skb_put(skb, info->pkt_len);
>>> +             err = memcpy_from_msg(payload, info->msg, info->pkt_len);
>>> +             if (err)
>>> +                     goto out;
>>> +     }
>>> +
>>> +     trace_virtio_transport_alloc_pkt(src_cid, src_port,
>>> +                                      dst_cid, dst_port,
>>> +                                      info->pkt_len,
>>> +                                      info->type,
>>> +                                      info->op,
>>> +                                      info->flags,
>>> +                                      false);
>>
>> ^^^ For SOCK_DGRAM, include/trace/events/vsock_virtio_transport_common.h also should
>> be updated?
> 
> Can you elaborate what needs to be changed?

Sure, there are:

TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_STREAM);
TRACE_DEFINE_ENUM(VIRTIO_VSOCK_TYPE_SEQPACKET);

#define show_type(val) \
	__print_symbolic(val, \
			 { VIRTIO_VSOCK_TYPE_STREAM, "STREAM" }, \
			 { VIRTIO_VSOCK_TYPE_SEQPACKET, "SEQPACKET" })

I guess SOCK_DGRAM handling should be added to print type of socket.

Thanks, Arseniy

> 
> Thank you,
> Amery
> 
>>
>>> +
>>> +     return t_ops->send_pkt(skb);
>>> +out:
>>> +     kfree_skb(skb);
>>> +     return err;
>>> +}
>>> +
>>>  int
>>>  virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
>>>                              struct sockaddr_vm *remote_addr,
>>>                              struct msghdr *msg,
>>>                              size_t dgram_len)
>>>  {
>>> -     return -EOPNOTSUPP;
>>> +     /* Here we are only using the info struct to retain style uniformity
>>> +      * and to ease future refactoring and merging.
>>> +      */
>>> +     struct virtio_vsock_pkt_info info = {
>>> +             .op = VIRTIO_VSOCK_OP_RW,
>>> +             .remote_cid = remote_addr->svm_cid,
>>> +             .remote_port = remote_addr->svm_port,
>>> +             .remote_flags = remote_addr->svm_flags,
>>> +             .msg = msg,
>>> +             .vsk = vsk,
>>> +             .pkt_len = dgram_len,
>>> +     };
>>> +
>>> +     return virtio_transport_dgram_send_pkt_info(vsk, &info);
>>>  }
>>>  EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
>>>
>>> --
>>> 2.20.1
>>
>> Thanks, Arseniy
Stefano Garzarella July 30, 2024, 8:22 a.m. UTC | #6
On Fri, Jul 26, 2024 at 04:22:16PM GMT, Amery Hung wrote:
>On Tue, Jul 23, 2024 at 7:42 AM Stefano Garzarella <sgarzare@redhat.com> wrote:
>>
>> On Wed, Jul 10, 2024 at 09:25:48PM GMT, Amery Hung wrote:
>> >From: Bobby Eshleman <bobby.eshleman@bytedance.com>
>> >
>> >This commit implements the common function
>> >virtio_transport_dgram_enqueue for enqueueing datagrams. It does not add
>> >usage in either vhost or virtio yet.
>> >
>> >Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
>> >Signed-off-by: Amery Hung <amery.hung@bytedance.com>
>> >---
>> > include/linux/virtio_vsock.h            |  1 +
>> > include/net/af_vsock.h                  |  2 +
>> > net/vmw_vsock/af_vsock.c                |  2 +-
>> > net/vmw_vsock/virtio_transport_common.c | 87 ++++++++++++++++++++++++-
>> > 4 files changed, 90 insertions(+), 2 deletions(-)
>> >
>> >diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
>> >index f749a066af46..4408749febd2 100644
>> >--- a/include/linux/virtio_vsock.h
>> >+++ b/include/linux/virtio_vsock.h
>> >@@ -152,6 +152,7 @@ struct virtio_vsock_pkt_info {
>> >       u16 op;
>> >       u32 flags;
>> >       bool reply;
>> >+      u8 remote_flags;
>> > };
>> >
>> > struct virtio_transport {
>> >diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
>> >index 44db8f2c507d..6e97d344ac75 100644
>> >--- a/include/net/af_vsock.h
>> >+++ b/include/net/af_vsock.h
>> >@@ -216,6 +216,8 @@ void vsock_for_each_connected_socket(struct vsock_transport *transport,
>> >                                    void (*fn)(struct sock *sk));
>> > int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
>> > bool vsock_find_cid(unsigned int cid);
>> >+const struct vsock_transport *vsock_dgram_lookup_transport(unsigned int cid,
>> >+                                                         __u8 flags);
>>
>> Why __u8 and not just u8?
>>
>
>Will change to u8.
>
>>
>> >
>> > struct vsock_skb_cb {
>> >       unsigned int src_cid;
>> >diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
>> >index ab08cd81720e..f83b655fdbe9 100644
>> >--- a/net/vmw_vsock/af_vsock.c
>> >+++ b/net/vmw_vsock/af_vsock.c
>> >@@ -487,7 +487,7 @@ vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
>> >       return transport;
>> > }
>> >
>> >-static const struct vsock_transport *
>> >+const struct vsock_transport *
>> > vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
>> > {
>> >       const struct vsock_transport *transport;
>> >diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
>> >index a1c76836d798..46cd1807f8e3 100644
>> >--- a/net/vmw_vsock/virtio_transport_common.c
>> >+++ b/net/vmw_vsock/virtio_transport_common.c
>> >@@ -1040,13 +1040,98 @@ int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
>> > }
>> > EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
>> >
>> >+static int virtio_transport_dgram_send_pkt_info(struct vsock_sock *vsk,
>> >+                                              struct virtio_vsock_pkt_info *info)
>> >+{
>> >+      u32 src_cid, src_port, dst_cid, dst_port;
>> >+      const struct vsock_transport *transport;
>> >+      const struct virtio_transport *t_ops;
>> >+      struct sock *sk = sk_vsock(vsk);
>> >+      struct virtio_vsock_hdr *hdr;
>> >+      struct sk_buff *skb;
>> >+      void *payload;
>> >+      int noblock = 0;
>> >+      int err;
>> >+
>> >+      info->type = virtio_transport_get_type(sk_vsock(vsk));
>> >+
>> >+      if (info->pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
>> >+              return -EMSGSIZE;
>> >+
>> >+      transport = vsock_dgram_lookup_transport(info->remote_cid, info->remote_flags);
>>
>> Can `transport` be null?
>>
>> I don't understand why we are calling vsock_dgram_lookup_transport()
>> again. Didn't we already do that in vsock_dgram_sendmsg()?
>>
>
>transport should be valid here sin)e we null-checked it in
>vsock_dgram_sendmsg(). The reason vsock_dgram_lookup_transport() is
>called again here is we don't have the transport when we called into
>transport->dgram_enqueue(). I can also instead add transport to the
>argument of dgram_enqueue() to eliminate this redundant lookup.

Yes, I would absolutely eliminate this double lookup.

You can add either a parameter, or define the callback in each transport 
and internally use the statically allocated transport in each.

For example for vhost/vsock.c:

static int vhost_transport_dgram_enqueue(....) {
     return virtio_transport_dgram_enqueue(&vhost_transport.transport,
                                           ...)
}

In virtio_transport_recv_pkt() we already do something similar.

>
>> Also should we add a comment mentioning that we can't use
>> virtio_transport_get_ops()? IIUC becuase the vsk can be not assigned
>> to a specific transport, right?
>>
>
>Correct. For virtio dgram socket, transport is not assigned unless
>vsock_dgram_connect() is called. I will add a comment here explaining
>this.

Thanks,
Stefano
diff mbox series

Patch

diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h
index f749a066af46..4408749febd2 100644
--- a/include/linux/virtio_vsock.h
+++ b/include/linux/virtio_vsock.h
@@ -152,6 +152,7 @@  struct virtio_vsock_pkt_info {
 	u16 op;
 	u32 flags;
 	bool reply;
+	u8 remote_flags;
 };
 
 struct virtio_transport {
diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index 44db8f2c507d..6e97d344ac75 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -216,6 +216,8 @@  void vsock_for_each_connected_socket(struct vsock_transport *transport,
 				     void (*fn)(struct sock *sk));
 int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk);
 bool vsock_find_cid(unsigned int cid);
+const struct vsock_transport *vsock_dgram_lookup_transport(unsigned int cid,
+							   __u8 flags);
 
 struct vsock_skb_cb {
 	unsigned int src_cid;
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index ab08cd81720e..f83b655fdbe9 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -487,7 +487,7 @@  vsock_connectible_lookup_transport(unsigned int cid, __u8 flags)
 	return transport;
 }
 
-static const struct vsock_transport *
+const struct vsock_transport *
 vsock_dgram_lookup_transport(unsigned int cid, __u8 flags)
 {
 	const struct vsock_transport *transport;
diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c
index a1c76836d798..46cd1807f8e3 100644
--- a/net/vmw_vsock/virtio_transport_common.c
+++ b/net/vmw_vsock/virtio_transport_common.c
@@ -1040,13 +1040,98 @@  int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
 }
 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
 
+static int virtio_transport_dgram_send_pkt_info(struct vsock_sock *vsk,
+						struct virtio_vsock_pkt_info *info)
+{
+	u32 src_cid, src_port, dst_cid, dst_port;
+	const struct vsock_transport *transport;
+	const struct virtio_transport *t_ops;
+	struct sock *sk = sk_vsock(vsk);
+	struct virtio_vsock_hdr *hdr;
+	struct sk_buff *skb;
+	void *payload;
+	int noblock = 0;
+	int err;
+
+	info->type = virtio_transport_get_type(sk_vsock(vsk));
+
+	if (info->pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
+		return -EMSGSIZE;
+
+	transport = vsock_dgram_lookup_transport(info->remote_cid, info->remote_flags);
+	t_ops = container_of(transport, struct virtio_transport, transport);
+	if (unlikely(!t_ops))
+		return -EFAULT;
+
+	if (info->msg)
+		noblock = info->msg->msg_flags & MSG_DONTWAIT;
+
+	/* Use sock_alloc_send_skb to throttle by sk_sndbuf. This helps avoid
+	 * triggering the OOM.
+	 */
+	skb = sock_alloc_send_skb(sk, info->pkt_len + VIRTIO_VSOCK_SKB_HEADROOM,
+				  noblock, &err);
+	if (!skb)
+		return err;
+
+	skb_reserve(skb, VIRTIO_VSOCK_SKB_HEADROOM);
+
+	src_cid = t_ops->transport.get_local_cid();
+	src_port = vsk->local_addr.svm_port;
+	dst_cid = info->remote_cid;
+	dst_port = info->remote_port;
+
+	hdr = virtio_vsock_hdr(skb);
+	hdr->type	= cpu_to_le16(info->type);
+	hdr->op		= cpu_to_le16(info->op);
+	hdr->src_cid	= cpu_to_le64(src_cid);
+	hdr->dst_cid	= cpu_to_le64(dst_cid);
+	hdr->src_port	= cpu_to_le32(src_port);
+	hdr->dst_port	= cpu_to_le32(dst_port);
+	hdr->flags	= cpu_to_le32(info->flags);
+	hdr->len	= cpu_to_le32(info->pkt_len);
+
+	if (info->msg && info->pkt_len > 0) {
+		payload = skb_put(skb, info->pkt_len);
+		err = memcpy_from_msg(payload, info->msg, info->pkt_len);
+		if (err)
+			goto out;
+	}
+
+	trace_virtio_transport_alloc_pkt(src_cid, src_port,
+					 dst_cid, dst_port,
+					 info->pkt_len,
+					 info->type,
+					 info->op,
+					 info->flags,
+					 false);
+
+	return t_ops->send_pkt(skb);
+out:
+	kfree_skb(skb);
+	return err;
+}
+
 int
 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
 			       struct sockaddr_vm *remote_addr,
 			       struct msghdr *msg,
 			       size_t dgram_len)
 {
-	return -EOPNOTSUPP;
+	/* Here we are only using the info struct to retain style uniformity
+	 * and to ease future refactoring and merging.
+	 */
+	struct virtio_vsock_pkt_info info = {
+		.op = VIRTIO_VSOCK_OP_RW,
+		.remote_cid = remote_addr->svm_cid,
+		.remote_port = remote_addr->svm_port,
+		.remote_flags = remote_addr->svm_flags,
+		.msg = msg,
+		.vsk = vsk,
+		.pkt_len = dgram_len,
+	};
+
+	return virtio_transport_dgram_send_pkt_info(vsk, &info);
 }
 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);