diff mbox series

[vhost,v7,01/11] virtio_ring: split: separate dma codes

Message ID 20230425073613.8839-2-xuanzhuo@linux.alibaba.com (mailing list archive)
State Not Applicable
Headers show
Series virtio core prepares for AF_XDP | expand

Checks

Context Check Description
netdev/tree_selection success Not a local patch

Commit Message

Xuan Zhuo April 25, 2023, 7:36 a.m. UTC
DMA-related logic is separated from the virtqueue_add_split() to
one new function. DMA address will be saved as sg->dma_address if
use_dma_api is true, then virtqueue_add_split() will use it directly.
Unmap operation will be simpler.

The purpose of this is to facilitate subsequent support to receive
dma address mapped by drivers.

Signed-off-by: Xuan Zhuo <xuanzhuo@linux.alibaba.com>
---
 drivers/virtio/virtio_ring.c | 122 +++++++++++++++++++++++++++--------
 1 file changed, 94 insertions(+), 28 deletions(-)

Comments

Christoph Hellwig May 1, 2023, 4:24 a.m. UTC | #1
> +static dma_addr_t vring_sg_address(struct scatterlist *sg)
> +{
> +	if (sg->dma_address)
> +		return sg->dma_address;

0 is a perfectly valid DMA address.  So I have no idea how this is
even supposed to work.
Xuan Zhuo May 6, 2023, 3:08 a.m. UTC | #2
On Sun, 30 Apr 2023 21:24:52 -0700, Christoph Hellwig <hch@infradead.org> wrote:
> > +static dma_addr_t vring_sg_address(struct scatterlist *sg)
> > +{
> > +	if (sg->dma_address)
> > +		return sg->dma_address;
>
> 0 is a perfectly valid DMA address.  So I have no idea how this is
> even supposed to work.


Yes.

I think I can change it here to judge based on sg_page(sg). If there is no page,
I assume this sg already has DMA Address.

Thanks
diff mbox series

Patch

diff --git a/drivers/virtio/virtio_ring.c b/drivers/virtio/virtio_ring.c
index c5310eaf8b46..b073a70c1291 100644
--- a/drivers/virtio/virtio_ring.c
+++ b/drivers/virtio/virtio_ring.c
@@ -379,6 +379,14 @@  static dma_addr_t vring_map_one_sg(const struct vring_virtqueue *vq,
 			    direction);
 }
 
+static dma_addr_t vring_sg_address(struct scatterlist *sg)
+{
+	if (sg->dma_address)
+		return sg->dma_address;
+
+	return (dma_addr_t)sg_phys(sg);
+}
+
 static dma_addr_t vring_map_single(const struct vring_virtqueue *vq,
 				   void *cpu_addr, size_t size,
 				   enum dma_data_direction direction)
@@ -520,6 +528,80 @@  static inline unsigned int virtqueue_add_desc_split(struct virtqueue *vq,
 	return next;
 }
 
+static void virtqueue_unmap_sgs(struct vring_virtqueue *vq,
+				struct scatterlist *sgs[],
+				unsigned int total_sg,
+				unsigned int out_sgs,
+				unsigned int in_sgs)
+{
+	struct scatterlist *sg;
+	unsigned int n;
+
+	if (!vq->use_dma_api)
+		return;
+
+	for (n = 0; n < out_sgs; n++) {
+		for (sg = sgs[n]; sg; sg = sg_next(sg)) {
+			if (!sg->dma_address)
+				return;
+
+			dma_unmap_page(vring_dma_dev(vq), sg->dma_address,
+				       sg->length, DMA_TO_DEVICE);
+		}
+	}
+
+	for (; n < (out_sgs + in_sgs); n++) {
+		for (sg = sgs[n]; sg; sg = sg_next(sg)) {
+			if (!sg->dma_address)
+				return;
+
+			dma_unmap_page(vring_dma_dev(vq), sg->dma_address,
+				       sg->length, DMA_FROM_DEVICE);
+		}
+	}
+}
+
+static int virtqueue_map_sgs(struct vring_virtqueue *vq,
+			     struct scatterlist *sgs[],
+			     unsigned int total_sg,
+			     unsigned int out_sgs,
+			     unsigned int in_sgs)
+{
+	struct scatterlist *sg;
+	unsigned int n;
+
+	if (!vq->use_dma_api)
+		return 0;
+
+	for (n = 0; n < out_sgs; n++) {
+		for (sg = sgs[n]; sg; sg = sg_next(sg)) {
+			dma_addr_t addr = vring_map_one_sg(vq, sg, DMA_TO_DEVICE);
+
+			if (vring_mapping_error(vq, addr))
+				goto err;
+
+			sg->dma_address = addr;
+		}
+	}
+
+	for (; n < (out_sgs + in_sgs); n++) {
+		for (sg = sgs[n]; sg; sg = sg_next(sg)) {
+			dma_addr_t addr = vring_map_one_sg(vq, sg, DMA_FROM_DEVICE);
+
+			if (vring_mapping_error(vq, addr))
+				goto err;
+
+			sg->dma_address = addr;
+		}
+	}
+
+	return 0;
+
+err:
+	virtqueue_unmap_sgs(vq, sgs, total_sg, out_sgs, in_sgs);
+	return -ENOMEM;
+}
+
 static inline int virtqueue_add_split(struct virtqueue *_vq,
 				      struct scatterlist *sgs[],
 				      unsigned int total_sg,
@@ -532,9 +614,9 @@  static inline int virtqueue_add_split(struct virtqueue *_vq,
 	struct vring_virtqueue *vq = to_vvq(_vq);
 	struct scatterlist *sg;
 	struct vring_desc *desc;
-	unsigned int i, n, avail, descs_used, prev, err_idx;
-	int head;
+	unsigned int i, n, avail, descs_used, prev;
 	bool indirect;
+	int head;
 
 	START_USE(vq);
 
@@ -586,32 +668,30 @@  static inline int virtqueue_add_split(struct virtqueue *_vq,
 		return -ENOSPC;
 	}
 
+	if (virtqueue_map_sgs(vq, sgs, total_sg, out_sgs, in_sgs))
+		goto err_map;
+
 	for (n = 0; n < out_sgs; n++) {
 		for (sg = sgs[n]; sg; sg = sg_next(sg)) {
-			dma_addr_t addr = vring_map_one_sg(vq, sg, DMA_TO_DEVICE);
-			if (vring_mapping_error(vq, addr))
-				goto unmap_release;
-
 			prev = i;
 			/* Note that we trust indirect descriptor
 			 * table since it use stream DMA mapping.
 			 */
-			i = virtqueue_add_desc_split(_vq, desc, i, addr, sg->length,
+			i = virtqueue_add_desc_split(_vq, desc, i,
+						     vring_sg_address(sg),
+						     sg->length,
 						     VRING_DESC_F_NEXT,
 						     indirect);
 		}
 	}
 	for (; n < (out_sgs + in_sgs); n++) {
 		for (sg = sgs[n]; sg; sg = sg_next(sg)) {
-			dma_addr_t addr = vring_map_one_sg(vq, sg, DMA_FROM_DEVICE);
-			if (vring_mapping_error(vq, addr))
-				goto unmap_release;
-
 			prev = i;
 			/* Note that we trust indirect descriptor
 			 * table since it use stream DMA mapping.
 			 */
-			i = virtqueue_add_desc_split(_vq, desc, i, addr,
+			i = virtqueue_add_desc_split(_vq, desc, i,
+						     vring_sg_address(sg),
 						     sg->length,
 						     VRING_DESC_F_NEXT |
 						     VRING_DESC_F_WRITE,
@@ -679,23 +759,9 @@  static inline int virtqueue_add_split(struct virtqueue *_vq,
 	return 0;
 
 unmap_release:
-	err_idx = i;
-
-	if (indirect)
-		i = 0;
-	else
-		i = head;
-
-	for (n = 0; n < total_sg; n++) {
-		if (i == err_idx)
-			break;
-		if (indirect) {
-			vring_unmap_one_split_indirect(vq, &desc[i]);
-			i = virtio16_to_cpu(_vq->vdev, desc[i].next);
-		} else
-			i = vring_unmap_one_split(vq, i);
-	}
+	virtqueue_unmap_sgs(vq, sgs, total_sg, out_sgs, in_sgs);
 
+err_map:
 	if (indirect)
 		kfree(desc);