Message ID | 20210928002212.14498-4-xiyou.wangcong@gmail.com (mailing list archive) |
---|---|
State | Superseded |
Delegated to: | BPF |
Headers | show |
Series | sock_map: fix ->poll() and update selftests | expand |
Cong Wang wrote: > From: Cong Wang <cong.wang@bytedance.com> > > Yucong noticed we can't poll() sockets in sockmap even when > they are the destination sockets of redirections. This is > because we never poll any psock queues in ->poll(), except > for TCP. Now we can overwrite >sock_is_readable() and > implement and invoke it for UDP and AF_UNIX sockets. nit: instead of 'because we never poll any psock queue...' how about 'because we do not poll the psock queues in ->poll(), except for TCP.' > > Reported-by: Yucong Sun <sunyucong@gmail.com> > Cc: John Fastabend <john.fastabend@gmail.com> > Cc: Daniel Borkmann <daniel@iogearbox.net> > Cc: Jakub Sitnicki <jakub@cloudflare.com> > Cc: Lorenz Bauer <lmb@cloudflare.com> > Signed-off-by: Cong Wang <cong.wang@bytedance.com> > --- [...] > static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes) > { > diff --git a/net/core/skmsg.c b/net/core/skmsg.c > index 2d6249b28928..93ae48581ad2 100644 > --- a/net/core/skmsg.c > +++ b/net/core/skmsg.c > @@ -474,6 +474,20 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, > } > EXPORT_SYMBOL_GPL(sk_msg_recvmsg); > > +bool sk_msg_is_readable(struct sock *sk) > +{ > + struct sk_psock *psock; > + bool empty = true; > + > + psock = sk_psock_get_checked(sk); We shouldn't need the checked version here right? We only get here because we hooked the sk with the callbacks from *_bpf_rebuild_rpotos. Then we can just use sk_psock() and save a few extra insns/branch. > + if (IS_ERR_OR_NULL(psock)) > + return false; > + empty = sk_psock_queue_empty(psock); > + sk_psock_put(sk, psock); > + return !empty; > +} > +EXPORT_SYMBOL_GPL(sk_msg_is_readable); [...]
On Thu, Sep 30, 2021 at 2:44 PM John Fastabend <john.fastabend@gmail.com> wrote: > > +bool sk_msg_is_readable(struct sock *sk) > > +{ > > + struct sk_psock *psock; > > + bool empty = true; > > + > > + psock = sk_psock_get_checked(sk); > > We shouldn't need the checked version here right? We only get here because > we hooked the sk with the callbacks from *_bpf_rebuild_rpotos. Then we > can just use sk_psock() and save a few extra insns/branch. Good catch! Indeed only sockmap overwrites that hook. I will send V3 shortly after all tests are done. Thanks
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index 8f577739fc36..a25434207dca 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -128,6 +128,7 @@ int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, struct sk_msg *msg, u32 bytes); int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, int len, int flags); +bool sk_msg_is_readable(struct sock *sk); static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes) { diff --git a/net/core/skmsg.c b/net/core/skmsg.c index 2d6249b28928..93ae48581ad2 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -474,6 +474,20 @@ int sk_msg_recvmsg(struct sock *sk, struct sk_psock *psock, struct msghdr *msg, } EXPORT_SYMBOL_GPL(sk_msg_recvmsg); +bool sk_msg_is_readable(struct sock *sk) +{ + struct sk_psock *psock; + bool empty = true; + + psock = sk_psock_get_checked(sk); + if (IS_ERR_OR_NULL(psock)) + return false; + empty = sk_psock_queue_empty(psock); + sk_psock_put(sk, psock); + return !empty; +} +EXPORT_SYMBOL_GPL(sk_msg_is_readable); + static struct sk_msg *sk_psock_create_ingress_msg(struct sock *sk, struct sk_buff *skb) { diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 8851c9463b4b..9f49c0967504 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -2866,6 +2866,8 @@ __poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait) !(sk->sk_shutdown & RCV_SHUTDOWN) && first_packet_length(sk) == -1) mask &= ~(EPOLLIN | EPOLLRDNORM); + if (sk_is_readable(sk)) + mask |= EPOLLIN | EPOLLRDNORM; return mask; } diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c index 7a1d5f473878..bbe6569c9ad3 100644 --- a/net/ipv4/udp_bpf.c +++ b/net/ipv4/udp_bpf.c @@ -114,6 +114,7 @@ static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base) *prot = *base; prot->close = sock_map_close; prot->recvmsg = udp_bpf_recvmsg; + prot->sock_is_readable = sk_msg_is_readable; } static void udp_bpf_check_v6_needs_rebuild(struct proto *ops) diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index 92345c9bb60c..f1cbaa0ccf6b 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -3014,6 +3014,8 @@ static __poll_t unix_poll(struct file *file, struct socket *sock, poll_table *wa /* readable? */ if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) mask |= EPOLLIN | EPOLLRDNORM; + if (sk_is_readable(sk)) + mask |= EPOLLIN | EPOLLRDNORM; /* Connection-based need to check for termination and startup */ if ((sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) && @@ -3053,6 +3055,8 @@ static __poll_t unix_dgram_poll(struct file *file, struct socket *sock, /* readable? */ if (!skb_queue_empty_lockless(&sk->sk_receive_queue)) mask |= EPOLLIN | EPOLLRDNORM; + if (sk_is_readable(sk)) + mask |= EPOLLIN | EPOLLRDNORM; /* Connection-based need to check for termination and startup */ if (sk->sk_type == SOCK_SEQPACKET) { diff --git a/net/unix/unix_bpf.c b/net/unix/unix_bpf.c index b927e2baae50..452376c6f419 100644 --- a/net/unix/unix_bpf.c +++ b/net/unix/unix_bpf.c @@ -102,6 +102,7 @@ static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *prot = *base; prot->close = sock_map_close; prot->recvmsg = unix_bpf_recvmsg; + prot->sock_is_readable = sk_msg_is_readable; } static void unix_stream_bpf_rebuild_protos(struct proto *prot, @@ -110,6 +111,7 @@ static void unix_stream_bpf_rebuild_protos(struct proto *prot, *prot = *base; prot->close = sock_map_close; prot->recvmsg = unix_bpf_recvmsg; + prot->sock_is_readable = sk_msg_is_readable; prot->unhash = sock_map_unhash; }