diff mbox series

[bpf-next,v2,2/9] sock: introduce sk_prot->update_proto()

Message ID 20210302023743.24123-3-xiyou.wangcong@gmail.com (mailing list archive)
State Superseded
Delegated to: BPF
Headers show
Series sockmap: introduce BPF_SK_SKB_VERDICT and support UDP | expand

Checks

Context Check Description
netdev/cover_letter success Link
netdev/fixes_present success Link
netdev/patch_count success Link
netdev/tree_selection success Clearly marked for bpf-next
netdev/subject_prefix success Link
netdev/cc_maintainers warning 11 maintainers not CCed: yoshfuji@linux-ipv6.org kuba@kernel.org davem@davemloft.net yhs@fb.com edumazet@google.com ast@kernel.org kpsingh@kernel.org songliubraving@fb.com kafai@fb.com andrii@kernel.org dsahern@kernel.org
netdev/source_inline success Was 0 now: 0
netdev/verify_signedoff success Link
netdev/module_param success Was 0 now: 0
netdev/build_32bit fail Errors and warnings before: 4299 this patch: 4304
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/verify_fixes success Link
netdev/checkpatch warning WARNING: line length of 86 exceeds 80 columns
netdev/build_allmodconfig_warn fail Errors and warnings before: 4719 this patch: 4724
netdev/header_inline success Link
netdev/stable success Stable not CCed

Commit Message

Cong Wang March 2, 2021, 2:37 a.m. UTC
From: Cong Wang <cong.wang@bytedance.com>

Currently sockmap calls into each protocol to update the struct
proto and replace it. This certainly won't work when the protocol
is implemented as a module, for example, AF_UNIX.

Introduce a new ops sk->sk_prot->update_proto(), so each protocol
can implement its own way to replace the struct proto. This also
helps get rid of symbol dependencies on CONFIG_INET.

Cc: John Fastabend <john.fastabend@gmail.com>
Cc: Daniel Borkmann <daniel@iogearbox.net>
Cc: Jakub Sitnicki <jakub@cloudflare.com>
Cc: Lorenz Bauer <lmb@cloudflare.com>
Signed-off-by: Cong Wang <cong.wang@bytedance.com>
---
 include/linux/skmsg.h | 18 +++---------------
 include/net/sock.h    |  3 +++
 include/net/tcp.h     |  1 +
 include/net/udp.h     |  1 +
 net/core/skmsg.c      |  5 -----
 net/core/sock_map.c   | 24 ++++--------------------
 net/ipv4/tcp_bpf.c    | 23 ++++++++++++++++++++---
 net/ipv4/tcp_ipv4.c   |  3 +++
 net/ipv4/udp.c        |  3 +++
 net/ipv4/udp_bpf.c    | 14 ++++++++++++--
 net/ipv6/tcp_ipv6.c   |  3 +++
 net/ipv6/udp.c        |  3 +++
 12 files changed, 56 insertions(+), 45 deletions(-)

Comments

Lorenz Bauer March 2, 2021, 4:22 p.m. UTC | #1
On Tue, 2 Mar 2021 at 02:37, Cong Wang <xiyou.wangcong@gmail.com> wrote:

...

> @@ -350,25 +351,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
>         }
>  }
>
> -static inline void sk_psock_update_proto(struct sock *sk,
> -                                        struct sk_psock *psock,
> -                                        struct proto *ops)
> -{
> -       /* Pairs with lockless read in sk_clone_lock() */
> -       WRITE_ONCE(sk->sk_prot, ops);
> -}
> -
>  static inline void sk_psock_restore_proto(struct sock *sk,
>                                           struct sk_psock *psock)
>  {
>         sk->sk_prot->unhash = psock->saved_unhash;

Not related to your patch set, but why do an extra restore of
sk_prot->unhash here? At this point sk->sk_prot is one of our tcp_bpf
/ udp_bpf protos, so overwriting that seems wrong?

> -       if (inet_csk_has_ulp(sk)) {
> -               tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
> -       } else {
> -               sk->sk_write_space = psock->saved_write_space;
> -               /* Pairs with lockless read in sk_clone_lock() */
> -               WRITE_ONCE(sk->sk_prot, psock->sk_proto);
> -       }
> +       if (psock->saved_update_proto)
> +               psock->saved_update_proto(sk, true);
>  }
>
>  static inline void sk_psock_set_state(struct sk_psock *psock,
> diff --git a/include/net/sock.h b/include/net/sock.h
> index 636810ddcd9b..0e8577c917e8 100644
> --- a/include/net/sock.h
> +++ b/include/net/sock.h
> @@ -1184,6 +1184,9 @@ struct proto {
>         void                    (*unhash)(struct sock *sk);
>         void                    (*rehash)(struct sock *sk);
>         int                     (*get_port)(struct sock *sk, unsigned short snum);
> +#ifdef CONFIG_BPF_SYSCALL
> +       int                     (*update_proto)(struct sock *sk, bool restore);

Kind of a nit, but this name suggests that the callback is a lot more
generic than it really is. The only thing you can use it for is to
prep the socket to be sockmap ready since we hardwire sockmap_unhash,
etc. It's also not at all clear that this only works if sk has an
sk_psock associated with it. Calling it without one would crash the
kernel since the update_proto functions don't check for !sk_psock.

Might as well call it install_sockmap_hooks or something and have a
valid sk_psock be passed in to the callback. Additionally, I'd prefer
if the function returned a struct proto * like it does at the moment.
That way we keep sk->sk_prot manipulation confined to the sockmap code
and don't have to copy paste it into every proto.

> diff --git a/net/core/sock_map.c b/net/core/sock_map.c
> index 3bddd9dd2da2..13d2af5bb81c 100644
> --- a/net/core/sock_map.c
> +++ b/net/core/sock_map.c
> @@ -184,26 +184,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
>
>  static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
>  {
> -       struct proto *prot;
> -
> -       switch (sk->sk_type) {
> -       case SOCK_STREAM:
> -               prot = tcp_bpf_get_proto(sk, psock);
> -               break;
> -
> -       case SOCK_DGRAM:
> -               prot = udp_bpf_get_proto(sk, psock);
> -               break;
> -
> -       default:
> +       if (!sk->sk_prot->update_proto)
>                 return -EINVAL;
> -       }
> -
> -       if (IS_ERR(prot))
> -               return PTR_ERR(prot);
> -
> -       sk_psock_update_proto(sk, psock, prot);
> -       return 0;
> +       psock->saved_update_proto = sk->sk_prot->update_proto;
> +       return sk->sk_prot->update_proto(sk, false);

I think reads / writes from sk_prot need READ_ONCE / WRITE_ONCE. We've
not been diligent about this so far, but I think it makes sense to be
careful in new code.
Cong Wang March 2, 2021, 6:23 p.m. UTC | #2
On Tue, Mar 2, 2021 at 8:22 AM Lorenz Bauer <lmb@cloudflare.com> wrote:
>
> On Tue, 2 Mar 2021 at 02:37, Cong Wang <xiyou.wangcong@gmail.com> wrote:
>
> ...
> >  static inline void sk_psock_restore_proto(struct sock *sk,
> >                                           struct sk_psock *psock)
> >  {
> >         sk->sk_prot->unhash = psock->saved_unhash;
>
> Not related to your patch set, but why do an extra restore of
> sk_prot->unhash here? At this point sk->sk_prot is one of our tcp_bpf
> / udp_bpf protos, so overwriting that seems wrong?

Good catch. It seems you are right, but I need a double check. And
yes, it is completely unrelated to my patch, as the current code has
the same problem.

>
> > -       if (inet_csk_has_ulp(sk)) {
> > -               tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
> > -       } else {
> > -               sk->sk_write_space = psock->saved_write_space;
> > -               /* Pairs with lockless read in sk_clone_lock() */
> > -               WRITE_ONCE(sk->sk_prot, psock->sk_proto);
> > -       }
> > +       if (psock->saved_update_proto)
> > +               psock->saved_update_proto(sk, true);
> >  }
> >
> >  static inline void sk_psock_set_state(struct sk_psock *psock,
> > diff --git a/include/net/sock.h b/include/net/sock.h
> > index 636810ddcd9b..0e8577c917e8 100644
> > --- a/include/net/sock.h
> > +++ b/include/net/sock.h
> > @@ -1184,6 +1184,9 @@ struct proto {
> >         void                    (*unhash)(struct sock *sk);
> >         void                    (*rehash)(struct sock *sk);
> >         int                     (*get_port)(struct sock *sk, unsigned short snum);
> > +#ifdef CONFIG_BPF_SYSCALL
> > +       int                     (*update_proto)(struct sock *sk, bool restore);
>
> Kind of a nit, but this name suggests that the callback is a lot more
> generic than it really is. The only thing you can use it for is to
> prep the socket to be sockmap ready since we hardwire sockmap_unhash,
> etc. It's also not at all clear that this only works if sk has an
> sk_psock associated with it. Calling it without one would crash the
> kernel since the update_proto functions don't check for !sk_psock.
>
> Might as well call it install_sockmap_hooks or something and have a
> valid sk_psock be passed in to the callback. Additionally, I'd prefer

For the name, sure, I am always open to better names. Not sure if
'install_sockmap_hooks' is a good name, I also want to express we
are overriding sk_prot. How about 'psock_update_sk_prot'?


> if the function returned a struct proto * like it does at the moment.
> That way we keep sk->sk_prot manipulation confined to the sockmap code
> and don't have to copy paste it into every proto.

Well, TCP seems too special to do this, as it could call tcp_update_ulp()
to update the proto.

>
> > diff --git a/net/core/sock_map.c b/net/core/sock_map.c
> > index 3bddd9dd2da2..13d2af5bb81c 100644
> > --- a/net/core/sock_map.c
> > +++ b/net/core/sock_map.c
> > @@ -184,26 +184,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
> >
> >  static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
> >  {
> > -       struct proto *prot;
> > -
> > -       switch (sk->sk_type) {
> > -       case SOCK_STREAM:
> > -               prot = tcp_bpf_get_proto(sk, psock);
> > -               break;
> > -
> > -       case SOCK_DGRAM:
> > -               prot = udp_bpf_get_proto(sk, psock);
> > -               break;
> > -
> > -       default:
> > +       if (!sk->sk_prot->update_proto)
> >                 return -EINVAL;
> > -       }
> > -
> > -       if (IS_ERR(prot))
> > -               return PTR_ERR(prot);
> > -
> > -       sk_psock_update_proto(sk, psock, prot);
> > -       return 0;
> > +       psock->saved_update_proto = sk->sk_prot->update_proto;
> > +       return sk->sk_prot->update_proto(sk, false);
>
> I think reads / writes from sk_prot need READ_ONCE / WRITE_ONCE. We've
> not been diligent about this so far, but I think it makes sense to be
> careful in new code.

Hmm, there are many places not using READ_ONCE/WRITE_ONCE,
for a quick example:

void sock_map_unhash(struct sock *sk)
{
        void (*saved_unhash)(struct sock *sk);
        struct sk_psock *psock;

        rcu_read_lock();
        psock = sk_psock(sk);
        if (unlikely(!psock)) {
                rcu_read_unlock();
                if (sk->sk_prot->unhash)
                        sk->sk_prot->unhash(sk);
                return;
        }

        saved_unhash = psock->saved_unhash;
        sock_map_remove_links(sk, psock);
        rcu_read_unlock();
        saved_unhash(sk);
}

Thanks.
Lorenz Bauer March 3, 2021, 9:35 a.m. UTC | #3
On Tue, 2 Mar 2021 at 18:23, Cong Wang <xiyou.wangcong@gmail.com> wrote:
>
> > if the function returned a struct proto * like it does at the moment.
> > That way we keep sk->sk_prot manipulation confined to the sockmap code
> > and don't have to copy paste it into every proto.
>
> Well, TCP seems too special to do this, as it could call tcp_update_ulp()
> to update the proto.

I had a quick look, tcp_bpf_update_proto is the only caller of tcp_update_ulp,
which in turn is the only caller of icsk_ulp_ops->update, which in turn is only
implemented as tls_update in tls_main.c. Turns out that tls_update
has another one of these calls:

} else {
    /* Pairs with lockless read in sk_clone_lock(). */
    WRITE_ONCE(sk->sk_prot, p);
    sk->sk_write_space = write_space;
}

Maybe it looks familiar? :o) I think it would be a worthwhile change.

>
> >
> > > diff --git a/net/core/sock_map.c b/net/core/sock_map.c
> > > index 3bddd9dd2da2..13d2af5bb81c 100644
> > > --- a/net/core/sock_map.c
> > > +++ b/net/core/sock_map.c
> > > @@ -184,26 +184,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
> > >
> > >  static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
> > >  {
> > > -       struct proto *prot;
> > > -
> > > -       switch (sk->sk_type) {
> > > -       case SOCK_STREAM:
> > > -               prot = tcp_bpf_get_proto(sk, psock);
> > > -               break;
> > > -
> > > -       case SOCK_DGRAM:
> > > -               prot = udp_bpf_get_proto(sk, psock);
> > > -               break;
> > > -
> > > -       default:
> > > +       if (!sk->sk_prot->update_proto)
> > >                 return -EINVAL;
> > > -       }
> > > -
> > > -       if (IS_ERR(prot))
> > > -               return PTR_ERR(prot);
> > > -
> > > -       sk_psock_update_proto(sk, psock, prot);
> > > -       return 0;
> > > +       psock->saved_update_proto = sk->sk_prot->update_proto;
> > > +       return sk->sk_prot->update_proto(sk, false);
> >
> > I think reads / writes from sk_prot need READ_ONCE / WRITE_ONCE. We've
> > not been diligent about this so far, but I think it makes sense to be
> > careful in new code.
>
> Hmm, there are many places not using READ_ONCE/WRITE_ONCE,
> for a quick example:

I know! I'll defer to John and Jakub.
Cong Wang March 3, 2021, 6:20 p.m. UTC | #4
On Wed, Mar 3, 2021 at 1:35 AM Lorenz Bauer <lmb@cloudflare.com> wrote:
>
> On Tue, 2 Mar 2021 at 18:23, Cong Wang <xiyou.wangcong@gmail.com> wrote:
> >
> > > if the function returned a struct proto * like it does at the moment.
> > > That way we keep sk->sk_prot manipulation confined to the sockmap code
> > > and don't have to copy paste it into every proto.
> >
> > Well, TCP seems too special to do this, as it could call tcp_update_ulp()
> > to update the proto.
>
> I had a quick look, tcp_bpf_update_proto is the only caller of tcp_update_ulp,
> which in turn is the only caller of icsk_ulp_ops->update, which in turn is only
> implemented as tls_update in tls_main.c. Turns out that tls_update
> has another one of these calls:
>
> } else {
>     /* Pairs with lockless read in sk_clone_lock(). */
>     WRITE_ONCE(sk->sk_prot, p);
>     sk->sk_write_space = write_space;
> }
>
> Maybe it looks familiar? :o) I think it would be a worthwhile change.

Yeah, I am not surprised we can change tcp_update_ulp() too, but
why should I bother kTLS when I do not have to? What you suggest
could at most save us a bit of code size, not a big gain. So, I'd keep
its return value as it is, unless you see any other benefits.

BTW, I will rename it to 'psock_update_sk_prot', please let me know
if you have any better names.

Thanks.
Lorenz Bauer March 4, 2021, 9:30 a.m. UTC | #5
On Wed, 3 Mar 2021 at 18:21, Cong Wang <xiyou.wangcong@gmail.com> wrote:
>
> Yeah, I am not surprised we can change tcp_update_ulp() too, but
> why should I bother kTLS when I do not have to? What you suggest
> could at most save us a bit of code size, not a big gain. So, I'd keep
> its return value as it is, unless you see any other benefits.

I think the end result is code that is easier to understand and
therefore maintain. Keep it as it is if you prefer.

> BTW, I will rename it to 'psock_update_sk_prot', please let me know
> if you have any better names.

SGTM.
Cong Wang March 4, 2021, 11:52 p.m. UTC | #6
On Tue, Mar 2, 2021 at 10:23 AM Cong Wang <xiyou.wangcong@gmail.com> wrote:
>
> On Tue, Mar 2, 2021 at 8:22 AM Lorenz Bauer <lmb@cloudflare.com> wrote:
> >
> > On Tue, 2 Mar 2021 at 02:37, Cong Wang <xiyou.wangcong@gmail.com> wrote:
> >
> > ...
> > >  static inline void sk_psock_restore_proto(struct sock *sk,
> > >                                           struct sk_psock *psock)
> > >  {
> > >         sk->sk_prot->unhash = psock->saved_unhash;
> >
> > Not related to your patch set, but why do an extra restore of
> > sk_prot->unhash here? At this point sk->sk_prot is one of our tcp_bpf
> > / udp_bpf protos, so overwriting that seems wrong?
>
> Good catch. It seems you are right, but I need a double check. And
> yes, it is completely unrelated to my patch, as the current code has
> the same problem.

Looking at this again. I noticed

commit 4da6a196f93b1af7612340e8c1ad8ce71e18f955
Author: John Fastabend <john.fastabend@gmail.com>
Date:   Sat Jan 11 06:11:59 2020 +0000

    bpf: Sockmap/tls, during free we may call tcp_bpf_unhash() in loop

intentionally fixed a bug in kTLS with overwriting this ->unhash.

I agree with you that it should not be updated for sockmap case,
however I don't know what to do with kTLS case, it seems the bug the
above commit fixed still exists if we just revert it.

Anyway, this should be targeted for -bpf as a bug fix, so it does not
belong to this patchset.

Thanks.
John Fastabend March 6, 2021, 12:27 a.m. UTC | #7
Cong Wang wrote:
> On Tue, Mar 2, 2021 at 10:23 AM Cong Wang <xiyou.wangcong@gmail.com> wrote:
> >
> > On Tue, Mar 2, 2021 at 8:22 AM Lorenz Bauer <lmb@cloudflare.com> wrote:
> > >
> > > On Tue, 2 Mar 2021 at 02:37, Cong Wang <xiyou.wangcong@gmail.com> wrote:
> > >
> > > ...
> > > >  static inline void sk_psock_restore_proto(struct sock *sk,
> > > >                                           struct sk_psock *psock)
> > > >  {
> > > >         sk->sk_prot->unhash = psock->saved_unhash;
> > >
> > > Not related to your patch set, but why do an extra restore of
> > > sk_prot->unhash here? At this point sk->sk_prot is one of our tcp_bpf
> > > / udp_bpf protos, so overwriting that seems wrong?

"extra"? restore_proto should only be called when the psock ref count
is zero and we need to transition back to the original socks proto
handlers. To trigger this we can simply delete a sock from the map.
In the case where we are deleting the psock overwriting the tcp_bpf
protos is exactly what we want.?

> >
> > Good catch. It seems you are right, but I need a double check. And
> > yes, it is completely unrelated to my patch, as the current code has
> > the same problem.
> 
> Looking at this again. I noticed
> 
> commit 4da6a196f93b1af7612340e8c1ad8ce71e18f955
> Author: John Fastabend <john.fastabend@gmail.com>
> Date:   Sat Jan 11 06:11:59 2020 +0000
> 
>     bpf: Sockmap/tls, during free we may call tcp_bpf_unhash() in loop
> 
> intentionally fixed a bug in kTLS with overwriting this ->unhash.
> 
> I agree with you that it should not be updated for sockmap case,
> however I don't know what to do with kTLS case, it seems the bug the
> above commit fixed still exists if we just revert it.
> 
> Anyway, this should be targeted for -bpf as a bug fix, so it does not
> belong to this patchset.
> 
> Thanks.

Hi,

I'm missing the error case here. The restore logic happens when the refcnt
hits 0 on the psock, indicating its time to garbage collect the psock. 

 sk_psock_put
   if (refcount_dec_and_test(&psock->refcnt))
    sk_psock_drop(sk, psock);
      sk_psock_restore_proto(sk, psock)
         sk->sk_prot->unhash = psock->saved_unhash

When sockets are initialized via sk_psock_init() we opulate the unhash field

 psock->saved_unhash = prot->unhash;

So we need to unwind this otherwise a future unhash() call would not call
the original protos unhash handler.

Care to give me some more context on what the bug is?

Thanks,
John
Cong Wang March 6, 2021, 12:57 a.m. UTC | #8
On Fri, Mar 5, 2021 at 4:27 PM John Fastabend <john.fastabend@gmail.com> wrote:
>
> Cong Wang wrote:
> > On Tue, Mar 2, 2021 at 10:23 AM Cong Wang <xiyou.wangcong@gmail.com> wrote:
> > >
> > > On Tue, Mar 2, 2021 at 8:22 AM Lorenz Bauer <lmb@cloudflare.com> wrote:
> > > >
> > > > On Tue, 2 Mar 2021 at 02:37, Cong Wang <xiyou.wangcong@gmail.com> wrote:
> > > >
> > > > ...
> > > > >  static inline void sk_psock_restore_proto(struct sock *sk,
> > > > >                                           struct sk_psock *psock)
> > > > >  {
> > > > >         sk->sk_prot->unhash = psock->saved_unhash;
> > > >
> > > > Not related to your patch set, but why do an extra restore of
> > > > sk_prot->unhash here? At this point sk->sk_prot is one of our tcp_bpf
> > > > / udp_bpf protos, so overwriting that seems wrong?
>
> "extra"? restore_proto should only be called when the psock ref count
> is zero and we need to transition back to the original socks proto
> handlers. To trigger this we can simply delete a sock from the map.
> In the case where we are deleting the psock overwriting the tcp_bpf
> protos is exactly what we want.?

Why do you want to overwrite tcp_bpf_prots->unhash? Overwriting
tcp_bpf_prots is correct, but overwriting tcp_bpf_prots->unhash is not.
Because once you overwrite it, the next time you use it to replace
sk->sk_prot, it would be a different one rather than sock_map_unhash():

// tcp_bpf_prots->unhash == sock_map_unhash
sk_psock_restore_proto();
// Now  tcp_bpf_prots->unhash is inet_unhash
...
sk_psock_update_proto();
// sk->sk_proto is now tcp_bpf_prots again,
// so its ->unhash now is inet_unhash
// but it should be sock_map_unhash here

Thanks.
John Fastabend March 6, 2021, 1:55 a.m. UTC | #9
Cong Wang wrote:
> On Fri, Mar 5, 2021 at 4:27 PM John Fastabend <john.fastabend@gmail.com> wrote:
> >
> > Cong Wang wrote:
> > > On Tue, Mar 2, 2021 at 10:23 AM Cong Wang <xiyou.wangcong@gmail.com> wrote:
> > > >
> > > > On Tue, Mar 2, 2021 at 8:22 AM Lorenz Bauer <lmb@cloudflare.com> wrote:
> > > > >
> > > > > On Tue, 2 Mar 2021 at 02:37, Cong Wang <xiyou.wangcong@gmail.com> wrote:
> > > > >
> > > > > ...
> > > > > >  static inline void sk_psock_restore_proto(struct sock *sk,
> > > > > >                                           struct sk_psock *psock)
> > > > > >  {
> > > > > >         sk->sk_prot->unhash = psock->saved_unhash;
> > > > >
> > > > > Not related to your patch set, but why do an extra restore of
> > > > > sk_prot->unhash here? At this point sk->sk_prot is one of our tcp_bpf
> > > > > / udp_bpf protos, so overwriting that seems wrong?
> >
> > "extra"? restore_proto should only be called when the psock ref count
> > is zero and we need to transition back to the original socks proto
> > handlers. To trigger this we can simply delete a sock from the map.
> > In the case where we are deleting the psock overwriting the tcp_bpf
> > protos is exactly what we want.?
> 
> Why do you want to overwrite tcp_bpf_prots->unhash? Overwriting
> tcp_bpf_prots is correct, but overwriting tcp_bpf_prots->unhash is not.
> Because once you overwrite it, the next time you use it to replace
> sk->sk_prot, it would be a different one rather than sock_map_unhash():
> 
> // tcp_bpf_prots->unhash == sock_map_unhash
> sk_psock_restore_proto();
> // Now  tcp_bpf_prots->unhash is inet_unhash
> ...
> sk_psock_update_proto();
> // sk->sk_proto is now tcp_bpf_prots again,
> // so its ->unhash now is inet_unhash
> // but it should be sock_map_unhash here

Right, we can fix this on the TLS side. I'll push a fix shortly.

> 
> Thanks.
Cong Wang March 9, 2021, 5:53 p.m. UTC | #10
On Fri, Mar 5, 2021 at 5:55 PM John Fastabend <john.fastabend@gmail.com> wrote:
>
> Cong Wang wrote:
> > On Fri, Mar 5, 2021 at 4:27 PM John Fastabend <john.fastabend@gmail.com> wrote:
> > >
> > > Cong Wang wrote:
> > > > On Tue, Mar 2, 2021 at 10:23 AM Cong Wang <xiyou.wangcong@gmail.com> wrote:
> > > > >
> > > > > On Tue, Mar 2, 2021 at 8:22 AM Lorenz Bauer <lmb@cloudflare.com> wrote:
> > > > > >
> > > > > > On Tue, 2 Mar 2021 at 02:37, Cong Wang <xiyou.wangcong@gmail.com> wrote:
> > > > > >
> > > > > > ...
> > > > > > >  static inline void sk_psock_restore_proto(struct sock *sk,
> > > > > > >                                           struct sk_psock *psock)
> > > > > > >  {
> > > > > > >         sk->sk_prot->unhash = psock->saved_unhash;
> > > > > >
> > > > > > Not related to your patch set, but why do an extra restore of
> > > > > > sk_prot->unhash here? At this point sk->sk_prot is one of our tcp_bpf
> > > > > > / udp_bpf protos, so overwriting that seems wrong?
> > >
> > > "extra"? restore_proto should only be called when the psock ref count
> > > is zero and we need to transition back to the original socks proto
> > > handlers. To trigger this we can simply delete a sock from the map.
> > > In the case where we are deleting the psock overwriting the tcp_bpf
> > > protos is exactly what we want.?
> >
> > Why do you want to overwrite tcp_bpf_prots->unhash? Overwriting
> > tcp_bpf_prots is correct, but overwriting tcp_bpf_prots->unhash is not.
> > Because once you overwrite it, the next time you use it to replace
> > sk->sk_prot, it would be a different one rather than sock_map_unhash():
> >
> > // tcp_bpf_prots->unhash == sock_map_unhash
> > sk_psock_restore_proto();
> > // Now  tcp_bpf_prots->unhash is inet_unhash
> > ...
> > sk_psock_update_proto();
> > // sk->sk_proto is now tcp_bpf_prots again,
> > // so its ->unhash now is inet_unhash
> > // but it should be sock_map_unhash here
>
> Right, we can fix this on the TLS side. I'll push a fix shortly.

Are you still working on this? If kTLS still needs it, then we can
have something like this:

diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 8edbbf5f2f93..5eb617df7f48 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -349,8 +349,8 @@ static inline void sk_psock_update_proto(struct sock *sk,
 static inline void sk_psock_restore_proto(struct sock *sk,
                                          struct sk_psock *psock)
 {
-       sk->sk_prot->unhash = psock->saved_unhash;
        if (inet_csk_has_ulp(sk)) {
+               sk->sk_prot->unhash = psock->saved_unhash;
                tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
        } else {
                sk->sk_write_space = psock->saved_write_space;


Thanks.
John Fastabend March 10, 2021, 6:33 a.m. UTC | #11
Cong Wang wrote:
> On Fri, Mar 5, 2021 at 5:55 PM John Fastabend <john.fastabend@gmail.com> wrote:
> >

[...]

> > > // tcp_bpf_prots->unhash == sock_map_unhash
> > > sk_psock_restore_proto();
> > > // Now  tcp_bpf_prots->unhash is inet_unhash
> > > ...
> > > sk_psock_update_proto();
> > > // sk->sk_proto is now tcp_bpf_prots again,
> > > // so its ->unhash now is inet_unhash
> > > // but it should be sock_map_unhash here
> >
> > Right, we can fix this on the TLS side. I'll push a fix shortly.
> 
> Are you still working on this? If kTLS still needs it, then we can
> have something like this:

Testing a fix now I will flush it out tomorrow. The below is not
really correct either it just moves the issue so it only impacts
TLS.

> 
> diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
> index 8edbbf5f2f93..5eb617df7f48 100644
> --- a/include/linux/skmsg.h
> +++ b/include/linux/skmsg.h
> @@ -349,8 +349,8 @@ static inline void sk_psock_update_proto(struct sock *sk,
>  static inline void sk_psock_restore_proto(struct sock *sk,
>                                           struct sk_psock *psock)
>  {
> -       sk->sk_prot->unhash = psock->saved_unhash;
>         if (inet_csk_has_ulp(sk)) {
> +               sk->sk_prot->unhash = psock->saved_unhash;
>                 tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
>         } else {
>                 sk->sk_write_space = psock->saved_write_space;
> 
> 
> Thanks.
diff mbox series

Patch

diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 451530d41af7..b5df69d5d397 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -98,6 +98,7 @@  struct sk_psock {
 	void (*saved_close)(struct sock *sk, long timeout);
 	void (*saved_write_space)(struct sock *sk);
 	void (*saved_data_ready)(struct sock *sk);
+	int  (*saved_update_proto)(struct sock *sk, bool restore);
 	struct proto			*sk_proto;
 	struct sk_psock_work_state	work_state;
 	struct work_struct		work;
@@ -350,25 +351,12 @@  static inline void sk_psock_cork_free(struct sk_psock *psock)
 	}
 }
 
-static inline void sk_psock_update_proto(struct sock *sk,
-					 struct sk_psock *psock,
-					 struct proto *ops)
-{
-	/* Pairs with lockless read in sk_clone_lock() */
-	WRITE_ONCE(sk->sk_prot, ops);
-}
-
 static inline void sk_psock_restore_proto(struct sock *sk,
 					  struct sk_psock *psock)
 {
 	sk->sk_prot->unhash = psock->saved_unhash;
-	if (inet_csk_has_ulp(sk)) {
-		tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
-	} else {
-		sk->sk_write_space = psock->saved_write_space;
-		/* Pairs with lockless read in sk_clone_lock() */
-		WRITE_ONCE(sk->sk_prot, psock->sk_proto);
-	}
+	if (psock->saved_update_proto)
+		psock->saved_update_proto(sk, true);
 }
 
 static inline void sk_psock_set_state(struct sk_psock *psock,
diff --git a/include/net/sock.h b/include/net/sock.h
index 636810ddcd9b..0e8577c917e8 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -1184,6 +1184,9 @@  struct proto {
 	void			(*unhash)(struct sock *sk);
 	void			(*rehash)(struct sock *sk);
 	int			(*get_port)(struct sock *sk, unsigned short snum);
+#ifdef CONFIG_BPF_SYSCALL
+	int			(*update_proto)(struct sock *sk, bool restore);
+#endif
 
 	/* Keeping track of sockets in use */
 #ifdef CONFIG_PROC_FS
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 075de26f449d..2efa4e5ea23d 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -2203,6 +2203,7 @@  struct sk_psock;
 
 #ifdef CONFIG_BPF_SYSCALL
 struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
+int tcp_bpf_update_proto(struct sock *sk, bool restore);
 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
 #endif /* CONFIG_BPF_SYSCALL */
 
diff --git a/include/net/udp.h b/include/net/udp.h
index d4d064c59232..df7cc1edc200 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -518,6 +518,7 @@  static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
 #ifdef CONFIG_BPF_SYSCALL
 struct sk_psock;
 struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
+int udp_bpf_update_proto(struct sock *sk, bool restore);
 #endif
 
 #endif	/* _UDP_H */
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index 5efd790f1b47..7dbd8344ec89 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -563,11 +563,6 @@  struct sk_psock *sk_psock_init(struct sock *sk, int node)
 
 	write_lock_bh(&sk->sk_callback_lock);
 
-	if (inet_csk_has_ulp(sk)) {
-		psock = ERR_PTR(-EINVAL);
-		goto out;
-	}
-
 	if (sk->sk_user_data) {
 		psock = ERR_PTR(-EBUSY);
 		goto out;
diff --git a/net/core/sock_map.c b/net/core/sock_map.c
index 3bddd9dd2da2..13d2af5bb81c 100644
--- a/net/core/sock_map.c
+++ b/net/core/sock_map.c
@@ -184,26 +184,10 @@  static void sock_map_unref(struct sock *sk, void *link_raw)
 
 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
 {
-	struct proto *prot;
-
-	switch (sk->sk_type) {
-	case SOCK_STREAM:
-		prot = tcp_bpf_get_proto(sk, psock);
-		break;
-
-	case SOCK_DGRAM:
-		prot = udp_bpf_get_proto(sk, psock);
-		break;
-
-	default:
+	if (!sk->sk_prot->update_proto)
 		return -EINVAL;
-	}
-
-	if (IS_ERR(prot))
-		return PTR_ERR(prot);
-
-	sk_psock_update_proto(sk, psock, prot);
-	return 0;
+	psock->saved_update_proto = sk->sk_prot->update_proto;
+	return sk->sk_prot->update_proto(sk, false);
 }
 
 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
@@ -570,7 +554,7 @@  static bool sock_map_redirect_allowed(const struct sock *sk)
 
 static bool sock_map_sk_is_suitable(const struct sock *sk)
 {
-	return sk_is_tcp(sk) || sk_is_udp(sk);
+	return !!sk->sk_prot->update_proto;
 }
 
 static bool sock_map_sk_state_allowed(const struct sock *sk)
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 17c322b875fd..737726c8138c 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -601,19 +601,36 @@  static int tcp_bpf_assert_proto_ops(struct proto *ops)
 	       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+int tcp_bpf_update_proto(struct sock *sk, bool restore)
 {
+	struct sk_psock *psock = sk_psock(sk);
 	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
 	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
+	if (restore) {
+		if (inet_csk_has_ulp(sk)) {
+			tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
+		} else {
+			sk->sk_write_space = psock->saved_write_space;
+			/* Pairs with lockless read in sk_clone_lock() */
+			WRITE_ONCE(sk->sk_prot, psock->sk_proto);
+		}
+		return 0;
+	}
+
+	if (inet_csk_has_ulp(sk))
+		return -EINVAL;
+
 	if (sk->sk_family == AF_INET6) {
 		if (tcp_bpf_assert_proto_ops(psock->sk_proto))
-			return ERR_PTR(-EINVAL);
+			return -EINVAL;
 
 		tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 	}
 
-	return &tcp_bpf_prots[family][config];
+	/* Pairs with lockless read in sk_clone_lock() */
+	WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
+	return 0;
 }
 
 /* If a child got cloned from a listening socket that had tcp_bpf
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index daad4f99db32..21c9e262d07c 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -2806,6 +2806,9 @@  struct proto tcp_prot = {
 	.hash			= inet_hash,
 	.unhash			= inet_unhash,
 	.get_port		= inet_csk_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+	.update_proto		= tcp_bpf_update_proto,
+#endif
 	.enter_memory_pressure	= tcp_enter_memory_pressure,
 	.leave_memory_pressure	= tcp_leave_memory_pressure,
 	.stream_memory_free	= tcp_stream_memory_free,
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 4a0478b17243..dbd25b59ce0e 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -2849,6 +2849,9 @@  struct proto udp_prot = {
 	.unhash			= udp_lib_unhash,
 	.rehash			= udp_v4_rehash,
 	.get_port		= udp_v4_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+	.update_proto		= udp_bpf_update_proto,
+#endif
 	.memory_allocated	= &udp_memory_allocated,
 	.sysctl_mem		= sysctl_udp_mem,
 	.sysctl_wmem_offset	= offsetof(struct net, ipv4.sysctl_udp_wmem_min),
diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c
index 7a94791efc1a..595836088e85 100644
--- a/net/ipv4/udp_bpf.c
+++ b/net/ipv4/udp_bpf.c
@@ -41,12 +41,22 @@  static int __init udp_bpf_v4_build_proto(void)
 }
 core_initcall(udp_bpf_v4_build_proto);
 
-struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
+int udp_bpf_update_proto(struct sock *sk, bool restore)
 {
 	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
+	struct sk_psock *psock = sk_psock(sk);
+
+	if (restore) {
+		sk->sk_write_space = psock->saved_write_space;
+		/* Pairs with lockless read in sk_clone_lock() */
+		WRITE_ONCE(sk->sk_prot, psock->sk_proto);
+		return 0;
+	}
 
 	if (sk->sk_family == AF_INET6)
 		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 
-	return &udp_bpf_prots[family];
+	/* Pairs with lockless read in sk_clone_lock() */
+	WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
+	return 0;
 }
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index bd44ded7e50c..ea5be7e7fcb8 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -2134,6 +2134,9 @@  struct proto tcpv6_prot = {
 	.hash			= inet6_hash,
 	.unhash			= inet_unhash,
 	.get_port		= inet_csk_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+	.update_proto		= tcp_bpf_update_proto,
+#endif
 	.enter_memory_pressure	= tcp_enter_memory_pressure,
 	.leave_memory_pressure	= tcp_leave_memory_pressure,
 	.stream_memory_free	= tcp_stream_memory_free,
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index d25e5a9252fd..105ba0cf739d 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -1713,6 +1713,9 @@  struct proto udpv6_prot = {
 	.unhash			= udp_lib_unhash,
 	.rehash			= udp_v6_rehash,
 	.get_port		= udp_v6_get_port,
+#ifdef CONFIG_BPF_SYSCALL
+	.update_proto		= udp_bpf_update_proto,
+#endif
 	.memory_allocated	= &udp_memory_allocated,
 	.sysctl_mem		= sysctl_udp_mem,
 	.sysctl_wmem_offset     = offsetof(struct net, ipv4.sysctl_udp_wmem_min),