@@ -29,10 +29,12 @@ TCP_AUTHOPT_KEY = 39
TCP_AUTHOPT_MAXKEYLEN = 80
class TCP_AUTHOPT_FLAG(IntFlag):
+ LOCK_KEYID = BIT(0)
+ LOCK_RNEXTKEYID = BIT(1)
REJECT_UNEXPECTED = BIT(2)
class TCP_AUTHOPT_KEY_FLAG(IntFlag):
DEL = BIT(0)
@@ -48,24 +50,32 @@ class TCP_AUTHOPT_ALG(IntEnum):
@dataclass
class tcp_authopt:
"""Like linux struct tcp_authopt"""
flags: int = 0
- sizeof = 4
+ send_keyid: int = 0
+ send_rnextkeyid: int = 0
+ recv_keyid: int = 0
+ recv_rnextkeyid: int = 0
+ sizeof = 8
def pack(self) -> bytes:
return struct.pack(
- "I",
+ "IBBBB",
self.flags,
+ self.send_keyid,
+ self.send_rnextkeyid,
+ self.recv_keyid,
+ self.recv_rnextkeyid,
)
def __bytes__(self):
return self.pack()
@classmethod
def unpack(cls, b: bytes):
- tup = struct.unpack("I", b)
+ tup = struct.unpack("IBBBB", b)
return cls(*tup)
def set_tcp_authopt(sock, opt: tcp_authopt):
return sock.setsockopt(socket.SOL_TCP, TCP_AUTHOPT, bytes(opt))
new file mode 100644
@@ -0,0 +1,181 @@
+# SPDX-License-Identifier: GPL-2.0
+import socket
+import typing
+from contextlib import ExitStack, contextmanager
+
+from .conftest import skipif_missing_tcp_authopt
+from .linux_tcp_authopt import (
+ TCP_AUTHOPT_FLAG,
+ get_tcp_authopt,
+ set_tcp_authopt,
+ set_tcp_authopt_key,
+ tcp_authopt,
+ tcp_authopt_key,
+)
+from .server import SimpleServerThread
+from .utils import DEFAULT_TCP_SERVER_PORT, check_socket_echo, create_listen_socket
+
+pytestmark = skipif_missing_tcp_authopt
+
+
+@contextmanager
+def make_tcp_authopt_socket_pair(
+ server_addr="127.0.0.1",
+ server_authopt: tcp_authopt = None,
+ server_key_list: typing.Iterable[tcp_authopt_key] = [],
+ client_authopt: tcp_authopt = None,
+ client_key_list: typing.Iterable[tcp_authopt_key] = [],
+) -> typing.Iterator[typing.Tuple[socket.socket, socket.socket]]:
+ """Make a pair for connected sockets for key switching tests
+
+ Server runs in a background thread implementing echo protocol"""
+ with ExitStack() as exit_stack:
+ listen_socket = exit_stack.enter_context(
+ create_listen_socket(bind_addr=server_addr)
+ )
+ server_thread = exit_stack.enter_context(
+ SimpleServerThread(listen_socket, mode="echo")
+ )
+ client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ client_socket.settimeout(1.0)
+
+ if server_authopt:
+ set_tcp_authopt(listen_socket, server_authopt)
+ for k in server_key_list:
+ set_tcp_authopt_key(listen_socket, k)
+ if client_authopt:
+ set_tcp_authopt(client_socket, client_authopt)
+ for k in client_key_list:
+ set_tcp_authopt_key(client_socket, k)
+
+ client_socket.connect((server_addr, DEFAULT_TCP_SERVER_PORT))
+ check_socket_echo(client_socket)
+ server_socket = server_thread.server_socket[0]
+
+ yield client_socket, server_socket
+
+
+def test_get_keyids(exit_stack: ExitStack):
+ """Check reading key ids"""
+ sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+ sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+ ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+ client_socket, server_socket = exit_stack.enter_context(
+ make_tcp_authopt_socket_pair(
+ server_key_list=[sk1, sk2],
+ client_key_list=[ck1],
+ )
+ )
+
+ check_socket_echo(client_socket)
+ client_tcp_authopt = get_tcp_authopt(client_socket)
+ server_tcp_authopt = get_tcp_authopt(server_socket)
+ assert server_tcp_authopt.send_keyid == 11
+ assert server_tcp_authopt.send_rnextkeyid == 12
+ assert server_tcp_authopt.recv_keyid == 12
+ assert server_tcp_authopt.recv_rnextkeyid == 11
+ assert client_tcp_authopt.send_keyid == 12
+ assert client_tcp_authopt.send_rnextkeyid == 11
+ assert client_tcp_authopt.recv_keyid == 11
+ assert client_tcp_authopt.recv_rnextkeyid == 12
+
+
+def test_rollover_send_keyid(exit_stack: ExitStack):
+ """Check reading key ids"""
+ sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+ sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+ ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+ ck2 = tcp_authopt_key(send_id=22, recv_id=21, key="222")
+ client_socket, server_socket = exit_stack.enter_context(
+ make_tcp_authopt_socket_pair(
+ server_key_list=[sk1, sk2],
+ client_key_list=[ck1, ck2],
+ client_authopt=tcp_authopt(
+ send_keyid=12, flags=TCP_AUTHOPT_FLAG.LOCK_KEYID
+ ),
+ )
+ )
+
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(client_socket).recv_keyid == 11
+ assert get_tcp_authopt(server_socket).recv_keyid == 12
+
+ # Explicit request for key2
+ set_tcp_authopt(
+ client_socket, tcp_authopt(send_keyid=22, flags=TCP_AUTHOPT_FLAG.LOCK_KEYID)
+ )
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(client_socket).recv_keyid == 21
+ assert get_tcp_authopt(server_socket).recv_keyid == 22
+
+
+def test_rollover_rnextkeyid(exit_stack: ExitStack):
+ """Check reading key ids"""
+ sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+ sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+ ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+ ck2 = tcp_authopt_key(send_id=22, recv_id=21, key="222")
+ client_socket, server_socket = exit_stack.enter_context(
+ make_tcp_authopt_socket_pair(
+ server_key_list=[sk1],
+ client_key_list=[ck1, ck2],
+ client_authopt=tcp_authopt(
+ send_keyid=12, flags=TCP_AUTHOPT_FLAG.LOCK_KEYID
+ ),
+ )
+ )
+
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(server_socket).recv_rnextkeyid == 11
+
+ # request rnextkeyd=22 but server does not have it
+ set_tcp_authopt(
+ client_socket,
+ tcp_authopt(send_rnextkeyid=21, flags=TCP_AUTHOPT_FLAG.LOCK_RNEXTKEYID),
+ )
+ check_socket_echo(client_socket)
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(server_socket).recv_rnextkeyid == 21
+ assert get_tcp_authopt(server_socket).send_keyid == 11
+
+ # after adding k2 on server the key is switched
+ set_tcp_authopt_key(server_socket, sk2)
+ check_socket_echo(client_socket)
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(server_socket).send_keyid == 21
+
+
+def test_rollover_delkey(exit_stack: ExitStack):
+ sk1 = tcp_authopt_key(send_id=11, recv_id=12, key="111")
+ sk2 = tcp_authopt_key(send_id=21, recv_id=22, key="222")
+ ck1 = tcp_authopt_key(send_id=12, recv_id=11, key="111")
+ ck2 = tcp_authopt_key(send_id=22, recv_id=21, key="222")
+ client_socket, server_socket = exit_stack.enter_context(
+ make_tcp_authopt_socket_pair(
+ server_key_list=[sk1, sk2],
+ client_key_list=[ck1, ck2],
+ client_authopt=tcp_authopt(
+ send_keyid=12, flags=TCP_AUTHOPT_FLAG.LOCK_KEYID
+ ),
+ )
+ )
+
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(server_socket).recv_keyid == 12
+
+ # invalid send_keyid is just ignored
+ set_tcp_authopt(client_socket, tcp_authopt(send_keyid=7))
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(client_socket).send_keyid == 12
+ assert get_tcp_authopt(server_socket).recv_keyid == 12
+ assert get_tcp_authopt(client_socket).recv_keyid == 11
+
+ # If a key is removed it is replaced by anything that matches
+ ck1.delete_flag = True
+ set_tcp_authopt_key(client_socket, ck1)
+ check_socket_echo(client_socket)
+ check_socket_echo(client_socket)
+ assert get_tcp_authopt(client_socket).send_keyid == 22
+ assert get_tcp_authopt(server_socket).send_keyid == 21
+ assert get_tcp_authopt(server_socket).recv_keyid == 22
+ assert get_tcp_authopt(client_socket).recv_keyid == 21
RFC5925 requires that the use can examine or control the keys being used. This is implemented in linux via fields on the TCP_AUTHOPT sockopt. Add socket-level tests for the adjusting keyids on live connections and checking the they are reflected on the peer. Also check smooth transitions via rnextkeyid. Signed-off-by: Leonard Crestez <cdleonard@gmail.com> --- .../tcp_authopt_test/linux_tcp_authopt.py | 16 +- .../tcp_authopt_test/test_rollover.py | 181 ++++++++++++++++++ 2 files changed, 194 insertions(+), 3 deletions(-) create mode 100644 tools/testing/selftests/tcp_authopt/tcp_authopt_test/test_rollover.py