@@ -258,8 +258,67 @@ bool lockdep_rtnl_net_is_held(struct net *net)
return lockdep_rtnl_is_held() && lockdep_is_held(&net->rtnl_mutex);
}
EXPORT_SYMBOL(lockdep_rtnl_net_is_held);
+#else
+static int rtnl_net_cmp_locks(const struct net *net_a, const struct net *net_b)
+{
+ /* No need to swap */
+ return -1;
+}
#endif
+struct rtnl_nets {
+ /* ->newlink() needs to freeze 3 netns at most;
+ * 2 for the new device, 1 for its peer.
+ */
+ struct net *net[3];
+ unsigned char len;
+};
+
+static void rtnl_nets_init(struct rtnl_nets *rtnl_nets)
+{
+ memset(rtnl_nets, 0, sizeof(*rtnl_nets));
+}
+
+static void rtnl_nets_destroy(struct rtnl_nets *rtnl_nets)
+{
+ int i;
+
+ for (i = 0; i < rtnl_nets->len; i++) {
+ put_net(rtnl_nets->net[i]);
+ rtnl_nets->net[i] = NULL;
+ }
+
+ rtnl_nets->len = 0;
+}
+
+/**
+ * rtnl_nets_add - Add netns to be locked before ->newlink().
+ *
+ * @rtnl_nets: rtnl_nets pointer passed to ->get_peer_net().
+ * @net: netns pointer with an extra refcnt held.
+ *
+ * The extra refcnt is released in rtnl_nets_destroy().
+ */
+static void rtnl_nets_add(struct rtnl_nets *rtnl_nets, struct net *net)
+{
+ int i;
+
+ DEBUG_NET_WARN_ON_ONCE(rtnl_nets->len == ARRAY_SIZE(rtnl_nets->net));
+
+ for (i = 0; i < rtnl_nets->len; i++) {
+ switch (rtnl_net_cmp_locks(rtnl_nets->net[i], net)) {
+ case 0:
+ put_net(net);
+ return;
+ case 1:
+ swap(rtnl_nets->net[i], net);
+ }
+ }
+
+ rtnl_nets->net[i] = net;
+ rtnl_nets->len++;
+}
+
static struct rtnl_link __rcu *__rcu *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
static inline int rtm_msgindex(int msgtype)
@@ -3796,6 +3855,7 @@ static int rtnl_newlink(struct sk_buff *skb, struct nlmsghdr *nlh,
struct net *tgt_net, *link_net = NULL;
struct rtnl_link_ops *ops = NULL;
struct rtnl_newlink_tbs *tbs;
+ struct rtnl_nets rtnl_nets;
int ops_srcu_index;
int ret;
@@ -3839,6 +3899,8 @@ static int rtnl_newlink(struct sk_buff *skb, struct nlmsghdr *nlh,
#endif
}
+ rtnl_nets_init(&rtnl_nets);
+
if (ops) {
if (ops->maxtype > RTNL_MAX_TYPE) {
ret = -EINVAL;
@@ -3868,6 +3930,8 @@ static int rtnl_newlink(struct sk_buff *skb, struct nlmsghdr *nlh,
goto put_ops;
}
+ rtnl_nets_add(&rtnl_nets, tgt_net);
+
if (tb[IFLA_LINK_NETNSID]) {
int id = nla_get_s32(tb[IFLA_LINK_NETNSID]);
@@ -3878,6 +3942,8 @@ static int rtnl_newlink(struct sk_buff *skb, struct nlmsghdr *nlh,
goto put_net;
}
+ rtnl_nets_add(&rtnl_nets, link_net);
+
if (!netlink_ns_capable(skb, link_net->user_ns, CAP_NET_ADMIN)) {
ret = -EPERM;
goto put_net;
@@ -3887,9 +3953,7 @@ static int rtnl_newlink(struct sk_buff *skb, struct nlmsghdr *nlh,
ret = __rtnl_newlink(skb, nlh, ops, tgt_net, link_net, tbs, data, extack);
put_net:
- if (link_net)
- put_net(link_net);
- put_net(tgt_net);
+ rtnl_nets_destroy(&rtnl_nets);
put_ops:
if (ops)
rtnl_link_ops_put(ops, ops_srcu_index);