diff mbox series

[RFC,bpf-next,3/4] selftests/bpf: implement custom Bloom filter purely in BPF

Message ID 20210922203224.912809-4-andrii@kernel.org (mailing list archive)
State RFC
Delegated to: BPF
Headers show
Series bpf_jhash_mem() and BPF Bloom filter implementation | expand

Checks

Context Check Description
netdev/apply fail Patch does not apply to bpf-next
netdev/tree_selection success Clearly marked for bpf-next
bpf/vmtest-bpf-next-PR fail merge-conflict

Commit Message

Andrii Nakryiko Sept. 22, 2021, 8:32 p.m. UTC
And integrate it into existing benchmarks (on BPF side). Posting this
separately from user-space benchmark to emphasize how little code is
necessary on BPF side to make this work.

Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
---
 .../selftests/bpf/progs/bloom_filter_map.c    | 81 ++++++++++++++++++-
 1 file changed, 80 insertions(+), 1 deletion(-)
diff mbox series

Patch

diff --git a/tools/testing/selftests/bpf/progs/bloom_filter_map.c b/tools/testing/selftests/bpf/progs/bloom_filter_map.c
index 3ae2f9bb5968..ee7bbde1af45 100644
--- a/tools/testing/selftests/bpf/progs/bloom_filter_map.c
+++ b/tools/testing/selftests/bpf/progs/bloom_filter_map.c
@@ -64,9 +64,43 @@  const __u32 drop_key  = 1;
 const __u32 false_hit_key = 2;
 
 const volatile bool hashmap_use_bloom_filter = true;
+const volatile bool hashmap_use_custom_bloom_filter = false;
+
+
+__u64 bloom_val = 0;
+static __u64 bloom_arr[256 * 1024]; /* enough for 1mln max_elems w/ 10 hash funcs */
+const volatile __u32 bloom_mask;
+const volatile __u32 bloom_hash_cnt;
+const volatile __u32 bloom_seed;
 
 int error = 0;
 
+static void bloom_add(const void *data, __u32 sz)
+{
+	__u32 i = 0, h;
+
+	for (i = 0; i < bloom_hash_cnt; i++) {
+		h = bpf_jhash_mem(data, sz, bloom_seed + i);
+		__sync_fetch_and_or(&bloom_arr[(h / 64) & bloom_mask], 1ULL << (h % 64));
+	}
+}
+
+bool bloom_contains(__u64 val)
+{
+	__u32 i = 0, h;
+	__u32 seed = bloom_seed, msk = bloom_mask;
+	__u64 v, *arr = bloom_arr;
+
+	for (i = bloom_hash_cnt; i > 0; i--) {
+		h = bpf_jhash_mem(&val, sizeof(val), seed);
+		v = arr[(h / 64) & msk];
+		if (!((v >> (h % 64)) & 1))
+			return false;
+		seed++;
+	}
+	return true;
+}
+
 static __always_inline void log_result(__u32 key, __u32 val)
 {
 	__u32 cpu = bpf_get_smp_processor_id();
@@ -99,6 +133,20 @@  check_elem(struct bpf_map *map, __u32 *key, __u64 *val,
 	return 0;
 }
 
+static __u64
+check_elem_custom(struct bpf_map *map, __u32 *key, __u64 *val,
+	   struct callback_ctx *data)
+{
+	if (!bloom_contains(*val)) {
+		error |= 1;
+		return 1; /* stop the iteration */
+	}
+
+	log_result(hit_key, 1);
+
+	return 0;
+}
+
 SEC("fentry/__x64_sys_getpgid")
 int prog_bloom_filter(void *ctx)
 {
@@ -110,6 +158,26 @@  int prog_bloom_filter(void *ctx)
 	return 0;
 }
 
+SEC("fentry/__x64_sys_getpgid")
+int prog_custom_bloom_filter(void *ctx)
+{
+	bpf_for_each_map_elem(&map_random_data, check_elem_custom, NULL, 0);
+
+	return 0;
+}
+
+__u32 bloom_err = 0;
+__u32 bloom_custom_hit = 0;
+__u32 bloom_noncustom_hit = 0;
+
+SEC("fentry/__x64_sys_getpgid")
+int prog_custom_bloom_filter_add(void *ctx)
+{
+	bloom_add(&bloom_val, sizeof(bloom_val));
+
+	return 0;
+}
+
 SEC("fentry/__x64_sys_getpgid")
 int prog_bloom_filter_inner_map(void *ctx)
 {
@@ -145,6 +213,15 @@  int prog_bloom_filter_hashmap_lookup(void *ctx)
 		val.data32[0] = /*i; */ bpf_get_prandom_u32();
 		val.data32[1] = /*i + 1;*/ bpf_get_prandom_u32();
 
+		if (hashmap_use_custom_bloom_filter)
+		{
+			if (!bloom_contains(val.data64)) {
+				hits++;
+				//custom_hit = true;
+				//__sync_fetch_and_add(&bloom_custom_hit, 1);
+				continue;
+			}
+		} 
 		if (hashmap_use_bloom_filter)
 		{
 			err = bpf_map_peek_elem(&map_bloom_filter, &val);
@@ -160,11 +237,13 @@  int prog_bloom_filter_hashmap_lookup(void *ctx)
 			}
 		}
 
+		//bloom_err += (custom_hit != noncustom_hit);
+
 		result = bpf_map_lookup_elem(&hashmap, &val);
 		if (result) {
 			hits++;
 		} else {
-			if (hashmap_use_bloom_filter)
+			if (hashmap_use_custom_bloom_filter || hashmap_use_bloom_filter)
 				false_hits++;
 			drops++;
 		}