diff mbox series

[1/2] sunrpc: implement rfc2203 rpcsec_gss seqnum cache

Message ID 20250310A151130a3b13064.njha@janestreet.com (mailing list archive)
State Handled Elsewhere
Headers show
Series fix gss seqno handling to be more rfc-compliant | expand

Commit Message

Nikhil Jha March 10, 2025, 3:11 p.m. UTC
This implements a sequence number cache of the last three (right now
hardcoded) sent sequence numbers for a given XID, as suggested by the
RFC.

From RFC2203 5.3.3.1:

"Note that the sequence number algorithm requires that the client
increment the sequence number even if it is retrying a request with
the same RPC transaction identifier.  It is not infrequent for
clients to get into a situation where they send two or more attempts
and a slow server sends the reply for the first attempt. With
RPCSEC_GSS, each request and reply will have a unique sequence
number. If the client wishes to improve turn around time on the RPC
call, it can cache the RPCSEC_GSS sequence number of each request it
sends. Then when it receives a response with a matching RPC
transaction identifier, it can compute the checksum of each sequence
number in the cache to try to match the checksum in the reply's
verifier."

Signed-off-by: Nikhil Jha <njha@janestreet.com>
---
 include/linux/sunrpc/xprt.h    | 31 +++++++++++++++++++++++++++++-
 net/sunrpc/auth_gss/auth_gss.c | 35 +++++++++++++++++++++++-----------
 net/sunrpc/xprt.c              |  1 +
 3 files changed, 55 insertions(+), 12 deletions(-)
diff mbox series

Patch

diff --git a/include/linux/sunrpc/xprt.h b/include/linux/sunrpc/xprt.h
index 81b952649d35..023cebacea37 100644
--- a/include/linux/sunrpc/xprt.h
+++ b/include/linux/sunrpc/xprt.h
@@ -30,6 +30,8 @@ 
 #define RPC_MAXCWND(xprt)	((xprt)->max_reqs << RPC_CWNDSHIFT)
 #define RPCXPRT_CONGESTED(xprt) ((xprt)->cong >= (xprt)->cwnd)
 
+#define RPC_GSS_SEQNO_ARRAY_SIZE 3U
+
 enum rpc_display_format_t {
 	RPC_DISPLAY_ADDR = 0,
 	RPC_DISPLAY_PORT,
@@ -66,7 +68,9 @@  struct rpc_rqst {
 	struct rpc_cred *	rq_cred;	/* Bound cred */
 	__be32			rq_xid;		/* request XID */
 	int			rq_cong;	/* has incremented xprt->cong */
-	u32			rq_seqno;	/* gss seq no. used on req. */
+	u32			rq_seqno;	/* latest gss seq no. used on req. */
+	u32			rq_seqnos[RPC_GSS_SEQNO_ARRAY_SIZE];	/* past req seqnos */
+	unsigned int		rq_seqno_count;	/* number of entries in the array */
 	int			rq_enc_pages_num;
 	struct page		**rq_enc_pages;	/* scratch pages for use by
 						   gss privacy code */
@@ -119,6 +123,31 @@  struct rpc_rqst {
 #define rq_svec			rq_snd_buf.head
 #define rq_slen			rq_snd_buf.len
 
+static inline void xdr_init_gss_seqnos(struct rpc_rqst *req)
+{
+	req->rq_seqno = 0;
+	req->rq_seqno_count = 0;
+}
+
+static inline int xdr_add_gss_seqno(struct rpc_rqst *req, u32 seqno)
+{
+	if (likely(req->rq_seqno_count < RPC_GSS_SEQNO_ARRAY_SIZE)) {
+		req->rq_seqnos[req->rq_seqno_count++] = req->rq_seqno;
+	} else {
+		/* Shift array to make room for the most recent one */
+		memmove(&req->rq_seqnos[0], &req->rq_seqnos[1],
+			(RPC_GSS_SEQNO_ARRAY_SIZE - 1) * sizeof(req->rq_seqnos[0]));
+		req->rq_seqnos[RPC_GSS_SEQNO_ARRAY_SIZE - 1] = req->rq_seqno;
+	}
+	req->rq_seqno = seqno;
+	return 0;
+}
+
+static inline bool xdr_gss_seqnos_empty(struct rpc_rqst *req)
+{
+	return req->rq_seqno == 0 && req->rq_seqno_count == 0;
+}
+
 /* RPC transport layer security policies */
 enum xprtsec_policies {
 	RPC_XPRTSEC_NONE = 0,
diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c
index 369310909fc9..c79300965391 100644
--- a/net/sunrpc/auth_gss/auth_gss.c
+++ b/net/sunrpc/auth_gss/auth_gss.c
@@ -1545,6 +1545,7 @@  static int gss_marshal(struct rpc_task *task, struct xdr_stream *xdr)
 	struct kvec	iov;
 	struct xdr_buf	verf_buf;
 	int status;
+	u32 seqno;
 
 	/* Credential */
 
@@ -1556,7 +1557,8 @@  static int gss_marshal(struct rpc_task *task, struct xdr_stream *xdr)
 	cred_len = p++;
 
 	spin_lock(&ctx->gc_seq_lock);
-	req->rq_seqno = (ctx->gc_seq < MAXSEQ) ? ctx->gc_seq++ : MAXSEQ;
+	seqno = (ctx->gc_seq < MAXSEQ) ? ctx->gc_seq++ : MAXSEQ;
+	xdr_add_gss_seqno(req, seqno);
 	spin_unlock(&ctx->gc_seq_lock);
 	if (req->rq_seqno == MAXSEQ)
 		goto expired;
@@ -1678,17 +1680,31 @@  gss_refresh_null(struct rpc_task *task)
 	return 0;
 }
 
+static u32
+gss_validate_seqno_mic(struct gss_cl_ctx *ctx, u32 seqno, __be32 *seq, __be32 *p, u32 len)
+{
+	struct kvec iov;
+	struct xdr_buf verf_buf;
+	struct xdr_netobj mic;
+
+	*seq = cpu_to_be32(seqno);
+	iov.iov_base = seq;
+	iov.iov_len = 4;
+	xdr_buf_from_iov(&iov, &verf_buf);
+	mic.data = (u8 *)p;
+	mic.len = len;
+	return gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
+}
+
 static int
 gss_validate(struct rpc_task *task, struct xdr_stream *xdr)
 {
 	struct rpc_cred *cred = task->tk_rqstp->rq_cred;
 	struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
 	__be32		*p, *seq = NULL;
-	struct kvec	iov;
-	struct xdr_buf	verf_buf;
-	struct xdr_netobj mic;
 	u32		len, maj_stat;
 	int		status;
+	int		i = 1; /* don't recheck the first item */
 
 	p = xdr_inline_decode(xdr, 2 * sizeof(*p));
 	if (!p)
@@ -1705,13 +1721,10 @@  gss_validate(struct rpc_task *task, struct xdr_stream *xdr)
 	seq = kmalloc(4, GFP_KERNEL);
 	if (!seq)
 		goto validate_failed;
-	*seq = cpu_to_be32(task->tk_rqstp->rq_seqno);
-	iov.iov_base = seq;
-	iov.iov_len = 4;
-	xdr_buf_from_iov(&iov, &verf_buf);
-	mic.data = (u8 *)p;
-	mic.len = len;
-	maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
+	maj_stat = gss_validate_seqno_mic(ctx, task->tk_rqstp->rq_seqno, seq, p, len);
+	/* RFC 2203 5.3.3.1 - compute the checksum of each sequence number in the cache */
+	while (unlikely(maj_stat == GSS_S_BAD_SIG && i < task->tk_rqstp->rq_seqno_count))
+		maj_stat = gss_validate_seqno_mic(ctx, task->tk_rqstp->rq_seqnos[i], seq, p, len);
 	if (maj_stat == GSS_S_CONTEXT_EXPIRED)
 		clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
 	if (maj_stat)
diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c
index 09f245cda526..7da7d0b0a018 100644
--- a/net/sunrpc/xprt.c
+++ b/net/sunrpc/xprt.c
@@ -1898,6 +1898,7 @@  xprt_request_init(struct rpc_task *task)
 	req->rq_snd_buf.bvec = NULL;
 	req->rq_rcv_buf.bvec = NULL;
 	req->rq_release_snd_buf = NULL;
+	xdr_init_gss_seqnos(req);
 	xprt_init_majortimeo(task, req, task->tk_client->cl_timeout);
 
 	trace_xprt_reserve(req);