diff mbox series

[RFC,3/5] RDMA/virtio-rdma: VirtIO RDMA test module

Message ID 20210902130625.25277-4-weijunji@bytedance.com (mailing list archive)
State RFC
Headers show
Series VirtIO RDMA | expand

Commit Message

Junji Wei Sept. 2, 2021, 1:06 p.m. UTC
This is a test module for virtio-rdma, it can
work with rc_pingpong server included in rdma-core.

Signed-off-by: Junji Wei <weijunji@bytedance.com>
---
 drivers/infiniband/hw/virtio/Makefile              |   1 +
 .../hw/virtio/virtio_rdma_rc_pingpong_client.c     | 477 +++++++++++++++++++++
 2 files changed, 478 insertions(+)
 create mode 100644 drivers/infiniband/hw/virtio/virtio_rdma_rc_pingpong_client.c
diff mbox series

Patch

diff --git a/drivers/infiniband/hw/virtio/Makefile b/drivers/infiniband/hw/virtio/Makefile
index fb637e467167..eb72a0aa48f3 100644
--- a/drivers/infiniband/hw/virtio/Makefile
+++ b/drivers/infiniband/hw/virtio/Makefile
@@ -1,4 +1,5 @@ 
 obj-$(CONFIG_INFINIBAND_VIRTIO_RDMA) += virtio_rdma.o
+obj-m := virtio_rdma_rc_pingpong_client.o
 
 virtio_rdma-y := virtio_rdma_main.o virtio_rdma_device.o virtio_rdma_ib.o \
 		 virtio_rdma_netdev.o
diff --git a/drivers/infiniband/hw/virtio/virtio_rdma_rc_pingpong_client.c b/drivers/infiniband/hw/virtio/virtio_rdma_rc_pingpong_client.c
new file mode 100644
index 000000000000..d1a38fe8f8cd
--- /dev/null
+++ b/drivers/infiniband/hw/virtio/virtio_rdma_rc_pingpong_client.c
@@ -0,0 +1,477 @@ 
+/*
+ * Virtio RDMA device: Test client
+ *
+ * Copyright (C) 2021 Junji Wei Bytedance Inc.
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301 USA
+ */
+
+#include <linux/init.h>
+#include <linux/module.h>
+#include <linux/kernel.h>
+
+#include<linux/in.h>
+#include<linux/inet.h>
+#include<linux/socket.h>
+#include<net/sock.h>
+
+#include <asm/dma.h>
+
+#include <rdma/ib_verbs.h>
+#include <rdma/ib_cache.h>
+#include "../../core/uverbs.h"
+
+MODULE_LICENSE("GPL");
+MODULE_AUTHOR("Junji Wei");
+MODULE_DESCRIPTION("Virtio rdma test module");
+MODULE_VERSION("0.01");
+
+#define SERVER_ADDR "10.131.251.125"
+#define SERVER_PORT 18515
+
+#define RX_DEPTH 500
+#define ITER 500
+#define PAGES 5
+
+struct pingpong_dest {
+    int 				lid;
+    int 				out_reads;
+    int 				qpn;
+    int 				psn;
+    unsigned			rkey;
+    unsigned long long		vaddr;
+    union ib_gid			gid;
+    unsigned			srqn;
+    int				gid_index;
+};
+
+static struct ib_device* open_dev(char* path)
+{
+	struct ib_device *ib_dev;
+    struct ib_uverbs_file *file;
+    struct file* filp;
+    struct ib_port_attr port_attr;
+    int rc;
+
+    filp = filp_open(path, O_RDWR | O_CLOEXEC, 0);
+    if (!filp)
+        pr_err("Open failed\n");
+
+    file = filp->private_data;
+    ib_dev = file->device->ib_dev;
+    if (!ib_dev)
+        pr_err("Get ib_dev failed\n");
+
+    pr_info("Open ib_device %s\n", ib_dev->node_desc);
+
+    /* test query_port */
+    rc = ib_query_port(ib_dev, 1, &port_attr);
+    if (rc)
+        pr_err("Query port failed\n");
+    pr_info("Port gid_tbl_len %d\n", port_attr.gid_tbl_len);
+
+	return ib_dev;
+}
+
+static struct socket* ethernet_client_connect(void)
+{
+	struct socket *sock;
+    struct sockaddr_in s_addr;
+    int ret;
+
+    memset(&s_addr,0,sizeof(s_addr));
+    s_addr.sin_family=AF_INET;
+    s_addr.sin_port=htons(SERVER_PORT);
+    
+    s_addr.sin_addr.s_addr = in_aton(SERVER_ADDR);
+    sock = (struct socket *)kmalloc(sizeof(struct socket), GFP_KERNEL);
+
+    /*create a socket*/
+    ret = sock_create_kern(&init_net, AF_INET, SOCK_STREAM, 0, &sock);
+    if (ret < 0) {
+        pr_err("client: socket create error\n");
+    }
+    pr_info("client: socket create ok\n");
+
+    /*connect server*/
+    ret = sock->ops->connect(sock, (struct sockaddr *)&s_addr, sizeof(s_addr), 0);
+    if (ret) {
+        pr_err("client: connect error\n");
+        return NULL;
+    }
+    pr_info("client: connect ok\n");
+
+    return sock;
+}
+
+static int ethernet_read_data(struct socket *sock, char* buf, int size) {
+    struct kvec vec;
+    struct msghdr msg;
+    int ret;
+
+    memset(&vec,0,sizeof(vec));
+    memset(&msg,0,sizeof(msg));
+    vec.iov_base = buf;
+    vec.iov_len = size;
+
+    ret = kernel_recvmsg(sock, &msg, &vec, 1, size, 0);
+    if (ret < 0) {
+        pr_err("read failed\n");
+        return ret;
+    }
+    return ret;
+}
+
+static int ethernet_write_data(struct socket *sock, char* buf, int size) {  
+    struct kvec vec;
+    struct msghdr msg;
+    int ret;
+
+    vec.iov_base = buf;
+    vec.iov_len = size;
+
+    memset(&msg,0,sizeof(msg));
+
+    ret = kernel_sendmsg(sock, &msg, &vec, 1, size);
+    if (ret < 0) {
+        pr_err("kernel_sendmsg error\n");
+        return ret;
+    }else if(ret != size){
+        pr_info("write ret != size");
+    }
+
+    pr_info("send success\n");
+    return ret;
+}
+
+static void gid_to_wire_gid(const union ib_gid *gid, char wgid[])
+{
+	uint32_t tmp_gid[4];
+	int i;
+
+	memcpy(tmp_gid, gid, sizeof(tmp_gid));
+	for (i = 0; i < 4; ++i)
+		sprintf(&wgid[i * 8], "%08x", cpu_to_be32(tmp_gid[i]));
+}
+
+void wire_gid_to_gid(const char *wgid, union ib_gid *gid)
+{
+	char tmp[9];
+	__be32 v32;
+	int i;
+	uint32_t tmp_gid[4];
+
+	for (tmp[8] = 0, i = 0; i < 4; ++i) {
+		memcpy(tmp, wgid + i * 8, 8);
+		sscanf(tmp, "%x", &v32);
+		tmp_gid[i] = be32_to_cpu(v32);
+	}
+	memcpy(gid, tmp_gid, sizeof(*gid));
+}
+
+static struct pingpong_dest *pp_client_exch_dest(const struct pingpong_dest *my_dest)
+{
+    struct socket* sock;
+	char msg[sizeof "0000:000000:000000:00000000000000000000000000000000"];
+	struct pingpong_dest *rem_dest = NULL;
+	char gid[33];
+
+    sock = ethernet_client_connect();
+    if (!sock) {
+        return NULL;
+    }
+
+	gid_to_wire_gid(&my_dest->gid, gid);
+	sprintf(msg, "%04x:%06x:%06x:%s", my_dest->lid, my_dest->qpn,
+							my_dest->psn, gid);
+    pr_info("Local %s\n", msg);
+	if (ethernet_write_data(sock, msg, sizeof msg) != sizeof msg) {
+		pr_err("Couldn't send local address\n");
+		goto out;
+	}
+
+	if (ethernet_read_data(sock, msg, sizeof msg) != sizeof msg ||
+	    ethernet_write_data(sock, "done", sizeof "done") != sizeof "done") {
+		pr_err("Couldn't read/write remote address\n");
+		goto out;
+	}
+
+	rem_dest = kmalloc(sizeof *rem_dest, GFP_KERNEL);
+	if (!rem_dest)
+		goto out;
+
+    pr_info("Remote %s\n", msg);
+	sscanf(msg, "%x:%x:%x:%s", &rem_dest->lid, &rem_dest->qpn,
+						&rem_dest->psn, gid);
+	wire_gid_to_gid(gid, &rem_dest->gid);
+
+out:
+	return rem_dest;
+}
+
+static int __init rdma_test_init(void) {
+    struct ib_device* ib_dev;
+    struct ib_pd* pd;
+    struct ib_mr *mr, *mr_recv;
+    uint64_t dma_addr, dma_addr_recv;
+    struct scatterlist sg;
+    struct scatterlist sgr;
+    const struct ib_cq_init_attr cq_attr = { 64, 0, 0 };
+    struct ib_cq *cq;
+    struct ib_qp *qp;
+    struct ib_qp_init_attr qp_init_attr = {
+        .event_handler = NULL,
+        .qp_context = NULL,
+        .srq = NULL,
+        .xrcd = NULL,
+        .cap = {
+            RX_DEPTH, RX_DEPTH, 1, 1, -1, 0
+        },
+        .sq_sig_type = IB_SIGNAL_ALL_WR,
+        .qp_type = IB_QPT_RC,
+        .create_flags = 0,
+        .port_num = 0,
+        .rwq_ind_tbl = NULL,
+        .source_qpn = 0
+    };
+    struct ib_qp_attr qp_attr = {};
+    struct ib_port_attr port_attr;
+    struct pingpong_dest my_dest;
+    struct pingpong_dest *rem_dest;
+    int mask, rand_num, iter;
+    struct ib_rdma_wr swr;
+    const struct ib_send_wr *bad_swr;
+    struct ib_recv_wr rwr;
+    const struct ib_recv_wr *bad_rwr;
+    struct ib_sge wsge[1], rsge[1];
+    uint64_t *addr_send, *addr_recv;
+    int i, wc_got;
+    struct ib_wc wc[2];
+    struct ib_reg_wr reg_wr;
+
+    ktime_t t0;
+    uint64_t rt;
+    int wc_total = 0;
+
+    pr_info("Start rdma test\n");
+    pr_info("Normal address: 0x%lu -- 0x%px\n", MAX_DMA_ADDRESS, high_memory);
+    
+    ib_dev = open_dev("/dev/infiniband/uverbs0");
+
+    pd = ib_alloc_pd(ib_dev, 0);
+    if (!pd) {
+        pr_err("alloc_pd failed\n");
+        return -ENOMEM;
+    }
+
+    mr = ib_alloc_mr(pd, IB_MR_TYPE_MEM_REG, PAGES);
+    mr_recv = ib_alloc_mr(pd, IB_MR_TYPE_MEM_REG, PAGES);
+    if (!mr || !mr_recv) {
+        pr_err("alloc_mr failed\n");
+        return -EIO;
+    }
+
+    addr_send = ib_dma_alloc_coherent(ib_dev, PAGE_SIZE * PAGES, &dma_addr, GFP_KERNEL);
+    memset((char*)addr_send, '?', 4096 * PAGES);
+    sg_dma_address(&sg) = dma_addr;
+	sg_dma_len(&sg) = PAGE_SIZE * PAGES;
+    ib_map_mr_sg(mr, &sg, 1, NULL, PAGE_SIZE);
+
+    addr_recv = ib_dma_alloc_coherent(ib_dev, PAGE_SIZE * PAGES, &dma_addr_recv, GFP_KERNEL);
+    sg_dma_address(&sgr) = dma_addr_recv;
+	sg_dma_len(&sgr) = PAGE_SIZE * PAGES;
+    ib_map_mr_sg(mr_recv, &sgr, 1, NULL, PAGE_SIZE);
+
+    memset((char*)addr_recv, 'x', 4096 * PAGES);
+    strcpy((char*)addr_recv, "hello world");
+    pr_info("Before %s\n", (char*)addr_send);
+    pr_info("Before %s\n", (char*)addr_recv);
+
+    cq = ib_create_cq(ib_dev, NULL, NULL, NULL, &cq_attr);
+    if (!cq) {
+        pr_err("create_cq failed\n");
+    }
+
+    qp_init_attr.send_cq = cq;
+    qp_init_attr.recv_cq = cq;
+    pr_info("qp type: %d\n", qp_init_attr.qp_type);
+    qp = ib_create_qp(pd, &qp_init_attr);
+    if (!qp) {
+        pr_err("create_qp failed\n");
+    }
+
+    // modify to init
+    memset(&qp_attr, 0, sizeof(qp_attr));
+    mask = IB_QP_STATE | IB_QP_ACCESS_FLAGS | IB_QP_PKEY_INDEX | IB_QP_PORT;
+    qp_attr.qp_state = IB_QPS_INIT;
+    qp_attr.port_num = 1;
+    qp_attr.pkey_index = 0;
+    qp_attr.qp_access_flags = 0;
+    ib_modify_qp(qp, &qp_attr, mask);
+
+    memset(&reg_wr, 0, sizeof(reg_wr));
+	reg_wr.wr.opcode = IB_WR_REG_MR;
+	reg_wr.wr.num_sge = 0;
+	reg_wr.mr = mr;
+	reg_wr.key = mr->lkey;
+	reg_wr.access = IB_ACCESS_LOCAL_WRITE;
+    ib_post_send(qp, &reg_wr.wr, &bad_swr);
+
+    memset(&reg_wr, 0, sizeof(reg_wr));
+	reg_wr.wr.opcode = IB_WR_REG_MR;
+	reg_wr.wr.num_sge = 0;
+	reg_wr.mr = mr_recv;
+	reg_wr.key = mr_recv->lkey;
+	reg_wr.access = IB_ACCESS_LOCAL_WRITE;
+    ib_post_send(qp, &reg_wr.wr, &bad_swr);
+
+    // post recv
+    rsge[0].addr = dma_addr_recv;
+    rsge[0].length = 4096 * PAGES;
+    rsge[0].lkey = mr_recv->lkey;
+
+    rwr.next = NULL;
+    rwr.wr_id = 1;
+    rwr.sg_list = rsge;
+    rwr.num_sge = 1;
+    for (i = 0; i < ITER; i++) {
+        if (ib_post_recv(qp, &rwr, &bad_rwr)) {
+            pr_err("post recv failed\n");
+            return -EIO;
+        }
+    }
+
+    // exchange info
+	if (ib_query_port(ib_dev, 1, &port_attr))
+		pr_err("query port failed");
+    my_dest.lid = port_attr.lid;
+
+    // TODO: fix rdma_query_gid
+    if (rdma_query_gid(ib_dev, 1, 1, &my_dest.gid))
+        pr_err("query gid failed");
+
+    get_random_bytes(&rand_num, sizeof(rand_num));
+    my_dest.gid_index = 1;
+    my_dest.qpn = qp->qp_num;
+    my_dest.psn = rand_num & 0xffffff;
+
+    pr_info("  local address:  LID 0x%04x, QPN 0x%06x, PSN 0x%06x, GID %pI6\n",
+	         my_dest.lid, my_dest.qpn, my_dest.psn, &my_dest.gid);
+
+    rem_dest = pp_client_exch_dest(&my_dest);
+    if (!rem_dest) {
+        return -EIO;
+    }
+
+    pr_info("  remote address: LID 0x%04x, QPN 0x%06x, PSN 0x%06x, GID %pI6\n",
+	       rem_dest->lid, rem_dest->qpn, rem_dest->psn, &rem_dest->gid);
+
+    my_dest.rkey = mr->rkey;
+    my_dest.out_reads = 1;
+    my_dest.vaddr = dma_addr;
+    my_dest.srqn = 0;
+
+    // modify to rtr
+    memset(&qp_attr, 0, sizeof(qp_attr));
+    mask = IB_QP_STATE | IB_QP_AV | IB_QP_PATH_MTU | IB_QP_DEST_QPN | IB_QP_RQ_PSN | IB_QP_MIN_RNR_TIMER | IB_QP_MAX_DEST_RD_ATOMIC;
+	qp_attr.qp_state		= IB_QPS_RTR;
+	qp_attr.path_mtu		= IB_MTU_1024;
+	qp_attr.dest_qp_num		= rem_dest->qpn;
+	qp_attr.rq_psn			= rem_dest->psn;
+	qp_attr.max_dest_rd_atomic	= 1;
+	qp_attr.min_rnr_timer		= 12;
+    qp_attr.ah_attr.ah_flags = IB_AH_GRH;
+	qp_attr.ah_attr.ib.dlid = rem_dest->lid; // is_global  lid
+    qp_attr.ah_attr.ib.src_path_bits = 0;
+	qp_attr.ah_attr.sl		= 0;
+	qp_attr.ah_attr.port_num	= 1;
+
+	if (rem_dest->gid.global.interface_id) {
+		qp_attr.ah_attr.grh.hop_limit = 1;
+		qp_attr.ah_attr.grh.dgid = rem_dest->gid;
+		qp_attr.ah_attr.grh.sgid_index = my_dest.gid_index;
+	}
+
+    if (ib_modify_qp(qp, &qp_attr, mask)) {
+        pr_info("Failed to modify to RTR\n");
+        return -EIO;
+    }
+
+    // modify to rts
+    memset(&qp_attr, 0, sizeof(qp_attr));
+    mask = IB_QP_STATE | IB_QP_SQ_PSN | IB_QP_TIMEOUT | IB_QP_RETRY_CNT | IB_QP_RNR_RETRY | IB_QP_MAX_QP_RD_ATOMIC;
+    qp_attr.qp_state = IB_QPS_RTS;
+	qp_attr.sq_psn = my_dest.psn;
+    qp_attr.timeout   = 14;
+    qp_attr.retry_cnt = 7;
+    qp_attr.rnr_retry = 7;
+    qp_attr.max_rd_atomic  = 1;
+    if (ib_modify_qp(qp, &qp_attr, mask)) {
+        pr_info("Failed to modify to RTS\n");
+    }
+
+    wsge[0].addr = dma_addr;
+    wsge[0].length = 4096 * PAGES;
+    wsge[0].lkey = mr->lkey;
+
+    swr.wr.next = NULL;
+    swr.wr.wr_id = 2;
+    swr.wr.sg_list = wsge;
+    swr.wr.num_sge = 1;
+    swr.wr.opcode = IB_WR_SEND;
+    swr.wr.send_flags = IB_SEND_SIGNALED;
+    swr.remote_addr = rem_dest->vaddr;
+    swr.rkey = rem_dest->rkey;
+
+    t0 = ktime_get();
+
+    for (iter = 0; iter < ITER; iter++) {
+        if (ib_post_send(qp, &swr.wr, &bad_swr)) {
+            pr_err("post send failed\n");
+            return -EIO;
+        }
+
+        do {
+            wc_got = ib_poll_cq(cq, 2, wc);
+        } while(wc_got < 1);
+        wc_total += wc_got;
+    }
+
+    pr_info("Total wc %d\n", wc_total);
+    do {
+        wc_total += ib_poll_cq(cq, 2, wc);
+    }while(wc_total < ITER * 2);
+
+    rt = ktime_to_us(ktime_sub(ktime_get(), t0));
+    pr_info("%d iters in %lld us = %lld usec/iter\n", ITER, rt, rt / ITER);
+    pr_info("%d bytes in %lld us = %lld Mbit/sec\n", ITER * 4096 * 2, rt, (uint64_t)ITER * 62500 / rt);
+
+    pr_info("After %s\n", (char*)addr_send);
+    pr_info("After %s\n", (char*)addr_recv);
+
+    ib_destroy_qp(qp);
+    ib_destroy_cq(cq);
+    ib_dereg_mr(mr);
+    ib_dereg_mr(mr_recv);
+    ib_dealloc_pd(pd);
+    return 0;
+}
+
+static void __exit rdma_test_exit(void) {
+    pr_info("Exit rdma test\n");
+}
+
+module_init(rdma_test_init);
+module_exit(rdma_test_exit);