diff mbox series

[bpf-next,v8,10/16] sock: introduce sk->sk_prot->psock_update_sk_prot()

Message ID 20210331023237.41094-11-xiyou.wangcong@gmail.com (mailing list archive)
State Accepted
Delegated to: BPF
Headers show
Series sockmap: introduce BPF_SK_SKB_VERDICT and support UDP | expand

Checks

Context Check Description
netdev/cover_letter success Link
netdev/fixes_present success Link
netdev/patch_count fail Series longer than 15 patches
netdev/tree_selection success Clearly marked for bpf-next
netdev/subject_prefix success Link
netdev/cc_maintainers warning 11 maintainers not CCed: dsahern@kernel.org yhs@fb.com kpsingh@kernel.org yoshfuji@linux-ipv6.org andrii@kernel.org kafai@fb.com ast@kernel.org kuba@kernel.org songliubraving@fb.com davem@davemloft.net edumazet@google.com
netdev/source_inline success Was 0 now: 0
netdev/verify_signedoff success Link
netdev/module_param success Was 0 now: 0
netdev/build_32bit success Errors and warnings before: 4112 this patch: 4112
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/verify_fixes success Link
netdev/checkpatch warning WARNING: line length of 86 exceeds 80 columns WARNING: line length of 87 exceeds 80 columns
netdev/build_allmodconfig_warn success Errors and warnings before: 4349 this patch: 4349
netdev/header_inline success Link

Commit Message

Cong Wang March 31, 2021, 2:32 a.m. UTC
From: Cong Wang <cong.wang@bytedance.com>

Currently sockmap calls into each protocol to update the struct
proto and replace it. This certainly won't work when the protocol
is implemented as a module, for example, AF_UNIX.

Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
protocol can implement its own way to replace the struct proto.
This also helps get rid of symbol dependencies on CONFIG_INET.

Cc: John Fastabend <john.fastabend@gmail.com>
Cc: Daniel Borkmann <daniel@iogearbox.net>
Cc: Jakub Sitnicki <jakub@cloudflare.com>
Cc: Lorenz Bauer <lmb@cloudflare.com>
Signed-off-by: Cong Wang <cong.wang@bytedance.com>
---
 include/linux/skmsg.h | 18 +++---------------
 include/net/sock.h    |  3 +++
 include/net/tcp.h     |  1 +
 include/net/udp.h     |  1 +
 net/core/skmsg.c      |  5 -----
 net/core/sock_map.c   | 24 ++++--------------------
 net/ipv4/tcp_bpf.c    | 24 +++++++++++++++++++++---
 net/ipv4/tcp_ipv4.c   |  3 +++
 net/ipv4/udp.c        |  3 +++
 net/ipv4/udp_bpf.c    | 15 +++++++++++++--
 net/ipv6/tcp_ipv6.c   |  3 +++
 net/ipv6/udp.c        |  3 +++
 12 files changed, 58 insertions(+), 45 deletions(-)

Comments

Jakub Sitnicki April 2, 2021, 10:16 a.m. UTC | #1
On Wed, Mar 31, 2021 at 04:32 AM CEST, Cong Wang wrote:
> From: Cong Wang <cong.wang@bytedance.com>
>
> Currently sockmap calls into each protocol to update the struct
> proto and replace it. This certainly won't work when the protocol
> is implemented as a module, for example, AF_UNIX.
>
> Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
> protocol can implement its own way to replace the struct proto.
> This also helps get rid of symbol dependencies on CONFIG_INET.
>
> Cc: John Fastabend <john.fastabend@gmail.com>
> Cc: Daniel Borkmann <daniel@iogearbox.net>
> Cc: Jakub Sitnicki <jakub@cloudflare.com>
> Cc: Lorenz Bauer <lmb@cloudflare.com>
> Signed-off-by: Cong Wang <cong.wang@bytedance.com>
> ---

[...]

> diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
> index 4a0478b17243..38952aaee3a1 100644
> --- a/net/ipv4/udp.c
> +++ b/net/ipv4/udp.c
> @@ -2849,6 +2849,9 @@ struct proto udp_prot = {
>  	.unhash			= udp_lib_unhash,
>  	.rehash			= udp_v4_rehash,
>  	.get_port		= udp_v4_get_port,
> +#ifdef CONFIG_BPF_SYSCALL
> +	.psock_update_sk_prot	= udp_bpf_update_proto,
> +#endif
>  	.memory_allocated	= &udp_memory_allocated,
>  	.sysctl_mem		= sysctl_udp_mem,
>  	.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_udp_wmem_min),
> diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c
> index 7a94791efc1a..6001f93cd3a0 100644
> --- a/net/ipv4/udp_bpf.c
> +++ b/net/ipv4/udp_bpf.c
> @@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void)
>  }
>  core_initcall(udp_bpf_v4_build_proto);
>
> -struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
> +int udp_bpf_update_proto(struct sock *sk, bool restore)
>  {
>  	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
> +	struct sk_psock *psock = sk_psock(sk);
> +
> +	if (restore) {
> +		sk->sk_write_space = psock->saved_write_space;
> +		/* Pairs with lockless read in sk_clone_lock() */

Just to clarify. UDP sockets don't get cloned, so the above comment
apply.

> +		WRITE_ONCE(sk->sk_prot, psock->sk_proto);
> +		return 0;
> +	}
>
>  	if (sk->sk_family == AF_INET6)
>  		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
>
> -	return &udp_bpf_prots[family];
> +	/* Pairs with lockless read in sk_clone_lock() */
> +	WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
> +	return 0;
>  }
> +EXPORT_SYMBOL_GPL(udp_bpf_update_proto);

[...]
Cong Wang April 3, 2021, 5:13 a.m. UTC | #2
On Fri, Apr 2, 2021 at 3:16 AM Jakub Sitnicki <jakub@cloudflare.com> wrote:
> > -struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
> > +int udp_bpf_update_proto(struct sock *sk, bool restore)
> >  {
> >       int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
> > +     struct sk_psock *psock = sk_psock(sk);
> > +
> > +     if (restore) {
> > +             sk->sk_write_space = psock->saved_write_space;
> > +             /* Pairs with lockless read in sk_clone_lock() */
>
> Just to clarify. UDP sockets don't get cloned, so the above comment
> apply.

Good catch! It is clearly a copy-n-paste. I will send a patch to remove it.

Thanks.
Eric Dumazet April 5, 2021, 8:25 a.m. UTC | #3
On 3/31/21 4:32 AM, Cong Wang wrote:
> From: Cong Wang <cong.wang@bytedance.com>
> 
> Currently sockmap calls into each protocol to update the struct
> proto and replace it. This certainly won't work when the protocol
> is implemented as a module, for example, AF_UNIX.
> 
> Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
> protocol can implement its own way to replace the struct proto.
> This also helps get rid of symbol dependencies on CONFIG_INET.

[...]


>  
> -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
> +int tcp_bpf_update_proto(struct sock *sk, bool restore)
>  {
> +	struct sk_psock *psock = sk_psock(sk);

I do not think RCU is held here ?

sk_psock() is using rcu_dereference_sk_user_data()

>  	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
>  	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
>  

Same issue in udp_bpf_update_proto() of course.
John Fastabend April 6, 2021, 6:12 p.m. UTC | #4
Eric Dumazet wrote:
> 
> 
> On 3/31/21 4:32 AM, Cong Wang wrote:
> > From: Cong Wang <cong.wang@bytedance.com>
> > 
> > Currently sockmap calls into each protocol to update the struct
> > proto and replace it. This certainly won't work when the protocol
> > is implemented as a module, for example, AF_UNIX.
> > 
> > Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
> > protocol can implement its own way to replace the struct proto.
> > This also helps get rid of symbol dependencies on CONFIG_INET.
> 
> [...]
> 
> 
> >  
> > -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
> > +int tcp_bpf_update_proto(struct sock *sk, bool restore)
> >  {
> > +	struct sk_psock *psock = sk_psock(sk);
> 
> I do not think RCU is held here ?

Hi, thanks for looking at this.

> 
> sk_psock() is using rcu_dereference_sk_user_data()

First caller of this is here,

 sock_{hash|map}_update_common <- has a WARN_ON_ONCE(!rcu_read_lock_held);
  sock_map_link()
   sock_map_init_proto()
    psock_update_sk_prot(sk, false)

And the other does this,

 sk_psock_put()
   sk_psock_drop()
     sk_psock_restore_proto
        psock_update_sk_prot(sk, true)

But we can get here through many callers and it sure doesn't look like its
all safe. For example one case,

 .sendmsg
   tcp_bpf_sendmsg
    psock = sk_psock_get(sk)
    sk_psock_put(sk, psock) <- this doesn't have the RCU held

> 
> >  	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
> >  	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
> >  
> 
> Same issue in udp_bpf_update_proto() of course.
> 

Yep.

Either we revert the patch or we can fix it to pass the psock through.
Passing the psock works because we have a reference on it and it wont
go away. I don't have any other good ideas off-hand.

Thanks Eric! I'm a bit surprised we didn't get an RCU splat from the
tests though.

.John
Cong Wang April 6, 2021, 6:30 p.m. UTC | #5
On Mon, Apr 5, 2021 at 1:25 AM Eric Dumazet <eric.dumazet@gmail.com> wrote:
>
>
>
> On 3/31/21 4:32 AM, Cong Wang wrote:
> > From: Cong Wang <cong.wang@bytedance.com>
> >
> > Currently sockmap calls into each protocol to update the struct
> > proto and replace it. This certainly won't work when the protocol
> > is implemented as a module, for example, AF_UNIX.
> >
> > Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
> > protocol can implement its own way to replace the struct proto.
> > This also helps get rid of symbol dependencies on CONFIG_INET.
>
> [...]
>
>
> >
> > -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
> > +int tcp_bpf_update_proto(struct sock *sk, bool restore)
> >  {
> > +     struct sk_psock *psock = sk_psock(sk);
>
> I do not think RCU is held here ?
>
> sk_psock() is using rcu_dereference_sk_user_data()

Right, I just saw the syzbot report. But here we already have
the writer lock of sk_callback_lock, hence RCU read lock here
makes no sense to me. Probably we just have to tell RCU we
already have sk_callback_lock.

Thanks.
John Fastabend April 6, 2021, 9:07 p.m. UTC | #6
Cong Wang wrote:
> On Mon, Apr 5, 2021 at 1:25 AM Eric Dumazet <eric.dumazet@gmail.com> wrote:
> >
> >
> >
> > On 3/31/21 4:32 AM, Cong Wang wrote:
> > > From: Cong Wang <cong.wang@bytedance.com>
> > >
> > > Currently sockmap calls into each protocol to update the struct
> > > proto and replace it. This certainly won't work when the protocol
> > > is implemented as a module, for example, AF_UNIX.
> > >
> > > Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
> > > protocol can implement its own way to replace the struct proto.
> > > This also helps get rid of symbol dependencies on CONFIG_INET.
> >
> > [...]
> >
> >
> > >
> > > -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
> > > +int tcp_bpf_update_proto(struct sock *sk, bool restore)
> > >  {
> > > +     struct sk_psock *psock = sk_psock(sk);
> >
> > I do not think RCU is held here ?
> >
> > sk_psock() is using rcu_dereference_sk_user_data()
> 
> Right, I just saw the syzbot report. But here we already have
> the writer lock of sk_callback_lock, hence RCU read lock here
> makes no sense to me. Probably we just have to tell RCU we
> already have sk_callback_lock.
> 
> Thanks.

I think you need to ensure its the psock we originally grabbed as
well. Otherwise how do we ensure the psock is not swapped from
another thread?
diff mbox series

Patch

diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index c83dbc2d81d9..5e800ddc2dc6 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -99,6 +99,7 @@  struct sk_psock {
 	void (*saved_close)(struct sock *sk, long timeout);
 	void (*saved_write_space)(struct sock *sk);
 	void (*saved_data_ready)(struct sock *sk);
+	int  (*psock_update_sk_prot)(struct sock *sk, bool restore);
 	struct proto			*sk_proto;
 	struct mutex			work_mutex;
 	struct sk_psock_work_state	work_state;
@@ -395,25 +396,12 @@  static inline void sk_psock_cork_free(struct sk_psock *psock)
 	}
 }
 
-static inline void sk_psock_update_proto(struct sock *sk,
-					 struct sk_psock *psock,
-					 struct proto *ops)
-{
-	/* Pairs with lockless read in sk_clone_lock() */
-	WRITE_ONCE(sk->sk_prot, ops);
-}
-
 static inline void sk_psock_restore_proto(struct sock *sk,
 					  struct sk_psock *psock)
 {
 	sk->sk_prot->unhash = psock->saved_unhash;
-	if (inet_csk_has_ulp(sk)) {
-		tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
-	} else {
-		sk->sk_write_space = psock->saved_write_space;
-		/* Pairs with lockless read in sk_clone_lock() */
-		WRITE_ONCE(sk->sk_prot, psock->sk_proto);
-	}
+	if (psock->psock_update_sk_prot)
+		psock->psock_update_sk_prot(sk, true);
 }
 
 static inline void sk_psock_set_state(struct sk_psock *psock,
diff --git a/include/net/sock.h b/include/net/sock.h
index 0b6266fd6bf6..8b4155e756c2 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1184,6 +1184,9 @@  struct proto {
 	void			(*unhash)(struct sock *sk);
 	void			(*rehash)(struct sock *sk);
 	int			(*get_port)(struct sock *sk, unsigned short snum);
+#ifdef CONFIG_BPF_SYSCALL
+	int			(*psock_update_sk_prot)(struct sock *sk, bool restore);
+#endif
 
 	/* Keeping track of sockets in use */
 #ifdef CONFIG_PROC_FS
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 075de26f449d..2efa4e5ea23d 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -2203,6 +2203,7 @@  struct sk_psock;
 
 #ifdef CONFIG_BPF_SYSCALL
 struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
+int tcp_bpf_update_proto(struct sock *sk, bool restore);
 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
 #endif /* CONFIG_BPF_SYSCALL */
 
diff --git a/include/net/udp.h b/include/net/udp.h
index d4d064c59232..df7cc1edc200 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -518,6 +518,7 @@  static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
 #ifdef CONFIG_BPF_SYSCALL
 struct sk_psock;
 struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
+int udp_bpf_update_proto(struct sock *sk, bool restore);
 #endif
 
 #endif	/* _UDP_H */
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index a045812d7c78..9fc83f7cc1a0 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -562,11 +562,6 @@  struct sk_psock *sk_psock_init(struct sock *sk, int node)
 
 	write_lock_bh(&sk->sk_callback_lock);
 
-	if (inet_csk_has_ulp(sk)) {
-		psock = ERR_PTR(-EINVAL);
-		goto out;
-	}
-
 	if (sk->sk_user_data) {
 		psock = ERR_PTR(-EBUSY);
 		goto out;
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index c2a0411e08a8..2915c7c8778b 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -185,26 +185,10 @@  static void sock_map_unref(struct sock *sk, void *link_raw)
 
 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 {
-	struct proto *prot;
-
-	switch (sk->sk_type) {
-	case SOCK_STREAM:
-		prot = tcp_bpf_get_proto(sk, psock);
-		break;
-
-	case SOCK_DGRAM:
-		prot = udp_bpf_get_proto(sk, psock);
-		break;
-
-	default:
+	if (!sk->sk_prot->psock_update_sk_prot)
 		return -EINVAL;
-	}
-
-	if (IS_ERR(prot))
-		return PTR_ERR(prot);
-
-	sk_psock_update_proto(sk, psock, prot);
-	return 0;
+	psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
+	return sk->sk_prot->psock_update_sk_prot(sk, false);
 }
 
 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
@@ -556,7 +540,7 @@  static bool sock_map_redirect_allowed(const struct sock *sk)
 
 static bool sock_map_sk_is_suitable(const struct sock *sk)
 {
-	return sk_is_tcp(sk) || sk_is_udp(sk);
+	return !!sk->sk_prot->psock_update_sk_prot;
 }
 
 static bool sock_map_sk_state_allowed(const struct sock *sk)
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index ae980716d896..ac8cfbaeacd2 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -595,20 +595,38 @@  static int tcp_bpf_assert_proto_ops(struct proto *ops)
 	       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+int tcp_bpf_update_proto(struct sock *sk, bool restore)
 {
+	struct sk_psock *psock = sk_psock(sk);
 	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
+	if (restore) {
+		if (inet_csk_has_ulp(sk)) {
+			tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
+		} else {
+			sk->sk_write_space = psock->saved_write_space;
+			/* Pairs with lockless read in sk_clone_lock() */
+			WRITE_ONCE(sk->sk_prot, psock->sk_proto);
+		}
+		return 0;
+	}
+
+	if (inet_csk_has_ulp(sk))
+		return -EINVAL;
+
 	if (sk->sk_family == AF_INET6) {
 		if (tcp_bpf_assert_proto_ops(psock->sk_proto))
-			return ERR_PTR(-EINVAL);
+			return -EINVAL;
 
 		tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 	}
 
-	return &tcp_bpf_prots[family][config];
+	/* Pairs with lockless read in sk_clone_lock() */
+	WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
+	return 0;
 }
+EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
 
 /* If a child got cloned from a listening socket that had tcp_bpf
  * protocol callbacks installed, we need to restore the callbacks to
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index daad4f99db32..dfc6d1c0e710 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -2806,6 +2806,9 @@  struct proto tcp_prot = {
 	.hash			= inet_hash,
 	.unhash			= inet_unhash,
 	.get_port		= inet_csk_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+	.psock_update_sk_prot	= tcp_bpf_update_proto,
+#endif
 	.enter_memory_pressure	= tcp_enter_memory_pressure,
 	.leave_memory_pressure	= tcp_leave_memory_pressure,
 	.stream_memory_free	= tcp_stream_memory_free,
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 4a0478b17243..38952aaee3a1 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -2849,6 +2849,9 @@  struct proto udp_prot = {
 	.unhash			= udp_lib_unhash,
 	.rehash			= udp_v4_rehash,
 	.get_port		= udp_v4_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+	.psock_update_sk_prot	= udp_bpf_update_proto,
+#endif
 	.memory_allocated	= &udp_memory_allocated,
 	.sysctl_mem		= sysctl_udp_mem,
 	.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_udp_wmem_min),
diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c
index 7a94791efc1a..6001f93cd3a0 100644
--- a/net/ipv4/udp_bpf.c
+++ b/net/ipv4/udp_bpf.c
@@ -41,12 +41,23 @@  static int __init udp_bpf_v4_build_proto(void)
 }
 core_initcall(udp_bpf_v4_build_proto);
 
-struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+int udp_bpf_update_proto(struct sock *sk, bool restore)
 {
 	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
+	struct sk_psock *psock = sk_psock(sk);
+
+	if (restore) {
+		sk->sk_write_space = psock->saved_write_space;
+		/* Pairs with lockless read in sk_clone_lock() */
+		WRITE_ONCE(sk->sk_prot, psock->sk_proto);
+		return 0;
+	}
 
 	if (sk->sk_family == AF_INET6)
 		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 
-	return &udp_bpf_prots[family];
+	/* Pairs with lockless read in sk_clone_lock() */
+	WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
+	return 0;
 }
+EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index d0f007741e8e..bff22d6ef516 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -2139,6 +2139,9 @@  struct proto tcpv6_prot = {
 	.hash			= inet6_hash,
 	.unhash			= inet_unhash,
 	.get_port		= inet_csk_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+	.psock_update_sk_prot	= tcp_bpf_update_proto,
+#endif
 	.enter_memory_pressure	= tcp_enter_memory_pressure,
 	.leave_memory_pressure	= tcp_leave_memory_pressure,
 	.stream_memory_free	= tcp_stream_memory_free,
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index d25e5a9252fd..ef2c75bb4771 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -1713,6 +1713,9 @@  struct proto udpv6_prot = {
 	.unhash			= udp_lib_unhash,
 	.rehash			= udp_v6_rehash,
 	.get_port		= udp_v6_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+	.psock_update_sk_prot	= udp_bpf_update_proto,
+#endif
 	.memory_allocated	= &udp_memory_allocated,
 	.sysctl_mem		= sysctl_udp_mem,
 	.sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_udp_wmem_min),