diff mbox series

LSM: Infrastructure management of the sock

Message ID 20230531110506.142951-1-gongruiqi@huaweicloud.com (mailing list archive)
State Handled Elsewhere
Delegated to: Paul Moore
Headers show
Series LSM: Infrastructure management of the sock | expand

Commit Message

GONG, Ruiqi May 31, 2023, 11:05 a.m. UTC
As the security infrastructure has taken over the management of multiple
*_security blobs that are accessed by multiple security modules, and
sk->sk_security shares the same situation, move its management out of
individual security modules and into the security infrastructure as
well. The infrastructure does the memory allocation, and each relavant
module uses its own share.

Signed-off-by: GONG, Ruiqi <gongruiqi@huaweicloud.com>
---
 include/linux/lsm_hooks.h         |  1 +
 security/apparmor/include/net.h   |  2 +-
 security/apparmor/lsm.c           | 20 +-------
 security/security.c               | 35 ++++++++++++-
 security/selinux/hooks.c          | 81 ++++++++++++++-----------------
 security/selinux/include/objsec.h |  4 ++
 security/selinux/netlabel.c       | 22 ++++-----
 security/smack/smack.h            |  5 ++
 security/smack/smack_lsm.c        | 65 +++++++++++--------------
 security/smack/smack_netfilter.c  |  4 +-
 10 files changed, 125 insertions(+), 114 deletions(-)

Comments

Casey Schaufler May 31, 2023, 2 p.m. UTC | #1
On 5/31/2023 4:05 AM, GONG, Ruiqi wrote:
> As the security infrastructure has taken over the management of multiple
> *_security blobs that are accessed by multiple security modules, and
> sk->sk_security shares the same situation, move its management out of
> individual security modules and into the security infrastructure as
> well. The infrastructure does the memory allocation, and each relavant
> module uses its own share.

Do you have a reason to make this change? The LSM infrastructure
manages other security blobs to enable multiple concurrently active
LSMs to use the blob. If only one LSM on a system can use the
socket blob there's no reason to move the management.

>
> Signed-off-by: GONG, Ruiqi <gongruiqi@huaweicloud.com>
> ---
>  include/linux/lsm_hooks.h         |  1 +
>  security/apparmor/include/net.h   |  2 +-
>  security/apparmor/lsm.c           | 20 +-------
>  security/security.c               | 35 ++++++++++++-
>  security/selinux/hooks.c          | 81 ++++++++++++++-----------------
>  security/selinux/include/objsec.h |  4 ++
>  security/selinux/netlabel.c       | 22 ++++-----
>  security/smack/smack.h            |  5 ++
>  security/smack/smack_lsm.c        | 65 +++++++++++--------------
>  security/smack/smack_netfilter.c  |  4 +-
>  10 files changed, 125 insertions(+), 114 deletions(-)
>
> diff --git a/include/linux/lsm_hooks.h b/include/linux/lsm_hooks.h
> index ab2b2fafa4a4..67b6e87ca6ec 100644
> --- a/include/linux/lsm_hooks.h
> +++ b/include/linux/lsm_hooks.h
> @@ -62,6 +62,7 @@ struct lsm_blob_sizes {
>  	int	lbs_superblock;
>  	int	lbs_ipc;
>  	int	lbs_msg_msg;
> +	int	lbs_sock;
>  	int	lbs_task;
>  };
>  
> diff --git a/security/apparmor/include/net.h b/security/apparmor/include/net.h
> index 6fa440b5daed..9eb159c09578 100644
> --- a/security/apparmor/include/net.h
> +++ b/security/apparmor/include/net.h
> @@ -51,7 +51,7 @@ struct aa_sk_ctx {
>  	struct aa_label *peer;
>  };
>  
> -#define SK_CTX(X) ((X)->sk_security)
> +#define SK_CTX(X) ((X)->sk_security + apparmor_blob_sizes.lbs_sock)
>  #define SOCK_ctx(X) SOCK_INODE(X)->i_security
>  #define DEFINE_AUDIT_NET(NAME, OP, SK, F, T, P)				  \
>  	struct lsm_network_audit NAME ## _net = { .sk = (SK),		  \
> diff --git a/security/apparmor/lsm.c b/security/apparmor/lsm.c
> index f431251ffb91..3dd849a6d7a1 100644
> --- a/security/apparmor/lsm.c
> +++ b/security/apparmor/lsm.c
> @@ -818,22 +818,6 @@ static int apparmor_task_kill(struct task_struct *target, struct kernel_siginfo
>  	return error;
>  }
>  
> -/**
> - * apparmor_sk_alloc_security - allocate and attach the sk_security field
> - */
> -static int apparmor_sk_alloc_security(struct sock *sk, int family, gfp_t flags)
> -{
> -	struct aa_sk_ctx *ctx;
> -
> -	ctx = kzalloc(sizeof(*ctx), flags);
> -	if (!ctx)
> -		return -ENOMEM;
> -
> -	SK_CTX(sk) = ctx;
> -
> -	return 0;
> -}
> -
>  /**
>   * apparmor_sk_free_security - free the sk_security field
>   */
> @@ -841,10 +825,8 @@ static void apparmor_sk_free_security(struct sock *sk)
>  {
>  	struct aa_sk_ctx *ctx = SK_CTX(sk);
>  
> -	SK_CTX(sk) = NULL;
>  	aa_put_label(ctx->label);
>  	aa_put_label(ctx->peer);
> -	kfree(ctx);
>  }
>  
>  /**
> @@ -1212,6 +1194,7 @@ static int apparmor_inet_conn_request(const struct sock *sk, struct sk_buff *skb
>  struct lsm_blob_sizes apparmor_blob_sizes __ro_after_init = {
>  	.lbs_cred = sizeof(struct aa_label *),
>  	.lbs_file = sizeof(struct aa_file_ctx),
> +	.lbs_sock = sizeof(struct aa_sk_ctx),
>  	.lbs_task = sizeof(struct aa_task_ctx),
>  };
>  
> @@ -1250,7 +1233,6 @@ static struct security_hook_list apparmor_hooks[] __ro_after_init = {
>  	LSM_HOOK_INIT(getprocattr, apparmor_getprocattr),
>  	LSM_HOOK_INIT(setprocattr, apparmor_setprocattr),
>  
> -	LSM_HOOK_INIT(sk_alloc_security, apparmor_sk_alloc_security),
>  	LSM_HOOK_INIT(sk_free_security, apparmor_sk_free_security),
>  	LSM_HOOK_INIT(sk_clone_security, apparmor_sk_clone_security),
>  
> diff --git a/security/security.c b/security/security.c
> index b720424ca37d..e71f4717cde5 100644
> --- a/security/security.c
> +++ b/security/security.c
> @@ -30,6 +30,7 @@
>  #include <linux/string.h>
>  #include <linux/msg.h>
>  #include <net/flow.h>
> +#include <net/sock.h>
>  
>  #define MAX_LSM_EVM_XATTR	2
>  
> @@ -210,6 +211,7 @@ static void __init lsm_set_blob_sizes(struct lsm_blob_sizes *needed)
>  	lsm_set_blob_size(&needed->lbs_inode, &blob_sizes.lbs_inode);
>  	lsm_set_blob_size(&needed->lbs_ipc, &blob_sizes.lbs_ipc);
>  	lsm_set_blob_size(&needed->lbs_msg_msg, &blob_sizes.lbs_msg_msg);
> +	lsm_set_blob_size(&needed->lbs_sock, &blob_sizes.lbs_sock);
>  	lsm_set_blob_size(&needed->lbs_superblock, &blob_sizes.lbs_superblock);
>  	lsm_set_blob_size(&needed->lbs_task, &blob_sizes.lbs_task);
>  }
> @@ -376,6 +378,7 @@ static void __init ordered_lsm_init(void)
>  	init_debug("inode blob size      = %d\n", blob_sizes.lbs_inode);
>  	init_debug("ipc blob size        = %d\n", blob_sizes.lbs_ipc);
>  	init_debug("msg_msg blob size    = %d\n", blob_sizes.lbs_msg_msg);
> +	init_debug("sock blob size       = %d\n", blob_sizes.lbs_sock);
>  	init_debug("superblock blob size = %d\n", blob_sizes.lbs_superblock);
>  	init_debug("task blob size       = %d\n", blob_sizes.lbs_task);
>  
> @@ -733,6 +736,27 @@ static int lsm_superblock_alloc(struct super_block *sb)
>  	return 0;
>  }
>  
> +/**
> + * lsm_sock_alloc - allocate a composite socket blob
> + * @sk: the socket that needs a blob
> + *
> + * Allocate the socket blob for all the modules
> + *
> + * Returns 0, or -ENOMEM if memory can't be allocated.
> + */
> +static int lsm_sock_alloc(struct sock *sk)
> +{
> +	if (blob_sizes.lbs_sock == 0) {
> +		sk->sk_security = NULL;
> +		return 0;
> +	}
> +
> +	sk->sk_security = kzalloc(blob_sizes.lbs_sock, GFP_KERNEL);
> +	if (sk->sk_security == NULL)
> +		return -ENOMEM;
> +	return 0;
> +}
> +
>  /*
>   * The default value of the LSM hook is defined in linux/lsm_hook_defs.h and
>   * can be accessed with:
> @@ -4369,7 +4393,14 @@ EXPORT_SYMBOL(security_socket_getpeersec_dgram);
>   */
>  int security_sk_alloc(struct sock *sk, int family, gfp_t priority)
>  {
> -	return call_int_hook(sk_alloc_security, 0, sk, family, priority);
> +	int rc = lsm_sock_alloc(sk);
> +
> +	if (unlikely(rc))
> +		return rc;
> +	rc = call_int_hook(sk_alloc_security, 0, sk, family, priority);
> +	if (unlikely(rc))
> +		security_sk_free(sk);
> +	return rc;
>  }
>  
>  /**
> @@ -4381,6 +4412,8 @@ int security_sk_alloc(struct sock *sk, int family, gfp_t priority)
>  void security_sk_free(struct sock *sk)
>  {
>  	call_void_hook(sk_free_security, sk);
> +	kfree(sk->sk_security);
> +	sk->sk_security = NULL;
>  }
>  
>  /**
> diff --git a/security/selinux/hooks.c b/security/selinux/hooks.c
> index d06e350fedee..f8397f05dc90 100644
> --- a/security/selinux/hooks.c
> +++ b/security/selinux/hooks.c
> @@ -4497,7 +4497,7 @@ static int socket_sockcreate_sid(const struct task_security_struct *tsec,
>  
>  static int sock_has_perm(struct sock *sk, u32 perms)
>  {
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	struct common_audit_data ad;
>  	struct lsm_network_audit net = {0,};
>  
> @@ -4552,7 +4552,7 @@ static int selinux_socket_post_create(struct socket *sock, int family,
>  	isec->initialized = LABEL_INITIALIZED;
>  
>  	if (sock->sk) {
> -		sksec = sock->sk->sk_security;
> +		sksec = selinux_sock(sock->sk);
>  		sksec->sclass = sclass;
>  		sksec->sid = sid;
>  		/* Allows detection of the first association on this socket */
> @@ -4568,8 +4568,8 @@ static int selinux_socket_post_create(struct socket *sock, int family,
>  static int selinux_socket_socketpair(struct socket *socka,
>  				     struct socket *sockb)
>  {
> -	struct sk_security_struct *sksec_a = socka->sk->sk_security;
> -	struct sk_security_struct *sksec_b = sockb->sk->sk_security;
> +	struct sk_security_struct *sksec_a = selinux_sock(socka->sk);
> +	struct sk_security_struct *sksec_b = selinux_sock(sockb->sk);
>  
>  	sksec_a->peer_sid = sksec_b->sid;
>  	sksec_b->peer_sid = sksec_a->sid;
> @@ -4584,7 +4584,7 @@ static int selinux_socket_socketpair(struct socket *socka,
>  static int selinux_socket_bind(struct socket *sock, struct sockaddr *address, int addrlen)
>  {
>  	struct sock *sk = sock->sk;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	u16 family;
>  	int err;
>  
> @@ -4717,7 +4717,7 @@ static int selinux_socket_connect_helper(struct socket *sock,
>  					 struct sockaddr *address, int addrlen)
>  {
>  	struct sock *sk = sock->sk;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	int err;
>  
>  	err = sock_has_perm(sk, SOCKET__CONNECT);
> @@ -4895,9 +4895,9 @@ static int selinux_socket_unix_stream_connect(struct sock *sock,
>  					      struct sock *other,
>  					      struct sock *newsk)
>  {
> -	struct sk_security_struct *sksec_sock = sock->sk_security;
> -	struct sk_security_struct *sksec_other = other->sk_security;
> -	struct sk_security_struct *sksec_new = newsk->sk_security;
> +	struct sk_security_struct *sksec_sock = selinux_sock(sock);
> +	struct sk_security_struct *sksec_other = selinux_sock(other);
> +	struct sk_security_struct *sksec_new = selinux_sock(newsk);
>  	struct common_audit_data ad;
>  	struct lsm_network_audit net = {0,};
>  	int err;
> @@ -4928,8 +4928,8 @@ static int selinux_socket_unix_stream_connect(struct sock *sock,
>  static int selinux_socket_unix_may_send(struct socket *sock,
>  					struct socket *other)
>  {
> -	struct sk_security_struct *ssec = sock->sk->sk_security;
> -	struct sk_security_struct *osec = other->sk->sk_security;
> +	struct sk_security_struct *ssec = selinux_sock(sock->sk);
> +	struct sk_security_struct *osec = selinux_sock(other->sk);
>  	struct common_audit_data ad;
>  	struct lsm_network_audit net = {0,};
>  
> @@ -4968,7 +4968,7 @@ static int selinux_sock_rcv_skb_compat(struct sock *sk, struct sk_buff *skb,
>  				       u16 family)
>  {
>  	int err = 0;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	u32 sk_sid = sksec->sid;
>  	struct common_audit_data ad;
>  	struct lsm_network_audit net = {0,};
> @@ -5000,7 +5000,7 @@ static int selinux_sock_rcv_skb_compat(struct sock *sk, struct sk_buff *skb,
>  static int selinux_socket_sock_rcv_skb(struct sock *sk, struct sk_buff *skb)
>  {
>  	int err;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	u16 family = sk->sk_family;
>  	u32 sk_sid = sksec->sid;
>  	struct common_audit_data ad;
> @@ -5073,7 +5073,7 @@ static int selinux_socket_getpeersec_stream(struct socket *sock,
>  	int err = 0;
>  	char *scontext = NULL;
>  	u32 scontext_len;
> -	struct sk_security_struct *sksec = sock->sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sock->sk);
>  	u32 peer_sid = SECSID_NULL;
>  
>  	if (sksec->sclass == SECCLASS_UNIX_STREAM_SOCKET ||
> @@ -5131,34 +5131,27 @@ static int selinux_socket_getpeersec_dgram(struct socket *sock, struct sk_buff *
>  
>  static int selinux_sk_alloc_security(struct sock *sk, int family, gfp_t priority)
>  {
> -	struct sk_security_struct *sksec;
> -
> -	sksec = kzalloc(sizeof(*sksec), priority);
> -	if (!sksec)
> -		return -ENOMEM;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  
>  	sksec->peer_sid = SECINITSID_UNLABELED;
>  	sksec->sid = SECINITSID_UNLABELED;
>  	sksec->sclass = SECCLASS_SOCKET;
>  	selinux_netlbl_sk_security_reset(sksec);
> -	sk->sk_security = sksec;
>  
>  	return 0;
>  }
>  
>  static void selinux_sk_free_security(struct sock *sk)
>  {
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  
> -	sk->sk_security = NULL;
>  	selinux_netlbl_sk_security_free(sksec);
> -	kfree(sksec);
>  }
>  
>  static void selinux_sk_clone_security(const struct sock *sk, struct sock *newsk)
>  {
> -	struct sk_security_struct *sksec = sk->sk_security;
> -	struct sk_security_struct *newsksec = newsk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
> +	struct sk_security_struct *newsksec = selinux_sock(newsk);
>  
>  	newsksec->sid = sksec->sid;
>  	newsksec->peer_sid = sksec->peer_sid;
> @@ -5172,7 +5165,7 @@ static void selinux_sk_getsecid(struct sock *sk, u32 *secid)
>  	if (!sk)
>  		*secid = SECINITSID_ANY_SOCKET;
>  	else {
> -		struct sk_security_struct *sksec = sk->sk_security;
> +		struct sk_security_struct *sksec = selinux_sock(sk);
>  
>  		*secid = sksec->sid;
>  	}
> @@ -5182,7 +5175,7 @@ static void selinux_sock_graft(struct sock *sk, struct socket *parent)
>  {
>  	struct inode_security_struct *isec =
>  		inode_security_novalidate(SOCK_INODE(parent));
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  
>  	if (sk->sk_family == PF_INET || sk->sk_family == PF_INET6 ||
>  	    sk->sk_family == PF_UNIX)
> @@ -5199,7 +5192,7 @@ static int selinux_sctp_process_new_assoc(struct sctp_association *asoc,
>  {
>  	struct sock *sk = asoc->base.sk;
>  	u16 family = sk->sk_family;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	struct common_audit_data ad;
>  	struct lsm_network_audit net = {0,};
>  	int err;
> @@ -5256,7 +5249,7 @@ static int selinux_sctp_process_new_assoc(struct sctp_association *asoc,
>  static int selinux_sctp_assoc_request(struct sctp_association *asoc,
>  				      struct sk_buff *skb)
>  {
> -	struct sk_security_struct *sksec = asoc->base.sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
>  	u32 conn_sid;
>  	int err;
>  
> @@ -5289,7 +5282,7 @@ static int selinux_sctp_assoc_request(struct sctp_association *asoc,
>  static int selinux_sctp_assoc_established(struct sctp_association *asoc,
>  					  struct sk_buff *skb)
>  {
> -	struct sk_security_struct *sksec = asoc->base.sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
>  
>  	if (!selinux_policycap_extsockclass())
>  		return 0;
> @@ -5388,8 +5381,8 @@ static int selinux_sctp_bind_connect(struct sock *sk, int optname,
>  static void selinux_sctp_sk_clone(struct sctp_association *asoc, struct sock *sk,
>  				  struct sock *newsk)
>  {
> -	struct sk_security_struct *sksec = sk->sk_security;
> -	struct sk_security_struct *newsksec = newsk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
> +	struct sk_security_struct *newsksec = selinux_sock(newsk);
>  
>  	/* If policy does not support SECCLASS_SCTP_SOCKET then call
>  	 * the non-sctp clone version.
> @@ -5405,8 +5398,8 @@ static void selinux_sctp_sk_clone(struct sctp_association *asoc, struct sock *sk
>  
>  static int selinux_mptcp_add_subflow(struct sock *sk, struct sock *ssk)
>  {
> -	struct sk_security_struct *ssksec = ssk->sk_security;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *ssksec = selinux_sock(ssk);
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  
>  	ssksec->sclass = sksec->sclass;
>  	ssksec->sid = sksec->sid;
> @@ -5421,7 +5414,7 @@ static int selinux_mptcp_add_subflow(struct sock *sk, struct sock *ssk)
>  static int selinux_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
>  				     struct request_sock *req)
>  {
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	int err;
>  	u16 family = req->rsk_ops->family;
>  	u32 connsid;
> @@ -5442,7 +5435,7 @@ static int selinux_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
>  static void selinux_inet_csk_clone(struct sock *newsk,
>  				   const struct request_sock *req)
>  {
> -	struct sk_security_struct *newsksec = newsk->sk_security;
> +	struct sk_security_struct *newsksec = selinux_sock(newsk);
>  
>  	newsksec->sid = req->secid;
>  	newsksec->peer_sid = req->peer_secid;
> @@ -5459,7 +5452,7 @@ static void selinux_inet_csk_clone(struct sock *newsk,
>  static void selinux_inet_conn_established(struct sock *sk, struct sk_buff *skb)
>  {
>  	u16 family = sk->sk_family;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  
>  	/* handle mapped IPv4 packets arriving via IPv6 sockets */
>  	if (family == PF_INET6 && skb->protocol == htons(ETH_P_IP))
> @@ -5540,7 +5533,7 @@ static int selinux_tun_dev_attach_queue(void *security)
>  static int selinux_tun_dev_attach(struct sock *sk, void *security)
>  {
>  	struct tun_security_struct *tunsec = security;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  
>  	/* we don't currently perform any NetLabel based labeling here and it
>  	 * isn't clear that we would want to do so anyway; while we could apply
> @@ -5666,7 +5659,7 @@ static unsigned int selinux_ip_output(void *priv, struct sk_buff *skb,
>  			return NF_ACCEPT;
>  
>  		/* standard practice, label using the parent socket */
> -		sksec = sk->sk_security;
> +		sksec = selinux_sock(sk);
>  		sid = sksec->sid;
>  	} else
>  		sid = SECINITSID_KERNEL;
> @@ -5689,7 +5682,7 @@ static unsigned int selinux_ip_postroute_compat(struct sk_buff *skb,
>  	sk = skb_to_full_sk(skb);
>  	if (sk == NULL)
>  		return NF_ACCEPT;
> -	sksec = sk->sk_security;
> +	sksec = selinux_sock(sk);
>  
>  	ad.type = LSM_AUDIT_DATA_NET;
>  	ad.u.net = &net;
> @@ -5779,9 +5772,8 @@ static unsigned int selinux_ip_postroute(void *priv,
>  		 * selinux_inet_conn_request().  See also selinux_ip_output()
>  		 * for similar problems. */
>  		u32 skb_sid;
> -		struct sk_security_struct *sksec;
> +		struct sk_security_struct *sksec = selinux_sock(sk);
>  
> -		sksec = sk->sk_security;
>  		if (selinux_skb_peerlbl_sid(skb, family, &skb_sid))
>  			return NF_DROP;
>  		/* At this point, if the returned skb peerlbl is SECSID_NULL
> @@ -5810,7 +5802,7 @@ static unsigned int selinux_ip_postroute(void *priv,
>  	} else {
>  		/* Locally generated packet, fetch the security label from the
>  		 * associated socket. */
> -		struct sk_security_struct *sksec = sk->sk_security;
> +		struct sk_security_struct *sksec = selinux_sock(sk);
>  		peer_sid = sksec->sid;
>  		secmark_perm = PACKET__SEND;
>  	}
> @@ -5856,7 +5848,7 @@ static int selinux_netlink_send(struct sock *sk, struct sk_buff *skb)
>  	unsigned int data_len = skb->len;
>  	unsigned char *data = skb->data;
>  	struct nlmsghdr *nlh;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	u16 sclass = sksec->sclass;
>  	u32 perm;
>  
> @@ -6814,6 +6806,7 @@ struct lsm_blob_sizes selinux_blob_sizes __ro_after_init = {
>  	.lbs_inode = sizeof(struct inode_security_struct),
>  	.lbs_ipc = sizeof(struct ipc_security_struct),
>  	.lbs_msg_msg = sizeof(struct msg_security_struct),
> +	.lbs_sock = sizeof(struct sk_security_struct),
>  	.lbs_superblock = sizeof(struct superblock_security_struct),
>  };
>  
> diff --git a/security/selinux/include/objsec.h b/security/selinux/include/objsec.h
> index 2953132408bf..49221f441c68 100644
> --- a/security/selinux/include/objsec.h
> +++ b/security/selinux/include/objsec.h
> @@ -194,4 +194,8 @@ static inline struct superblock_security_struct *selinux_superblock(
>  	return superblock->s_security + selinux_blob_sizes.lbs_superblock;
>  }
>  
> +static inline struct sk_security_struct *selinux_sock(const struct sock *sk)
> +{
> +	return sk->sk_security + selinux_blob_sizes.lbs_sock;
> +}
>  #endif /* _SELINUX_OBJSEC_H_ */
> diff --git a/security/selinux/netlabel.c b/security/selinux/netlabel.c
> index 528f5186e912..9755561aa466 100644
> --- a/security/selinux/netlabel.c
> +++ b/security/selinux/netlabel.c
> @@ -68,7 +68,7 @@ static int selinux_netlbl_sidlookup_cached(struct sk_buff *skb,
>  static struct netlbl_lsm_secattr *selinux_netlbl_sock_genattr(struct sock *sk)
>  {
>  	int rc;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	struct netlbl_lsm_secattr *secattr;
>  
>  	if (sksec->nlbl_secattr != NULL)
> @@ -100,7 +100,7 @@ static struct netlbl_lsm_secattr *selinux_netlbl_sock_getattr(
>  							const struct sock *sk,
>  							u32 sid)
>  {
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	struct netlbl_lsm_secattr *secattr = sksec->nlbl_secattr;
>  
>  	if (secattr == NULL)
> @@ -239,7 +239,7 @@ int selinux_netlbl_skbuff_setsid(struct sk_buff *skb,
>  	 * being labeled by it's parent socket, if it is just exit */
>  	sk = skb_to_full_sk(skb);
>  	if (sk != NULL) {
> -		struct sk_security_struct *sksec = sk->sk_security;
> +		struct sk_security_struct *sksec = selinux_sock(sk);
>  
>  		if (sksec->nlbl_state != NLBL_REQSKB)
>  			return 0;
> @@ -276,7 +276,7 @@ int selinux_netlbl_sctp_assoc_request(struct sctp_association *asoc,
>  {
>  	int rc;
>  	struct netlbl_lsm_secattr secattr;
> -	struct sk_security_struct *sksec = asoc->base.sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
>  	struct sockaddr_in addr4;
>  	struct sockaddr_in6 addr6;
>  
> @@ -355,7 +355,7 @@ int selinux_netlbl_inet_conn_request(struct request_sock *req, u16 family)
>   */
>  void selinux_netlbl_inet_csk_clone(struct sock *sk, u16 family)
>  {
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  
>  	if (family == PF_INET)
>  		sksec->nlbl_state = NLBL_LABELED;
> @@ -373,8 +373,8 @@ void selinux_netlbl_inet_csk_clone(struct sock *sk, u16 family)
>   */
>  void selinux_netlbl_sctp_sk_clone(struct sock *sk, struct sock *newsk)
>  {
> -	struct sk_security_struct *sksec = sk->sk_security;
> -	struct sk_security_struct *newsksec = newsk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
> +	struct sk_security_struct *newsksec = selinux_sock(newsk);
>  
>  	newsksec->nlbl_state = sksec->nlbl_state;
>  }
> @@ -392,7 +392,7 @@ void selinux_netlbl_sctp_sk_clone(struct sock *sk, struct sock *newsk)
>  int selinux_netlbl_socket_post_create(struct sock *sk, u16 family)
>  {
>  	int rc;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	struct netlbl_lsm_secattr *secattr;
>  
>  	if (family != PF_INET && family != PF_INET6)
> @@ -506,7 +506,7 @@ int selinux_netlbl_socket_setsockopt(struct socket *sock,
>  {
>  	int rc = 0;
>  	struct sock *sk = sock->sk;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	struct netlbl_lsm_secattr secattr;
>  
>  	if (selinux_netlbl_option(level, optname) &&
> @@ -544,7 +544,7 @@ static int selinux_netlbl_socket_connect_helper(struct sock *sk,
>  						struct sockaddr *addr)
>  {
>  	int rc;
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  	struct netlbl_lsm_secattr *secattr;
>  
>  	/* connected sockets are allowed to disconnect when the address family
> @@ -583,7 +583,7 @@ static int selinux_netlbl_socket_connect_helper(struct sock *sk,
>  int selinux_netlbl_socket_connect_locked(struct sock *sk,
>  					 struct sockaddr *addr)
>  {
> -	struct sk_security_struct *sksec = sk->sk_security;
> +	struct sk_security_struct *sksec = selinux_sock(sk);
>  
>  	if (sksec->nlbl_state != NLBL_REQSKB &&
>  	    sksec->nlbl_state != NLBL_CONNLABELED)
> diff --git a/security/smack/smack.h b/security/smack/smack.h
> index aa15ff56ed6e..2d0163076eca 100644
> --- a/security/smack/smack.h
> +++ b/security/smack/smack.h
> @@ -355,6 +355,11 @@ static inline struct superblock_smack *smack_superblock(
>  	return superblock->s_security + smack_blob_sizes.lbs_superblock;
>  }
>  
> +static inline struct socket_smack *smack_sock(const struct sock *sk)
> +{
> +	return sk->sk_security + smack_blob_sizes.lbs_sock;
> +}
> +
>  /*
>   * Is the directory transmuting?
>   */
> diff --git a/security/smack/smack_lsm.c b/security/smack/smack_lsm.c
> index 6e270cf3fd30..ab026ff79504 100644
> --- a/security/smack/smack_lsm.c
> +++ b/security/smack/smack_lsm.c
> @@ -1502,7 +1502,7 @@ static int smack_inode_getsecurity(struct mnt_idmap *idmap,
>  		if (sock == NULL || sock->sk == NULL)
>  			return -EOPNOTSUPP;
>  
> -		ssp = sock->sk->sk_security;
> +		ssp = smack_sock(sock->sk);
>  
>  		if (strcmp(name, XATTR_SMACK_IPIN) == 0)
>  			isp = ssp->smk_in;
> @@ -1890,7 +1890,7 @@ static int smack_file_receive(struct file *file)
>  
>  	if (inode->i_sb->s_magic == SOCKFS_MAGIC) {
>  		sock = SOCKET_I(inode);
> -		ssp = sock->sk->sk_security;
> +		ssp = smack_sock(sock->sk);
>  		tsp = smack_cred(current_cred());
>  		/*
>  		 * If the receiving process can't write to the
> @@ -2310,11 +2310,7 @@ static void smack_task_to_inode(struct task_struct *p, struct inode *inode)
>  static int smack_sk_alloc_security(struct sock *sk, int family, gfp_t gfp_flags)
>  {
>  	struct smack_known *skp = smk_of_current();
> -	struct socket_smack *ssp;
> -
> -	ssp = kzalloc(sizeof(struct socket_smack), gfp_flags);
> -	if (ssp == NULL)
> -		return -ENOMEM;
> +	struct socket_smack *ssp = smack_sock(sk);
>  
>  	/*
>  	 * Sockets created by kernel threads receive web label.
> @@ -2328,8 +2324,6 @@ static int smack_sk_alloc_security(struct sock *sk, int family, gfp_t gfp_flags)
>  	}
>  	ssp->smk_packet = NULL;
>  
> -	sk->sk_security = ssp;
> -
>  	return 0;
>  }
>  
> @@ -2355,7 +2349,6 @@ static void smack_sk_free_security(struct sock *sk)
>  		rcu_read_unlock();
>  	}
>  #endif
> -	kfree(sk->sk_security);
>  }
>  
>  /**
> @@ -2367,8 +2360,8 @@ static void smack_sk_free_security(struct sock *sk)
>   */
>  static void smack_sk_clone_security(const struct sock *sk, struct sock *newsk)
>  {
> -	struct socket_smack *ssp_old = sk->sk_security;
> -	struct socket_smack *ssp_new = newsk->sk_security;
> +	struct socket_smack *ssp_old = smack_sock(sk);
> +	struct socket_smack *ssp_new = smack_sock(newsk);
>  
>  	*ssp_new = *ssp_old;
>  }
> @@ -2484,7 +2477,7 @@ static struct smack_known *smack_ipv6host_label(struct sockaddr_in6 *sip)
>   */
>  static int smack_netlbl_add(struct sock *sk)
>  {
> -	struct socket_smack *ssp = sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sk);
>  	struct smack_known *skp = ssp->smk_out;
>  	int rc;
>  
> @@ -2516,7 +2509,7 @@ static int smack_netlbl_add(struct sock *sk)
>   */
>  static void smack_netlbl_delete(struct sock *sk)
>  {
> -	struct socket_smack *ssp = sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sk);
>  
>  	/*
>  	 * Take the label off the socket if one is set.
> @@ -2548,7 +2541,7 @@ static int smk_ipv4_check(struct sock *sk, struct sockaddr_in *sap)
>  	struct smack_known *skp;
>  	int rc = 0;
>  	struct smack_known *hkp;
> -	struct socket_smack *ssp = sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sk);
>  	struct smk_audit_info ad;
>  
>  	rcu_read_lock();
> @@ -2621,7 +2614,7 @@ static void smk_ipv6_port_label(struct socket *sock, struct sockaddr *address)
>  {
>  	struct sock *sk = sock->sk;
>  	struct sockaddr_in6 *addr6;
> -	struct socket_smack *ssp = sock->sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sock->sk);
>  	struct smk_port_label *spp;
>  	unsigned short port = 0;
>  
> @@ -2709,7 +2702,7 @@ static int smk_ipv6_port_check(struct sock *sk, struct sockaddr_in6 *address,
>  				int act)
>  {
>  	struct smk_port_label *spp;
> -	struct socket_smack *ssp = sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sk);
>  	struct smack_known *skp = NULL;
>  	unsigned short port;
>  	struct smack_known *object;
> @@ -2803,7 +2796,7 @@ static int smack_inode_setsecurity(struct inode *inode, const char *name,
>  	if (sock == NULL || sock->sk == NULL)
>  		return -EOPNOTSUPP;
>  
> -	ssp = sock->sk->sk_security;
> +	ssp = smack_sock(sock->sk);
>  
>  	if (strcmp(name, XATTR_SMACK_IPIN) == 0)
>  		ssp->smk_in = skp;
> @@ -2851,7 +2844,7 @@ static int smack_socket_post_create(struct socket *sock, int family,
>  	 * Sockets created by kernel threads receive web label.
>  	 */
>  	if (unlikely(current->flags & PF_KTHREAD)) {
> -		ssp = sock->sk->sk_security;
> +		ssp = smack_sock(sock->sk);
>  		ssp->smk_in = &smack_known_web;
>  		ssp->smk_out = &smack_known_web;
>  	}
> @@ -2876,8 +2869,8 @@ static int smack_socket_post_create(struct socket *sock, int family,
>  static int smack_socket_socketpair(struct socket *socka,
>  		                   struct socket *sockb)
>  {
> -	struct socket_smack *asp = socka->sk->sk_security;
> -	struct socket_smack *bsp = sockb->sk->sk_security;
> +	struct socket_smack *asp = smack_sock(socka->sk);
> +	struct socket_smack *bsp = smack_sock(sockb->sk);
>  
>  	asp->smk_packet = bsp->smk_out;
>  	bsp->smk_packet = asp->smk_out;
> @@ -2940,7 +2933,7 @@ static int smack_socket_connect(struct socket *sock, struct sockaddr *sap,
>  		if (__is_defined(SMACK_IPV6_SECMARK_LABELING))
>  			rsp = smack_ipv6host_label(sip);
>  		if (rsp != NULL) {
> -			struct socket_smack *ssp = sock->sk->sk_security;
> +			struct socket_smack *ssp = smack_sock(sock->sk);
>  
>  			rc = smk_ipv6_check(ssp->smk_out, rsp, sip,
>  					    SMK_CONNECTING);
> @@ -3671,9 +3664,9 @@ static int smack_unix_stream_connect(struct sock *sock,
>  {
>  	struct smack_known *skp;
>  	struct smack_known *okp;
> -	struct socket_smack *ssp = sock->sk_security;
> -	struct socket_smack *osp = other->sk_security;
> -	struct socket_smack *nsp = newsk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sock);
> +	struct socket_smack *osp = smack_sock(other);
> +	struct socket_smack *nsp = smack_sock(newsk);
>  	struct smk_audit_info ad;
>  	int rc = 0;
>  #ifdef CONFIG_AUDIT
> @@ -3719,8 +3712,8 @@ static int smack_unix_stream_connect(struct sock *sock,
>   */
>  static int smack_unix_may_send(struct socket *sock, struct socket *other)
>  {
> -	struct socket_smack *ssp = sock->sk->sk_security;
> -	struct socket_smack *osp = other->sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sock->sk);
> +	struct socket_smack *osp = smack_sock(other->sk);
>  	struct smk_audit_info ad;
>  	int rc;
>  
> @@ -3757,7 +3750,7 @@ static int smack_socket_sendmsg(struct socket *sock, struct msghdr *msg,
>  	struct sockaddr_in6 *sap = (struct sockaddr_in6 *) msg->msg_name;
>  #endif
>  #ifdef SMACK_IPV6_SECMARK_LABELING
> -	struct socket_smack *ssp = sock->sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sock->sk);
>  	struct smack_known *rsp;
>  #endif
>  	int rc = 0;
> @@ -3969,7 +3962,7 @@ static struct smack_known *smack_from_netlbl(const struct sock *sk, u16 family,
>  	netlbl_secattr_init(&secattr);
>  
>  	if (sk)
> -		ssp = sk->sk_security;
> +		ssp = smack_sock(sk);
>  
>  	if (netlbl_skbuff_getattr(skb, family, &secattr) == 0) {
>  		skp = smack_from_secattr(&secattr, ssp);
> @@ -3991,7 +3984,7 @@ static struct smack_known *smack_from_netlbl(const struct sock *sk, u16 family,
>   */
>  static int smack_socket_sock_rcv_skb(struct sock *sk, struct sk_buff *skb)
>  {
> -	struct socket_smack *ssp = sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sk);
>  	struct smack_known *skp = NULL;
>  	int rc = 0;
>  	struct smk_audit_info ad;
> @@ -4090,12 +4083,11 @@ static int smack_socket_getpeersec_stream(struct socket *sock,
>  					  sockptr_t optval, sockptr_t optlen,
>  					  unsigned int len)
>  {
> -	struct socket_smack *ssp;
> +	struct socket_smack *ssp = smack_sock(sock->sk);
>  	char *rcp = "";
>  	u32 slen = 1;
>  	int rc = 0;
>  
> -	ssp = sock->sk->sk_security;
>  	if (ssp->smk_packet != NULL) {
>  		rcp = ssp->smk_packet->smk_known;
>  		slen = strlen(rcp) + 1;
> @@ -4145,7 +4137,7 @@ static int smack_socket_getpeersec_dgram(struct socket *sock,
>  
>  	switch (family) {
>  	case PF_UNIX:
> -		ssp = sock->sk->sk_security;
> +		ssp = smack_sock(sock->sk);
>  		s = ssp->smk_out->smk_secid;
>  		break;
>  	case PF_INET:
> @@ -4194,7 +4186,7 @@ static void smack_sock_graft(struct sock *sk, struct socket *parent)
>  	    (sk->sk_family != PF_INET && sk->sk_family != PF_INET6))
>  		return;
>  
> -	ssp = sk->sk_security;
> +	ssp = smack_sock(sk);
>  	ssp->smk_in = skp;
>  	ssp->smk_out = skp;
>  	/* cssp->smk_packet is already set in smack_inet_csk_clone() */
> @@ -4214,7 +4206,7 @@ static int smack_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
>  {
>  	u16 family = sk->sk_family;
>  	struct smack_known *skp;
> -	struct socket_smack *ssp = sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sk);
>  	struct sockaddr_in addr;
>  	struct iphdr *hdr;
>  	struct smack_known *hskp;
> @@ -4300,7 +4292,7 @@ static int smack_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
>  static void smack_inet_csk_clone(struct sock *sk,
>  				 const struct request_sock *req)
>  {
> -	struct socket_smack *ssp = sk->sk_security;
> +	struct socket_smack *ssp = smack_sock(sk);
>  	struct smack_known *skp;
>  
>  	if (req->peer_secid != 0) {
> @@ -4868,6 +4860,7 @@ struct lsm_blob_sizes smack_blob_sizes __ro_after_init = {
>  	.lbs_inode = sizeof(struct inode_smack),
>  	.lbs_ipc = sizeof(struct smack_known *),
>  	.lbs_msg_msg = sizeof(struct smack_known *),
> +	.lbs_sock = sizeof(struct socket_smack),
>  	.lbs_superblock = sizeof(struct superblock_smack),
>  };
>  
> diff --git a/security/smack/smack_netfilter.c b/security/smack/smack_netfilter.c
> index b945c1d3a743..bad71b7e648d 100644
> --- a/security/smack/smack_netfilter.c
> +++ b/security/smack/smack_netfilter.c
> @@ -26,8 +26,8 @@ static unsigned int smack_ip_output(void *priv,
>  	struct socket_smack *ssp;
>  	struct smack_known *skp;
>  
> -	if (sk && sk->sk_security) {
> -		ssp = sk->sk_security;
> +	if (sk) {
> +		ssp = smack_sock(sk);
>  		skp = ssp->smk_out;
>  		skb->secmark = skp->smk_secid;
>  	}
Casey Schaufler May 31, 2023, 9:36 p.m. UTC | #2
On 5/31/2023 2:10 PM, Paul Moore wrote:
> On Wed, May 31, 2023 at 10:00 AM Casey Schaufler <casey@schaufler-ca.com> wrote:
>> On 5/31/2023 4:05 AM, GONG, Ruiqi wrote:
>>> As the security infrastructure has taken over the management of multiple
>>> *_security blobs that are accessed by multiple security modules, and
>>> sk->sk_security shares the same situation, move its management out of
>>> individual security modules and into the security infrastructure as
>>> well. The infrastructure does the memory allocation, and each relavant
>>> module uses its own share.
>> Do you have a reason to make this change? The LSM infrastructure
>> manages other security blobs to enable multiple concurrently active
>> LSMs to use the blob. If only one LSM on a system can use the
>> socket blob there's no reason to move the management.
> I think an argument could be made for consistent handling of security
> blobs, but with the LSM stacking work in development the argument for
> merging this patch needs to be a lot stronger than just "consistency".

I'm betting that someone has an out-of-tree LSM that uses a socket blob,
and that the intended use case includes stacking with one of the "major"
LSMs. I would encourage that someone to propose that LSM for upstream.
GONG, Ruiqi June 1, 2023, 3:03 a.m. UTC | #3
On 2023/05/31 22:00, Casey Schaufler wrote:
> On 5/31/2023 4:05 AM, GONG, Ruiqi wrote:
>> As the security infrastructure has taken over the management of multiple
>> *_security blobs that are accessed by multiple security modules, and
>> sk->sk_security shares the same situation, move its management out of
>> individual security modules and into the security infrastructure as
>> well. The infrastructure does the memory allocation, and each relavant
>> module uses its own share.
> 
> Do you have a reason to make this change? The LSM infrastructure
> manages other security blobs to enable multiple concurrently active
> LSMs to use the blob. If only one LSM on a system can use the
> socket blob there's no reason to move the management.

I proposed this patch because I was dealing with a kmemleak problem on
5.10 caused by disabling SELinux at runtime these days, which involed
the key and sk blobs, and got surprised that they were not managed by
the security infrastructure even in linux-next. I thought maybe they
were just left out temporarily, so let's unify the whole thing...

Since it seems there's no urgent demand for this, and the LSM stacking
is already in progress, I'm ok to just leave it for now.

> 
>>
>> Signed-off-by: GONG, Ruiqi <gongruiqi@huaweicloud.com>
>> ---
>>  include/linux/lsm_hooks.h         |  1 +
>>  security/apparmor/include/net.h   |  2 +-
>>  security/apparmor/lsm.c           | 20 +-------
>>  security/security.c               | 35 ++++++++++++-
>>  security/selinux/hooks.c          | 81 ++++++++++++++-----------------
>>  security/selinux/include/objsec.h |  4 ++
>>  security/selinux/netlabel.c       | 22 ++++-----
>>  security/smack/smack.h            |  5 ++
>>  security/smack/smack_lsm.c        | 65 +++++++++++--------------
>>  security/smack/smack_netfilter.c  |  4 +-
>>  10 files changed, 125 insertions(+), 114 deletions(-)
>>
>> diff --git a/include/linux/lsm_hooks.h b/include/linux/lsm_hooks.h
>> index ab2b2fafa4a4..67b6e87ca6ec 100644
>> --- a/include/linux/lsm_hooks.h
>> +++ b/include/linux/lsm_hooks.h
>> @@ -62,6 +62,7 @@ struct lsm_blob_sizes {
>>  	int	lbs_superblock;
>>  	int	lbs_ipc;
>>  	int	lbs_msg_msg;
>> +	int	lbs_sock;
>>  	int	lbs_task;
>>  };
>>  
>> diff --git a/security/apparmor/include/net.h b/security/apparmor/include/net.h
>> index 6fa440b5daed..9eb159c09578 100644
>> --- a/security/apparmor/include/net.h
>> +++ b/security/apparmor/include/net.h
>> @@ -51,7 +51,7 @@ struct aa_sk_ctx {
>>  	struct aa_label *peer;
>>  };
>>  
>> -#define SK_CTX(X) ((X)->sk_security)
>> +#define SK_CTX(X) ((X)->sk_security + apparmor_blob_sizes.lbs_sock)
>>  #define SOCK_ctx(X) SOCK_INODE(X)->i_security
>>  #define DEFINE_AUDIT_NET(NAME, OP, SK, F, T, P)				  \
>>  	struct lsm_network_audit NAME ## _net = { .sk = (SK),		  \
>> diff --git a/security/apparmor/lsm.c b/security/apparmor/lsm.c
>> index f431251ffb91..3dd849a6d7a1 100644
>> --- a/security/apparmor/lsm.c
>> +++ b/security/apparmor/lsm.c
>> @@ -818,22 +818,6 @@ static int apparmor_task_kill(struct task_struct *target, struct kernel_siginfo
>>  	return error;
>>  }
>>  
>> -/**
>> - * apparmor_sk_alloc_security - allocate and attach the sk_security field
>> - */
>> -static int apparmor_sk_alloc_security(struct sock *sk, int family, gfp_t flags)
>> -{
>> -	struct aa_sk_ctx *ctx;
>> -
>> -	ctx = kzalloc(sizeof(*ctx), flags);
>> -	if (!ctx)
>> -		return -ENOMEM;
>> -
>> -	SK_CTX(sk) = ctx;
>> -
>> -	return 0;
>> -}
>> -
>>  /**
>>   * apparmor_sk_free_security - free the sk_security field
>>   */
>> @@ -841,10 +825,8 @@ static void apparmor_sk_free_security(struct sock *sk)
>>  {
>>  	struct aa_sk_ctx *ctx = SK_CTX(sk);
>>  
>> -	SK_CTX(sk) = NULL;
>>  	aa_put_label(ctx->label);
>>  	aa_put_label(ctx->peer);
>> -	kfree(ctx);
>>  }
>>  
>>  /**
>> @@ -1212,6 +1194,7 @@ static int apparmor_inet_conn_request(const struct sock *sk, struct sk_buff *skb
>>  struct lsm_blob_sizes apparmor_blob_sizes __ro_after_init = {
>>  	.lbs_cred = sizeof(struct aa_label *),
>>  	.lbs_file = sizeof(struct aa_file_ctx),
>> +	.lbs_sock = sizeof(struct aa_sk_ctx),
>>  	.lbs_task = sizeof(struct aa_task_ctx),
>>  };
>>  
>> @@ -1250,7 +1233,6 @@ static struct security_hook_list apparmor_hooks[] __ro_after_init = {
>>  	LSM_HOOK_INIT(getprocattr, apparmor_getprocattr),
>>  	LSM_HOOK_INIT(setprocattr, apparmor_setprocattr),
>>  
>> -	LSM_HOOK_INIT(sk_alloc_security, apparmor_sk_alloc_security),
>>  	LSM_HOOK_INIT(sk_free_security, apparmor_sk_free_security),
>>  	LSM_HOOK_INIT(sk_clone_security, apparmor_sk_clone_security),
>>  
>> diff --git a/security/security.c b/security/security.c
>> index b720424ca37d..e71f4717cde5 100644
>> --- a/security/security.c
>> +++ b/security/security.c
>> @@ -30,6 +30,7 @@
>>  #include <linux/string.h>
>>  #include <linux/msg.h>
>>  #include <net/flow.h>
>> +#include <net/sock.h>
>>  
>>  #define MAX_LSM_EVM_XATTR	2
>>  
>> @@ -210,6 +211,7 @@ static void __init lsm_set_blob_sizes(struct lsm_blob_sizes *needed)
>>  	lsm_set_blob_size(&needed->lbs_inode, &blob_sizes.lbs_inode);
>>  	lsm_set_blob_size(&needed->lbs_ipc, &blob_sizes.lbs_ipc);
>>  	lsm_set_blob_size(&needed->lbs_msg_msg, &blob_sizes.lbs_msg_msg);
>> +	lsm_set_blob_size(&needed->lbs_sock, &blob_sizes.lbs_sock);
>>  	lsm_set_blob_size(&needed->lbs_superblock, &blob_sizes.lbs_superblock);
>>  	lsm_set_blob_size(&needed->lbs_task, &blob_sizes.lbs_task);
>>  }
>> @@ -376,6 +378,7 @@ static void __init ordered_lsm_init(void)
>>  	init_debug("inode blob size      = %d\n", blob_sizes.lbs_inode);
>>  	init_debug("ipc blob size        = %d\n", blob_sizes.lbs_ipc);
>>  	init_debug("msg_msg blob size    = %d\n", blob_sizes.lbs_msg_msg);
>> +	init_debug("sock blob size       = %d\n", blob_sizes.lbs_sock);
>>  	init_debug("superblock blob size = %d\n", blob_sizes.lbs_superblock);
>>  	init_debug("task blob size       = %d\n", blob_sizes.lbs_task);
>>  
>> @@ -733,6 +736,27 @@ static int lsm_superblock_alloc(struct super_block *sb)
>>  	return 0;
>>  }
>>  
>> +/**
>> + * lsm_sock_alloc - allocate a composite socket blob
>> + * @sk: the socket that needs a blob
>> + *
>> + * Allocate the socket blob for all the modules
>> + *
>> + * Returns 0, or -ENOMEM if memory can't be allocated.
>> + */
>> +static int lsm_sock_alloc(struct sock *sk)
>> +{
>> +	if (blob_sizes.lbs_sock == 0) {
>> +		sk->sk_security = NULL;
>> +		return 0;
>> +	}
>> +
>> +	sk->sk_security = kzalloc(blob_sizes.lbs_sock, GFP_KERNEL);
>> +	if (sk->sk_security == NULL)
>> +		return -ENOMEM;
>> +	return 0;
>> +}
>> +
>>  /*
>>   * The default value of the LSM hook is defined in linux/lsm_hook_defs.h and
>>   * can be accessed with:
>> @@ -4369,7 +4393,14 @@ EXPORT_SYMBOL(security_socket_getpeersec_dgram);
>>   */
>>  int security_sk_alloc(struct sock *sk, int family, gfp_t priority)
>>  {
>> -	return call_int_hook(sk_alloc_security, 0, sk, family, priority);
>> +	int rc = lsm_sock_alloc(sk);
>> +
>> +	if (unlikely(rc))
>> +		return rc;
>> +	rc = call_int_hook(sk_alloc_security, 0, sk, family, priority);
>> +	if (unlikely(rc))
>> +		security_sk_free(sk);
>> +	return rc;
>>  }
>>  
>>  /**
>> @@ -4381,6 +4412,8 @@ int security_sk_alloc(struct sock *sk, int family, gfp_t priority)
>>  void security_sk_free(struct sock *sk)
>>  {
>>  	call_void_hook(sk_free_security, sk);
>> +	kfree(sk->sk_security);
>> +	sk->sk_security = NULL;
>>  }
>>  
>>  /**
>> diff --git a/security/selinux/hooks.c b/security/selinux/hooks.c
>> index d06e350fedee..f8397f05dc90 100644
>> --- a/security/selinux/hooks.c
>> +++ b/security/selinux/hooks.c
>> @@ -4497,7 +4497,7 @@ static int socket_sockcreate_sid(const struct task_security_struct *tsec,
>>  
>>  static int sock_has_perm(struct sock *sk, u32 perms)
>>  {
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	struct common_audit_data ad;
>>  	struct lsm_network_audit net = {0,};
>>  
>> @@ -4552,7 +4552,7 @@ static int selinux_socket_post_create(struct socket *sock, int family,
>>  	isec->initialized = LABEL_INITIALIZED;
>>  
>>  	if (sock->sk) {
>> -		sksec = sock->sk->sk_security;
>> +		sksec = selinux_sock(sock->sk);
>>  		sksec->sclass = sclass;
>>  		sksec->sid = sid;
>>  		/* Allows detection of the first association on this socket */
>> @@ -4568,8 +4568,8 @@ static int selinux_socket_post_create(struct socket *sock, int family,
>>  static int selinux_socket_socketpair(struct socket *socka,
>>  				     struct socket *sockb)
>>  {
>> -	struct sk_security_struct *sksec_a = socka->sk->sk_security;
>> -	struct sk_security_struct *sksec_b = sockb->sk->sk_security;
>> +	struct sk_security_struct *sksec_a = selinux_sock(socka->sk);
>> +	struct sk_security_struct *sksec_b = selinux_sock(sockb->sk);
>>  
>>  	sksec_a->peer_sid = sksec_b->sid;
>>  	sksec_b->peer_sid = sksec_a->sid;
>> @@ -4584,7 +4584,7 @@ static int selinux_socket_socketpair(struct socket *socka,
>>  static int selinux_socket_bind(struct socket *sock, struct sockaddr *address, int addrlen)
>>  {
>>  	struct sock *sk = sock->sk;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	u16 family;
>>  	int err;
>>  
>> @@ -4717,7 +4717,7 @@ static int selinux_socket_connect_helper(struct socket *sock,
>>  					 struct sockaddr *address, int addrlen)
>>  {
>>  	struct sock *sk = sock->sk;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	int err;
>>  
>>  	err = sock_has_perm(sk, SOCKET__CONNECT);
>> @@ -4895,9 +4895,9 @@ static int selinux_socket_unix_stream_connect(struct sock *sock,
>>  					      struct sock *other,
>>  					      struct sock *newsk)
>>  {
>> -	struct sk_security_struct *sksec_sock = sock->sk_security;
>> -	struct sk_security_struct *sksec_other = other->sk_security;
>> -	struct sk_security_struct *sksec_new = newsk->sk_security;
>> +	struct sk_security_struct *sksec_sock = selinux_sock(sock);
>> +	struct sk_security_struct *sksec_other = selinux_sock(other);
>> +	struct sk_security_struct *sksec_new = selinux_sock(newsk);
>>  	struct common_audit_data ad;
>>  	struct lsm_network_audit net = {0,};
>>  	int err;
>> @@ -4928,8 +4928,8 @@ static int selinux_socket_unix_stream_connect(struct sock *sock,
>>  static int selinux_socket_unix_may_send(struct socket *sock,
>>  					struct socket *other)
>>  {
>> -	struct sk_security_struct *ssec = sock->sk->sk_security;
>> -	struct sk_security_struct *osec = other->sk->sk_security;
>> +	struct sk_security_struct *ssec = selinux_sock(sock->sk);
>> +	struct sk_security_struct *osec = selinux_sock(other->sk);
>>  	struct common_audit_data ad;
>>  	struct lsm_network_audit net = {0,};
>>  
>> @@ -4968,7 +4968,7 @@ static int selinux_sock_rcv_skb_compat(struct sock *sk, struct sk_buff *skb,
>>  				       u16 family)
>>  {
>>  	int err = 0;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	u32 sk_sid = sksec->sid;
>>  	struct common_audit_data ad;
>>  	struct lsm_network_audit net = {0,};
>> @@ -5000,7 +5000,7 @@ static int selinux_sock_rcv_skb_compat(struct sock *sk, struct sk_buff *skb,
>>  static int selinux_socket_sock_rcv_skb(struct sock *sk, struct sk_buff *skb)
>>  {
>>  	int err;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	u16 family = sk->sk_family;
>>  	u32 sk_sid = sksec->sid;
>>  	struct common_audit_data ad;
>> @@ -5073,7 +5073,7 @@ static int selinux_socket_getpeersec_stream(struct socket *sock,
>>  	int err = 0;
>>  	char *scontext = NULL;
>>  	u32 scontext_len;
>> -	struct sk_security_struct *sksec = sock->sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sock->sk);
>>  	u32 peer_sid = SECSID_NULL;
>>  
>>  	if (sksec->sclass == SECCLASS_UNIX_STREAM_SOCKET ||
>> @@ -5131,34 +5131,27 @@ static int selinux_socket_getpeersec_dgram(struct socket *sock, struct sk_buff *
>>  
>>  static int selinux_sk_alloc_security(struct sock *sk, int family, gfp_t priority)
>>  {
>> -	struct sk_security_struct *sksec;
>> -
>> -	sksec = kzalloc(sizeof(*sksec), priority);
>> -	if (!sksec)
>> -		return -ENOMEM;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>>  	sksec->peer_sid = SECINITSID_UNLABELED;
>>  	sksec->sid = SECINITSID_UNLABELED;
>>  	sksec->sclass = SECCLASS_SOCKET;
>>  	selinux_netlbl_sk_security_reset(sksec);
>> -	sk->sk_security = sksec;
>>  
>>  	return 0;
>>  }
>>  
>>  static void selinux_sk_free_security(struct sock *sk)
>>  {
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>> -	sk->sk_security = NULL;
>>  	selinux_netlbl_sk_security_free(sksec);
>> -	kfree(sksec);
>>  }
>>  
>>  static void selinux_sk_clone_security(const struct sock *sk, struct sock *newsk)
>>  {
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> -	struct sk_security_struct *newsksec = newsk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>> +	struct sk_security_struct *newsksec = selinux_sock(newsk);
>>  
>>  	newsksec->sid = sksec->sid;
>>  	newsksec->peer_sid = sksec->peer_sid;
>> @@ -5172,7 +5165,7 @@ static void selinux_sk_getsecid(struct sock *sk, u32 *secid)
>>  	if (!sk)
>>  		*secid = SECINITSID_ANY_SOCKET;
>>  	else {
>> -		struct sk_security_struct *sksec = sk->sk_security;
>> +		struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>>  		*secid = sksec->sid;
>>  	}
>> @@ -5182,7 +5175,7 @@ static void selinux_sock_graft(struct sock *sk, struct socket *parent)
>>  {
>>  	struct inode_security_struct *isec =
>>  		inode_security_novalidate(SOCK_INODE(parent));
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>>  	if (sk->sk_family == PF_INET || sk->sk_family == PF_INET6 ||
>>  	    sk->sk_family == PF_UNIX)
>> @@ -5199,7 +5192,7 @@ static int selinux_sctp_process_new_assoc(struct sctp_association *asoc,
>>  {
>>  	struct sock *sk = asoc->base.sk;
>>  	u16 family = sk->sk_family;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	struct common_audit_data ad;
>>  	struct lsm_network_audit net = {0,};
>>  	int err;
>> @@ -5256,7 +5249,7 @@ static int selinux_sctp_process_new_assoc(struct sctp_association *asoc,
>>  static int selinux_sctp_assoc_request(struct sctp_association *asoc,
>>  				      struct sk_buff *skb)
>>  {
>> -	struct sk_security_struct *sksec = asoc->base.sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
>>  	u32 conn_sid;
>>  	int err;
>>  
>> @@ -5289,7 +5282,7 @@ static int selinux_sctp_assoc_request(struct sctp_association *asoc,
>>  static int selinux_sctp_assoc_established(struct sctp_association *asoc,
>>  					  struct sk_buff *skb)
>>  {
>> -	struct sk_security_struct *sksec = asoc->base.sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
>>  
>>  	if (!selinux_policycap_extsockclass())
>>  		return 0;
>> @@ -5388,8 +5381,8 @@ static int selinux_sctp_bind_connect(struct sock *sk, int optname,
>>  static void selinux_sctp_sk_clone(struct sctp_association *asoc, struct sock *sk,
>>  				  struct sock *newsk)
>>  {
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> -	struct sk_security_struct *newsksec = newsk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>> +	struct sk_security_struct *newsksec = selinux_sock(newsk);
>>  
>>  	/* If policy does not support SECCLASS_SCTP_SOCKET then call
>>  	 * the non-sctp clone version.
>> @@ -5405,8 +5398,8 @@ static void selinux_sctp_sk_clone(struct sctp_association *asoc, struct sock *sk
>>  
>>  static int selinux_mptcp_add_subflow(struct sock *sk, struct sock *ssk)
>>  {
>> -	struct sk_security_struct *ssksec = ssk->sk_security;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *ssksec = selinux_sock(ssk);
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>>  	ssksec->sclass = sksec->sclass;
>>  	ssksec->sid = sksec->sid;
>> @@ -5421,7 +5414,7 @@ static int selinux_mptcp_add_subflow(struct sock *sk, struct sock *ssk)
>>  static int selinux_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
>>  				     struct request_sock *req)
>>  {
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	int err;
>>  	u16 family = req->rsk_ops->family;
>>  	u32 connsid;
>> @@ -5442,7 +5435,7 @@ static int selinux_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
>>  static void selinux_inet_csk_clone(struct sock *newsk,
>>  				   const struct request_sock *req)
>>  {
>> -	struct sk_security_struct *newsksec = newsk->sk_security;
>> +	struct sk_security_struct *newsksec = selinux_sock(newsk);
>>  
>>  	newsksec->sid = req->secid;
>>  	newsksec->peer_sid = req->peer_secid;
>> @@ -5459,7 +5452,7 @@ static void selinux_inet_csk_clone(struct sock *newsk,
>>  static void selinux_inet_conn_established(struct sock *sk, struct sk_buff *skb)
>>  {
>>  	u16 family = sk->sk_family;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>>  	/* handle mapped IPv4 packets arriving via IPv6 sockets */
>>  	if (family == PF_INET6 && skb->protocol == htons(ETH_P_IP))
>> @@ -5540,7 +5533,7 @@ static int selinux_tun_dev_attach_queue(void *security)
>>  static int selinux_tun_dev_attach(struct sock *sk, void *security)
>>  {
>>  	struct tun_security_struct *tunsec = security;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>>  	/* we don't currently perform any NetLabel based labeling here and it
>>  	 * isn't clear that we would want to do so anyway; while we could apply
>> @@ -5666,7 +5659,7 @@ static unsigned int selinux_ip_output(void *priv, struct sk_buff *skb,
>>  			return NF_ACCEPT;
>>  
>>  		/* standard practice, label using the parent socket */
>> -		sksec = sk->sk_security;
>> +		sksec = selinux_sock(sk);
>>  		sid = sksec->sid;
>>  	} else
>>  		sid = SECINITSID_KERNEL;
>> @@ -5689,7 +5682,7 @@ static unsigned int selinux_ip_postroute_compat(struct sk_buff *skb,
>>  	sk = skb_to_full_sk(skb);
>>  	if (sk == NULL)
>>  		return NF_ACCEPT;
>> -	sksec = sk->sk_security;
>> +	sksec = selinux_sock(sk);
>>  
>>  	ad.type = LSM_AUDIT_DATA_NET;
>>  	ad.u.net = &net;
>> @@ -5779,9 +5772,8 @@ static unsigned int selinux_ip_postroute(void *priv,
>>  		 * selinux_inet_conn_request().  See also selinux_ip_output()
>>  		 * for similar problems. */
>>  		u32 skb_sid;
>> -		struct sk_security_struct *sksec;
>> +		struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>> -		sksec = sk->sk_security;
>>  		if (selinux_skb_peerlbl_sid(skb, family, &skb_sid))
>>  			return NF_DROP;
>>  		/* At this point, if the returned skb peerlbl is SECSID_NULL
>> @@ -5810,7 +5802,7 @@ static unsigned int selinux_ip_postroute(void *priv,
>>  	} else {
>>  		/* Locally generated packet, fetch the security label from the
>>  		 * associated socket. */
>> -		struct sk_security_struct *sksec = sk->sk_security;
>> +		struct sk_security_struct *sksec = selinux_sock(sk);
>>  		peer_sid = sksec->sid;
>>  		secmark_perm = PACKET__SEND;
>>  	}
>> @@ -5856,7 +5848,7 @@ static int selinux_netlink_send(struct sock *sk, struct sk_buff *skb)
>>  	unsigned int data_len = skb->len;
>>  	unsigned char *data = skb->data;
>>  	struct nlmsghdr *nlh;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	u16 sclass = sksec->sclass;
>>  	u32 perm;
>>  
>> @@ -6814,6 +6806,7 @@ struct lsm_blob_sizes selinux_blob_sizes __ro_after_init = {
>>  	.lbs_inode = sizeof(struct inode_security_struct),
>>  	.lbs_ipc = sizeof(struct ipc_security_struct),
>>  	.lbs_msg_msg = sizeof(struct msg_security_struct),
>> +	.lbs_sock = sizeof(struct sk_security_struct),
>>  	.lbs_superblock = sizeof(struct superblock_security_struct),
>>  };
>>  
>> diff --git a/security/selinux/include/objsec.h b/security/selinux/include/objsec.h
>> index 2953132408bf..49221f441c68 100644
>> --- a/security/selinux/include/objsec.h
>> +++ b/security/selinux/include/objsec.h
>> @@ -194,4 +194,8 @@ static inline struct superblock_security_struct *selinux_superblock(
>>  	return superblock->s_security + selinux_blob_sizes.lbs_superblock;
>>  }
>>  
>> +static inline struct sk_security_struct *selinux_sock(const struct sock *sk)
>> +{
>> +	return sk->sk_security + selinux_blob_sizes.lbs_sock;
>> +}
>>  #endif /* _SELINUX_OBJSEC_H_ */
>> diff --git a/security/selinux/netlabel.c b/security/selinux/netlabel.c
>> index 528f5186e912..9755561aa466 100644
>> --- a/security/selinux/netlabel.c
>> +++ b/security/selinux/netlabel.c
>> @@ -68,7 +68,7 @@ static int selinux_netlbl_sidlookup_cached(struct sk_buff *skb,
>>  static struct netlbl_lsm_secattr *selinux_netlbl_sock_genattr(struct sock *sk)
>>  {
>>  	int rc;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	struct netlbl_lsm_secattr *secattr;
>>  
>>  	if (sksec->nlbl_secattr != NULL)
>> @@ -100,7 +100,7 @@ static struct netlbl_lsm_secattr *selinux_netlbl_sock_getattr(
>>  							const struct sock *sk,
>>  							u32 sid)
>>  {
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	struct netlbl_lsm_secattr *secattr = sksec->nlbl_secattr;
>>  
>>  	if (secattr == NULL)
>> @@ -239,7 +239,7 @@ int selinux_netlbl_skbuff_setsid(struct sk_buff *skb,
>>  	 * being labeled by it's parent socket, if it is just exit */
>>  	sk = skb_to_full_sk(skb);
>>  	if (sk != NULL) {
>> -		struct sk_security_struct *sksec = sk->sk_security;
>> +		struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>>  		if (sksec->nlbl_state != NLBL_REQSKB)
>>  			return 0;
>> @@ -276,7 +276,7 @@ int selinux_netlbl_sctp_assoc_request(struct sctp_association *asoc,
>>  {
>>  	int rc;
>>  	struct netlbl_lsm_secattr secattr;
>> -	struct sk_security_struct *sksec = asoc->base.sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
>>  	struct sockaddr_in addr4;
>>  	struct sockaddr_in6 addr6;
>>  
>> @@ -355,7 +355,7 @@ int selinux_netlbl_inet_conn_request(struct request_sock *req, u16 family)
>>   */
>>  void selinux_netlbl_inet_csk_clone(struct sock *sk, u16 family)
>>  {
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>>  	if (family == PF_INET)
>>  		sksec->nlbl_state = NLBL_LABELED;
>> @@ -373,8 +373,8 @@ void selinux_netlbl_inet_csk_clone(struct sock *sk, u16 family)
>>   */
>>  void selinux_netlbl_sctp_sk_clone(struct sock *sk, struct sock *newsk)
>>  {
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> -	struct sk_security_struct *newsksec = newsk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>> +	struct sk_security_struct *newsksec = selinux_sock(newsk);
>>  
>>  	newsksec->nlbl_state = sksec->nlbl_state;
>>  }
>> @@ -392,7 +392,7 @@ void selinux_netlbl_sctp_sk_clone(struct sock *sk, struct sock *newsk)
>>  int selinux_netlbl_socket_post_create(struct sock *sk, u16 family)
>>  {
>>  	int rc;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	struct netlbl_lsm_secattr *secattr;
>>  
>>  	if (family != PF_INET && family != PF_INET6)
>> @@ -506,7 +506,7 @@ int selinux_netlbl_socket_setsockopt(struct socket *sock,
>>  {
>>  	int rc = 0;
>>  	struct sock *sk = sock->sk;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	struct netlbl_lsm_secattr secattr;
>>  
>>  	if (selinux_netlbl_option(level, optname) &&
>> @@ -544,7 +544,7 @@ static int selinux_netlbl_socket_connect_helper(struct sock *sk,
>>  						struct sockaddr *addr)
>>  {
>>  	int rc;
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  	struct netlbl_lsm_secattr *secattr;
>>  
>>  	/* connected sockets are allowed to disconnect when the address family
>> @@ -583,7 +583,7 @@ static int selinux_netlbl_socket_connect_helper(struct sock *sk,
>>  int selinux_netlbl_socket_connect_locked(struct sock *sk,
>>  					 struct sockaddr *addr)
>>  {
>> -	struct sk_security_struct *sksec = sk->sk_security;
>> +	struct sk_security_struct *sksec = selinux_sock(sk);
>>  
>>  	if (sksec->nlbl_state != NLBL_REQSKB &&
>>  	    sksec->nlbl_state != NLBL_CONNLABELED)
>> diff --git a/security/smack/smack.h b/security/smack/smack.h
>> index aa15ff56ed6e..2d0163076eca 100644
>> --- a/security/smack/smack.h
>> +++ b/security/smack/smack.h
>> @@ -355,6 +355,11 @@ static inline struct superblock_smack *smack_superblock(
>>  	return superblock->s_security + smack_blob_sizes.lbs_superblock;
>>  }
>>  
>> +static inline struct socket_smack *smack_sock(const struct sock *sk)
>> +{
>> +	return sk->sk_security + smack_blob_sizes.lbs_sock;
>> +}
>> +
>>  /*
>>   * Is the directory transmuting?
>>   */
>> diff --git a/security/smack/smack_lsm.c b/security/smack/smack_lsm.c
>> index 6e270cf3fd30..ab026ff79504 100644
>> --- a/security/smack/smack_lsm.c
>> +++ b/security/smack/smack_lsm.c
>> @@ -1502,7 +1502,7 @@ static int smack_inode_getsecurity(struct mnt_idmap *idmap,
>>  		if (sock == NULL || sock->sk == NULL)
>>  			return -EOPNOTSUPP;
>>  
>> -		ssp = sock->sk->sk_security;
>> +		ssp = smack_sock(sock->sk);
>>  
>>  		if (strcmp(name, XATTR_SMACK_IPIN) == 0)
>>  			isp = ssp->smk_in;
>> @@ -1890,7 +1890,7 @@ static int smack_file_receive(struct file *file)
>>  
>>  	if (inode->i_sb->s_magic == SOCKFS_MAGIC) {
>>  		sock = SOCKET_I(inode);
>> -		ssp = sock->sk->sk_security;
>> +		ssp = smack_sock(sock->sk);
>>  		tsp = smack_cred(current_cred());
>>  		/*
>>  		 * If the receiving process can't write to the
>> @@ -2310,11 +2310,7 @@ static void smack_task_to_inode(struct task_struct *p, struct inode *inode)
>>  static int smack_sk_alloc_security(struct sock *sk, int family, gfp_t gfp_flags)
>>  {
>>  	struct smack_known *skp = smk_of_current();
>> -	struct socket_smack *ssp;
>> -
>> -	ssp = kzalloc(sizeof(struct socket_smack), gfp_flags);
>> -	if (ssp == NULL)
>> -		return -ENOMEM;
>> +	struct socket_smack *ssp = smack_sock(sk);
>>  
>>  	/*
>>  	 * Sockets created by kernel threads receive web label.
>> @@ -2328,8 +2324,6 @@ static int smack_sk_alloc_security(struct sock *sk, int family, gfp_t gfp_flags)
>>  	}
>>  	ssp->smk_packet = NULL;
>>  
>> -	sk->sk_security = ssp;
>> -
>>  	return 0;
>>  }
>>  
>> @@ -2355,7 +2349,6 @@ static void smack_sk_free_security(struct sock *sk)
>>  		rcu_read_unlock();
>>  	}
>>  #endif
>> -	kfree(sk->sk_security);
>>  }
>>  
>>  /**
>> @@ -2367,8 +2360,8 @@ static void smack_sk_free_security(struct sock *sk)
>>   */
>>  static void smack_sk_clone_security(const struct sock *sk, struct sock *newsk)
>>  {
>> -	struct socket_smack *ssp_old = sk->sk_security;
>> -	struct socket_smack *ssp_new = newsk->sk_security;
>> +	struct socket_smack *ssp_old = smack_sock(sk);
>> +	struct socket_smack *ssp_new = smack_sock(newsk);
>>  
>>  	*ssp_new = *ssp_old;
>>  }
>> @@ -2484,7 +2477,7 @@ static struct smack_known *smack_ipv6host_label(struct sockaddr_in6 *sip)
>>   */
>>  static int smack_netlbl_add(struct sock *sk)
>>  {
>> -	struct socket_smack *ssp = sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sk);
>>  	struct smack_known *skp = ssp->smk_out;
>>  	int rc;
>>  
>> @@ -2516,7 +2509,7 @@ static int smack_netlbl_add(struct sock *sk)
>>   */
>>  static void smack_netlbl_delete(struct sock *sk)
>>  {
>> -	struct socket_smack *ssp = sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sk);
>>  
>>  	/*
>>  	 * Take the label off the socket if one is set.
>> @@ -2548,7 +2541,7 @@ static int smk_ipv4_check(struct sock *sk, struct sockaddr_in *sap)
>>  	struct smack_known *skp;
>>  	int rc = 0;
>>  	struct smack_known *hkp;
>> -	struct socket_smack *ssp = sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sk);
>>  	struct smk_audit_info ad;
>>  
>>  	rcu_read_lock();
>> @@ -2621,7 +2614,7 @@ static void smk_ipv6_port_label(struct socket *sock, struct sockaddr *address)
>>  {
>>  	struct sock *sk = sock->sk;
>>  	struct sockaddr_in6 *addr6;
>> -	struct socket_smack *ssp = sock->sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sock->sk);
>>  	struct smk_port_label *spp;
>>  	unsigned short port = 0;
>>  
>> @@ -2709,7 +2702,7 @@ static int smk_ipv6_port_check(struct sock *sk, struct sockaddr_in6 *address,
>>  				int act)
>>  {
>>  	struct smk_port_label *spp;
>> -	struct socket_smack *ssp = sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sk);
>>  	struct smack_known *skp = NULL;
>>  	unsigned short port;
>>  	struct smack_known *object;
>> @@ -2803,7 +2796,7 @@ static int smack_inode_setsecurity(struct inode *inode, const char *name,
>>  	if (sock == NULL || sock->sk == NULL)
>>  		return -EOPNOTSUPP;
>>  
>> -	ssp = sock->sk->sk_security;
>> +	ssp = smack_sock(sock->sk);
>>  
>>  	if (strcmp(name, XATTR_SMACK_IPIN) == 0)
>>  		ssp->smk_in = skp;
>> @@ -2851,7 +2844,7 @@ static int smack_socket_post_create(struct socket *sock, int family,
>>  	 * Sockets created by kernel threads receive web label.
>>  	 */
>>  	if (unlikely(current->flags & PF_KTHREAD)) {
>> -		ssp = sock->sk->sk_security;
>> +		ssp = smack_sock(sock->sk);
>>  		ssp->smk_in = &smack_known_web;
>>  		ssp->smk_out = &smack_known_web;
>>  	}
>> @@ -2876,8 +2869,8 @@ static int smack_socket_post_create(struct socket *sock, int family,
>>  static int smack_socket_socketpair(struct socket *socka,
>>  		                   struct socket *sockb)
>>  {
>> -	struct socket_smack *asp = socka->sk->sk_security;
>> -	struct socket_smack *bsp = sockb->sk->sk_security;
>> +	struct socket_smack *asp = smack_sock(socka->sk);
>> +	struct socket_smack *bsp = smack_sock(sockb->sk);
>>  
>>  	asp->smk_packet = bsp->smk_out;
>>  	bsp->smk_packet = asp->smk_out;
>> @@ -2940,7 +2933,7 @@ static int smack_socket_connect(struct socket *sock, struct sockaddr *sap,
>>  		if (__is_defined(SMACK_IPV6_SECMARK_LABELING))
>>  			rsp = smack_ipv6host_label(sip);
>>  		if (rsp != NULL) {
>> -			struct socket_smack *ssp = sock->sk->sk_security;
>> +			struct socket_smack *ssp = smack_sock(sock->sk);
>>  
>>  			rc = smk_ipv6_check(ssp->smk_out, rsp, sip,
>>  					    SMK_CONNECTING);
>> @@ -3671,9 +3664,9 @@ static int smack_unix_stream_connect(struct sock *sock,
>>  {
>>  	struct smack_known *skp;
>>  	struct smack_known *okp;
>> -	struct socket_smack *ssp = sock->sk_security;
>> -	struct socket_smack *osp = other->sk_security;
>> -	struct socket_smack *nsp = newsk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sock);
>> +	struct socket_smack *osp = smack_sock(other);
>> +	struct socket_smack *nsp = smack_sock(newsk);
>>  	struct smk_audit_info ad;
>>  	int rc = 0;
>>  #ifdef CONFIG_AUDIT
>> @@ -3719,8 +3712,8 @@ static int smack_unix_stream_connect(struct sock *sock,
>>   */
>>  static int smack_unix_may_send(struct socket *sock, struct socket *other)
>>  {
>> -	struct socket_smack *ssp = sock->sk->sk_security;
>> -	struct socket_smack *osp = other->sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sock->sk);
>> +	struct socket_smack *osp = smack_sock(other->sk);
>>  	struct smk_audit_info ad;
>>  	int rc;
>>  
>> @@ -3757,7 +3750,7 @@ static int smack_socket_sendmsg(struct socket *sock, struct msghdr *msg,
>>  	struct sockaddr_in6 *sap = (struct sockaddr_in6 *) msg->msg_name;
>>  #endif
>>  #ifdef SMACK_IPV6_SECMARK_LABELING
>> -	struct socket_smack *ssp = sock->sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sock->sk);
>>  	struct smack_known *rsp;
>>  #endif
>>  	int rc = 0;
>> @@ -3969,7 +3962,7 @@ static struct smack_known *smack_from_netlbl(const struct sock *sk, u16 family,
>>  	netlbl_secattr_init(&secattr);
>>  
>>  	if (sk)
>> -		ssp = sk->sk_security;
>> +		ssp = smack_sock(sk);
>>  
>>  	if (netlbl_skbuff_getattr(skb, family, &secattr) == 0) {
>>  		skp = smack_from_secattr(&secattr, ssp);
>> @@ -3991,7 +3984,7 @@ static struct smack_known *smack_from_netlbl(const struct sock *sk, u16 family,
>>   */
>>  static int smack_socket_sock_rcv_skb(struct sock *sk, struct sk_buff *skb)
>>  {
>> -	struct socket_smack *ssp = sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sk);
>>  	struct smack_known *skp = NULL;
>>  	int rc = 0;
>>  	struct smk_audit_info ad;
>> @@ -4090,12 +4083,11 @@ static int smack_socket_getpeersec_stream(struct socket *sock,
>>  					  sockptr_t optval, sockptr_t optlen,
>>  					  unsigned int len)
>>  {
>> -	struct socket_smack *ssp;
>> +	struct socket_smack *ssp = smack_sock(sock->sk);
>>  	char *rcp = "";
>>  	u32 slen = 1;
>>  	int rc = 0;
>>  
>> -	ssp = sock->sk->sk_security;
>>  	if (ssp->smk_packet != NULL) {
>>  		rcp = ssp->smk_packet->smk_known;
>>  		slen = strlen(rcp) + 1;
>> @@ -4145,7 +4137,7 @@ static int smack_socket_getpeersec_dgram(struct socket *sock,
>>  
>>  	switch (family) {
>>  	case PF_UNIX:
>> -		ssp = sock->sk->sk_security;
>> +		ssp = smack_sock(sock->sk);
>>  		s = ssp->smk_out->smk_secid;
>>  		break;
>>  	case PF_INET:
>> @@ -4194,7 +4186,7 @@ static void smack_sock_graft(struct sock *sk, struct socket *parent)
>>  	    (sk->sk_family != PF_INET && sk->sk_family != PF_INET6))
>>  		return;
>>  
>> -	ssp = sk->sk_security;
>> +	ssp = smack_sock(sk);
>>  	ssp->smk_in = skp;
>>  	ssp->smk_out = skp;
>>  	/* cssp->smk_packet is already set in smack_inet_csk_clone() */
>> @@ -4214,7 +4206,7 @@ static int smack_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
>>  {
>>  	u16 family = sk->sk_family;
>>  	struct smack_known *skp;
>> -	struct socket_smack *ssp = sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sk);
>>  	struct sockaddr_in addr;
>>  	struct iphdr *hdr;
>>  	struct smack_known *hskp;
>> @@ -4300,7 +4292,7 @@ static int smack_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
>>  static void smack_inet_csk_clone(struct sock *sk,
>>  				 const struct request_sock *req)
>>  {
>> -	struct socket_smack *ssp = sk->sk_security;
>> +	struct socket_smack *ssp = smack_sock(sk);
>>  	struct smack_known *skp;
>>  
>>  	if (req->peer_secid != 0) {
>> @@ -4868,6 +4860,7 @@ struct lsm_blob_sizes smack_blob_sizes __ro_after_init = {
>>  	.lbs_inode = sizeof(struct inode_smack),
>>  	.lbs_ipc = sizeof(struct smack_known *),
>>  	.lbs_msg_msg = sizeof(struct smack_known *),
>> +	.lbs_sock = sizeof(struct socket_smack),
>>  	.lbs_superblock = sizeof(struct superblock_smack),
>>  };
>>  
>> diff --git a/security/smack/smack_netfilter.c b/security/smack/smack_netfilter.c
>> index b945c1d3a743..bad71b7e648d 100644
>> --- a/security/smack/smack_netfilter.c
>> +++ b/security/smack/smack_netfilter.c
>> @@ -26,8 +26,8 @@ static unsigned int smack_ip_output(void *priv,
>>  	struct socket_smack *ssp;
>>  	struct smack_known *skp;
>>  
>> -	if (sk && sk->sk_security) {
>> -		ssp = sk->sk_security;
>> +	if (sk) {
>> +		ssp = smack_sock(sk);
>>  		skp = ssp->smk_out;
>>  		skb->secmark = skp->smk_secid;
>>  	}
diff mbox series

Patch

diff --git a/include/linux/lsm_hooks.h b/include/linux/lsm_hooks.h
index ab2b2fafa4a4..67b6e87ca6ec 100644
--- a/include/linux/lsm_hooks.h
+++ b/include/linux/lsm_hooks.h
@@ -62,6 +62,7 @@  struct lsm_blob_sizes {
 	int	lbs_superblock;
 	int	lbs_ipc;
 	int	lbs_msg_msg;
+	int	lbs_sock;
 	int	lbs_task;
 };
 
diff --git a/security/apparmor/include/net.h b/security/apparmor/include/net.h
index 6fa440b5daed..9eb159c09578 100644
--- a/security/apparmor/include/net.h
+++ b/security/apparmor/include/net.h
@@ -51,7 +51,7 @@  struct aa_sk_ctx {
 	struct aa_label *peer;
 };
 
-#define SK_CTX(X) ((X)->sk_security)
+#define SK_CTX(X) ((X)->sk_security + apparmor_blob_sizes.lbs_sock)
 #define SOCK_ctx(X) SOCK_INODE(X)->i_security
 #define DEFINE_AUDIT_NET(NAME, OP, SK, F, T, P)				  \
 	struct lsm_network_audit NAME ## _net = { .sk = (SK),		  \
diff --git a/security/apparmor/lsm.c b/security/apparmor/lsm.c
index f431251ffb91..3dd849a6d7a1 100644
--- a/security/apparmor/lsm.c
+++ b/security/apparmor/lsm.c
@@ -818,22 +818,6 @@  static int apparmor_task_kill(struct task_struct *target, struct kernel_siginfo
 	return error;
 }
 
-/**
- * apparmor_sk_alloc_security - allocate and attach the sk_security field
- */
-static int apparmor_sk_alloc_security(struct sock *sk, int family, gfp_t flags)
-{
-	struct aa_sk_ctx *ctx;
-
-	ctx = kzalloc(sizeof(*ctx), flags);
-	if (!ctx)
-		return -ENOMEM;
-
-	SK_CTX(sk) = ctx;
-
-	return 0;
-}
-
 /**
  * apparmor_sk_free_security - free the sk_security field
  */
@@ -841,10 +825,8 @@  static void apparmor_sk_free_security(struct sock *sk)
 {
 	struct aa_sk_ctx *ctx = SK_CTX(sk);
 
-	SK_CTX(sk) = NULL;
 	aa_put_label(ctx->label);
 	aa_put_label(ctx->peer);
-	kfree(ctx);
 }
 
 /**
@@ -1212,6 +1194,7 @@  static int apparmor_inet_conn_request(const struct sock *sk, struct sk_buff *skb
 struct lsm_blob_sizes apparmor_blob_sizes __ro_after_init = {
 	.lbs_cred = sizeof(struct aa_label *),
 	.lbs_file = sizeof(struct aa_file_ctx),
+	.lbs_sock = sizeof(struct aa_sk_ctx),
 	.lbs_task = sizeof(struct aa_task_ctx),
 };
 
@@ -1250,7 +1233,6 @@  static struct security_hook_list apparmor_hooks[] __ro_after_init = {
 	LSM_HOOK_INIT(getprocattr, apparmor_getprocattr),
 	LSM_HOOK_INIT(setprocattr, apparmor_setprocattr),
 
-	LSM_HOOK_INIT(sk_alloc_security, apparmor_sk_alloc_security),
 	LSM_HOOK_INIT(sk_free_security, apparmor_sk_free_security),
 	LSM_HOOK_INIT(sk_clone_security, apparmor_sk_clone_security),
 
diff --git a/security/security.c b/security/security.c
index b720424ca37d..e71f4717cde5 100644
--- a/security/security.c
+++ b/security/security.c
@@ -30,6 +30,7 @@ 
 #include <linux/string.h>
 #include <linux/msg.h>
 #include <net/flow.h>
+#include <net/sock.h>
 
 #define MAX_LSM_EVM_XATTR	2
 
@@ -210,6 +211,7 @@  static void __init lsm_set_blob_sizes(struct lsm_blob_sizes *needed)
 	lsm_set_blob_size(&needed->lbs_inode, &blob_sizes.lbs_inode);
 	lsm_set_blob_size(&needed->lbs_ipc, &blob_sizes.lbs_ipc);
 	lsm_set_blob_size(&needed->lbs_msg_msg, &blob_sizes.lbs_msg_msg);
+	lsm_set_blob_size(&needed->lbs_sock, &blob_sizes.lbs_sock);
 	lsm_set_blob_size(&needed->lbs_superblock, &blob_sizes.lbs_superblock);
 	lsm_set_blob_size(&needed->lbs_task, &blob_sizes.lbs_task);
 }
@@ -376,6 +378,7 @@  static void __init ordered_lsm_init(void)
 	init_debug("inode blob size      = %d\n", blob_sizes.lbs_inode);
 	init_debug("ipc blob size        = %d\n", blob_sizes.lbs_ipc);
 	init_debug("msg_msg blob size    = %d\n", blob_sizes.lbs_msg_msg);
+	init_debug("sock blob size       = %d\n", blob_sizes.lbs_sock);
 	init_debug("superblock blob size = %d\n", blob_sizes.lbs_superblock);
 	init_debug("task blob size       = %d\n", blob_sizes.lbs_task);
 
@@ -733,6 +736,27 @@  static int lsm_superblock_alloc(struct super_block *sb)
 	return 0;
 }
 
+/**
+ * lsm_sock_alloc - allocate a composite socket blob
+ * @sk: the socket that needs a blob
+ *
+ * Allocate the socket blob for all the modules
+ *
+ * Returns 0, or -ENOMEM if memory can't be allocated.
+ */
+static int lsm_sock_alloc(struct sock *sk)
+{
+	if (blob_sizes.lbs_sock == 0) {
+		sk->sk_security = NULL;
+		return 0;
+	}
+
+	sk->sk_security = kzalloc(blob_sizes.lbs_sock, GFP_KERNEL);
+	if (sk->sk_security == NULL)
+		return -ENOMEM;
+	return 0;
+}
+
 /*
  * The default value of the LSM hook is defined in linux/lsm_hook_defs.h and
  * can be accessed with:
@@ -4369,7 +4393,14 @@  EXPORT_SYMBOL(security_socket_getpeersec_dgram);
  */
 int security_sk_alloc(struct sock *sk, int family, gfp_t priority)
 {
-	return call_int_hook(sk_alloc_security, 0, sk, family, priority);
+	int rc = lsm_sock_alloc(sk);
+
+	if (unlikely(rc))
+		return rc;
+	rc = call_int_hook(sk_alloc_security, 0, sk, family, priority);
+	if (unlikely(rc))
+		security_sk_free(sk);
+	return rc;
 }
 
 /**
@@ -4381,6 +4412,8 @@  int security_sk_alloc(struct sock *sk, int family, gfp_t priority)
 void security_sk_free(struct sock *sk)
 {
 	call_void_hook(sk_free_security, sk);
+	kfree(sk->sk_security);
+	sk->sk_security = NULL;
 }
 
 /**
diff --git a/security/selinux/hooks.c b/security/selinux/hooks.c
index d06e350fedee..f8397f05dc90 100644
--- a/security/selinux/hooks.c
+++ b/security/selinux/hooks.c
@@ -4497,7 +4497,7 @@  static int socket_sockcreate_sid(const struct task_security_struct *tsec,
 
 static int sock_has_perm(struct sock *sk, u32 perms)
 {
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	struct common_audit_data ad;
 	struct lsm_network_audit net = {0,};
 
@@ -4552,7 +4552,7 @@  static int selinux_socket_post_create(struct socket *sock, int family,
 	isec->initialized = LABEL_INITIALIZED;
 
 	if (sock->sk) {
-		sksec = sock->sk->sk_security;
+		sksec = selinux_sock(sock->sk);
 		sksec->sclass = sclass;
 		sksec->sid = sid;
 		/* Allows detection of the first association on this socket */
@@ -4568,8 +4568,8 @@  static int selinux_socket_post_create(struct socket *sock, int family,
 static int selinux_socket_socketpair(struct socket *socka,
 				     struct socket *sockb)
 {
-	struct sk_security_struct *sksec_a = socka->sk->sk_security;
-	struct sk_security_struct *sksec_b = sockb->sk->sk_security;
+	struct sk_security_struct *sksec_a = selinux_sock(socka->sk);
+	struct sk_security_struct *sksec_b = selinux_sock(sockb->sk);
 
 	sksec_a->peer_sid = sksec_b->sid;
 	sksec_b->peer_sid = sksec_a->sid;
@@ -4584,7 +4584,7 @@  static int selinux_socket_socketpair(struct socket *socka,
 static int selinux_socket_bind(struct socket *sock, struct sockaddr *address, int addrlen)
 {
 	struct sock *sk = sock->sk;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	u16 family;
 	int err;
 
@@ -4717,7 +4717,7 @@  static int selinux_socket_connect_helper(struct socket *sock,
 					 struct sockaddr *address, int addrlen)
 {
 	struct sock *sk = sock->sk;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	int err;
 
 	err = sock_has_perm(sk, SOCKET__CONNECT);
@@ -4895,9 +4895,9 @@  static int selinux_socket_unix_stream_connect(struct sock *sock,
 					      struct sock *other,
 					      struct sock *newsk)
 {
-	struct sk_security_struct *sksec_sock = sock->sk_security;
-	struct sk_security_struct *sksec_other = other->sk_security;
-	struct sk_security_struct *sksec_new = newsk->sk_security;
+	struct sk_security_struct *sksec_sock = selinux_sock(sock);
+	struct sk_security_struct *sksec_other = selinux_sock(other);
+	struct sk_security_struct *sksec_new = selinux_sock(newsk);
 	struct common_audit_data ad;
 	struct lsm_network_audit net = {0,};
 	int err;
@@ -4928,8 +4928,8 @@  static int selinux_socket_unix_stream_connect(struct sock *sock,
 static int selinux_socket_unix_may_send(struct socket *sock,
 					struct socket *other)
 {
-	struct sk_security_struct *ssec = sock->sk->sk_security;
-	struct sk_security_struct *osec = other->sk->sk_security;
+	struct sk_security_struct *ssec = selinux_sock(sock->sk);
+	struct sk_security_struct *osec = selinux_sock(other->sk);
 	struct common_audit_data ad;
 	struct lsm_network_audit net = {0,};
 
@@ -4968,7 +4968,7 @@  static int selinux_sock_rcv_skb_compat(struct sock *sk, struct sk_buff *skb,
 				       u16 family)
 {
 	int err = 0;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	u32 sk_sid = sksec->sid;
 	struct common_audit_data ad;
 	struct lsm_network_audit net = {0,};
@@ -5000,7 +5000,7 @@  static int selinux_sock_rcv_skb_compat(struct sock *sk, struct sk_buff *skb,
 static int selinux_socket_sock_rcv_skb(struct sock *sk, struct sk_buff *skb)
 {
 	int err;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	u16 family = sk->sk_family;
 	u32 sk_sid = sksec->sid;
 	struct common_audit_data ad;
@@ -5073,7 +5073,7 @@  static int selinux_socket_getpeersec_stream(struct socket *sock,
 	int err = 0;
 	char *scontext = NULL;
 	u32 scontext_len;
-	struct sk_security_struct *sksec = sock->sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sock->sk);
 	u32 peer_sid = SECSID_NULL;
 
 	if (sksec->sclass == SECCLASS_UNIX_STREAM_SOCKET ||
@@ -5131,34 +5131,27 @@  static int selinux_socket_getpeersec_dgram(struct socket *sock, struct sk_buff *
 
 static int selinux_sk_alloc_security(struct sock *sk, int family, gfp_t priority)
 {
-	struct sk_security_struct *sksec;
-
-	sksec = kzalloc(sizeof(*sksec), priority);
-	if (!sksec)
-		return -ENOMEM;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 
 	sksec->peer_sid = SECINITSID_UNLABELED;
 	sksec->sid = SECINITSID_UNLABELED;
 	sksec->sclass = SECCLASS_SOCKET;
 	selinux_netlbl_sk_security_reset(sksec);
-	sk->sk_security = sksec;
 
 	return 0;
 }
 
 static void selinux_sk_free_security(struct sock *sk)
 {
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 
-	sk->sk_security = NULL;
 	selinux_netlbl_sk_security_free(sksec);
-	kfree(sksec);
 }
 
 static void selinux_sk_clone_security(const struct sock *sk, struct sock *newsk)
 {
-	struct sk_security_struct *sksec = sk->sk_security;
-	struct sk_security_struct *newsksec = newsk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
+	struct sk_security_struct *newsksec = selinux_sock(newsk);
 
 	newsksec->sid = sksec->sid;
 	newsksec->peer_sid = sksec->peer_sid;
@@ -5172,7 +5165,7 @@  static void selinux_sk_getsecid(struct sock *sk, u32 *secid)
 	if (!sk)
 		*secid = SECINITSID_ANY_SOCKET;
 	else {
-		struct sk_security_struct *sksec = sk->sk_security;
+		struct sk_security_struct *sksec = selinux_sock(sk);
 
 		*secid = sksec->sid;
 	}
@@ -5182,7 +5175,7 @@  static void selinux_sock_graft(struct sock *sk, struct socket *parent)
 {
 	struct inode_security_struct *isec =
 		inode_security_novalidate(SOCK_INODE(parent));
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 
 	if (sk->sk_family == PF_INET || sk->sk_family == PF_INET6 ||
 	    sk->sk_family == PF_UNIX)
@@ -5199,7 +5192,7 @@  static int selinux_sctp_process_new_assoc(struct sctp_association *asoc,
 {
 	struct sock *sk = asoc->base.sk;
 	u16 family = sk->sk_family;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	struct common_audit_data ad;
 	struct lsm_network_audit net = {0,};
 	int err;
@@ -5256,7 +5249,7 @@  static int selinux_sctp_process_new_assoc(struct sctp_association *asoc,
 static int selinux_sctp_assoc_request(struct sctp_association *asoc,
 				      struct sk_buff *skb)
 {
-	struct sk_security_struct *sksec = asoc->base.sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
 	u32 conn_sid;
 	int err;
 
@@ -5289,7 +5282,7 @@  static int selinux_sctp_assoc_request(struct sctp_association *asoc,
 static int selinux_sctp_assoc_established(struct sctp_association *asoc,
 					  struct sk_buff *skb)
 {
-	struct sk_security_struct *sksec = asoc->base.sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
 
 	if (!selinux_policycap_extsockclass())
 		return 0;
@@ -5388,8 +5381,8 @@  static int selinux_sctp_bind_connect(struct sock *sk, int optname,
 static void selinux_sctp_sk_clone(struct sctp_association *asoc, struct sock *sk,
 				  struct sock *newsk)
 {
-	struct sk_security_struct *sksec = sk->sk_security;
-	struct sk_security_struct *newsksec = newsk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
+	struct sk_security_struct *newsksec = selinux_sock(newsk);
 
 	/* If policy does not support SECCLASS_SCTP_SOCKET then call
 	 * the non-sctp clone version.
@@ -5405,8 +5398,8 @@  static void selinux_sctp_sk_clone(struct sctp_association *asoc, struct sock *sk
 
 static int selinux_mptcp_add_subflow(struct sock *sk, struct sock *ssk)
 {
-	struct sk_security_struct *ssksec = ssk->sk_security;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *ssksec = selinux_sock(ssk);
+	struct sk_security_struct *sksec = selinux_sock(sk);
 
 	ssksec->sclass = sksec->sclass;
 	ssksec->sid = sksec->sid;
@@ -5421,7 +5414,7 @@  static int selinux_mptcp_add_subflow(struct sock *sk, struct sock *ssk)
 static int selinux_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
 				     struct request_sock *req)
 {
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	int err;
 	u16 family = req->rsk_ops->family;
 	u32 connsid;
@@ -5442,7 +5435,7 @@  static int selinux_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
 static void selinux_inet_csk_clone(struct sock *newsk,
 				   const struct request_sock *req)
 {
-	struct sk_security_struct *newsksec = newsk->sk_security;
+	struct sk_security_struct *newsksec = selinux_sock(newsk);
 
 	newsksec->sid = req->secid;
 	newsksec->peer_sid = req->peer_secid;
@@ -5459,7 +5452,7 @@  static void selinux_inet_csk_clone(struct sock *newsk,
 static void selinux_inet_conn_established(struct sock *sk, struct sk_buff *skb)
 {
 	u16 family = sk->sk_family;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 
 	/* handle mapped IPv4 packets arriving via IPv6 sockets */
 	if (family == PF_INET6 && skb->protocol == htons(ETH_P_IP))
@@ -5540,7 +5533,7 @@  static int selinux_tun_dev_attach_queue(void *security)
 static int selinux_tun_dev_attach(struct sock *sk, void *security)
 {
 	struct tun_security_struct *tunsec = security;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 
 	/* we don't currently perform any NetLabel based labeling here and it
 	 * isn't clear that we would want to do so anyway; while we could apply
@@ -5666,7 +5659,7 @@  static unsigned int selinux_ip_output(void *priv, struct sk_buff *skb,
 			return NF_ACCEPT;
 
 		/* standard practice, label using the parent socket */
-		sksec = sk->sk_security;
+		sksec = selinux_sock(sk);
 		sid = sksec->sid;
 	} else
 		sid = SECINITSID_KERNEL;
@@ -5689,7 +5682,7 @@  static unsigned int selinux_ip_postroute_compat(struct sk_buff *skb,
 	sk = skb_to_full_sk(skb);
 	if (sk == NULL)
 		return NF_ACCEPT;
-	sksec = sk->sk_security;
+	sksec = selinux_sock(sk);
 
 	ad.type = LSM_AUDIT_DATA_NET;
 	ad.u.net = &net;
@@ -5779,9 +5772,8 @@  static unsigned int selinux_ip_postroute(void *priv,
 		 * selinux_inet_conn_request().  See also selinux_ip_output()
 		 * for similar problems. */
 		u32 skb_sid;
-		struct sk_security_struct *sksec;
+		struct sk_security_struct *sksec = selinux_sock(sk);
 
-		sksec = sk->sk_security;
 		if (selinux_skb_peerlbl_sid(skb, family, &skb_sid))
 			return NF_DROP;
 		/* At this point, if the returned skb peerlbl is SECSID_NULL
@@ -5810,7 +5802,7 @@  static unsigned int selinux_ip_postroute(void *priv,
 	} else {
 		/* Locally generated packet, fetch the security label from the
 		 * associated socket. */
-		struct sk_security_struct *sksec = sk->sk_security;
+		struct sk_security_struct *sksec = selinux_sock(sk);
 		peer_sid = sksec->sid;
 		secmark_perm = PACKET__SEND;
 	}
@@ -5856,7 +5848,7 @@  static int selinux_netlink_send(struct sock *sk, struct sk_buff *skb)
 	unsigned int data_len = skb->len;
 	unsigned char *data = skb->data;
 	struct nlmsghdr *nlh;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	u16 sclass = sksec->sclass;
 	u32 perm;
 
@@ -6814,6 +6806,7 @@  struct lsm_blob_sizes selinux_blob_sizes __ro_after_init = {
 	.lbs_inode = sizeof(struct inode_security_struct),
 	.lbs_ipc = sizeof(struct ipc_security_struct),
 	.lbs_msg_msg = sizeof(struct msg_security_struct),
+	.lbs_sock = sizeof(struct sk_security_struct),
 	.lbs_superblock = sizeof(struct superblock_security_struct),
 };
 
diff --git a/security/selinux/include/objsec.h b/security/selinux/include/objsec.h
index 2953132408bf..49221f441c68 100644
--- a/security/selinux/include/objsec.h
+++ b/security/selinux/include/objsec.h
@@ -194,4 +194,8 @@  static inline struct superblock_security_struct *selinux_superblock(
 	return superblock->s_security + selinux_blob_sizes.lbs_superblock;
 }
 
+static inline struct sk_security_struct *selinux_sock(const struct sock *sk)
+{
+	return sk->sk_security + selinux_blob_sizes.lbs_sock;
+}
 #endif /* _SELINUX_OBJSEC_H_ */
diff --git a/security/selinux/netlabel.c b/security/selinux/netlabel.c
index 528f5186e912..9755561aa466 100644
--- a/security/selinux/netlabel.c
+++ b/security/selinux/netlabel.c
@@ -68,7 +68,7 @@  static int selinux_netlbl_sidlookup_cached(struct sk_buff *skb,
 static struct netlbl_lsm_secattr *selinux_netlbl_sock_genattr(struct sock *sk)
 {
 	int rc;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	struct netlbl_lsm_secattr *secattr;
 
 	if (sksec->nlbl_secattr != NULL)
@@ -100,7 +100,7 @@  static struct netlbl_lsm_secattr *selinux_netlbl_sock_getattr(
 							const struct sock *sk,
 							u32 sid)
 {
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	struct netlbl_lsm_secattr *secattr = sksec->nlbl_secattr;
 
 	if (secattr == NULL)
@@ -239,7 +239,7 @@  int selinux_netlbl_skbuff_setsid(struct sk_buff *skb,
 	 * being labeled by it's parent socket, if it is just exit */
 	sk = skb_to_full_sk(skb);
 	if (sk != NULL) {
-		struct sk_security_struct *sksec = sk->sk_security;
+		struct sk_security_struct *sksec = selinux_sock(sk);
 
 		if (sksec->nlbl_state != NLBL_REQSKB)
 			return 0;
@@ -276,7 +276,7 @@  int selinux_netlbl_sctp_assoc_request(struct sctp_association *asoc,
 {
 	int rc;
 	struct netlbl_lsm_secattr secattr;
-	struct sk_security_struct *sksec = asoc->base.sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(asoc->base.sk);
 	struct sockaddr_in addr4;
 	struct sockaddr_in6 addr6;
 
@@ -355,7 +355,7 @@  int selinux_netlbl_inet_conn_request(struct request_sock *req, u16 family)
  */
 void selinux_netlbl_inet_csk_clone(struct sock *sk, u16 family)
 {
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 
 	if (family == PF_INET)
 		sksec->nlbl_state = NLBL_LABELED;
@@ -373,8 +373,8 @@  void selinux_netlbl_inet_csk_clone(struct sock *sk, u16 family)
  */
 void selinux_netlbl_sctp_sk_clone(struct sock *sk, struct sock *newsk)
 {
-	struct sk_security_struct *sksec = sk->sk_security;
-	struct sk_security_struct *newsksec = newsk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
+	struct sk_security_struct *newsksec = selinux_sock(newsk);
 
 	newsksec->nlbl_state = sksec->nlbl_state;
 }
@@ -392,7 +392,7 @@  void selinux_netlbl_sctp_sk_clone(struct sock *sk, struct sock *newsk)
 int selinux_netlbl_socket_post_create(struct sock *sk, u16 family)
 {
 	int rc;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	struct netlbl_lsm_secattr *secattr;
 
 	if (family != PF_INET && family != PF_INET6)
@@ -506,7 +506,7 @@  int selinux_netlbl_socket_setsockopt(struct socket *sock,
 {
 	int rc = 0;
 	struct sock *sk = sock->sk;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	struct netlbl_lsm_secattr secattr;
 
 	if (selinux_netlbl_option(level, optname) &&
@@ -544,7 +544,7 @@  static int selinux_netlbl_socket_connect_helper(struct sock *sk,
 						struct sockaddr *addr)
 {
 	int rc;
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 	struct netlbl_lsm_secattr *secattr;
 
 	/* connected sockets are allowed to disconnect when the address family
@@ -583,7 +583,7 @@  static int selinux_netlbl_socket_connect_helper(struct sock *sk,
 int selinux_netlbl_socket_connect_locked(struct sock *sk,
 					 struct sockaddr *addr)
 {
-	struct sk_security_struct *sksec = sk->sk_security;
+	struct sk_security_struct *sksec = selinux_sock(sk);
 
 	if (sksec->nlbl_state != NLBL_REQSKB &&
 	    sksec->nlbl_state != NLBL_CONNLABELED)
diff --git a/security/smack/smack.h b/security/smack/smack.h
index aa15ff56ed6e..2d0163076eca 100644
--- a/security/smack/smack.h
+++ b/security/smack/smack.h
@@ -355,6 +355,11 @@  static inline struct superblock_smack *smack_superblock(
 	return superblock->s_security + smack_blob_sizes.lbs_superblock;
 }
 
+static inline struct socket_smack *smack_sock(const struct sock *sk)
+{
+	return sk->sk_security + smack_blob_sizes.lbs_sock;
+}
+
 /*
  * Is the directory transmuting?
  */
diff --git a/security/smack/smack_lsm.c b/security/smack/smack_lsm.c
index 6e270cf3fd30..ab026ff79504 100644
--- a/security/smack/smack_lsm.c
+++ b/security/smack/smack_lsm.c
@@ -1502,7 +1502,7 @@  static int smack_inode_getsecurity(struct mnt_idmap *idmap,
 		if (sock == NULL || sock->sk == NULL)
 			return -EOPNOTSUPP;
 
-		ssp = sock->sk->sk_security;
+		ssp = smack_sock(sock->sk);
 
 		if (strcmp(name, XATTR_SMACK_IPIN) == 0)
 			isp = ssp->smk_in;
@@ -1890,7 +1890,7 @@  static int smack_file_receive(struct file *file)
 
 	if (inode->i_sb->s_magic == SOCKFS_MAGIC) {
 		sock = SOCKET_I(inode);
-		ssp = sock->sk->sk_security;
+		ssp = smack_sock(sock->sk);
 		tsp = smack_cred(current_cred());
 		/*
 		 * If the receiving process can't write to the
@@ -2310,11 +2310,7 @@  static void smack_task_to_inode(struct task_struct *p, struct inode *inode)
 static int smack_sk_alloc_security(struct sock *sk, int family, gfp_t gfp_flags)
 {
 	struct smack_known *skp = smk_of_current();
-	struct socket_smack *ssp;
-
-	ssp = kzalloc(sizeof(struct socket_smack), gfp_flags);
-	if (ssp == NULL)
-		return -ENOMEM;
+	struct socket_smack *ssp = smack_sock(sk);
 
 	/*
 	 * Sockets created by kernel threads receive web label.
@@ -2328,8 +2324,6 @@  static int smack_sk_alloc_security(struct sock *sk, int family, gfp_t gfp_flags)
 	}
 	ssp->smk_packet = NULL;
 
-	sk->sk_security = ssp;
-
 	return 0;
 }
 
@@ -2355,7 +2349,6 @@  static void smack_sk_free_security(struct sock *sk)
 		rcu_read_unlock();
 	}
 #endif
-	kfree(sk->sk_security);
 }
 
 /**
@@ -2367,8 +2360,8 @@  static void smack_sk_free_security(struct sock *sk)
  */
 static void smack_sk_clone_security(const struct sock *sk, struct sock *newsk)
 {
-	struct socket_smack *ssp_old = sk->sk_security;
-	struct socket_smack *ssp_new = newsk->sk_security;
+	struct socket_smack *ssp_old = smack_sock(sk);
+	struct socket_smack *ssp_new = smack_sock(newsk);
 
 	*ssp_new = *ssp_old;
 }
@@ -2484,7 +2477,7 @@  static struct smack_known *smack_ipv6host_label(struct sockaddr_in6 *sip)
  */
 static int smack_netlbl_add(struct sock *sk)
 {
-	struct socket_smack *ssp = sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sk);
 	struct smack_known *skp = ssp->smk_out;
 	int rc;
 
@@ -2516,7 +2509,7 @@  static int smack_netlbl_add(struct sock *sk)
  */
 static void smack_netlbl_delete(struct sock *sk)
 {
-	struct socket_smack *ssp = sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sk);
 
 	/*
 	 * Take the label off the socket if one is set.
@@ -2548,7 +2541,7 @@  static int smk_ipv4_check(struct sock *sk, struct sockaddr_in *sap)
 	struct smack_known *skp;
 	int rc = 0;
 	struct smack_known *hkp;
-	struct socket_smack *ssp = sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sk);
 	struct smk_audit_info ad;
 
 	rcu_read_lock();
@@ -2621,7 +2614,7 @@  static void smk_ipv6_port_label(struct socket *sock, struct sockaddr *address)
 {
 	struct sock *sk = sock->sk;
 	struct sockaddr_in6 *addr6;
-	struct socket_smack *ssp = sock->sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sock->sk);
 	struct smk_port_label *spp;
 	unsigned short port = 0;
 
@@ -2709,7 +2702,7 @@  static int smk_ipv6_port_check(struct sock *sk, struct sockaddr_in6 *address,
 				int act)
 {
 	struct smk_port_label *spp;
-	struct socket_smack *ssp = sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sk);
 	struct smack_known *skp = NULL;
 	unsigned short port;
 	struct smack_known *object;
@@ -2803,7 +2796,7 @@  static int smack_inode_setsecurity(struct inode *inode, const char *name,
 	if (sock == NULL || sock->sk == NULL)
 		return -EOPNOTSUPP;
 
-	ssp = sock->sk->sk_security;
+	ssp = smack_sock(sock->sk);
 
 	if (strcmp(name, XATTR_SMACK_IPIN) == 0)
 		ssp->smk_in = skp;
@@ -2851,7 +2844,7 @@  static int smack_socket_post_create(struct socket *sock, int family,
 	 * Sockets created by kernel threads receive web label.
 	 */
 	if (unlikely(current->flags & PF_KTHREAD)) {
-		ssp = sock->sk->sk_security;
+		ssp = smack_sock(sock->sk);
 		ssp->smk_in = &smack_known_web;
 		ssp->smk_out = &smack_known_web;
 	}
@@ -2876,8 +2869,8 @@  static int smack_socket_post_create(struct socket *sock, int family,
 static int smack_socket_socketpair(struct socket *socka,
 		                   struct socket *sockb)
 {
-	struct socket_smack *asp = socka->sk->sk_security;
-	struct socket_smack *bsp = sockb->sk->sk_security;
+	struct socket_smack *asp = smack_sock(socka->sk);
+	struct socket_smack *bsp = smack_sock(sockb->sk);
 
 	asp->smk_packet = bsp->smk_out;
 	bsp->smk_packet = asp->smk_out;
@@ -2940,7 +2933,7 @@  static int smack_socket_connect(struct socket *sock, struct sockaddr *sap,
 		if (__is_defined(SMACK_IPV6_SECMARK_LABELING))
 			rsp = smack_ipv6host_label(sip);
 		if (rsp != NULL) {
-			struct socket_smack *ssp = sock->sk->sk_security;
+			struct socket_smack *ssp = smack_sock(sock->sk);
 
 			rc = smk_ipv6_check(ssp->smk_out, rsp, sip,
 					    SMK_CONNECTING);
@@ -3671,9 +3664,9 @@  static int smack_unix_stream_connect(struct sock *sock,
 {
 	struct smack_known *skp;
 	struct smack_known *okp;
-	struct socket_smack *ssp = sock->sk_security;
-	struct socket_smack *osp = other->sk_security;
-	struct socket_smack *nsp = newsk->sk_security;
+	struct socket_smack *ssp = smack_sock(sock);
+	struct socket_smack *osp = smack_sock(other);
+	struct socket_smack *nsp = smack_sock(newsk);
 	struct smk_audit_info ad;
 	int rc = 0;
 #ifdef CONFIG_AUDIT
@@ -3719,8 +3712,8 @@  static int smack_unix_stream_connect(struct sock *sock,
  */
 static int smack_unix_may_send(struct socket *sock, struct socket *other)
 {
-	struct socket_smack *ssp = sock->sk->sk_security;
-	struct socket_smack *osp = other->sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sock->sk);
+	struct socket_smack *osp = smack_sock(other->sk);
 	struct smk_audit_info ad;
 	int rc;
 
@@ -3757,7 +3750,7 @@  static int smack_socket_sendmsg(struct socket *sock, struct msghdr *msg,
 	struct sockaddr_in6 *sap = (struct sockaddr_in6 *) msg->msg_name;
 #endif
 #ifdef SMACK_IPV6_SECMARK_LABELING
-	struct socket_smack *ssp = sock->sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sock->sk);
 	struct smack_known *rsp;
 #endif
 	int rc = 0;
@@ -3969,7 +3962,7 @@  static struct smack_known *smack_from_netlbl(const struct sock *sk, u16 family,
 	netlbl_secattr_init(&secattr);
 
 	if (sk)
-		ssp = sk->sk_security;
+		ssp = smack_sock(sk);
 
 	if (netlbl_skbuff_getattr(skb, family, &secattr) == 0) {
 		skp = smack_from_secattr(&secattr, ssp);
@@ -3991,7 +3984,7 @@  static struct smack_known *smack_from_netlbl(const struct sock *sk, u16 family,
  */
 static int smack_socket_sock_rcv_skb(struct sock *sk, struct sk_buff *skb)
 {
-	struct socket_smack *ssp = sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sk);
 	struct smack_known *skp = NULL;
 	int rc = 0;
 	struct smk_audit_info ad;
@@ -4090,12 +4083,11 @@  static int smack_socket_getpeersec_stream(struct socket *sock,
 					  sockptr_t optval, sockptr_t optlen,
 					  unsigned int len)
 {
-	struct socket_smack *ssp;
+	struct socket_smack *ssp = smack_sock(sock->sk);
 	char *rcp = "";
 	u32 slen = 1;
 	int rc = 0;
 
-	ssp = sock->sk->sk_security;
 	if (ssp->smk_packet != NULL) {
 		rcp = ssp->smk_packet->smk_known;
 		slen = strlen(rcp) + 1;
@@ -4145,7 +4137,7 @@  static int smack_socket_getpeersec_dgram(struct socket *sock,
 
 	switch (family) {
 	case PF_UNIX:
-		ssp = sock->sk->sk_security;
+		ssp = smack_sock(sock->sk);
 		s = ssp->smk_out->smk_secid;
 		break;
 	case PF_INET:
@@ -4194,7 +4186,7 @@  static void smack_sock_graft(struct sock *sk, struct socket *parent)
 	    (sk->sk_family != PF_INET && sk->sk_family != PF_INET6))
 		return;
 
-	ssp = sk->sk_security;
+	ssp = smack_sock(sk);
 	ssp->smk_in = skp;
 	ssp->smk_out = skp;
 	/* cssp->smk_packet is already set in smack_inet_csk_clone() */
@@ -4214,7 +4206,7 @@  static int smack_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
 {
 	u16 family = sk->sk_family;
 	struct smack_known *skp;
-	struct socket_smack *ssp = sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sk);
 	struct sockaddr_in addr;
 	struct iphdr *hdr;
 	struct smack_known *hskp;
@@ -4300,7 +4292,7 @@  static int smack_inet_conn_request(const struct sock *sk, struct sk_buff *skb,
 static void smack_inet_csk_clone(struct sock *sk,
 				 const struct request_sock *req)
 {
-	struct socket_smack *ssp = sk->sk_security;
+	struct socket_smack *ssp = smack_sock(sk);
 	struct smack_known *skp;
 
 	if (req->peer_secid != 0) {
@@ -4868,6 +4860,7 @@  struct lsm_blob_sizes smack_blob_sizes __ro_after_init = {
 	.lbs_inode = sizeof(struct inode_smack),
 	.lbs_ipc = sizeof(struct smack_known *),
 	.lbs_msg_msg = sizeof(struct smack_known *),
+	.lbs_sock = sizeof(struct socket_smack),
 	.lbs_superblock = sizeof(struct superblock_smack),
 };
 
diff --git a/security/smack/smack_netfilter.c b/security/smack/smack_netfilter.c
index b945c1d3a743..bad71b7e648d 100644
--- a/security/smack/smack_netfilter.c
+++ b/security/smack/smack_netfilter.c
@@ -26,8 +26,8 @@  static unsigned int smack_ip_output(void *priv,
 	struct socket_smack *ssp;
 	struct smack_known *skp;
 
-	if (sk && sk->sk_security) {
-		ssp = sk->sk_security;
+	if (sk) {
+		ssp = smack_sock(sk);
 		skp = ssp->smk_out;
 		skb->secmark = skp->smk_secid;
 	}