@@ -48,6 +48,14 @@ struct sock *__inet6_lookup_established(struct net *net,
const u16 hnum, const int dif,
const int sdif);
+struct sock *__inet6_lookup_established_locked(struct net *net,
+ struct inet_hashinfo *hashinfo,
+ const struct in6_addr *saddr,
+ const __be16 sport,
+ const struct in6_addr *daddr,
+ const u16 hnum, const int dif,
+ const int sdif);
+
typedef u32 (inet6_ehashfn_t)(const struct net *net,
const struct in6_addr *laddr, const u16 lport,
const struct in6_addr *faddr, const __be16 fport);
@@ -103,6 +111,27 @@ static inline struct sock *__inet6_lookup(struct net *net,
daddr, hnum, dif, sdif);
}
+static inline struct sock *__inet6_lookup_locked(struct net *net,
+ struct inet_hashinfo *hashinfo,
+ struct sk_buff *skb, int doff,
+ const struct in6_addr *saddr,
+ const __be16 sport,
+ const struct in6_addr *daddr,
+ const u16 hnum,
+ const int dif, const int sdif,
+ bool *refcounted)
+{
+ struct sock *sk = __inet6_lookup_established_locked(net, hashinfo, saddr,
+ sport, daddr, hnum,
+ dif, sdif);
+ *refcounted = true;
+ if (sk)
+ return sk;
+ *refcounted = false;
+ return inet6_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
+ daddr, hnum, dif, sdif);
+}
+
static inline
struct sock *inet6_steal_sock(struct net *net, struct sk_buff *skb, int doff,
const struct in6_addr *saddr, const __be16 sport,
@@ -167,6 +196,30 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
iif, sdif, refcounted);
}
+static inline struct sock *__inet6_lookup_skb_locked(struct inet_hashinfo *hashinfo,
+ struct sk_buff *skb, int doff,
+ const __be16 sport,
+ const __be16 dport,
+ int iif, int sdif,
+ bool *refcounted)
+{
+ struct net *net = dev_net(skb_dst(skb)->dev);
+ const struct ipv6hdr *ip6h = ipv6_hdr(skb);
+ struct sock *sk;
+
+ sk = inet6_steal_sock(net, skb, doff, &ip6h->saddr, sport, &ip6h->daddr, dport,
+ refcounted, inet6_ehashfn);
+ if (IS_ERR(sk))
+ return NULL;
+ if (sk)
+ return sk;
+
+ return __inet6_lookup_locked(net, hashinfo, skb,
+ doff, &ip6h->saddr, sport,
+ &ip6h->daddr, ntohs(dport),
+ iif, sdif, refcounted);
+}
+
struct sock *inet6_lookup(struct net *net, struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
const struct in6_addr *saddr, const __be16 sport,
@@ -374,6 +374,12 @@ struct sock *__inet_lookup_established(struct net *net,
const __be32 daddr, const u16 hnum,
const int dif, const int sdif);
+struct sock *__inet_lookup_established_locked(struct net *net,
+ struct inet_hashinfo *hashinfo,
+ const __be32 saddr, const __be16 sport,
+ const __be32 daddr, const u16 hnum,
+ const int dif, const int sdif);
+
typedef u32 (inet_ehashfn_t)(const struct net *net,
const __be32 laddr, const __u16 lport,
const __be32 faddr, const __be16 fport);
@@ -426,6 +432,27 @@ static inline struct sock *__inet_lookup(struct net *net,
sport, daddr, hnum, dif, sdif);
}
+static inline struct sock *__inet_lookup_locked(struct net *net,
+ struct inet_hashinfo *hashinfo,
+ struct sk_buff *skb, int doff,
+ const __be32 saddr, const __be16 sport,
+ const __be32 daddr, const __be16 dport,
+ const int dif, const int sdif,
+ bool *refcounted)
+{
+ u16 hnum = ntohs(dport);
+ struct sock *sk;
+
+ sk = __inet_lookup_established_locked(net, hashinfo, saddr, sport,
+ daddr, hnum, dif, sdif);
+ *refcounted = true;
+ if (sk)
+ return sk;
+ *refcounted = false;
+ return __inet_lookup_listener(net, hashinfo, skb, doff, saddr,
+ sport, daddr, hnum, dif, sdif);
+}
+
static inline struct sock *inet_lookup(struct net *net,
struct inet_hashinfo *hashinfo,
struct sk_buff *skb, int doff,
@@ -509,6 +536,31 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
refcounted);
}
+static inline struct sock *__inet_lookup_skb_locked(struct inet_hashinfo *hashinfo,
+ struct sk_buff *skb,
+ int doff,
+ const __be16 sport,
+ const __be16 dport,
+ const int sdif,
+ bool *refcounted)
+{
+ struct net *net = dev_net(skb_dst(skb)->dev);
+ const struct iphdr *iph = ip_hdr(skb);
+ struct sock *sk;
+
+ sk = inet_steal_sock(net, skb, doff, iph->saddr, sport, iph->daddr, dport,
+ refcounted, inet_ehashfn);
+ if (IS_ERR(sk))
+ return NULL;
+ if (sk)
+ return sk;
+
+ return __inet_lookup_locked(net, hashinfo, skb,
+ doff, iph->saddr, sport,
+ iph->daddr, dport, inet_iif(skb), sdif,
+ refcounted);
+}
+
static inline void sk_daddr_set(struct sock *sk, __be32 addr)
{
sk->sk_daddr = addr; /* alias of inet_daddr */
@@ -535,6 +535,55 @@ struct sock *__inet_lookup_established(struct net *net,
}
EXPORT_SYMBOL_GPL(__inet_lookup_established);
+struct sock *__inet_lookup_established_locked(struct net *net,
+ struct inet_hashinfo *hashinfo,
+ const __be32 saddr, const __be16 sport,
+ const __be32 daddr, const u16 hnum,
+ const int dif, const int sdif)
+{
+ INET_ADDR_COOKIE(acookie, saddr, daddr);
+ const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
+ struct sock *sk;
+ const struct hlist_nulls_node *node;
+ /* Optimize here for direct hit, only listening connections can
+ * have wildcards anyways.
+ */
+ unsigned int hash = inet_ehashfn(net, daddr, hnum, saddr, sport);
+ unsigned int slot = hash & hashinfo->ehash_mask;
+ struct inet_ehash_bucket *head = &hashinfo->ehash[slot];
+ spinlock_t *lock = inet_ehash_lockp(hashinfo, hash);
+
+ spin_lock(lock);
+begin:
+ sk_nulls_for_each(sk, node, &head->chain) {
+ if (sk->sk_hash != hash)
+ continue;
+ if (likely(inet_match(net, sk, acookie, ports, dif, sdif))) {
+ if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
+ goto out;
+ if (unlikely(!inet_match(net, sk, acookie,
+ ports, dif, sdif))) {
+ sock_gen_put(sk);
+ goto begin;
+ }
+ goto found;
+ }
+ }
+ /*
+ * if the nulls value we got at the end of this lookup is
+ * not the expected one, we must restart lookup.
+ * We probably met an item that was moved to another chain.
+ */
+ if (get_nulls_value(node) != slot)
+ goto begin;
+out:
+ sk = NULL;
+found:
+ spin_unlock(lock);
+ return sk;
+}
+EXPORT_SYMBOL_GPL(__inet_lookup_established_locked);
+
/* called with local bh disabled */
static int __inet_check_established(struct inet_timewait_death_row *death_row,
struct sock *sk, __u16 lport,
@@ -2209,6 +2209,15 @@ int tcp_v4_rcv(struct sk_buff *skb)
sk = __inet_lookup_skb(net->ipv4.tcp_death_row.hashinfo,
skb, __tcp_hdrlen(th), th->source,
th->dest, sdif, &refcounted);
+
+ /* The 1st lookup is prone to races as it's RCU
+ * Under rare conditions it can find a LISTEN socket
+ * Avoid an erroneous RST and this time do a locked lookup.
+ */
+ if (unlikely(sk && sk->sk_state == TCP_LISTEN && th->ack))
+ sk = __inet_lookup_skb_locked(net->ipv4.tcp_death_row.hashinfo,
+ skb, __tcp_hdrlen(th), th->source,
+ th->dest, sdif, &refcounted);
if (!sk)
goto no_tcp_socket;
@@ -90,6 +90,51 @@ struct sock *__inet6_lookup_established(struct net *net,
}
EXPORT_SYMBOL(__inet6_lookup_established);
+struct sock *__inet6_lookup_established_locked(struct net *net,
+ struct inet_hashinfo *hashinfo,
+ const struct in6_addr *saddr,
+ const __be16 sport,
+ const struct in6_addr *daddr,
+ const u16 hnum,
+ const int dif, const int sdif)
+{
+ struct sock *sk;
+ const struct hlist_nulls_node *node;
+ const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
+ /* Optimize here for direct hit, only listening connections can
+ * have wildcards anyways.
+ */
+ unsigned int hash = inet6_ehashfn(net, daddr, hnum, saddr, sport);
+ unsigned int slot = hash & hashinfo->ehash_mask;
+ struct inet_ehash_bucket *head = &hashinfo->ehash[slot];
+ spinlock_t *lock = inet_ehash_lockp(hashinfo, hash);
+
+ spin_lock(lock);
+begin:
+ sk_nulls_for_each(sk, node, &head->chain) {
+ if (sk->sk_hash != hash)
+ continue;
+ if (!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))
+ continue;
+ if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt)))
+ goto out;
+
+ if (unlikely(!inet6_match(net, sk, saddr, daddr, ports, dif, sdif))) {
+ sock_gen_put(sk);
+ goto begin;
+ }
+ goto found;
+ }
+ if (get_nulls_value(node) != slot)
+ goto begin;
+out:
+ sk = NULL;
+found:
+ spin_unlock(lock);
+ return sk;
+}
+EXPORT_SYMBOL(__inet6_lookup_established_locked);
+
static inline int compute_score(struct sock *sk, struct net *net,
const unsigned short hnum,
const struct in6_addr *daddr,
@@ -1790,6 +1790,15 @@ INDIRECT_CALLABLE_SCOPE int tcp_v6_rcv(struct sk_buff *skb)
sk = __inet6_lookup_skb(net->ipv4.tcp_death_row.hashinfo, skb, __tcp_hdrlen(th),
th->source, th->dest, inet6_iif(skb), sdif,
&refcounted);
+
+ /* The 1st lookup is prone to races as it's RCU
+ * Under rare conditions it can find a LISTEN socket
+ * Avoid an erroneous RST and this time do a locked lookup.
+ */
+ if (unlikely(sk && sk->sk_state == TCP_LISTEN && th->ack))
+ sk = __inet6_lookup_skb_locked(net->ipv4.tcp_death_row.hashinfo, skb,
+ __tcp_hdrlen(th), th->source, th->dest,
+ inet6_iif(skb), sdif, &refcounted);
if (!sk)
goto no_tcp_socket;