@@ -71,6 +71,7 @@ struct msghdr {
void __user *msg_control_user;
};
bool msg_control_is_user : 1;
+ bool msg_control_copy_to_user : 1;
bool msg_get_inq : 1;/* return INQ after receive */
unsigned int msg_flags; /* flags on received message */
__kernel_size_t msg_controllen; /* ancillary data buffer length */
@@ -168,6 +169,11 @@ static inline struct cmsghdr * cmsg_nxthdr (struct msghdr *__msg, struct cmsghdr
return __cmsg_nxthdr(__msg->msg_control, __msg->msg_controllen, __cmsg);
}
+static inline bool cmsg_copy_to_user(struct cmsghdr *__cmsg)
+{
+ return 0;
+}
+
static inline size_t msg_data_left(struct msghdr *msg)
{
return iov_iter_count(&msg->msg_iter);
@@ -396,6 +402,8 @@ struct timespec64;
struct __kernel_timespec;
struct old_timespec32;
+DECLARE_STATIC_KEY_FALSE(tx_copy_cmsg_to_user_key);
+
struct scm_timestamping_internal {
struct timespec64 ts[3];
};
@@ -1804,7 +1804,7 @@ static inline void sockcm_init(struct sockcm_cookie *sockc,
};
}
-int __sock_cmsg_send(struct sock *sk, struct cmsghdr *cmsg,
+int __sock_cmsg_send(struct sock *sk, struct msghdr *msg, struct cmsghdr *cmsg,
struct sockcm_cookie *sockc);
int sock_cmsg_send(struct sock *sk, struct msghdr *msg,
struct sockcm_cookie *sockc);
@@ -2826,8 +2826,8 @@ struct sk_buff *sock_alloc_send_pskb(struct sock *sk, unsigned long header_len,
}
EXPORT_SYMBOL(sock_alloc_send_pskb);
-int __sock_cmsg_send(struct sock *sk, struct cmsghdr *cmsg,
- struct sockcm_cookie *sockc)
+int __sock_cmsg_send(struct sock *sk, struct msghdr *msg __always_unused,
+ struct cmsghdr *cmsg, struct sockcm_cookie *sockc)
{
u32 tsflags;
@@ -2881,7 +2881,7 @@ int sock_cmsg_send(struct sock *sk, struct msghdr *msg,
return -EINVAL;
if (cmsg->cmsg_level != SOL_SOCKET)
continue;
- ret = __sock_cmsg_send(sk, cmsg, sockc);
+ ret = __sock_cmsg_send(sk, msg, cmsg, sockc);
if (ret)
return ret;
}
@@ -267,7 +267,7 @@ int ip_cmsg_send(struct sock *sk, struct msghdr *msg, struct ipcm_cookie *ipc,
}
#endif
if (cmsg->cmsg_level == SOL_SOCKET) {
- err = __sock_cmsg_send(sk, cmsg, &ipc->sockc);
+ err = __sock_cmsg_send(sk, msg, cmsg, &ipc->sockc);
if (err)
return err;
continue;
@@ -777,7 +777,7 @@ int ip6_datagram_send_ctl(struct net *net, struct sock *sk,
}
if (cmsg->cmsg_level == SOL_SOCKET) {
- err = __sock_cmsg_send(sk, cmsg, &ipc6->sockc);
+ err = __sock_cmsg_send(sk, msg, cmsg, &ipc6->sockc);
if (err)
return err;
continue;
@@ -2537,8 +2537,49 @@ static int copy_msghdr_from_user(struct msghdr *kmsg,
return err < 0 ? err : 0;
}
-static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
- unsigned int flags, struct used_address *used_address,
+DEFINE_STATIC_KEY_FALSE(tx_copy_cmsg_to_user_key);
+
+static int sendmsg_copy_cmsg_to_user(struct msghdr *msg_sys,
+ struct user_msghdr __user *umsg)
+{
+ struct compat_msghdr __user *umsg_compat =
+ (struct compat_msghdr __user *)umsg;
+ unsigned int flags = msg_sys->msg_flags;
+ struct msghdr msg_user = *msg_sys;
+ unsigned long cmsg_ptr;
+ struct cmsghdr *cmsg;
+ int err;
+
+ msg_user.msg_control_is_user = true;
+ msg_user.msg_control_user = umsg->msg_control;
+ cmsg_ptr = (unsigned long)msg_user.msg_control;
+ for_each_cmsghdr(cmsg, msg_sys) {
+ if (!CMSG_OK(msg_sys, cmsg))
+ break;
+ if (!cmsg_copy_to_user(cmsg))
+ continue;
+ err = put_cmsg(&msg_user, cmsg->cmsg_level, cmsg->cmsg_type,
+ cmsg->cmsg_len - sizeof(*cmsg), CMSG_DATA(cmsg));
+ if (err)
+ return err;
+ }
+
+ err = __put_user((msg_sys->msg_flags & ~MSG_CMSG_COMPAT),
+ COMPAT_FLAGS(umsg));
+ if (err)
+ return err;
+ if (MSG_CMSG_COMPAT & flags)
+ err = __put_user((unsigned long)msg_user.msg_control - cmsg_ptr,
+ &umsg_compat->msg_controllen);
+ else
+ err = __put_user((unsigned long)msg_user.msg_control - cmsg_ptr,
+ &umsg->msg_controllen);
+ return err;
+}
+
+static int ____sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
+ struct msghdr *msg_sys, unsigned int flags,
+ struct used_address *used_address,
unsigned int allowed_msghdr_flags)
{
unsigned char ctl[sizeof(struct cmsghdr) + 20]
@@ -2549,6 +2590,8 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
ssize_t err;
err = -ENOBUFS;
+ if (static_branch_unlikely(&tx_copy_cmsg_to_user_key))
+ msg_sys->msg_control_copy_to_user = false;
if (msg_sys->msg_controllen > INT_MAX)
goto out;
@@ -2606,6 +2649,16 @@ static int ____sys_sendmsg(struct socket *sock, struct msghdr *msg_sys,
used_address->name_len);
}
+ if (static_branch_unlikely(&tx_copy_cmsg_to_user_key)) {
+ if (msg_sys->msg_control_copy_to_user && msg && err >= 0) {
+ ssize_t len = err;
+
+ err = sendmsg_copy_cmsg_to_user(msg_sys, msg);
+ if (!err)
+ err = len;
+ }
+ }
+
out_freectl:
if (ctl_buf != ctl)
sock_kfree_s(sock->sk, ctl_buf, ctl_len);
@@ -2648,8 +2701,8 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
if (err < 0)
return err;
- err = ____sys_sendmsg(sock, msg_sys, flags, used_address,
- allowed_msghdr_flags);
+ err = ____sys_sendmsg(sock, msg, msg_sys, flags, used_address,
+ allowed_msghdr_flags);
kfree(iov);
return err;
}
@@ -2660,7 +2713,7 @@ static int ___sys_sendmsg(struct socket *sock, struct user_msghdr __user *msg,
long __sys_sendmsg_sock(struct socket *sock, struct msghdr *msg,
unsigned int flags)
{
- return ____sys_sendmsg(sock, msg, flags, NULL, 0);
+ return ____sys_sendmsg(sock, NULL, msg, flags, NULL, 0);
}
long __sys_sendmsg(int fd, struct user_msghdr __user *msg, unsigned int flags,