diff mbox series

[RFC] netlink: split up copies in the ack construction

Message ID 20221005002814.2233715-1-kuba@kernel.org (mailing list archive)
State Changes Requested
Headers show
Series [RFC] netlink: split up copies in the ack construction | expand

Commit Message

Jakub Kicinski Oct. 5, 2022, 12:28 a.m. UTC
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
---
 include/net/netlink.h        | 21 +++++++++++++++++++++
 include/uapi/linux/netlink.h |  2 ++
 net/netlink/af_netlink.c     | 30 +++++++++++++++++++++---------
 3 files changed, 44 insertions(+), 9 deletions(-)

Comments

Kees Cook Oct. 5, 2022, 3:03 a.m. UTC | #1
On Tue, Oct 04, 2022 at 05:28:14PM -0700, Jakub Kicinski wrote:
> Signed-off-by: Jakub Kicinski <kuba@kernel.org>
> ---
>  include/net/netlink.h        | 21 +++++++++++++++++++++
>  include/uapi/linux/netlink.h |  2 ++
>  net/netlink/af_netlink.c     | 30 +++++++++++++++++++++---------
>  3 files changed, 44 insertions(+), 9 deletions(-)
> 
> diff --git a/include/net/netlink.h b/include/net/netlink.h
> index 4418b1981e31..46c40fabd2b5 100644
> --- a/include/net/netlink.h
> +++ b/include/net/netlink.h
> @@ -931,6 +931,27 @@ static inline struct nlmsghdr *nlmsg_put(struct sk_buff *skb, u32 portid, u32 se
>  	return __nlmsg_put(skb, portid, seq, type, payload, flags);
>  }
>  
> +/**
> + * nlmsg_append - Add more data to a nlmsg in a skb
> + * @skb: socket buffer to store message in
> + * @payload: length of message payload
> + *
> + * Append data to an existing nlmsg, used when constructing a message
> + * with multiple fixed-format headers (which is rare).
> + * Returns NULL if the tailroom of the skb is insufficient to store
> + * the extra payload.
> + */
> +static inline void *nlmsg_append(struct sk_buff *skb, u32 size)
> +{
> +	if (unlikely(skb_tailroom(skb) < NLMSG_ALIGN(size)))
> +		return NULL;
> +
> +	if (NLMSG_ALIGN(size) - size)
> +		memset(skb_tail_pointer(skb) + size, 0,
> +		       NLMSG_ALIGN(size) - size);
> +	return __skb_put(skb, NLMSG_ALIGN(size));
> +}
> +
>  /**
>   * nlmsg_put_answer - Add a new callback based netlink message to an skb
>   * @skb: socket buffer to store message in
> diff --git a/include/uapi/linux/netlink.h b/include/uapi/linux/netlink.h
> index e2ae82e3f9f7..fba3ca8152fa 100644
> --- a/include/uapi/linux/netlink.h
> +++ b/include/uapi/linux/netlink.h
> @@ -48,6 +48,7 @@ struct sockaddr_nl {
>   * @nlmsg_flags: Additional flags
>   * @nlmsg_seq:   Sequence number
>   * @nlmsg_pid:   Sending process port ID
> + * @nlmsg_data:  Message payload
>   */
>  struct nlmsghdr {
>  	__u32		nlmsg_len;
> @@ -55,6 +56,7 @@ struct nlmsghdr {
>  	__u16		nlmsg_flags;
>  	__u32		nlmsg_seq;
>  	__u32		nlmsg_pid;
> +	__DECLARE_FLEX_ARRAY(char, nlmsg_data);

Since the flex array isn't part of a union, it can just be declared
"normally":

	__u8		nlmsg_data[];

I'd also suggest u8 (rather than signed char) because compilers hate us,
and I've been burned too many times by having char arrays do stupid
things with the signed bit.

>  };
>  
>  /* Flags values */
> diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
> index a662e8a5ff84..f8c94454b916 100644
> --- a/net/netlink/af_netlink.c
> +++ b/net/netlink/af_netlink.c
> @@ -2488,19 +2488,25 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
>  		flags |= NLM_F_ACK_TLVS;
>  
>  	skb = nlmsg_new(payload + tlvlen, GFP_KERNEL);
> -	if (!skb) {
> -		NETLINK_CB(in_skb).sk->sk_err = ENOBUFS;
> -		sk_error_report(NETLINK_CB(in_skb).sk);
> -		return;
> -	}
> +	if (!skb)
> +		goto err_bad_put;
>  
>  	rep = nlmsg_put(skb, NETLINK_CB(in_skb).portid, nlh->nlmsg_seq,
> -			NLMSG_ERROR, payload, flags);
> +			NLMSG_ERROR, sizeof(*errmsg), flags);
> +	if (!rep)
> +		goto err_bad_put;
>  	errmsg = nlmsg_data(rep);
>  	errmsg->error = err;
> -	unsafe_memcpy(&errmsg->msg, nlh, payload > sizeof(*errmsg)
> -					 ? nlh->nlmsg_len : sizeof(*nlh),
> -		      /* Bounds checked by the skb layer. */);
> +	errmsg->msg = *nlh;
> +
> +	if (!(flags & NLM_F_CAPPED)) {
> +		if (!nlmsg_append(skb, nlmsg_len(nlh)))
> +			goto err_bad_put;
> +
> +		/* the nlh + 1 is probably going to make you unhappy? */
> +		memcpy(errmsg->msg.nlmsg_data, nlh->nlmsg_data,
> +		       nlmsg_len(nlh));
> +	}
>  
>  	if (tlvlen)
>  		netlink_ack_tlv_fill(in_skb, skb, nlh, err, extack);
> @@ -2508,6 +2514,12 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
>  	nlmsg_end(skb, rep);
>  
>  	nlmsg_unicast(in_skb->sk, skb, NETLINK_CB(in_skb).portid);
> +
> +	return;
> +
> +err_bad_put:
> +	NETLINK_CB(in_skb).sk->sk_err = ENOBUFS;
> +	sk_error_report(NETLINK_CB(in_skb).sk);
>  }
>  EXPORT_SYMBOL(netlink_ack);

The rest looks great! :) I suspect you'll want to do close to the same
conversion in net/netfilter/ipset/ip_set_core.c in call_ad() too.
diff mbox series

Patch

diff --git a/include/net/netlink.h b/include/net/netlink.h
index 4418b1981e31..46c40fabd2b5 100644
--- a/include/net/netlink.h
+++ b/include/net/netlink.h
@@ -931,6 +931,27 @@  static inline struct nlmsghdr *nlmsg_put(struct sk_buff *skb, u32 portid, u32 se
 	return __nlmsg_put(skb, portid, seq, type, payload, flags);
 }
 
+/**
+ * nlmsg_append - Add more data to a nlmsg in a skb
+ * @skb: socket buffer to store message in
+ * @payload: length of message payload
+ *
+ * Append data to an existing nlmsg, used when constructing a message
+ * with multiple fixed-format headers (which is rare).
+ * Returns NULL if the tailroom of the skb is insufficient to store
+ * the extra payload.
+ */
+static inline void *nlmsg_append(struct sk_buff *skb, u32 size)
+{
+	if (unlikely(skb_tailroom(skb) < NLMSG_ALIGN(size)))
+		return NULL;
+
+	if (NLMSG_ALIGN(size) - size)
+		memset(skb_tail_pointer(skb) + size, 0,
+		       NLMSG_ALIGN(size) - size);
+	return __skb_put(skb, NLMSG_ALIGN(size));
+}
+
 /**
  * nlmsg_put_answer - Add a new callback based netlink message to an skb
  * @skb: socket buffer to store message in
diff --git a/include/uapi/linux/netlink.h b/include/uapi/linux/netlink.h
index e2ae82e3f9f7..fba3ca8152fa 100644
--- a/include/uapi/linux/netlink.h
+++ b/include/uapi/linux/netlink.h
@@ -48,6 +48,7 @@  struct sockaddr_nl {
  * @nlmsg_flags: Additional flags
  * @nlmsg_seq:   Sequence number
  * @nlmsg_pid:   Sending process port ID
+ * @nlmsg_data:  Message payload
  */
 struct nlmsghdr {
 	__u32		nlmsg_len;
@@ -55,6 +56,7 @@  struct nlmsghdr {
 	__u16		nlmsg_flags;
 	__u32		nlmsg_seq;
 	__u32		nlmsg_pid;
+	__DECLARE_FLEX_ARRAY(char, nlmsg_data);
 };
 
 /* Flags values */
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index a662e8a5ff84..f8c94454b916 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -2488,19 +2488,25 @@  void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
 		flags |= NLM_F_ACK_TLVS;
 
 	skb = nlmsg_new(payload + tlvlen, GFP_KERNEL);
-	if (!skb) {
-		NETLINK_CB(in_skb).sk->sk_err = ENOBUFS;
-		sk_error_report(NETLINK_CB(in_skb).sk);
-		return;
-	}
+	if (!skb)
+		goto err_bad_put;
 
 	rep = nlmsg_put(skb, NETLINK_CB(in_skb).portid, nlh->nlmsg_seq,
-			NLMSG_ERROR, payload, flags);
+			NLMSG_ERROR, sizeof(*errmsg), flags);
+	if (!rep)
+		goto err_bad_put;
 	errmsg = nlmsg_data(rep);
 	errmsg->error = err;
-	unsafe_memcpy(&errmsg->msg, nlh, payload > sizeof(*errmsg)
-					 ? nlh->nlmsg_len : sizeof(*nlh),
-		      /* Bounds checked by the skb layer. */);
+	errmsg->msg = *nlh;
+
+	if (!(flags & NLM_F_CAPPED)) {
+		if (!nlmsg_append(skb, nlmsg_len(nlh)))
+			goto err_bad_put;
+
+		/* the nlh + 1 is probably going to make you unhappy? */
+		memcpy(errmsg->msg.nlmsg_data, nlh->nlmsg_data,
+		       nlmsg_len(nlh));
+	}
 
 	if (tlvlen)
 		netlink_ack_tlv_fill(in_skb, skb, nlh, err, extack);
@@ -2508,6 +2514,12 @@  void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
 	nlmsg_end(skb, rep);
 
 	nlmsg_unicast(in_skb->sk, skb, NETLINK_CB(in_skb).portid);
+
+	return;
+
+err_bad_put:
+	NETLINK_CB(in_skb).sk->sk_err = ENOBUFS;
+	sk_error_report(NETLINK_CB(in_skb).sk);
 }
 EXPORT_SYMBOL(netlink_ack);