diff mbox series

[RFC,2/2] mm: add a BPF-based per-process heatmap tool

Message ID 20220911083418.2818369-3-yuanchu@google.com (mailing list archive)
State New
Headers show
Series mm: multi-gen LRU: per-process heatmaps | expand

Commit Message

Yuanchu Xie Sept. 11, 2022, 8:34 a.m. UTC
The heatmap tool uses a bpf program integrated with a TUI as a proof of
concept for consuming page access information from MGLRU. It displays
heats, NUMA node, and anon/other with configurable aging intervals and
aggregation intervals.

Signed-off-by: Yuanchu Xie <yuanchu@google.com>
---
 tools/vm/heatmap/Makefile       |  30 ++
 tools/vm/heatmap/heatmap.bpf.c  | 123 +++++++
 tools/vm/heatmap/heatmap.user.c | 188 ++++++++++
 tools/vm/heatmap/heatmap_tui.py | 600 ++++++++++++++++++++++++++++++++
 4 files changed, 941 insertions(+)
 create mode 100644 tools/vm/heatmap/Makefile
 create mode 100644 tools/vm/heatmap/heatmap.bpf.c
 create mode 100644 tools/vm/heatmap/heatmap.user.c
 create mode 100755 tools/vm/heatmap/heatmap_tui.py
diff mbox series

Patch

diff --git a/tools/vm/heatmap/Makefile b/tools/vm/heatmap/Makefile
new file mode 100644
index 000000000000..43ae4af67781
--- /dev/null
+++ b/tools/vm/heatmap/Makefile
@@ -0,0 +1,30 @@ 
+# SPDX-License-Identifier: GPL-2.0
+include ../../build/Build.include
+
+MAKEFLAGS += --no-builtin-rules
+MAKEFLAGS += --no-builtin-variables
+
+CC := clang
+KERNEL_DIR := $(abspath ../../..)
+
+.PHONY: clean all
+
+all: heatmap.user heatmap.skel.h vmlinux.h
+
+%.user: %.user.o
+	$(CC) -g -O2 $^ -l:libbpf.a -lelf -lz -o $@ -static
+
+vmlinux.h: $(KERNEL_DIR)/vmlinux
+	bpftool btf dump file $(KERNEL_DIR)/vmlinux format c > vmlinux.h
+
+%.bpf.o: vmlinux.h %.bpf.c
+	$(CC) -g -D__TARGET_ARCH_x86 -O2 -target bpf -c $*.bpf.c -o $*.bpf.o
+
+%.skel.h: %.bpf.o
+	bpftool gen skeleton $*.bpf.o > $*.skel.h
+
+%.user.o: %.skel.h %.user.c
+	$(CC) -g -O2 -c $*.user.c -o $*.user.o
+
+clean:
+	rm -f *.o heatmap.user *.skel.h vmlinux.h
diff --git a/tools/vm/heatmap/heatmap.bpf.c b/tools/vm/heatmap/heatmap.bpf.c
new file mode 100644
index 000000000000..924d896a0c4f
--- /dev/null
+++ b/tools/vm/heatmap/heatmap.bpf.c
@@ -0,0 +1,123 @@ 
+// SPDX-License-Identifier: GPL-2.0
+#include "vmlinux.h"
+
+#include <bpf/bpf_helpers.h>
+#include <bpf/bpf_tracing.h>
+
+static const u8 one = 1;
+pid_t target_pid;
+#define MAP_SHIFT (12 + 9)
+#define MIXED_NODES -1
+#define MIXED_MEM -1
+
+struct region_stat {
+	u16 accesses;
+	s8 mem_type; /* NON_ANON, ANON */
+	s8 node_id;
+};
+
+struct heatmap_outer {
+	__uint(type, BPF_MAP_TYPE_HASH);
+	__uint(map_flags, BPF_F_NO_PREALLOC);
+	__uint(max_entries, 1000000);
+	__type(key, u64);
+	__type(value, struct region_stat);
+} heatmap SEC(".maps");
+
+int probe(unsigned int nid, unsigned long addr, unsigned long len, bool anon)
+{
+	u64 map_key = addr >> MAP_SHIFT;
+	struct region_stat *region = bpf_map_lookup_elem(&heatmap, &map_key);
+	int err;
+	struct region_stat to_insert;
+
+	if (!region) {
+		to_insert.accesses = len;
+		to_insert.mem_type = anon;
+		to_insert.node_id = nid;
+		err = bpf_map_update_elem(&heatmap, &map_key, &to_insert,
+					  BPF_NOEXIST);
+		if (err)
+			return err;
+	} else {
+		region->accesses += len;
+		if (region->node_id != (int)nid)
+			region->node_id = MIXED_NODES;
+		if (region->mem_type != anon)
+			region->mem_type = MIXED_MEM;
+		err = bpf_map_update_elem(&heatmap, &map_key, region,
+					  BPF_EXIST);
+		if (err)
+			return err;
+	}
+	return 0;
+}
+
+SEC("fentry/mglru_pte_probe")
+int BPF_PROG(fentry_mglru_pte_probe, pid_t pid, unsigned int nid,
+	     unsigned long addr, unsigned long len, bool anon)
+{
+	int err;
+
+	if (pid != target_pid)
+		return 0;
+	err = probe(nid, addr, len, anon);
+	if (err)
+		bpf_printk("PTE called addr:0x%lx len:%lu error:%ld", addr, len,
+			   err);
+	return 0;
+}
+
+SEC("fentry/mglru_pmd_probe")
+int BPF_PROG(fentry_mglru_pmd_probe, pid_t pid, unsigned int nid,
+	     unsigned long addr, unsigned long len, bool anon)
+{
+	int err;
+
+	if (pid != target_pid)
+		return 0;
+	err = probe(nid, addr, len, anon);
+	if (err)
+		bpf_printk("PMD called addr:0x%lx len:%lu error:%ld", addr, len,
+			   err);
+	return 0;
+}
+
+extern void
+bpf_set_skip_mm(struct bpf_mglru_should_skip_mm_control *should_skip) __ksym;
+
+SEC("fentry/bpf_mglru_should_skip_mm")
+int BPF_PROG(bpf_mglru_should_skip_mm,
+	     struct bpf_mglru_should_skip_mm_control *ctl)
+{
+	if (ctl->pid != target_pid) {
+		bpf_printk("aging wrong pid");
+		bpf_set_skip_mm(ctl);
+	}
+	return 0;
+}
+
+extern int bpf_run_aging(int memcg_id, bool can_swap, bool force_scan) __ksym;
+
+struct args {
+	int memcg_id;
+};
+
+SEC("syscall")
+int memcg_run_aging(struct args *ctx)
+{
+	int err;
+
+	err = bpf_run_aging(ctx->memcg_id, true, true);
+
+	if (err != 0) {
+		bpf_printk("aging failed for memcg %ld with error %ld",
+			   ctx->memcg_id, err);
+		return 0;
+	}
+
+	bpf_printk("aging succeeded for memcg %ld", ctx->memcg_id);
+	return 0;
+}
+
+char LICENSE[] SEC("license") = "GPL";
diff --git a/tools/vm/heatmap/heatmap.user.c b/tools/vm/heatmap/heatmap.user.c
new file mode 100644
index 000000000000..094ba1e49233
--- /dev/null
+++ b/tools/vm/heatmap/heatmap.user.c
@@ -0,0 +1,188 @@ 
+// SPDX-License-Identifier: GPL-2.0
+#include <bpf/bpf.h>
+#include <bpf/libbpf.h>
+#include <fcntl.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/mount.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "heatmap.skel.h"
+
+#define MAP_SHIFT (12 + 9)
+
+static bool terminate;
+
+struct args {
+	int memcg_id;
+};
+
+void handle_sigint(int sig)
+{
+	terminate = true;
+}
+
+static int libbpf_print_fn(enum libbpf_print_level level, const char *format,
+			   va_list args)
+{
+	return vfprintf(stderr, format, args);
+}
+
+int run_aging(int aging_fd, int memcg_id)
+{
+	struct args ctx = {
+		.memcg_id = memcg_id,
+	};
+	LIBBPF_OPTS(bpf_test_run_opts, tattr, .ctx_in = &ctx,
+		    .ctx_size_in = sizeof(ctx));
+	return bpf_prog_test_run_opts(aging_fd, &tattr);
+}
+
+int attach_progs(pid_t pid, struct heatmap_bpf **heatmap_obj, int *aging_fd,
+		 int *heatmap_fd)
+{
+	int err;
+	int fd;
+	struct heatmap_bpf *obj;
+
+	obj = heatmap_bpf__open();
+	if (obj == NULL) {
+		perror("Error when opening heatmap bpf object");
+		return -1;
+	}
+	obj->bss->target_pid = pid;
+
+	err = heatmap_bpf__load(obj);
+	if (err) {
+		perror("Error loading heatmap bpf object");
+		goto cleanup;
+	}
+
+	fd = bpf_program__fd(obj->progs.memcg_run_aging);
+
+	err = heatmap_bpf__attach(obj);
+	if (err) {
+		perror("Error attaching heatmap bpf object");
+		goto cleanup;
+	}
+
+	*aging_fd = fd;
+	*heatmap_fd = bpf_map__fd(obj->maps.heatmap);
+	*heatmap_obj = obj;
+	return 0;
+
+cleanup:
+	heatmap_bpf__destroy(obj);
+	return err;
+}
+
+int bpf_map_delete_and_get_next_key(int fd, const void *key, void *next_key)
+{
+	int err = bpf_map_get_next_key(fd, key, next_key);
+
+	bpf_map_delete_elem(fd, key);
+	return err;
+}
+
+struct region_stat {
+	__u16 accesses;
+	__s8 mem_type; /* NON_ANON, ANON */
+	__s8 node_id;
+};
+
+void dump_map(int fd)
+{
+	__u64 prev_key, key;
+	struct region_stat value;
+	int err;
+
+	while (bpf_map_delete_and_get_next_key(fd, &prev_key, &key) == 0) {
+		err = bpf_map_lookup_elem(fd, &key, &value);
+		if (err < 0) {
+			/* impossible if we don't have racing deletions */
+			exit(-1);
+		}
+		printf("%llu %u %d %d\n", key << MAP_SHIFT, value.accesses,
+		       value.mem_type, value.node_id);
+		prev_key = key;
+	}
+}
+
+void detach_progs(struct heatmap_bpf *heatmap_obj)
+{
+	heatmap_bpf__detach(heatmap_obj);
+	heatmap_bpf__destroy(heatmap_obj);
+}
+
+int main(void)
+{
+	struct heatmap_bpf *heatmap_obj = NULL;
+	int aging_fd = -1;
+	int heatmap_fd = -1;
+	int memcg_id = -1;
+	int err;
+
+	signal(SIGINT, handle_sigint);
+	setvbuf(stdout, NULL, _IONBF, BUFSIZ);
+	libbpf_set_print(libbpf_print_fn);
+
+	while (!terminate) {
+		char *buffer = NULL;
+
+		if (scanf("%ms", &buffer) == 1) {
+			if (strcmp(buffer, "exit") == 0) {
+				printf("No hard feelings.\n");
+				exit(0);
+
+			} else if (heatmap_obj == NULL &&
+				   strcmp(buffer, "attach") == 0) {
+				pid_t pid_;
+				int memcg_id_;
+
+				if (scanf("%d %d", &pid_, &memcg_id_) == 2) {
+					err = attach_progs(pid_, &heatmap_obj,
+							   &aging_fd,
+							   &heatmap_fd);
+					if (err) {
+						printf("error: aging %d\n",
+						       err);
+						goto next;
+					}
+					memcg_id = memcg_id_;
+					printf("success: attach\n");
+
+				} else
+					printf("error: invalid arguments\n");
+
+			} else if (heatmap_obj != NULL) {
+				if (strcmp(buffer, "map") == 0) {
+					dump_map(heatmap_fd);
+					printf("success: map\n");
+
+				} else if (strcmp(buffer, "age") == 0) {
+					err = run_aging(aging_fd, memcg_id);
+					if (err)
+						printf("error: age %d\n", err);
+					else
+						printf("success: age\n");
+
+				} else if (strcmp(buffer, "detach") == 0) {
+					detach_progs(heatmap_obj);
+					heatmap_obj = NULL;
+					heatmap_fd = -1;
+					aging_fd = -1;
+					memcg_id = -1;
+					printf("success: detach\n");
+				}
+
+			} else
+				printf("error: invalid command\n");
+
+next:
+			free(buffer);
+		} else
+			printf("error: invalid command\n");
+	}
+}
diff --git a/tools/vm/heatmap/heatmap_tui.py b/tools/vm/heatmap/heatmap_tui.py
new file mode 100755
index 000000000000..9be6b611bc24
--- /dev/null
+++ b/tools/vm/heatmap/heatmap_tui.py
@@ -0,0 +1,600 @@ 
+#!/usr/bin/env python3
+# SPDX-License-Identifier: GPL-2.0
+import urwid
+import parse
+import subprocess
+import os
+
+from time import time
+from math import floor
+
+UPDATE_INTERVAL = 0.5
+MEM_MAX = 0
+MEM_MIN = 0x10000000000000000
+INVALID_NODE = -2
+MIXED_NODE = -1
+INVALID_MEM = -2
+MIXED_MEM = -1
+ANON_MEM = 1
+MAX_AGGREGATIONS = 20
+MIN_ROWS = 40
+MAX_COL = 7
+
+
+class HeatmapModel:
+    def __init__(self, max_nr_aggregations):
+        self.data_range = (MEM_MIN, MEM_MAX)
+        self.data = []
+        self.max_nr_aggregations = max_nr_aggregations
+
+    def get_nr_aggregations(self):
+        return len(self.data)
+
+    def get_max_nr_aggregations(self):
+        return self.max_nr_generations
+
+    def set_max_nr_aggregations(self, max_nr_aggregations):
+        self.max_nr_aggregations = max_nr_aggregations
+
+    def append_data(self, bpf_access_data):
+        region_min = MEM_MIN
+        region_max = MEM_MAX
+        for region in bpf_access_data:
+            if region["address"] > region_max:
+                region_max = region["address"]
+            if region["address"] < region_min:
+                region_min = region["address"]
+
+        self.data.append({
+            "min": region_min,
+            "max": region_max,
+            "data": sorted(bpf_access_data, key=lambda d: d["address"]),
+        })
+
+        if len(self.data) > self.max_nr_aggregations:
+            del self.data[:len(self.data) - self.max_nr_aggregations]
+
+
+    def get_display_sections(self, rows):
+        REGION_SIZE = 2 * 1024 * 1024 # region size (21 bits) 2 MB
+        ranges = []
+        for d in self.data:
+            for r in d["data"]:
+                addr = r["address"]
+                ranges.append(("start", addr - REGION_SIZE))
+                ranges.append(("end", addr + REGION_SIZE * 2))
+
+        ranges = sorted(ranges, key=lambda d: d[1])
+        sections = []
+        start_addr = None
+        nesting = 0
+        total_size = 0
+        for (tag, addr) in ranges:
+            if tag == "start":
+                nesting += 1
+                if start_addr == None:
+                    start_addr = addr
+
+            if tag == "end":
+                nesting -= 1;
+                if nesting == 0:
+                    total_size += addr - start_addr
+                    sections.append((start_addr, addr))
+                    start_addr = None
+
+        if len(sections) > rows:
+            # compact some sections
+            sections_with_idx = [((start, end), i, end - start) for (i, (start, end)) in enumerate(sections)]
+            sections_with_idx = sorted(sections_with_idx, key=lambda x:
+                                       x[-1] + abs(sections[x[1] + 1][0] - x[0][1]) + abs(sections[x[1] - 1][1] - x[0][0]))
+            for i in range(0, len(sections) - rows):
+                # natural number of regions is greater than the number of rows
+                ((start, end), section_i, size) = sections_with_idx[i]
+                (_, prev_end) = sections[section_i - 1]
+                (succ_start, _) = sections[section_i + 1]
+                if abs(succ_start - end) > abs(prev_end - start):
+                    sections[section_i - 1][1] = end
+                else:
+                    sections[section_i - 1][0] = start
+
+            new_sections = []
+            for i in range(len(sections) - rows, len(sections)):
+                (_, section_i, _) = sections_with_idx[i]
+                new_sections.append(sections[section_i])
+            sections = sorted(new_sections, key=lambda x: x[0])
+
+        else:
+            extra_rows = rows - len(sections)
+            split_sections = []
+            spill_over_factor = 0
+            for (start, end) in sections:
+                fraction_of_row = (end - start) / total_size * rows
+                if fraction_of_row > 1:
+                    additional_rows_frac = (spill_over_factor + (fraction_of_row - 1))
+                    spill_over_factor = additional_rows_frac - floor(additional_rows_frac)
+                    additional_rows = min(floor(additional_rows_frac), extra_rows)
+                    extra_rows -= additional_rows
+                    new_rows = additional_rows + 1
+                    # split current interval
+                    inc = (end - start + new_rows - 1) // new_rows # round up
+                    for i in range(0, new_rows):
+                        split_sections.append((start + inc * i, min(start + inc * (i + 1), end)))
+                else:
+                    split_sections.append((start, end))
+            while extra_rows > 0:
+                split_sections.append((0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF))
+                extra_rows -= 1
+            sections = split_sections
+
+        return sections
+
+    def format_data(self, rows):
+        cell_init = (0, INVALID_MEM, INVALID_NODE)
+        def cell_update(cell, r):
+            (acc, mem, node) = cell
+            acc += r["accesses"]
+            if mem == INVALID_MEM:
+                mem = r["mem"]
+            elif mem != r["mem"]:
+                mem = MIXED_MEM
+
+            if node == INVALID_NODE:
+                node = r["node"]
+            elif node != r["node"]:
+                node = MIXED_NODE
+
+            return (acc, mem, node)
+
+        return self.__format_data(rows, cell_init, cell_update)
+
+    def __format_data(self, rows, cell_init, cell_update):
+        row_ranges = self.get_display_sections(rows)
+        assert(len(row_ranges) == rows)
+        graph = []
+        for d in self.data:
+            (curr_start, curr_end) = row_ranges[0]
+            range_idx = 0
+
+            col = []
+            cell = cell_init
+            for r in d["data"]:
+                while r["address"] >= curr_end:
+                    col.append(cell)
+                    cell = cell_init
+                    range_idx += 1
+                    (curr_start, curr_end) = row_ranges[range_idx]
+
+                cell = cell_update(cell, r)
+
+            while range_idx < len(row_ranges):
+                (curr_start, curr_end) = row_ranges[range_idx]
+                col.append(cell)
+                cell = cell_init
+                range_idx += 1
+
+            assert(len(col) == rows)
+            graph.append(col)
+
+        col_labels = [start for (start, end) in row_ranges]
+        return (col_labels, graph)
+
+
+class GraphView(urwid.Widget):
+    def __init__(self, model, modes):
+        self.model = model
+        self.mode = modes[0]
+        urwid.Widget.__init__(self)
+
+    def set_mode(self, mode):
+        self.mode = mode
+        self.update()
+
+    def render_col(self, col):
+        l = []
+        if self.mode == "Anon/Other":
+            for (_, mem, _) in col:
+                if mem == MIXED_MEM:
+                    l.append(("mixed", "M" * MAX_COL))
+                elif mem == INVALID_MEM:
+                    l.append(("invalid", "." * MAX_COL))
+                elif mem == ANON_MEM:
+                    l.append(("anon mem", "A" * MAX_COL))
+                else:
+                    l.append(("other mem", "X" * MAX_COL))
+        elif self.mode == "NUMA Node":
+            for (_, _, node) in col:
+                if node == MIXED_NODE:
+                    l.append(("mixed", "M" * MAX_COL))
+                elif node == INVALID_NODE:
+                    l.append(("invalid", "." * MAX_COL))
+                else:
+                    node_str = str(node)
+                    left_pad = (MAX_COL - len(node_str)) // 2
+                    right_pad = MAX_COL - left_pad - len(node_str)
+                    node_style = "node"
+                    if node < 4 and node >= 0:
+                        node_style += " " + node_str
+                    l.append((node_style, "_" * left_pad + node_str + "_" * right_pad))
+        else:
+            for (acc, _, _) in col:
+                acc_str = str(acc)
+                left_pad = (MAX_COL - len(acc_str)) // 2
+                right_pad = MAX_COL - left_pad - len(acc_str)
+                heat_style = "heat"
+                if acc >= 512:
+                    heat_style += " mid"
+                elif acc >= 1024:
+                    heat_style += " high"
+                l.append((heat_style, u"\u00a0" * left_pad + acc_str + u"\u00a0" * right_pad))
+
+        return (urwid.Text(l).render((MAX_COL,)), None, False, MAX_COL)
+
+    def render(self, size, focus=False):
+        (cols, rows) = size
+        LABEL_COLS = 16 + 2 + 4 # 0x and 16 chars of hex address, plus padding
+        self.model.set_max_nr_aggregations((cols - LABEL_COLS) // MAX_COL)
+        data_cols = cols - LABEL_COLS
+        (labels, data) = self.model.format_data(rows)[-(data_cols // MAX_COL):]
+        if len(labels) == 0:
+            label_col = urwid.SolidCanvas(" ", LABEL_COLS, rows)
+        else:
+            label_col = [("pg smooth", "0x{:016X}".format(addr) + u"\u00a0" * 4) for addr in labels]
+            label_col = urwid.Text(label_col).render((LABEL_COLS,))
+
+        label_col = (label_col, None, False, LABEL_COLS)
+        if len(data) > 0:
+            return urwid.CanvasJoin([label_col] + list(map(self.render_col, data)) + [
+                (urwid.SolidCanvas(" ", data_cols - len(data) * MAX_COL, rows),
+                 None, False, data_cols - len(data) * MAX_COL)])
+        else:
+            return urwid.SolidCanvas(" ", cols, rows)
+
+    def update(self):
+        self._invalidate()
+
+    def rows(self, size, focus=False):
+        return MIN_ROWS
+
+    def keypress(self, size, key):
+        return key
+
+
+class HeatmapView(urwid.WidgetWrap):
+    palette = [
+        ('body',         'black',      'light gray', 'standout'),
+        ('header',       'white',      'dark red',   'bold'),
+        ('screen edge',  'light blue', 'dark cyan'),
+        ('main shadow',  'dark gray',  'black'),
+        ('line',         'black',      'light gray', 'standout'),
+        ('bg background','light gray', 'black'),
+        ('bg 1',         'black',      'dark blue', 'standout'),
+        ('bg 1 smooth',  'dark blue',  'black'),
+        ('bg 2',         'black',      'dark cyan', 'standout'),
+        ('bg 2 smooth',  'dark cyan',  'black'),
+        ('button normal','light gray', 'dark blue', 'standout'),
+        ('button select','white',      'dark green'),
+        ('line',         'black',      'light gray', 'standout'),
+        ('pg normal',    'white',      'black', 'standout'),
+        ('pg complete',  'white',      'dark magenta'),
+        ('pg smooth',     'dark magenta','black'),
+
+        ("mixed", "light blue", "black"),
+        ("invalid", "dark gray", "black"),
+        ("anon mem", "white", "dark cyan"),
+        ("other mem", "light blue", "dark red"),
+
+        ("node", "white", "black"),
+
+        ("node 0", "light green", "black"),
+        ("node 1", "light blue", "black"),
+        ("node 2", "light red", "black"),
+        ("node 3", "yellow", "black"),
+
+        ("heat", "yellow", "black"),
+        ("heat mid", "dark red", "brown"),
+        ("heat high", "light red", "dark red"),
+    ]
+
+    def __init__(self, controller):
+        self.controller = controller
+        urwid.WidgetWrap.__init__(self, self.draw_view())
+
+    def update_graph(self):
+        self.graph_view.update()
+        pass
+
+    def set_selected_mode(self, new_mode):
+        self.graph_view.set_mode(new_mode)
+        for b in self.mode_buttons:
+            if b.get_label() == new_mode:
+                b.set_state(True, do_callback=False)
+                break
+
+    def on_mode_button(self, button, state):
+        if state:
+            self.controller.on_mode_change(button.get_label())
+
+    def radio_button(self, group, label, state, on_state_change):
+        w = urwid.RadioButton(group, label, state, on_state_change=on_state_change)
+        w = urwid.AttrWrap(w, 'button normal', 'button select')
+        return w
+
+    def button(self, label, on_press):
+        w = urwid.Button(label, on_press)
+        w = urwid.AttrWrap(w, 'button normal', 'button select')
+        return w
+
+    def set_alert_message(self, new_alert_text):
+        self.alert_text.set_text(new_alert_text)
+        pass
+
+    def set_start_button_text(self, new_text):
+        self.start_button.set_label(new_text)
+        pass
+
+    def edit_box(self, label, text, on_change):
+        w = urwid.Edit(label, text)
+        urwid.connect_signal(w, 'change', on_change)
+        w = urwid.AttrWrap(w, 'edit')
+        return w
+
+    def draw_control_pane(self):
+        g = []
+        self.mode_buttons = [self.radio_button(g, mode, mode == self.controller.mode,
+                                               self.on_mode_button)
+                          for mode in self.controller.overlay_modes]
+        self.pid_box = self.edit_box("PID: ", "1", self.controller.on_pid_change)
+        self.memcg_box = self.edit_box("memcg id: ", "1", self.controller.on_memcg_change)
+        self.start_button = self.button("Start", self.controller.on_start_button)
+        self.alert_text = urwid.Text("", align="center")
+
+        aging_text = urwid.Text("Aging interval", align="center")
+        self.aging_box = self.edit_box("Seconds: ", str(0.5), self.controller.on_aging_change)
+        aggregation_text = urwid.Text("Aggregation Interval", align="center")
+        self.aggregation_box = self.edit_box("Aging cycles: ", str(3), self.controller.on_aggregation_change)
+
+        self.quit_button = self.button("Quit", self.controller.on_quit_button)
+        l = [
+            urwid.Text("Overlay Mode", align="center")
+        ] + self.mode_buttons + [
+            urwid.Divider(),
+            self.pid_box,
+            self.memcg_box,
+            self.start_button,
+            urwid.Divider(),
+            aging_text,
+            self.aging_box,
+            aggregation_text,
+            self.aggregation_box,
+            urwid.Divider(),
+            self.alert_text,
+            urwid.Divider(),
+            self.quit_button
+        ]
+        return urwid.ListBox(urwid.SimpleListWalker(l))
+
+    def main_shadow(self, w):
+        # Wrap a shadow and background around widget w
+        bg = urwid.AttrWrap(urwid.SolidFill(u"\u2592"), 'screen edge')
+        shadow = urwid.AttrWrap(urwid.SolidFill(u" "), 'main shadow')
+
+        bg = urwid.Overlay( shadow, bg,
+                            ('fixed left', 3), ('fixed right', 1),
+                            ('fixed top', 2), ('fixed bottom', 1))
+        w = urwid.Overlay( w, bg,
+                           ('fixed left', 2), ('fixed right', 3),
+                           ('fixed top', 1), ('fixed bottom', 2))
+        return w
+
+    def draw_view(self):
+        control_pane = self.draw_control_pane()
+        self.graph_view = GraphView(self.controller.model, self.controller.overlay_modes)
+        vline = urwid.AttrWrap(urwid.SolidFill(u'\u2502'), 'line')
+        w = urwid.Columns([("weight", 4, self.graph_view), ("fixed", 1, vline), control_pane],
+                          dividechars=1, focus_column=2)
+        w = urwid.Padding(w,('fixed left',1),('fixed right',0))
+        w = urwid.AttrWrap(w,'body')
+        w = urwid.LineBox(w)
+        w = urwid.AttrWrap(w,'line')
+        w = self.main_shadow(w)
+        return w
+
+class BpfBridge:
+    def __init__(self):
+        self.bpf_bridge = subprocess.Popen(os.path.join(os.path.dirname(os.path.abspath(__file__)),
+                                                        "heatmap.user"),
+                                           stdin=subprocess.PIPE,
+                                           stdout=subprocess.PIPE,
+                                           stderr=subprocess.DEVNULL,
+                                           text=True)
+
+    def attach(self, pid, memcg_id):
+        self.bpf_bridge.stdin.write("attach\n")
+        self.bpf_bridge.stdin.write("{:d} {:d}\n".format(pid, memcg_id))
+        self.bpf_bridge.stdin.flush()
+        result = self.bpf_bridge.stdout.readline()
+        if "success" not in result:
+            return result
+
+        return "success"
+
+    def detach(self):
+        self.bpf_bridge.stdin.write("detach\n")
+        self.bpf_bridge.stdin.flush()
+        result = self.bpf_bridge.stdout.readline()
+        if "success" not in result:
+            return result
+
+        return "success"
+
+    def run_aging(self):
+        self.bpf_bridge.stdin.write("age\n")
+        self.bpf_bridge.stdin.flush()
+        result = self.bpf_bridge.stdout.readline()
+        if "success" not in result:
+            return result
+
+        return "success"
+
+    def get_map(self):
+        self.bpf_bridge.stdin.write("map\n")
+        self.bpf_bridge.stdin.flush()
+        result = ""
+        access_data = []
+        parser = parse.compile("{:d} {:d} {:d} {:d}\n")
+        result = self.bpf_bridge.stdout.readline()
+        while "success" not in result:
+            parsed_access = parser.parse(result)
+            if parsed_access is None:
+                raise Exception("woo " + result)
+                return result
+
+            access_data.append({
+                "address": parsed_access[0],
+                "accesses": parsed_access[1],
+                "mem": parsed_access[2],
+                "node": parsed_access[3],
+            })
+            result = self.bpf_bridge.stdout.readline()
+
+        return access_data
+
+
+class HeatmapController:
+    def __init__(self):
+        self.last_aged = 0
+        self.aging_interval = 0.5
+        self.aging_count = 0
+        self.aggregation_interval = 3
+        self.overlay_modes = ["Heat", "NUMA Node", "Anon/Other"]
+        self.monitoring_pid = 1
+        self.monitoring_memcg_id = 1
+        self.mode = self.overlay_modes[0]
+        self.model = HeatmapModel(MAX_AGGREGATIONS)
+        self.bpf_bridge = BpfBridge()
+        self.monitoring = False
+        self.view = HeatmapView(self)
+
+    def main(self):
+        self.loop = urwid.MainLoop(self.view, self.view.palette)
+        self.timer = self.loop.set_alarm_in(UPDATE_INTERVAL, self.on_timer)
+        # spawn the monitored process
+        self.loop.run()
+
+    def on_mode_change(self, new_mode):
+        mode = new_mode
+        self.view.set_selected_mode(new_mode)
+
+    def on_aging_change(self, w, new_value):
+        try:
+            aging_interval = float(new_value)
+            if aging_interval > 0:
+                self.aging_interval = aging_interval
+                self.view.set_alert_message("")
+            else:
+                self.view.set_alert_message("invalid aging interval")
+        except ValueError:
+            self.view.set_alert_message("invalid aging interval")
+
+    def on_aggregation_change(self, w, new_value):
+        try:
+            aggregation_interval = int(new_value)
+            if aggregation_interval > 0:
+                self.aggregation_interval = aggregation_interval
+                self.view.set_alert_message("")
+            else:
+                self.view.set_alert_message("invalid aggregation interval")
+        except ValueError:
+            self.view.set_alert_message("invalid aggregation interval")
+
+    def update_graph(self, new_data):
+        self.model.append_data(new_data)
+        self.view.update_graph()
+
+    def on_timer(self, loop=None, user_data=None):
+        # perform aging
+        # read data
+        delta_time = -time()
+        if self.last_aged <= -delta_time - self.aging_interval:
+            if self.monitoring:
+                # aging
+                self.last_aged = -delta_time
+                self.aging_count += 1
+                err = self.bpf_bridge.run_aging()
+                if "success" not in err:
+                    self.disable_monitoring()
+                    self.view.set_alert_message(err)
+
+                if self.aging_count % self.aggregation_interval == 0:
+                    # aggregation
+                    data = self.bpf_bridge.get_map()
+                    if isinstance(data, str):
+                        # get map failed
+                        self.disable_monitoring()
+                        self.view.set_alert_message(err)
+                    else:
+                        self.update_graph(data)
+
+        delta_time += time()
+        if delta_time > UPDATE_INTERVAL:
+            self.view.set_alert_message("timer running behind")
+            self.loop.set_alarm_in(0, self.on_timer)
+        else:
+            self.loop.set_alarm_in(UPDATE_INTERVAL - delta_time, self.on_timer)
+
+    def on_pid_change(self, widget, new_text):
+        pid = parse.parse("{:d}", new_text)
+        self.view.set_alert_message("")
+        if self.monitoring:
+            return
+
+        if pid is not None:
+            self.monitoring_pid = pid[0]
+        else:
+            self.monitoring_pid = -1
+
+    def on_memcg_change(self, widget, new_text):
+        memcg_id = parse.parse("{:d}", new_text)
+        self.view.set_alert_message("")
+        if self.monitoring:
+            return
+
+        if memcg_id is not None:
+            self.monitoring_memcg_id = memcg_id[0]
+        else:
+            self.monitoring_memcg_id = -1
+
+    def disable_monitoring(self):
+        self.monitoring = False
+        self.bpf_bridge.detach()
+        self.view.set_start_button_text("Start")
+
+    def on_start_button(self, w):
+        if self.monitoring:
+            err = self.bpf_bridge.detach()
+            self.monitoring = False
+            self.view.set_start_button_text("Start")
+            if "success" not in err:
+                self.view.set_alert_message(err)
+                self.view.pid_box.set_edit_text(str(self.monitoring_pid))
+        else:
+            if self.monitoring_pid != -1 and self.monitoring_memcg_id != -1:
+                err = self.bpf_bridge.attach(self.monitoring_pid, self.monitoring_memcg_id)
+                if "success" in err:
+                    self.view.set_start_button_text("Stop")
+                    self.monitoring = True
+                else:
+                    self.view.set_alert_message(err)
+            else:
+                self.view.set_alert_message("invalid pid/memcg")
+
+
+
+    def on_quit_button(self, w):
+        raise urwid.ExitMainLoop()
+
+def main():
+    HeatmapController().main()
+
+if "__main__" == __name__:
+    main()