diff mbox series

[12/15] wifi: iwlwifi: mvm: set STA mask for keys in MLO

Message ID 20230414130637.cedae2f21829.Iae07b736c3109d085ad5b74ec8282ce45020da39@changeid (mailing list archive)
State Accepted
Delegated to: Johannes Berg
Headers show
Series wifi: iwlwifi: updates intended for v6.4 2023-04-14 | expand

Commit Message

Greenman, Gregory April 14, 2023, 10:12 a.m. UTC
From: Johannes Berg <johannes.berg@intel.com>

Implement the full STA mask and selecting the correct link
for key installation.

While at it, catch errors if this function returns a bad
zero station mask, rather than waiting for the firmware to
crash on it.

Signed-off-by: Johannes Berg <johannes.berg@intel.com>
Signed-off-by: Gregory Greenman <gregory.greenman@intel.com>
---
 .../net/wireless/intel/iwlwifi/mvm/mld-key.c  | 69 ++++++++++++++++---
 1 file changed, 60 insertions(+), 9 deletions(-)
diff mbox series

Patch

diff --git a/drivers/net/wireless/intel/iwlwifi/mvm/mld-key.c b/drivers/net/wireless/intel/iwlwifi/mvm/mld-key.c
index 9ec1c505002f..7c417b39aca4 100644
--- a/drivers/net/wireless/intel/iwlwifi/mvm/mld-key.c
+++ b/drivers/net/wireless/intel/iwlwifi/mvm/mld-key.c
@@ -14,23 +14,68 @@  static u32 iwl_mvm_get_sec_sta_mask(struct iwl_mvm *mvm,
 				    struct ieee80211_key_conf *keyconf)
 {
 	struct iwl_mvm_vif *mvmvif = iwl_mvm_vif_from_mac80211(vif);
+	struct iwl_mvm_vif_link_info *link_info = &mvmvif->deflink;
+	struct iwl_mvm_link_sta *link_sta;
+	struct iwl_mvm_sta *mvmsta;
+	u32 result = 0;
+	int link_id;
 
+	lockdep_assert_held(&mvm->mutex);
+
+	if (keyconf->link_id >= 0) {
+		link_info = mvmvif->link[keyconf->link_id];
+		if (!link_info)
+			return 0;
+	}
+
+	/* AP group keys are per link and should be on the mcast STA */
 	if (vif->type == NL80211_IFTYPE_AP &&
 	    !(keyconf->flags & IEEE80211_KEY_FLAG_PAIRWISE))
-		return BIT(mvmvif->deflink.mcast_sta.sta_id);
+		return BIT(link_info->mcast_sta.sta_id);
+
+	/* for client mode use the AP STA also for group keys */
+	if (!sta && vif->type == NL80211_IFTYPE_STATION)
+		sta = mvmvif->ap_sta;
+
+	/* During remove the STA was removed and the group keys come later
+	 * (which sounds like a bad sequence, but remember that to mac80211 the
+	 * group keys have no sta pointer), so we don't have a STA now.
+	 * Since this happens for group keys only, just use the link_info as
+	 * the group keys are per link; make sure that is the case by checking
+	 * we do have a link_id or are not doing MLO.
+	 * Of course the same can be done during add as well, but we must do
+	 * it during remove, since we don't have the mvmvif->ap_sta pointer.
+	 */
+	if (!sta && (keyconf->link_id >= 0 || !vif->valid_links))
+		return BIT(link_info->ap_sta_id);
 
-	if (sta) {
-		struct iwl_mvm_sta *mvmsta = iwl_mvm_sta_from_mac80211(sta);
+	/* this shouldn't happen now */
+	if (!sta)
+		return 0;
 
+	mvmsta = iwl_mvm_sta_from_mac80211(sta);
+
+	/* it's easy when the STA is not an MLD */
+	if (!sta->valid_links)
 		return BIT(mvmsta->deflink.sta_id);
-	}
 
-	if (vif->type == NL80211_IFTYPE_STATION &&
-	    mvmvif->deflink.ap_sta_id != IWL_MVM_INVALID_STA)
-		return BIT(mvmvif->deflink.ap_sta_id);
+	/* but if it is an MLD, get the mask of all the FW STAs it has ... */
+	for (link_id = 0; link_id < ARRAY_SIZE(mvmsta->link); link_id++) {
+		/* unless we have a specific link in mind (GTK on client) */
+		if (keyconf->link_id >= 0 &&
+		    keyconf->link_id != link_id)
+			continue;
+
+		link_sta =
+			rcu_dereference_protected(mvmsta->link[link_id],
+						  lockdep_is_held(&mvm->mutex));
+		if (!link_sta)
+			continue;
+
+		result |= BIT(link_sta->sta_id);
+	}
 
-	/* invalid */
-	return 0;
+	return result;
 }
 
 static u32 iwl_mvm_get_sec_flags(struct iwl_mvm *mvm,
@@ -113,6 +158,9 @@  int iwl_mvm_sec_key_add(struct iwl_mvm *mvm,
 	if (WARN_ON(keyconf->keylen > sizeof(cmd.u.add.key)))
 		return -EINVAL;
 
+	if (WARN_ON(!sta_mask))
+		return -EINVAL;
+
 	if (keyconf->cipher == WLAN_CIPHER_SUITE_WEP40 ||
 	    keyconf->cipher == WLAN_CIPHER_SUITE_WEP104)
 		memcpy(cmd.u.add.key + IWL_SEC_WEP_KEY_OFFSET, keyconf->key,
@@ -159,6 +207,9 @@  static int _iwl_mvm_sec_key_del(struct iwl_mvm *mvm,
 	u32 key_flags = iwl_mvm_get_sec_flags(mvm, vif, sta, keyconf);
 	int ret;
 
+	if (WARN_ON(!sta_mask))
+		return -EINVAL;
+
 	ret = __iwl_mvm_sec_key_del(mvm, sta_mask, key_flags, keyconf->keyidx,
 				    flags);
 	if (ret)