@@ -279,14 +279,14 @@ static void __net_exit devlink_pernet_pre_exit(struct net *net)
* all devlink instances from this namespace into init_net.
*/
devlinks_xa_for_each_registered_get(net, index, devlink) {
- devl_lock(devlink);
+ devl_dev_lock(devlink, true);
err = 0;
if (devl_is_registered(devlink))
err = devlink_reload(devlink, &init_net,
DEVLINK_RELOAD_ACTION_DRIVER_REINIT,
DEVLINK_RELOAD_LIMIT_UNSPEC,
&actions_performed, NULL);
- devl_unlock(devlink);
+ devl_dev_unlock(devlink, true);
devlink_put(devlink);
if (err && err != -EOPNOTSUPP)
pr_warn("Failed to reload devlink instance into init_net\n");
@@ -4,6 +4,7 @@
* Copyright (c) 2016 Jiri Pirko <jiri@mellanox.com>
*/
+#include <linux/device.h>
#include <net/genetlink.h>
#include <net/sock.h>
#include "devl_internal.h"
@@ -356,6 +357,13 @@ int devlink_reload(struct devlink *devlink, struct net *dest_net,
struct net *curr_net;
int err;
+ /* Make sure the reload operations are invoked with the device lock
+ * held to allow drivers to trigger functionality that expects it
+ * (e.g., PCI reset) and to close possible races between these
+ * operations and probe/remove.
+ */
+ device_lock_assert(devlink->dev);
+
memcpy(remote_reload_stats, devlink->stats.remote_reload_stats,
sizeof(remote_reload_stats));
@@ -3,6 +3,7 @@
* Copyright (c) 2016 Jiri Pirko <jiri@mellanox.com>
*/
+#include <linux/device.h>
#include <linux/mutex.h>
#include <linux/netdevice.h>
#include <linux/notifier.h>
@@ -87,12 +88,27 @@ static inline bool devl_is_registered(struct devlink *devlink)
return xa_get_mark(&devlinks, devlink->index, DEVLINK_REGISTERED);
}
+static inline void devl_dev_lock(struct devlink *devlink, bool dev_lock)
+{
+ if (dev_lock)
+ device_lock(devlink->dev);
+ devl_lock(devlink);
+}
+
+static inline void devl_dev_unlock(struct devlink *devlink, bool dev_lock)
+{
+ devl_unlock(devlink);
+ if (dev_lock)
+ device_unlock(devlink->dev);
+}
+
/* Netlink */
#define DEVLINK_NL_FLAG_NEED_PORT BIT(0)
#define DEVLINK_NL_FLAG_NEED_DEVLINK_OR_PORT BIT(1)
#define DEVLINK_NL_FLAG_NEED_RATE BIT(2)
#define DEVLINK_NL_FLAG_NEED_RATE_NODE BIT(3)
#define DEVLINK_NL_FLAG_NEED_LINECARD BIT(4)
+#define DEVLINK_NL_FLAG_NEED_DEV_LOCK BIT(5)
enum devlink_multicast_groups {
DEVLINK_MCGRP_CONFIG,
@@ -122,7 +138,8 @@ struct devlink_cmd {
extern const struct genl_small_ops devlink_nl_ops[56];
struct devlink *
-devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs);
+devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs,
+ bool dev_lock);
void devlink_notify_unregister(struct devlink *devlink);
void devlink_notify_register(struct devlink *devlink);
@@ -1253,7 +1253,8 @@ devlink_health_reporter_get_from_cb(struct netlink_callback *cb)
struct nlattr **attrs = info->attrs;
struct devlink *devlink;
- devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs);
+ devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs,
+ false);
if (IS_ERR(devlink))
return NULL;
devl_unlock(devlink);
@@ -5185,7 +5185,8 @@ static int devlink_nl_cmd_region_read_dumpit(struct sk_buff *skb,
start_offset = state->start_offset;
- devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs);
+ devlink = devlink_get_from_attrs_lock(sock_net(cb->skb->sk), attrs,
+ false);
if (IS_ERR(devlink))
return PTR_ERR(devlink);
@@ -6478,6 +6479,7 @@ const struct genl_small_ops devlink_nl_ops[56] = {
.validate = GENL_DONT_VALIDATE_STRICT | GENL_DONT_VALIDATE_DUMP,
.doit = devlink_nl_cmd_reload,
.flags = GENL_ADMIN_PERM,
+ .internal_flags = DEVLINK_NL_FLAG_NEED_DEV_LOCK,
},
{
.cmd = DEVLINK_CMD_PARAM_GET,
@@ -83,7 +83,8 @@ static const struct nla_policy devlink_nl_policy[DEVLINK_ATTR_MAX + 1] = {
};
struct devlink *
-devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs)
+devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs,
+ bool dev_lock)
{
struct devlink *devlink;
unsigned long index;
@@ -97,12 +98,12 @@ devlink_get_from_attrs_lock(struct net *net, struct nlattr **attrs)
devname = nla_data(attrs[DEVLINK_ATTR_DEV_NAME]);
devlinks_xa_for_each_registered_get(net, index, devlink) {
- devl_lock(devlink);
+ devl_dev_lock(devlink, dev_lock);
if (devl_is_registered(devlink) &&
strcmp(devlink->dev->bus->name, busname) == 0 &&
strcmp(dev_name(devlink->dev), devname) == 0)
return devlink;
- devl_unlock(devlink);
+ devl_dev_unlock(devlink, dev_lock);
devlink_put(devlink);
}
@@ -115,9 +116,12 @@ static int devlink_nl_pre_doit(const struct genl_split_ops *ops,
struct devlink_linecard *linecard;
struct devlink_port *devlink_port;
struct devlink *devlink;
+ bool dev_lock;
int err;
- devlink = devlink_get_from_attrs_lock(genl_info_net(info), info->attrs);
+ dev_lock = !!(ops->internal_flags & DEVLINK_NL_FLAG_NEED_DEV_LOCK);
+ devlink = devlink_get_from_attrs_lock(genl_info_net(info), info->attrs,
+ dev_lock);
if (IS_ERR(devlink))
return PTR_ERR(devlink);
@@ -162,7 +166,7 @@ static int devlink_nl_pre_doit(const struct genl_split_ops *ops,
return 0;
unlock:
- devl_unlock(devlink);
+ devl_dev_unlock(devlink, dev_lock);
devlink_put(devlink);
return err;
}
@@ -171,9 +175,11 @@ static void devlink_nl_post_doit(const struct genl_split_ops *ops,
struct sk_buff *skb, struct genl_info *info)
{
struct devlink *devlink;
+ bool dev_lock;
+ dev_lock = !!(ops->internal_flags & DEVLINK_NL_FLAG_NEED_DEV_LOCK);
devlink = info->user_ptr[0];
- devl_unlock(devlink);
+ devl_dev_unlock(devlink, dev_lock);
devlink_put(devlink);
}