diff mbox series

[2/2] net/tcp: Disable TCP-MD5 static key on tcp_md5sig_info destruction

Message ID 20221102211350.625011-3-dima@arista.com (mailing list archive)
State Superseded
Delegated to: Netdev Maintainers
Headers show
Series net/tcp: Dynamically disable TCP-MD5 static key | expand

Checks

Context Check Description
netdev/tree_selection success Guessed tree name to be net-next, async
netdev/fixes_present success Fixes tag not required for -next series
netdev/subject_prefix warning Target tree name not specified in the subject
netdev/cover_letter success Series has a cover letter
netdev/patch_count success Link
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 994 this patch: 994
netdev/cc_maintainers success CCed 7 of 7 maintainers
netdev/build_clang success Errors and warnings before: 120 this patch: 120
netdev/module_param success Was 0 now: 0
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 999 this patch: 999
netdev/checkpatch warning WARNING: line length of 82 exceeds 80 columns WARNING: line length of 84 exceeds 80 columns
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Dmitry Safonov Nov. 2, 2022, 9:13 p.m. UTC
To do that, separate two scenarios:
- where it's the first MD5 key on the system, which means that enabling
  of the static key may need to sleep;
- copying of an existing key from a listening socket to the request
  socket upon receiving a signed TCP segment, where static key was
  already enabled (when the key was added to the listening socket).

Now the life-time of the static branch for TCP-MD5 is until:
- last tcp_md5sig_info is destroyed
- last socket in time-wait state with MD5 key is closed.

Which means that after all sockets with TCP-MD5 keys are gone, the
system gets back the performance of disabled md5-key static branch.

Signed-off-by: Dmitry Safonov <dima@arista.com>
---
 include/net/tcp.h        | 10 ++++---
 net/ipv4/tcp.c           |  5 +---
 net/ipv4/tcp_ipv4.c      | 56 ++++++++++++++++++++++++++++++----------
 net/ipv4/tcp_minisocks.c |  9 ++++---
 net/ipv4/tcp_output.c    |  4 +--
 net/ipv6/tcp_ipv6.c      | 10 +++----
 6 files changed, 63 insertions(+), 31 deletions(-)

Comments

Eric Dumazet Nov. 2, 2022, 9:25 p.m. UTC | #1
On Wed, Nov 2, 2022 at 2:14 PM Dmitry Safonov <dima@arista.com> wrote:
>
> To do that, separate two scenarios:
> - where it's the first MD5 key on the system, which means that enabling
>   of the static key may need to sleep;
> - copying of an existing key from a listening socket to the request
>   socket upon receiving a signed TCP segment, where static key was
>   already enabled (when the key was added to the listening socket).
>
> Now the life-time of the static branch for TCP-MD5 is until:
> - last tcp_md5sig_info is destroyed
> - last socket in time-wait state with MD5 key is closed.
>
> Which means that after all sockets with TCP-MD5 keys are gone, the
> system gets back the performance of disabled md5-key static branch.
>
> Signed-off-by: Dmitry Safonov <dima@arista.com>
> ---
>  include/net/tcp.h        | 10 ++++---
>  net/ipv4/tcp.c           |  5 +---
>  net/ipv4/tcp_ipv4.c      | 56 ++++++++++++++++++++++++++++++----------
>  net/ipv4/tcp_minisocks.c |  9 ++++---
>  net/ipv4/tcp_output.c    |  4 +--
>  net/ipv6/tcp_ipv6.c      | 10 +++----
>  6 files changed, 63 insertions(+), 31 deletions(-)
>
> diff --git a/include/net/tcp.h b/include/net/tcp.h
> index 14d45661a84d..a0cdf013782a 100644
> --- a/include/net/tcp.h
> +++ b/include/net/tcp.h
> @@ -1675,7 +1675,11 @@ int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
>                         const struct sock *sk, const struct sk_buff *skb);
>  int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
>                    int family, u8 prefixlen, int l3index, u8 flags,
> -                  const u8 *newkey, u8 newkeylen, gfp_t gfp);
> +                  const u8 *newkey, u8 newkeylen);
> +int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
> +                    int family, u8 prefixlen, int l3index,
> +                    struct tcp_md5sig_key *key);
> +
>  int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
>                    int family, u8 prefixlen, int l3index, u8 flags);
>  struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
> @@ -1683,7 +1687,7 @@ struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
>
>  #ifdef CONFIG_TCP_MD5SIG
>  #include <linux/jump_label.h>
> -extern struct static_key_false tcp_md5_needed;
> +extern struct static_key_false_deferred tcp_md5_needed;
>  struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
>                                            const union tcp_md5_addr *addr,
>                                            int family);
> @@ -1691,7 +1695,7 @@ static inline struct tcp_md5sig_key *
>  tcp_md5_do_lookup(const struct sock *sk, int l3index,
>                   const union tcp_md5_addr *addr, int family)
>  {
> -       if (!static_branch_unlikely(&tcp_md5_needed))
> +       if (!static_branch_unlikely(&tcp_md5_needed.key))
>                 return NULL;
>         return __tcp_md5_do_lookup(sk, l3index, addr, family);
>  }
> diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
> index ef14efa1fb70..936ed566cc89 100644
> --- a/net/ipv4/tcp.c
> +++ b/net/ipv4/tcp.c
> @@ -4460,11 +4460,8 @@ bool tcp_alloc_md5sig_pool(void)
>         if (unlikely(!READ_ONCE(tcp_md5sig_pool_populated))) {
>                 mutex_lock(&tcp_md5sig_mutex);
>
> -               if (!tcp_md5sig_pool_populated) {
> +               if (!tcp_md5sig_pool_populated)
>                         __tcp_alloc_md5sig_pool();
> -                       if (tcp_md5sig_pool_populated)
> -                               static_branch_inc(&tcp_md5_needed);
> -               }
>
>                 mutex_unlock(&tcp_md5sig_mutex);
>         }
> diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
> index fae80b1a1796..f812d507fc9a 100644
> --- a/net/ipv4/tcp_ipv4.c
> +++ b/net/ipv4/tcp_ipv4.c
> @@ -1064,7 +1064,7 @@ static void tcp_v4_reqsk_destructor(struct request_sock *req)
>   * We need to maintain these in the sk structure.
>   */
>
> -DEFINE_STATIC_KEY_FALSE(tcp_md5_needed);
> +DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_md5_needed, HZ);
>  EXPORT_SYMBOL(tcp_md5_needed);
>
>  static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new)
> @@ -1177,9 +1177,6 @@ static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp)
>         struct tcp_sock *tp = tcp_sk(sk);
>         struct tcp_md5sig_info *md5sig;
>
> -       if (rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)))
> -               return 0;
> -
>         md5sig = kmalloc(sizeof(*md5sig), gfp);
>         if (!md5sig)
>                 return -ENOMEM;
> @@ -1191,9 +1188,9 @@ static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp)
>  }
>
>  /* This can be called on a newly created socket, from other files */
> -int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
> -                  int family, u8 prefixlen, int l3index, u8 flags,
> -                  const u8 *newkey, u8 newkeylen, gfp_t gfp)
> +static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
> +                           int family, u8 prefixlen, int l3index, u8 flags,
> +                           const u8 *newkey, u8 newkeylen, gfp_t gfp)
>  {
>         /* Add Key to the list */
>         struct tcp_md5sig_key *key;
> @@ -1220,9 +1217,6 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
>                 return 0;
>         }
>
> -       if (tcp_md5sig_info_add(sk, gfp))
> -               return -ENOMEM;
> -
>         md5sig = rcu_dereference_protected(tp->md5sig_info,
>                                            lockdep_sock_is_held(sk));
>
> @@ -1246,8 +1240,44 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
>         hlist_add_head_rcu(&key->node, &md5sig->head);
>         return 0;
>  }
> +
> +int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
> +                  int family, u8 prefixlen, int l3index, u8 flags,
> +                  const u8 *newkey, u8 newkeylen)
> +{
> +       struct tcp_sock *tp = tcp_sk(sk);
> +
> +       if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
> +               if (tcp_md5sig_info_add(sk, GFP_KERNEL))
> +                       return -ENOMEM;
> +
> +               static_branch_inc(&tcp_md5_needed.key);
> +       }
> +
> +       return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags,
> +                               newkey, newkeylen, GFP_KERNEL);
> +}
>  EXPORT_SYMBOL(tcp_md5_do_add);
>
> +int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
> +                    int family, u8 prefixlen, int l3index,
> +                    struct tcp_md5sig_key *key)
> +{
> +       struct tcp_sock *tp = tcp_sk(sk);
> +
> +       if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
> +               if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC)))
> +                       return -ENOMEM;
> +
> +               atomic_inc(&tcp_md5_needed.key.key.enabled);

static_branch_inc ?

> +       }
> +
> +       return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index,
> +                               key->flags, key->key, key->keylen,
> +                               sk_gfp_mask(sk, GFP_ATOMIC));
> +}
> +EXPORT_SYMBOL(tcp_md5_key_copy);
> +
>  int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
>                    u8 prefixlen, int l3index, u8 flags)
>  {
> @@ -1334,7 +1364,7 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
>                 return -EINVAL;
>
>         return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags,
> -                             cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
> +                             cmd.tcpm_key, cmd.tcpm_keylen);
>  }
>
>  static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
> @@ -1591,8 +1621,7 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
>                  * memory, then we end up not copying the key
>                  * across. Shucks.
>                  */
> -               tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags,
> -                              key->key, key->keylen, GFP_ATOMIC);
> +               tcp_md5_key_copy(newsk, addr, AF_INET, 32, l3index, key);
>                 sk_gso_disable(newsk);
>         }
>  #endif
> @@ -2284,6 +2313,7 @@ void tcp_v4_destroy_sock(struct sock *sk)
>                 tcp_clear_md5_list(sk);
>                 kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu);
>                 tp->md5sig_info = NULL;
> +               static_branch_slow_dec_deferred(&tcp_md5_needed);
>         }
>  #endif
>
> diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
> index c375f603a16c..fb500160b8d2 100644
> --- a/net/ipv4/tcp_minisocks.c
> +++ b/net/ipv4/tcp_minisocks.c
> @@ -291,13 +291,14 @@ void tcp_time_wait(struct sock *sk, int state, int timeo)
>                  */
>                 do {
>                         tcptw->tw_md5_key = NULL;
> -                       if (static_branch_unlikely(&tcp_md5_needed)) {
> +                       if (static_branch_unlikely(&tcp_md5_needed.key)) {
>                                 struct tcp_md5sig_key *key;
>
>                                 key = tp->af_specific->md5_lookup(sk, sk);
>                                 if (key) {
>                                         tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
>                                         BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool());
> +                                       atomic_inc(&tcp_md5_needed.key.key.enabled);

static_branch_inc

>                                 }
>                         }
>                 } while (0);
> @@ -337,11 +338,13 @@ EXPORT_SYMBOL(tcp_time_wait);
>  void tcp_twsk_destructor(struct sock *sk)
>  {
>  #ifdef CONFIG_TCP_MD5SIG
> -       if (static_branch_unlikely(&tcp_md5_needed)) {
> +       if (static_branch_unlikely(&tcp_md5_needed.key)) {
>                 struct tcp_timewait_sock *twsk = tcp_twsk(sk);
>
> -               if (twsk->tw_md5_key)
> +               if (twsk->tw_md5_key) {

Orthogonal to this patch, but I wonder why we do not clear
twsk->tw_md5_key before kfree_rcu()

It seems a lookup could catch the invalid pointer.

>                         kfree_rcu(twsk->tw_md5_key, rcu);
> +                       static_branch_slow_dec_deferred(&tcp_md5_needed);
> +               }
>         }
>  #endif
>  }
> diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
> index c69f4d966024..86e71c8c76bc 100644
> --- a/net/ipv4/tcp_output.c
> +++ b/net/ipv4/tcp_output.c
> @@ -766,7 +766,7 @@ static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb,
>
>         *md5 = NULL;
>  #ifdef CONFIG_TCP_MD5SIG
> -       if (static_branch_unlikely(&tcp_md5_needed) &&
> +       if (static_branch_unlikely(&tcp_md5_needed.key) &&
>             rcu_access_pointer(tp->md5sig_info)) {
>                 *md5 = tp->af_specific->md5_lookup(sk, sk);
>                 if (*md5) {
> @@ -922,7 +922,7 @@ static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb
>
>         *md5 = NULL;
>  #ifdef CONFIG_TCP_MD5SIG
> -       if (static_branch_unlikely(&tcp_md5_needed) &&
> +       if (static_branch_unlikely(&tcp_md5_needed.key) &&
>             rcu_access_pointer(tp->md5sig_info)) {
>                 *md5 = tp->af_specific->md5_lookup(sk, sk);
>                 if (*md5) {
> diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
> index 2a3f9296df1e..3e3bdc120fc8 100644
> --- a/net/ipv6/tcp_ipv6.c
> +++ b/net/ipv6/tcp_ipv6.c
> @@ -677,12 +677,11 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
>         if (ipv6_addr_v4mapped(&sin6->sin6_addr))
>                 return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
>                                       AF_INET, prefixlen, l3index, flags,
> -                                     cmd.tcpm_key, cmd.tcpm_keylen,
> -                                     GFP_KERNEL);
> +                                     cmd.tcpm_key, cmd.tcpm_keylen);
>
>         return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
>                               AF_INET6, prefixlen, l3index, flags,
> -                             cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
> +                             cmd.tcpm_key, cmd.tcpm_keylen);
>  }
>
>  static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
> @@ -1382,9 +1381,8 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
>                  * memory, then we end up not copying the key
>                  * across. Shucks.
>                  */
> -               tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
> -                              AF_INET6, 128, l3index, key->flags, key->key, key->keylen,
> -                              sk_gfp_mask(sk, GFP_ATOMIC));
> +               tcp_md5_key_copy(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
> +                                AF_INET6, 128, l3index, key);
>         }
>  #endif
>
> --
> 2.38.1
>
Dmitry Safonov Nov. 2, 2022, 9:40 p.m. UTC | #2
On 11/2/22 21:25, Eric Dumazet wrote:
> On Wed, Nov 2, 2022 at 2:14 PM Dmitry Safonov <dima@arista.com> wrote:
[..]
>> +int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
>> +                  int family, u8 prefixlen, int l3index, u8 flags,
>> +                  const u8 *newkey, u8 newkeylen)
>> +{
>> +       struct tcp_sock *tp = tcp_sk(sk);
>> +
>> +       if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
>> +               if (tcp_md5sig_info_add(sk, GFP_KERNEL))
>> +                       return -ENOMEM;
>> +
>> +               static_branch_inc(&tcp_md5_needed.key);
>> +       }
>> +
>> +       return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags,
>> +                               newkey, newkeylen, GFP_KERNEL);
>> +}
>>  EXPORT_SYMBOL(tcp_md5_do_add);
>>
>> +int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
>> +                    int family, u8 prefixlen, int l3index,
>> +                    struct tcp_md5sig_key *key)
>> +{
>> +       struct tcp_sock *tp = tcp_sk(sk);
>> +
>> +       if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
>> +               if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC)))
>> +                       return -ENOMEM;
>> +
>> +               atomic_inc(&tcp_md5_needed.key.key.enabled);
> 
> static_branch_inc ?

That's the difference between tcp_md5_do_add() and tcp_md5_key_copy():
the first one can sleep on either allocation or static branch patching,
while the second one is used where there is md5 key and it can't get
destroyed during the function call. tcp_md5_key_copy() is called
somewhere from the softirq handler so it needs an atomic allocation as
well as this a little bit hacky part.

[..]
>> @@ -337,11 +338,13 @@ EXPORT_SYMBOL(tcp_time_wait);
>>  void tcp_twsk_destructor(struct sock *sk)
>>  {
>>  #ifdef CONFIG_TCP_MD5SIG
>> -       if (static_branch_unlikely(&tcp_md5_needed)) {
>> +       if (static_branch_unlikely(&tcp_md5_needed.key)) {
>>                 struct tcp_timewait_sock *twsk = tcp_twsk(sk);
>>
>> -               if (twsk->tw_md5_key)
>> +               if (twsk->tw_md5_key) {
> 
> Orthogonal to this patch, but I wonder why we do not clear
> twsk->tw_md5_key before kfree_rcu()
> 
> It seems a lookup could catch the invalid pointer.
> 
>>                         kfree_rcu(twsk->tw_md5_key, rcu);
>> +                       static_branch_slow_dec_deferred(&tcp_md5_needed);
>> +               }
>>         }
>>  #endif

A good question, let me check on this.

Thanks for the review,
          Dmitry
Eric Dumazet Nov. 2, 2022, 9:49 p.m. UTC | #3
On Wed, Nov 2, 2022 at 2:40 PM Dmitry Safonov <dima@arista.com> wrote:
>
> On 11/2/22 21:25, Eric Dumazet wrote:
> > On Wed, Nov 2, 2022 at 2:14 PM Dmitry Safonov <dima@arista.com> wrote:
> [..]
> >> +int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
> >> +                  int family, u8 prefixlen, int l3index, u8 flags,
> >> +                  const u8 *newkey, u8 newkeylen)
> >> +{
> >> +       struct tcp_sock *tp = tcp_sk(sk);
> >> +
> >> +       if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
> >> +               if (tcp_md5sig_info_add(sk, GFP_KERNEL))
> >> +                       return -ENOMEM;
> >> +
> >> +               static_branch_inc(&tcp_md5_needed.key);
> >> +       }
> >> +
> >> +       return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags,
> >> +                               newkey, newkeylen, GFP_KERNEL);
> >> +}
> >>  EXPORT_SYMBOL(tcp_md5_do_add);
> >>
> >> +int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
> >> +                    int family, u8 prefixlen, int l3index,
> >> +                    struct tcp_md5sig_key *key)
> >> +{
> >> +       struct tcp_sock *tp = tcp_sk(sk);
> >> +
> >> +       if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
> >> +               if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC)))
> >> +                       return -ENOMEM;
> >> +
> >> +               atomic_inc(&tcp_md5_needed.key.key.enabled);
> >
> > static_branch_inc ?
>
> That's the difference between tcp_md5_do_add() and tcp_md5_key_copy():
> the first one can sleep on either allocation or static branch patching,
> while the second one is used where there is md5 key and it can't get
> destroyed during the function call. tcp_md5_key_copy() is called
> somewhere from the softirq handler so it needs an atomic allocation as
> well as this a little bit hacky part.
>

Are you sure ?

static_branch_inc() is what we want here, it is a nice wrapper around
the correct internal details,
and ultimately boils to an atomic_inc(). It is safe for all contexts.

But if/when jump labels get refcount_t one day, we will not have to
change TCP stack because
it made some implementation assumptions.
Eric Dumazet Nov. 2, 2022, 9:53 p.m. UTC | #4
On Wed, Nov 2, 2022 at 2:49 PM Eric Dumazet <edumazet@google.com> wrote:

>
> Are you sure ?
>
> static_branch_inc() is what we want here, it is a nice wrapper around
> the correct internal details,
> and ultimately boils to an atomic_inc(). It is safe for all contexts.
>
> But if/when jump labels get refcount_t one day, we will not have to
> change TCP stack because
> it made some implementation assumptions.

Oh, I think I understand this better now.

Please provide a helper like

static inline void static_key_fast_inc(struct static_key *key)
{
       atomic_inc(&key->enabled);
}

Something like that.
Dmitry Safonov Nov. 3, 2022, 3:40 p.m. UTC | #5
On 11/2/22 21:53, Eric Dumazet wrote:
> On Wed, Nov 2, 2022 at 2:49 PM Eric Dumazet <edumazet@google.com> wrote:
> 
>>
>> Are you sure ?
>>
>> static_branch_inc() is what we want here, it is a nice wrapper around
>> the correct internal details,
>> and ultimately boils to an atomic_inc(). It is safe for all contexts.
>>
>> But if/when jump labels get refcount_t one day, we will not have to
>> change TCP stack because
>> it made some implementation assumptions.
> 
> Oh, I think I understand this better now.
> 
> Please provide a helper like
> 
> static inline void static_key_fast_inc(struct static_key *key)
> {
>        atomic_inc(&key->enabled);
> }
> 
> Something like that.

Sure, that sounds like a better thing to do, rather than the hack I had.

Thanks, will send v2 soon,
          Dmitry
Dmitry Safonov Nov. 3, 2022, 4:53 p.m. UTC | #6
On 11/2/22 21:25, Eric Dumazet wrote:
> On Wed, Nov 2, 2022 at 2:14 PM Dmitry Safonov <dima@arista.com> wrote:
[..]
>> @@ -337,11 +338,13 @@ EXPORT_SYMBOL(tcp_time_wait);
>>  void tcp_twsk_destructor(struct sock *sk)
>>  {
>>  #ifdef CONFIG_TCP_MD5SIG
>> -       if (static_branch_unlikely(&tcp_md5_needed)) {
>> +       if (static_branch_unlikely(&tcp_md5_needed.key)) {
>>                 struct tcp_timewait_sock *twsk = tcp_twsk(sk);
>>
>> -               if (twsk->tw_md5_key)
>> +               if (twsk->tw_md5_key) {
> 
> Orthogonal to this patch, but I wonder why we do not clear
> twsk->tw_md5_key before kfree_rcu()
> 
> It seems a lookup could catch the invalid pointer.
> 
>>                         kfree_rcu(twsk->tw_md5_key, rcu);
>> +                       static_branch_slow_dec_deferred(&tcp_md5_needed);
>> +               }
>>         }

I looked into that, it seems tcp_twsk_destructor() is called from
inet_twsk_free(), which is either called from:
1. inet_twsk_put(), protected by tw->tw_refcnt
2. sock_gen_put(), protected by the same sk->sk_refcnt

So, in result, if I understand correctly, lookups should fail on ref
counter check. Maybe I'm missing something, but clearing here seems not
necessary?

I can add rcu_assign_pointer() just in case the destruction path changes
in v2 if you think it's worth it :-)

Thanks,
          Dmitry
Eric Dumazet Nov. 3, 2022, 5:04 p.m. UTC | #7
On Thu, Nov 3, 2022 at 9:53 AM Dmitry Safonov <dima@arista.com> wrote:
>
> On 11/2/22 21:25, Eric Dumazet wrote:
> > On Wed, Nov 2, 2022 at 2:14 PM Dmitry Safonov <dima@arista.com> wrote:
> [..]
> >> @@ -337,11 +338,13 @@ EXPORT_SYMBOL(tcp_time_wait);
> >>  void tcp_twsk_destructor(struct sock *sk)
> >>  {
> >>  #ifdef CONFIG_TCP_MD5SIG
> >> -       if (static_branch_unlikely(&tcp_md5_needed)) {
> >> +       if (static_branch_unlikely(&tcp_md5_needed.key)) {
> >>                 struct tcp_timewait_sock *twsk = tcp_twsk(sk);
> >>
> >> -               if (twsk->tw_md5_key)
> >> +               if (twsk->tw_md5_key) {
> >
> > Orthogonal to this patch, but I wonder why we do not clear
> > twsk->tw_md5_key before kfree_rcu()
> >
> > It seems a lookup could catch the invalid pointer.
> >
> >>                         kfree_rcu(twsk->tw_md5_key, rcu);
> >> +                       static_branch_slow_dec_deferred(&tcp_md5_needed);
> >> +               }
> >>         }
>
> I looked into that, it seems tcp_twsk_destructor() is called from
> inet_twsk_free(), which is either called from:
> 1. inet_twsk_put(), protected by tw->tw_refcnt
> 2. sock_gen_put(), protected by the same sk->sk_refcnt
>
> So, in result, if I understand correctly, lookups should fail on ref
> counter check. Maybe I'm missing something, but clearing here seems not
> necessary?
>
> I can add rcu_assign_pointer() just in case the destruction path changes
> in v2 if you think it's worth it :-)

Agree, this seems fine.
diff mbox series

Patch

diff --git a/include/net/tcp.h b/include/net/tcp.h
index 14d45661a84d..a0cdf013782a 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -1675,7 +1675,11 @@  int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
 			const struct sock *sk, const struct sk_buff *skb);
 int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
 		   int family, u8 prefixlen, int l3index, u8 flags,
-		   const u8 *newkey, u8 newkeylen, gfp_t gfp);
+		   const u8 *newkey, u8 newkeylen);
+int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
+		     int family, u8 prefixlen, int l3index,
+		     struct tcp_md5sig_key *key);
+
 int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
 		   int family, u8 prefixlen, int l3index, u8 flags);
 struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
@@ -1683,7 +1687,7 @@  struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
 
 #ifdef CONFIG_TCP_MD5SIG
 #include <linux/jump_label.h>
-extern struct static_key_false tcp_md5_needed;
+extern struct static_key_false_deferred tcp_md5_needed;
 struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
 					   const union tcp_md5_addr *addr,
 					   int family);
@@ -1691,7 +1695,7 @@  static inline struct tcp_md5sig_key *
 tcp_md5_do_lookup(const struct sock *sk, int l3index,
 		  const union tcp_md5_addr *addr, int family)
 {
-	if (!static_branch_unlikely(&tcp_md5_needed))
+	if (!static_branch_unlikely(&tcp_md5_needed.key))
 		return NULL;
 	return __tcp_md5_do_lookup(sk, l3index, addr, family);
 }
diff --git a/net/ipv4/tcp.c b/net/ipv4/tcp.c
index ef14efa1fb70..936ed566cc89 100644
--- a/net/ipv4/tcp.c
+++ b/net/ipv4/tcp.c
@@ -4460,11 +4460,8 @@  bool tcp_alloc_md5sig_pool(void)
 	if (unlikely(!READ_ONCE(tcp_md5sig_pool_populated))) {
 		mutex_lock(&tcp_md5sig_mutex);
 
-		if (!tcp_md5sig_pool_populated) {
+		if (!tcp_md5sig_pool_populated)
 			__tcp_alloc_md5sig_pool();
-			if (tcp_md5sig_pool_populated)
-				static_branch_inc(&tcp_md5_needed);
-		}
 
 		mutex_unlock(&tcp_md5sig_mutex);
 	}
diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
index fae80b1a1796..f812d507fc9a 100644
--- a/net/ipv4/tcp_ipv4.c
+++ b/net/ipv4/tcp_ipv4.c
@@ -1064,7 +1064,7 @@  static void tcp_v4_reqsk_destructor(struct request_sock *req)
  * We need to maintain these in the sk structure.
  */
 
-DEFINE_STATIC_KEY_FALSE(tcp_md5_needed);
+DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_md5_needed, HZ);
 EXPORT_SYMBOL(tcp_md5_needed);
 
 static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new)
@@ -1177,9 +1177,6 @@  static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp)
 	struct tcp_sock *tp = tcp_sk(sk);
 	struct tcp_md5sig_info *md5sig;
 
-	if (rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)))
-		return 0;
-
 	md5sig = kmalloc(sizeof(*md5sig), gfp);
 	if (!md5sig)
 		return -ENOMEM;
@@ -1191,9 +1188,9 @@  static int tcp_md5sig_info_add(struct sock *sk, gfp_t gfp)
 }
 
 /* This can be called on a newly created socket, from other files */
-int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
-		   int family, u8 prefixlen, int l3index, u8 flags,
-		   const u8 *newkey, u8 newkeylen, gfp_t gfp)
+static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
+			    int family, u8 prefixlen, int l3index, u8 flags,
+			    const u8 *newkey, u8 newkeylen, gfp_t gfp)
 {
 	/* Add Key to the list */
 	struct tcp_md5sig_key *key;
@@ -1220,9 +1217,6 @@  int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
 		return 0;
 	}
 
-	if (tcp_md5sig_info_add(sk, gfp))
-		return -ENOMEM;
-
 	md5sig = rcu_dereference_protected(tp->md5sig_info,
 					   lockdep_sock_is_held(sk));
 
@@ -1246,8 +1240,44 @@  int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
 	hlist_add_head_rcu(&key->node, &md5sig->head);
 	return 0;
 }
+
+int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
+		   int family, u8 prefixlen, int l3index, u8 flags,
+		   const u8 *newkey, u8 newkeylen)
+{
+	struct tcp_sock *tp = tcp_sk(sk);
+
+	if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
+		if (tcp_md5sig_info_add(sk, GFP_KERNEL))
+			return -ENOMEM;
+
+		static_branch_inc(&tcp_md5_needed.key);
+	}
+
+	return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags,
+				newkey, newkeylen, GFP_KERNEL);
+}
 EXPORT_SYMBOL(tcp_md5_do_add);
 
+int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr,
+		     int family, u8 prefixlen, int l3index,
+		     struct tcp_md5sig_key *key)
+{
+	struct tcp_sock *tp = tcp_sk(sk);
+
+	if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) {
+		if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC)))
+			return -ENOMEM;
+
+		atomic_inc(&tcp_md5_needed.key.key.enabled);
+	}
+
+	return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index,
+				key->flags, key->key, key->keylen,
+				sk_gfp_mask(sk, GFP_ATOMIC));
+}
+EXPORT_SYMBOL(tcp_md5_key_copy);
+
 int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
 		   u8 prefixlen, int l3index, u8 flags)
 {
@@ -1334,7 +1364,7 @@  static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
 		return -EINVAL;
 
 	return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags,
-			      cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
+			      cmd.tcpm_key, cmd.tcpm_keylen);
 }
 
 static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
@@ -1591,8 +1621,7 @@  struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
 		 * memory, then we end up not copying the key
 		 * across. Shucks.
 		 */
-		tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags,
-			       key->key, key->keylen, GFP_ATOMIC);
+		tcp_md5_key_copy(newsk, addr, AF_INET, 32, l3index, key);
 		sk_gso_disable(newsk);
 	}
 #endif
@@ -2284,6 +2313,7 @@  void tcp_v4_destroy_sock(struct sock *sk)
 		tcp_clear_md5_list(sk);
 		kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu);
 		tp->md5sig_info = NULL;
+		static_branch_slow_dec_deferred(&tcp_md5_needed);
 	}
 #endif
 
diff --git a/net/ipv4/tcp_minisocks.c b/net/ipv4/tcp_minisocks.c
index c375f603a16c..fb500160b8d2 100644
--- a/net/ipv4/tcp_minisocks.c
+++ b/net/ipv4/tcp_minisocks.c
@@ -291,13 +291,14 @@  void tcp_time_wait(struct sock *sk, int state, int timeo)
 		 */
 		do {
 			tcptw->tw_md5_key = NULL;
-			if (static_branch_unlikely(&tcp_md5_needed)) {
+			if (static_branch_unlikely(&tcp_md5_needed.key)) {
 				struct tcp_md5sig_key *key;
 
 				key = tp->af_specific->md5_lookup(sk, sk);
 				if (key) {
 					tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC);
 					BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool());
+					atomic_inc(&tcp_md5_needed.key.key.enabled);
 				}
 			}
 		} while (0);
@@ -337,11 +338,13 @@  EXPORT_SYMBOL(tcp_time_wait);
 void tcp_twsk_destructor(struct sock *sk)
 {
 #ifdef CONFIG_TCP_MD5SIG
-	if (static_branch_unlikely(&tcp_md5_needed)) {
+	if (static_branch_unlikely(&tcp_md5_needed.key)) {
 		struct tcp_timewait_sock *twsk = tcp_twsk(sk);
 
-		if (twsk->tw_md5_key)
+		if (twsk->tw_md5_key) {
 			kfree_rcu(twsk->tw_md5_key, rcu);
+			static_branch_slow_dec_deferred(&tcp_md5_needed);
+		}
 	}
 #endif
 }
diff --git a/net/ipv4/tcp_output.c b/net/ipv4/tcp_output.c
index c69f4d966024..86e71c8c76bc 100644
--- a/net/ipv4/tcp_output.c
+++ b/net/ipv4/tcp_output.c
@@ -766,7 +766,7 @@  static unsigned int tcp_syn_options(struct sock *sk, struct sk_buff *skb,
 
 	*md5 = NULL;
 #ifdef CONFIG_TCP_MD5SIG
-	if (static_branch_unlikely(&tcp_md5_needed) &&
+	if (static_branch_unlikely(&tcp_md5_needed.key) &&
 	    rcu_access_pointer(tp->md5sig_info)) {
 		*md5 = tp->af_specific->md5_lookup(sk, sk);
 		if (*md5) {
@@ -922,7 +922,7 @@  static unsigned int tcp_established_options(struct sock *sk, struct sk_buff *skb
 
 	*md5 = NULL;
 #ifdef CONFIG_TCP_MD5SIG
-	if (static_branch_unlikely(&tcp_md5_needed) &&
+	if (static_branch_unlikely(&tcp_md5_needed.key) &&
 	    rcu_access_pointer(tp->md5sig_info)) {
 		*md5 = tp->af_specific->md5_lookup(sk, sk);
 		if (*md5) {
diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
index 2a3f9296df1e..3e3bdc120fc8 100644
--- a/net/ipv6/tcp_ipv6.c
+++ b/net/ipv6/tcp_ipv6.c
@@ -677,12 +677,11 @@  static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
 	if (ipv6_addr_v4mapped(&sin6->sin6_addr))
 		return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
 				      AF_INET, prefixlen, l3index, flags,
-				      cmd.tcpm_key, cmd.tcpm_keylen,
-				      GFP_KERNEL);
+				      cmd.tcpm_key, cmd.tcpm_keylen);
 
 	return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
 			      AF_INET6, prefixlen, l3index, flags,
-			      cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
+			      cmd.tcpm_key, cmd.tcpm_keylen);
 }
 
 static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
@@ -1382,9 +1381,8 @@  static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
 		 * memory, then we end up not copying the key
 		 * across. Shucks.
 		 */
-		tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
-			       AF_INET6, 128, l3index, key->flags, key->key, key->keylen,
-			       sk_gfp_mask(sk, GFP_ATOMIC));
+		tcp_md5_key_copy(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
+				 AF_INET6, 128, l3index, key);
 	}
 #endif