@@ -545,7 +545,7 @@ static void iwl_mvm_mld_vif_delete_all_stas(struct iwl_mvm *mvm,
continue;
iwl_mvm_sec_key_remove_ap(mvm, vif, link, i);
- ret = iwl_mvm_mld_rm_sta_id(mvm, vif, link->ap_sta_id);
+ ret = iwl_mvm_mld_rm_sta_id(mvm, link->ap_sta_id);
if (ret)
IWL_ERR(mvm, "failed to remove AP station\n");
@@ -458,6 +458,21 @@ static int iwl_mvm_mld_cfg_sta(struct iwl_mvm *mvm, struct ieee80211_sta *sta,
return iwl_mvm_mld_send_sta_cmd(mvm, &cmd);
}
+static void iwl_mvm_mld_free_sta_link(struct iwl_mvm *mvm,
+ struct iwl_mvm_sta *mvm_sta,
+ struct iwl_mvm_link_sta *mvm_sta_link,
+ unsigned int link_id,
+ bool is_in_fw)
+{
+ RCU_INIT_POINTER(mvm->fw_id_to_mac_id[mvm_sta_link->sta_id],
+ is_in_fw ? ERR_PTR(-EINVAL) : NULL);
+ RCU_INIT_POINTER(mvm->fw_id_to_link_sta[mvm_sta_link->sta_id], NULL);
+ RCU_INIT_POINTER(mvm_sta->link[link_id], NULL);
+
+ if (mvm_sta_link != &mvm_sta->deflink)
+ kfree_rcu(mvm_sta_link, rcu_head);
+}
+
static void iwl_mvm_mld_sta_rm_all_sta_links(struct iwl_mvm *mvm,
struct iwl_mvm_sta *mvm_sta)
{
@@ -471,28 +486,10 @@ static void iwl_mvm_mld_sta_rm_all_sta_links(struct iwl_mvm *mvm,
if (!link)
continue;
- RCU_INIT_POINTER(mvm->fw_id_to_mac_id[link->sta_id], NULL);
- RCU_INIT_POINTER(mvm->fw_id_to_link_sta[link->sta_id], NULL);
- RCU_INIT_POINTER(mvm_sta->link[link_id], NULL);
-
- if (link != &mvm_sta->deflink)
- kfree_rcu(link, rcu_head);
+ iwl_mvm_mld_free_sta_link(mvm, mvm_sta, link, link_id, false);
}
}
-static void iwl_mvm_mld_free_sta_link(struct iwl_mvm *mvm,
- struct iwl_mvm_sta *mvm_sta,
- struct iwl_mvm_link_sta *mvm_sta_link,
- unsigned int link_id)
-{
- RCU_INIT_POINTER(mvm->fw_id_to_mac_id[mvm_sta_link->sta_id], NULL);
- RCU_INIT_POINTER(mvm->fw_id_to_link_sta[mvm_sta_link->sta_id], NULL);
- RCU_INIT_POINTER(mvm_sta->link[link_id], NULL);
-
- if (mvm_sta_link != &mvm_sta->deflink)
- kfree_rcu(mvm_sta_link, rcu_head);
-}
-
static int iwl_mvm_mld_alloc_sta_link(struct iwl_mvm *mvm,
struct ieee80211_vif *vif,
struct ieee80211_sta *sta,
@@ -787,20 +784,24 @@ int iwl_mvm_mld_rm_sta(struct iwl_mvm *mvm, struct ieee80211_vif *vif,
struct iwl_mvm_link_sta *mvm_link_sta =
rcu_dereference_protected(mvm_sta->link[link_id],
lockdep_is_held(&mvm->mutex));
+ bool stay_in_fw;
- if (iwl_mvm_sta_del(mvm, vif, sta, mvm_link_sta, &ret))
- return ret;
+ stay_in_fw = iwl_mvm_sta_del(mvm, vif, sta, mvm_link_sta, &ret);
+ if (ret)
+ break;
- ret = iwl_mvm_mld_rm_sta_from_fw(mvm, mvm_link_sta->sta_id);
- }
+ if (!stay_in_fw)
+ ret = iwl_mvm_mld_rm_sta_from_fw(mvm,
+ mvm_link_sta->sta_id);
- iwl_mvm_mld_sta_rm_all_sta_links(mvm, mvm_sta);
+ iwl_mvm_mld_free_sta_link(mvm, mvm_sta, mvm_link_sta,
+ link_id, stay_in_fw);
+ }
return ret;
}
-int iwl_mvm_mld_rm_sta_id(struct iwl_mvm *mvm, struct ieee80211_vif *vif,
- u8 sta_id)
+int iwl_mvm_mld_rm_sta_id(struct iwl_mvm *mvm, u8 sta_id)
{
int ret = iwl_mvm_mld_rm_sta_from_fw(mvm, sta_id);
@@ -978,7 +979,8 @@ int iwl_mvm_mld_update_sta_links(struct iwl_mvm *mvm,
if (vif->type == NL80211_IFTYPE_STATION)
mvm_vif_link->ap_sta_id = IWL_MVM_INVALID_STA;
- iwl_mvm_mld_free_sta_link(mvm, mvm_sta, mvm_sta_link, link_id);
+ iwl_mvm_mld_free_sta_link(mvm, mvm_sta, mvm_sta_link, link_id,
+ false);
}
for_each_set_bit(link_id, &links_to_add, IEEE80211_MLD_MAX_NUM_LINKS) {
@@ -1054,7 +1056,8 @@ int iwl_mvm_mld_update_sta_links(struct iwl_mvm *mvm,
rcu_dereference_protected(mvm_sta->link[link_id],
lockdep_is_held(&mvm->mutex));
- iwl_mvm_mld_free_sta_link(mvm, mvm_sta, mvm_sta_link, link_id);
+ iwl_mvm_mld_free_sta_link(mvm, mvm_sta, mvm_sta_link, link_id,
+ false);
}
return ret;
@@ -639,8 +639,7 @@ int iwl_mvm_mld_update_sta(struct iwl_mvm *mvm, struct ieee80211_vif *vif,
struct ieee80211_sta *sta);
int iwl_mvm_mld_rm_sta(struct iwl_mvm *mvm, struct ieee80211_vif *vif,
struct ieee80211_sta *sta);
-int iwl_mvm_mld_rm_sta_id(struct iwl_mvm *mvm, struct ieee80211_vif *vif,
- u8 sta_id);
+int iwl_mvm_mld_rm_sta_id(struct iwl_mvm *mvm, u8 sta_id);
int iwl_mvm_mld_update_sta_links(struct iwl_mvm *mvm,
struct ieee80211_vif *vif,
struct ieee80211_sta *sta,