@@ -14,6 +14,7 @@
#include <net/sock.h>
#include <net/tcp.h>
#include <net/tls.h>
+#include <net/tls_prot.h>
#include <net/handshake.h>
#include <linux/inet.h>
#include <linux/llist.h>
@@ -118,6 +119,7 @@ struct nvmet_tcp_cmd {
u32 pdu_len;
u32 pdu_recv;
int sg_idx;
+ char recv_cbuf[CMSG_LEN(sizeof(char))];
struct msghdr recv_msg;
struct bio_vec *iov;
u32 flags;
@@ -1119,20 +1121,65 @@ static inline bool nvmet_tcp_pdu_valid(u8 type)
return false;
}
+static int nvmet_tcp_tls_record_ok(struct nvmet_tcp_queue *queue,
+ struct msghdr *msg, char *cbuf)
+{
+ struct cmsghdr *cmsg = (struct cmsghdr *)cbuf;
+ u8 ctype, level, description;
+ int ret = 0;
+
+ ctype = tls_get_record_type(queue->sock->sk, cmsg);
+ switch (ctype) {
+ case 0:
+ break;
+ case TLS_RECORD_TYPE_DATA:
+ break;
+ case TLS_RECORD_TYPE_ALERT:
+ tls_alert_recv(queue->sock->sk, msg, &level, &description);
+ if (level == TLS_ALERT_LEVEL_FATAL) {
+ pr_err("queue %d: TLS Alert desc %u\n",
+ queue->idx, description);
+ ret = -ENOTCONN;
+ } else {
+ pr_warn("queue %d: TLS Alert desc %u\n",
+ queue->idx, description);
+ ret = -EAGAIN;
+ }
+ break;
+ default:
+ /* discard this record type */
+ pr_err("queue %d: TLS record %d unhandled\n",
+ queue->idx, ctype);
+ ret = -EAGAIN;
+ break;
+ }
+ return ret;
+}
+
static int nvmet_tcp_try_recv_pdu(struct nvmet_tcp_queue *queue)
{
struct nvme_tcp_hdr *hdr = &queue->pdu.cmd.hdr;
- int len;
+ int len, ret;
struct kvec iov;
+ char cbuf[CMSG_LEN(sizeof(char))] = {};
struct msghdr msg = { .msg_flags = MSG_DONTWAIT };
recv:
iov.iov_base = (void *)&queue->pdu + queue->offset;
iov.iov_len = queue->left;
+ if (queue->tls_pskid) {
+ msg.msg_control = cbuf;
+ msg.msg_controllen = sizeof(cbuf);
+ }
len = kernel_recvmsg(queue->sock, &msg, &iov, 1,
iov.iov_len, msg.msg_flags);
if (unlikely(len < 0))
return len;
+ if (queue->tls_pskid) {
+ ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf);
+ if (ret < 0)
+ return ret;
+ }
queue->offset += len;
queue->left -= len;
@@ -1185,16 +1232,22 @@ static void nvmet_tcp_prep_recv_ddgst(struct nvmet_tcp_cmd *cmd)
static int nvmet_tcp_try_recv_data(struct nvmet_tcp_queue *queue)
{
struct nvmet_tcp_cmd *cmd = queue->cmd;
- int ret;
+ int len, ret;
while (msg_data_left(&cmd->recv_msg)) {
- ret = sock_recvmsg(cmd->queue->sock, &cmd->recv_msg,
+ len = sock_recvmsg(cmd->queue->sock, &cmd->recv_msg,
cmd->recv_msg.msg_flags);
- if (ret <= 0)
- return ret;
+ if (len <= 0)
+ return len;
+ if (queue->tls_pskid) {
+ ret = nvmet_tcp_tls_record_ok(cmd->queue,
+ &cmd->recv_msg, cmd->recv_cbuf);
+ if (ret < 0)
+ return ret;
+ }
- cmd->pdu_recv += ret;
- cmd->rbytes_done += ret;
+ cmd->pdu_recv += len;
+ cmd->rbytes_done += len;
}
if (queue->data_digest) {
@@ -1212,20 +1265,30 @@ static int nvmet_tcp_try_recv_data(struct nvmet_tcp_queue *queue)
static int nvmet_tcp_try_recv_ddgst(struct nvmet_tcp_queue *queue)
{
struct nvmet_tcp_cmd *cmd = queue->cmd;
- int ret;
+ int ret, len;
+ char cbuf[CMSG_LEN(sizeof(char))] = {};
struct msghdr msg = { .msg_flags = MSG_DONTWAIT };
struct kvec iov = {
.iov_base = (void *)&cmd->recv_ddgst + queue->offset,
.iov_len = queue->left
};
- ret = kernel_recvmsg(queue->sock, &msg, &iov, 1,
+ if (queue->tls_pskid) {
+ msg.msg_control = cbuf;
+ msg.msg_controllen = sizeof(cbuf);
+ }
+ len = kernel_recvmsg(queue->sock, &msg, &iov, 1,
iov.iov_len, msg.msg_flags);
- if (unlikely(ret < 0))
- return ret;
+ if (unlikely(len < 0))
+ return len;
+ if (queue->tls_pskid) {
+ ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf);
+ if (ret < 0)
+ return ret;
+ }
- queue->offset += ret;
- queue->left -= ret;
+ queue->offset += len;
+ queue->left -= len;
if (queue->left)
return -EAGAIN;
@@ -1406,6 +1469,10 @@ static int nvmet_tcp_alloc_cmd(struct nvmet_tcp_queue *queue,
if (!c->r2t_pdu)
goto out_free_data;
+ if (queue->state == NVMET_TCP_Q_TLS_HANDSHAKE) {
+ c->recv_msg.msg_control = c->recv_cbuf;
+ c->recv_msg.msg_controllen = sizeof(c->recv_cbuf);
+ }
c->recv_msg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL;
list_add_tail(&c->entry, &queue->free_list);
@@ -1673,6 +1740,7 @@ static void nvmet_tcp_tls_handshake_done(void *data, int status,
WARN_ON(queue->state != NVMET_TCP_Q_TLS_HANDSHAKE);
if (!status)
queue->tls_pskid = peerid;
+
queue->state = NVMET_TCP_Q_CONNECTING;
spin_unlock_bh(&queue->state_lock);