Message ID | 20210331023237.41094-11-xiyou.wangcong@gmail.com (mailing list archive) |
---|---|
State | Accepted |
Delegated to: | BPF |
Headers | show |
Series | sockmap: introduce BPF_SK_SKB_VERDICT and support UDP | expand |
On Wed, Mar 31, 2021 at 04:32 AM CEST, Cong Wang wrote: > From: Cong Wang <cong.wang@bytedance.com> > > Currently sockmap calls into each protocol to update the struct > proto and replace it. This certainly won't work when the protocol > is implemented as a module, for example, AF_UNIX. > > Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each > protocol can implement its own way to replace the struct proto. > This also helps get rid of symbol dependencies on CONFIG_INET. > > Cc: John Fastabend <john.fastabend@gmail.com> > Cc: Daniel Borkmann <daniel@iogearbox.net> > Cc: Jakub Sitnicki <jakub@cloudflare.com> > Cc: Lorenz Bauer <lmb@cloudflare.com> > Signed-off-by: Cong Wang <cong.wang@bytedance.com> > --- [...] > diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c > index 4a0478b17243..38952aaee3a1 100644 > --- a/net/ipv4/udp.c > +++ b/net/ipv4/udp.c > @@ -2849,6 +2849,9 @@ struct proto udp_prot = { > .unhash = udp_lib_unhash, > .rehash = udp_v4_rehash, > .get_port = udp_v4_get_port, > +#ifdef CONFIG_BPF_SYSCALL > + .psock_update_sk_prot = udp_bpf_update_proto, > +#endif > .memory_allocated = &udp_memory_allocated, > .sysctl_mem = sysctl_udp_mem, > .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), > diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c > index 7a94791efc1a..6001f93cd3a0 100644 > --- a/net/ipv4/udp_bpf.c > +++ b/net/ipv4/udp_bpf.c > @@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void) > } > core_initcall(udp_bpf_v4_build_proto); > > -struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) > +int udp_bpf_update_proto(struct sock *sk, bool restore) > { > int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; > + struct sk_psock *psock = sk_psock(sk); > + > + if (restore) { > + sk->sk_write_space = psock->saved_write_space; > + /* Pairs with lockless read in sk_clone_lock() */ Just to clarify. UDP sockets don't get cloned, so the above comment apply. > + WRITE_ONCE(sk->sk_prot, psock->sk_proto); > + return 0; > + } > > if (sk->sk_family == AF_INET6) > udp_bpf_check_v6_needs_rebuild(psock->sk_proto); > > - return &udp_bpf_prots[family]; > + /* Pairs with lockless read in sk_clone_lock() */ > + WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]); > + return 0; > } > +EXPORT_SYMBOL_GPL(udp_bpf_update_proto); [...]
On Fri, Apr 2, 2021 at 3:16 AM Jakub Sitnicki <jakub@cloudflare.com> wrote: > > -struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) > > +int udp_bpf_update_proto(struct sock *sk, bool restore) > > { > > int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; > > + struct sk_psock *psock = sk_psock(sk); > > + > > + if (restore) { > > + sk->sk_write_space = psock->saved_write_space; > > + /* Pairs with lockless read in sk_clone_lock() */ > > Just to clarify. UDP sockets don't get cloned, so the above comment > apply. Good catch! It is clearly a copy-n-paste. I will send a patch to remove it. Thanks.
On 3/31/21 4:32 AM, Cong Wang wrote: > From: Cong Wang <cong.wang@bytedance.com> > > Currently sockmap calls into each protocol to update the struct > proto and replace it. This certainly won't work when the protocol > is implemented as a module, for example, AF_UNIX. > > Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each > protocol can implement its own way to replace the struct proto. > This also helps get rid of symbol dependencies on CONFIG_INET. [...] > > -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) > +int tcp_bpf_update_proto(struct sock *sk, bool restore) > { > + struct sk_psock *psock = sk_psock(sk); I do not think RCU is held here ? sk_psock() is using rcu_dereference_sk_user_data() > int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; > int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; > Same issue in udp_bpf_update_proto() of course.
Eric Dumazet wrote: > > > On 3/31/21 4:32 AM, Cong Wang wrote: > > From: Cong Wang <cong.wang@bytedance.com> > > > > Currently sockmap calls into each protocol to update the struct > > proto and replace it. This certainly won't work when the protocol > > is implemented as a module, for example, AF_UNIX. > > > > Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each > > protocol can implement its own way to replace the struct proto. > > This also helps get rid of symbol dependencies on CONFIG_INET. > > [...] > > > > > > -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) > > +int tcp_bpf_update_proto(struct sock *sk, bool restore) > > { > > + struct sk_psock *psock = sk_psock(sk); > > I do not think RCU is held here ? Hi, thanks for looking at this. > > sk_psock() is using rcu_dereference_sk_user_data() First caller of this is here, sock_{hash|map}_update_common <- has a WARN_ON_ONCE(!rcu_read_lock_held); sock_map_link() sock_map_init_proto() psock_update_sk_prot(sk, false) And the other does this, sk_psock_put() sk_psock_drop() sk_psock_restore_proto psock_update_sk_prot(sk, true) But we can get here through many callers and it sure doesn't look like its all safe. For example one case, .sendmsg tcp_bpf_sendmsg psock = sk_psock_get(sk) sk_psock_put(sk, psock) <- this doesn't have the RCU held > > > int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; > > int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; > > > > Same issue in udp_bpf_update_proto() of course. > Yep. Either we revert the patch or we can fix it to pass the psock through. Passing the psock works because we have a reference on it and it wont go away. I don't have any other good ideas off-hand. Thanks Eric! I'm a bit surprised we didn't get an RCU splat from the tests though. .John
On Mon, Apr 5, 2021 at 1:25 AM Eric Dumazet <eric.dumazet@gmail.com> wrote: > > > > On 3/31/21 4:32 AM, Cong Wang wrote: > > From: Cong Wang <cong.wang@bytedance.com> > > > > Currently sockmap calls into each protocol to update the struct > > proto and replace it. This certainly won't work when the protocol > > is implemented as a module, for example, AF_UNIX. > > > > Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each > > protocol can implement its own way to replace the struct proto. > > This also helps get rid of symbol dependencies on CONFIG_INET. > > [...] > > > > > > -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) > > +int tcp_bpf_update_proto(struct sock *sk, bool restore) > > { > > + struct sk_psock *psock = sk_psock(sk); > > I do not think RCU is held here ? > > sk_psock() is using rcu_dereference_sk_user_data() Right, I just saw the syzbot report. But here we already have the writer lock of sk_callback_lock, hence RCU read lock here makes no sense to me. Probably we just have to tell RCU we already have sk_callback_lock. Thanks.
Cong Wang wrote: > On Mon, Apr 5, 2021 at 1:25 AM Eric Dumazet <eric.dumazet@gmail.com> wrote: > > > > > > > > On 3/31/21 4:32 AM, Cong Wang wrote: > > > From: Cong Wang <cong.wang@bytedance.com> > > > > > > Currently sockmap calls into each protocol to update the struct > > > proto and replace it. This certainly won't work when the protocol > > > is implemented as a module, for example, AF_UNIX. > > > > > > Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each > > > protocol can implement its own way to replace the struct proto. > > > This also helps get rid of symbol dependencies on CONFIG_INET. > > > > [...] > > > > > > > > > > -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) > > > +int tcp_bpf_update_proto(struct sock *sk, bool restore) > > > { > > > + struct sk_psock *psock = sk_psock(sk); > > > > I do not think RCU is held here ? > > > > sk_psock() is using rcu_dereference_sk_user_data() > > Right, I just saw the syzbot report. But here we already have > the writer lock of sk_callback_lock, hence RCU read lock here > makes no sense to me. Probably we just have to tell RCU we > already have sk_callback_lock. > > Thanks. I think you need to ensure its the psock we originally grabbed as well. Otherwise how do we ensure the psock is not swapped from another thread?
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h index c83dbc2d81d9..5e800ddc2dc6 100644 --- a/include/linux/skmsg.h +++ b/include/linux/skmsg.h @@ -99,6 +99,7 @@ struct sk_psock { void (*saved_close)(struct sock *sk, long timeout); void (*saved_write_space)(struct sock *sk); void (*saved_data_ready)(struct sock *sk); + int (*psock_update_sk_prot)(struct sock *sk, bool restore); struct proto *sk_proto; struct mutex work_mutex; struct sk_psock_work_state work_state; @@ -395,25 +396,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock) } } -static inline void sk_psock_update_proto(struct sock *sk, - struct sk_psock *psock, - struct proto *ops) -{ - /* Pairs with lockless read in sk_clone_lock() */ - WRITE_ONCE(sk->sk_prot, ops); -} - static inline void sk_psock_restore_proto(struct sock *sk, struct sk_psock *psock) { sk->sk_prot->unhash = psock->saved_unhash; - if (inet_csk_has_ulp(sk)) { - tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); - } else { - sk->sk_write_space = psock->saved_write_space; - /* Pairs with lockless read in sk_clone_lock() */ - WRITE_ONCE(sk->sk_prot, psock->sk_proto); - } + if (psock->psock_update_sk_prot) + psock->psock_update_sk_prot(sk, true); } static inline void sk_psock_set_state(struct sk_psock *psock, diff --git a/include/net/sock.h b/include/net/sock.h index 0b6266fd6bf6..8b4155e756c2 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -1184,6 +1184,9 @@ struct proto { void (*unhash)(struct sock *sk); void (*rehash)(struct sock *sk); int (*get_port)(struct sock *sk, unsigned short snum); +#ifdef CONFIG_BPF_SYSCALL + int (*psock_update_sk_prot)(struct sock *sk, bool restore); +#endif /* Keeping track of sockets in use */ #ifdef CONFIG_PROC_FS diff --git a/include/net/tcp.h b/include/net/tcp.h index 075de26f449d..2efa4e5ea23d 100644 --- a/include/net/tcp.h +++ b/include/net/tcp.h @@ -2203,6 +2203,7 @@ struct sk_psock; #ifdef CONFIG_BPF_SYSCALL struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock); +int tcp_bpf_update_proto(struct sock *sk, bool restore); void tcp_bpf_clone(const struct sock *sk, struct sock *newsk); #endif /* CONFIG_BPF_SYSCALL */ diff --git a/include/net/udp.h b/include/net/udp.h index d4d064c59232..df7cc1edc200 100644 --- a/include/net/udp.h +++ b/include/net/udp.h @@ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk, #ifdef CONFIG_BPF_SYSCALL struct sk_psock; struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock); +int udp_bpf_update_proto(struct sock *sk, bool restore); #endif #endif /* _UDP_H */ diff --git a/net/core/skmsg.c b/net/core/skmsg.c index a045812d7c78..9fc83f7cc1a0 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -562,11 +562,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node) write_lock_bh(&sk->sk_callback_lock); - if (inet_csk_has_ulp(sk)) { - psock = ERR_PTR(-EINVAL); - goto out; - } - if (sk->sk_user_data) { psock = ERR_PTR(-EBUSY); goto out; diff --git a/net/core/sock_map.c b/net/core/sock_map.c index c2a0411e08a8..2915c7c8778b 100644 --- a/net/core/sock_map.c +++ b/net/core/sock_map.c @@ -185,26 +185,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw) static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) { - struct proto *prot; - - switch (sk->sk_type) { - case SOCK_STREAM: - prot = tcp_bpf_get_proto(sk, psock); - break; - - case SOCK_DGRAM: - prot = udp_bpf_get_proto(sk, psock); - break; - - default: + if (!sk->sk_prot->psock_update_sk_prot) return -EINVAL; - } - - if (IS_ERR(prot)) - return PTR_ERR(prot); - - sk_psock_update_proto(sk, psock, prot); - return 0; + psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot; + return sk->sk_prot->psock_update_sk_prot(sk, false); } static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) @@ -556,7 +540,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk) static bool sock_map_sk_is_suitable(const struct sock *sk) { - return sk_is_tcp(sk) || sk_is_udp(sk); + return !!sk->sk_prot->psock_update_sk_prot; } static bool sock_map_sk_state_allowed(const struct sock *sk) diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c index ae980716d896..ac8cfbaeacd2 100644 --- a/net/ipv4/tcp_bpf.c +++ b/net/ipv4/tcp_bpf.c @@ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops) ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; } -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) +int tcp_bpf_update_proto(struct sock *sk, bool restore) { + struct sk_psock *psock = sk_psock(sk); int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; + if (restore) { + if (inet_csk_has_ulp(sk)) { + tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); + } else { + sk->sk_write_space = psock->saved_write_space; + /* Pairs with lockless read in sk_clone_lock() */ + WRITE_ONCE(sk->sk_prot, psock->sk_proto); + } + return 0; + } + + if (inet_csk_has_ulp(sk)) + return -EINVAL; + if (sk->sk_family == AF_INET6) { if (tcp_bpf_assert_proto_ops(psock->sk_proto)) - return ERR_PTR(-EINVAL); + return -EINVAL; tcp_bpf_check_v6_needs_rebuild(psock->sk_proto); } - return &tcp_bpf_prots[family][config]; + /* Pairs with lockless read in sk_clone_lock() */ + WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]); + return 0; } +EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); /* If a child got cloned from a listening socket that had tcp_bpf * protocol callbacks installed, we need to restore the callbacks to diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c index daad4f99db32..dfc6d1c0e710 100644 --- a/net/ipv4/tcp_ipv4.c +++ b/net/ipv4/tcp_ipv4.c @@ -2806,6 +2806,9 @@ struct proto tcp_prot = { .hash = inet_hash, .unhash = inet_unhash, .get_port = inet_csk_get_port, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = tcp_bpf_update_proto, +#endif .enter_memory_pressure = tcp_enter_memory_pressure, .leave_memory_pressure = tcp_leave_memory_pressure, .stream_memory_free = tcp_stream_memory_free, diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c index 4a0478b17243..38952aaee3a1 100644 --- a/net/ipv4/udp.c +++ b/net/ipv4/udp.c @@ -2849,6 +2849,9 @@ struct proto udp_prot = { .unhash = udp_lib_unhash, .rehash = udp_v4_rehash, .get_port = udp_v4_get_port, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = udp_bpf_update_proto, +#endif .memory_allocated = &udp_memory_allocated, .sysctl_mem = sysctl_udp_mem, .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min), diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c index 7a94791efc1a..6001f93cd3a0 100644 --- a/net/ipv4/udp_bpf.c +++ b/net/ipv4/udp_bpf.c @@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void) } core_initcall(udp_bpf_v4_build_proto); -struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock) +int udp_bpf_update_proto(struct sock *sk, bool restore) { int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6; + struct sk_psock *psock = sk_psock(sk); + + if (restore) { + sk->sk_write_space = psock->saved_write_space; + /* Pairs with lockless read in sk_clone_lock() */ + WRITE_ONCE(sk->sk_prot, psock->sk_proto); + return 0; + } if (sk->sk_family == AF_INET6) udp_bpf_check_v6_needs_rebuild(psock->sk_proto); - return &udp_bpf_prots[family]; + /* Pairs with lockless read in sk_clone_lock() */ + WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]); + return 0; } +EXPORT_SYMBOL_GPL(udp_bpf_update_proto); diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index d0f007741e8e..bff22d6ef516 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -2139,6 +2139,9 @@ struct proto tcpv6_prot = { .hash = inet6_hash, .unhash = inet_unhash, .get_port = inet_csk_get_port, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = tcp_bpf_update_proto, +#endif .enter_memory_pressure = tcp_enter_memory_pressure, .leave_memory_pressure = tcp_leave_memory_pressure, .stream_memory_free = tcp_stream_memory_free, diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c index d25e5a9252fd..ef2c75bb4771 100644 --- a/net/ipv6/udp.c +++ b/net/ipv6/udp.c @@ -1713,6 +1713,9 @@ struct proto udpv6_prot = { .unhash = udp_lib_unhash, .rehash = udp_v6_rehash, .get_port = udp_v6_get_port, +#ifdef CONFIG_BPF_SYSCALL + .psock_update_sk_prot = udp_bpf_update_proto, +#endif .memory_allocated = &udp_memory_allocated, .sysctl_mem = sysctl_udp_mem, .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),