diff mbox series

[RFC,net-next,2/2] devlink: Acquire device lock during reload

Message ID 20230619125015.1541143-3-idosch@nvidia.com (mailing list archive)
State RFC
Delegated to: Netdev Maintainers
Headers show
Series devlink: Acquire device lock during reload | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for net-next, async
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 10 this patch: 10
netdev/cc_maintainers success CCed 6 of 6 maintainers
netdev/build_clang success Errors and warnings before: 8 this patch: 8
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 10 this patch: 10
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 160 lines checked
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Ido Schimmel June 19, 2023, 12:50 p.m. UTC
Device drivers register with devlink from their probe routines (under
the device lock) by acquiring the devlink instance lock and calling
devl_register().

Drivers that support a devlink reload usually implement the
reload_{down, up}() operations in a similar fashion to their remove and
probe routines, respectively.

However, while the remove and probe routines are invoked with the device
lock held, the reload operations are only invoked with the devlink
instance lock held. It is therefore impossible for drivers to acquire
the device lock from their reload operations, as this would result in
lock inversion.

The motivating use case for invoking the reload operations with the
device lock held is in mlxsw which needs to trigger a PCI reset as part
of the reload. The driver cannot call pci_reset_function() as this
function acquires the device lock. Instead, it needs to call
__pci_reset_function_locked which expects the device lock to be held.

To that end, adjust devlink to always acquire the device lock before the
devlink instance lock when performing a reload. Do that both when reload
is triggered explicitly by user space and when it is triggered as part
of netns dismantle.

Tested the following flows with netdevsim and mlxsw while lockdep is
enabled:

netdevsim:

 # echo "10 1" > /sys/bus/netdevsim/new_device
 # devlink dev reload netdevsim/netdevsim10
 # ip netns add bla
 # devlink dev reload netdevsim/netdevsim10 netns bla
 # ip netns del bla
 # echo 10 > /sys/bus/netdevsim/del_device

mlxsw:

 # devlink dev reload pci/0000:01:00.0
 # ip netns add bla
 # devlink dev reload pci/0000:01:00.0 netns bla
 # ip netns del bla
 # echo 1 > /sys/bus/pci/devices/0000\:01\:00.0/remove
 # echo 1 > /sys/bus/pci/rescan

Reviewed-by: Jiri Pirko <jiri@nvidia.com>
Signed-off-by: Ido Schimmel <idosch@nvidia.com>
---
 net/devlink/core.c          |  4 ++--
 net/devlink/dev.c           |  8 ++++++++
 net/devlink/devl_internal.h | 19 ++++++++++++++++++-
 net/devlink/health.c        |  3 ++-
 net/devlink/leftover.c      |  4 +++-
 net/devlink/netlink.c       | 18 ++++++++++++------
 6 files changed, 45 insertions(+), 11 deletions(-)
diff mbox series

Patch

diff --git a/net/devlink/core.c b/net/devlink/core.c
index b191112f57af..a4b6d548e50c 100644
--- a/net/devlink/core.c
+++ b/net/devlink/core.c
@@ -279,14 +279,14 @@  static void __net_exit devlink_pernet_pre_exit(struct net *net)
 	 * all devlink instances from this namespace into init_net.
 	 */
 	devlinks_xa_for_each_registered_get(net, index, devlink) {
-		devl_lock(devlink);
+		devl_dev_lock(devlink, true);
 		err = 0;
 		if (devl_is_registered(devlink))
 			err = devlink_reload(devlink, &init_net,
 					     DEVLINK_RELOAD_ACTION_DRIVER_REINIT,
 					     DEVLINK_RELOAD_LIMIT_UNSPEC,
 					     &actions_performed, NULL);
-		devl_unlock(devlink);
+		devl_dev_unlock(devlink, true);
 		devlink_put(devlink);
 		if (err && err != -EOPNOTSUPP)
 			pr_warn("Failed to reload devlink instance into init_net\n");
diff --git a/net/devlink/dev.c b/net/devlink/dev.c
index bf1d6f1bcfc7..daee2039fb58 100644
--- a/net/devlink/dev.c
+++ b/net/devlink/dev.c
@@ -4,6 +4,7 @@ 
  * Copyright (c) 2016 Jiri Pirko <jiri@mellanox.com>
  */
 
+#include <linux/device.h>
 #include <net/genetlink.h>
 #include <net/sock.h>
 #include "devl_internal.h"
@@ -356,6 +357,13 @@  int devlink_reload(struct devlink *devlink, struct net *dest_net,
 	struct net *curr_net;
 	int err;
 
+	/* Make sure the reload operations are invoked with the device lock
+	 * held to allow drivers to trigger functionality that expects it
+	 * (e.g., PCI reset) and to close possible races between these
+	 * operations and probe/remove.
+	 */
+	device_lock_assert(devlink->dev);
+
 	memcpy(remote_reload_stats, devlink->stats.remote_reload_stats,
 	       sizeof(remote_reload_stats));
 
diff --git a/net/devlink/devl_internal.h b/net/devlink/devl_internal.h
index 62921b2eb0d3..99c3efbae718 100644
--- a/net/devlink/devl_internal.h
+++ b/net/devlink/devl_internal.h
@@ -3,6 +3,7 @@ 
  * Copyright (c) 2016 Jiri Pirko <jiri@mellanox.com>
  */
 
+#include <linux/device.h>
 #include <linux/mutex.h>
 #include <linux/netdevice.h>
 #include <linux/notifier.h>
@@ -87,12 +88,27 @@  static inline bool devl_is_registered(struct devlink *devlink)
 	return xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED);
 }
 
+static inline void devl_dev_lock(struct devlink *devlink, bool dev_lock)
+{
+	if (dev_lock)
+		device_lock(devlink->dev);
+	devl_lock(devlink);
+}
+
+static inline void devl_dev_unlock(struct devlink *devlink, bool dev_lock)
+{
+	devl_unlock(devlink);
+	if (dev_lock)
+		device_unlock(devlink->dev);
+}
+
 /* Netlink */
 #define DEVLINK_NL_FLAG_NEED_PORT		BIT(0)
 #define DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT	BIT(1)
 #define DEVLINK_NL_FLAG_NEED_RATE		BIT(2)
 #define DEVLINK_NL_FLAG_NEED_RATE_NODE		BIT(3)
 #define DEVLINK_NL_FLAG_NEED_LINECARD		BIT(4)
+#define DEVLINK_NL_FLAG_NEED_DEV_LOCK		BIT(5)
 
 enum devlink_multicast_groups {
 	DEVLINK_MCGRP_CONFIG,
@@ -122,7 +138,8 @@  struct devlink_cmd {
 extern const struct genl_small_ops devlink_nl_ops[56];
 
 struct devlink *
-devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs);
+devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs,
+			    bool dev_lock);
 
 void devlink_notify_unregister(struct devlink *devlink);
 void devlink_notify_register(struct devlink *devlink);
diff --git a/net/devlink/health.c b/net/devlink/health.c
index 194340a8bb86..fa8ccdcffb7a 100644
--- a/net/devlink/health.c
+++ b/net/devlink/health.c
@@ -1253,7 +1253,8 @@  devlink_health_reporter_get_from_cb(struct netlink_callback *cb)
 	struct nlattr **attrs = info->attrs;
 	struct devlink *devlink;
 
-	devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs);
+	devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs,
+					      false);
 	if (IS_ERR(devlink))
 		return NULL;
 	devl_unlock(devlink);
diff --git a/net/devlink/leftover.c b/net/devlink/leftover.c
index 1f00f874471f..f4e6030e3b56 100644
--- a/net/devlink/leftover.c
+++ b/net/devlink/leftover.c
@@ -5185,7 +5185,8 @@  static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb,
 
 	start_offset = state->start_offset;
 
-	devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs);
+	devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs,
+					      false);
 	if (IS_ERR(devlink))
 		return PTR_ERR(devlink);
 
@@ -6478,6 +6479,7 @@  const struct genl_small_ops devlink_nl_ops[56] = {
 		.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
 		.doit = devlink_nl_cmd_reload,
 		.flags = GENL_ADMIN_PERM,
+		.internal_flags = DEVLINK_NL_FLAG_NEED_DEV_LOCK,
 	},
 	{
 		.cmd = DEVLINK_CMD_PARAM_GET,
diff --git a/net/devlink/netlink.c b/net/devlink/netlink.c
index 7a332eb70f70..95fd8e3befea 100644
--- a/net/devlink/netlink.c
+++ b/net/devlink/netlink.c
@@ -83,7 +83,8 @@  static const struct nla_policy devlink_nl_policy[DEVLINK_ATTR_MAX + 1] = {
 };
 
 struct devlink *
-devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs)
+devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs,
+			    bool dev_lock)
 {
 	struct devlink *devlink;
 	unsigned long index;
@@ -97,12 +98,12 @@  devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs)
 	devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);
 
 	devlinks_xa_for_each_registered_get(net, index, devlink) {
-		devl_lock(devlink);
+		devl_dev_lock(devlink, dev_lock);
 		if (devl_is_registered(devlink) &&
 		    strcmp(devlink->dev->bus->name, busname) == 0 &&
 		    strcmp(dev_name(devlink->dev), devname) == 0)
 			return devlink;
-		devl_unlock(devlink);
+		devl_dev_unlock(devlink, dev_lock);
 		devlink_put(devlink);
 	}
 
@@ -115,9 +116,12 @@  static int devlink_nl_pre_doit(const struct genl_split_ops *ops,
 	struct devlink_linecard *linecard;
 	struct devlink_port *devlink_port;
 	struct devlink *devlink;
+	bool dev_lock;
 	int err;
 
-	devlink = devlink_get_from_attrs_lock(genl_info_net(info), info->attrs);
+	dev_lock = !!(ops->internal_flags & DEVLINK_NL_FLAG_NEED_DEV_LOCK);
+	devlink = devlink_get_from_attrs_lock(genl_info_net(info), info->attrs,
+					      dev_lock);
 	if (IS_ERR(devlink))
 		return PTR_ERR(devlink);
 
@@ -162,7 +166,7 @@  static int devlink_nl_pre_doit(const struct genl_split_ops *ops,
 	return 0;
 
 unlock:
-	devl_unlock(devlink);
+	devl_dev_unlock(devlink, dev_lock);
 	devlink_put(devlink);
 	return err;
 }
@@ -171,9 +175,11 @@  static void devlink_nl_post_doit(const struct genl_split_ops *ops,
 				 struct sk_buff *skb, struct genl_info *info)
 {
 	struct devlink *devlink;
+	bool dev_lock;
 
+	dev_lock = !!(ops->internal_flags & DEVLINK_NL_FLAG_NEED_DEV_LOCK);
 	devlink = info->user_ptr[0];
-	devl_unlock(devlink);
+	devl_dev_unlock(devlink, dev_lock);
 	devlink_put(devlink);
 }