@@ -826,11 +826,6 @@ int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
return -EINVAL;
}
-static bool mptcp_local_id_match(const struct mptcp_sock *msk, u8 local_id, u8 id)
-{
- return local_id == id || (!local_id && msk->mpc_endpoint_id == id);
-}
-
static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
const struct mptcp_rm_list *rm_list,
enum linux_mptcp_mib_field rm_type)
@@ -867,7 +862,7 @@ static void mptcp_pm_nl_rm_addr_or_subflow(struct mptcp_sock *msk,
continue;
if (rm_type == MPTCP_MIB_RMADDR && remote_id != rm_id)
continue;
- if (rm_type == MPTCP_MIB_RMSUBFLOW && !mptcp_local_id_match(msk, id, rm_id))
+ if (rm_type == MPTCP_MIB_RMSUBFLOW && id != rm_id)
continue;
pr_debug(" -> %s rm_list_ids[%d]=%u local_id=%u remote_id=%u mpc_id=%u\n",
@@ -1501,6 +1496,12 @@ static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
return false;
}
+static u8 mptcp_endp_get_local_id(struct mptcp_sock *msk,
+ const struct mptcp_addr_info *addr)
+{
+ return msk->mpc_endpoint_id == addr->id ? 0 : addr->id;
+}
+
static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
const struct mptcp_addr_info *addr,
bool force)
@@ -1508,7 +1509,7 @@ static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
struct mptcp_rm_list list = { .nr = 0 };
bool ret;
- list.ids[list.nr++] = addr->id;
+ list.ids[list.nr++] = mptcp_endp_get_local_id(msk, addr);
ret = remove_anno_list_by_saddr(msk, addr);
if (ret || force) {
@@ -1535,14 +1536,12 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
const struct mptcp_pm_addr_entry *entry)
{
const struct mptcp_addr_info *addr = &entry->addr;
- struct mptcp_rm_list list = { .nr = 0 };
+ struct mptcp_rm_list list = { .nr = 1 };
long s_slot = 0, s_num = 0;
struct mptcp_sock *msk;
pr_debug("remove_id=%d\n", addr->id);
- list.ids[list.nr++] = addr->id;
-
while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
struct sock *sk = (struct sock *)msk;
bool remove_subflow;
@@ -1560,6 +1559,7 @@ static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
mptcp_pm_remove_anno_addr(msk, addr, remove_subflow &&
!(entry->flags & MPTCP_PM_ADDR_FLAG_IMPLICIT));
+ list.ids[0] = mptcp_endp_get_local_id(msk, addr);
if (remove_subflow) {
spin_lock_bh(&msk->pm.lock);
mptcp_pm_nl_rm_subflow_received(msk, &list);
@@ -1668,6 +1668,7 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
return ret;
}
+/* Called from the userspace PM only */
void mptcp_pm_remove_addrs(struct mptcp_sock *msk, struct list_head *rm_list)
{
struct mptcp_rm_list alist = { .nr = 0 };
@@ -1696,6 +1697,7 @@ void mptcp_pm_remove_addrs(struct mptcp_sock *msk, struct list_head *rm_list)
}
}
+/* Called from the in-kernel PM only */
static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk,
struct list_head *rm_list)
{
@@ -1705,11 +1707,11 @@ static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk,
list_for_each_entry(entry, rm_list, list) {
if (slist.nr < MPTCP_RM_IDS_MAX &&
lookup_subflow_by_saddr(&msk->conn_list, &entry->addr))
- slist.ids[slist.nr++] = entry->addr.id;
+ slist.ids[slist.nr++] = mptcp_endp_get_local_id(msk, &entry->addr);
if (alist.nr < MPTCP_RM_IDS_MAX &&
remove_anno_list_by_saddr(msk, &entry->addr))
- alist.ids[alist.nr++] = entry->addr.id;
+ alist.ids[alist.nr++] = mptcp_endp_get_local_id(msk, &entry->addr);
}
spin_lock_bh(&msk->pm.lock);
@@ -1997,7 +1999,7 @@ static void mptcp_pm_nl_fullmesh(struct mptcp_sock *msk,
{
struct mptcp_rm_list list = { .nr = 0 };
- list.ids[list.nr++] = addr->id;
+ list.ids[list.nr++] = mptcp_endp_get_local_id(msk, addr);
spin_lock_bh(&msk->pm.lock);
mptcp_pm_nl_rm_subflow_received(msk, &list);