diff mbox series

[v2,1/3] vsock/virtio: use RCU to avoid use-after-free on the_virtio_vsock

Message ID 20190628123659.139576-2-sgarzare@redhat.com (mailing list archive)
State New, archived
Headers show
Series vsock/virtio: several fixes in the .probe() and .remove() | expand

Commit Message

Stefano Garzarella June 28, 2019, 12:36 p.m. UTC
Some callbacks used by the upper layers can run while we are in the
.remove(). A potential use-after-free can happen, because we free
the_virtio_vsock without knowing if the callbacks are over or not.

To solve this issue we move the assignment of the_virtio_vsock at the
end of .probe(), when we finished all the initialization, and at the
beginning of .remove(), before to release resources.
For the same reason, we do the same also for the vdev->priv.

We use RCU to be sure that all callbacks that use the_virtio_vsock
ended before freeing it. This is not required for callbacks that
use vdev->priv, because after the vdev->config->del_vqs() we are sure
that they are ended and will no longer be invoked.

We also take the mutex during the .remove() to avoid that .probe() can
run while we are resetting the device.

Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
---
 net/vmw_vsock/virtio_transport.c | 67 +++++++++++++++++++++-----------
 1 file changed, 44 insertions(+), 23 deletions(-)

Comments

Stefan Hajnoczi July 1, 2019, 2:54 p.m. UTC | #1
On Fri, Jun 28, 2019 at 02:36:57PM +0200, Stefano Garzarella wrote:
> Some callbacks used by the upper layers can run while we are in the
> .remove(). A potential use-after-free can happen, because we free
> the_virtio_vsock without knowing if the callbacks are over or not.
> 
> To solve this issue we move the assignment of the_virtio_vsock at the
> end of .probe(), when we finished all the initialization, and at the
> beginning of .remove(), before to release resources.
> For the same reason, we do the same also for the vdev->priv.
> 
> We use RCU to be sure that all callbacks that use the_virtio_vsock
> ended before freeing it. This is not required for callbacks that
> use vdev->priv, because after the vdev->config->del_vqs() we are sure
> that they are ended and will no longer be invoked.

->del_vqs() is only called at the very end, did you forget to move it
earlier?

In particular, the virtqueue handler callbacks schedule a workqueue.
The work functions use container_of() to get vsock.  We need to be sure
that the work item isn't freed along with vsock while the work item is
still pending.

How do we know that the virtqueue handler is never called in such a way
that it sees vsock != NULL (there is no explicit memory barrier on the
read side) and then schedules a work item after flush_work() has run?

Stefan
Stefan Hajnoczi July 1, 2019, 3:10 p.m. UTC | #2
On Fri, Jun 28, 2019 at 02:36:57PM +0200, Stefano Garzarella wrote:
> Some callbacks used by the upper layers can run while we are in the
> .remove(). A potential use-after-free can happen, because we free
> the_virtio_vsock without knowing if the callbacks are over or not.
> 
> To solve this issue we move the assignment of the_virtio_vsock at the
> end of .probe(), when we finished all the initialization, and at the
> beginning of .remove(), before to release resources.
> For the same reason, we do the same also for the vdev->priv.
> 
> We use RCU to be sure that all callbacks that use the_virtio_vsock
> ended before freeing it. This is not required for callbacks that
> use vdev->priv, because after the vdev->config->del_vqs() we are sure
> that they are ended and will no longer be invoked.

My question is answered in Patch 3.
Jason Wang July 3, 2019, 9:53 a.m. UTC | #3
On 2019/6/28 下午8:36, Stefano Garzarella wrote:
> Some callbacks used by the upper layers can run while we are in the
> .remove(). A potential use-after-free can happen, because we free
> the_virtio_vsock without knowing if the callbacks are over or not.
>
> To solve this issue we move the assignment of the_virtio_vsock at the
> end of .probe(), when we finished all the initialization, and at the
> beginning of .remove(), before to release resources.
> For the same reason, we do the same also for the vdev->priv.
>
> We use RCU to be sure that all callbacks that use the_virtio_vsock
> ended before freeing it. This is not required for callbacks that
> use vdev->priv, because after the vdev->config->del_vqs() we are sure
> that they are ended and will no longer be invoked.
>
> We also take the mutex during the .remove() to avoid that .probe() can
> run while we are resetting the device.
>
> Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
> ---
>   net/vmw_vsock/virtio_transport.c | 67 +++++++++++++++++++++-----------
>   1 file changed, 44 insertions(+), 23 deletions(-)
>
> diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
> index 9c287e3e393c..7ad510ec12e0 100644
> --- a/net/vmw_vsock/virtio_transport.c
> +++ b/net/vmw_vsock/virtio_transport.c
> @@ -65,19 +65,22 @@ struct virtio_vsock {
>   	u32 guest_cid;
>   };
>   
> -static struct virtio_vsock *virtio_vsock_get(void)
> -{
> -	return the_virtio_vsock;
> -}
> -
>   static u32 virtio_transport_get_local_cid(void)
>   {
> -	struct virtio_vsock *vsock = virtio_vsock_get();
> +	struct virtio_vsock *vsock;
> +	u32 ret;
>   
> -	if (!vsock)
> -		return VMADDR_CID_ANY;
> +	rcu_read_lock();
> +	vsock = rcu_dereference(the_virtio_vsock);
> +	if (!vsock) {
> +		ret = VMADDR_CID_ANY;
> +		goto out_rcu;
> +	}
>   
> -	return vsock->guest_cid;
> +	ret = vsock->guest_cid;
> +out_rcu:
> +	rcu_read_unlock();
> +	return ret;
>   }
>   
>   static void virtio_transport_loopback_work(struct work_struct *work)
> @@ -197,14 +200,18 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
>   	struct virtio_vsock *vsock;
>   	int len = pkt->len;
>   
> -	vsock = virtio_vsock_get();
> +	rcu_read_lock();
> +	vsock = rcu_dereference(the_virtio_vsock);
>   	if (!vsock) {
>   		virtio_transport_free_pkt(pkt);
> -		return -ENODEV;
> +		len = -ENODEV;
> +		goto out_rcu;
>   	}
>   
> -	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid)
> -		return virtio_transport_send_pkt_loopback(vsock, pkt);
> +	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
> +		len = virtio_transport_send_pkt_loopback(vsock, pkt);
> +		goto out_rcu;
> +	}
>   
>   	if (pkt->reply)
>   		atomic_inc(&vsock->queued_replies);
> @@ -214,6 +221,9 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
>   	spin_unlock_bh(&vsock->send_pkt_list_lock);
>   
>   	queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
> +
> +out_rcu:
> +	rcu_read_unlock();
>   	return len;
>   }
>   
> @@ -222,12 +232,14 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
>   {
>   	struct virtio_vsock *vsock;
>   	struct virtio_vsock_pkt *pkt, *n;
> -	int cnt = 0;
> +	int cnt = 0, ret;
>   	LIST_HEAD(freeme);
>   
> -	vsock = virtio_vsock_get();
> +	rcu_read_lock();
> +	vsock = rcu_dereference(the_virtio_vsock);
>   	if (!vsock) {
> -		return -ENODEV;
> +		ret = -ENODEV;
> +		goto out_rcu;
>   	}
>   
>   	spin_lock_bh(&vsock->send_pkt_list_lock);
> @@ -255,7 +267,11 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
>   			queue_work(virtio_vsock_workqueue, &vsock->rx_work);
>   	}
>   
> -	return 0;
> +	ret = 0;
> +
> +out_rcu:
> +	rcu_read_unlock();
> +	return ret;
>   }
>   
>   static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
> @@ -590,8 +606,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
>   	vsock->rx_buf_max_nr = 0;
>   	atomic_set(&vsock->queued_replies, 0);
>   
> -	vdev->priv = vsock;
> -	the_virtio_vsock = vsock;
>   	mutex_init(&vsock->tx_lock);
>   	mutex_init(&vsock->rx_lock);
>   	mutex_init(&vsock->event_lock);
> @@ -613,6 +627,9 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
>   	virtio_vsock_event_fill(vsock);
>   	mutex_unlock(&vsock->event_lock);
>   
> +	vdev->priv = vsock;
> +	rcu_assign_pointer(the_virtio_vsock, vsock);


You probably need to use rcu_dereference_protected() to access 
the_virtio_vsock in the function in order to survive from sparse.


> +
>   	mutex_unlock(&the_virtio_vsock_mutex);
>   	return 0;
>   
> @@ -627,6 +644,12 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
>   	struct virtio_vsock *vsock = vdev->priv;
>   	struct virtio_vsock_pkt *pkt;
>   
> +	mutex_lock(&the_virtio_vsock_mutex);
> +
> +	vdev->priv = NULL;
> +	rcu_assign_pointer(the_virtio_vsock, NULL);


This is still suspicious, can we access the_virtio_vsock through 
vdev->priv? If yes, we may still get use-after-free since it was not 
protected by RCU.

Another more interesting question, I believe we will do singleton for 
virtio_vsock structure. Then what's the point of using vdev->priv to 
access the_virtio_vsock? It looks to me we can it brings extra troubles 
for doing synchronization.

Thanks


> +	synchronize_rcu();
> +
>   	flush_work(&vsock->loopback_work);
>   	flush_work(&vsock->rx_work);
>   	flush_work(&vsock->tx_work);
> @@ -666,12 +689,10 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
>   	}
>   	spin_unlock_bh(&vsock->loopback_list_lock);
>   
> -	mutex_lock(&the_virtio_vsock_mutex);
> -	the_virtio_vsock = NULL;
> -	mutex_unlock(&the_virtio_vsock_mutex);
> -
>   	vdev->config->del_vqs(vdev);
>   
> +	mutex_unlock(&the_virtio_vsock_mutex);
> +
>   	kfree(vsock);
>   }
>
Stefano Garzarella July 3, 2019, 10:41 a.m. UTC | #4
On Wed, Jul 03, 2019 at 05:53:58PM +0800, Jason Wang wrote:
> 
> On 2019/6/28 下午8:36, Stefano Garzarella wrote:
> > Some callbacks used by the upper layers can run while we are in the
> > .remove(). A potential use-after-free can happen, because we free
> > the_virtio_vsock without knowing if the callbacks are over or not.
> > 
> > To solve this issue we move the assignment of the_virtio_vsock at the
> > end of .probe(), when we finished all the initialization, and at the
> > beginning of .remove(), before to release resources.
> > For the same reason, we do the same also for the vdev->priv.
> > 
> > We use RCU to be sure that all callbacks that use the_virtio_vsock
> > ended before freeing it. This is not required for callbacks that
> > use vdev->priv, because after the vdev->config->del_vqs() we are sure
> > that they are ended and will no longer be invoked.
> > 
> > We also take the mutex during the .remove() to avoid that .probe() can
> > run while we are resetting the device.
> > 
> > Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
> > ---
> >   net/vmw_vsock/virtio_transport.c | 67 +++++++++++++++++++++-----------
> >   1 file changed, 44 insertions(+), 23 deletions(-)
> > 
> > diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
> > index 9c287e3e393c..7ad510ec12e0 100644
> > --- a/net/vmw_vsock/virtio_transport.c
> > +++ b/net/vmw_vsock/virtio_transport.c
> > @@ -65,19 +65,22 @@ struct virtio_vsock {
> >   	u32 guest_cid;
> >   };
> > -static struct virtio_vsock *virtio_vsock_get(void)
> > -{
> > -	return the_virtio_vsock;
> > -}
> > -
> >   static u32 virtio_transport_get_local_cid(void)
> >   {
> > -	struct virtio_vsock *vsock = virtio_vsock_get();
> > +	struct virtio_vsock *vsock;
> > +	u32 ret;
> > -	if (!vsock)
> > -		return VMADDR_CID_ANY;
> > +	rcu_read_lock();
> > +	vsock = rcu_dereference(the_virtio_vsock);
> > +	if (!vsock) {
> > +		ret = VMADDR_CID_ANY;
> > +		goto out_rcu;
> > +	}
> > -	return vsock->guest_cid;
> > +	ret = vsock->guest_cid;
> > +out_rcu:
> > +	rcu_read_unlock();
> > +	return ret;
> >   }
> >   static void virtio_transport_loopback_work(struct work_struct *work)
> > @@ -197,14 +200,18 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
> >   	struct virtio_vsock *vsock;
> >   	int len = pkt->len;
> > -	vsock = virtio_vsock_get();
> > +	rcu_read_lock();
> > +	vsock = rcu_dereference(the_virtio_vsock);
> >   	if (!vsock) {
> >   		virtio_transport_free_pkt(pkt);
> > -		return -ENODEV;
> > +		len = -ENODEV;
> > +		goto out_rcu;
> >   	}
> > -	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid)
> > -		return virtio_transport_send_pkt_loopback(vsock, pkt);
> > +	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
> > +		len = virtio_transport_send_pkt_loopback(vsock, pkt);
> > +		goto out_rcu;
> > +	}
> >   	if (pkt->reply)
> >   		atomic_inc(&vsock->queued_replies);
> > @@ -214,6 +221,9 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
> >   	spin_unlock_bh(&vsock->send_pkt_list_lock);
> >   	queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
> > +
> > +out_rcu:
> > +	rcu_read_unlock();
> >   	return len;
> >   }
> > @@ -222,12 +232,14 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
> >   {
> >   	struct virtio_vsock *vsock;
> >   	struct virtio_vsock_pkt *pkt, *n;
> > -	int cnt = 0;
> > +	int cnt = 0, ret;
> >   	LIST_HEAD(freeme);
> > -	vsock = virtio_vsock_get();
> > +	rcu_read_lock();
> > +	vsock = rcu_dereference(the_virtio_vsock);
> >   	if (!vsock) {
> > -		return -ENODEV;
> > +		ret = -ENODEV;
> > +		goto out_rcu;
> >   	}
> >   	spin_lock_bh(&vsock->send_pkt_list_lock);
> > @@ -255,7 +267,11 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
> >   			queue_work(virtio_vsock_workqueue, &vsock->rx_work);
> >   	}
> > -	return 0;
> > +	ret = 0;
> > +
> > +out_rcu:
> > +	rcu_read_unlock();
> > +	return ret;
> >   }
> >   static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
> > @@ -590,8 +606,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
> >   	vsock->rx_buf_max_nr = 0;
> >   	atomic_set(&vsock->queued_replies, 0);
> > -	vdev->priv = vsock;
> > -	the_virtio_vsock = vsock;
> >   	mutex_init(&vsock->tx_lock);
> >   	mutex_init(&vsock->rx_lock);
> >   	mutex_init(&vsock->event_lock);
> > @@ -613,6 +627,9 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
> >   	virtio_vsock_event_fill(vsock);
> >   	mutex_unlock(&vsock->event_lock);
> > +	vdev->priv = vsock;
> > +	rcu_assign_pointer(the_virtio_vsock, vsock);
> 
> 
> You probably need to use rcu_dereference_protected() to access
> the_virtio_vsock in the function in order to survive from sparse.
> 

Ooo, thanks!

Do you mean when we check if the_virtio_vsock is not null at the beginning of
virtio_vsock_probe()?

> 
> > +
> >   	mutex_unlock(&the_virtio_vsock_mutex);
> >   	return 0;
> > @@ -627,6 +644,12 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
> >   	struct virtio_vsock *vsock = vdev->priv;
> >   	struct virtio_vsock_pkt *pkt;
> > +	mutex_lock(&the_virtio_vsock_mutex);
> > +
> > +	vdev->priv = NULL;
> > +	rcu_assign_pointer(the_virtio_vsock, NULL);
> 
> 
> This is still suspicious, can we access the_virtio_vsock through vdev->priv?
> If yes, we may still get use-after-free since it was not protected by RCU.

We will free the object only after calling the del_vqs(), so we are sure
that the vq_callbacks ended and will no longer be invoked.
So, IIUC it shouldn't happen.

> 
> Another more interesting question, I believe we will do singleton for
> virtio_vsock structure. Then what's the point of using vdev->priv to access
> the_virtio_vsock? It looks to me we can it brings extra troubles for doing
> synchronization.

I thought about it when I tried to use RCU to stop the worker and I
think make sense. Maybe can be another series after this will be merged.

@Stefan, what do you think about that?

Thanks,
Stefano
Jason Wang July 4, 2019, 3:58 a.m. UTC | #5
On 2019/7/3 下午6:41, Stefano Garzarella wrote:
> On Wed, Jul 03, 2019 at 05:53:58PM +0800, Jason Wang wrote:
>> On 2019/6/28 下午8:36, Stefano Garzarella wrote:
>>> Some callbacks used by the upper layers can run while we are in the
>>> .remove(). A potential use-after-free can happen, because we free
>>> the_virtio_vsock without knowing if the callbacks are over or not.
>>>
>>> To solve this issue we move the assignment of the_virtio_vsock at the
>>> end of .probe(), when we finished all the initialization, and at the
>>> beginning of .remove(), before to release resources.
>>> For the same reason, we do the same also for the vdev->priv.
>>>
>>> We use RCU to be sure that all callbacks that use the_virtio_vsock
>>> ended before freeing it. This is not required for callbacks that
>>> use vdev->priv, because after the vdev->config->del_vqs() we are sure
>>> that they are ended and will no longer be invoked.
>>>
>>> We also take the mutex during the .remove() to avoid that .probe() can
>>> run while we are resetting the device.
>>>
>>> Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
>>> ---
>>>    net/vmw_vsock/virtio_transport.c | 67 +++++++++++++++++++++-----------
>>>    1 file changed, 44 insertions(+), 23 deletions(-)
>>>
>>> diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
>>> index 9c287e3e393c..7ad510ec12e0 100644
>>> --- a/net/vmw_vsock/virtio_transport.c
>>> +++ b/net/vmw_vsock/virtio_transport.c
>>> @@ -65,19 +65,22 @@ struct virtio_vsock {
>>>    	u32 guest_cid;
>>>    };
>>> -static struct virtio_vsock *virtio_vsock_get(void)
>>> -{
>>> -	return the_virtio_vsock;
>>> -}
>>> -
>>>    static u32 virtio_transport_get_local_cid(void)
>>>    {
>>> -	struct virtio_vsock *vsock = virtio_vsock_get();
>>> +	struct virtio_vsock *vsock;
>>> +	u32 ret;
>>> -	if (!vsock)
>>> -		return VMADDR_CID_ANY;
>>> +	rcu_read_lock();
>>> +	vsock = rcu_dereference(the_virtio_vsock);
>>> +	if (!vsock) {
>>> +		ret = VMADDR_CID_ANY;
>>> +		goto out_rcu;
>>> +	}
>>> -	return vsock->guest_cid;
>>> +	ret = vsock->guest_cid;
>>> +out_rcu:
>>> +	rcu_read_unlock();
>>> +	return ret;
>>>    }
>>>    static void virtio_transport_loopback_work(struct work_struct *work)
>>> @@ -197,14 +200,18 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
>>>    	struct virtio_vsock *vsock;
>>>    	int len = pkt->len;
>>> -	vsock = virtio_vsock_get();
>>> +	rcu_read_lock();
>>> +	vsock = rcu_dereference(the_virtio_vsock);
>>>    	if (!vsock) {
>>>    		virtio_transport_free_pkt(pkt);
>>> -		return -ENODEV;
>>> +		len = -ENODEV;
>>> +		goto out_rcu;
>>>    	}
>>> -	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid)
>>> -		return virtio_transport_send_pkt_loopback(vsock, pkt);
>>> +	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
>>> +		len = virtio_transport_send_pkt_loopback(vsock, pkt);
>>> +		goto out_rcu;
>>> +	}
>>>    	if (pkt->reply)
>>>    		atomic_inc(&vsock->queued_replies);
>>> @@ -214,6 +221,9 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
>>>    	spin_unlock_bh(&vsock->send_pkt_list_lock);
>>>    	queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
>>> +
>>> +out_rcu:
>>> +	rcu_read_unlock();
>>>    	return len;
>>>    }
>>> @@ -222,12 +232,14 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
>>>    {
>>>    	struct virtio_vsock *vsock;
>>>    	struct virtio_vsock_pkt *pkt, *n;
>>> -	int cnt = 0;
>>> +	int cnt = 0, ret;
>>>    	LIST_HEAD(freeme);
>>> -	vsock = virtio_vsock_get();
>>> +	rcu_read_lock();
>>> +	vsock = rcu_dereference(the_virtio_vsock);
>>>    	if (!vsock) {
>>> -		return -ENODEV;
>>> +		ret = -ENODEV;
>>> +		goto out_rcu;
>>>    	}
>>>    	spin_lock_bh(&vsock->send_pkt_list_lock);
>>> @@ -255,7 +267,11 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
>>>    			queue_work(virtio_vsock_workqueue, &vsock->rx_work);
>>>    	}
>>> -	return 0;
>>> +	ret = 0;
>>> +
>>> +out_rcu:
>>> +	rcu_read_unlock();
>>> +	return ret;
>>>    }
>>>    static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
>>> @@ -590,8 +606,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
>>>    	vsock->rx_buf_max_nr = 0;
>>>    	atomic_set(&vsock->queued_replies, 0);
>>> -	vdev->priv = vsock;
>>> -	the_virtio_vsock = vsock;
>>>    	mutex_init(&vsock->tx_lock);
>>>    	mutex_init(&vsock->rx_lock);
>>>    	mutex_init(&vsock->event_lock);
>>> @@ -613,6 +627,9 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
>>>    	virtio_vsock_event_fill(vsock);
>>>    	mutex_unlock(&vsock->event_lock);
>>> +	vdev->priv = vsock;
>>> +	rcu_assign_pointer(the_virtio_vsock, vsock);
>>
>> You probably need to use rcu_dereference_protected() to access
>> the_virtio_vsock in the function in order to survive from sparse.
>>
> Ooo, thanks!
>
> Do you mean when we check if the_virtio_vsock is not null at the beginning of
> virtio_vsock_probe()?


I mean instead of:

     /* Only one virtio-vsock device per guest is supported */
     if (the_virtio_vsock) {
         ret = -EBUSY;
         goto out;
     }

you should use:

if (rcu_dereference_protected(the_virtio_vosck, 
lock_dep_is_held(&the_virtio_vsock_mutex))

...


>
>>> +
>>>    	mutex_unlock(&the_virtio_vsock_mutex);
>>>    	return 0;
>>> @@ -627,6 +644,12 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
>>>    	struct virtio_vsock *vsock = vdev->priv;
>>>    	struct virtio_vsock_pkt *pkt;
>>> +	mutex_lock(&the_virtio_vsock_mutex);
>>> +
>>> +	vdev->priv = NULL;
>>> +	rcu_assign_pointer(the_virtio_vsock, NULL);
>>
>> This is still suspicious, can we access the_virtio_vsock through vdev->priv?
>> If yes, we may still get use-after-free since it was not protected by RCU.
> We will free the object only after calling the del_vqs(), so we are sure
> that the vq_callbacks ended and will no longer be invoked.
> So, IIUC it shouldn't happen.


Yes, but any dereference that is not done in vq_callbacks will be very 
dangerous in the future.

Thanks


>
>> Another more interesting question, I believe we will do singleton for
>> virtio_vsock structure. Then what's the point of using vdev->priv to access
>> the_virtio_vsock? It looks to me we can it brings extra troubles for doing
>> synchronization.
> I thought about it when I tried to use RCU to stop the worker and I
> think make sense. Maybe can be another series after this will be merged.
>
> @Stefan, what do you think about that?
>
> Thanks,
> Stefano
Stefano Garzarella July 4, 2019, 9:20 a.m. UTC | #6
On Thu, Jul 04, 2019 at 11:58:00AM +0800, Jason Wang wrote:
> 
> On 2019/7/3 下午6:41, Stefano Garzarella wrote:
> > On Wed, Jul 03, 2019 at 05:53:58PM +0800, Jason Wang wrote:
> > > On 2019/6/28 下午8:36, Stefano Garzarella wrote:
> > > > Some callbacks used by the upper layers can run while we are in the
> > > > .remove(). A potential use-after-free can happen, because we free
> > > > the_virtio_vsock without knowing if the callbacks are over or not.
> > > > 
> > > > To solve this issue we move the assignment of the_virtio_vsock at the
> > > > end of .probe(), when we finished all the initialization, and at the
> > > > beginning of .remove(), before to release resources.
> > > > For the same reason, we do the same also for the vdev->priv.
> > > > 
> > > > We use RCU to be sure that all callbacks that use the_virtio_vsock
> > > > ended before freeing it. This is not required for callbacks that
> > > > use vdev->priv, because after the vdev->config->del_vqs() we are sure
> > > > that they are ended and will no longer be invoked.
> > > > 
> > > > We also take the mutex during the .remove() to avoid that .probe() can
> > > > run while we are resetting the device.
> > > > 
> > > > Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
> > > > ---
> > > >    net/vmw_vsock/virtio_transport.c | 67 +++++++++++++++++++++-----------
> > > >    1 file changed, 44 insertions(+), 23 deletions(-)
> > > > 
> > > > diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
> > > > index 9c287e3e393c..7ad510ec12e0 100644
> > > > --- a/net/vmw_vsock/virtio_transport.c
> > > > +++ b/net/vmw_vsock/virtio_transport.c
> > > > @@ -65,19 +65,22 @@ struct virtio_vsock {
> > > >    	u32 guest_cid;
> > > >    };
> > > > -static struct virtio_vsock *virtio_vsock_get(void)
> > > > -{
> > > > -	return the_virtio_vsock;
> > > > -}
> > > > -
> > > >    static u32 virtio_transport_get_local_cid(void)
> > > >    {
> > > > -	struct virtio_vsock *vsock = virtio_vsock_get();
> > > > +	struct virtio_vsock *vsock;
> > > > +	u32 ret;
> > > > -	if (!vsock)
> > > > -		return VMADDR_CID_ANY;
> > > > +	rcu_read_lock();
> > > > +	vsock = rcu_dereference(the_virtio_vsock);
> > > > +	if (!vsock) {
> > > > +		ret = VMADDR_CID_ANY;
> > > > +		goto out_rcu;
> > > > +	}
> > > > -	return vsock->guest_cid;
> > > > +	ret = vsock->guest_cid;
> > > > +out_rcu:
> > > > +	rcu_read_unlock();
> > > > +	return ret;
> > > >    }
> > > >    static void virtio_transport_loopback_work(struct work_struct *work)
> > > > @@ -197,14 +200,18 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
> > > >    	struct virtio_vsock *vsock;
> > > >    	int len = pkt->len;
> > > > -	vsock = virtio_vsock_get();
> > > > +	rcu_read_lock();
> > > > +	vsock = rcu_dereference(the_virtio_vsock);
> > > >    	if (!vsock) {
> > > >    		virtio_transport_free_pkt(pkt);
> > > > -		return -ENODEV;
> > > > +		len = -ENODEV;
> > > > +		goto out_rcu;
> > > >    	}
> > > > -	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid)
> > > > -		return virtio_transport_send_pkt_loopback(vsock, pkt);
> > > > +	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
> > > > +		len = virtio_transport_send_pkt_loopback(vsock, pkt);
> > > > +		goto out_rcu;
> > > > +	}
> > > >    	if (pkt->reply)
> > > >    		atomic_inc(&vsock->queued_replies);
> > > > @@ -214,6 +221,9 @@ virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
> > > >    	spin_unlock_bh(&vsock->send_pkt_list_lock);
> > > >    	queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
> > > > +
> > > > +out_rcu:
> > > > +	rcu_read_unlock();
> > > >    	return len;
> > > >    }
> > > > @@ -222,12 +232,14 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
> > > >    {
> > > >    	struct virtio_vsock *vsock;
> > > >    	struct virtio_vsock_pkt *pkt, *n;
> > > > -	int cnt = 0;
> > > > +	int cnt = 0, ret;
> > > >    	LIST_HEAD(freeme);
> > > > -	vsock = virtio_vsock_get();
> > > > +	rcu_read_lock();
> > > > +	vsock = rcu_dereference(the_virtio_vsock);
> > > >    	if (!vsock) {
> > > > -		return -ENODEV;
> > > > +		ret = -ENODEV;
> > > > +		goto out_rcu;
> > > >    	}
> > > >    	spin_lock_bh(&vsock->send_pkt_list_lock);
> > > > @@ -255,7 +267,11 @@ virtio_transport_cancel_pkt(struct vsock_sock *vsk)
> > > >    			queue_work(virtio_vsock_workqueue, &vsock->rx_work);
> > > >    	}
> > > > -	return 0;
> > > > +	ret = 0;
> > > > +
> > > > +out_rcu:
> > > > +	rcu_read_unlock();
> > > > +	return ret;
> > > >    }
> > > >    static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
> > > > @@ -590,8 +606,6 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
> > > >    	vsock->rx_buf_max_nr = 0;
> > > >    	atomic_set(&vsock->queued_replies, 0);
> > > > -	vdev->priv = vsock;
> > > > -	the_virtio_vsock = vsock;
> > > >    	mutex_init(&vsock->tx_lock);
> > > >    	mutex_init(&vsock->rx_lock);
> > > >    	mutex_init(&vsock->event_lock);
> > > > @@ -613,6 +627,9 @@ static int virtio_vsock_probe(struct virtio_device *vdev)
> > > >    	virtio_vsock_event_fill(vsock);
> > > >    	mutex_unlock(&vsock->event_lock);
> > > > +	vdev->priv = vsock;
> > > > +	rcu_assign_pointer(the_virtio_vsock, vsock);
> > > 
> > > You probably need to use rcu_dereference_protected() to access
> > > the_virtio_vsock in the function in order to survive from sparse.
> > > 
> > Ooo, thanks!
> > 
> > Do you mean when we check if the_virtio_vsock is not null at the beginning of
> > virtio_vsock_probe()?
> 
> 
> I mean instead of:
> 
>     /* Only one virtio-vsock device per guest is supported */
>     if (the_virtio_vsock) {
>         ret = -EBUSY;
>         goto out;
>     }
> 
> you should use:
> 
> if (rcu_dereference_protected(the_virtio_vosck,
> lock_dep_is_held(&the_virtio_vsock_mutex))
> 
> ...

Okay, thanks for confirming! I'll send a v3 to fix this!

> 
> 
> > 
> > > > +
> > > >    	mutex_unlock(&the_virtio_vsock_mutex);
> > > >    	return 0;
> > > > @@ -627,6 +644,12 @@ static void virtio_vsock_remove(struct virtio_device *vdev)
> > > >    	struct virtio_vsock *vsock = vdev->priv;
> > > >    	struct virtio_vsock_pkt *pkt;
> > > > +	mutex_lock(&the_virtio_vsock_mutex);
> > > > +
> > > > +	vdev->priv = NULL;
> > > > +	rcu_assign_pointer(the_virtio_vsock, NULL);
> > > 
> > > This is still suspicious, can we access the_virtio_vsock through vdev->priv?
> > > If yes, we may still get use-after-free since it was not protected by RCU.
> > We will free the object only after calling the del_vqs(), so we are sure
> > that the vq_callbacks ended and will no longer be invoked.
> > So, IIUC it shouldn't happen.
> 
> 
> Yes, but any dereference that is not done in vq_callbacks will be very
> dangerous in the future.

Right.

Do you think make sense to continue with this series in order to fix the
hot-unplug issue, then I'll work to refactor the driver code to use the refcnt
(as you suggested in patch 2) and singleton for the_virtio_vsock?

Thanks,
Stefano
Stefan Hajnoczi July 4, 2019, 10:17 a.m. UTC | #7
On Wed, Jul 03, 2019 at 12:41:35PM +0200, Stefano Garzarella wrote:
> On Wed, Jul 03, 2019 at 05:53:58PM +0800, Jason Wang wrote:
> > On 2019/6/28 下午8:36, Stefano Garzarella wrote:
> > Another more interesting question, I believe we will do singleton for
> > virtio_vsock structure. Then what's the point of using vdev->priv to access
> > the_virtio_vsock? It looks to me we can it brings extra troubles for doing
> > synchronization.
> 
> I thought about it when I tried to use RCU to stop the worker and I
> think make sense. Maybe can be another series after this will be merged.
> 
> @Stefan, what do you think about that?

Yes, let's make it a singleton and keep no other references to it.

Stefan
Jason Wang July 5, 2019, 12:18 a.m. UTC | #8
On 2019/7/4 下午5:20, Stefano Garzarella wrote:
>>>> This is still suspicious, can we access the_virtio_vsock through vdev->priv?
>>>> If yes, we may still get use-after-free since it was not protected by RCU.
>>> We will free the object only after calling the del_vqs(), so we are sure
>>> that the vq_callbacks ended and will no longer be invoked.
>>> So, IIUC it shouldn't happen.
>> Yes, but any dereference that is not done in vq_callbacks will be very
>> dangerous in the future.
> Right.
>
> Do you think make sense to continue with this series in order to fix the
> hot-unplug issue, then I'll work to refactor the driver code to use the refcnt
> (as you suggested in patch 2) and singleton for the_virtio_vsock?
>
> Thanks,
> Stefano


Yes.

Thanks
diff mbox series

Patch

diff --git a/net/vmw_vsock/virtio_transport.c b/net/vmw_vsock/virtio_transport.c
index 9c287e3e393c..7ad510ec12e0 100644
--- a/net/vmw_vsock/virtio_transport.c
+++ b/net/vmw_vsock/virtio_transport.c
@@ -65,19 +65,22 @@  struct virtio_vsock {
 	u32 guest_cid;
 };
 
-static struct virtio_vsock *virtio_vsock_get(void)
-{
-	return the_virtio_vsock;
-}
-
 static u32 virtio_transport_get_local_cid(void)
 {
-	struct virtio_vsock *vsock = virtio_vsock_get();
+	struct virtio_vsock *vsock;
+	u32 ret;
 
-	if (!vsock)
-		return VMADDR_CID_ANY;
+	rcu_read_lock();
+	vsock = rcu_dereference(the_virtio_vsock);
+	if (!vsock) {
+		ret = VMADDR_CID_ANY;
+		goto out_rcu;
+	}
 
-	return vsock->guest_cid;
+	ret = vsock->guest_cid;
+out_rcu:
+	rcu_read_unlock();
+	return ret;
 }
 
 static void virtio_transport_loopback_work(struct work_struct *work)
@@ -197,14 +200,18 @@  virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 	struct virtio_vsock *vsock;
 	int len = pkt->len;
 
-	vsock = virtio_vsock_get();
+	rcu_read_lock();
+	vsock = rcu_dereference(the_virtio_vsock);
 	if (!vsock) {
 		virtio_transport_free_pkt(pkt);
-		return -ENODEV;
+		len = -ENODEV;
+		goto out_rcu;
 	}
 
-	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid)
-		return virtio_transport_send_pkt_loopback(vsock, pkt);
+	if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) {
+		len = virtio_transport_send_pkt_loopback(vsock, pkt);
+		goto out_rcu;
+	}
 
 	if (pkt->reply)
 		atomic_inc(&vsock->queued_replies);
@@ -214,6 +221,9 @@  virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 	spin_unlock_bh(&vsock->send_pkt_list_lock);
 
 	queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
+
+out_rcu:
+	rcu_read_unlock();
 	return len;
 }
 
@@ -222,12 +232,14 @@  virtio_transport_cancel_pkt(struct vsock_sock *vsk)
 {
 	struct virtio_vsock *vsock;
 	struct virtio_vsock_pkt *pkt, *n;
-	int cnt = 0;
+	int cnt = 0, ret;
 	LIST_HEAD(freeme);
 
-	vsock = virtio_vsock_get();
+	rcu_read_lock();
+	vsock = rcu_dereference(the_virtio_vsock);
 	if (!vsock) {
-		return -ENODEV;
+		ret = -ENODEV;
+		goto out_rcu;
 	}
 
 	spin_lock_bh(&vsock->send_pkt_list_lock);
@@ -255,7 +267,11 @@  virtio_transport_cancel_pkt(struct vsock_sock *vsk)
 			queue_work(virtio_vsock_workqueue, &vsock->rx_work);
 	}
 
-	return 0;
+	ret = 0;
+
+out_rcu:
+	rcu_read_unlock();
+	return ret;
 }
 
 static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
@@ -590,8 +606,6 @@  static int virtio_vsock_probe(struct virtio_device *vdev)
 	vsock->rx_buf_max_nr = 0;
 	atomic_set(&vsock->queued_replies, 0);
 
-	vdev->priv = vsock;
-	the_virtio_vsock = vsock;
 	mutex_init(&vsock->tx_lock);
 	mutex_init(&vsock->rx_lock);
 	mutex_init(&vsock->event_lock);
@@ -613,6 +627,9 @@  static int virtio_vsock_probe(struct virtio_device *vdev)
 	virtio_vsock_event_fill(vsock);
 	mutex_unlock(&vsock->event_lock);
 
+	vdev->priv = vsock;
+	rcu_assign_pointer(the_virtio_vsock, vsock);
+
 	mutex_unlock(&the_virtio_vsock_mutex);
 	return 0;
 
@@ -627,6 +644,12 @@  static void virtio_vsock_remove(struct virtio_device *vdev)
 	struct virtio_vsock *vsock = vdev->priv;
 	struct virtio_vsock_pkt *pkt;
 
+	mutex_lock(&the_virtio_vsock_mutex);
+
+	vdev->priv = NULL;
+	rcu_assign_pointer(the_virtio_vsock, NULL);
+	synchronize_rcu();
+
 	flush_work(&vsock->loopback_work);
 	flush_work(&vsock->rx_work);
 	flush_work(&vsock->tx_work);
@@ -666,12 +689,10 @@  static void virtio_vsock_remove(struct virtio_device *vdev)
 	}
 	spin_unlock_bh(&vsock->loopback_list_lock);
 
-	mutex_lock(&the_virtio_vsock_mutex);
-	the_virtio_vsock = NULL;
-	mutex_unlock(&the_virtio_vsock_mutex);
-
 	vdev->config->del_vqs(vdev);
 
+	mutex_unlock(&the_virtio_vsock_mutex);
+
 	kfree(vsock);
 }