@@ -267,11 +267,8 @@ static int mptcp_userspace_pm_remove_id_zero_address(struct mptcp_sock *msk)
{
struct mptcp_rm_list list = { .nr = 0 };
struct mptcp_subflow_context *subflow;
- struct sock *sk = (struct sock *)msk;
bool has_id_0 = false;
- int err = -EINVAL;
- lock_sock(sk);
mptcp_for_each_subflow(msk, subflow) {
if (READ_ONCE(subflow->local_id) == 0) {
has_id_0 = true;
@@ -279,7 +276,7 @@ static int mptcp_userspace_pm_remove_id_zero_address(struct mptcp_sock *msk)
}
}
if (!has_id_0)
- goto remove_err;
+ return -EINVAL;
list.ids[list.nr++] = 0;
@@ -287,11 +284,7 @@ static int mptcp_userspace_pm_remove_id_zero_address(struct mptcp_sock *msk)
mptcp_pm_remove_addr(msk, &list);
spin_unlock_bh(&msk->pm.lock);
- err = 0;
-
-remove_err:
- release_sock(sk);
- return err;
+ return 0;
}
void mptcp_pm_remove_addr_entry(struct mptcp_sock *msk,
@@ -314,20 +307,46 @@ void mptcp_pm_remove_addr_entry(struct mptcp_sock *msk,
spin_unlock_bh(&msk->pm.lock);
}
+static int mptcp_userspace_pm_address_removed(struct mptcp_sock *msk,
+ struct mptcp_pm_param *param)
+{
+ struct mptcp_pm_addr_entry *entry;
+ u8 id = param->addr.id;
+
+ if (id == 0)
+ return mptcp_userspace_pm_remove_id_zero_address(msk);
+
+ spin_lock_bh(&msk->pm.lock);
+ entry = mptcp_userspace_pm_lookup_addr_by_id(msk, id);
+ if (!entry) {
+ spin_unlock_bh(&msk->pm.lock);
+ return -EINVAL;
+ }
+
+ list_del_rcu(&entry->list);
+ spin_unlock_bh(&msk->pm.lock);
+
+ mptcp_pm_remove_addr_entry(msk, entry);
+
+ sock_kfree_s((struct sock *)msk, entry, sizeof(*entry));
+
+ return 0;
+}
+
int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info)
{
- struct mptcp_pm_addr_entry *match;
+ struct mptcp_addr_info addr;
+ struct mptcp_pm_param param;
struct mptcp_sock *msk;
struct nlattr *id;
int err = -EINVAL;
struct sock *sk;
- u8 id_val;
if (GENL_REQ_ATTR_CHECK(info, MPTCP_PM_ATTR_LOC_ID))
return err;
id = info->attrs[MPTCP_PM_ATTR_LOC_ID];
- id_val = nla_get_u8(id);
+ addr.id = nla_get_u8(id);
msk = mptcp_userspace_pm_get_sock(info);
if (!msk)
@@ -335,36 +354,16 @@ int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info)
sk = (struct sock *)msk;
- if (id_val == 0) {
- err = mptcp_userspace_pm_remove_id_zero_address(msk);
- goto out;
- }
-
lock_sock(sk);
-
- spin_lock_bh(&msk->pm.lock);
- match = mptcp_userspace_pm_lookup_addr_by_id(msk, id_val);
- if (!match) {
- spin_unlock_bh(&msk->pm.lock);
- release_sock(sk);
- goto out;
- }
-
- list_del_rcu(&match->list);
- spin_unlock_bh(&msk->pm.lock);
-
- mptcp_pm_remove_addr_entry(msk, match);
-
+ mptcp_pm_param_set_contexts(¶m, NULL, &addr);
+ err = msk->pm.ops && msk->pm.ops->address_removed ?
+ msk->pm.ops->address_removed(msk, ¶m) :
+ mptcp_userspace_pm_address_removed(msk, ¶m);
release_sock(sk);
-
- sock_kfree_s(sk, match, sizeof(*match));
-
- err = 0;
-out:
if (err)
NL_SET_ERR_MSG_ATTR_FMT(info->extack, id,
"address with id %u not found",
- id_val);
+ addr.id);
sock_put(sk);
return err;
@@ -705,6 +704,7 @@ int mptcp_userspace_pm_get_addr(u8 id, struct mptcp_pm_addr_entry *addr,
static struct mptcp_pm_ops mptcp_userspace_pm = {
.address_announced = mptcp_userspace_pm_address_announced,
+ .address_removed = mptcp_userspace_pm_address_removed,
.get_local_id = mptcp_userspace_pm_get_local_id,
.get_priority = mptcp_userspace_pm_get_priority,
.type = MPTCP_PM_TYPE_USERSPACE,