@@ -387,6 +387,7 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
static struct virtio_transport vhost_transport = {
.transport = {
.features = VSOCK_TRANSPORT_F_H2G,
+ .module = THIS_MODULE,
.get_local_cid = vhost_transport_get_local_cid,
@@ -100,6 +100,7 @@ struct vsock_transport_send_notify_data {
struct vsock_transport {
uint64_t features;
+ struct module *module;
/* Initialize/tear-down socket. */
int (*init)(struct vsock_sock *, struct vsock_sock *);
@@ -416,13 +416,28 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
return -ESOCKTNOSUPPORT;
}
- if (!vsk->transport)
+ /* We increase the module refcnt to prevent the tranport unloading
+ * while there are open sockets assigned to it.
+ */
+ if (!vsk->transport || !try_module_get(vsk->transport->module)) {
+ vsk->transport = NULL;
return -ENODEV;
+ }
return vsk->transport->init(vsk, psk);
}
EXPORT_SYMBOL_GPL(vsock_assign_transport);
+static void vsock_deassign_transport(struct vsock_sock *vsk)
+{
+ if (!vsk->transport)
+ return;
+
+ vsk->transport->destruct(vsk);
+ module_put(vsk->transport->module);
+ vsk->transport = NULL;
+}
+
static bool vsock_find_cid(unsigned int cid)
{
if (transport_g2h && cid == transport_g2h->get_local_cid())
@@ -728,8 +743,7 @@ static void vsock_sk_destruct(struct sock *sk)
{
struct vsock_sock *vsk = vsock_sk(sk);
- if (vsk->transport)
- vsk->transport->destruct(vsk);
+ vsock_deassign_transport(vsk);
/* When clearing these addresses, there's no need to set the family and
* possibly register the address family with the kernel.
@@ -2161,9 +2175,6 @@ void vsock_core_unregister(const struct vsock_transport *t)
{
mutex_lock(&vsock_register_mutex);
- /* RFC-TODO: maybe we should check if there are open sockets
- * assigned to that transport and avoid the unregistration
- */
if (transport_h2g == t)
transport_h2g = NULL;
@@ -857,6 +857,7 @@ int hvs_notify_send_post_enqueue(struct vsock_sock *vsk, ssize_t written,
static struct vsock_transport hvs_transport = {
.features = VSOCK_TRANSPORT_F_G2H,
+ .module = THIS_MODULE,
.get_local_cid = hvs_get_local_cid,
@@ -463,6 +463,7 @@ static void virtio_vsock_rx_done(struct virtqueue *vq)
static struct virtio_transport virtio_transport = {
.transport = {
.features = VSOCK_TRANSPORT_F_G2H,
+ .module = THIS_MODULE,
.get_local_cid = virtio_transport_get_local_cid,
@@ -2021,6 +2021,7 @@ static u32 vmci_transport_get_local_cid(void)
static struct vsock_transport vmci_transport = {
.features = VSOCK_TRANSPORT_F_DGRAM | VSOCK_TRANSPORT_F_H2G,
+ .module = THIS_MODULE,
.init = vmci_transport_socket_init,
.destruct = vmci_transport_destruct,
.release = vmci_transport_release,
This patch adds 'module' member in the 'struct vsock_transport' in order to get/put the transport module. This prevents the module unloading while sockets are assigned to it. We increase the module refcnt when a socket is assigned to a transport, and we decrease the module refcnt when the socket is destructed. Signed-off-by: Stefano Garzarella <sgarzare@redhat.com> --- drivers/vhost/vsock.c | 1 + include/net/af_vsock.h | 1 + net/vmw_vsock/af_vsock.c | 23 +++++++++++++++++------ net/vmw_vsock/hyperv_transport.c | 1 + net/vmw_vsock/virtio_transport.c | 1 + net/vmw_vsock/vmci_transport.c | 1 + 6 files changed, 22 insertions(+), 6 deletions(-)