diff mbox series

[net-next,1/5] virtio_ring: introduce virtqueue_get_buf_ctx_dma()

Message ID 20240116075924.42798-2-xuanzhuo@linux.alibaba.com (mailing list archive)
State Deferred
Delegated to: Netdev Maintainers
Headers show
Series virtio-net: sq support premapped mode | expand

Checks

Context Check Description
netdev/tree_selection success Clearly marked for net-next
netdev/apply fail Patch does not apply to net-next

Commit Message

Xuan Zhuo Jan. 16, 2024, 7:59 a.m. UTC
introduce virtqueue_get_buf_ctx_dma() to collect the dma info when
get buf from virtio core for premapped mode.

If the virtio queue is premapped mode, the virtio-net send buf may
have many desc. Every desc dma address need to be unmap. So here we
introduce a new helper to collect the dma address of the buffer from
the virtio core.

Because the BAD_RING is called (that may set vq->broken), so
the relative "const" of vq is removed.

Signed-off-by: Xuan Zhuo <xuanzhuo@linux.alibaba.com>
---
 drivers/virtio/virtio_ring.c | 174 +++++++++++++++++++++++++----------
 include/linux/virtio.h       |  16 ++++
 2 files changed, 142 insertions(+), 48 deletions(-)

Comments

Jason Wang Jan. 24, 2024, 6:54 a.m. UTC | #1
On Tue, Jan 16, 2024 at 3:59 PM Xuan Zhuo <xuanzhuo@linux.alibaba.com> wrote:
>
> introduce virtqueue_get_buf_ctx_dma() to collect the dma info when
> get buf from virtio core for premapped mode.
>
> If the virtio queue is premapped mode, the virtio-net send buf may
> have many desc.

This feature is not specific to virtio-net, so we can let "for example ..."

> Every desc dma address need to be unmap. So here we
> introduce a new helper to collect the dma address of the buffer from
> the virtio core.

Let's explain why we can't (or suboptimal to) depend on a driver to do this.

>
> Because the BAD_RING is called (that may set vq->broken), so
> the relative "const" of vq is removed.
>
> Signed-off-by: Xuan Zhuo <xuanzhuo@linux.alibaba.com>
> ---
>  drivers/virtio/virtio_ring.c | 174 +++++++++++++++++++++++++----------
>  include/linux/virtio.h       |  16 ++++
>  2 files changed, 142 insertions(+), 48 deletions(-)
>
> diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
> index 49299b1f9ec7..82f72428605b 100644
> --- a/drivers/virtio/virtio_ring.c
> +++ b/drivers/virtio/virtio_ring.c
> @@ -362,6 +362,45 @@ static struct device *vring_dma_dev(const struct vring_virtqueue *vq)
>         return vq->dma_dev;
>  }
>
> +/*
> + *     use_dma_api premapped -> do_unmap
> + *  1. false       false        false
> + *  2. true        false        true
> + *  3. true        true         false
> + *
> + * Only #3, we should return the DMA info to the driver.

So the code has a check for dma:

        if (!dma)
                return false;

For whatever the case, we need a better comment here.

So we had: use_dma_api, premapped, do_unmap and dma now. I must say
I'm totally lost in this maze. It's a strong hint the API is too
complicated and needs to be tweaked.

For example, is it legal to have do_unmap be false buf DMA is true?

Here're suggestions:

1) rename premapped to buffer_is_premapped
2) rename do_unmap to buffer_need_unmap or introduce a helper

bool vring_need_unmap_buffer()
{
        return use_dma_api && !premapped;
}

3) split the getting dma info logic into an helper like vritqueue_get_dma_info()

so we can do

if (!vring_need_unmap_buffer()) {
        virtqueue_get_dma_info()
        return;
}

4) explain why we still need to check dma assuming we had
vring_need_unmap_buffer():

If vring_need_unmap_buffer() is true, we don't need to care about dma at all.

If vring_need_unmap_buffer() is false, we must return dma info
otherwise there's a leak?

> + *
> + * Return:
> + * true: the virtio core must unmap the desc
> + * false: the virtio core skip the desc unmap

Might it be better to say "It's up to the driver to unmap"?

> + */
> +static bool vring_need_unmap(struct vring_virtqueue *vq,
> +                            struct virtio_dma_head *dma,
> +                            dma_addr_t addr, unsigned int length)
> +{
> +       if (vq->do_unmap)
> +               return true;
> +
> +       if (!vq->premapped)
> +               return false;
> +
> +       if (!dma)
> +               return false;

So the logic here is odd.

if (!dma)
        return false;
...
        return false;

A strong hint to split the below getting info logic into another
helper. The root cause is the function do more than just a "yes or
no".

> +
> +       if (unlikely(dma->next >= dma->num)) {
> +               BAD_RING(vq, "premapped vq: collect dma overflow: %pad %u\n",
> +                        &addr, length);
> +               return false;
> +       }
> +
> +       dma->items[dma->next].addr = addr;
> +       dma->items[dma->next].length = length;
> +
> +       ++dma->next;
> +
> +       return false;
> +}
> +

Thanks

....
Michael S. Tsirkin Feb. 22, 2024, 7:43 p.m. UTC | #2
On Tue, Jan 16, 2024 at 03:59:20PM +0800, Xuan Zhuo wrote:
> introduce virtqueue_get_buf_ctx_dma() to collect the dma info when
> get buf from virtio core for premapped mode.
> 
> If the virtio queue is premapped mode, the virtio-net send buf may
> have many desc. Every desc dma address need to be unmap. So here we
> introduce a new helper to collect the dma address of the buffer from
> the virtio core.
> 
> Because the BAD_RING is called (that may set vq->broken), so
> the relative "const" of vq is removed.
> 
> Signed-off-by: Xuan Zhuo <xuanzhuo@linux.alibaba.com>
> ---
>  drivers/virtio/virtio_ring.c | 174 +++++++++++++++++++++++++----------
>  include/linux/virtio.h       |  16 ++++
>  2 files changed, 142 insertions(+), 48 deletions(-)
> 
> diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
> index 49299b1f9ec7..82f72428605b 100644
> --- a/drivers/virtio/virtio_ring.c
> +++ b/drivers/virtio/virtio_ring.c
> @@ -362,6 +362,45 @@ static struct device *vring_dma_dev(const struct vring_virtqueue *vq)
>  	return vq->dma_dev;
>  }
>  
> +/*
> + *     use_dma_api premapped -> do_unmap
> + *  1. false       false        false
> + *  2. true        false        true
> + *  3. true        true         false
> + *
> + * Only #3, we should return the DMA info to the driver.

no idea what this table is supposed to mean.
Instead of this, just add comments near each
piece of code explaining it.
E.g. something like (guest guessing at the reason, pls replace
with the real one):

	/* if premapping is not supported, no need to unmap */
	if (!vq->premapped)
		return false;

and so on


> + * Return:
> + * true: the virtio core must unmap the desc
> + * false: the virtio core skip the desc unmap
> + */
> +static bool vring_need_unmap(struct vring_virtqueue *vq,
> +			     struct virtio_dma_head *dma,
> +			     dma_addr_t addr, unsigned int length)
> +{
> +	if (vq->do_unmap)
> +		return true;
> +
> +	if (!vq->premapped)
> +		return false;
> +
> +	if (!dma)
> +		return false;
> +
> +	if (unlikely(dma->next >= dma->num)) {
> +		BAD_RING(vq, "premapped vq: collect dma overflow: %pad %u\n",
> +			 &addr, length);
> +		return false;
> +	}
> +
> +	dma->items[dma->next].addr = addr;
> +	dma->items[dma->next].length = length;
> +
> +	++dma->next;
> +
> +	return false;
> +}
> +
>  /* Map one sg entry. */
>  static int vring_map_one_sg(const struct vring_virtqueue *vq, struct scatterlist *sg,
>  			    enum dma_data_direction direction, dma_addr_t *addr)
> @@ -440,12 +479,14 @@ static void virtqueue_init(struct vring_virtqueue *vq, u32 num)
>   * Split ring specific functions - *_split().
>   */
>  
> -static void vring_unmap_one_split_indirect(const struct vring_virtqueue *vq,
> -					   const struct vring_desc *desc)
> +static void vring_unmap_one_split_indirect(struct vring_virtqueue *vq,
> +					   const struct vring_desc *desc,
> +					   struct virtio_dma_head *dma)
>  {
>  	u16 flags;
>  
> -	if (!vq->do_unmap)
> +	if (!vring_need_unmap(vq, dma, virtio64_to_cpu(vq->vq.vdev, desc->addr),
> +			  virtio32_to_cpu(vq->vq.vdev, desc->len)))
>  		return;
>  
>  	flags = virtio16_to_cpu(vq->vq.vdev, desc->flags);
> @@ -457,8 +498,8 @@ static void vring_unmap_one_split_indirect(const struct vring_virtqueue *vq,
>  		       DMA_FROM_DEVICE : DMA_TO_DEVICE);
>  }
>  
> -static unsigned int vring_unmap_one_split(const struct vring_virtqueue *vq,
> -					  unsigned int i)
> +static unsigned int vring_unmap_one_split(struct vring_virtqueue *vq,
> +					  unsigned int i, struct virtio_dma_head *dma)
>  {
>  	struct vring_desc_extra *extra = vq->split.desc_extra;
>  	u16 flags;
> @@ -474,17 +515,16 @@ static unsigned int vring_unmap_one_split(const struct vring_virtqueue *vq,
>  				 extra[i].len,
>  				 (flags & VRING_DESC_F_WRITE) ?
>  				 DMA_FROM_DEVICE : DMA_TO_DEVICE);
> -	} else {
> -		if (!vq->do_unmap)
> -			goto out;
> -
> -		dma_unmap_page(vring_dma_dev(vq),
> -			       extra[i].addr,
> -			       extra[i].len,
> -			       (flags & VRING_DESC_F_WRITE) ?
> -			       DMA_FROM_DEVICE : DMA_TO_DEVICE);
> +		goto out;
>  	}
>  
> +	if (!vring_need_unmap(vq, dma, extra[i].addr, extra[i].len))
> +		goto out;
> +
> +	dma_unmap_page(vring_dma_dev(vq), extra[i].addr, extra[i].len,
> +		       (flags & VRING_DESC_F_WRITE) ?
> +		       DMA_FROM_DEVICE : DMA_TO_DEVICE);
> +
>  out:
>  	return extra[i].next;
>  }
> @@ -717,10 +757,10 @@ static inline int virtqueue_add_split(struct virtqueue *_vq,
>  		if (i == err_idx)
>  			break;
>  		if (indirect) {
> -			vring_unmap_one_split_indirect(vq, &desc[i]);
> +			vring_unmap_one_split_indirect(vq, &desc[i], NULL);
>  			i = virtio16_to_cpu(_vq->vdev, desc[i].next);
>  		} else
> -			i = vring_unmap_one_split(vq, i);
> +			i = vring_unmap_one_split(vq, i, NULL);
>  	}
>  
>  free_indirect:
> @@ -763,7 +803,7 @@ static bool virtqueue_kick_prepare_split(struct virtqueue *_vq)
>  }
>  
>  static void detach_buf_split(struct vring_virtqueue *vq, unsigned int head,
> -			     void **ctx)
> +			     struct virtio_dma_head *dma, void **ctx)
>  {
>  	unsigned int i, j;
>  	__virtio16 nextflag = cpu_to_virtio16(vq->vq.vdev, VRING_DESC_F_NEXT);
> @@ -775,12 +815,12 @@ static void detach_buf_split(struct vring_virtqueue *vq, unsigned int head,
>  	i = head;
>  
>  	while (vq->split.vring.desc[i].flags & nextflag) {
> -		vring_unmap_one_split(vq, i);
> +		vring_unmap_one_split(vq, i, dma);
>  		i = vq->split.desc_extra[i].next;
>  		vq->vq.num_free++;
>  	}
>  
> -	vring_unmap_one_split(vq, i);
> +	vring_unmap_one_split(vq, i, dma);
>  	vq->split.desc_extra[i].next = vq->free_head;
>  	vq->free_head = head;
>  
> @@ -802,9 +842,9 @@ static void detach_buf_split(struct vring_virtqueue *vq, unsigned int head,
>  				VRING_DESC_F_INDIRECT));
>  		BUG_ON(len == 0 || len % sizeof(struct vring_desc));
>  
> -		if (vq->do_unmap) {
> +		if (vq->do_unmap || dma) {
>  			for (j = 0; j < len / sizeof(struct vring_desc); j++)
> -				vring_unmap_one_split_indirect(vq, &indir_desc[j]);
> +				vring_unmap_one_split_indirect(vq, &indir_desc[j], dma);
>  		}
>  
>  		kfree(indir_desc);
> @@ -822,6 +862,7 @@ static bool more_used_split(const struct vring_virtqueue *vq)
>  
>  static void *virtqueue_get_buf_ctx_split(struct virtqueue *_vq,
>  					 unsigned int *len,
> +					 struct virtio_dma_head *dma,
>  					 void **ctx)
>  {
>  	struct vring_virtqueue *vq = to_vvq(_vq);
> @@ -862,7 +903,7 @@ static void *virtqueue_get_buf_ctx_split(struct virtqueue *_vq,
>  
>  	/* detach_buf_split clears data, so grab it now. */
>  	ret = vq->split.desc_state[i].data;
> -	detach_buf_split(vq, i, ctx);
> +	detach_buf_split(vq, i, dma, ctx);
>  	vq->last_used_idx++;
>  	/* If we expect an interrupt for the next entry, tell host
>  	 * by writing event index and flush out the write before
> @@ -984,7 +1025,7 @@ static void *virtqueue_detach_unused_buf_split(struct virtqueue *_vq)
>  			continue;
>  		/* detach_buf_split clears data, so grab it now. */
>  		buf = vq->split.desc_state[i].data;
> -		detach_buf_split(vq, i, NULL);
> +		detach_buf_split(vq, i, NULL, NULL);
>  		vq->split.avail_idx_shadow--;
>  		vq->split.vring.avail->idx = cpu_to_virtio16(_vq->vdev,
>  				vq->split.avail_idx_shadow);
> @@ -1220,8 +1261,9 @@ static u16 packed_last_used(u16 last_used_idx)
>  	return last_used_idx & ~(-(1 << VRING_PACKED_EVENT_F_WRAP_CTR));
>  }
>  
> -static void vring_unmap_extra_packed(const struct vring_virtqueue *vq,
> -				     const struct vring_desc_extra *extra)
> +static void vring_unmap_extra_packed(struct vring_virtqueue *vq,
> +				     const struct vring_desc_extra *extra,
> +				     struct virtio_dma_head *dma)
>  {
>  	u16 flags;
>  
> @@ -1235,23 +1277,24 @@ static void vring_unmap_extra_packed(const struct vring_virtqueue *vq,
>  				 extra->addr, extra->len,
>  				 (flags & VRING_DESC_F_WRITE) ?
>  				 DMA_FROM_DEVICE : DMA_TO_DEVICE);
> -	} else {
> -		if (!vq->do_unmap)
> -			return;
> -
> -		dma_unmap_page(vring_dma_dev(vq),
> -			       extra->addr, extra->len,
> -			       (flags & VRING_DESC_F_WRITE) ?
> -			       DMA_FROM_DEVICE : DMA_TO_DEVICE);
> +		return;
>  	}
> +
> +	if (!vring_need_unmap(vq, dma, extra->addr, extra->len))
> +		return;
> +
> +	dma_unmap_page(vring_dma_dev(vq), extra->addr, extra->len,
> +		       (flags & VRING_DESC_F_WRITE) ?
> +		       DMA_FROM_DEVICE : DMA_TO_DEVICE);
>  }
>  
> -static void vring_unmap_desc_packed(const struct vring_virtqueue *vq,
> -				    const struct vring_packed_desc *desc)
> +static void vring_unmap_desc_packed(struct vring_virtqueue *vq,
> +				    const struct vring_packed_desc *desc,
> +				    struct virtio_dma_head *dma)
>  {
>  	u16 flags;
>  
> -	if (!vq->do_unmap)
> +	if (!vring_need_unmap(vq, dma, le64_to_cpu(desc->addr), le32_to_cpu(desc->len)))
>  		return;
>  
>  	flags = le16_to_cpu(desc->flags);
> @@ -1389,7 +1432,7 @@ static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
>  	err_idx = i;
>  
>  	for (i = 0; i < err_idx; i++)
> -		vring_unmap_desc_packed(vq, &desc[i]);
> +		vring_unmap_desc_packed(vq, &desc[i], NULL);
>  
>  free_desc:
>  	kfree(desc);
> @@ -1539,7 +1582,7 @@ static inline int virtqueue_add_packed(struct virtqueue *_vq,
>  	for (n = 0; n < total_sg; n++) {
>  		if (i == err_idx)
>  			break;
> -		vring_unmap_extra_packed(vq, &vq->packed.desc_extra[curr]);
> +		vring_unmap_extra_packed(vq, &vq->packed.desc_extra[curr], NULL);
>  		curr = vq->packed.desc_extra[curr].next;
>  		i++;
>  		if (i >= vq->packed.vring.num)
> @@ -1600,7 +1643,9 @@ static bool virtqueue_kick_prepare_packed(struct virtqueue *_vq)
>  }
>  
>  static void detach_buf_packed(struct vring_virtqueue *vq,
> -			      unsigned int id, void **ctx)
> +			      unsigned int id,
> +			      struct virtio_dma_head *dma,
> +			      void **ctx)
>  {
>  	struct vring_desc_state_packed *state = NULL;
>  	struct vring_packed_desc *desc;
> @@ -1615,11 +1660,10 @@ static void detach_buf_packed(struct vring_virtqueue *vq,
>  	vq->free_head = id;
>  	vq->vq.num_free += state->num;
>  
> -	if (unlikely(vq->do_unmap)) {
> +	if (vq->do_unmap || dma) {
>  		curr = id;
>  		for (i = 0; i < state->num; i++) {
> -			vring_unmap_extra_packed(vq,
> -						 &vq->packed.desc_extra[curr]);
> +			vring_unmap_extra_packed(vq, &vq->packed.desc_extra[curr], dma);
>  			curr = vq->packed.desc_extra[curr].next;
>  		}
>  	}
> @@ -1632,11 +1676,11 @@ static void detach_buf_packed(struct vring_virtqueue *vq,
>  		if (!desc)
>  			return;
>  
> -		if (vq->do_unmap) {
> +		if (vq->do_unmap || dma) {
>  			len = vq->packed.desc_extra[id].len;
>  			for (i = 0; i < len / sizeof(struct vring_packed_desc);
>  					i++)
> -				vring_unmap_desc_packed(vq, &desc[i]);
> +				vring_unmap_desc_packed(vq, &desc[i], dma);
>  		}
>  		kfree(desc);
>  		state->indir_desc = NULL;
> @@ -1672,6 +1716,7 @@ static bool more_used_packed(const struct vring_virtqueue *vq)
>  
>  static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
>  					  unsigned int *len,
> +					  struct virtio_dma_head *dma,
>  					  void **ctx)
>  {
>  	struct vring_virtqueue *vq = to_vvq(_vq);
> @@ -1712,7 +1757,7 @@ static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
>  
>  	/* detach_buf_packed clears data, so grab it now. */
>  	ret = vq->packed.desc_state[id].data;
> -	detach_buf_packed(vq, id, ctx);
> +	detach_buf_packed(vq, id, dma, ctx);
>  
>  	last_used += vq->packed.desc_state[id].num;
>  	if (unlikely(last_used >= vq->packed.vring.num)) {
> @@ -1877,7 +1922,7 @@ static void *virtqueue_detach_unused_buf_packed(struct virtqueue *_vq)
>  			continue;
>  		/* detach_buf clears data, so grab it now. */
>  		buf = vq->packed.desc_state[i].data;
> -		detach_buf_packed(vq, i, NULL);
> +		detach_buf_packed(vq, i, NULL, NULL);
>  		END_USE(vq);
>  		return buf;
>  	}
> @@ -2417,11 +2462,44 @@ void *virtqueue_get_buf_ctx(struct virtqueue *_vq, unsigned int *len,
>  {
>  	struct vring_virtqueue *vq = to_vvq(_vq);
>  
> -	return vq->packed_ring ? virtqueue_get_buf_ctx_packed(_vq, len, ctx) :
> -				 virtqueue_get_buf_ctx_split(_vq, len, ctx);
> +	return vq->packed_ring ? virtqueue_get_buf_ctx_packed(_vq, len, NULL, ctx) :
> +				 virtqueue_get_buf_ctx_split(_vq, len, NULL, ctx);
>  }
>  EXPORT_SYMBOL_GPL(virtqueue_get_buf_ctx);
>  
> +/**
> + * virtqueue_get_buf_ctx_dma - get the next used buffer with the dma info
> + * @_vq: the struct virtqueue we're talking about.
> + * @len: the length written into the buffer
> + * @dma: the head of the array to store the dma info
> + * @ctx: extra context for the token
> + *
> + * If the device wrote data into the buffer, @len will be set to the
> + * amount written.  This means you don't need to clear the buffer
> + * beforehand to ensure there's no data leakage in the case of short
> + * writes.
> + *
> + * Caller must ensure we don't call this with other virtqueue
> + * operations at the same time (except where noted).
> + *
> + * We store the dma info of every descriptor of this buf to the dma->items
> + * array. If the array size is too small, some dma info may be missed, so
> + * the caller must ensure the array is large enough. The dma->next is the out
> + * value to the caller, indicates the num of the used items.

num -> number?
So next is the number of items? And num is what?
Can't we avoid hacks like this in APIs?

> + *
> + * Returns NULL if there are no used buffers, or the "data" token
> + * handed to virtqueue_add_*().
> + */
> +void *virtqueue_get_buf_ctx_dma(struct virtqueue *_vq, unsigned int *len,
> +				struct virtio_dma_head *dma, void **ctx)
> +{
> +	struct vring_virtqueue *vq = to_vvq(_vq);
> +
> +	return vq->packed_ring ? virtqueue_get_buf_ctx_packed(_vq, len, dma, ctx) :
> +				 virtqueue_get_buf_ctx_split(_vq, len, dma, ctx);
> +}
> +EXPORT_SYMBOL_GPL(virtqueue_get_buf_ctx_dma);
> +
>  void *virtqueue_get_buf(struct virtqueue *_vq, unsigned int *len)
>  {
>  	return virtqueue_get_buf_ctx(_vq, len, NULL);
> diff --git a/include/linux/virtio.h b/include/linux/virtio.h
> index 4cc614a38376..572aecec205b 100644
> --- a/include/linux/virtio.h
> +++ b/include/linux/virtio.h
> @@ -75,6 +75,22 @@ void *virtqueue_get_buf(struct virtqueue *vq, unsigned int *len);
>  void *virtqueue_get_buf_ctx(struct virtqueue *vq, unsigned int *len,
>  			    void **ctx);
>  
> +struct virtio_dma_item {
> +	dma_addr_t addr;
> +	unsigned int length;
> +};
> +
> +struct virtio_dma_head {
> +	/* total num of items. */
> +	u16 num;
> +	/* point to the next item to store dma info. */
> +	u16 next;

I'm not sure what is this data structure ... is it a linked list?  a ring?
pls document.


> +	struct virtio_dma_item items[];
> +};
> +
> +void *virtqueue_get_buf_ctx_dma(struct virtqueue *_vq, unsigned int *len,
> +				struct virtio_dma_head *dma, void **ctx);
> +
>  void virtqueue_disable_cb(struct virtqueue *vq);
>  
>  bool virtqueue_enable_cb(struct virtqueue *vq);
> -- 
> 2.32.0.3.g01195cf9f
diff mbox series

Patch

diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
index 49299b1f9ec7..82f72428605b 100644
--- a/drivers/virtio/virtio_ring.c
+++ b/drivers/virtio/virtio_ring.c
@@ -362,6 +362,45 @@  static struct device *vring_dma_dev(const struct vring_virtqueue *vq)
 	return vq->dma_dev;
 }
 
+/*
+ *     use_dma_api premapped -> do_unmap
+ *  1. false       false        false
+ *  2. true        false        true
+ *  3. true        true         false
+ *
+ * Only #3, we should return the DMA info to the driver.
+ *
+ * Return:
+ * true: the virtio core must unmap the desc
+ * false: the virtio core skip the desc unmap
+ */
+static bool vring_need_unmap(struct vring_virtqueue *vq,
+			     struct virtio_dma_head *dma,
+			     dma_addr_t addr, unsigned int length)
+{
+	if (vq->do_unmap)
+		return true;
+
+	if (!vq->premapped)
+		return false;
+
+	if (!dma)
+		return false;
+
+	if (unlikely(dma->next >= dma->num)) {
+		BAD_RING(vq, "premapped vq: collect dma overflow: %pad %u\n",
+			 &addr, length);
+		return false;
+	}
+
+	dma->items[dma->next].addr = addr;
+	dma->items[dma->next].length = length;
+
+	++dma->next;
+
+	return false;
+}
+
 /* Map one sg entry. */
 static int vring_map_one_sg(const struct vring_virtqueue *vq, struct scatterlist *sg,
 			    enum dma_data_direction direction, dma_addr_t *addr)
@@ -440,12 +479,14 @@  static void virtqueue_init(struct vring_virtqueue *vq, u32 num)
  * Split ring specific functions - *_split().
  */
 
-static void vring_unmap_one_split_indirect(const struct vring_virtqueue *vq,
-					   const struct vring_desc *desc)
+static void vring_unmap_one_split_indirect(struct vring_virtqueue *vq,
+					   const struct vring_desc *desc,
+					   struct virtio_dma_head *dma)
 {
 	u16 flags;
 
-	if (!vq->do_unmap)
+	if (!vring_need_unmap(vq, dma, virtio64_to_cpu(vq->vq.vdev, desc->addr),
+			  virtio32_to_cpu(vq->vq.vdev, desc->len)))
 		return;
 
 	flags = virtio16_to_cpu(vq->vq.vdev, desc->flags);
@@ -457,8 +498,8 @@  static void vring_unmap_one_split_indirect(const struct vring_virtqueue *vq,
 		       DMA_FROM_DEVICE : DMA_TO_DEVICE);
 }
 
-static unsigned int vring_unmap_one_split(const struct vring_virtqueue *vq,
-					  unsigned int i)
+static unsigned int vring_unmap_one_split(struct vring_virtqueue *vq,
+					  unsigned int i, struct virtio_dma_head *dma)
 {
 	struct vring_desc_extra *extra = vq->split.desc_extra;
 	u16 flags;
@@ -474,17 +515,16 @@  static unsigned int vring_unmap_one_split(const struct vring_virtqueue *vq,
 				 extra[i].len,
 				 (flags & VRING_DESC_F_WRITE) ?
 				 DMA_FROM_DEVICE : DMA_TO_DEVICE);
-	} else {
-		if (!vq->do_unmap)
-			goto out;
-
-		dma_unmap_page(vring_dma_dev(vq),
-			       extra[i].addr,
-			       extra[i].len,
-			       (flags & VRING_DESC_F_WRITE) ?
-			       DMA_FROM_DEVICE : DMA_TO_DEVICE);
+		goto out;
 	}
 
+	if (!vring_need_unmap(vq, dma, extra[i].addr, extra[i].len))
+		goto out;
+
+	dma_unmap_page(vring_dma_dev(vq), extra[i].addr, extra[i].len,
+		       (flags & VRING_DESC_F_WRITE) ?
+		       DMA_FROM_DEVICE : DMA_TO_DEVICE);
+
 out:
 	return extra[i].next;
 }
@@ -717,10 +757,10 @@  static inline int virtqueue_add_split(struct virtqueue *_vq,
 		if (i == err_idx)
 			break;
 		if (indirect) {
-			vring_unmap_one_split_indirect(vq, &desc[i]);
+			vring_unmap_one_split_indirect(vq, &desc[i], NULL);
 			i = virtio16_to_cpu(_vq->vdev, desc[i].next);
 		} else
-			i = vring_unmap_one_split(vq, i);
+			i = vring_unmap_one_split(vq, i, NULL);
 	}
 
 free_indirect:
@@ -763,7 +803,7 @@  static bool virtqueue_kick_prepare_split(struct virtqueue *_vq)
 }
 
 static void detach_buf_split(struct vring_virtqueue *vq, unsigned int head,
-			     void **ctx)
+			     struct virtio_dma_head *dma, void **ctx)
 {
 	unsigned int i, j;
 	__virtio16 nextflag = cpu_to_virtio16(vq->vq.vdev, VRING_DESC_F_NEXT);
@@ -775,12 +815,12 @@  static void detach_buf_split(struct vring_virtqueue *vq, unsigned int head,
 	i = head;
 
 	while (vq->split.vring.desc[i].flags & nextflag) {
-		vring_unmap_one_split(vq, i);
+		vring_unmap_one_split(vq, i, dma);
 		i = vq->split.desc_extra[i].next;
 		vq->vq.num_free++;
 	}
 
-	vring_unmap_one_split(vq, i);
+	vring_unmap_one_split(vq, i, dma);
 	vq->split.desc_extra[i].next = vq->free_head;
 	vq->free_head = head;
 
@@ -802,9 +842,9 @@  static void detach_buf_split(struct vring_virtqueue *vq, unsigned int head,
 				VRING_DESC_F_INDIRECT));
 		BUG_ON(len == 0 || len % sizeof(struct vring_desc));
 
-		if (vq->do_unmap) {
+		if (vq->do_unmap || dma) {
 			for (j = 0; j < len / sizeof(struct vring_desc); j++)
-				vring_unmap_one_split_indirect(vq, &indir_desc[j]);
+				vring_unmap_one_split_indirect(vq, &indir_desc[j], dma);
 		}
 
 		kfree(indir_desc);
@@ -822,6 +862,7 @@  static bool more_used_split(const struct vring_virtqueue *vq)
 
 static void *virtqueue_get_buf_ctx_split(struct virtqueue *_vq,
 					 unsigned int *len,
+					 struct virtio_dma_head *dma,
 					 void **ctx)
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
@@ -862,7 +903,7 @@  static void *virtqueue_get_buf_ctx_split(struct virtqueue *_vq,
 
 	/* detach_buf_split clears data, so grab it now. */
 	ret = vq->split.desc_state[i].data;
-	detach_buf_split(vq, i, ctx);
+	detach_buf_split(vq, i, dma, ctx);
 	vq->last_used_idx++;
 	/* If we expect an interrupt for the next entry, tell host
 	 * by writing event index and flush out the write before
@@ -984,7 +1025,7 @@  static void *virtqueue_detach_unused_buf_split(struct virtqueue *_vq)
 			continue;
 		/* detach_buf_split clears data, so grab it now. */
 		buf = vq->split.desc_state[i].data;
-		detach_buf_split(vq, i, NULL);
+		detach_buf_split(vq, i, NULL, NULL);
 		vq->split.avail_idx_shadow--;
 		vq->split.vring.avail->idx = cpu_to_virtio16(_vq->vdev,
 				vq->split.avail_idx_shadow);
@@ -1220,8 +1261,9 @@  static u16 packed_last_used(u16 last_used_idx)
 	return last_used_idx & ~(-(1 << VRING_PACKED_EVENT_F_WRAP_CTR));
 }
 
-static void vring_unmap_extra_packed(const struct vring_virtqueue *vq,
-				     const struct vring_desc_extra *extra)
+static void vring_unmap_extra_packed(struct vring_virtqueue *vq,
+				     const struct vring_desc_extra *extra,
+				     struct virtio_dma_head *dma)
 {
 	u16 flags;
 
@@ -1235,23 +1277,24 @@  static void vring_unmap_extra_packed(const struct vring_virtqueue *vq,
 				 extra->addr, extra->len,
 				 (flags & VRING_DESC_F_WRITE) ?
 				 DMA_FROM_DEVICE : DMA_TO_DEVICE);
-	} else {
-		if (!vq->do_unmap)
-			return;
-
-		dma_unmap_page(vring_dma_dev(vq),
-			       extra->addr, extra->len,
-			       (flags & VRING_DESC_F_WRITE) ?
-			       DMA_FROM_DEVICE : DMA_TO_DEVICE);
+		return;
 	}
+
+	if (!vring_need_unmap(vq, dma, extra->addr, extra->len))
+		return;
+
+	dma_unmap_page(vring_dma_dev(vq), extra->addr, extra->len,
+		       (flags & VRING_DESC_F_WRITE) ?
+		       DMA_FROM_DEVICE : DMA_TO_DEVICE);
 }
 
-static void vring_unmap_desc_packed(const struct vring_virtqueue *vq,
-				    const struct vring_packed_desc *desc)
+static void vring_unmap_desc_packed(struct vring_virtqueue *vq,
+				    const struct vring_packed_desc *desc,
+				    struct virtio_dma_head *dma)
 {
 	u16 flags;
 
-	if (!vq->do_unmap)
+	if (!vring_need_unmap(vq, dma, le64_to_cpu(desc->addr), le32_to_cpu(desc->len)))
 		return;
 
 	flags = le16_to_cpu(desc->flags);
@@ -1389,7 +1432,7 @@  static int virtqueue_add_indirect_packed(struct vring_virtqueue *vq,
 	err_idx = i;
 
 	for (i = 0; i < err_idx; i++)
-		vring_unmap_desc_packed(vq, &desc[i]);
+		vring_unmap_desc_packed(vq, &desc[i], NULL);
 
 free_desc:
 	kfree(desc);
@@ -1539,7 +1582,7 @@  static inline int virtqueue_add_packed(struct virtqueue *_vq,
 	for (n = 0; n < total_sg; n++) {
 		if (i == err_idx)
 			break;
-		vring_unmap_extra_packed(vq, &vq->packed.desc_extra[curr]);
+		vring_unmap_extra_packed(vq, &vq->packed.desc_extra[curr], NULL);
 		curr = vq->packed.desc_extra[curr].next;
 		i++;
 		if (i >= vq->packed.vring.num)
@@ -1600,7 +1643,9 @@  static bool virtqueue_kick_prepare_packed(struct virtqueue *_vq)
 }
 
 static void detach_buf_packed(struct vring_virtqueue *vq,
-			      unsigned int id, void **ctx)
+			      unsigned int id,
+			      struct virtio_dma_head *dma,
+			      void **ctx)
 {
 	struct vring_desc_state_packed *state = NULL;
 	struct vring_packed_desc *desc;
@@ -1615,11 +1660,10 @@  static void detach_buf_packed(struct vring_virtqueue *vq,
 	vq->free_head = id;
 	vq->vq.num_free += state->num;
 
-	if (unlikely(vq->do_unmap)) {
+	if (vq->do_unmap || dma) {
 		curr = id;
 		for (i = 0; i < state->num; i++) {
-			vring_unmap_extra_packed(vq,
-						 &vq->packed.desc_extra[curr]);
+			vring_unmap_extra_packed(vq, &vq->packed.desc_extra[curr], dma);
 			curr = vq->packed.desc_extra[curr].next;
 		}
 	}
@@ -1632,11 +1676,11 @@  static void detach_buf_packed(struct vring_virtqueue *vq,
 		if (!desc)
 			return;
 
-		if (vq->do_unmap) {
+		if (vq->do_unmap || dma) {
 			len = vq->packed.desc_extra[id].len;
 			for (i = 0; i < len / sizeof(struct vring_packed_desc);
 					i++)
-				vring_unmap_desc_packed(vq, &desc[i]);
+				vring_unmap_desc_packed(vq, &desc[i], dma);
 		}
 		kfree(desc);
 		state->indir_desc = NULL;
@@ -1672,6 +1716,7 @@  static bool more_used_packed(const struct vring_virtqueue *vq)
 
 static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
 					  unsigned int *len,
+					  struct virtio_dma_head *dma,
 					  void **ctx)
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
@@ -1712,7 +1757,7 @@  static void *virtqueue_get_buf_ctx_packed(struct virtqueue *_vq,
 
 	/* detach_buf_packed clears data, so grab it now. */
 	ret = vq->packed.desc_state[id].data;
-	detach_buf_packed(vq, id, ctx);
+	detach_buf_packed(vq, id, dma, ctx);
 
 	last_used += vq->packed.desc_state[id].num;
 	if (unlikely(last_used >= vq->packed.vring.num)) {
@@ -1877,7 +1922,7 @@  static void *virtqueue_detach_unused_buf_packed(struct virtqueue *_vq)
 			continue;
 		/* detach_buf clears data, so grab it now. */
 		buf = vq->packed.desc_state[i].data;
-		detach_buf_packed(vq, i, NULL);
+		detach_buf_packed(vq, i, NULL, NULL);
 		END_USE(vq);
 		return buf;
 	}
@@ -2417,11 +2462,44 @@  void *virtqueue_get_buf_ctx(struct virtqueue *_vq, unsigned int *len,
 {
 	struct vring_virtqueue *vq = to_vvq(_vq);
 
-	return vq->packed_ring ? virtqueue_get_buf_ctx_packed(_vq, len, ctx) :
-				 virtqueue_get_buf_ctx_split(_vq, len, ctx);
+	return vq->packed_ring ? virtqueue_get_buf_ctx_packed(_vq, len, NULL, ctx) :
+				 virtqueue_get_buf_ctx_split(_vq, len, NULL, ctx);
 }
 EXPORT_SYMBOL_GPL(virtqueue_get_buf_ctx);
 
+/**
+ * virtqueue_get_buf_ctx_dma - get the next used buffer with the dma info
+ * @_vq: the struct virtqueue we're talking about.
+ * @len: the length written into the buffer
+ * @dma: the head of the array to store the dma info
+ * @ctx: extra context for the token
+ *
+ * If the device wrote data into the buffer, @len will be set to the
+ * amount written.  This means you don't need to clear the buffer
+ * beforehand to ensure there's no data leakage in the case of short
+ * writes.
+ *
+ * Caller must ensure we don't call this with other virtqueue
+ * operations at the same time (except where noted).
+ *
+ * We store the dma info of every descriptor of this buf to the dma->items
+ * array. If the array size is too small, some dma info may be missed, so
+ * the caller must ensure the array is large enough. The dma->next is the out
+ * value to the caller, indicates the num of the used items.
+ *
+ * Returns NULL if there are no used buffers, or the "data" token
+ * handed to virtqueue_add_*().
+ */
+void *virtqueue_get_buf_ctx_dma(struct virtqueue *_vq, unsigned int *len,
+				struct virtio_dma_head *dma, void **ctx)
+{
+	struct vring_virtqueue *vq = to_vvq(_vq);
+
+	return vq->packed_ring ? virtqueue_get_buf_ctx_packed(_vq, len, dma, ctx) :
+				 virtqueue_get_buf_ctx_split(_vq, len, dma, ctx);
+}
+EXPORT_SYMBOL_GPL(virtqueue_get_buf_ctx_dma);
+
 void *virtqueue_get_buf(struct virtqueue *_vq, unsigned int *len)
 {
 	return virtqueue_get_buf_ctx(_vq, len, NULL);
diff --git a/include/linux/virtio.h b/include/linux/virtio.h
index 4cc614a38376..572aecec205b 100644
--- a/include/linux/virtio.h
+++ b/include/linux/virtio.h
@@ -75,6 +75,22 @@  void *virtqueue_get_buf(struct virtqueue *vq, unsigned int *len);
 void *virtqueue_get_buf_ctx(struct virtqueue *vq, unsigned int *len,
 			    void **ctx);
 
+struct virtio_dma_item {
+	dma_addr_t addr;
+	unsigned int length;
+};
+
+struct virtio_dma_head {
+	/* total num of items. */
+	u16 num;
+	/* point to the next item to store dma info. */
+	u16 next;
+	struct virtio_dma_item items[];
+};
+
+void *virtqueue_get_buf_ctx_dma(struct virtqueue *_vq, unsigned int *len,
+				struct virtio_dma_head *dma, void **ctx);
+
 void virtqueue_disable_cb(struct virtqueue *vq);
 
 bool virtqueue_enable_cb(struct virtqueue *vq);