diff mbox series

[v2,net-next,2/4] af_unix: Return struct unix_sock from unix_get_socket().

Message ID 20231123014747.66063-3-kuniyu@amazon.com (mailing list archive)
State Changes Requested
Delegated to: Netdev Maintainers
Headers show
Series af_unix: Random improvements for GC. | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/codegen success Generated files up to date
netdev/tree_selection success Clearly marked for net-next
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 1282 this patch: 1282
netdev/cc_maintainers warning 3 maintainers not CCed: io-uring@vger.kernel.org asml.silence@gmail.com axboe@kernel.dk
netdev/build_clang success Errors and warnings before: 1155 this patch: 1155
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 1319 this patch: 1319
netdev/checkpatch warning CHECK: Please use a blank line after function/struct/union/enum declarations WARNING: line length of 84 exceeds 80 columns WARNING: line length of 85 exceeds 80 columns WARNING: line length of 86 exceeds 80 columns
netdev/build_clang_rust success No Rust files in patch. Skipping build
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Kuniyuki Iwashima Nov. 23, 2023, 1:47 a.m. UTC
Currently, unix_get_socket() returns struct sock, but after calling
it, we always cast it to unix_sk().

Let's return struct unix_sock from unix_get_socket().

Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.com>
---
 include/linux/io_uring.h |  4 ++--
 include/net/af_unix.h    |  2 +-
 io_uring/io_uring.c      |  5 +++--
 net/unix/garbage.c       | 19 +++++++------------
 net/unix/scm.c           | 26 +++++++++++---------------
 5 files changed, 24 insertions(+), 32 deletions(-)

Comments

Paolo Abeni Nov. 27, 2023, 9:08 a.m. UTC | #1
On Wed, 2023-11-22 at 17:47 -0800, Kuniyuki Iwashima wrote:
> Currently, unix_get_socket() returns struct sock, but after calling
> it, we always cast it to unix_sk().
> 
> Let's return struct unix_sock from unix_get_socket().
> 
> Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.com>
> ---
>  include/linux/io_uring.h |  4 ++--
>  include/net/af_unix.h    |  2 +-
>  io_uring/io_uring.c      |  5 +++--
>  net/unix/garbage.c       | 19 +++++++------------
>  net/unix/scm.c           | 26 +++++++++++---------------
>  5 files changed, 24 insertions(+), 32 deletions(-)
> 
> diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
> index aefb73eeeebf..be16677f0e4c 100644
> --- a/include/linux/io_uring.h
> +++ b/include/linux/io_uring.h
> @@ -54,7 +54,7 @@ int io_uring_cmd_import_fixed(u64 ubuf, unsigned long len, int rw,
>  			      struct iov_iter *iter, void *ioucmd);
>  void io_uring_cmd_done(struct io_uring_cmd *cmd, ssize_t ret, ssize_t res2,
>  			unsigned issue_flags);
> -struct sock *io_uring_get_socket(struct file *file);
> +struct unix_sock *io_uring_get_socket(struct file *file);
>  void __io_uring_cancel(bool cancel_all);
>  void __io_uring_free(struct task_struct *tsk);
>  void io_uring_unreg_ringfd(void);
> @@ -111,7 +111,7 @@ static inline void io_uring_cmd_do_in_task_lazy(struct io_uring_cmd *ioucmd,
>  			void (*task_work_cb)(struct io_uring_cmd *, unsigned))
>  {
>  }
> -static inline struct sock *io_uring_get_socket(struct file *file)
> +static inline struct unix_sock *io_uring_get_socket(struct file *file)
>  {
>  	return NULL;
>  }
> diff --git a/include/net/af_unix.h b/include/net/af_unix.h
> index 5a8a670b1920..c628d30ceb19 100644
> --- a/include/net/af_unix.h
> +++ b/include/net/af_unix.h
> @@ -14,7 +14,7 @@ void unix_destruct_scm(struct sk_buff *skb);
>  void io_uring_destruct_scm(struct sk_buff *skb);
>  void unix_gc(void);
>  void wait_for_unix_gc(void);
> -struct sock *unix_get_socket(struct file *filp);
> +struct unix_sock *unix_get_socket(struct file *filp);
>  struct sock *unix_peer_get(struct sock *sk);
>  
>  #define UNIX_HASH_MOD	(256 - 1)
> diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
> index ed254076c723..daed897f5975 100644
> --- a/io_uring/io_uring.c
> +++ b/io_uring/io_uring.c
> @@ -177,13 +177,14 @@ static struct ctl_table kernel_io_uring_disabled_table[] = {
>  };
>  #endif
>  
> -struct sock *io_uring_get_socket(struct file *file)
> +struct unix_sock *io_uring_get_socket(struct file *file)
>  {
>  #if defined(CONFIG_UNIX)
>  	if (io_is_uring_fops(file)) {
>  		struct io_ring_ctx *ctx = file->private_data;
>  
> -		return ctx->ring_sock->sk;
> +		if (ctx->ring_sock->sk)
> +			return unix_sk(ctx->ring_sock->sk);
>  	}
>  #endif
>  	return NULL;
> diff --git a/net/unix/garbage.c b/net/unix/garbage.c
> index db1bb99bb793..4d634f5f6a55 100644
> --- a/net/unix/garbage.c
> +++ b/net/unix/garbage.c
> @@ -105,20 +105,15 @@ static void scan_inflight(struct sock *x, void (*func)(struct unix_sock *),
>  
>  			while (nfd--) {
>  				/* Get the socket the fd matches if it indeed does so */
> -				struct sock *sk = unix_get_socket(*fp++);
> +				struct unix_sock *u = unix_get_socket(*fp++);
>  
> -				if (sk) {
> -					struct unix_sock *u = unix_sk(sk);
> +				/* Ignore non-candidates, they could have been added
> +				 * to the queues after starting the garbage collection
> +				 */
> +				if (u && test_bit(UNIX_GC_CANDIDATE, &u->gc_flags)) {
> +					hit = true;
>  
> -					/* Ignore non-candidates, they could
> -					 * have been added to the queues after
> -					 * starting the garbage collection
> -					 */
> -					if (test_bit(UNIX_GC_CANDIDATE, &u->gc_flags)) {
> -						hit = true;
> -
> -						func(u);
> -					}
> +					func(u);
>  				}
>  			}
>  			if (hit && hitlist != NULL) {
> diff --git a/net/unix/scm.c b/net/unix/scm.c
> index 4b3979272a81..36ce8fed9acc 100644
> --- a/net/unix/scm.c
> +++ b/net/unix/scm.c
> @@ -21,9 +21,8 @@ EXPORT_SYMBOL(gc_inflight_list);
>  DEFINE_SPINLOCK(unix_gc_lock);
>  EXPORT_SYMBOL(unix_gc_lock);
>  
> -struct sock *unix_get_socket(struct file *filp)
> +struct unix_sock *unix_get_socket(struct file *filp)
>  {
> -	struct sock *u_sock = NULL;
>  	struct inode *inode = file_inode(filp);
>  
>  	/* Socket ? */
> @@ -34,12 +33,13 @@ struct sock *unix_get_socket(struct file *filp)
>  
>  		/* PF_UNIX ? */
>  		if (s && ops && ops->family == PF_UNIX)
> -			u_sock = s;
> -	} else {
> -		/* Could be an io_uring instance */
> -		u_sock = io_uring_get_socket(filp);
> +			return unix_sk(s);
> +
> +		return NULL;
>  	}
> -	return u_sock;
> +
> +	/* Could be an io_uring instance */
> +	return io_uring_get_socket(filp);
>  }
>  EXPORT_SYMBOL(unix_get_socket);
>  
> @@ -48,13 +48,11 @@ EXPORT_SYMBOL(unix_get_socket);
>   */
>  void unix_inflight(struct user_struct *user, struct file *fp)
>  {
> -	struct sock *s = unix_get_socket(fp);
> +	struct unix_sock *u = unix_get_socket(fp);
>  
>  	spin_lock(&unix_gc_lock);
>  
> -	if (s) {
> -		struct unix_sock *u = unix_sk(s);
> -
> +	if (u) {
>  		if (!u->inflight) {
>  			BUG_ON(!list_empty(&u->link));
>  			list_add_tail(&u->link, &gc_inflight_list);
> @@ -71,13 +69,11 @@ void unix_inflight(struct user_struct *user, struct file *fp)
>  
>  void unix_notinflight(struct user_struct *user, struct file *fp)
>  {
> -	struct sock *s = unix_get_socket(fp);
> +	struct unix_sock *u = unix_get_socket(fp);
>  
>  	spin_lock(&unix_gc_lock);
>  
> -	if (s) {
> -		struct unix_sock *u = unix_sk(s);
> -
> +	if (u) {
>  		BUG_ON(!u->inflight);
>  		BUG_ON(list_empty(&u->link));
>  

Adding the io_uring peoples to the recipient list for awareness. I
guess this deserves an explicit ack from them.

Cheers,

Paolo
Pavel Begunkov Nov. 27, 2023, 12:33 p.m. UTC | #2
On 11/27/23 09:08, Paolo Abeni wrote:
> On Wed, 2023-11-22 at 17:47 -0800, Kuniyuki Iwashima wrote:
>> Currently, unix_get_socket() returns struct sock, but after calling
>> it, we always cast it to unix_sk().
>>
>> Let's return struct unix_sock from unix_get_socket().
>>
>> Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.com>
>> ---
>>   include/linux/io_uring.h |  4 ++--
>>   include/net/af_unix.h    |  2 +-
>>   io_uring/io_uring.c      |  5 +++--
>>   net/unix/garbage.c       | 19 +++++++------------
>>   net/unix/scm.c           | 26 +++++++++++---------------
>>   5 files changed, 24 insertions(+), 32 deletions(-)
>>
>> diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
>> index aefb73eeeebf..be16677f0e4c 100644
>> --- a/include/linux/io_uring.h
>> +++ b/include/linux/io_uring.h
>> @@ -54,7 +54,7 @@ int io_uring_cmd_import_fixed(u64 ubuf, unsigned long len, int rw,
>>   			      struct iov_iter *iter, void *ioucmd);
>>   void io_uring_cmd_done(struct io_uring_cmd *cmd, ssize_t ret, ssize_t res2,
>>   			unsigned issue_flags);
>> -struct sock *io_uring_get_socket(struct file *file);
>> +struct unix_sock *io_uring_get_socket(struct file *file);
>>   void __io_uring_cancel(bool cancel_all);
>>   void __io_uring_free(struct task_struct *tsk);
>>   void io_uring_unreg_ringfd(void);
>> @@ -111,7 +111,7 @@ static inline void io_uring_cmd_do_in_task_lazy(struct io_uring_cmd *ioucmd,
>>   			void (*task_work_cb)(struct io_uring_cmd *, unsigned))
>>   {
>>   }
>> -static inline struct sock *io_uring_get_socket(struct file *file)
>> +static inline struct unix_sock *io_uring_get_socket(struct file *file)
>>   {
>>   	return NULL;
>>   }
>> diff --git a/include/net/af_unix.h b/include/net/af_unix.h
>> index 5a8a670b1920..c628d30ceb19 100644
>> --- a/include/net/af_unix.h
>> +++ b/include/net/af_unix.h
>> @@ -14,7 +14,7 @@ void unix_destruct_scm(struct sk_buff *skb);
>>   void io_uring_destruct_scm(struct sk_buff *skb);
>>   void unix_gc(void);
>>   void wait_for_unix_gc(void);
>> -struct sock *unix_get_socket(struct file *filp);
>> +struct unix_sock *unix_get_socket(struct file *filp);
>>   struct sock *unix_peer_get(struct sock *sk);
>>   
>>   #define UNIX_HASH_MOD	(256 - 1)
>> diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
>> index ed254076c723..daed897f5975 100644
>> --- a/io_uring/io_uring.c
>> +++ b/io_uring/io_uring.c
>> @@ -177,13 +177,14 @@ static struct ctl_table kernel_io_uring_disabled_table[] = {
>>   };
>>   #endif
>>   
>> -struct sock *io_uring_get_socket(struct file *file)
>> +struct unix_sock *io_uring_get_socket(struct file *file)
>>   {
>>   #if defined(CONFIG_UNIX)
>>   	if (io_is_uring_fops(file)) {
>>   		struct io_ring_ctx *ctx = file->private_data;
>>   
>> -		return ctx->ring_sock->sk;
>> +		if (ctx->ring_sock->sk)
>> +			return unix_sk(ctx->ring_sock->sk);
>>   	}
>>   #endif
>>   	return NULL;
>> diff --git a/net/unix/garbage.c b/net/unix/garbage.c
>> index db1bb99bb793..4d634f5f6a55 100644
>> --- a/net/unix/garbage.c
>> +++ b/net/unix/garbage.c
>> @@ -105,20 +105,15 @@ static void scan_inflight(struct sock *x, void (*func)(struct unix_sock *),
>>   
>>   			while (nfd--) {
>>   				/* Get the socket the fd matches if it indeed does so */
>> -				struct sock *sk = unix_get_socket(*fp++);
>> +				struct unix_sock *u = unix_get_socket(*fp++);
>>   
>> -				if (sk) {
>> -					struct unix_sock *u = unix_sk(sk);
>> +				/* Ignore non-candidates, they could have been added
>> +				 * to the queues after starting the garbage collection
>> +				 */
>> +				if (u && test_bit(UNIX_GC_CANDIDATE, &u->gc_flags)) {
>> +					hit = true;
>>   
>> -					/* Ignore non-candidates, they could
>> -					 * have been added to the queues after
>> -					 * starting the garbage collection
>> -					 */
>> -					if (test_bit(UNIX_GC_CANDIDATE, &u->gc_flags)) {
>> -						hit = true;
>> -
>> -						func(u);
>> -					}
>> +					func(u);
>>   				}
>>   			}
>>   			if (hit && hitlist != NULL) {
>> diff --git a/net/unix/scm.c b/net/unix/scm.c
>> index 4b3979272a81..36ce8fed9acc 100644
>> --- a/net/unix/scm.c
>> +++ b/net/unix/scm.c
>> @@ -21,9 +21,8 @@ EXPORT_SYMBOL(gc_inflight_list);
>>   DEFINE_SPINLOCK(unix_gc_lock);
>>   EXPORT_SYMBOL(unix_gc_lock);
>>   
>> -struct sock *unix_get_socket(struct file *filp)
>> +struct unix_sock *unix_get_socket(struct file *filp)
>>   {
>> -	struct sock *u_sock = NULL;
>>   	struct inode *inode = file_inode(filp);
>>   
>>   	/* Socket ? */
>> @@ -34,12 +33,13 @@ struct sock *unix_get_socket(struct file *filp)
>>   
>>   		/* PF_UNIX ? */
>>   		if (s && ops && ops->family == PF_UNIX)
>> -			u_sock = s;
>> -	} else {
>> -		/* Could be an io_uring instance */
>> -		u_sock = io_uring_get_socket(filp);
>> +			return unix_sk(s);
>> +
>> +		return NULL;
>>   	}
>> -	return u_sock;
>> +
>> +	/* Could be an io_uring instance */
>> +	return io_uring_get_socket(filp);
>>   }
>>   EXPORT_SYMBOL(unix_get_socket);
>>   
>> @@ -48,13 +48,11 @@ EXPORT_SYMBOL(unix_get_socket);
>>    */
>>   void unix_inflight(struct user_struct *user, struct file *fp)
>>   {
>> -	struct sock *s = unix_get_socket(fp);
>> +	struct unix_sock *u = unix_get_socket(fp);
>>   
>>   	spin_lock(&unix_gc_lock);
>>   
>> -	if (s) {
>> -		struct unix_sock *u = unix_sk(s);
>> -
>> +	if (u) {
>>   		if (!u->inflight) {
>>   			BUG_ON(!list_empty(&u->link));
>>   			list_add_tail(&u->link, &gc_inflight_list);
>> @@ -71,13 +69,11 @@ void unix_inflight(struct user_struct *user, struct file *fp)
>>   
>>   void unix_notinflight(struct user_struct *user, struct file *fp)
>>   {
>> -	struct sock *s = unix_get_socket(fp);
>> +	struct unix_sock *u = unix_get_socket(fp);
>>   
>>   	spin_lock(&unix_gc_lock);
>>   
>> -	if (s) {
>> -		struct unix_sock *u = unix_sk(s);
>> -
>> +	if (u) {
>>   		BUG_ON(!u->inflight);
>>   		BUG_ON(list_empty(&u->link));
>>   
> 
> Adding the io_uring peoples to the recipient list for awareness. I
> guess this deserves an explicit ack from them.

Thanks Paolo, lgtm

Acked-by: Pavel Begunkov <asml.silence@gmail.com>
Simon Horman Dec. 1, 2023, 9:35 a.m. UTC | #3
On Wed, Nov 22, 2023 at 05:47:45PM -0800, Kuniyuki Iwashima wrote:
> Currently, unix_get_socket() returns struct sock, but after calling
> it, we always cast it to unix_sk().
> 
> Let's return struct unix_sock from unix_get_socket().
> 
> Signed-off-by: Kuniyuki Iwashima <kuniyu@amazon.com>

Thanks Iwashima-san,

this looks like a nice clean-up to me.

Reviewed-by: Simon Horman <horms@kernel.org>
diff mbox series

Patch

diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
index aefb73eeeebf..be16677f0e4c 100644
--- a/include/linux/io_uring.h
+++ b/include/linux/io_uring.h
@@ -54,7 +54,7 @@  int io_uring_cmd_import_fixed(u64 ubuf, unsigned long len, int rw,
 			      struct iov_iter *iter, void *ioucmd);
 void io_uring_cmd_done(struct io_uring_cmd *cmd, ssize_t ret, ssize_t res2,
 			unsigned issue_flags);
-struct sock *io_uring_get_socket(struct file *file);
+struct unix_sock *io_uring_get_socket(struct file *file);
 void __io_uring_cancel(bool cancel_all);
 void __io_uring_free(struct task_struct *tsk);
 void io_uring_unreg_ringfd(void);
@@ -111,7 +111,7 @@  static inline void io_uring_cmd_do_in_task_lazy(struct io_uring_cmd *ioucmd,
 			void (*task_work_cb)(struct io_uring_cmd *, unsigned))
 {
 }
-static inline struct sock *io_uring_get_socket(struct file *file)
+static inline struct unix_sock *io_uring_get_socket(struct file *file)
 {
 	return NULL;
 }
diff --git a/include/net/af_unix.h b/include/net/af_unix.h
index 5a8a670b1920..c628d30ceb19 100644
--- a/include/net/af_unix.h
+++ b/include/net/af_unix.h
@@ -14,7 +14,7 @@  void unix_destruct_scm(struct sk_buff *skb);
 void io_uring_destruct_scm(struct sk_buff *skb);
 void unix_gc(void);
 void wait_for_unix_gc(void);
-struct sock *unix_get_socket(struct file *filp);
+struct unix_sock *unix_get_socket(struct file *filp);
 struct sock *unix_peer_get(struct sock *sk);
 
 #define UNIX_HASH_MOD	(256 - 1)
diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
index ed254076c723..daed897f5975 100644
--- a/io_uring/io_uring.c
+++ b/io_uring/io_uring.c
@@ -177,13 +177,14 @@  static struct ctl_table kernel_io_uring_disabled_table[] = {
 };
 #endif
 
-struct sock *io_uring_get_socket(struct file *file)
+struct unix_sock *io_uring_get_socket(struct file *file)
 {
 #if defined(CONFIG_UNIX)
 	if (io_is_uring_fops(file)) {
 		struct io_ring_ctx *ctx = file->private_data;
 
-		return ctx->ring_sock->sk;
+		if (ctx->ring_sock->sk)
+			return unix_sk(ctx->ring_sock->sk);
 	}
 #endif
 	return NULL;
diff --git a/net/unix/garbage.c b/net/unix/garbage.c
index db1bb99bb793..4d634f5f6a55 100644
--- a/net/unix/garbage.c
+++ b/net/unix/garbage.c
@@ -105,20 +105,15 @@  static void scan_inflight(struct sock *x, void (*func)(struct unix_sock *),
 
 			while (nfd--) {
 				/* Get the socket the fd matches if it indeed does so */
-				struct sock *sk = unix_get_socket(*fp++);
+				struct unix_sock *u = unix_get_socket(*fp++);
 
-				if (sk) {
-					struct unix_sock *u = unix_sk(sk);
+				/* Ignore non-candidates, they could have been added
+				 * to the queues after starting the garbage collection
+				 */
+				if (u && test_bit(UNIX_GC_CANDIDATE, &u->gc_flags)) {
+					hit = true;
 
-					/* Ignore non-candidates, they could
-					 * have been added to the queues after
-					 * starting the garbage collection
-					 */
-					if (test_bit(UNIX_GC_CANDIDATE, &u->gc_flags)) {
-						hit = true;
-
-						func(u);
-					}
+					func(u);
 				}
 			}
 			if (hit && hitlist != NULL) {
diff --git a/net/unix/scm.c b/net/unix/scm.c
index 4b3979272a81..36ce8fed9acc 100644
--- a/net/unix/scm.c
+++ b/net/unix/scm.c
@@ -21,9 +21,8 @@  EXPORT_SYMBOL(gc_inflight_list);
 DEFINE_SPINLOCK(unix_gc_lock);
 EXPORT_SYMBOL(unix_gc_lock);
 
-struct sock *unix_get_socket(struct file *filp)
+struct unix_sock *unix_get_socket(struct file *filp)
 {
-	struct sock *u_sock = NULL;
 	struct inode *inode = file_inode(filp);
 
 	/* Socket ? */
@@ -34,12 +33,13 @@  struct sock *unix_get_socket(struct file *filp)
 
 		/* PF_UNIX ? */
 		if (s && ops && ops->family == PF_UNIX)
-			u_sock = s;
-	} else {
-		/* Could be an io_uring instance */
-		u_sock = io_uring_get_socket(filp);
+			return unix_sk(s);
+
+		return NULL;
 	}
-	return u_sock;
+
+	/* Could be an io_uring instance */
+	return io_uring_get_socket(filp);
 }
 EXPORT_SYMBOL(unix_get_socket);
 
@@ -48,13 +48,11 @@  EXPORT_SYMBOL(unix_get_socket);
  */
 void unix_inflight(struct user_struct *user, struct file *fp)
 {
-	struct sock *s = unix_get_socket(fp);
+	struct unix_sock *u = unix_get_socket(fp);
 
 	spin_lock(&unix_gc_lock);
 
-	if (s) {
-		struct unix_sock *u = unix_sk(s);
-
+	if (u) {
 		if (!u->inflight) {
 			BUG_ON(!list_empty(&u->link));
 			list_add_tail(&u->link, &gc_inflight_list);
@@ -71,13 +69,11 @@  void unix_inflight(struct user_struct *user, struct file *fp)
 
 void unix_notinflight(struct user_struct *user, struct file *fp)
 {
-	struct sock *s = unix_get_socket(fp);
+	struct unix_sock *u = unix_get_socket(fp);
 
 	spin_lock(&unix_gc_lock);
 
-	if (s) {
-		struct unix_sock *u = unix_sk(s);
-
+	if (u) {
 		BUG_ON(!u->inflight);
 		BUG_ON(list_empty(&u->link));