@@ -57,6 +57,7 @@ struct sock_xprt {
struct work_struct error_worker;
struct work_struct recv_worker;
struct mutex recv_mutex;
+ struct completion handshake_done;
struct sockaddr_storage srcaddr;
unsigned short srcport;
int xprt_err;
@@ -48,6 +48,7 @@
#include <net/udp.h>
#include <net/tcp.h>
#include <net/tls.h>
+#include <net/handshake.h>
#include <linux/bvec.h>
#include <linux/highmem.h>
@@ -189,6 +190,11 @@ static struct ctl_table xs_tunables_table[] = {
*/
#define XS_IDLE_DISC_TO (5U * 60 * HZ)
+/*
+ * TLS handshake timeout.
+ */
+#define XS_TLS_HANDSHAKE_TO (10U * HZ)
+
#if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
# undef RPC_DEBUG_DATA
# define RPCDBG_FACILITY RPCDBG_TRANS
@@ -1238,6 +1244,10 @@ static void xs_reset_transport(struct sock_xprt *transport)
if (atomic_read(&transport->xprt.swapper))
sk_clear_memalloc(sk);
+ /* XXX: Maybe also send a TLS Closure alert? */
+
+ tls_handshake_cancel(sk);
+
kernel_sock_shutdown(sock, SHUT_RDWR);
mutex_lock(&transport->recv_mutex);
@@ -2411,60 +2421,266 @@ static void xs_tcp_setup_socket(struct work_struct *work)
current_restore_flags(pflags, PF_MEMALLOC);
}
+/*
+ * Transfer the connected socket to @upper_transport, then mark that
+ * xprt CONNECTED.
+ */
+static int xs_tls_finish_connecting(struct rpc_xprt *lower_xprt,
+ struct sock_xprt *upper_transport)
+{
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+ struct rpc_xprt *upper_xprt = &upper_transport->xprt;
+
+ if (!upper_transport->inet) {
+ struct socket *sock = lower_transport->sock;
+ struct sock *sk = sock->sk;
+
+ /* Avoid temporary address, they are bad for long-lived
+ * connections such as NFS mounts.
+ * RFC4941, section 3.6 suggests that:
+ * Individual applications, which have specific
+ * knowledge about the normal duration of connections,
+ * MAY override this as appropriate.
+ */
+ if (xs_addr(upper_xprt)->sa_family == PF_INET6) {
+ ip6_sock_set_addr_preferences(sk,
+ IPV6_PREFER_SRC_PUBLIC);
+ }
+
+ xs_tcp_set_socket_timeouts(upper_xprt, sock);
+ tcp_sock_set_nodelay(sk);
+
+ lock_sock(sk);
+
+ /*
+ * @sk is already connected, so it now has the RPC callbacks.
+ * Reach into @lower_transport to save the original ones.
+ */
+ upper_transport->old_data_ready = lower_transport->old_data_ready;
+ upper_transport->old_state_change = lower_transport->old_state_change;
+ upper_transport->old_write_space = lower_transport->old_write_space;
+ upper_transport->old_error_report = lower_transport->old_error_report;
+ sk->sk_user_data = upper_xprt;
+
+ /* socket options */
+ sock_reset_flag(sk, SOCK_LINGER);
+
+ xprt_clear_connected(upper_xprt);
+
+ upper_transport->sock = sock;
+ upper_transport->inet = sk;
+ upper_transport->file = lower_transport->file;
+
+ release_sock(sk);
+
+ /* Reset lower_transport before shutting down its clnt */
+ mutex_lock(&lower_transport->recv_mutex);
+ lower_transport->inet = NULL;
+ lower_transport->sock = NULL;
+ lower_transport->file = NULL;
+
+ xprt_clear_connected(lower_xprt);
+ xs_sock_reset_connection_flags(lower_xprt);
+ xs_stream_reset_connect(lower_transport);
+ mutex_unlock(&lower_transport->recv_mutex);
+ }
+
+ if (!xprt_bound(upper_xprt))
+ return -ENOTCONN;
+
+ xs_set_memalloc(upper_xprt);
+
+ if (!xprt_test_and_set_connected(upper_xprt)) {
+ upper_xprt->connect_cookie++;
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+
+ upper_xprt->stat.connect_count++;
+ upper_xprt->stat.connect_time += (long)jiffies -
+ upper_xprt->stat.connect_start;
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ }
+ return 0;
+}
+
/**
- * xs_tls_connect - establish a TLS session on a socket
- * @work: queued work item
+ * xs_tls_handshake_done - TLS handshake completion handler
+ * @data: address of xprt to wake
+ * @status: status of handshake
+ * @peerid: serial number of key containing the remote's identity
*
*/
-static void xs_tls_connect(struct work_struct *work)
+static void xs_tls_handshake_done(void *data, int status, key_serial_t peerid)
{
- struct sock_xprt *transport =
- container_of(work, struct sock_xprt, connect_worker.work);
- struct rpc_clnt *clnt;
+ struct rpc_xprt *lower_xprt = data;
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
- clnt = transport->clnt;
- transport->clnt = NULL;
- if (IS_ERR(clnt))
- goto out_unlock;
+ lower_transport->xprt_err = status ? -EACCES : 0;
+ complete(&lower_transport->handshake_done);
+ xprt_put(lower_xprt);
+}
- xs_tcp_setup_socket(work);
+static int xs_tls_handshake_sync(struct rpc_xprt *lower_xprt, struct xprtsec_parms *xprtsec)
+{
+ struct sock_xprt *lower_transport =
+ container_of(lower_xprt, struct sock_xprt, xprt);
+ struct tls_handshake_args args = {
+ .ta_sock = lower_transport->sock,
+ .ta_done = xs_tls_handshake_done,
+ .ta_data = xprt_get(lower_xprt),
+ .ta_peername = lower_xprt->servername,
+ };
+ struct sock *sk = lower_transport->inet;
+ int rc;
- rpc_shutdown_client(clnt);
+ init_completion(&lower_transport->handshake_done);
+ set_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
-out_unlock:
- return;
+ lower_transport->xprt_err = -ETIMEDOUT;
+ switch (xprtsec->policy) {
+ case RPC_XPRTSEC_TLS_ANON:
+ rc = tls_client_hello_anon(&args, GFP_KERNEL);
+ if (rc)
+ goto out_put_xprt;
+ break;
+ case RPC_XPRTSEC_TLS_X509:
+ args.ta_my_cert = xprtsec->cert_serial;
+ args.ta_my_privkey = xprtsec->privkey_serial;
+ rc = tls_client_hello_x509(&args, GFP_KERNEL);
+ if (rc)
+ goto out_put_xprt;
+ break;
+ default:
+ rc = -EACCES;
+ goto out_put_xprt;
+ }
+
+ rc = wait_for_completion_interruptible_timeout(&lower_transport->handshake_done,
+ XS_TLS_HANDSHAKE_TO);
+ if (rc <= 0) {
+ if (!tls_handshake_cancel(sk)) {
+ if (rc == 0)
+ rc = -ETIMEDOUT;
+ goto out_put_xprt;
+ }
+ }
+
+ rc = lower_transport->xprt_err;
+
+out:
+ xs_stream_reset_connect(lower_transport);
+ clear_bit(XPRT_SOCK_IGNORE_RECV, &lower_transport->sock_state);
+ return rc;
+
+out_put_xprt:
+ xprt_put(lower_xprt);
+ goto out;
}
-static void xs_set_transport_clnt(struct rpc_clnt *clnt, struct rpc_xprt *xprt)
+/**
+ * xs_tls_connect - establish a TLS session on a socket
+ * @work: queued work item
+ *
+ * For RPC-with-TLS, there is a two-stage connection process.
+ *
+ * The "upper-layer xprt" is visible to the RPC consumer. Once it has
+ * been marked connected, the consumer knows that a TCP connection and
+ * a TLS session have been established.
+ *
+ * A "lower-layer xprt", created in this function, handles the mechanics
+ * of connecting the TCP socket, performing the RPC_AUTH_TLS probe, and
+ * then driving the TLS handshake. Once all that is complete, the upper
+ * layer xprt is marked connected.
+ */
+static void xs_tls_connect(struct work_struct *work)
{
- struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+ struct sock_xprt *upper_transport =
+ container_of(work, struct sock_xprt, connect_worker.work);
+ struct rpc_clnt *upper_clnt = upper_transport->clnt;
+ struct rpc_xprt *upper_xprt = &upper_transport->xprt;
struct rpc_create_args args = {
- .net = xprt->xprt_net,
- .protocol = xprt->prot,
- .address = (struct sockaddr *)&xprt->addr,
- .addrsize = xprt->addrlen,
- .timeout = clnt->cl_timeout,
- .servername = xprt->servername,
- .nodename = clnt->cl_nodename,
- .program = clnt->cl_program,
- .prognumber = clnt->cl_prog,
- .version = clnt->cl_vers,
+ .net = upper_xprt->xprt_net,
+ .protocol = upper_xprt->prot,
+ .address = (struct sockaddr *)&upper_xprt->addr,
+ .addrsize = upper_xprt->addrlen,
+ .timeout = upper_clnt->cl_timeout,
+ .servername = upper_xprt->servername,
+ .nodename = upper_clnt->cl_nodename,
+ .program = upper_clnt->cl_program,
+ .prognumber = upper_clnt->cl_prog,
+ .version = upper_clnt->cl_vers,
.authflavor = RPC_AUTH_TLS,
- .cred = clnt->cl_cred,
+ .cred = upper_clnt->cl_cred,
.xprtsec = {
.policy = RPC_XPRTSEC_NONE,
},
- .flags = RPC_CLNT_CREATE_NOPING,
};
+ unsigned int pflags = current->flags;
+ struct rpc_clnt *lower_clnt;
+ struct rpc_xprt *lower_xprt;
+ int status;
- switch (xprt->xprtsec.policy) {
- case RPC_XPRTSEC_TLS_ANON:
- case RPC_XPRTSEC_TLS_X509:
- transport->clnt = rpc_create(&args);
- break;
- default:
- transport->clnt = ERR_PTR(-ENOTCONN);
+ if (atomic_read(&upper_xprt->swapper))
+ current->flags |= PF_MEMALLOC;
+
+ xs_stream_start_connect(upper_transport);
+
+ /* This implicitly sends an RPC_AUTH_TLS probe */
+ lower_clnt = rpc_create(&args);
+ if (IS_ERR(lower_clnt)) {
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+ xprt_wake_pending_tasks(upper_xprt, PTR_ERR(lower_clnt));
+ smp_mb__before_atomic();
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ goto out_unlock;
}
+
+ /* RPC_AUTH_TLS probe was successful. Try a TLS handshake on
+ * the lower xprt.
+ */
+ rcu_read_lock();
+ lower_xprt = rcu_dereference(lower_clnt->cl_xprt);
+ rcu_read_unlock();
+ status = xs_tls_handshake_sync(lower_xprt, &upper_xprt->xprtsec);
+ if (status)
+ goto out_close;
+
+ status = xs_tls_finish_connecting(lower_xprt, upper_transport);
+ if (status)
+ goto out_close;
+
+ trace_rpc_socket_connect(upper_xprt, upper_transport->sock, 0);
+ if (!xprt_test_and_set_connected(upper_xprt)) {
+ upper_xprt->connect_cookie++;
+ clear_bit(XPRT_SOCK_CONNECTING, &upper_transport->sock_state);
+ xprt_clear_connecting(upper_xprt);
+
+ upper_xprt->stat.connect_count++;
+ upper_xprt->stat.connect_time += (long)jiffies -
+ upper_xprt->stat.connect_start;
+ xs_run_error_worker(upper_transport, XPRT_SOCK_WAKE_PENDING);
+ }
+ rpc_shutdown_client(lower_clnt);
+
+out_unlock:
+ current_restore_flags(pflags, PF_MEMALLOC);
+ upper_transport->clnt = NULL;
+ xprt_unlock_connect(upper_xprt, upper_transport);
+ return;
+
+out_close:
+ rpc_shutdown_client(lower_clnt);
+
+ /* xprt_force_disconnect() wakes tasks with a fixed tk_status code.
+ * Wake them first here to ensure they get our tk_status code.
+ */
+ xprt_wake_pending_tasks(upper_xprt, status);
+ xs_tcp_force_close(upper_xprt);
+ xprt_clear_connecting(upper_xprt);
+ goto out_unlock;
}
/**
@@ -2498,8 +2714,7 @@ static void xs_connect(struct rpc_xprt *xprt, struct rpc_task *task)
} else
dprintk("RPC: xs_connect scheduled xprt %p\n", xprt);
- xs_set_transport_clnt(task->tk_client, xprt);
-
+ transport->clnt = task->tk_client;
queue_delayed_work(xprtiod_workqueue,
&transport->connect_worker,
delay);