diff mbox series

[2/3] iw: scan: change parsing from in-place to cached

Message ID 20240930181145.1043048-3-dylan.eskew@candelatech.com (mailing list archive)
State New
Delegated to: Johannes Berg
Headers show
Series iw: scan: ie parsing restructure | expand

Commit Message

Dylan Eskew Sept. 30, 2024, 6:11 p.m. UTC
Since some ies require references to other ies, this
introduces the infrastructure to prevent double parsing
the ie buffer by caching the ie data rather than reading
it in-place.

Signed-off-by: Dylan Eskew <dylan.eskew@candelatech.com>
---
 ieee80211.h |  39 +++++++++++++++
 scan.c      | 138 +++++++++++++++++++++++++++++++++++++++-------------
 2 files changed, 142 insertions(+), 35 deletions(-)
diff mbox series

Patch

diff --git a/ieee80211.h b/ieee80211.h
index 96d5c52..dc14096 100644
--- a/ieee80211.h
+++ b/ieee80211.h
@@ -101,6 +101,45 @@  enum elem_id_ext {
 	EID_EXT_HE_CAPABILITY = 35,
 };
 
+struct ieee80211_elems {
+	const __u8 *ie_start;
+	size_t total_length;
+
+	__u8 seen[32];
+	__u8 seen_ext[32];
+
+	const __u8 *ies[NUM_IES];
+	const __u8 *ies_ext[NUM_IES];
+
+	/* payload lengths of each ie */
+	__u8 lengths[NUM_IES];
+	__u8 lengths_ext[NUM_IES];
+};
+
+static inline void set_seen(struct ieee80211_elems *elems,
+			    const uint8_t type)
+{
+	elems->seen[type / 8] |= 1 << (type % 8);
+}
+
+static inline void set_seen_ext(struct ieee80211_elems *elems,
+				const uint8_t type)
+{
+	elems->seen_ext[type / 8] |= 1 << (type % 8);
+}
+
+static inline int was_seen(struct ieee80211_elems *elems,
+			   const uint8_t type)
+{
+	return elems->seen[type / 8] & (1 << (type % 8));
+}
+
+static inline int was_seen_ext(struct ieee80211_elems *elems,
+			       const uint8_t type)
+{
+	return elems->seen_ext[type / 8] & (1 << (type % 8));
+}
+
 #define SUITE(oui, id)  (((oui) << 8) | (id))
 
 /* cipher suite selectors */
diff --git a/scan.c b/scan.c
index 83b7f58..6fdaf0b 100644
--- a/scan.c
+++ b/scan.c
@@ -1782,6 +1782,12 @@  struct ie_print {
 	uint8_t flags;
 };
 
+struct element {
+	__u8 id;
+	__u8 datalen;
+	__u8 data[];
+} __attribute__ ((packed));
+
 static void print_ie(const struct ie_print *p, const uint8_t type, uint8_t len,
 		     const uint8_t *data,
 		     const struct print_ies_data *ie_buffer)
@@ -2332,7 +2338,7 @@  static const struct ie_print wfa_printers[] = {
 	[28] = { "OWE Transition Mode", print_wifi_owe_tarns, 7, 255, BIT(PRINT_SCAN), },
 };
 
-static void print_vendor(unsigned char len, unsigned char *data,
+static void print_vendor(unsigned char len, const unsigned char *data,
 			 bool unknown, enum print_ie_type ptype)
 {
 	int i;
@@ -2402,31 +2408,88 @@  static const struct ie_print ext_printers[] = {
 	[EID_EXT_HE_CAPABILITY] = { "HE capabilities", print_he_capa, 21, 54, BIT(PRINT_SCAN), },
 };
 
-static void print_extension(unsigned char len, unsigned char *ie,
-			    bool unknown, enum print_ie_type ptype)
+
+static void print_extension(bool unknown, struct ieee80211_elems *elems,
+			    enum print_ie_type ptype)
 {
-	unsigned char tag;
+	unsigned i;
 
-	if (len < 1) {
-		printf("\tExtension IE: <empty>\n");
+	if (elems == NULL)
 		return;
+
+	for (i = 0; i < NUM_IES; i++) {
+		if (!was_seen_ext(elems, i))
+			continue;
+
+		if (elems->lengths_ext[i] < 1) {
+			printf("\tExtension IE %u: <empty>\n", i);
+			continue;
+		}
+
+		if (i < ARRAY_SIZE(ext_printers) &&
+		    ext_printers[i].name &&
+		    ext_printers[i].flags & BIT(ptype) &&
+		    elems->lengths_ext[i] > 0) {
+			print_ie(&ext_printers[i], i, elems->lengths_ext[i] - 1,
+				 elems->ies_ext[i], NULL);
+		} else if (unknown) {
+			int j;
+
+			printf("\tUnknown Extension ID (%d):", i);
+			for (j = 1; j < elems->lengths_ext[i]; j++)
+				printf(" %.2x", elems->ies_ext[i][j]);
+			printf("\n");
+		}
 	}
+}
 
-	tag = ie[0];
-	if (tag < ARRAY_SIZE(ext_printers) && ext_printers[tag].name &&
-	    ext_printers[tag].flags & BIT(ptype)) {
-		print_ie(&ext_printers[tag], tag, len - 1, ie + 1, NULL);
+static void parse_ie_ext(const struct element *elem,
+			 struct ieee80211_elems *elems)
+{
+	const uint8_t type = elem->data[0];
+	const uint8_t *data = elem->data + 1;
+
+	if (!elem->datalen)
 		return;
-	}
 
-	if (unknown) {
-		int i;
+	set_seen_ext(elems, type);
 
-		printf("\tUnknown Extension ID (%d):", ie[0]);
-		for (i = 1; i < len; i++)
-			printf(" %.2x", ie[i]);
-		printf("\n");
+	elems->lengths_ext[type] = elem->datalen;
+
+	if (elem->datalen)
+		elems->ies_ext[type] = data;
+}
+
+static int parse_ies(unsigned char *ie, int ielen, bool unknown,
+		     struct ieee80211_elems *elems)
+{
+	const struct element *elem;
+
+	if (ie == NULL || ielen < 0 || elems == NULL)
+		return 1;
+
+	elems->ie_start = ie;
+	elems->total_length = ielen;
+
+	elem = (const struct element *)ie;
+	while (ielen >= 2 && ielen - 2 > elem->datalen) {
+		uint8_t id = elem->id;
+		uint8_t elen = elem->datalen;
+		const uint8_t *data = elem->data;
+
+		set_seen(elems, id);
+
+		if (id == EID_EXTENSION) {
+			parse_ie_ext(elem, elems);
+		} else if (elen) {
+			elems->lengths[id] = elen;
+			elems->ies[id] = data;
+		}
+
+		ielen -= elem->datalen + 2;
+		elem = (const struct element *)(elem->data + elem->datalen);
 	}
+	return 0;
 }
 
 void print_ies(unsigned char *ie, int ielen, bool unknown,
@@ -2435,31 +2498,36 @@  void print_ies(unsigned char *ie, int ielen, bool unknown,
 	struct print_ies_data ie_buffer = {
 		.ie = ie,
 		.ielen = ielen };
+	struct ieee80211_elems elems = { };
+	unsigned i;
 
-	if (ie == NULL || ielen < 0)
+	if (parse_ies(ie, ielen, unknown, &elems))
 		return;
 
-	while (ielen >= 2 && ielen - 2 >= ie[1]) {
-		if (ie[0] < ARRAY_SIZE(ieprinters) &&
-		    ieprinters[ie[0]].name &&
-		    ieprinters[ie[0]].flags & BIT(ptype) &&
-			    ie[1] > 0) {
-			print_ie(&ieprinters[ie[0]],
-				 ie[0], ie[1], ie + 2, &ie_buffer);
-		} else if (ie[0] == 221 /* vendor */) {
-			print_vendor(ie[1], ie + 2, unknown, ptype);
-		} else if (ie[0] == 255 /* extension */) {
-			print_extension(ie[1], ie + 2, unknown, ptype);
+	for (i = 0; i < NUM_IES; i++) {
+		if (!was_seen(&elems, i))
+			continue;
+
+		if (i < ARRAY_SIZE(ieprinters) &&
+		    ieprinters[i].name &&
+		    ieprinters[i].flags & BIT(ptype) &&
+		    elems.lengths[i] > 0) {
+			print_ie(&ieprinters[i], i, elems.lengths[i],
+				 elems.ies[i], &ie_buffer);
+		}
+		else if (i == EID_VENDOR) {
+			print_vendor(elems.lengths[i], elems.ies[i],
+				     unknown, ptype);
+		} else if (i == EID_EXTENSION) {
+			print_extension(unknown, &elems, ptype);
 		} else if (unknown) {
-			int i;
+			int j;
 
-			printf("\tUnknown IE (%d):", ie[0]);
-			for (i=0; i<ie[1]; i++)
-				printf(" %.2x", ie[2+i]);
+			printf("\tUnknown IE (%d):", i);
+			for (j = 0; j < elems.lengths[i]; j++)
+				printf("  %.2x", elems.ies[i][j]);
 			printf("\n");
 		}
-		ielen -= ie[1] + 2;
-		ie += ie[1] + 2;
 	}
 }