diff mbox series

[RFC,v3,03/11] af_vsock: add zerocopy receive logic

Message ID 7aeba781-db09-9be1-a9a3-a4c16da38fb5@sberdevices.ru (mailing list archive)
State New, archived
Headers show
Series virtio/vsock: experimental zerocopy receive | expand

Commit Message

Arseniy Krasnov Nov. 6, 2022, 7:40 p.m. UTC
This:
1) Adds callback for 'mmap()' call on socket. It checks vm area flags
   and sets vm area ops.
2) Adds special 'getsockopt()' case which calls transport zerocopy
   callback. Input argument is vm area address.
3) Adds 'getsockopt()/setsockopt()' for switching on/off rx zerocopy
   mode.

Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
---
 include/net/af_vsock.h          |   8 ++
 include/uapi/linux/vm_sockets.h |   3 +
 net/vmw_vsock/af_vsock.c        | 187 +++++++++++++++++++++++++++++++-
 3 files changed, 196 insertions(+), 2 deletions(-)

Comments

Stefano Garzarella Nov. 11, 2022, 1:55 p.m. UTC | #1
On Sun, Nov 06, 2022 at 07:40:12PM +0000, Arseniy Krasnov wrote:
>This:
>1) Adds callback for 'mmap()' call on socket. It checks vm area flags
>   and sets vm area ops.
>2) Adds special 'getsockopt()' case which calls transport zerocopy
>   callback. Input argument is vm area address.
>3) Adds 'getsockopt()/setsockopt()' for switching on/off rx zerocopy
>   mode.
>
>Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
>---
> include/net/af_vsock.h          |   8 ++
> include/uapi/linux/vm_sockets.h |   3 +
> net/vmw_vsock/af_vsock.c        | 187 +++++++++++++++++++++++++++++++-
> 3 files changed, 196 insertions(+), 2 deletions(-)
>
>diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
>index 568a87c5e0d0..e4f12ef8e623 100644
>--- a/include/net/af_vsock.h
>+++ b/include/net/af_vsock.h
>@@ -73,6 +73,8 @@ struct vsock_sock {
>
> 	/* Private to transport. */
> 	void *trans;
>+
>+	bool rx_zerocopy_on;

Maybe better to leave the last fields the private ones to transports, so 
I would say put it before trans;

> };
>
> s64 vsock_stream_has_data(struct vsock_sock *vsk);
>@@ -138,6 +140,12 @@ struct vsock_transport {
> 	bool (*stream_allow)(u32 cid, u32 port);
> 	int (*set_rcvlowat)(struct vsock_sock *vsk, int val);
>
>+	int (*zerocopy_rx_mmap)(struct vsock_sock *vsk,
>+				struct vm_area_struct *vma);
>+	int (*zerocopy_dequeue)(struct vsock_sock *vsk,
>+				struct page **pages,
>+				unsigned long *pages_num);
>+
> 	/* SEQ_PACKET. */
> 	ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
> 				     int flags);
>diff --git a/include/uapi/linux/vm_sockets.h b/include/uapi/linux/vm_sockets.h
>index c60ca33eac59..d1f792bed1a7 100644
>--- a/include/uapi/linux/vm_sockets.h
>+++ b/include/uapi/linux/vm_sockets.h
>@@ -83,6 +83,9 @@
>
> #define SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW 8
>
>+#define SO_VM_SOCKETS_MAP_RX 9
>+#define SO_VM_SOCKETS_ZEROCOPY 10

Before removing RFC, we should document these macros because they are 
exposed to the user.

>+
> #if !defined(__KERNEL__)
> #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && 
> defined(__ILP32__))
> #define SO_VM_SOCKETS_CONNECT_TIMEOUT SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD
>diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
>index ee418701cdee..21a915eb0820 100644
>--- a/net/vmw_vsock/af_vsock.c
>+++ b/net/vmw_vsock/af_vsock.c
>@@ -1663,6 +1663,16 @@ static int vsock_connectible_setsockopt(struct socket *sock,
> 		}
> 		break;
> 	}
>+	case SO_VM_SOCKETS_ZEROCOPY: {
>+		if (sock->state != SS_UNCONNECTED) {
>+			err = -EOPNOTSUPP;
>+			break;
>+		}
>+
>+		COPY_IN(val);
>+		vsk->rx_zerocopy_on = val;
>+		break;
>+	}
>
> 	default:
> 		err = -ENOPROTOOPT;
>@@ -1676,6 +1686,124 @@ static int vsock_connectible_setsockopt(struct socket *sock,
> 	return err;
> }
>
>+static const struct vm_operations_struct afvsock_vm_ops = {
>+};
>+
>+static int vsock_recv_zerocopy(struct socket *sock,
>+			       unsigned long address)
>+{
>+	const struct vsock_transport *transport;
>+	struct vm_area_struct *vma;
>+	unsigned long vma_pages;
>+	struct vsock_sock *vsk;
>+	struct page **pages;
>+	struct sock *sk;
>+	int err;
>+	int i;
>+
>+	sk = sock->sk;
>+	vsk = vsock_sk(sk);
>+	err = 0;
>+
>+	lock_sock(sk);
>+
>+	if (!vsk->rx_zerocopy_on) {
>+		err = -EOPNOTSUPP;
>+		goto out_unlock_sock;
>+	}
>+
>+	transport = vsk->transport;
>+
>+	if (!transport->zerocopy_dequeue) {
>+		err = -EOPNOTSUPP;
>+		goto out_unlock_sock;
>+	}
>+
>+	mmap_write_lock(current->mm);
>+
>+	vma = vma_lookup(current->mm, address);
>+
>+	if (!vma || vma->vm_ops != &afvsock_vm_ops) {
>+		err = -EINVAL;
>+		goto out_unlock_vma;
>+	}
>+
>+	/* Allow to use vm area only from the first page. */
>+	if (vma->vm_start != address) {
>+		err = -EINVAL;
>+		goto out_unlock_vma;
>+	}
>+
>+	vma_pages = (vma->vm_end - vma->vm_start) / PAGE_SIZE;
>+	pages = kmalloc_array(vma_pages, sizeof(pages[0]),
>+			      GFP_KERNEL | __GFP_ZERO);
>+
>+	if (!pages) {
>+		err = -EINVAL;
>+		goto out_unlock_vma;
>+	}
>+
>+	err = transport->zerocopy_dequeue(vsk, pages, &vma_pages);
>+
>+	if (err)
>+		goto out_unlock_vma;
>+
>+	/* Now 'vma_pages' contains number of pages in array.
>+	 * If array element is NULL, skip it, go to next page.
>+	 */
>+	for (i = 0; i < vma_pages; i++) {
>+		if (pages[i]) {
>+			unsigned long pages_inserted;
>+
>+			pages_inserted = 1;
>+			err = vm_insert_pages(vma, address, &pages[i], &pages_inserted);
>+
>+			if (err || pages_inserted) {
>+				/* Failed to insert some pages, we have "partially"
>+				 * mapped vma. Do not return, set error code. This
>+				 * code will be returned to user. User needs to call
>+				 * 'madvise()/mmap()' to clear this vma. Anyway,
>+				 * references to all pages will to be dropped below.
>+				 */
>+				if (!err) {
>+					err = -EFAULT;
>+					break;
>+				}
>+			}
>+		}
>+
>+		address += PAGE_SIZE;
>+	}
>+
>+	i = 0;
>+
>+	while (i < vma_pages) {
>+		/* Drop ref count for all pages, returned by transport.
>+		 * We call 'put_page()' only once, as transport needed
>+		 * to 'get_page()' at least only once also, to prevent
>+		 * pages being freed. If transport calls 'get_page()'
>+		 * more twice or more for every page - we don't care,
>+		 * if transport calls 'get_page()' only one time, this
>+		 * meanse that every page had ref count equal to 1,then
>+		 * 'vm_insert_pages()' increments it to 2. After this
>+		 * loop, ref count will be 1 again, and page will be
>+		 * returned to allocator by user.
>+		 */
>+		if (pages[i])
>+			put_page(pages[i]);
>+		i++;
>+	}
>+
>+	kfree(pages);
>+
>+out_unlock_vma:
>+	mmap_write_unlock(current->mm);
>+out_unlock_sock:
>+	release_sock(sk);
>+
>+	return err;
>+}
>+
> static int vsock_connectible_getsockopt(struct socket *sock,
> 					int level, int optname,
> 					char __user *optval,
>@@ -1720,6 +1848,26 @@ static int vsock_connectible_getsockopt(struct socket *sock,
> 		lv = sock_get_timeout(vsk->connect_timeout, &v,
> 				      optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
> 		break;
>+	case SO_VM_SOCKETS_ZEROCOPY: {
>+		lock_sock(sk);
>+
>+		v.val64 = vsk->rx_zerocopy_on;
>+
>+		release_sock(sk);
>+
>+		break;
>+	}
>+	case SO_VM_SOCKETS_MAP_RX: {
>+		unsigned long vma_addr;
>+
>+		if (len < sizeof(vma_addr))
>+			return -EINVAL;
>+
>+		if (copy_from_user(&vma_addr, optval, sizeof(vma_addr)))
>+			return -EFAULT;
>+
>+		return vsock_recv_zerocopy(sock, vma_addr);
>+	}
>
> 	default:
> 		return -ENOPROTOOPT;
>@@ -2167,6 +2315,41 @@ static int vsock_set_rcvlowat(struct sock *sk, int val)
> 	return 0;
> }
>
>+static int afvsock_mmap(struct file *file, struct socket *sock,
>+			struct vm_area_struct *vma)
>+{
>+	const struct vsock_transport *transport;
>+	struct vsock_sock *vsk;
>+	struct sock *sk;
>+	int err;
>+
>+	if (vma->vm_flags & (VM_WRITE | VM_EXEC))
>+		return -EPERM;
>+
>+	vma->vm_flags &= ~(VM_MAYWRITE | VM_MAYEXEC);
>+	vma->vm_flags |= (VM_MIXEDMAP);
>+	vma->vm_ops = &afvsock_vm_ops;
>+
>+	sk = sock->sk;
>+	vsk = vsock_sk(sk);
>+
>+	lock_sock(sk);
>+
>+	transport = vsk->transport;
>+
>+	if (!transport || !transport->zerocopy_rx_mmap) {
>+		err = -EOPNOTSUPP;
>+		goto out_unlock;
>+	}
>+
>+	err = transport->zerocopy_rx_mmap(vsk, vma);
>+
>+out_unlock:
>+	release_sock(sk);
>+
>+	return err;
>+}
>+
> static const struct proto_ops vsock_stream_ops = {
> 	.family = PF_VSOCK,
> 	.owner = THIS_MODULE,
>@@ -2184,7 +2367,7 @@ static const struct proto_ops vsock_stream_ops = {
> 	.getsockopt = vsock_connectible_getsockopt,
> 	.sendmsg = vsock_connectible_sendmsg,
> 	.recvmsg = vsock_connectible_recvmsg,
>-	.mmap = sock_no_mmap,
>+	.mmap = afvsock_mmap,
> 	.sendpage = sock_no_sendpage,
> 	.set_rcvlowat = vsock_set_rcvlowat,
> };
>@@ -2206,7 +2389,7 @@ static const struct proto_ops vsock_seqpacket_ops = {
> 	.getsockopt = vsock_connectible_getsockopt,
> 	.sendmsg = vsock_connectible_sendmsg,
> 	.recvmsg = vsock_connectible_recvmsg,
>-	.mmap = sock_no_mmap,
>+	.mmap = afvsock_mmap,
> 	.sendpage = sock_no_sendpage,
> };
>
>-- 
>2.35.0
diff mbox series

Patch

diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index 568a87c5e0d0..e4f12ef8e623 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -73,6 +73,8 @@  struct vsock_sock {
 
 	/* Private to transport. */
 	void *trans;
+
+	bool rx_zerocopy_on;
 };
 
 s64 vsock_stream_has_data(struct vsock_sock *vsk);
@@ -138,6 +140,12 @@  struct vsock_transport {
 	bool (*stream_allow)(u32 cid, u32 port);
 	int (*set_rcvlowat)(struct vsock_sock *vsk, int val);
 
+	int (*zerocopy_rx_mmap)(struct vsock_sock *vsk,
+				struct vm_area_struct *vma);
+	int (*zerocopy_dequeue)(struct vsock_sock *vsk,
+				struct page **pages,
+				unsigned long *pages_num);
+
 	/* SEQ_PACKET. */
 	ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
 				     int flags);
diff --git a/include/uapi/linux/vm_sockets.h b/include/uapi/linux/vm_sockets.h
index c60ca33eac59..d1f792bed1a7 100644
--- a/include/uapi/linux/vm_sockets.h
+++ b/include/uapi/linux/vm_sockets.h
@@ -83,6 +83,9 @@ 
 
 #define SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW 8
 
+#define SO_VM_SOCKETS_MAP_RX 9
+#define SO_VM_SOCKETS_ZEROCOPY 10
+
 #if !defined(__KERNEL__)
 #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
 #define SO_VM_SOCKETS_CONNECT_TIMEOUT SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index ee418701cdee..21a915eb0820 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -1663,6 +1663,16 @@  static int vsock_connectible_setsockopt(struct socket *sock,
 		}
 		break;
 	}
+	case SO_VM_SOCKETS_ZEROCOPY: {
+		if (sock->state != SS_UNCONNECTED) {
+			err = -EOPNOTSUPP;
+			break;
+		}
+
+		COPY_IN(val);
+		vsk->rx_zerocopy_on = val;
+		break;
+	}
 
 	default:
 		err = -ENOPROTOOPT;
@@ -1676,6 +1686,124 @@  static int vsock_connectible_setsockopt(struct socket *sock,
 	return err;
 }
 
+static const struct vm_operations_struct afvsock_vm_ops = {
+};
+
+static int vsock_recv_zerocopy(struct socket *sock,
+			       unsigned long address)
+{
+	const struct vsock_transport *transport;
+	struct vm_area_struct *vma;
+	unsigned long vma_pages;
+	struct vsock_sock *vsk;
+	struct page **pages;
+	struct sock *sk;
+	int err;
+	int i;
+
+	sk = sock->sk;
+	vsk = vsock_sk(sk);
+	err = 0;
+
+	lock_sock(sk);
+
+	if (!vsk->rx_zerocopy_on) {
+		err = -EOPNOTSUPP;
+		goto out_unlock_sock;
+	}
+
+	transport = vsk->transport;
+
+	if (!transport->zerocopy_dequeue) {
+		err = -EOPNOTSUPP;
+		goto out_unlock_sock;
+	}
+
+	mmap_write_lock(current->mm);
+
+	vma = vma_lookup(current->mm, address);
+
+	if (!vma || vma->vm_ops != &afvsock_vm_ops) {
+		err = -EINVAL;
+		goto out_unlock_vma;
+	}
+
+	/* Allow to use vm area only from the first page. */
+	if (vma->vm_start != address) {
+		err = -EINVAL;
+		goto out_unlock_vma;
+	}
+
+	vma_pages = (vma->vm_end - vma->vm_start) / PAGE_SIZE;
+	pages = kmalloc_array(vma_pages, sizeof(pages[0]),
+			      GFP_KERNEL | __GFP_ZERO);
+
+	if (!pages) {
+		err = -EINVAL;
+		goto out_unlock_vma;
+	}
+
+	err = transport->zerocopy_dequeue(vsk, pages, &vma_pages);
+
+	if (err)
+		goto out_unlock_vma;
+
+	/* Now 'vma_pages' contains number of pages in array.
+	 * If array element is NULL, skip it, go to next page.
+	 */
+	for (i = 0; i < vma_pages; i++) {
+		if (pages[i]) {
+			unsigned long pages_inserted;
+
+			pages_inserted = 1;
+			err = vm_insert_pages(vma, address, &pages[i], &pages_inserted);
+
+			if (err || pages_inserted) {
+				/* Failed to insert some pages, we have "partially"
+				 * mapped vma. Do not return, set error code. This
+				 * code will be returned to user. User needs to call
+				 * 'madvise()/mmap()' to clear this vma. Anyway,
+				 * references to all pages will to be dropped below.
+				 */
+				if (!err) {
+					err = -EFAULT;
+					break;
+				}
+			}
+		}
+
+		address += PAGE_SIZE;
+	}
+
+	i = 0;
+
+	while (i < vma_pages) {
+		/* Drop ref count for all pages, returned by transport.
+		 * We call 'put_page()' only once, as transport needed
+		 * to 'get_page()' at least only once also, to prevent
+		 * pages being freed. If transport calls 'get_page()'
+		 * more twice or more for every page - we don't care,
+		 * if transport calls 'get_page()' only one time, this
+		 * meanse that every page had ref count equal to 1,then
+		 * 'vm_insert_pages()' increments it to 2. After this
+		 * loop, ref count will be 1 again, and page will be
+		 * returned to allocator by user.
+		 */
+		if (pages[i])
+			put_page(pages[i]);
+		i++;
+	}
+
+	kfree(pages);
+
+out_unlock_vma:
+	mmap_write_unlock(current->mm);
+out_unlock_sock:
+	release_sock(sk);
+
+	return err;
+}
+
 static int vsock_connectible_getsockopt(struct socket *sock,
 					int level, int optname,
 					char __user *optval,
@@ -1720,6 +1848,26 @@  static int vsock_connectible_getsockopt(struct socket *sock,
 		lv = sock_get_timeout(vsk->connect_timeout, &v,
 				      optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
 		break;
+	case SO_VM_SOCKETS_ZEROCOPY: {
+		lock_sock(sk);
+
+		v.val64 = vsk->rx_zerocopy_on;
+
+		release_sock(sk);
+
+		break;
+	}
+	case SO_VM_SOCKETS_MAP_RX: {
+		unsigned long vma_addr;
+
+		if (len < sizeof(vma_addr))
+			return -EINVAL;
+
+		if (copy_from_user(&vma_addr, optval, sizeof(vma_addr)))
+			return -EFAULT;
+
+		return vsock_recv_zerocopy(sock, vma_addr);
+	}
 
 	default:
 		return -ENOPROTOOPT;
@@ -2167,6 +2315,41 @@  static int vsock_set_rcvlowat(struct sock *sk, int val)
 	return 0;
 }
 
+static int afvsock_mmap(struct file *file, struct socket *sock,
+			struct vm_area_struct *vma)
+{
+	const struct vsock_transport *transport;
+	struct vsock_sock *vsk;
+	struct sock *sk;
+	int err;
+
+	if (vma->vm_flags & (VM_WRITE | VM_EXEC))
+		return -EPERM;
+
+	vma->vm_flags &= ~(VM_MAYWRITE | VM_MAYEXEC);
+	vma->vm_flags |= (VM_MIXEDMAP);
+	vma->vm_ops = &afvsock_vm_ops;
+
+	sk = sock->sk;
+	vsk = vsock_sk(sk);
+
+	lock_sock(sk);
+
+	transport = vsk->transport;
+
+	if (!transport || !transport->zerocopy_rx_mmap) {
+		err = -EOPNOTSUPP;
+		goto out_unlock;
+	}
+
+	err = transport->zerocopy_rx_mmap(vsk, vma);
+
+out_unlock:
+	release_sock(sk);
+
+	return err;
+}
+
 static const struct proto_ops vsock_stream_ops = {
 	.family = PF_VSOCK,
 	.owner = THIS_MODULE,
@@ -2184,7 +2367,7 @@  static const struct proto_ops vsock_stream_ops = {
 	.getsockopt = vsock_connectible_getsockopt,
 	.sendmsg = vsock_connectible_sendmsg,
 	.recvmsg = vsock_connectible_recvmsg,
-	.mmap = sock_no_mmap,
+	.mmap = afvsock_mmap,
 	.sendpage = sock_no_sendpage,
 	.set_rcvlowat = vsock_set_rcvlowat,
 };
@@ -2206,7 +2389,7 @@  static const struct proto_ops vsock_seqpacket_ops = {
 	.getsockopt = vsock_connectible_getsockopt,
 	.sendmsg = vsock_connectible_sendmsg,
 	.recvmsg = vsock_connectible_recvmsg,
-	.mmap = sock_no_mmap,
+	.mmap = afvsock_mmap,
 	.sendpage = sock_no_sendpage,
 };