@@ -369,12 +369,40 @@ int mptcp_pm_nl_remove_doit(struct sk_buff *skb, struct genl_info *info)
return err;
}
+static int mptcp_userspace_pm_subflow_established(struct mptcp_sock *msk,
+ struct mptcp_pm_param *param)
+{
+ struct mptcp_pm_addr_entry *entry = ¶m->entry;
+ struct mptcp_addr_info *remote = ¶m->addr;
+ struct sock *sk = (struct sock *)msk;
+ struct mptcp_pm_local local;
+ int err;
+
+ err = mptcp_userspace_pm_append_new_local_addr(msk, entry, false);
+ if (err < 0)
+ return err;
+
+ local.addr = entry->addr;
+ local.flags = entry->flags;
+ local.ifindex = entry->ifindex;
+
+ err = __mptcp_subflow_connect(sk, &local, remote);
+ spin_lock_bh(&msk->pm.lock);
+ if (err)
+ mptcp_userspace_pm_delete_local_addr(msk, entry);
+ else
+ msk->pm.subflows++;
+ spin_unlock_bh(&msk->pm.lock);
+
+ return err;
+}
+
int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info)
{
struct mptcp_pm_addr_entry entry = { 0 };
struct mptcp_addr_info addr_r;
struct nlattr *raddr, *laddr;
- struct mptcp_pm_local local;
+ struct mptcp_pm_param param;
struct mptcp_sock *msk;
int err = -EINVAL;
struct sock *sk;
@@ -412,31 +440,16 @@ int mptcp_pm_nl_subflow_create_doit(struct sk_buff *skb, struct genl_info *info)
goto create_err;
}
- err = mptcp_userspace_pm_append_new_local_addr(msk, &entry, false);
- if (err < 0) {
- NL_SET_ERR_MSG_ATTR(info->extack, laddr,
- "did not match address and id");
- goto create_err;
- }
-
- local.addr = entry.addr;
- local.flags = entry.flags;
- local.ifindex = entry.ifindex;
-
lock_sock(sk);
- err = __mptcp_subflow_connect(sk, &local, &addr_r);
+ mptcp_pm_param_set_contexts(¶m, &entry, &addr_r);
+ err = msk->pm.ops && msk->pm.ops->subflow_established ?
+ msk->pm.ops->subflow_established(msk, ¶m) :
+ mptcp_userspace_pm_subflow_established(msk, ¶m);
release_sock(sk);
if (err)
GENL_SET_ERR_MSG_FMT(info, "connect error: %d", err);
- spin_lock_bh(&msk->pm.lock);
- if (err)
- mptcp_userspace_pm_delete_local_addr(msk, &entry);
- else
- msk->pm.subflows++;
- spin_unlock_bh(&msk->pm.lock);
-
create_err:
sock_put(sk);
return err;
@@ -705,6 +718,7 @@ int mptcp_userspace_pm_get_addr(u8 id, struct mptcp_pm_addr_entry *addr,
static struct mptcp_pm_ops mptcp_userspace_pm = {
.address_announced = mptcp_userspace_pm_address_announced,
.address_removed = mptcp_userspace_pm_address_removed,
+ .subflow_established = mptcp_userspace_pm_subflow_established,
.get_local_id = mptcp_userspace_pm_get_local_id,
.get_priority = mptcp_userspace_pm_get_priority,
.type = MPTCP_PM_TYPE_USERSPACE,