@@ -436,11 +436,14 @@ static int __inet6_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len,
int inet6_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
{
struct sock *sk = sock->sk;
+ const struct proto *prot;
int err = 0;
+ /* IPV6_ADDRFORM can change sk->sk_prot under us. */
+ prot = READ_ONCE(sk->sk_prot);
/* If the socket has its own bind function then use it. */
- if (sk->sk_prot->bind)
- return sk->sk_prot->bind(sk, uaddr, addr_len);
+ if (prot->bind)
+ return prot->bind(sk, uaddr, addr_len);
if (addr_len < SIN6_LEN_RFC2133)
return -EINVAL;
@@ -545,6 +548,7 @@ int inet6_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
{
struct sock *sk = sock->sk;
struct net *net = sock_net(sk);
+ const struct proto *prot;
switch (cmd) {
case SIOCGSTAMP:
@@ -565,9 +569,11 @@ int inet6_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
case SIOCSIFDSTADDR:
return addrconf_set_dstaddr(net, (void __user *) arg);
default:
- if (!sk->sk_prot->ioctl)
+ /* IPV6_ADDRFORM can change sk->sk_prot under us. */
+ prot = READ_ONCE(sk->sk_prot);
+ if (!prot->ioctl)
return -ENOIOCTLCMD;
- return sk->sk_prot->ioctl(sk, cmd, arg);
+ return prot->ioctl(sk, cmd, arg);
}
/*NOTREACHED*/
return 0;
@@ -577,25 +583,31 @@ int inet6_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
int inet6_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
{
struct sock *sk = sock->sk;
+ const struct proto *prot;
if (unlikely(inet_send_prepare(sk)))
return -EAGAIN;
- return sk->sk_prot->sendmsg(sk, msg, size);
+ /* IPV6_ADDRFORM can change sk->sk_prot under us. */
+ prot = READ_ONCE(sk->sk_prot);
+ return prot->sendmsg(sk, msg, size);
}
int inet6_recvmsg(struct socket *sock, struct msghdr *msg, size_t size,
int flags)
{
struct sock *sk = sock->sk;
+ const struct proto *prot;
int addr_len = 0;
int err;
if (likely(!(flags & MSG_ERRQUEUE)))
sock_rps_record_flow(sk);
- err = sk->sk_prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
- flags & ~MSG_DONTWAIT, &addr_len);
+ /* IPV6_ADDRFORM can change sk->sk_prot under us. */
+ prot = READ_ONCE(sk->sk_prot);
+ err = prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT,
+ flags & ~MSG_DONTWAIT, &addr_len);
if (err >= 0)
msg->msg_namelen = addr_len;
return err;
@@ -230,7 +230,8 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
sock_prot_inuse_add(net, sk->sk_prot, -1);
sock_prot_inuse_add(net, &tcp_prot, 1);
local_bh_enable();
- sk->sk_prot = &tcp_prot;
+ /* Paired with READ_ONCE(sk->sk_prot) in net/ipv6/af_inet6.c */
+ WRITE_ONCE(sk->sk_prot, &tcp_prot);
/* Paired with READ_ONCE() in tcp_(get|set)sockopt() */
WRITE_ONCE(icsk->icsk_af_ops, &ipv4_specific);
sk->sk_socket->ops = &inet_stream_ops;
@@ -245,7 +246,8 @@ static int do_ipv6_setsockopt(struct sock *sk, int level, int optname,
sock_prot_inuse_add(net, sk->sk_prot, -1);
sock_prot_inuse_add(net, prot, 1);
local_bh_enable();
- sk->sk_prot = prot;
+ /* Paired with READ_ONCE(sk->sk_prot) in net/ipv6/af_inet6.c */
+ WRITE_ONCE(sk->sk_prot, prot);
sk->sk_socket->ops = &inet_dgram_ops;
sk->sk_family = PF_INET;
}