diff mbox series

[RFC,v2,3/8] af_vsock: add zerocopy receive logic

Message ID 129aa328-ad4d-cb2c-4a51-4a2bf9c9be37@sberdevices.ru (mailing list archive)
State New, archived
Headers show
Series virtio/vsock: experimental zerocopy receive | expand

Commit Message

Arseniy Krasnov June 3, 2022, 5:35 a.m. UTC
This:
1) Adds callback for 'mmap()' call on socket. It checks vm
   area flags and sets vm area ops.
2) Adds special 'getsockopt()' case which calls transport
   zerocopy callback. Input argument is vm area address.
3) Adds 'getsockopt()/setsockopt()' for switching on/off rx
   zerocopy mode.

Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
---
 include/net/af_vsock.h          |   7 +++
 include/uapi/linux/vm_sockets.h |   3 +
 net/vmw_vsock/af_vsock.c        | 100 ++++++++++++++++++++++++++++++++
 3 files changed, 110 insertions(+)

Comments

Stefano Garzarella June 9, 2022, 8:39 a.m. UTC | #1
On Fri, Jun 03, 2022 at 05:35:48AM +0000, Arseniy Krasnov wrote:
>This:
>1) Adds callback for 'mmap()' call on socket. It checks vm
>   area flags and sets vm area ops.
>2) Adds special 'getsockopt()' case which calls transport
>   zerocopy callback. Input argument is vm area address.
>3) Adds 'getsockopt()/setsockopt()' for switching on/off rx
>   zerocopy mode.
>
>Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
>---
> include/net/af_vsock.h          |   7 +++
> include/uapi/linux/vm_sockets.h |   3 +
> net/vmw_vsock/af_vsock.c        | 100 ++++++++++++++++++++++++++++++++
> 3 files changed, 110 insertions(+)
>
>diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
>index f742e50207fb..f15f84c648ff 100644
>--- a/include/net/af_vsock.h
>+++ b/include/net/af_vsock.h
>@@ -135,6 +135,13 @@ struct vsock_transport {
> 	bool (*stream_is_active)(struct vsock_sock *);
> 	bool (*stream_allow)(u32 cid, u32 port);
>
>+	int (*rx_zerocopy_set)(struct vsock_sock *vsk,
>+			       bool enable);
>+	int (*rx_zerocopy_get)(struct vsock_sock *vsk);
>+	int (*zerocopy_dequeue)(struct vsock_sock *vsk,
>+				struct vm_area_struct *vma,
>+				unsigned long addr);
>+
> 	/* SEQ_PACKET. */
> 	ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
> 				     int flags);
>diff --git a/include/uapi/linux/vm_sockets.h b/include/uapi/linux/vm_sockets.h
>index c60ca33eac59..d1f792bed1a7 100644
>--- a/include/uapi/linux/vm_sockets.h
>+++ b/include/uapi/linux/vm_sockets.h
>@@ -83,6 +83,9 @@
>
> #define SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW 8
>
>+#define SO_VM_SOCKETS_MAP_RX 9
>+#define SO_VM_SOCKETS_ZEROCOPY 10
>+
> #if !defined(__KERNEL__)
> #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
> #define SO_VM_SOCKETS_CONNECT_TIMEOUT SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD
>diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
>index f04abf662ec6..10061ef21730 100644
>--- a/net/vmw_vsock/af_vsock.c
>+++ b/net/vmw_vsock/af_vsock.c
>@@ -1644,6 +1644,17 @@ static int vsock_connectible_setsockopt(struct socket *sock,
> 		}
> 		break;
> 	}
>+	case SO_VM_SOCKETS_ZEROCOPY: {
>+		if (!transport || !transport->rx_zerocopy_set) {
>+			err = -EOPNOTSUPP;
>+		} else {
>+			COPY_IN(val);
>+
>+			if (transport->rx_zerocopy_set(vsk, val))
>+				err = -EINVAL;
>+		}
>+		break;
>+	}
>
> 	default:
> 		err = -ENOPROTOOPT;
>@@ -1657,6 +1668,48 @@ static int vsock_connectible_setsockopt(struct socket *sock,
> 	return err;
> }
>
>+static const struct vm_operations_struct afvsock_vm_ops = {
>+};
>+
>+static int vsock_recv_zerocopy(struct socket *sock,
>+			       unsigned long address)
>+{
>+	struct sock *sk = sock->sk;
>+	struct vsock_sock *vsk = vsock_sk(sk);
>+	struct vm_area_struct *vma;
>+	const struct vsock_transport *transport;
>+	int res;
>+
>+	transport = vsk->transport;
>+
>+	if (!transport->rx_zerocopy_get)
>+		return -EOPNOTSUPP;
>+
>+	if (!transport->rx_zerocopy_get(vsk))
>+		return -EOPNOTSUPP;

Maybe we can merge in
         if (!transport->rx_zerocopy_get ||
             !transport->rx_zerocopy_get(vsk)}
                 return -EOPNOTSUPP;

>+
>+	if (!transport->zerocopy_dequeue)
>+		return -EOPNOTSUPP;
>+
>+	lock_sock(sk);
>+	mmap_write_lock(current->mm);

So, multiple threads using different sockets are serialized if they use 
zero-copy?

IIUC this is necessary because the callback calls vm_insert_page().

At this point I think it's better not to do this in every transport, but 
have the callback return an array of pages to map and we map them here 
trying to limit as much as possible the critical section to protect with 
mmap_write_lock().

>+
>+	vma = vma_lookup(current->mm, address);
>+
>+	if (!vma || vma->vm_ops != &afvsock_vm_ops) {
>+		mmap_write_unlock(current->mm);
>+		release_sock(sk);
>+		return -EINVAL;
>+	}
>+
>+	res = transport->zerocopy_dequeue(vsk, vma, address);
>+
>+	mmap_write_unlock(current->mm);
>+	release_sock(sk);
>+
>+	return res;
>+}
>+
> static int vsock_connectible_getsockopt(struct socket *sock,
> 					int level, int optname,
> 					char __user *optval,
>@@ -1701,6 +1754,39 @@ static int vsock_connectible_getsockopt(struct socket *sock,
> 		lv = sock_get_timeout(vsk->connect_timeout, &v,
> 				      optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
> 		break;
>+	case SO_VM_SOCKETS_ZEROCOPY: {
>+		const struct vsock_transport *transport;
>+		int res;
>+
>+		transport = vsk->transport;
>+
>+		if (!transport->rx_zerocopy_get)
>+			return -EOPNOTSUPP;
>+
>+		lock_sock(sk);

I think we should call lock_sock() before reading the transport to avoid 
races and we should check if it is assigned.

At that point I think is better to store this info in vsock_sock and not 
in the transport.

And maybe we should allow to change it only if the socket state is 
SS_UNCONNECTED, inheriting from the parent the setting for sockets that 
have it.

>+
>+		res = transport->rx_zerocopy_get(vsk);
>+
>+		release_sock(sk);
>+
>+		if (res < 0)
>+			return -EINVAL;
>+
>+		v.val64 = res;
>+
>+		break;
>+	}
>+	case SO_VM_SOCKETS_MAP_RX: {
>+		unsigned long vma_addr;
>+
>+		if (len < sizeof(vma_addr))
>+			return -EINVAL;
>+
>+		if (copy_from_user(&vma_addr, optval, sizeof(vma_addr)))
>+			return -EFAULT;
>+
>+		return vsock_recv_zerocopy(sock, vma_addr);
>+	}
>
> 	default:
> 		return -ENOPROTOOPT;
>@@ -2129,6 +2215,19 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
> 	return err;
> }
>
>+static int afvsock_mmap(struct file *file, struct socket *sock,
>+			struct vm_area_struct *vma)
>+{
>+	if (vma->vm_flags & (VM_WRITE | VM_EXEC))
>+		return -EPERM;
>+
>+	vma->vm_flags &= ~(VM_MAYWRITE | VM_MAYEXEC);
>+	vma->vm_flags |= (VM_MIXEDMAP);
>+	vma->vm_ops = &afvsock_vm_ops;
>+
>+	return 0;
>+}
>+
> static const struct proto_ops vsock_stream_ops = {
> 	.family = PF_VSOCK,
> 	.owner = THIS_MODULE,
>@@ -2148,6 +2247,7 @@ static const struct proto_ops vsock_stream_ops = {
> 	.recvmsg = vsock_connectible_recvmsg,
> 	.mmap = sock_no_mmap,
> 	.sendpage = sock_no_sendpage,
>+	.mmap = afvsock_mmap,
> };
>
> static const struct proto_ops vsock_seqpacket_ops = {
>-- 
>2.25.1
Arseniy Krasnov June 9, 2022, 12:20 p.m. UTC | #2
On 09.06.2022 11:39, Stefano Garzarella wrote:
> On Fri, Jun 03, 2022 at 05:35:48AM +0000, Arseniy Krasnov wrote:
>> This:
>> 1) Adds callback for 'mmap()' call on socket. It checks vm
>>   area flags and sets vm area ops.
>> 2) Adds special 'getsockopt()' case which calls transport
>>   zerocopy callback. Input argument is vm area address.
>> 3) Adds 'getsockopt()/setsockopt()' for switching on/off rx
>>   zerocopy mode.
>>
>> Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
>> ---
>> include/net/af_vsock.h          |   7 +++
>> include/uapi/linux/vm_sockets.h |   3 +
>> net/vmw_vsock/af_vsock.c        | 100 ++++++++++++++++++++++++++++++++
>> 3 files changed, 110 insertions(+)
>>
>> diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
>> index f742e50207fb..f15f84c648ff 100644
>> --- a/include/net/af_vsock.h
>> +++ b/include/net/af_vsock.h
>> @@ -135,6 +135,13 @@ struct vsock_transport {
>>     bool (*stream_is_active)(struct vsock_sock *);
>>     bool (*stream_allow)(u32 cid, u32 port);
>>
>> +    int (*rx_zerocopy_set)(struct vsock_sock *vsk,
>> +                   bool enable);
>> +    int (*rx_zerocopy_get)(struct vsock_sock *vsk);
>> +    int (*zerocopy_dequeue)(struct vsock_sock *vsk,
>> +                struct vm_area_struct *vma,
>> +                unsigned long addr);
>> +
>>     /* SEQ_PACKET. */
>>     ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
>>                      int flags);
>> diff --git a/include/uapi/linux/vm_sockets.h b/include/uapi/linux/vm_sockets.h
>> index c60ca33eac59..d1f792bed1a7 100644
>> --- a/include/uapi/linux/vm_sockets.h
>> +++ b/include/uapi/linux/vm_sockets.h
>> @@ -83,6 +83,9 @@
>>
>> #define SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW 8
>>
>> +#define SO_VM_SOCKETS_MAP_RX 9
>> +#define SO_VM_SOCKETS_ZEROCOPY 10
>> +
>> #if !defined(__KERNEL__)
>> #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
>> #define SO_VM_SOCKETS_CONNECT_TIMEOUT SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD
>> diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
>> index f04abf662ec6..10061ef21730 100644
>> --- a/net/vmw_vsock/af_vsock.c
>> +++ b/net/vmw_vsock/af_vsock.c
>> @@ -1644,6 +1644,17 @@ static int vsock_connectible_setsockopt(struct socket *sock,
>>         }
>>         break;
>>     }
>> +    case SO_VM_SOCKETS_ZEROCOPY: {
>> +        if (!transport || !transport->rx_zerocopy_set) {
>> +            err = -EOPNOTSUPP;
>> +        } else {
>> +            COPY_IN(val);
>> +
>> +            if (transport->rx_zerocopy_set(vsk, val))
>> +                err = -EINVAL;
>> +        }
>> +        break;
>> +    }
>>
>>     default:
>>         err = -ENOPROTOOPT;
>> @@ -1657,6 +1668,48 @@ static int vsock_connectible_setsockopt(struct socket *sock,
>>     return err;
>> }
>>
>> +static const struct vm_operations_struct afvsock_vm_ops = {
>> +};
>> +
>> +static int vsock_recv_zerocopy(struct socket *sock,
>> +                   unsigned long address)
>> +{
>> +    struct sock *sk = sock->sk;
>> +    struct vsock_sock *vsk = vsock_sk(sk);
>> +    struct vm_area_struct *vma;
>> +    const struct vsock_transport *transport;
>> +    int res;
>> +
>> +    transport = vsk->transport;
>> +
>> +    if (!transport->rx_zerocopy_get)
>> +        return -EOPNOTSUPP;
>> +
>> +    if (!transport->rx_zerocopy_get(vsk))
>> +        return -EOPNOTSUPP;
> 
> Maybe we can merge in
>         if (!transport->rx_zerocopy_get ||
>             !transport->rx_zerocopy_get(vsk)}
>                 return -EOPNOTSUPP;
> 
>> +
>> +    if (!transport->zerocopy_dequeue)
>> +        return -EOPNOTSUPP;
>> +
>> +    lock_sock(sk);
>> +    mmap_write_lock(current->mm);
> 
> So, multiple threads using different sockets are serialized if they use zero-copy?
> 
> IIUC this is necessary because the callback calls vm_insert_page().
> 
> At this point I think it's better not to do this in every transport, but have the callback return an array of pages to map and we map them here trying to limit as much as possible the critical section to protect with mmap_write_lock().

Yes, it will be easy to return array of pages by transport callback,

> 
>> +
>> +    vma = vma_lookup(current->mm, address);
>> +
>> +    if (!vma || vma->vm_ops != &afvsock_vm_ops) {
>> +        mmap_write_unlock(current->mm);
>> +        release_sock(sk);
>> +        return -EINVAL;
>> +    }
>> +
>> +    res = transport->zerocopy_dequeue(vsk, vma, address);
>> +
>> +    mmap_write_unlock(current->mm);
>> +    release_sock(sk);
>> +
>> +    return res;
>> +}
>> +
>> static int vsock_connectible_getsockopt(struct socket *sock,
>>                     int level, int optname,
>>                     char __user *optval,
>> @@ -1701,6 +1754,39 @@ static int vsock_connectible_getsockopt(struct socket *sock,
>>         lv = sock_get_timeout(vsk->connect_timeout, &v,
>>                       optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
>>         break;
>> +    case SO_VM_SOCKETS_ZEROCOPY: {
>> +        const struct vsock_transport *transport;
>> +        int res;
>> +
>> +        transport = vsk->transport;
>> +
>> +        if (!transport->rx_zerocopy_get)
>> +            return -EOPNOTSUPP;
>> +
>> +        lock_sock(sk);
> 
> I think we should call lock_sock() before reading the transport to avoid races and we should check if it is assigned.
> 
> At that point I think is better to store this info in vsock_sock and not in the transport.
You mean to store flag that zerocopy is enabled in 'vsock_sock', just reading it here, without touching transport?
> 
> And maybe we should allow to change it only if the socket state is SS_UNCONNECTED, inheriting from the parent the setting for sockets that have itAck
> 
>> +
>> +        res = transport->rx_zerocopy_get(vsk);
>> +
>> +        release_sock(sk);
>> +
>> +        if (res < 0)
>> +            return -EINVAL;
>> +
>> +        v.val64 = res;
>> +
>> +        break;
>> +    }
>> +    case SO_VM_SOCKETS_MAP_RX: {
>> +        unsigned long vma_addr;
>> +
>> +        if (len < sizeof(vma_addr))
>> +            return -EINVAL;
>> +
>> +        if (copy_from_user(&vma_addr, optval, sizeof(vma_addr)))
>> +            return -EFAULT;
>> +
>> +        return vsock_recv_zerocopy(sock, vma_addr);
>> +    }
>>
>>     default:
>>         return -ENOPROTOOPT;
>> @@ -2129,6 +2215,19 @@ vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
>>     return err;
>> }
>>
>> +static int afvsock_mmap(struct file *file, struct socket *sock,
>> +            struct vm_area_struct *vma)
>> +{
>> +    if (vma->vm_flags & (VM_WRITE | VM_EXEC))
>> +        return -EPERM;
>> +
>> +    vma->vm_flags &= ~(VM_MAYWRITE | VM_MAYEXEC);
>> +    vma->vm_flags |= (VM_MIXEDMAP);
>> +    vma->vm_ops = &afvsock_vm_ops;
>> +
>> +    return 0;
>> +}
>> +
>> static const struct proto_ops vsock_stream_ops = {
>>     .family = PF_VSOCK,
>>     .owner = THIS_MODULE,
>> @@ -2148,6 +2247,7 @@ static const struct proto_ops vsock_stream_ops = {
>>     .recvmsg = vsock_connectible_recvmsg,
>>     .mmap = sock_no_mmap,
>>     .sendpage = sock_no_sendpage,
>> +    .mmap = afvsock_mmap,
>> };
>>
>> static const struct proto_ops vsock_seqpacket_ops = {
>> -- 
>> 2.25.1
>
Stefano Garzarella June 13, 2022, 8:50 a.m. UTC | #3
On Thu, Jun 09, 2022 at 12:20:22PM +0000, Arseniy Krasnov wrote:
>On 09.06.2022 11:39, Stefano Garzarella wrote:
>> On Fri, Jun 03, 2022 at 05:35:48AM +0000, Arseniy Krasnov wrote:
>>> This:
>>> 1) Adds callback for 'mmap()' call on socket. It checks vm
>>>   area flags and sets vm area ops.
>>> 2) Adds special 'getsockopt()' case which calls transport
>>>   zerocopy callback. Input argument is vm area address.
>>> 3) Adds 'getsockopt()/setsockopt()' for switching on/off rx
>>>   zerocopy mode.
>>>
>>> Signed-off-by: Arseniy Krasnov <AVKrasnov@sberdevices.ru>
>>> ---
>>> include/net/af_vsock.h          |   7 +++
>>> include/uapi/linux/vm_sockets.h |   3 +
>>> net/vmw_vsock/af_vsock.c        | 100 ++++++++++++++++++++++++++++++++
>>> 3 files changed, 110 insertions(+)
>>>
>>> diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
>>> index f742e50207fb..f15f84c648ff 100644
>>> --- a/include/net/af_vsock.h
>>> +++ b/include/net/af_vsock.h
>>> @@ -135,6 +135,13 @@ struct vsock_transport {
>>>     bool (*stream_is_active)(struct vsock_sock *);
>>>     bool (*stream_allow)(u32 cid, u32 port);
>>>
>>> +    int (*rx_zerocopy_set)(struct vsock_sock *vsk,
>>> +                   bool enable);
>>> +    int (*rx_zerocopy_get)(struct vsock_sock *vsk);
>>> +    int (*zerocopy_dequeue)(struct vsock_sock *vsk,
>>> +                struct vm_area_struct *vma,
>>> +                unsigned long addr);
>>> +
>>>     /* SEQ_PACKET. */
>>>     ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
>>>                      int flags);
>>> diff --git a/include/uapi/linux/vm_sockets.h b/include/uapi/linux/vm_sockets.h
>>> index c60ca33eac59..d1f792bed1a7 100644
>>> --- a/include/uapi/linux/vm_sockets.h
>>> +++ b/include/uapi/linux/vm_sockets.h
>>> @@ -83,6 +83,9 @@
>>>
>>> #define SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW 8
>>>
>>> +#define SO_VM_SOCKETS_MAP_RX 9
>>> +#define SO_VM_SOCKETS_ZEROCOPY 10
>>> +
>>> #if !defined(__KERNEL__)
>>> #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
>>> #define SO_VM_SOCKETS_CONNECT_TIMEOUT SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD
>>> diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
>>> index f04abf662ec6..10061ef21730 100644
>>> --- a/net/vmw_vsock/af_vsock.c
>>> +++ b/net/vmw_vsock/af_vsock.c
>>> @@ -1644,6 +1644,17 @@ static int vsock_connectible_setsockopt(struct socket *sock,
>>>         }
>>>         break;
>>>     }
>>> +    case SO_VM_SOCKETS_ZEROCOPY: {
>>> +        if (!transport || !transport->rx_zerocopy_set) {
>>> +            err = -EOPNOTSUPP;
>>> +        } else {
>>> +            COPY_IN(val);
>>> +
>>> +            if (transport->rx_zerocopy_set(vsk, val))
>>> +                err = -EINVAL;
>>> +        }
>>> +        break;
>>> +    }
>>>
>>>     default:
>>>         err = -ENOPROTOOPT;
>>> @@ -1657,6 +1668,48 @@ static int vsock_connectible_setsockopt(struct socket *sock,
>>>     return err;
>>> }
>>>
>>> +static const struct vm_operations_struct afvsock_vm_ops = {
>>> +};
>>> +
>>> +static int vsock_recv_zerocopy(struct socket *sock,
>>> +                   unsigned long address)
>>> +{
>>> +    struct sock *sk = sock->sk;
>>> +    struct vsock_sock *vsk = vsock_sk(sk);
>>> +    struct vm_area_struct *vma;
>>> +    const struct vsock_transport *transport;
>>> +    int res;
>>> +
>>> +    transport = vsk->transport;
>>> +
>>> +    if (!transport->rx_zerocopy_get)
>>> +        return -EOPNOTSUPP;
>>> +
>>> +    if (!transport->rx_zerocopy_get(vsk))
>>> +        return -EOPNOTSUPP;
>>
>> Maybe we can merge in
>>         if (!transport->rx_zerocopy_get ||
>>             !transport->rx_zerocopy_get(vsk)}
>>                 return -EOPNOTSUPP;
>>
>>> +
>>> +    if (!transport->zerocopy_dequeue)
>>> +        return -EOPNOTSUPP;
>>> +
>>> +    lock_sock(sk);
>>> +    mmap_write_lock(current->mm);
>>
>> So, multiple threads using different sockets are serialized if they use zero-copy?
>>
>> IIUC this is necessary because the callback calls vm_insert_page().
>>
>> At this point I think it's better not to do this in every transport, but have the callback return an array of pages to map and we map them here trying to limit as much as possible the critical section to protect with mmap_write_lock().
>
>Yes, it will be easy to return array of pages by transport callback,
>
>>
>>> +
>>> +    vma = vma_lookup(current->mm, address);
>>> +
>>> +    if (!vma || vma->vm_ops != &afvsock_vm_ops) {
>>> +        mmap_write_unlock(current->mm);
>>> +        release_sock(sk);
>>> +        return -EINVAL;
>>> +    }
>>> +
>>> +    res = transport->zerocopy_dequeue(vsk, vma, address);
>>> +
>>> +    mmap_write_unlock(current->mm);
>>> +    release_sock(sk);
>>> +
>>> +    return res;
>>> +}
>>> +
>>> static int vsock_connectible_getsockopt(struct socket *sock,
>>>                     int level, int optname,
>>>                     char __user *optval,
>>> @@ -1701,6 +1754,39 @@ static int vsock_connectible_getsockopt(struct socket *sock,
>>>         lv = sock_get_timeout(vsk->connect_timeout, &v,
>>>                       optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
>>>         break;
>>> +    case SO_VM_SOCKETS_ZEROCOPY: {
>>> +        const struct vsock_transport *transport;
>>> +        int res;
>>> +
>>> +        transport = vsk->transport;
>>> +
>>> +        if (!transport->rx_zerocopy_get)
>>> +            return -EOPNOTSUPP;
>>> +
>>> +        lock_sock(sk);
>>
>> I think we should call lock_sock() before reading the transport to avoid races and we should check if it is assigned.
>>
>> At that point I think is better to store this info in vsock_sock and not in the transport.
>You mean to store flag that zerocopy is enabled in 'vsock_sock', just reading it here, without touching transport?

Yep.

Thanks,
Stefano
diff mbox series

Patch

diff --git a/include/net/af_vsock.h b/include/net/af_vsock.h
index f742e50207fb..f15f84c648ff 100644
--- a/include/net/af_vsock.h
+++ b/include/net/af_vsock.h
@@ -135,6 +135,13 @@  struct vsock_transport {
 	bool (*stream_is_active)(struct vsock_sock *);
 	bool (*stream_allow)(u32 cid, u32 port);
 
+	int (*rx_zerocopy_set)(struct vsock_sock *vsk,
+			       bool enable);
+	int (*rx_zerocopy_get)(struct vsock_sock *vsk);
+	int (*zerocopy_dequeue)(struct vsock_sock *vsk,
+				struct vm_area_struct *vma,
+				unsigned long addr);
+
 	/* SEQ_PACKET. */
 	ssize_t (*seqpacket_dequeue)(struct vsock_sock *vsk, struct msghdr *msg,
 				     int flags);
diff --git a/include/uapi/linux/vm_sockets.h b/include/uapi/linux/vm_sockets.h
index c60ca33eac59..d1f792bed1a7 100644
--- a/include/uapi/linux/vm_sockets.h
+++ b/include/uapi/linux/vm_sockets.h
@@ -83,6 +83,9 @@ 
 
 #define SO_VM_SOCKETS_CONNECT_TIMEOUT_NEW 8
 
+#define SO_VM_SOCKETS_MAP_RX 9
+#define SO_VM_SOCKETS_ZEROCOPY 10
+
 #if !defined(__KERNEL__)
 #if __BITS_PER_LONG == 64 || (defined(__x86_64__) && defined(__ILP32__))
 #define SO_VM_SOCKETS_CONNECT_TIMEOUT SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD
diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index f04abf662ec6..10061ef21730 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -1644,6 +1644,17 @@  static int vsock_connectible_setsockopt(struct socket *sock,
 		}
 		break;
 	}
+	case SO_VM_SOCKETS_ZEROCOPY: {
+		if (!transport || !transport->rx_zerocopy_set) {
+			err = -EOPNOTSUPP;
+		} else {
+			COPY_IN(val);
+
+			if (transport->rx_zerocopy_set(vsk, val))
+				err = -EINVAL;
+		}
+		break;
+	}
 
 	default:
 		err = -ENOPROTOOPT;
@@ -1657,6 +1668,48 @@  static int vsock_connectible_setsockopt(struct socket *sock,
 	return err;
 }
 
+static const struct vm_operations_struct afvsock_vm_ops = {
+};
+
+static int vsock_recv_zerocopy(struct socket *sock,
+			       unsigned long address)
+{
+	struct sock *sk = sock->sk;
+	struct vsock_sock *vsk = vsock_sk(sk);
+	struct vm_area_struct *vma;
+	const struct vsock_transport *transport;
+	int res;
+
+	transport = vsk->transport;
+
+	if (!transport->rx_zerocopy_get)
+		return -EOPNOTSUPP;
+
+	if (!transport->rx_zerocopy_get(vsk))
+		return -EOPNOTSUPP;
+
+	if (!transport->zerocopy_dequeue)
+		return -EOPNOTSUPP;
+
+	lock_sock(sk);
+	mmap_write_lock(current->mm);
+
+	vma = vma_lookup(current->mm, address);
+
+	if (!vma || vma->vm_ops != &afvsock_vm_ops) {
+		mmap_write_unlock(current->mm);
+		release_sock(sk);
+		return -EINVAL;
+	}
+
+	res = transport->zerocopy_dequeue(vsk, vma, address);
+
+	mmap_write_unlock(current->mm);
+	release_sock(sk);
+
+	return res;
+}
+
 static int vsock_connectible_getsockopt(struct socket *sock,
 					int level, int optname,
 					char __user *optval,
@@ -1701,6 +1754,39 @@  static int vsock_connectible_getsockopt(struct socket *sock,
 		lv = sock_get_timeout(vsk->connect_timeout, &v,
 				      optname == SO_VM_SOCKETS_CONNECT_TIMEOUT_OLD);
 		break;
+	case SO_VM_SOCKETS_ZEROCOPY: {
+		const struct vsock_transport *transport;
+		int res;
+
+		transport = vsk->transport;
+
+		if (!transport->rx_zerocopy_get)
+			return -EOPNOTSUPP;
+
+		lock_sock(sk);
+
+		res = transport->rx_zerocopy_get(vsk);
+
+		release_sock(sk);
+
+		if (res < 0)
+			return -EINVAL;
+
+		v.val64 = res;
+
+		break;
+	}
+	case SO_VM_SOCKETS_MAP_RX: {
+		unsigned long vma_addr;
+
+		if (len < sizeof(vma_addr))
+			return -EINVAL;
+
+		if (copy_from_user(&vma_addr, optval, sizeof(vma_addr)))
+			return -EFAULT;
+
+		return vsock_recv_zerocopy(sock, vma_addr);
+	}
 
 	default:
 		return -ENOPROTOOPT;
@@ -2129,6 +2215,19 @@  vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 	return err;
 }
 
+static int afvsock_mmap(struct file *file, struct socket *sock,
+			struct vm_area_struct *vma)
+{
+	if (vma->vm_flags & (VM_WRITE | VM_EXEC))
+		return -EPERM;
+
+	vma->vm_flags &= ~(VM_MAYWRITE | VM_MAYEXEC);
+	vma->vm_flags |= (VM_MIXEDMAP);
+	vma->vm_ops = &afvsock_vm_ops;
+
+	return 0;
+}
+
 static const struct proto_ops vsock_stream_ops = {
 	.family = PF_VSOCK,
 	.owner = THIS_MODULE,
@@ -2148,6 +2247,7 @@  static const struct proto_ops vsock_stream_ops = {
 	.recvmsg = vsock_connectible_recvmsg,
 	.mmap = sock_no_mmap,
 	.sendpage = sock_no_sendpage,
+	.mmap = afvsock_mmap,
 };
 
 static const struct proto_ops vsock_seqpacket_ops = {