diff mbox series

[RFC,2/2] kasan: add mem track interface and its test cases

Message ID 20240118124109.37324-3-lizhe.67@bytedance.com (mailing list archive)
State New
Headers show
Series kasan: introduce mem track feature | expand

Commit Message

lizhe.67@bytedance.com Jan. 18, 2024, 12:41 p.m. UTC
From: Li Zhe <lizhe.67@bytedance.com>

kasan_track_memory() and kasan_untrack_memory() are two interfaces
used to track memory write operations. We can use them to locate
problems where memory has been accidentally rewritten. Examples of
interface usages are shown in kasan_test_module.c

Signed-off-by: Li Zhe <lizhe.67@bytedance.com>
---
 include/linux/kasan.h        |   5 ++
 mm/kasan/generic.c           | 161 +++++++++++++++++++++++++++++++++++
 mm/kasan/kasan_test_module.c |  26 ++++++
 3 files changed, 192 insertions(+)
diff mbox series

Patch

diff --git a/include/linux/kasan.h b/include/linux/kasan.h
index dbb06d789e74..ca5d93629ccf 100644
--- a/include/linux/kasan.h
+++ b/include/linux/kasan.h
@@ -604,4 +604,9 @@  void kasan_non_canonical_hook(unsigned long addr);
 static inline void kasan_non_canonical_hook(unsigned long addr) { }
 #endif /* CONFIG_KASAN_GENERIC || CONFIG_KASAN_SW_TAGS */
 
+#ifdef CONFIG_KASAN_MEM_TRACK
+int kasan_track_memory(const void *addr, size_t size);
+int kasan_untrack_memory(const void *addr, size_t size);
+#endif
+
 #endif /* LINUX_KASAN_H */
diff --git a/mm/kasan/generic.c b/mm/kasan/generic.c
index a204ddcbaa3f..61f3f5125338 100644
--- a/mm/kasan/generic.c
+++ b/mm/kasan/generic.c
@@ -402,6 +402,167 @@  static __always_inline bool memory_is_tracked(const void *addr, size_t size)
 
 	return memory_is_tracked_n(addr, size);
 }
+
+/* deal with addr do not cross 8(shadow size)-byte boundary */
+static void __kasan_track_memory(const void *shadow_addr, size_t offset, size_t size)
+{
+	s8 mask;
+
+	if ((offset & 0x01) || (size & 0x01))
+		mask = kasan_track_mask_odd(size);
+	else
+		mask = kasan_track_mask_even(size);
+	offset = offset >> 1;
+	*(s8 *)shadow_addr |= mask << (KASAN_TRACK_VALUE_OFFSET + offset);
+}
+
+static void _kasan_track_memory(const void *addr, size_t size)
+{
+	unsigned int words;
+	const void *start = kasan_mem_to_shadow(addr);
+	unsigned int prefix = (unsigned long)addr % 8;
+
+	if (prefix) {
+		unsigned int tmp_size = (unsigned int)size;
+
+		tmp_size = min(8 - prefix, tmp_size);
+		__kasan_track_memory(start, prefix, tmp_size);
+		start++;
+		size -= tmp_size;
+	}
+
+	words = size / 8;
+	while (words) {
+		__kasan_track_memory(start, 0, 8);
+		start++;
+		words--;
+	}
+
+	if (size % 8)
+		__kasan_track_memory(start, 0, size % 8);
+}
+
+static inline bool is_cpu_entry_area_addr(unsigned long addr)
+{
+	return ((addr >= CPU_ENTRY_AREA_BASE) &&
+		(addr < CPU_ENTRY_AREA_BASE + CPU_ENTRY_AREA_MAP_SIZE));
+}
+
+static inline bool is_kernel_text_data(unsigned long addr)
+{
+	return ((addr >= (unsigned long)_stext) && (addr < (unsigned long)_end));
+}
+
+static bool can_track(unsigned long addr)
+{
+	if (!virt_addr_valid(addr) &&
+		!is_module_address(addr) &&
+#ifdef CONFIG_KASAN_VMALLOC
+		!is_vmalloc_addr((const void *)addr) &&
+#endif
+		!is_cpu_entry_area_addr(addr) &&
+		!is_kernel_text_data(addr)
+	)
+		return false;
+
+	return true;
+}
+
+int kasan_track_memory(const void *addr, size_t size)
+{
+	if (!kasan_arch_is_ready())
+		return -EINVAL;
+
+	if (unlikely(size == 0))
+		return -EINVAL;
+
+	if (unlikely(addr + size < addr))
+		return -EINVAL;
+
+	if (unlikely(!addr_has_metadata(addr)))
+		return -EINVAL;
+
+	if (likely(memory_is_poisoned(addr, size)))
+		return -EINVAL;
+
+	if (!can_track((unsigned long)addr))
+		return -EINVAL;
+
+	_kasan_track_memory(addr, size);
+	return 0;
+}
+EXPORT_SYMBOL(kasan_track_memory);
+
+/* deal with addr do not cross 8(shadow size)-byte boundary */
+static void __kasan_untrack_memory(const void *shadow_addr, size_t offset, size_t size)
+{
+	s8 mask;
+
+	if (size % 0x01) {
+		offset = (offset - 1) >> 1;
+		mask = kasan_track_mask_odd(size);
+		/*
+		 * SIZE is odd, which means we may clear someone else's tracking flags of
+		 * nearby tracked memory.
+		 */
+		pr_info("It's possible to clear someone else's tracking flags\n");
+	} else {
+		offset = offset >> 1;
+		mask = kasan_track_mask_even(size);
+	}
+	*(s8 *)shadow_addr &= ~(mask << (KASAN_TRACK_VALUE_OFFSET + offset));
+}
+
+static void _kasan_untrack_memory(const void *addr, size_t size)
+{
+	unsigned int words;
+	const void *start = kasan_mem_to_shadow(addr);
+	unsigned int prefix = (unsigned long)addr % 8;
+
+	if (prefix) {
+		unsigned int tmp_size = (unsigned int)size;
+
+		tmp_size = min(8 - prefix, tmp_size);
+		__kasan_untrack_memory(start, prefix, tmp_size);
+		start++;
+		size -= tmp_size;
+	}
+
+	words = size / 8;
+	while (words) {
+		__kasan_untrack_memory(start, 0, 8);
+		start++;
+		words--;
+	}
+
+	if (size % 8)
+		__kasan_untrack_memory(start, 0, size % 8);
+}
+
+int kasan_untrack_memory(const void *addr, size_t size)
+{
+	if (!kasan_arch_is_ready())
+		return -EINVAL;
+
+	if (unlikely(size == 0))
+		return -EINVAL;
+
+	if (unlikely(addr + size < addr))
+		return -EINVAL;
+
+	if (unlikely(!addr_has_metadata(addr)))
+		return -EINVAL;
+
+	if (likely(memory_is_poisoned(addr, size)))
+		return -EINVAL;
+
+	if (!can_track((unsigned long)addr))
+		return -EINVAL;
+
+	_kasan_untrack_memory(addr, size);
+	return 0;
+}
+EXPORT_SYMBOL(kasan_untrack_memory);
 #endif
 
 static __always_inline bool check_region_inline(const void *addr,
diff --git a/mm/kasan/kasan_test_module.c b/mm/kasan/kasan_test_module.c
index 8b7b3ea2c74e..1dba44dbfc81 100644
--- a/mm/kasan/kasan_test_module.c
+++ b/mm/kasan/kasan_test_module.c
@@ -62,6 +62,31 @@  static noinline void __init copy_user_test(void)
 	kfree(kmem);
 }
 
+#ifdef CONFIG_KASAN_MEM_TRACK
+static noinline void __init mem_track_test(void)
+{
+	int ret;
+	int *ptr = kmalloc(sizeof(int), GFP_KERNEL);
+
+	if (!ptr)
+		return;
+
+	ret = kasan_track_memory(ptr, sizeof(int));
+	if (ret) {
+		pr_warn("There is a bug of mem_track\n");
+		goto out;
+	}
+	pr_info("trigger mem_track\n");
+	WRITE_ONCE(*ptr, 1);
+	kasan_untrack_memory(ptr, sizeof(int));
+
+out:
+	kfree(ptr);
+}
+#else
+static inline void __init mem_track_test(void) {}
+#endif
+
 static int __init test_kasan_module_init(void)
 {
 	/*
@@ -72,6 +97,7 @@  static int __init test_kasan_module_init(void)
 	bool multishot = kasan_save_enable_multi_shot();
 
 	copy_user_test();
+	mem_track_test();
 
 	kasan_restore_multi_shot(multishot);
 	return -EAGAIN;