diff mbox series

[v1,1/1] virtio-snd: rewrite invalid tx/rx message handling

Message ID virtio-snd-rewrite-invalid-tx-rx-message-handling-v1.manos.pitsidianakis@linaro.org (mailing list archive)
State New, archived
Headers show
Series virtio-snd: fix invalid tx/rx message handling logic | expand

Commit Message

Manos Pitsidianakis March 24, 2024, 10:04 a.m. UTC
The current handling of invalid virtqueue elements inside the TX/RX virt
queue handlers is wrong.

They are added in a per-stream invalid queue to be processed after the
handler is done examining each message, but the invalid message might
not be specifying any stream_id; which means it's invalid to add it to
any stream->invalid queue since stream could be NULL at this point.

This commit moves the invalid queue to the VirtIOSound struct which
guarantees there will always be a valid temporary place to store them
inside the tx/rx handlers. The queue will be emptied before the handler
returns, so the queue must be empty at any other point of the device's
lifetime.

Signed-off-by: Manos Pitsidianakis <manos.pitsidianakis@linaro.org>
---
 include/hw/audio/virtio-snd.h |  16 +++-
 hw/audio/virtio-snd.c         | 137 +++++++++++++++-------------------
 2 files changed, 77 insertions(+), 76 deletions(-)
diff mbox series

Patch

diff --git a/include/hw/audio/virtio-snd.h b/include/hw/audio/virtio-snd.h
index 3d79181364..8dafedb276 100644
--- a/include/hw/audio/virtio-snd.h
+++ b/include/hw/audio/virtio-snd.h
@@ -151,7 +151,6 @@  struct VirtIOSoundPCMStream {
     QemuMutex queue_mutex;
     bool active;
     QSIMPLEQ_HEAD(, VirtIOSoundPCMBuffer) queue;
-    QSIMPLEQ_HEAD(, VirtIOSoundPCMBuffer) invalid;
 };
 
 /*
@@ -223,6 +222,21 @@  struct VirtIOSound {
     QemuMutex cmdq_mutex;
     QTAILQ_HEAD(, virtio_snd_ctrl_command) cmdq;
     bool processing_cmdq;
+    /*
+     * Convenience queue to keep track of invalid tx/rx queue messages inside
+     * the tx/rx callbacks.
+     *
+     * In the callbacks as a first step we are emptying the virtqueue to handle
+     * each message and we cannot add an invalid message back to the queue: we
+     * would re-process it in subsequent loop iterations.
+     *
+     * Instead, we add them to this queue and after finishing examining every
+     * virtqueue element, we inform the guest for each invalid message.
+     *
+     * This queue must be empty at all times except for inside the tx/rx
+     * callbacks.
+     */
+    QSIMPLEQ_HEAD(, VirtIOSoundPCMBuffer) invalid;
 };
 
 struct virtio_snd_ctrl_command {
diff --git a/hw/audio/virtio-snd.c b/hw/audio/virtio-snd.c
index 30493f06a8..90d9a2796e 100644
--- a/hw/audio/virtio-snd.c
+++ b/hw/audio/virtio-snd.c
@@ -456,7 +456,6 @@  static uint32_t virtio_snd_pcm_prepare(VirtIOSound *s, uint32_t stream_id)
         stream->s = s;
         qemu_mutex_init(&stream->queue_mutex);
         QSIMPLEQ_INIT(&stream->queue);
-        QSIMPLEQ_INIT(&stream->invalid);
 
         /*
          * stream_id >= s->snd_conf.streams was checked before so this is
@@ -611,9 +610,6 @@  static size_t virtio_snd_pcm_get_io_msgs_count(VirtIOSoundPCMStream *stream)
         QSIMPLEQ_FOREACH_SAFE(buffer, &stream->queue, entry, next) {
             count += 1;
         }
-        QSIMPLEQ_FOREACH_SAFE(buffer, &stream->invalid, entry, next) {
-            count += 1;
-        }
     }
     return count;
 }
@@ -831,47 +827,36 @@  static void virtio_snd_handle_event(VirtIODevice *vdev, VirtQueue *vq)
     trace_virtio_snd_handle_event();
 }
 
+/*
+ * Must only be called if vsnd->invalid is not empty.
+ */
 static inline void empty_invalid_queue(VirtIODevice *vdev, VirtQueue *vq)
 {
     VirtIOSoundPCMBuffer *buffer = NULL;
-    VirtIOSoundPCMStream *stream = NULL;
     virtio_snd_pcm_status resp = { 0 };
     VirtIOSound *vsnd = VIRTIO_SND(vdev);
-    bool any = false;
 
-    for (uint32_t i = 0; i < vsnd->snd_conf.streams; i++) {
-        stream = vsnd->pcm->streams[i];
-        if (stream) {
-            any = false;
-            WITH_QEMU_LOCK_GUARD(&stream->queue_mutex) {
-                while (!QSIMPLEQ_EMPTY(&stream->invalid)) {
-                    buffer = QSIMPLEQ_FIRST(&stream->invalid);
-                    if (buffer->vq != vq) {
-                        break;
-                    }
-                    any = true;
-                    resp.status = cpu_to_le32(VIRTIO_SND_S_BAD_MSG);
-                    iov_from_buf(buffer->elem->in_sg,
-                                 buffer->elem->in_num,
-                                 0,
-                                 &resp,
-                                 sizeof(virtio_snd_pcm_status));
-                    virtqueue_push(vq,
-                                   buffer->elem,
-                                   sizeof(virtio_snd_pcm_status));
-                    QSIMPLEQ_REMOVE_HEAD(&stream->invalid, entry);
-                    virtio_snd_pcm_buffer_free(buffer);
-                }
-                if (any) {
-                    /*
-                     * Notify vq about virtio_snd_pcm_status responses.
-                     * Buffer responses must be notified separately later.
-                     */
-                    virtio_notify(vdev, vq);
-                }
-            }
-        }
+    g_assert(!QSIMPLEQ_EMPTY(&vsnd->invalid));
+
+    while (!QSIMPLEQ_EMPTY(&vsnd->invalid)) {
+        buffer = QSIMPLEQ_FIRST(&vsnd->invalid);
+        /* If buffer->vq != vq, our logic is fundamentally wrong, so bail out */
+        g_assert(buffer->vq == vq);
+
+        resp.status = cpu_to_le32(VIRTIO_SND_S_BAD_MSG);
+        iov_from_buf(buffer->elem->in_sg,
+                     buffer->elem->in_num,
+                     0,
+                     &resp,
+                     sizeof(virtio_snd_pcm_status));
+        virtqueue_push(vq,
+                       buffer->elem,
+                       sizeof(virtio_snd_pcm_status));
+        QSIMPLEQ_REMOVE_HEAD(&vsnd->invalid, entry);
+        virtio_snd_pcm_buffer_free(buffer);
     }
+    /* Notify vq about virtio_snd_pcm_status responses. */
+    virtio_notify(vdev, vq);
 }
 
 /*
@@ -883,15 +868,14 @@  static inline void empty_invalid_queue(VirtIODevice *vdev, VirtQueue *vq)
  */
 static void virtio_snd_handle_tx_xfer(VirtIODevice *vdev, VirtQueue *vq)
 {
-    VirtIOSound *s = VIRTIO_SND(vdev);
-    VirtIOSoundPCMStream *stream = NULL;
+    VirtIOSound *vsnd = VIRTIO_SND(vdev);
     VirtIOSoundPCMBuffer *buffer;
     VirtQueueElement *elem;
     size_t msg_sz, size;
     virtio_snd_pcm_xfer hdr;
     uint32_t stream_id;
     /*
-     * If any of the I/O messages are invalid, put them in stream->invalid and
+     * If any of the I/O messages are invalid, put them in vsnd->invalid and
      * return them after the for loop.
      */
     bool must_empty_invalid_queue = false;
@@ -901,7 +885,7 @@  static void virtio_snd_handle_tx_xfer(VirtIODevice *vdev, VirtQueue *vq)
     }
     trace_virtio_snd_handle_tx_xfer();
 
-    for (;;) {
+    for (VirtIOSoundPCMStream *stream = NULL;; stream = NULL) {
         elem = virtqueue_pop(vq, sizeof(VirtQueueElement));
         if (!elem) {
             break;
@@ -913,16 +897,16 @@  static void virtio_snd_handle_tx_xfer(VirtIODevice *vdev, VirtQueue *vq)
                             &hdr,
                             sizeof(virtio_snd_pcm_xfer));
         if (msg_sz != sizeof(virtio_snd_pcm_xfer)) {
-            continue;
+            goto tx_err;
         }
         stream_id = le32_to_cpu(hdr.stream_id);
 
-        if (stream_id >= s->snd_conf.streams
-            || s->pcm->streams[stream_id] == NULL) {
-            continue;
+        if (stream_id >= vsnd->snd_conf.streams
+            || vsnd->pcm->streams[stream_id] == NULL) {
+            goto tx_err;
         }
 
-        stream = s->pcm->streams[stream_id];
+        stream = vsnd->pcm->streams[stream_id];
         if (stream->info.direction != VIRTIO_SND_D_OUTPUT) {
             goto tx_err;
         }
@@ -942,13 +926,11 @@  static void virtio_snd_handle_tx_xfer(VirtIODevice *vdev, VirtQueue *vq)
         continue;
 
 tx_err:
-        WITH_QEMU_LOCK_GUARD(&stream->queue_mutex) {
-            must_empty_invalid_queue = true;
-            buffer = g_malloc0(sizeof(VirtIOSoundPCMBuffer));
-            buffer->elem = elem;
-            buffer->vq = vq;
-            QSIMPLEQ_INSERT_TAIL(&stream->invalid, buffer, entry);
-        }
+        must_empty_invalid_queue = true;
+        buffer = g_malloc0(sizeof(VirtIOSoundPCMBuffer));
+        buffer->elem = elem;
+        buffer->vq = vq;
+        QSIMPLEQ_INSERT_TAIL(&vsnd->invalid, buffer, entry);
     }
 
     if (must_empty_invalid_queue) {
@@ -965,15 +947,14 @@  tx_err:
  */
 static void virtio_snd_handle_rx_xfer(VirtIODevice *vdev, VirtQueue *vq)
 {
-    VirtIOSound *s = VIRTIO_SND(vdev);
-    VirtIOSoundPCMStream *stream = NULL;
+    VirtIOSound *vsnd = VIRTIO_SND(vdev);
     VirtIOSoundPCMBuffer *buffer;
     VirtQueueElement *elem;
     size_t msg_sz, size;
     virtio_snd_pcm_xfer hdr;
     uint32_t stream_id;
     /*
-     * if any of the I/O messages are invalid, put them in stream->invalid and
+     * if any of the I/O messages are invalid, put them in vsnd->invalid and
      * return them after the for loop.
      */
     bool must_empty_invalid_queue = false;
@@ -983,7 +964,7 @@  static void virtio_snd_handle_rx_xfer(VirtIODevice *vdev, VirtQueue *vq)
     }
     trace_virtio_snd_handle_rx_xfer();
 
-    for (;;) {
+    for (VirtIOSoundPCMStream *stream = NULL;; stream = NULL) {
         elem = virtqueue_pop(vq, sizeof(VirtQueueElement));
         if (!elem) {
             break;
@@ -995,16 +976,16 @@  static void virtio_snd_handle_rx_xfer(VirtIODevice *vdev, VirtQueue *vq)
                             &hdr,
                             sizeof(virtio_snd_pcm_xfer));
         if (msg_sz != sizeof(virtio_snd_pcm_xfer)) {
-            continue;
+            goto rx_err;
         }
         stream_id = le32_to_cpu(hdr.stream_id);
 
-        if (stream_id >= s->snd_conf.streams
-            || !s->pcm->streams[stream_id]) {
-            continue;
+        if (stream_id >= vsnd->snd_conf.streams
+            || !vsnd->pcm->streams[stream_id]) {
+            goto rx_err;
         }
 
-        stream = s->pcm->streams[stream_id];
+        stream = vsnd->pcm->streams[stream_id];
         if (stream == NULL || stream->info.direction != VIRTIO_SND_D_INPUT) {
             goto rx_err;
         }
@@ -1021,13 +1002,11 @@  static void virtio_snd_handle_rx_xfer(VirtIODevice *vdev, VirtQueue *vq)
         continue;
 
 rx_err:
-        WITH_QEMU_LOCK_GUARD(&stream->queue_mutex) {
-            must_empty_invalid_queue = true;
-            buffer = g_malloc0(sizeof(VirtIOSoundPCMBuffer));
-            buffer->elem = elem;
-            buffer->vq = vq;
-            QSIMPLEQ_INSERT_TAIL(&stream->invalid, buffer, entry);
-        }
+        must_empty_invalid_queue = true;
+        buffer = g_malloc0(sizeof(VirtIOSoundPCMBuffer));
+        buffer->elem = elem;
+        buffer->vq = vq;
+        QSIMPLEQ_INSERT_TAIL(&vsnd->invalid, buffer, entry);
     }
 
     if (must_empty_invalid_queue) {
@@ -1127,6 +1106,7 @@  static void virtio_snd_realize(DeviceState *dev, Error **errp)
         virtio_add_queue(vdev, 64, virtio_snd_handle_rx_xfer);
     qemu_mutex_init(&vsnd->cmdq_mutex);
     QTAILQ_INIT(&vsnd->cmdq);
+    QSIMPLEQ_INIT(&vsnd->invalid);
 
     for (uint32_t i = 0; i < vsnd->snd_conf.streams; i++) {
         status = virtio_snd_set_pcm_params(vsnd, i, &default_params);
@@ -1376,13 +1356,20 @@  static void virtio_snd_unrealize(DeviceState *dev)
 
 static void virtio_snd_reset(VirtIODevice *vdev)
 {
-    VirtIOSound *s = VIRTIO_SND(vdev);
+    VirtIOSound *vsnd = VIRTIO_SND(vdev);
     virtio_snd_ctrl_command *cmd;
 
-    WITH_QEMU_LOCK_GUARD(&s->cmdq_mutex) {
-        while (!QTAILQ_EMPTY(&s->cmdq)) {
-            cmd = QTAILQ_FIRST(&s->cmdq);
-            QTAILQ_REMOVE(&s->cmdq, cmd, next);
+    /*
+     * Sanity check that the invalid buffer message queue is emptied at the end
+     * of every virtio_snd_handle_tx_xfer/virtio_snd_handle_rx_xfer call, and
+     * must be empty otherwise.
+     */
+    g_assert(QSIMPLEQ_EMPTY(&vsnd->invalid));
+
+    WITH_QEMU_LOCK_GUARD(&vsnd->cmdq_mutex) {
+        while (!QTAILQ_EMPTY(&vsnd->cmdq)) {
+            cmd = QTAILQ_FIRST(&vsnd->cmdq);
+            QTAILQ_REMOVE(&vsnd->cmdq, cmd, next);
             virtio_snd_ctrl_cmd_free(cmd);
         }
     }