diff mbox series

[16/19] wifi: mac80211: verify BSS membership selectors and basic rates

Message ID 20250101070249.e58a0f34c798.Ifeb3bfd7b157ffa2ccdb20ca1cba6cf068fd117d@changeid (mailing list archive)
State New
Delegated to: Johannes Berg
Headers show
Series wifi: mac80211: updates - 30-12-24 | expand

Commit Message

Miri Korenblit Jan. 1, 2025, 5:05 a.m. UTC
From: Benjamin Berg <benjamin.berg@intel.com>

We should not attempt a connection if the BSS we are connecting to
requires support for a basic rate or other feature using the BSS
membership selector. Add a check verifying this.

Signed-off-by: Benjamin Berg <benjamin.berg@intel.com>
Reviewed-by: Johannes Berg <johannes.berg@intel.com>
Signed-off-by: Miri Korenblit <miriam.rachel.korenblit@intel.com>
---
 net/mac80211/ieee80211_i.h |   4 +
 net/mac80211/mlme.c        | 193 ++++++++++++++++++++++++-------------
 2 files changed, 132 insertions(+), 65 deletions(-)
diff mbox series

Patch

diff --git a/net/mac80211/ieee80211_i.h b/net/mac80211/ieee80211_i.h
index a98133d5c362..8081607ccbd3 100644
--- a/net/mac80211/ieee80211_i.h
+++ b/net/mac80211/ieee80211_i.h
@@ -404,6 +404,8 @@  struct ieee80211_mgd_auth_data {
 	int tries;
 	u16 algorithm, expected_transaction;
 
+	unsigned long userspace_selectors[BITS_TO_LONGS(128)];
+
 	u8 key[WLAN_KEY_LEN_WEP104];
 	u8 key_len, key_idx;
 	bool done, waiting;
@@ -444,6 +446,8 @@  struct ieee80211_mgd_assoc_data {
 	const u8 *supp_rates;
 	u8 supp_rates_len;
 
+	unsigned long userspace_selectors[BITS_TO_LONGS(128)];
+
 	unsigned long timeout;
 	int tries;
 
diff --git a/net/mac80211/mlme.c b/net/mac80211/mlme.c
index a1197ac68c98..9735ce4adffe 100644
--- a/net/mac80211/mlme.c
+++ b/net/mac80211/mlme.c
@@ -694,6 +694,63 @@  ieee80211_verify_sta_eht_mcs_support(struct ieee80211_sub_if_data *sdata,
 	return true;
 }
 
+static void ieee80211_get_rates(struct ieee80211_supported_band *sband,
+				const u8 *supp_rates,
+				unsigned int supp_rates_len,
+				u32 *rates, u32 *basic_rates,
+				unsigned long *unknown_rates_selectors,
+				bool *have_higher_than_11mbit,
+				int *min_rate, int *min_rate_index)
+{
+	int i, j;
+
+	for (i = 0; i < supp_rates_len; i++) {
+		int rate = supp_rates[i] & 0x7f;
+		bool is_basic = !!(supp_rates[i] & 0x80);
+
+		if ((rate * 5) > 110 && have_higher_than_11mbit)
+			*have_higher_than_11mbit = true;
+
+		/*
+		 * Skip membership selectors since they're not rates.
+		 *
+		 * Note: Even though the membership selector and the basic
+		 *	 rate flag share the same bit, they are not exactly
+		 *	 the same.
+		 */
+		if (is_basic && rate >= BSS_MEMBERSHIP_SELECTOR_MIN) {
+			if (unknown_rates_selectors)
+				set_bit(rate, unknown_rates_selectors);
+			continue;
+		}
+
+		for (j = 0; j < sband->n_bitrates; j++) {
+			struct ieee80211_rate *br;
+			int brate;
+
+			br = &sband->bitrates[j];
+
+			brate = DIV_ROUND_UP(br->bitrate, 5);
+			if (brate == rate) {
+				if (rates)
+					*rates |= BIT(j);
+				if (is_basic && basic_rates)
+					*basic_rates |= BIT(j);
+				if (min_rate && (rate * 5) < *min_rate) {
+					*min_rate = rate * 5;
+					if (min_rate_index)
+						*min_rate_index = j;
+				}
+				break;
+			}
+		}
+
+		/* Handle an unknown entry as if it is an unknown selector */
+		if (is_basic && unknown_rates_selectors && j == sband->n_bitrates)
+			set_bit(rate, unknown_rates_selectors);
+	}
+}
+
 static bool ieee80211_chandef_usable(struct ieee80211_sub_if_data *sdata,
 				     const struct cfg80211_chan_def *chandef,
 				     u32 prohibited_flags)
@@ -924,7 +981,8 @@  ieee80211_determine_chan_mode(struct ieee80211_sub_if_data *sdata,
 			      struct ieee80211_conn_settings *conn,
 			      struct cfg80211_bss *cbss, int link_id,
 			      struct ieee80211_chan_req *chanreq,
-			      struct cfg80211_chan_def *ap_chandef)
+			      struct cfg80211_chan_def *ap_chandef,
+			      unsigned long *userspace_selectors)
 {
 	const struct cfg80211_bss_ies *ies = rcu_dereference(cbss->ies);
 	struct ieee80211_bss *bss = (void *)cbss->priv;
@@ -938,6 +996,8 @@  ieee80211_determine_chan_mode(struct ieee80211_sub_if_data *sdata,
 	struct ieee802_11_elems *elems;
 	struct ieee80211_supported_band *sband;
 	enum ieee80211_conn_mode ap_mode;
+	unsigned long unknown_rates_selectors[BITS_TO_LONGS(128)] = {};
+	unsigned long sta_selectors[BITS_TO_LONGS(128)] = {};
 	int ret;
 
 again:
@@ -966,6 +1026,10 @@  ieee80211_determine_chan_mode(struct ieee80211_sub_if_data *sdata,
 
 	sband = sdata->local->hw.wiphy->bands[channel->band];
 
+	ieee80211_get_rates(sband, elems->supp_rates, elems->supp_rates_len,
+			    NULL, NULL, unknown_rates_selectors, NULL, NULL,
+			    NULL);
+
 	switch (channel->band) {
 	case NL80211_BAND_S1GHZ:
 		if (WARN_ON(ap_mode != IEEE80211_CONN_MODE_S1G)) {
@@ -1016,6 +1080,29 @@  ieee80211_determine_chan_mode(struct ieee80211_sub_if_data *sdata,
 
 	chanreq->oper = *ap_chandef;
 
+	bitmap_copy(sta_selectors, userspace_selectors, 128);
+	if (conn->mode >= IEEE80211_CONN_MODE_HT)
+		set_bit(BSS_MEMBERSHIP_SELECTOR_HT_PHY, sta_selectors);
+	if (conn->mode >= IEEE80211_CONN_MODE_VHT)
+		set_bit(BSS_MEMBERSHIP_SELECTOR_VHT_PHY, sta_selectors);
+	if (conn->mode >= IEEE80211_CONN_MODE_HE)
+		set_bit(BSS_MEMBERSHIP_SELECTOR_HE_PHY, sta_selectors);
+	if (conn->mode >= IEEE80211_CONN_MODE_EHT)
+		set_bit(BSS_MEMBERSHIP_SELECTOR_EHT_PHY, sta_selectors);
+
+	/*
+	 * We do not support EPD or GLK so never add them.
+	 * SAE_H2E is handled through userspace_selectors.
+	 */
+
+	/* Check if we support all required features */
+	if (!bitmap_subset(unknown_rates_selectors, sta_selectors, 128)) {
+		link_id_info(sdata, link_id,
+			     "required basic rate or BSS membership selectors not supported or disabled, rejecting connection\n");
+		ret = -EINVAL;
+		goto free;
+	}
+
 	ieee80211_set_chanreq_ap(sdata, chanreq, conn, ap_chandef);
 
 	while (!ieee80211_chandef_usable(sdata, &chanreq->oper,
@@ -4729,62 +4816,6 @@  static void ieee80211_rx_mgmt_disassoc(struct ieee80211_sub_if_data *sdata,
 				    false);
 }
 
-static void ieee80211_get_rates(struct ieee80211_supported_band *sband,
-				u8 *supp_rates, unsigned int supp_rates_len,
-				u32 *rates, u32 *basic_rates,
-				unsigned long *unknown_rates_selectors,
-				bool *have_higher_than_11mbit,
-				int *min_rate, int *min_rate_index)
-{
-	int i, j;
-
-	for (i = 0; i < supp_rates_len; i++) {
-		int rate = supp_rates[i] & 0x7f;
-		bool is_basic = !!(supp_rates[i] & 0x80);
-
-		if ((rate * 5) > 110 && have_higher_than_11mbit)
-			*have_higher_than_11mbit = true;
-
-		/*
-		 * Skip membership selectors since they're not rates.
-		 *
-		 * Note: Even though the membership selector and the basic
-		 *	 rate flag share the same bit, they are not exactly
-		 *	 the same.
-		 */
-		if (is_basic && rate >= BSS_MEMBERSHIP_SELECTOR_MIN) {
-			if (unknown_rates_selectors)
-				set_bit(rate, unknown_rates_selectors);
-			continue;
-		}
-
-		for (j = 0; j < sband->n_bitrates; j++) {
-			struct ieee80211_rate *br;
-			int brate;
-
-			br = &sband->bitrates[j];
-
-			brate = DIV_ROUND_UP(br->bitrate, 5);
-			if (brate == rate) {
-				if (rates)
-					*rates |= BIT(j);
-				if (is_basic && basic_rates)
-					*basic_rates |= BIT(j);
-				if (min_rate && (rate * 5) < *min_rate) {
-					*min_rate = rate * 5;
-					if (min_rate_index)
-						*min_rate_index = j;
-				}
-				break;
-			}
-		}
-
-		/* Handle an unknown entry as if it is an unknown selector */
-		if (is_basic && unknown_rates_selectors && j == sband->n_bitrates)
-			set_bit(rate, unknown_rates_selectors);
-	}
-}
-
 static bool ieee80211_twt_req_supported(struct ieee80211_sub_if_data *sdata,
 					struct ieee80211_supported_band *sband,
 					const struct link_sta_info *link_sta,
@@ -5650,7 +5681,8 @@  static int ieee80211_prep_channel(struct ieee80211_sub_if_data *sdata,
 				  struct ieee80211_link_data *link,
 				  int link_id,
 				  struct cfg80211_bss *cbss, bool mlo,
-				  struct ieee80211_conn_settings *conn)
+				  struct ieee80211_conn_settings *conn,
+				  unsigned long *userspace_selectors)
 {
 	struct ieee80211_local *local = sdata->local;
 	bool is_6ghz = cbss->channel->band == NL80211_BAND_6GHZ;
@@ -5663,7 +5695,8 @@  static int ieee80211_prep_channel(struct ieee80211_sub_if_data *sdata,
 
 	rcu_read_lock();
 	elems = ieee80211_determine_chan_mode(sdata, conn, cbss, link_id,
-					      &chanreq, &ap_chandef);
+					      &chanreq, &ap_chandef,
+					      userspace_selectors);
 
 	if (IS_ERR(elems)) {
 		rcu_read_unlock();
@@ -5857,7 +5890,8 @@  static bool ieee80211_assoc_success(struct ieee80211_sub_if_data *sdata,
 			link->u.mgd.conn = assoc_data->link[link_id].conn;
 
 			err = ieee80211_prep_channel(sdata, link, link_id, cbss,
-						     true, &link->u.mgd.conn);
+						     true, &link->u.mgd.conn,
+						     assoc_data->userspace_selectors);
 			if (err) {
 				link_info(link, "prep_channel failed\n");
 				goto out_err;
@@ -8367,7 +8401,8 @@  static int ieee80211_prep_connection(struct ieee80211_sub_if_data *sdata,
 				     struct cfg80211_bss *cbss, s8 link_id,
 				     const u8 *ap_mld_addr, bool assoc,
 				     struct ieee80211_conn_settings *conn,
-				     bool override)
+				     bool override,
+				     unsigned long *userspace_selectors)
 {
 	struct ieee80211_local *local = sdata->local;
 	struct ieee80211_if_managed *ifmgd = &sdata->u.mgd;
@@ -8506,7 +8541,8 @@  static int ieee80211_prep_connection(struct ieee80211_sub_if_data *sdata,
 		 */
 		link->u.mgd.conn = *conn;
 		err = ieee80211_prep_channel(sdata, link, link->link_id, cbss,
-					     mlo, &link->u.mgd.conn);
+					     mlo, &link->u.mgd.conn,
+					     userspace_selectors);
 		if (err) {
 			if (new_sta)
 				sta_info_free(local, new_sta);
@@ -8622,6 +8658,22 @@  static bool ieee80211_mgd_csa_in_process(struct ieee80211_sub_if_data *sdata,
 	return ret;
 }
 
+static void ieee80211_parse_cfg_selectors(unsigned long *userspace_selectors,
+					  const u8 *supported_selectors,
+					  u8 supported_selectors_len)
+{
+	if (supported_selectors) {
+		for (int i = 0; i < supported_selectors_len; i++) {
+			set_bit(supported_selectors[i],
+				userspace_selectors);
+		}
+	} else {
+		/* Assume SAE_H2E support for backward compatibility. */
+		set_bit(BSS_MEMBERSHIP_SELECTOR_SAE_H2E,
+			userspace_selectors);
+	}
+}
+
 /* config hooks */
 int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
 		       struct cfg80211_auth_request *req)
@@ -8723,6 +8775,10 @@  int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
 		memcpy(auth_data->key, req->key, req->key_len);
 	}
 
+	ieee80211_parse_cfg_selectors(auth_data->userspace_selectors,
+				      req->supported_selectors,
+				      req->supported_selectors_len);
+
 	auth_data->algorithm = auth_alg;
 
 	/* try to authenticate/probe */
@@ -8776,7 +8832,8 @@  int ieee80211_mgd_auth(struct ieee80211_sub_if_data *sdata,
 
 	err = ieee80211_prep_connection(sdata, req->bss, req->link_id,
 					req->ap_mld_addr, cont_auth,
-					&conn, false);
+					&conn, false,
+					auth_data->userspace_selectors);
 	if (err)
 		goto err_clear;
 
@@ -9063,6 +9120,10 @@  int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata,
 					    false);
 	}
 
+	ieee80211_parse_cfg_selectors(assoc_data->userspace_selectors,
+				      req->supported_selectors,
+				      req->supported_selectors_len);
+
 	memcpy(&ifmgd->ht_capa, &req->ht_capa, sizeof(ifmgd->ht_capa));
 	memcpy(&ifmgd->ht_capa_mask, &req->ht_capa_mask,
 	       sizeof(ifmgd->ht_capa_mask));
@@ -9309,7 +9370,8 @@  int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata,
 		/* only calculate the mode, hence link == NULL */
 		err = ieee80211_prep_channel(sdata, NULL, i,
 					     assoc_data->link[i].bss, true,
-					     &assoc_data->link[i].conn);
+					     &assoc_data->link[i].conn,
+					     assoc_data->userspace_selectors);
 		if (err) {
 			req->links[i].error = err;
 			goto err_clear;
@@ -9325,7 +9387,8 @@  int ieee80211_mgd_assoc(struct ieee80211_sub_if_data *sdata,
 	err = ieee80211_prep_connection(sdata, cbss, req->link_id,
 					req->ap_mld_addr, true,
 					&assoc_data->link[assoc_link_id].conn,
-					override);
+					override,
+					assoc_data->userspace_selectors);
 	if (err)
 		goto err_clear;