diff mbox series

[mptcp-next,v3,4/4] selftests: mptcp: add inq test case

Message ID 20211116134804.30090-5-fw@strlen.de (mailing list archive)
State Accepted, archived
Commit 75e2fd07442774b350fd6fd45f8209d45e2476ed
Delegated to: Matthieu Baerts
Headers show
Series TCP_INQ support | expand

Checks

Context Check Description
matttbe/build success Build and static analysis OK
matttbe/checkpatch warning total: 0 errors, 1 warnings, 2 checks, 667 lines checked
matttbe/KVM_Validation__normal warning Unstable: 1 failed test(s): packetdrill_sockopts
matttbe/KVM_Validation__debug warning Unstable: 3 failed test(s): packetdrill_sockopts selftest_diag selftest_mptcp_join

Commit Message

Florian Westphal Nov. 16, 2021, 1:48 p.m. UTC
client & server use a unix socket connection to communicate
outside of the mptcp connection.

This allows the consumer to know in advance how many bytes have been
(or will be) sent by the peer.
This allows stricter checks on the bytecounts reported by TCP_INQ cmsg.

Suggested-by: Mat Martineau <mathew.j.martineau@linux.intel.com>
Signed-off-by: Florian Westphal <fw@strlen.de>
---
 tools/testing/selftests/net/mptcp/.gitignore  |   1 +
 tools/testing/selftests/net/mptcp/Makefile    |   2 +-
 tools/testing/selftests/net/mptcp/mptcp_inq.c | 603 ++++++++++++++++++
 .../selftests/net/mptcp/mptcp_sockopt.sh      |  40 ++
 4 files changed, 645 insertions(+), 1 deletion(-)
 create mode 100644 tools/testing/selftests/net/mptcp/mptcp_inq.c

Comments

Mat Martineau Nov. 16, 2021, 7:12 p.m. UTC | #1
On Tue, 16 Nov 2021, Florian Westphal wrote:

> client & server use a unix socket connection to communicate
> outside of the mptcp connection.
>
> This allows the consumer to know in advance how many bytes have been
> (or will be) sent by the peer.
> This allows stricter checks on the bytecounts reported by TCP_INQ cmsg.
>
> Suggested-by: Mat Martineau <mathew.j.martineau@linux.intel.com>
> Signed-off-by: Florian Westphal <fw@strlen.de>
> ---
> tools/testing/selftests/net/mptcp/.gitignore  |   1 +
> tools/testing/selftests/net/mptcp/Makefile    |   2 +-
> tools/testing/selftests/net/mptcp/mptcp_inq.c | 603 ++++++++++++++++++
> .../selftests/net/mptcp/mptcp_sockopt.sh      |  40 ++
> 4 files changed, 645 insertions(+), 1 deletion(-)
> create mode 100644 tools/testing/selftests/net/mptcp/mptcp_inq.c
>
> diff --git a/tools/testing/selftests/net/mptcp/.gitignore b/tools/testing/selftests/net/mptcp/.gitignore
> index 7569d892967a..49daae73c41e 100644
> --- a/tools/testing/selftests/net/mptcp/.gitignore
> +++ b/tools/testing/selftests/net/mptcp/.gitignore
> @@ -1,5 +1,6 @@
> # SPDX-License-Identifier: GPL-2.0-only
> mptcp_connect
> +mptcp_inq
> mptcp_sockopt
> pm_nl_ctl
> *.pcap
> diff --git a/tools/testing/selftests/net/mptcp/Makefile b/tools/testing/selftests/net/mptcp/Makefile
> index bbf4e448bad9..0356c4501c99 100644
> --- a/tools/testing/selftests/net/mptcp/Makefile
> +++ b/tools/testing/selftests/net/mptcp/Makefile
> @@ -8,7 +8,7 @@ CFLAGS =  -Wall -Wl,--no-as-needed -O2 -g  -I$(top_srcdir)/usr/include
> TEST_PROGS := mptcp_connect.sh pm_netlink.sh mptcp_join.sh diag.sh \
> 	      simult_flows.sh mptcp_sockopt.sh
>
> -TEST_GEN_FILES = mptcp_connect pm_nl_ctl mptcp_sockopt
> +TEST_GEN_FILES = mptcp_connect pm_nl_ctl mptcp_sockopt mptcp_inq
>
> TEST_FILES := settings
>
> diff --git a/tools/testing/selftests/net/mptcp/mptcp_inq.c b/tools/testing/selftests/net/mptcp/mptcp_inq.c
> new file mode 100644
> index 000000000000..62ee2dd63b5f
> --- /dev/null
> +++ b/tools/testing/selftests/net/mptcp/mptcp_inq.c
> @@ -0,0 +1,603 @@
> +// SPDX-License-Identifier: GPL-2.0
> +
> +#define _GNU_SOURCE
> +
> +#include <assert.h>
> +#include <errno.h>
> +#include <fcntl.h>
> +#include <limits.h>
> +#include <string.h>
> +#include <stdarg.h>
> +#include <stdbool.h>
> +#include <stdint.h>
> +#include <inttypes.h>
> +#include <stdio.h>
> +#include <stdlib.h>
> +#include <strings.h>
> +#include <unistd.h>
> +#include <time.h>
> +
> +#include <sys/ioctl.h>
> +#include <sys/ioctl.h>
> +#include <sys/socket.h>
> +#include <sys/types.h>
> +#include <sys/wait.h>
> +
> +#include <netdb.h>
> +#include <netinet/in.h>
> +
> +#include <linux/tcp.h>
> +#include <linux/sockios.h>
> +
> +#ifndef IPPROTO_MPTCP
> +#define IPPROTO_MPTCP 262
> +#endif
> +#ifndef SOL_MPTCP
> +#define SOL_MPTCP 284
> +#endif
> +
> +static int pf = AF_INET;
> +static int proto_tx = IPPROTO_MPTCP;
> +static int proto_rx = IPPROTO_MPTCP;
> +
> +static void die_perror(const char *msg)
> +{
> +	perror(msg);
> +	exit(1);
> +}
> +
> +static void die_usage(int r)
> +{
> +	fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
> +	exit(r);
> +}
> +
> +static void xerror(const char *fmt, ...)
> +{
> +	va_list ap;
> +
> +	va_start(ap, fmt);
> +	vfprintf(stderr, fmt, ap);
> +	va_end(ap);
> +	fputc('\n', stderr);
> +	exit(1);
> +}
> +
> +static const char *getxinfo_strerr(int err)
> +{
> +	if (err == EAI_SYSTEM)
> +		return strerror(errno);
> +
> +	return gai_strerror(err);
> +}
> +
> +static void xgetaddrinfo(const char *node, const char *service,
> +			 const struct addrinfo *hints,
> +			 struct addrinfo **res)
> +{
> +	int err = getaddrinfo(node, service, hints, res);
> +
> +	if (err) {
> +		const char *errstr = getxinfo_strerr(err);
> +
> +		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
> +			node ? node : "", service ? service : "", errstr);
> +		exit(1);
> +	}
> +}
> +
> +static int sock_listen_mptcp(const char * const listenaddr,
> +			     const char * const port)
> +{
> +	int sock;
> +	struct addrinfo hints = {
> +		.ai_protocol = IPPROTO_TCP,
> +		.ai_socktype = SOCK_STREAM,
> +		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
> +	};
> +
> +	hints.ai_family = pf;
> +
> +	struct addrinfo *a, *addr;
> +	int one = 1;
> +
> +	xgetaddrinfo(listenaddr, port, &hints, &addr);
> +	hints.ai_family = pf;
> +
> +	for (a = addr; a; a = a->ai_next) {
> +		sock = socket(a->ai_family, a->ai_socktype, proto_rx);
> +		if (sock < 0)
> +			continue;
> +
> +		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
> +				     sizeof(one)))
> +			perror("setsockopt");
> +
> +		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
> +			break; /* success */
> +
> +		perror("bind");
> +		close(sock);
> +		sock = -1;
> +	}
> +
> +	freeaddrinfo(addr);
> +
> +	if (sock < 0)
> +		xerror("could not create listen socket");
> +
> +	if (listen(sock, 20))
> +		die_perror("listen");
> +
> +	return sock;
> +}
> +
> +static int sock_connect_mptcp(const char * const remoteaddr,
> +			      const char * const port, int proto)
> +{
> +	struct addrinfo hints = {
> +		.ai_protocol = IPPROTO_TCP,
> +		.ai_socktype = SOCK_STREAM,
> +	};
> +	struct addrinfo *a, *addr;
> +	int sock = -1;
> +
> +	hints.ai_family = pf;
> +
> +	xgetaddrinfo(remoteaddr, port, &hints, &addr);
> +	for (a = addr; a; a = a->ai_next) {
> +		sock = socket(a->ai_family, a->ai_socktype, proto);
> +		if (sock < 0)
> +			continue;
> +
> +		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
> +			break; /* success */
> +
> +		die_perror("connect");
> +	}
> +
> +	if (sock < 0)
> +		xerror("could not create connect socket");
> +
> +	freeaddrinfo(addr);
> +	return sock;
> +}
> +
> +static int protostr_to_num(const char *s)
> +{
> +	if (strcasecmp(s, "tcp") == 0)
> +		return IPPROTO_TCP;
> +	if (strcasecmp(s, "mptcp") == 0)
> +		return IPPROTO_MPTCP;
> +
> +	die_usage(1);
> +	return 0;
> +}
> +
> +static void parse_opts(int argc, char **argv)
> +{
> +	int c;
> +
> +	while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
> +		switch (c) {
> +		case 'h':
> +			die_usage(0);
> +			break;
> +		case '6':
> +			pf = AF_INET6;
> +			break;
> +		case 't':
> +			proto_tx = protostr_to_num(optarg);
> +			break;
> +		case 'r':
> +			proto_rx = protostr_to_num(optarg);
> +			break;
> +		default:
> +			die_usage(1);
> +			break;
> +		}
> +	}
> +}
> +
> +/* wait up to timeout milliseconds */
> +static void wait_for_ack(int fd, int timeout, size_t total)
> +{
> +	int i;
> +
> +	for (i = 0; i < timeout; i++) {
> +		int nsd, ret, queued = -1;
> +		struct timespec req;
> +
> +		ret = ioctl(fd, TIOCOUTQ, &queued);
> +		if (ret < 0)
> +			die_perror("TIOCOUTQ");
> +
> +		ret = ioctl(fd, SIOCOUTQNSD, &nsd);
> +		if (ret < 0)
> +			die_perror("SIOCOUTQNSD");
> +
> +		if ((size_t)queued > total)
> +			xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
> +		assert(nsd <= queued);
> +
> +		if (queued == 0)
> +			return;
> +
> +		/* wait for peer to ack rx of all data */
> +		req.tv_sec = 0;
> +		req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
> +		nanosleep(&req, NULL);
> +	}
> +
> +	xerror("still tx data queued after %u ms\n", timeout);
> +}
> +
> +static void connect_one_server(int fd, int unixfd)
> +{
> +	size_t len, i, total, sent;
> +	char buf[4096], buf2[4096];
> +	int queued, on = 1;

Still getting the 'unused variable' warnings on this line, but we can 
squash a fix for that (or wait to apply just this last patch?).

- Mat

> +	ssize_t ret;
> +
> +	len = rand() % (sizeof(buf) - 1);
> +
> +	if (len < 128)
> +		len = 128;
> +
> +	for (i = 0; i < len ; i++) {
> +		buf[i] = rand() % 26;
> +		buf[i] += 'A';
> +	}
> +
> +	buf[i] = '\n';
> +
> +	/* un-block server */
> +	ret = read(unixfd, buf2, 4);
> +	assert(ret == 4);
> +
> +	assert(strncmp(buf2, "xmit", 4) == 0);
> +
> +	ret = write(unixfd, &len, sizeof(len));
> +	assert(ret == (ssize_t)sizeof(len));
> +
> +	ret = write(fd, buf, len);
> +	if (ret < 0)
> +		die_perror("write");
> +
> +	if (ret != (ssize_t)len)
> +		xerror("short write");
> +
> +	ret = read(unixfd, buf2, 4);
> +	assert(strncmp(buf2, "huge", 4) == 0);
> +
> +	total = rand() % (16 * 1024 * 1024);
> +	total += (1 * 1024 * 1024);
> +	sent = total;
> +
> +	ret = write(unixfd, &total, sizeof(total));
> +	assert(ret == (ssize_t)sizeof(total));
> +
> +	wait_for_ack(fd, 5000, len);
> +
> +	while (total > 0) {
> +		if (total > sizeof(buf))
> +			len = sizeof(buf);
> +		else
> +			len = total;
> +
> +		ret = write(fd, buf, len);
> +		if (ret < 0)
> +			die_perror("write");
> +		total -= ret;
> +
> +		/* we don't have to care about buf content, only
> +		 * number of total bytes sent
> +		 */
> +	}
> +
> +	ret = read(unixfd, buf2, 4);
> +	assert(ret == 4);
> +	assert(strncmp(buf2, "shut", 4) == 0);
> +
> +	wait_for_ack(fd, 5000, sent);
> +
> +	write(fd, buf, 1);
> +	close(fd);
> +	ret = write(unixfd, "closed", 6);
> +	assert(ret == 6);
> +
> +	close(unixfd);
> +}
> +
> +static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
> +{
> +	struct cmsghdr *cmsg;
> +
> +	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
> +		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
> +			memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
> +			return;
> +		}
> +	}
> +
> +	xerror("could not find TCP_CM_INQ cmsg type");
> +}
> +
> +static void process_one_client(int fd, int unixfd)
> +{
> +	unsigned int tcp_inq;
> +	size_t expect_len;
> +	char msg_buf[4096];
> +	char buf[4096];
> +	char tmp[16];
> +	struct iovec iov = {
> +		.iov_base = buf,
> +		.iov_len = 1,
> +	};
> +	struct msghdr msg = {
> +		.msg_iov = &iov,
> +		.msg_iovlen = 1,
> +		.msg_control = msg_buf,
> +		.msg_controllen = sizeof(msg_buf),
> +	};
> +	ssize_t ret, tot;
> +
> +	ret = write(unixfd, "xmit", 4);
> +	assert(ret == 4);
> +
> +	ret = read(unixfd, &expect_len, sizeof(expect_len));
> +	assert(ret == (ssize_t)sizeof(expect_len));
> +
> +	if (expect_len > sizeof(buf))
> +		xerror("expect len %zu exceeds buffer size", expect_len);
> +
> +	for (;;) {
> +		struct timespec req;
> +		unsigned int queued;
> +
> +		ret = ioctl(fd, FIONREAD, &queued);
> +		if (ret < 0)
> +			die_perror("FIONREAD");
> +		if (queued > expect_len)
> +			xerror("FIONREAD returned %u, but only %zu expected\n",
> +					queued, expect_len);
> +		if (queued == expect_len)
> +			break;
> +
> +		req.tv_sec = 0;
> +		req.tv_nsec = 1000 * 1000ul;
> +		nanosleep(&req, NULL);
> +	}
> +
> +	/* read one byte, expect cmsg to return expected - 1 */
> +	ret = recvmsg(fd, &msg, 0);
> +	if (ret < 0)
> +		die_perror("recvmsg");
> +
> +	if (msg.msg_controllen == 0)
> +		xerror("msg_controllen is 0");
> +
> +	get_tcp_inq(&msg, &tcp_inq);
> +
> +	assert((size_t)tcp_inq == (expect_len - 1));
> +
> +	iov.iov_len = sizeof(buf);
> +	ret = recvmsg(fd, &msg, 0);
> +	if (ret < 0)
> +		die_perror("recvmsg");
> +
> +	/* should have gotten exact remainder of all pending data */
> +	assert(ret == (ssize_t)tcp_inq);
> +
> +	/* should be 0, all drained */
> +	get_tcp_inq(&msg, &tcp_inq);
> +	assert(tcp_inq == 0);
> +
> +	/* request a large swath of data. */
> +	ret = write(unixfd, "huge", 4);
> +	assert(ret == 4);
> +
> +	ret = read(unixfd, &expect_len, sizeof(expect_len));
> +	assert(ret == (ssize_t)sizeof(expect_len));
> +
> +	/* peer should send us a few mb of data */
> +	if (expect_len <= sizeof(buf))
> +		xerror("expect len %zu too small\n", expect_len);
> +
> +	tot = 0;
> +	do {
> +		iov.iov_len = sizeof(buf);
> +		ret = recvmsg(fd, &msg, 0);
> +		if (ret < 0)
> +			die_perror("recvmsg");
> +
> +		tot += ret;
> +
> +		get_tcp_inq(&msg, &tcp_inq);
> +
> +		if (tcp_inq > expect_len - tot)
> +			xerror("inq %d, remaining %d total_len %d\n",
> +					tcp_inq, expect_len - tot, (int)expect_len);
> +
> +		assert(tcp_inq <= expect_len - tot);
> +	} while ((size_t)tot < expect_len);
> +
> +	ret = write(unixfd, "shut", 4);
> +	assert(ret == 4);
> +
> +	/* wait for hangup. Should have received one more byte of data. */
> +	ret = read(unixfd, tmp, sizeof(tmp));
> +	assert(ret == 6);
> +	assert(strncmp(tmp, "closed", 6) == 0);
> +
> +	sleep(1);
> +
> +	iov.iov_len = 1;
> +	ret = recvmsg(fd, &msg, 0);
> +	if (ret < 0)
> +		die_perror("recvmsg");
> +	assert(ret == 1);
> +
> +	get_tcp_inq(&msg, &tcp_inq);
> +
> +	/* tcp_inq should be 1 due to received fin. */
> +	assert(tcp_inq == 1);
> +
> +	iov.iov_len = 1;
> +	ret = recvmsg(fd, &msg, 0);
> +	if (ret < 0)
> +		die_perror("recvmsg");
> +
> +	/* expect EOF */
> +	assert(ret == 0);
> +	get_tcp_inq(&msg, &tcp_inq);
> +	assert(tcp_inq == 1);
> +
> +	close(fd);
> +}
> +
> +static int xaccept(int s)
> +{
> +	int fd = accept(s, NULL, 0);
> +
> +	if (fd < 0)
> +		die_perror("accept");
> +
> +	return fd;
> +}
> +
> +static int server(int unixfd)
> +{
> +	int fd = -1, r, on = 1;
> +
> +	switch (pf) {
> +	case AF_INET:
> +		fd = sock_listen_mptcp("127.0.0.1", "15432");
> +		break;
> +	case AF_INET6:
> +		fd = sock_listen_mptcp("::1", "15432");
> +		break;
> +	default:
> +		xerror("Unknown pf %d\n", pf);
> +		break;
> +	}
> +
> +	r = write(unixfd, "conn", 4);
> +	assert(r == 4);
> +
> +	alarm(15);
> +	r = xaccept(fd);
> +
> +	if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
> +		die_perror("setsockopt");
> +
> +	process_one_client(r, unixfd);
> +
> +	return 0;
> +}
> +
> +static int client(int unixfd)
> +{
> +	int fd = -1;
> +
> +	alarm(15);
> +
> +	switch (pf) {
> +	case AF_INET:
> +		fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
> +		break;
> +	case AF_INET6:
> +		fd = sock_connect_mptcp("::1", "15432", proto_tx);
> +		break;
> +	default:
> +		xerror("Unknown pf %d\n", pf);
> +	}
> +
> +	connect_one_server(fd, unixfd);
> +
> +	return 0;
> +}
> +
> +static void init_rng(void)
> +{
> +	int fd = open("/dev/urandom", O_RDONLY);
> +	unsigned int foo;
> +
> +	if (fd > 0) {
> +		int ret = read(fd, &foo, sizeof(foo));
> +
> +		if (ret < 0)
> +			srand(fd + foo);
> +		close(fd);
> +	}
> +
> +	srand(foo);
> +}
> +
> +static pid_t xfork(void)
> +{
> +	pid_t p = fork();
> +
> +	if (p < 0)
> +		die_perror("fork");
> +	else if (p == 0)
> +		init_rng();
> +
> +	return p;
> +}
> +
> +static int rcheck(int wstatus, const char *what)
> +{
> +	if (WIFEXITED(wstatus)) {
> +		if (WEXITSTATUS(wstatus) == 0)
> +			return 0;
> +		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
> +		return WEXITSTATUS(wstatus);
> +	} else if (WIFSIGNALED(wstatus)) {
> +		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
> +	} else if (WIFSTOPPED(wstatus)) {
> +		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
> +	}
> +
> +	return 111;
> +}
> +
> +int main(int argc, char *argv[])
> +{
> +	int e1, e2, wstatus;
> +	pid_t s, c, ret;
> +	int unixfds[2];
> +
> +	parse_opts(argc, argv);
> +
> +	e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
> +	if (e1 < 0)
> +		die_perror("pipe");
> +
> +	s = xfork();
> +	if (s == 0)
> +		return server(unixfds[1]);
> +
> +	close(unixfds[1]);
> +
> +	/* wait until server bound a socket */
> +	e1 = read(unixfds[0], &e1, 4);
> +	assert(e1 == 4);
> +
> +	c = xfork();
> +	if (c == 0)
> +		return client(unixfds[0]);
> +
> +	close(unixfds[0]);
> +
> +	ret = waitpid(s, &wstatus, 0);
> +	if (ret == -1)
> +		die_perror("waitpid");
> +	e1 = rcheck(wstatus, "server");
> +	ret = waitpid(c, &wstatus, 0);
> +	if (ret == -1)
> +		die_perror("waitpid");
> +	e2 = rcheck(wstatus, "client");
> +
> +	return e1 ? e1 : e2;
> +}
> diff --git a/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh b/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh
> index c8c364369599..0879da915014 100755
> --- a/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh
> +++ b/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh
> @@ -279,6 +279,45 @@ run_tests()
> 	fi
> }
>
> +do_tcpinq_test()
> +{
> +	ip netns exec "$ns1" ./mptcp_inq "$@"
> +	lret=$?
> +	if [ $lret -ne 0 ];then
> +		ret=$lret
> +		echo "FAIL: mptcp_inq $@" 1>&2
> +		return $lret
> +	fi
> +
> +	echo "PASS: TCP_INQ cmsg/ioctl $@"
> +	return $lret
> +}
> +
> +do_tcpinq_tests()
> +{
> +	local lret=0
> +
> +	ip netns exec "$ns1" iptables -F
> +	ip netns exec "$ns1" ip6tables -F
> +
> +	for args in "-t tcp" "-r tcp"; do
> +		do_tcpinq_test $args
> +		lret=$?
> +		if [ $lret -ne 0 ] ; then
> +			return $lret
> +		fi
> +		do_tcpinq_test -6 $args
> +		lret=$?
> +		if [ $lret -ne 0 ] ; then
> +			return $lret
> +		fi
> +	done
> +
> +	do_tcpinq_test -r tcp -t tcp
> +
> +	return $?
> +}
> +
> sin=$(mktemp)
> sout=$(mktemp)
> cin=$(mktemp)
> @@ -300,4 +339,5 @@ if [ $ret -eq 0 ];then
> 	echo "PASS: SOL_MPTCP getsockopt has expected information"
> fi
>
> +do_tcpinq_tests
> exit $ret
> -- 
> 2.32.0
>
>

--
Mat Martineau
Intel
Florian Westphal Nov. 16, 2021, 11:21 p.m. UTC | #2
Mat Martineau <mathew.j.martineau@linux.intel.com> wrote:
> > +static void connect_one_server(int fd, int unixfd)
> > +{
> > +	size_t len, i, total, sent;
> > +	char buf[4096], buf2[4096];
> > +	int queued, on = 1;
> 
> Still getting the 'unused variable' warnings on this line, but we can squash
> a fix for that (or wait to apply just this last patch?).

Sigh.  Mathieu, can you zap the line above (queued, on) after applying &
squash this?
diff mbox series

Patch

diff --git a/tools/testing/selftests/net/mptcp/.gitignore b/tools/testing/selftests/net/mptcp/.gitignore
index 7569d892967a..49daae73c41e 100644
--- a/tools/testing/selftests/net/mptcp/.gitignore
+++ b/tools/testing/selftests/net/mptcp/.gitignore
@@ -1,5 +1,6 @@ 
 # SPDX-License-Identifier: GPL-2.0-only
 mptcp_connect
+mptcp_inq
 mptcp_sockopt
 pm_nl_ctl
 *.pcap
diff --git a/tools/testing/selftests/net/mptcp/Makefile b/tools/testing/selftests/net/mptcp/Makefile
index bbf4e448bad9..0356c4501c99 100644
--- a/tools/testing/selftests/net/mptcp/Makefile
+++ b/tools/testing/selftests/net/mptcp/Makefile
@@ -8,7 +8,7 @@  CFLAGS =  -Wall -Wl,--no-as-needed -O2 -g  -I$(top_srcdir)/usr/include
 TEST_PROGS := mptcp_connect.sh pm_netlink.sh mptcp_join.sh diag.sh \
 	      simult_flows.sh mptcp_sockopt.sh
 
-TEST_GEN_FILES = mptcp_connect pm_nl_ctl mptcp_sockopt
+TEST_GEN_FILES = mptcp_connect pm_nl_ctl mptcp_sockopt mptcp_inq
 
 TEST_FILES := settings
 
diff --git a/tools/testing/selftests/net/mptcp/mptcp_inq.c b/tools/testing/selftests/net/mptcp/mptcp_inq.c
new file mode 100644
index 000000000000..62ee2dd63b5f
--- /dev/null
+++ b/tools/testing/selftests/net/mptcp/mptcp_inq.c
@@ -0,0 +1,603 @@ 
+// SPDX-License-Identifier: GPL-2.0
+
+#define _GNU_SOURCE
+
+#include <assert.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <string.h>
+#include <stdarg.h>
+#include <stdbool.h>
+#include <stdint.h>
+#include <inttypes.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <strings.h>
+#include <unistd.h>
+#include <time.h>
+
+#include <sys/ioctl.h>
+#include <sys/ioctl.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <sys/wait.h>
+
+#include <netdb.h>
+#include <netinet/in.h>
+
+#include <linux/tcp.h>
+#include <linux/sockios.h>
+
+#ifndef IPPROTO_MPTCP
+#define IPPROTO_MPTCP 262
+#endif
+#ifndef SOL_MPTCP
+#define SOL_MPTCP 284
+#endif
+
+static int pf = AF_INET;
+static int proto_tx = IPPROTO_MPTCP;
+static int proto_rx = IPPROTO_MPTCP;
+
+static void die_perror(const char *msg)
+{
+	perror(msg);
+	exit(1);
+}
+
+static void die_usage(int r)
+{
+	fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
+	exit(r);
+}
+
+static void xerror(const char *fmt, ...)
+{
+	va_list ap;
+
+	va_start(ap, fmt);
+	vfprintf(stderr, fmt, ap);
+	va_end(ap);
+	fputc('\n', stderr);
+	exit(1);
+}
+
+static const char *getxinfo_strerr(int err)
+{
+	if (err == EAI_SYSTEM)
+		return strerror(errno);
+
+	return gai_strerror(err);
+}
+
+static void xgetaddrinfo(const char *node, const char *service,
+			 const struct addrinfo *hints,
+			 struct addrinfo **res)
+{
+	int err = getaddrinfo(node, service, hints, res);
+
+	if (err) {
+		const char *errstr = getxinfo_strerr(err);
+
+		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
+			node ? node : "", service ? service : "", errstr);
+		exit(1);
+	}
+}
+
+static int sock_listen_mptcp(const char * const listenaddr,
+			     const char * const port)
+{
+	int sock;
+	struct addrinfo hints = {
+		.ai_protocol = IPPROTO_TCP,
+		.ai_socktype = SOCK_STREAM,
+		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
+	};
+
+	hints.ai_family = pf;
+
+	struct addrinfo *a, *addr;
+	int one = 1;
+
+	xgetaddrinfo(listenaddr, port, &hints, &addr);
+	hints.ai_family = pf;
+
+	for (a = addr; a; a = a->ai_next) {
+		sock = socket(a->ai_family, a->ai_socktype, proto_rx);
+		if (sock < 0)
+			continue;
+
+		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
+				     sizeof(one)))
+			perror("setsockopt");
+
+		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
+			break; /* success */
+
+		perror("bind");
+		close(sock);
+		sock = -1;
+	}
+
+	freeaddrinfo(addr);
+
+	if (sock < 0)
+		xerror("could not create listen socket");
+
+	if (listen(sock, 20))
+		die_perror("listen");
+
+	return sock;
+}
+
+static int sock_connect_mptcp(const char * const remoteaddr,
+			      const char * const port, int proto)
+{
+	struct addrinfo hints = {
+		.ai_protocol = IPPROTO_TCP,
+		.ai_socktype = SOCK_STREAM,
+	};
+	struct addrinfo *a, *addr;
+	int sock = -1;
+
+	hints.ai_family = pf;
+
+	xgetaddrinfo(remoteaddr, port, &hints, &addr);
+	for (a = addr; a; a = a->ai_next) {
+		sock = socket(a->ai_family, a->ai_socktype, proto);
+		if (sock < 0)
+			continue;
+
+		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
+			break; /* success */
+
+		die_perror("connect");
+	}
+
+	if (sock < 0)
+		xerror("could not create connect socket");
+
+	freeaddrinfo(addr);
+	return sock;
+}
+
+static int protostr_to_num(const char *s)
+{
+	if (strcasecmp(s, "tcp") == 0)
+		return IPPROTO_TCP;
+	if (strcasecmp(s, "mptcp") == 0)
+		return IPPROTO_MPTCP;
+
+	die_usage(1);
+	return 0;
+}
+
+static void parse_opts(int argc, char **argv)
+{
+	int c;
+
+	while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
+		switch (c) {
+		case 'h':
+			die_usage(0);
+			break;
+		case '6':
+			pf = AF_INET6;
+			break;
+		case 't':
+			proto_tx = protostr_to_num(optarg);
+			break;
+		case 'r':
+			proto_rx = protostr_to_num(optarg);
+			break;
+		default:
+			die_usage(1);
+			break;
+		}
+	}
+}
+
+/* wait up to timeout milliseconds */
+static void wait_for_ack(int fd, int timeout, size_t total)
+{
+	int i;
+
+	for (i = 0; i < timeout; i++) {
+		int nsd, ret, queued = -1;
+		struct timespec req;
+
+		ret = ioctl(fd, TIOCOUTQ, &queued);
+		if (ret < 0)
+			die_perror("TIOCOUTQ");
+
+		ret = ioctl(fd, SIOCOUTQNSD, &nsd);
+		if (ret < 0)
+			die_perror("SIOCOUTQNSD");
+
+		if ((size_t)queued > total)
+			xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
+		assert(nsd <= queued);
+
+		if (queued == 0)
+			return;
+
+		/* wait for peer to ack rx of all data */
+		req.tv_sec = 0;
+		req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
+		nanosleep(&req, NULL);
+	}
+
+	xerror("still tx data queued after %u ms\n", timeout);
+}
+
+static void connect_one_server(int fd, int unixfd)
+{
+	size_t len, i, total, sent;
+	char buf[4096], buf2[4096];
+	int queued, on = 1;
+	ssize_t ret;
+
+	len = rand() % (sizeof(buf) - 1);
+
+	if (len < 128)
+		len = 128;
+
+	for (i = 0; i < len ; i++) {
+		buf[i] = rand() % 26;
+		buf[i] += 'A';
+	}
+
+	buf[i] = '\n';
+
+	/* un-block server */
+	ret = read(unixfd, buf2, 4);
+	assert(ret == 4);
+
+	assert(strncmp(buf2, "xmit", 4) == 0);
+
+	ret = write(unixfd, &len, sizeof(len));
+	assert(ret == (ssize_t)sizeof(len));
+
+	ret = write(fd, buf, len);
+	if (ret < 0)
+		die_perror("write");
+
+	if (ret != (ssize_t)len)
+		xerror("short write");
+
+	ret = read(unixfd, buf2, 4);
+	assert(strncmp(buf2, "huge", 4) == 0);
+
+	total = rand() % (16 * 1024 * 1024);
+	total += (1 * 1024 * 1024);
+	sent = total;
+
+	ret = write(unixfd, &total, sizeof(total));
+	assert(ret == (ssize_t)sizeof(total));
+
+	wait_for_ack(fd, 5000, len);
+
+	while (total > 0) {
+		if (total > sizeof(buf))
+			len = sizeof(buf);
+		else
+			len = total;
+
+		ret = write(fd, buf, len);
+		if (ret < 0)
+			die_perror("write");
+		total -= ret;
+
+		/* we don't have to care about buf content, only
+		 * number of total bytes sent
+		 */
+	}
+
+	ret = read(unixfd, buf2, 4);
+	assert(ret == 4);
+	assert(strncmp(buf2, "shut", 4) == 0);
+
+	wait_for_ack(fd, 5000, sent);
+
+	write(fd, buf, 1);
+	close(fd);
+	ret = write(unixfd, "closed", 6);
+	assert(ret == 6);
+
+	close(unixfd);
+}
+
+static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
+{
+	struct cmsghdr *cmsg;
+
+	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
+		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
+			memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
+			return;
+		}
+	}
+
+	xerror("could not find TCP_CM_INQ cmsg type");
+}
+
+static void process_one_client(int fd, int unixfd)
+{
+	unsigned int tcp_inq;
+	size_t expect_len;
+	char msg_buf[4096];
+	char buf[4096];
+	char tmp[16];
+	struct iovec iov = {
+		.iov_base = buf,
+		.iov_len = 1,
+	};
+	struct msghdr msg = {
+		.msg_iov = &iov,
+		.msg_iovlen = 1,
+		.msg_control = msg_buf,
+		.msg_controllen = sizeof(msg_buf),
+	};
+	ssize_t ret, tot;
+
+	ret = write(unixfd, "xmit", 4);
+	assert(ret == 4);
+
+	ret = read(unixfd, &expect_len, sizeof(expect_len));
+	assert(ret == (ssize_t)sizeof(expect_len));
+
+	if (expect_len > sizeof(buf))
+		xerror("expect len %zu exceeds buffer size", expect_len);
+
+	for (;;) {
+		struct timespec req;
+		unsigned int queued;
+
+		ret = ioctl(fd, FIONREAD, &queued);
+		if (ret < 0)
+			die_perror("FIONREAD");
+		if (queued > expect_len)
+			xerror("FIONREAD returned %u, but only %zu expected\n",
+					queued, expect_len);
+		if (queued == expect_len)
+			break;
+
+		req.tv_sec = 0;
+		req.tv_nsec = 1000 * 1000ul;
+		nanosleep(&req, NULL);
+	}
+
+	/* read one byte, expect cmsg to return expected - 1 */
+	ret = recvmsg(fd, &msg, 0);
+	if (ret < 0)
+		die_perror("recvmsg");
+
+	if (msg.msg_controllen == 0)
+		xerror("msg_controllen is 0");
+
+	get_tcp_inq(&msg, &tcp_inq);
+
+	assert((size_t)tcp_inq == (expect_len - 1));
+
+	iov.iov_len = sizeof(buf);
+	ret = recvmsg(fd, &msg, 0);
+	if (ret < 0)
+		die_perror("recvmsg");
+
+	/* should have gotten exact remainder of all pending data */
+	assert(ret == (ssize_t)tcp_inq);
+
+	/* should be 0, all drained */
+	get_tcp_inq(&msg, &tcp_inq);
+	assert(tcp_inq == 0);
+
+	/* request a large swath of data. */
+	ret = write(unixfd, "huge", 4);
+	assert(ret == 4);
+
+	ret = read(unixfd, &expect_len, sizeof(expect_len));
+	assert(ret == (ssize_t)sizeof(expect_len));
+
+	/* peer should send us a few mb of data */
+	if (expect_len <= sizeof(buf))
+		xerror("expect len %zu too small\n", expect_len);
+
+	tot = 0;
+	do {
+		iov.iov_len = sizeof(buf);
+		ret = recvmsg(fd, &msg, 0);
+		if (ret < 0)
+			die_perror("recvmsg");
+
+		tot += ret;
+
+		get_tcp_inq(&msg, &tcp_inq);
+
+		if (tcp_inq > expect_len - tot)
+			xerror("inq %d, remaining %d total_len %d\n",
+					tcp_inq, expect_len - tot, (int)expect_len);
+
+		assert(tcp_inq <= expect_len - tot);
+	} while ((size_t)tot < expect_len);
+
+	ret = write(unixfd, "shut", 4);
+	assert(ret == 4);
+
+	/* wait for hangup. Should have received one more byte of data. */
+	ret = read(unixfd, tmp, sizeof(tmp));
+	assert(ret == 6);
+	assert(strncmp(tmp, "closed", 6) == 0);
+
+	sleep(1);
+
+	iov.iov_len = 1;
+	ret = recvmsg(fd, &msg, 0);
+	if (ret < 0)
+		die_perror("recvmsg");
+	assert(ret == 1);
+
+	get_tcp_inq(&msg, &tcp_inq);
+
+	/* tcp_inq should be 1 due to received fin. */
+	assert(tcp_inq == 1);
+
+	iov.iov_len = 1;
+	ret = recvmsg(fd, &msg, 0);
+	if (ret < 0)
+		die_perror("recvmsg");
+
+	/* expect EOF */
+	assert(ret == 0);
+	get_tcp_inq(&msg, &tcp_inq);
+	assert(tcp_inq == 1);
+
+	close(fd);
+}
+
+static int xaccept(int s)
+{
+	int fd = accept(s, NULL, 0);
+
+	if (fd < 0)
+		die_perror("accept");
+
+	return fd;
+}
+
+static int server(int unixfd)
+{
+	int fd = -1, r, on = 1;
+
+	switch (pf) {
+	case AF_INET:
+		fd = sock_listen_mptcp("127.0.0.1", "15432");
+		break;
+	case AF_INET6:
+		fd = sock_listen_mptcp("::1", "15432");
+		break;
+	default:
+		xerror("Unknown pf %d\n", pf);
+		break;
+	}
+
+	r = write(unixfd, "conn", 4);
+	assert(r == 4);
+
+	alarm(15);
+	r = xaccept(fd);
+
+	if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
+		die_perror("setsockopt");
+
+	process_one_client(r, unixfd);
+
+	return 0;
+}
+
+static int client(int unixfd)
+{
+	int fd = -1;
+
+	alarm(15);
+
+	switch (pf) {
+	case AF_INET:
+		fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
+		break;
+	case AF_INET6:
+		fd = sock_connect_mptcp("::1", "15432", proto_tx);
+		break;
+	default:
+		xerror("Unknown pf %d\n", pf);
+	}
+
+	connect_one_server(fd, unixfd);
+
+	return 0;
+}
+
+static void init_rng(void)
+{
+	int fd = open("/dev/urandom", O_RDONLY);
+	unsigned int foo;
+
+	if (fd > 0) {
+		int ret = read(fd, &foo, sizeof(foo));
+
+		if (ret < 0)
+			srand(fd + foo);
+		close(fd);
+	}
+
+	srand(foo);
+}
+
+static pid_t xfork(void)
+{
+	pid_t p = fork();
+
+	if (p < 0)
+		die_perror("fork");
+	else if (p == 0)
+		init_rng();
+
+	return p;
+}
+
+static int rcheck(int wstatus, const char *what)
+{
+	if (WIFEXITED(wstatus)) {
+		if (WEXITSTATUS(wstatus) == 0)
+			return 0;
+		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
+		return WEXITSTATUS(wstatus);
+	} else if (WIFSIGNALED(wstatus)) {
+		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
+	} else if (WIFSTOPPED(wstatus)) {
+		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
+	}
+
+	return 111;
+}
+
+int main(int argc, char *argv[])
+{
+	int e1, e2, wstatus;
+	pid_t s, c, ret;
+	int unixfds[2];
+
+	parse_opts(argc, argv);
+
+	e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
+	if (e1 < 0)
+		die_perror("pipe");
+
+	s = xfork();
+	if (s == 0)
+		return server(unixfds[1]);
+
+	close(unixfds[1]);
+
+	/* wait until server bound a socket */
+	e1 = read(unixfds[0], &e1, 4);
+	assert(e1 == 4);
+
+	c = xfork();
+	if (c == 0)
+		return client(unixfds[0]);
+
+	close(unixfds[0]);
+
+	ret = waitpid(s, &wstatus, 0);
+	if (ret == -1)
+		die_perror("waitpid");
+	e1 = rcheck(wstatus, "server");
+	ret = waitpid(c, &wstatus, 0);
+	if (ret == -1)
+		die_perror("waitpid");
+	e2 = rcheck(wstatus, "client");
+
+	return e1 ? e1 : e2;
+}
diff --git a/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh b/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh
index c8c364369599..0879da915014 100755
--- a/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh
+++ b/tools/testing/selftests/net/mptcp/mptcp_sockopt.sh
@@ -279,6 +279,45 @@  run_tests()
 	fi
 }
 
+do_tcpinq_test()
+{
+	ip netns exec "$ns1" ./mptcp_inq "$@"
+	lret=$?
+	if [ $lret -ne 0 ];then
+		ret=$lret
+		echo "FAIL: mptcp_inq $@" 1>&2
+		return $lret
+	fi
+
+	echo "PASS: TCP_INQ cmsg/ioctl $@"
+	return $lret
+}
+
+do_tcpinq_tests()
+{
+	local lret=0
+
+	ip netns exec "$ns1" iptables -F
+	ip netns exec "$ns1" ip6tables -F
+
+	for args in "-t tcp" "-r tcp"; do
+		do_tcpinq_test $args
+		lret=$?
+		if [ $lret -ne 0 ] ; then
+			return $lret
+		fi
+		do_tcpinq_test -6 $args
+		lret=$?
+		if [ $lret -ne 0 ] ; then
+			return $lret
+		fi
+	done
+
+	do_tcpinq_test -r tcp -t tcp
+
+	return $?
+}
+
 sin=$(mktemp)
 sout=$(mktemp)
 cin=$(mktemp)
@@ -300,4 +339,5 @@  if [ $ret -eq 0 ];then
 	echo "PASS: SOL_MPTCP getsockopt has expected information"
 fi
 
+do_tcpinq_tests
 exit $ret