@@ -349,5 +349,7 @@ struct tcp_zerocopy_receive {
__u32 recv_skip_hint; /* out: amount of bytes to skip */
__u32 inq; /* out: amount of bytes in read queue */
__s32 err; /* out: socket error */
+ __u64 copybuf_address; /* in: copybuf address (small reads) */
+ __s32 copybuf_len; /* in/out: copybuf bytes avail/used or error */
};
#endif /* _UAPI_LINUX_TCP_H */
@@ -1743,6 +1743,52 @@ int tcp_mmap(struct file *file, struct socket *sock,
}
EXPORT_SYMBOL(tcp_mmap);
+static int tcp_copy_straggler_data(struct tcp_zerocopy_receive *zc,
+ struct sk_buff *skb, u32 copylen,
+ u32 *offset, u32 *seq)
+{
+ unsigned long copy_address = (unsigned long)zc->copybuf_address;
+ struct msghdr msg = {};
+ struct iovec iov;
+ int err;
+
+ if (copy_address != zc->copybuf_address)
+ return -EINVAL;
+
+ err = import_single_range(READ, (void __user *)copy_address,
+ copylen, &iov, &msg.msg_iter);
+ if (err)
+ return err;
+ err = skb_copy_datagram_msg(skb, *offset, &msg, copylen);
+ if (err)
+ return err;
+ zc->recv_skip_hint -= copylen;
+ *offset += copylen;
+ *seq += copylen;
+ return (__s32)copylen;
+}
+
+static int tcp_zerocopy_handle_leftover_data(struct tcp_zerocopy_receive *zc,
+ struct sock *sk,
+ struct sk_buff *skb,
+ u32 *seq,
+ s32 copybuf_len)
+{
+ u32 offset, copylen = min_t(u32, copybuf_len, zc->recv_skip_hint);
+
+ if (!copylen)
+ return 0;
+ /* skb is null if inq < PAGE_SIZE. */
+ if (skb)
+ offset = *seq - TCP_SKB_CB(skb)->seq;
+ else
+ skb = tcp_recv_skb(sk, *seq, &offset);
+
+ zc->copybuf_len = tcp_copy_straggler_data(zc, skb, copylen, &offset,
+ seq);
+ return zc->copybuf_len < 0 ? 0 : copylen;
+}
+
static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma,
struct page **pages,
unsigned long pages_to_map,
@@ -1776,8 +1822,10 @@ static int tcp_zerocopy_vm_insert_batch(struct vm_area_struct *vma,
static int tcp_zerocopy_receive(struct sock *sk,
struct tcp_zerocopy_receive *zc)
{
+ u32 length = 0, offset, vma_len, avail_len, aligned_len, copylen = 0;
unsigned long address = (unsigned long)zc->address;
- u32 length = 0, seq, offset, zap_len;
+ s32 copybuf_len = zc->copybuf_len;
+ struct tcp_sock *tp = tcp_sk(sk);
#define PAGE_BATCH_SIZE 8
struct page *pages[PAGE_BATCH_SIZE];
const skb_frag_t *frags = NULL;
@@ -1785,10 +1833,12 @@ static int tcp_zerocopy_receive(struct sock *sk,
struct sk_buff *skb = NULL;
unsigned long pg_idx = 0;
unsigned long curr_addr;
- struct tcp_sock *tp;
- int inq;
+ u32 seq = tp->copied_seq;
+ int inq = tcp_inq(sk);
int ret;
+ zc->copybuf_len = 0;
+
if (address & (PAGE_SIZE - 1) || address != zc->address)
return -EINVAL;
@@ -1797,8 +1847,6 @@ static int tcp_zerocopy_receive(struct sock *sk,
sock_rps_record_flow(sk);
- tp = tcp_sk(sk);
-
mmap_read_lock(current->mm);
vma = find_vma(current->mm, address);
@@ -1806,17 +1854,16 @@ static int tcp_zerocopy_receive(struct sock *sk,
mmap_read_unlock(current->mm);
return -EINVAL;
}
- zc->length = min_t(unsigned long, zc->length, vma->vm_end - address);
-
- seq = tp->copied_seq;
- inq = tcp_inq(sk);
- zc->length = min_t(u32, zc->length, inq);
- zap_len = zc->length & ~(PAGE_SIZE - 1);
- if (zap_len) {
- zap_page_range(vma, address, zap_len);
+ vma_len = min_t(unsigned long, zc->length, vma->vm_end - address);
+ avail_len = min_t(u32, vma_len, inq);
+ aligned_len = avail_len & ~(PAGE_SIZE - 1);
+ if (aligned_len) {
+ zap_page_range(vma, address, aligned_len);
+ zc->length = aligned_len;
zc->recv_skip_hint = 0;
} else {
- zc->recv_skip_hint = zc->length;
+ zc->length = avail_len;
+ zc->recv_skip_hint = avail_len;
}
ret = 0;
curr_addr = address;
@@ -1885,13 +1932,18 @@ static int tcp_zerocopy_receive(struct sock *sk,
}
out:
mmap_read_unlock(current->mm);
- if (length) {
+ /* Try to copy straggler data. */
+ if (!ret)
+ copylen = tcp_zerocopy_handle_leftover_data(zc, sk, skb, &seq,
+ copybuf_len);
+
+ if (length + copylen) {
WRITE_ONCE(tp->copied_seq, seq);
tcp_rcv_space_adjust(sk);
/* Clean up data we have read: This will do ACK frames. */
tcp_recv_skb(sk, seq, &offset);
- tcp_cleanup_rbuf(sk, length);
+ tcp_cleanup_rbuf(sk, length + copylen);
ret = 0;
if (length == zc->length)
zc->recv_skip_hint = 0;