diff mbox series

[RFC,v2,liburing,2/2] add tests for multishot recvmsg

Message ID 20220714115428.1569612-3-dylany@fb.com (mailing list archive)
State New
Headers show
Series multishot recvmsg | expand

Commit Message

Dylan Yudaken July 14, 2022, 11:54 a.m. UTC
Expand the multishot recv test to include recvmsg.
This also checks that sockaddr comes back, and that control messages work
properly.

Signed-off-by: Dylan Yudaken <dylany@fb.com>
---
 test/recv-multishot.c | 180 +++++++++++++++++++++++++++++++++++++-----
 1 file changed, 161 insertions(+), 19 deletions(-)
diff mbox series

Patch

diff --git a/test/recv-multishot.c b/test/recv-multishot.c
index 9df8184..a322e43 100644
--- a/test/recv-multishot.c
+++ b/test/recv-multishot.c
@@ -27,20 +27,45 @@  enum early_error_t {
 struct args {
 	bool stream;
 	bool wait_each;
+	bool recvmsg;
 	enum early_error_t early_error;
 };
 
+static int check_sockaddr(struct sockaddr_in *in)
+{
+	struct in_addr expected;
+
+	inet_pton(AF_INET, "127.0.0.1", &expected);
+	if (in->sin_family != AF_INET) {
+		fprintf(stderr, "bad family %d\n", (int)htons(in->sin_family));
+		return -1;
+	}
+	if (memcmp(&expected, &in->sin_addr, sizeof(in->sin_addr))) {
+		char buff[256];
+		const char *addr = inet_ntop(AF_INET, &in->sin_addr, buff, sizeof(buff));
+
+		fprintf(stderr, "unexpected address %s\n", addr ? addr : "INVALID");
+		return -1;
+	}
+	return 0;
+}
+
 static int test(struct args *args)
 {
 	int const N = 8;
 	int const N_BUFFS = N * 64;
 	int const N_CQE_OVERFLOW = 4;
 	int const min_cqes = 2;
+	int const NAME_LEN = sizeof(struct sockaddr_storage);
+	int const CONTROL_LEN = CMSG_ALIGN(sizeof(struct sockaddr_storage))
+					+ sizeof(struct cmsghdr);
 	struct io_uring ring;
 	struct io_uring_cqe *cqe;
 	struct io_uring_sqe *sqe;
-	int fds[2], ret, i, j, total_sent_bytes = 0, total_recv_bytes = 0;
+	int fds[2], ret, i, j;
+	int total_sent_bytes = 0, total_recv_bytes = 0, total_dropped_bytes = 0;
 	int send_buff[256];
+	int *sent_buffs[N_BUFFS];
 	int *recv_buffs[N_BUFFS];
 	int *at;
 	struct io_uring_cqe recv_cqe[N_BUFFS];
@@ -50,7 +75,7 @@  static int test(struct args *args)
 	struct __kernel_timespec timeout = {
 		.tv_sec = 1,
 	};
-
+	struct msghdr msg;
 
 	memset(recv_buffs, 0, sizeof(recv_buffs));
 
@@ -75,21 +100,42 @@  static int test(struct args *args)
 		return ret;
 	}
 
+	if (!args->stream) {
+		bool val = true;
+
+		/* force some cmsgs to come back to us */
+		ret = setsockopt(fds[0], IPPROTO_IP, IP_RECVORIGDSTADDR, &val,
+				 sizeof(val));
+		if (ret) {
+			fprintf(stderr, "setsockopt failed %d\n", errno);
+			goto cleanup;
+		}
+	}
+
 	for (i = 0; i < ARRAY_SIZE(send_buff); i++)
 		send_buff[i] = i;
 
 	for (i = 0; i < ARRAY_SIZE(recv_buffs); i++) {
 		/* prepare some different sized buffers */
-		int buffer_size = (i % 2 == 0 && args->stream) ? 1 : N * sizeof(int);
+		int buffer_size = (i % 2 == 0 && (args->stream || args->recvmsg)) ? 1 : N;
+
+		buffer_size *= sizeof(int);
+		if (args->recvmsg) {
+			buffer_size +=
+				sizeof(struct io_uring_recvmsg_out) +
+				NAME_LEN +
+				CONTROL_LEN;
+		}
 
-		recv_buffs[i] = malloc(sizeof(*at) * buffer_size);
+		recv_buffs[i] = malloc(buffer_size);
 
 		if (i > 2 && args->early_error == ERROR_NOT_ENOUGH_BUFFERS)
 			continue;
 
 		sqe = io_uring_get_sqe(&ring);
 		io_uring_prep_provide_buffers(sqe, recv_buffs[i],
-					buffer_size * sizeof(*recv_buffs[i]), 1, 7, i);
+					buffer_size, 1, 7, i);
+		memset(recv_buffs[i], 0xcc, buffer_size);
 		if (io_uring_submit_and_wait_timeout(&ring, &cqe, 1, &timeout, NULL) != 0) {
 			fprintf(stderr, "provide buffers failed: %d\n", ret);
 			ret = -1;
@@ -99,7 +145,19 @@  static int test(struct args *args)
 	}
 
 	sqe = io_uring_get_sqe(&ring);
-	io_uring_prep_recv_multishot(sqe, fds[0], NULL, 0, 0);
+	if (args->recvmsg) {
+		unsigned int flags = 0;
+
+		if (!args->stream)
+			flags |= MSG_TRUNC;
+
+		memset(&msg, 0, sizeof(msg));
+		msg.msg_namelen = NAME_LEN;
+		msg.msg_controllen = CONTROL_LEN;
+		io_uring_prep_recvmsg_multishot(sqe, fds[0], &msg, flags);
+	} else {
+		io_uring_prep_recv_multishot(sqe, fds[0], NULL, 0, 0);
+	}
 	sqe->flags |= IOSQE_BUFFER_SELECT;
 	sqe->buf_group = 7;
 	io_uring_sqe_set_data64(sqe, 1234);
@@ -111,6 +169,7 @@  static int test(struct args *args)
 		int to_send = sizeof(*at) * (i+1);
 
 		total_sent_bytes += to_send;
+		sent_buffs[i] = at;
 		if (send(fds[1], at, to_send, 0) != to_send) {
 			if (early_error_started)
 				break;
@@ -202,9 +261,12 @@  static int test(struct args *args)
 			(args->early_error == ERROR_EARLY_OVERFLOW &&
 			 !args->wait_each && i == N_CQE_OVERFLOW);
 		int *this_recv;
+		int orig_payload_size = cqe->res;
 
 
 		if (should_be_last) {
+			int used_res = cqe->res;
+
 			if (!is_last) {
 				fprintf(stderr, "not last cqe had error %d\n", i);
 				goto cleanup;
@@ -234,7 +296,22 @@  static int test(struct args *args)
 				break;
 			case ERROR_NONE:
 			case ERROR_EARLY_CLOSE_SENDER:
-				if (cqe->res != 0) {
+				if (args->recvmsg && (cqe->flags & IORING_CQE_F_BUFFER)) {
+					void *buff = recv_buffs[cqe->flags >> 16];
+					struct io_uring_recvmsg_out *o =
+						io_uring_recvmsg_validate(buff, cqe->res, &msg);
+
+					if (!o) {
+						fprintf(stderr, "invalid buff\n");
+						goto cleanup;
+					}
+					if (o->payloadlen != 0) {
+						fprintf(stderr, "expected 0 payloadlen, got %u\n",
+							o->payloadlen);
+						goto cleanup;
+					}
+					used_res = 0;
+				} else if (cqe->res != 0) {
 					fprintf(stderr, "early error: res %d\n", cqe->res);
 					goto cleanup;
 				}
@@ -254,7 +331,7 @@  static int test(struct args *args)
 				goto cleanup;
 			}
 
-			if (cqe->res <= 0)
+			if (used_res <= 0)
 				continue;
 		} else {
 			if (!(cqe->flags & IORING_CQE_F_MORE)) {
@@ -268,7 +345,61 @@  static int test(struct args *args)
 			goto cleanup;
 		}
 
+		this_recv = recv_buffs[cqe->flags >> 16];
+
+		if (args->recvmsg) {
+			struct io_uring_recvmsg_out *o = io_uring_recvmsg_validate(
+				this_recv, cqe->res, &msg);
+
+			if (!o) {
+				fprintf(stderr, "bad recvmsg\n");
+				goto cleanup;
+			}
+			orig_payload_size = o->payloadlen;
+
+			if (!args->stream) {
+				orig_payload_size = o->payloadlen;
+
+				struct cmsghdr *cmsg;
+
+				if (o->namelen < sizeof(struct sockaddr_in)) {
+					fprintf(stderr, "bad addr len %d",
+						o->namelen);
+					goto cleanup;
+				}
+				if (check_sockaddr((struct sockaddr_in *)io_uring_recvmsg_name(o)))
+					goto cleanup;
+
+				cmsg = io_uring_recvmsg_cmsg_firsthdr(o, &msg);
+				if (!cmsg ||
+				    cmsg->cmsg_level != IPPROTO_IP ||
+				    cmsg->cmsg_type != IP_RECVORIGDSTADDR) {
+					fprintf(stderr, "bad cmsg");
+					goto cleanup;
+				}
+				if (check_sockaddr((struct sockaddr_in *)CMSG_DATA(cmsg)))
+					goto cleanup;
+				cmsg = io_uring_recvmsg_cmsg_nexthdr(o, &msg, cmsg);
+				if (cmsg) {
+					fprintf(stderr, "unexpected extra cmsg\n");
+					goto cleanup;
+				}
+
+			}
+
+			this_recv = (int *)io_uring_recvmsg_payload(o, &msg);
+			cqe->res = io_uring_recvmsg_payload_length(o, cqe->res, &msg);
+			if (o->payloadlen != cqe->res) {
+				if (!(o->flags & MSG_TRUNC)) {
+					fprintf(stderr, "expected truncated flag\n");
+					goto cleanup;
+				}
+				total_dropped_bytes += (o->payloadlen - cqe->res);
+			}
+		}
+
 		total_recv_bytes += cqe->res;
+
 		if (cqe->res % 4 != 0) {
 			/*
 			 * doesn't seem to happen in practice, would need some
@@ -278,9 +409,20 @@  static int test(struct args *args)
 			goto cleanup;
 		}
 
-		/* check buffer arrived in order (for tcp) */
-		this_recv = recv_buffs[cqe->flags >> 16];
-		for (j = 0; args->stream && j < cqe->res / 4; j++) {
+		/*
+		 * for tcp: check buffer arrived in order
+		 * for udp: based on size validate data based on size
+		 */
+		if (!args->stream) {
+			int sent_idx = orig_payload_size / sizeof(*at) - 1;
+
+			if (sent_idx < 0 || sent_idx > N) {
+				fprintf(stderr, "Bad sent idx: %d\n", sent_idx);
+				goto cleanup;
+			}
+			at = sent_buffs[sent_idx];
+		}
+		for (j = 0; j < cqe->res / 4; j++) {
 			int sent = *at++;
 			int recv = *this_recv++;
 
@@ -291,15 +433,14 @@  static int test(struct args *args)
 		}
 	}
 
-	if (args->early_error == ERROR_NONE && total_recv_bytes < total_sent_bytes) {
+	if (args->early_error == ERROR_NONE &&
+	    total_recv_bytes + total_dropped_bytes < total_sent_bytes) {
 		fprintf(stderr,
-			"missing recv: recv=%d sent=%d\n", total_recv_bytes, total_sent_bytes);
+			"missing recv: recv=%d dropped=%d sent=%d\n",
+			total_recv_bytes, total_sent_bytes, total_dropped_bytes);
 		goto cleanup;
 	}
 
-	/* check the final one */
-	cqe = &recv_cqe[recv_cqes-1];
-
 	ret = 0;
 cleanup:
 	for (i = 0; i < ARRAY_SIZE(recv_buffs); i++)
@@ -320,18 +461,19 @@  int main(int argc, char *argv[])
 	if (argc > 1)
 		return T_EXIT_SKIP;
 
-	for (loop = 0; loop < 4; loop++) {
+	for (loop = 0; loop < 8; loop++) {
 		struct args a = {
 			.stream = loop & 0x01,
 			.wait_each = loop & 0x2,
+			.recvmsg = loop & 0x04,
 		};
 		for (early_error = 0; early_error < ERROR_EARLY_LAST; early_error++) {
 			a.early_error = (enum early_error_t)early_error;
 			ret = test(&a);
 			if (ret) {
 				fprintf(stderr,
-					"test stream=%d wait_each=%d early_error=%d failed\n",
-					a.stream, a.wait_each, a.early_error);
+					"test stream=%d wait_each=%d recvmsg=%d early_error=%d failed\n",
+					a.stream, a.wait_each, a.recvmsg, a.early_error);
 				return T_EXIT_FAIL;
 			}
 			if (no_recv_mshot)