@@ -226,7 +226,7 @@ static inline int udp_lib_hash(struct sock *sk)
}
void udp_lib_unhash(struct sock *sk);
-void udp_lib_rehash(struct sock *sk, u16 new_hash);
+void udp_lib_rehash(struct sock *sk, u16 new_hash, u16 new_hash4);
static inline void udp_lib_close(struct sock *sk, long timeout)
{
@@ -319,6 +319,7 @@ int udp_rcv(struct sk_buff *skb);
int udp_ioctl(struct sock *sk, int cmd, int *karg);
int udp_init_sock(struct sock *sk);
int udp_pre_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len);
+int udp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len);
int __udp_disconnect(struct sock *sk, int flags);
int udp_disconnect(struct sock *sk, int flags);
__poll_t udp_poll(struct file *file, struct socket *sock, poll_table *wait);
@@ -478,6 +478,27 @@ static struct sock *udp4_lib_lookup2(const struct net *net,
return result;
}
+static struct sock *udp4_lib_lookup4(const struct net *net,
+ __be32 saddr, __be16 sport,
+ __be32 daddr, unsigned int hnum,
+ int dif, int sdif,
+ struct udp_table *udptable)
+{
+ unsigned int hash4 = udp_ehashfn(net, daddr, hnum, saddr, sport);
+ const __portpair ports = INET_COMBINED_PORTS(sport, hnum);
+ struct udp_hslot *hslot4 = udp_hashslot4(udptable, hash4);
+ struct udp_sock *up;
+ struct sock *sk;
+
+ INET_ADDR_COOKIE(acookie, saddr, daddr);
+ udp_lrpa_for_each_entry_rcu(up, &hslot4->head) {
+ sk = (struct sock *)up;
+ if (inet_match(net, sk, acookie, ports, dif, sdif))
+ return sk;
+ }
+ return NULL;
+}
+
/* UDP is nearly always wildcards out the wazoo, it makes no sense to try
* harder than this. -DaveM
*/
@@ -493,6 +514,12 @@ struct sock *__udp4_lib_lookup(const struct net *net, __be32 saddr,
hash2 = ipv4_portaddr_hash(net, daddr, hnum);
hslot2 = udp_hashslot2(udptable, hash2);
+ if (UDP_HSLOT_MAIN(hslot2)->hash4_cnt) {
+ result = udp4_lib_lookup4(net, saddr, sport, daddr, hnum, dif, sdif, udptable);
+ if (result) /* udp4_lib_lookup4 return sk or NULL */
+ return result;
+ }
+
/* Lookup connected or non-wildcard socket */
result = udp4_lib_lookup2(net, saddr, sport,
daddr, hnum, dif, sdif,
@@ -1931,6 +1958,85 @@ int udp_pre_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
}
EXPORT_SYMBOL(udp_pre_connect);
+/* In hash4, rehash can also happen in connect(), where hash4_cnt keeps unchanged. */
+static void udp4_rehash4(struct udp_table *udptable, struct sock *sk, u16 newhash4)
+{
+ struct udp_hslot *hslot4, *nhslot4;
+
+ hslot4 = udp_hashslot4(udptable, udp_sk(sk)->udp_lrpa_hash);
+ nhslot4 = udp_hashslot4(udptable, newhash4);
+ udp_sk(sk)->udp_lrpa_hash = newhash4;
+
+ if (hslot4 != nhslot4) {
+ spin_lock_bh(&hslot4->lock);
+ hlist_del_init_rcu(&udp_sk(sk)->udp_lrpa_node);
+ hslot4->count--;
+ spin_unlock_bh(&hslot4->lock);
+
+ synchronize_rcu();
+
+ spin_lock_bh(&nhslot4->lock);
+ hlist_add_head_rcu(&udp_sk(sk)->udp_lrpa_node, &nhslot4->head);
+ nhslot4->count++;
+ spin_unlock_bh(&nhslot4->lock);
+ }
+}
+
+/* call with sock lock */
+static void udp4_hash4(struct sock *sk)
+{
+ struct udp_hslot *hslot, *hslot4;
+ struct udp_hslot_main *hslotm2;
+ struct net *net = sock_net(sk);
+ struct udp_table *udptable;
+ unsigned int hash;
+
+ if (sk_unhashed(sk) || inet_sk(sk)->inet_rcv_saddr == htonl(INADDR_ANY))
+ return;
+
+ hash = udp_ehashfn(net, inet_sk(sk)->inet_rcv_saddr, inet_sk(sk)->inet_num,
+ inet_sk(sk)->inet_daddr, inet_sk(sk)->inet_dport);
+
+ udptable = net->ipv4.udp_table;
+ if (udp_hashed4(sk)) {
+ udp4_rehash4(udptable, sk, hash);
+ return;
+ }
+
+ hslot = udp_hashslot(udptable, net, udp_sk(sk)->udp_port_hash);
+ hslotm2 = udp_hashslot2_main(udptable, udp_sk(sk)->udp_portaddr_hash);
+ hslot4 = udp_hashslot4(udptable, hash);
+ udp_sk(sk)->udp_lrpa_hash = hash;
+
+ spin_lock_bh(&hslot->lock);
+ if (rcu_access_pointer(sk->sk_reuseport_cb))
+ reuseport_detach_sock(sk);
+
+ spin_lock(&hslot4->lock);
+ hlist_add_head_rcu(&udp_sk(sk)->udp_lrpa_node, &hslot4->head);
+ hslot4->count++;
+ spin_unlock(&hslot4->lock);
+
+ spin_lock(&hslotm2->hslot.lock);
+ hslotm2->hash4_cnt++;
+ spin_unlock(&hslotm2->hslot.lock);
+
+ spin_unlock_bh(&hslot->lock);
+}
+
+int udp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
+{
+ int res;
+
+ lock_sock(sk);
+ res = __ip4_datagram_connect(sk, uaddr, addr_len);
+ if (!res)
+ udp4_hash4(sk);
+ release_sock(sk);
+ return res;
+}
+EXPORT_SYMBOL(udp_connect);
+
int __udp_disconnect(struct sock *sk, int flags)
{
struct inet_sock *inet = inet_sk(sk);
@@ -1972,7 +2078,7 @@ void udp_lib_unhash(struct sock *sk)
{
if (sk_hashed(sk)) {
struct udp_table *udptable = udp_get_table_prot(sk);
- struct udp_hslot *hslot, *hslot2;
+ struct udp_hslot *hslot, *hslot2, *hslot4;
hslot = udp_hashslot(udptable, sock_net(sk),
udp_sk(sk)->udp_port_hash);
@@ -1990,6 +2096,18 @@ void udp_lib_unhash(struct sock *sk)
hlist_del_init_rcu(&udp_sk(sk)->udp_portaddr_node);
hslot2->count--;
spin_unlock(&hslot2->lock);
+
+ if (udp_hashed4(sk)) {
+ hslot4 = udp_hashslot4(udptable, udp_sk(sk)->udp_lrpa_hash);
+ spin_lock(&hslot4->lock);
+ hlist_del_init_rcu(&udp_sk(sk)->udp_lrpa_node);
+ hslot4->count--;
+ spin_unlock(&hslot4->lock);
+
+ spin_lock(&hslot2->lock);
+ UDP_HSLOT_MAIN(hslot2)->hash4_cnt--;
+ spin_unlock(&hslot2->lock);
+ }
}
spin_unlock_bh(&hslot->lock);
}
@@ -1999,7 +2117,7 @@ EXPORT_SYMBOL(udp_lib_unhash);
/*
* inet_rcv_saddr was changed, we must rehash secondary hash
*/
-void udp_lib_rehash(struct sock *sk, u16 newhash)
+void udp_lib_rehash(struct sock *sk, u16 newhash, u16 newhash4)
{
if (sk_hashed(sk)) {
struct udp_table *udptable = udp_get_table_prot(sk);
@@ -2031,6 +2149,19 @@ void udp_lib_rehash(struct sock *sk, u16 newhash)
spin_unlock(&nhslot2->lock);
}
+ if (udp_hashed4(sk)) {
+ udp4_rehash4(udptable, sk, newhash4);
+
+ if (hslot2 != nhslot2) {
+ spin_lock(&hslot2->lock);
+ UDP_HSLOT_MAIN(hslot2)->hash4_cnt--;
+ spin_unlock(&hslot2->lock);
+
+ spin_lock(&nhslot2->lock);
+ UDP_HSLOT_MAIN(nhslot2)->hash4_cnt++;
+ spin_unlock(&nhslot2->lock);
+ }
+ }
spin_unlock_bh(&hslot->lock);
}
}
@@ -2042,7 +2173,10 @@ void udp_v4_rehash(struct sock *sk)
u16 new_hash = ipv4_portaddr_hash(sock_net(sk),
inet_sk(sk)->inet_rcv_saddr,
inet_sk(sk)->inet_num);
- udp_lib_rehash(sk, new_hash);
+ u16 new_hash4 = udp_ehashfn(sock_net(sk),
+ inet_sk(sk)->inet_rcv_saddr, inet_sk(sk)->inet_num,
+ inet_sk(sk)->inet_daddr, inet_sk(sk)->inet_dport);
+ udp_lib_rehash(sk, new_hash, new_hash4);
}
static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
@@ -2935,7 +3069,7 @@ struct proto udp_prot = {
.owner = THIS_MODULE,
.close = udp_lib_close,
.pre_connect = udp_pre_connect,
- .connect = ip4_datagram_connect,
+ .connect = udp_connect,
.disconnect = udp_disconnect,
.ioctl = udp_ioctl,
.init = udp_init_sock,
@@ -111,7 +111,7 @@ void udp_v6_rehash(struct sock *sk)
&sk->sk_v6_rcv_saddr,
inet_sk(sk)->inet_num);
- udp_lib_rehash(sk, new_hash);
+ udp_lib_rehash(sk, new_hash, 0); /* 4-tuple hash not implemented */
}
static int compute_score(struct sock *sk, const struct net *net,