@@ -1588,10 +1588,11 @@ union tcp_md5_addr {
struct tcp_md5sig_key {
struct hlist_node node;
u8 keylen;
u8 family; /* AF_INET or AF_INET6 */
u8 prefixlen;
+ u8 flags;
union tcp_md5_addr addr;
int l3index; /* set if key added with L3 scope */
u8 key[TCP_MD5SIG_MAXKEYLEN];
struct rcu_head rcu;
};
@@ -1633,14 +1634,14 @@ struct tcp_md5sig_pool {
/* - functions */
int tcp_v4_md5_hash_skb(char *md5_hash, const struct tcp_md5sig_key *key,
const struct sock *sk, const struct sk_buff *skb);
int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
- int family, u8 prefixlen, int l3index,
+ int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen, gfp_t gfp);
int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr,
- int family, u8 prefixlen, int l3index);
+ int family, u8 prefixlen, int l3index, u8 flags);
struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
const struct sock *addr_sk);
#ifdef CONFIG_TCP_MD5SIG
#include <linux/jump_label.h>
@@ -1071,11 +1071,11 @@ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
hlist_for_each_entry_rcu(key, &md5sig->head, node,
lockdep_sock_is_held(sk)) {
if (key->family != family)
continue;
- if (key->l3index && key->l3index != l3index)
+ if (key->flags & TCP_MD5SIG_FLAG_IFINDEX && key->l3index != l3index)
continue;
if (family == AF_INET) {
mask = inet_make_mask(key->prefixlen);
match = (key->addr.a4.s_addr & mask) ==
(addr->a4.s_addr & mask);
@@ -1096,11 +1096,11 @@ struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index,
EXPORT_SYMBOL(__tcp_md5_do_lookup);
static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk,
const union tcp_md5_addr *addr,
int family, u8 prefixlen,
- int l3index)
+ int l3index, u8 flags)
{
const struct tcp_sock *tp = tcp_sk(sk);
struct tcp_md5sig_key *key;
unsigned int size = sizeof(struct in_addr);
const struct tcp_md5sig_info *md5sig;
@@ -1116,10 +1116,12 @@ static struct tcp_md5sig_key *tcp_md5_do_lookup_exact(const struct sock *sk,
#endif
hlist_for_each_entry_rcu(key, &md5sig->head, node,
lockdep_sock_is_held(sk)) {
if (key->family != family)
continue;
+ if ((key->flags & TCP_MD5SIG_FLAG_IFINDEX) != (flags & TCP_MD5SIG_FLAG_IFINDEX))
+ continue;
if (key->l3index != l3index)
continue;
if (!memcmp(&key->addr, addr, size) &&
key->prefixlen == prefixlen)
return key;
@@ -1140,19 +1142,19 @@ struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk,
}
EXPORT_SYMBOL(tcp_v4_md5_lookup);
/* This can be called on a newly created socket, from other files */
int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
- int family, u8 prefixlen, int l3index,
+ int family, u8 prefixlen, int l3index, u8 flags,
const u8 *newkey, u8 newkeylen, gfp_t gfp)
{
/* Add Key to the list */
struct tcp_md5sig_key *key;
struct tcp_sock *tp = tcp_sk(sk);
struct tcp_md5sig_info *md5sig;
- key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index);
+ key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index, flags);
if (key) {
/* Pre-existing entry - just update that one.
* Note that the key might be used concurrently.
* data_race() is telling kcsan that we do not care of
* key mismatches, since changing MD5 key on live flows
@@ -1193,24 +1195,25 @@ int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr,
memcpy(key->key, newkey, newkeylen);
key->keylen = newkeylen;
key->family = family;
key->prefixlen = prefixlen;
key->l3index = l3index;
+ key->flags = flags;
memcpy(&key->addr, addr,
(family == AF_INET6) ? sizeof(struct in6_addr) :
sizeof(struct in_addr));
hlist_add_head_rcu(&key->node, &md5sig->head);
return 0;
}
EXPORT_SYMBOL(tcp_md5_do_add);
int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family,
- u8 prefixlen, int l3index)
+ u8 prefixlen, int l3index, u8 flags)
{
struct tcp_md5sig_key *key;
- key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index);
+ key = tcp_md5_do_lookup_exact(sk, addr, family, prefixlen, l3index, flags);
if (!key)
return -ENOENT;
hlist_del_rcu(&key->node);
atomic_sub(sizeof(*key), &sk->sk_omem_alloc);
kfree_rcu(key, rcu);
@@ -1240,28 +1243,31 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
struct tcp_md5sig cmd;
struct sockaddr_in *sin = (struct sockaddr_in *)&cmd.tcpm_addr;
const union tcp_md5_addr *addr;
u8 prefixlen = 32;
int l3index = 0;
+ u8 flags;
if (optlen < sizeof(cmd))
return -EINVAL;
if (copy_from_sockptr(&cmd, optval, sizeof(cmd)))
return -EFAULT;
if (sin->sin_family != AF_INET)
return -EINVAL;
+ flags = cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX;
+
if (optname == TCP_MD5SIG_EXT &&
cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) {
prefixlen = cmd.tcpm_prefixlen;
if (prefixlen > 32)
return -EINVAL;
}
- if (optname == TCP_MD5SIG_EXT &&
+ if (optname == TCP_MD5SIG_EXT && cmd.tcpm_ifindex &&
cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX) {
struct net_device *dev;
rcu_read_lock();
dev = dev_get_by_index_rcu(sock_net(sk), cmd.tcpm_ifindex);
@@ -1278,16 +1284,16 @@ static int tcp_v4_parse_md5_keys(struct sock *sk, int optname,
}
addr = (union tcp_md5_addr *)&sin->sin_addr.s_addr;
if (!cmd.tcpm_keylen)
- return tcp_md5_do_del(sk, addr, AF_INET, prefixlen, l3index);
+ return tcp_md5_do_del(sk, addr, AF_INET, prefixlen, l3index, flags);
if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
return -EINVAL;
- return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index,
+ return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags,
cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
}
static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp,
__be32 daddr, __be32 saddr,
@@ -1607,11 +1613,11 @@ struct sock *tcp_v4_syn_recv_sock(const struct sock *sk, struct sk_buff *skb,
* We're using one, so create a matching key
* on the newsk structure. If we fail to get
* memory, then we end up not copying the key
* across. Shucks.
*/
- tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index,
+ tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags,
key->key, key->keylen, GFP_ATOMIC);
sk_nocaps_add(newsk, NETIF_F_GSO_MASK);
}
#endif
@@ -597,31 +597,34 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
{
struct tcp_md5sig cmd;
struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *)&cmd.tcpm_addr;
int l3index = 0;
u8 prefixlen;
+ u8 flags;
if (optlen < sizeof(cmd))
return -EINVAL;
if (copy_from_sockptr(&cmd, optval, sizeof(cmd)))
return -EFAULT;
if (sin6->sin6_family != AF_INET6)
return -EINVAL;
+ flags = cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX;
+
if (optname == TCP_MD5SIG_EXT &&
cmd.tcpm_flags & TCP_MD5SIG_FLAG_PREFIX) {
prefixlen = cmd.tcpm_prefixlen;
if (prefixlen > 128 || (ipv6_addr_v4mapped(&sin6->sin6_addr) &&
prefixlen > 32))
return -EINVAL;
} else {
prefixlen = ipv6_addr_v4mapped(&sin6->sin6_addr) ? 32 : 128;
}
- if (optname == TCP_MD5SIG_EXT &&
+ if (optname == TCP_MD5SIG_EXT && cmd.tcpm_ifindex &&
cmd.tcpm_flags & TCP_MD5SIG_FLAG_IFINDEX) {
struct net_device *dev;
rcu_read_lock();
dev = dev_get_by_index_rcu(sock_net(sk), cmd.tcpm_ifindex);
@@ -638,26 +641,26 @@ static int tcp_v6_parse_md5_keys(struct sock *sk, int optname,
if (!cmd.tcpm_keylen) {
if (ipv6_addr_v4mapped(&sin6->sin6_addr))
return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
AF_INET, prefixlen,
- l3index);
+ l3index, flags);
return tcp_md5_do_del(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
- AF_INET6, prefixlen, l3index);
+ AF_INET6, prefixlen, l3index, flags);
}
if (cmd.tcpm_keylen > TCP_MD5SIG_MAXKEYLEN)
return -EINVAL;
if (ipv6_addr_v4mapped(&sin6->sin6_addr))
return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3],
- AF_INET, prefixlen, l3index,
+ AF_INET, prefixlen, l3index, flags,
cmd.tcpm_key, cmd.tcpm_keylen,
GFP_KERNEL);
return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr,
- AF_INET6, prefixlen, l3index,
+ AF_INET6, prefixlen, l3index, flags,
cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL);
}
static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp,
const struct in6_addr *daddr,
@@ -1402,11 +1405,11 @@ static struct sock *tcp_v6_syn_recv_sock(const struct sock *sk, struct sk_buff *
* on the newsk structure. If we fail to get
* memory, then we end up not copying the key
* across. Shucks.
*/
tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr,
- AF_INET6, 128, l3index, key->key, key->keylen,
+ AF_INET6, 128, l3index, key->flags, key->key, key->keylen,
sk_gfp_mask(sk, GFP_ATOMIC));
}
#endif
if (__inet_inherit_port(sk, newsk) < 0) {