diff mbox series

[bpf-next,v3,3/3] selftests/bpf: Add tests for bpf_copy_from_user_task_str

Message ID 20250124183303.2019147-2-linux@jordanrome.com (mailing list archive)
State New
Headers show
Series [bpf-next,v3,1/3] mm: add copy_remote_vm_str | expand

Commit Message

Jordan Rome Jan. 24, 2025, 6:33 p.m. UTC
This adds tests for both the happy path and the
error path (with and without the BPF_F_PAD_ZEROS flag).

Signed-off-by: Jordan Rome <linux@jordanrome.com>
---
 .../selftests/bpf/prog_tests/bpf_iter.c       |  68 ++++++++++++
 .../selftests/bpf/prog_tests/read_vsyscall.c  |   1 +
 .../selftests/bpf/progs/bpf_iter_tasks.c      | 103 ++++++++++++++++++
 .../selftests/bpf/progs/read_vsyscall.c       |  11 +-
 4 files changed, 181 insertions(+), 2 deletions(-)

--
2.43.5
diff mbox series

Patch

diff --git a/tools/testing/selftests/bpf/prog_tests/bpf_iter.c b/tools/testing/selftests/bpf/prog_tests/bpf_iter.c
index 6f1bfacd7375..add4a18c33bd 100644
--- a/tools/testing/selftests/bpf/prog_tests/bpf_iter.c
+++ b/tools/testing/selftests/bpf/prog_tests/bpf_iter.c
@@ -323,19 +323,87 @@  static void test_task_pidfd(void)
 static void test_task_sleepable(void)
 {
 	struct bpf_iter_tasks *skel;
+	int pid, status, err, data_pipe[2], finish_pipe[2], c;
+	char *test_data = NULL;
+	char *test_data_long = NULL;
+	char *data[2];
+
+	if (!ASSERT_OK(pipe(data_pipe), "data_pipe") ||
+	    !ASSERT_OK(pipe(finish_pipe), "finish_pipe"))
+		return;

 	skel = bpf_iter_tasks__open_and_load();
 	if (!ASSERT_OK_PTR(skel, "bpf_iter_tasks__open_and_load"))
 		return;

+	pid = fork();
+	if (!ASSERT_GE(pid, 0, "fork"))
+		return;
+
+	if (pid == 0) {
+		/* child */
+		close(data_pipe[0]);
+		close(finish_pipe[1]);
+
+		test_data = malloc(sizeof(char) * 10);
+		strncpy(test_data, "test_data", 10);
+		test_data[9] = '\0';
+
+		test_data_long = malloc(sizeof(char) * 5000);
+		for (int i = 0; i < 5000; ++i) {
+			if (i % 2 == 0)
+				test_data_long[i] = 'b';
+			else
+				test_data_long[i] = 'a';
+		}
+		test_data_long[4999] = '\0';
+
+		data[0] = test_data;
+		data[1] = test_data_long;
+
+		write(data_pipe[1], &data, sizeof(data));
+
+		/* keep child alive until after the test */
+		err = read(finish_pipe[0], &c, 1);
+		if (err != 1)
+			exit(-1);
+
+		close(data_pipe[1]);
+		close(finish_pipe[0]);
+		_exit(0);
+	}
+
+	/* parent */
+	close(data_pipe[1]);
+	close(finish_pipe[0]);
+
+	err = read(data_pipe[0], &data, sizeof(data));
+	ASSERT_EQ(err, sizeof(data), "read_check");
+
+	skel->bss->user_ptr = data[0];
+	skel->bss->user_ptr_long = data[1];
+	skel->bss->pid = pid;
+
 	do_dummy_read(skel->progs.dump_task_sleepable);

 	ASSERT_GT(skel->bss->num_expected_failure_copy_from_user_task, 0,
 		  "num_expected_failure_copy_from_user_task");
 	ASSERT_GT(skel->bss->num_success_copy_from_user_task, 0,
 		  "num_success_copy_from_user_task");
+	ASSERT_GT(skel->bss->num_expected_failure_copy_from_user_task_str, 0,
+		  "num_expected_failure_copy_from_user_task_str");
+	ASSERT_GT(skel->bss->num_success_copy_from_user_task_str, 0,
+		  "num_success_copy_from_user_task_str");

 	bpf_iter_tasks__destroy(skel);
+
+	write(finish_pipe[1], &c, 1);
+	err = waitpid(pid, &status, 0);
+	ASSERT_EQ(err, pid, "waitpid");
+	ASSERT_EQ(status, 0, "zero_child_exit");
+
+	close(data_pipe[0]);
+	close(finish_pipe[1]);
 }

 static void test_task_stack(void)
diff --git a/tools/testing/selftests/bpf/prog_tests/read_vsyscall.c b/tools/testing/selftests/bpf/prog_tests/read_vsyscall.c
index c7b9ba8b1d06..a8d1eaa67020 100644
--- a/tools/testing/selftests/bpf/prog_tests/read_vsyscall.c
+++ b/tools/testing/selftests/bpf/prog_tests/read_vsyscall.c
@@ -24,6 +24,7 @@  struct read_ret_desc {
 	{ .name = "copy_from_user", .ret = -EFAULT },
 	{ .name = "copy_from_user_task", .ret = -EFAULT },
 	{ .name = "copy_from_user_str", .ret = -EFAULT },
+	{ .name = "copy_from_user_task_str", .ret = -EFAULT },
 };

 void test_read_vsyscall(void)
diff --git a/tools/testing/selftests/bpf/progs/bpf_iter_tasks.c b/tools/testing/selftests/bpf/progs/bpf_iter_tasks.c
index bc10c4e4b4fa..e4b80260e9c5 100644
--- a/tools/testing/selftests/bpf/progs/bpf_iter_tasks.c
+++ b/tools/testing/selftests/bpf/progs/bpf_iter_tasks.c
@@ -9,6 +9,13 @@  char _license[] SEC("license") = "GPL";
 uint32_t tid = 0;
 int num_unknown_tid = 0;
 int num_known_tid = 0;
+void *user_ptr = 0;
+void *user_ptr_long = 0;
+uint32_t pid = 0;
+
+static char big_str1[5000];
+static char big_str2[5005];
+static char big_str3[4996];

 SEC("iter/task")
 int dump_task(struct bpf_iter__task *ctx)
@@ -35,7 +42,9 @@  int dump_task(struct bpf_iter__task *ctx)
 }

 int num_expected_failure_copy_from_user_task = 0;
+int num_expected_failure_copy_from_user_task_str = 0;
 int num_success_copy_from_user_task = 0;
+int num_success_copy_from_user_task_str = 0;

 SEC("iter.s/task")
 int dump_task_sleepable(struct bpf_iter__task *ctx)
@@ -44,6 +53,9 @@  int dump_task_sleepable(struct bpf_iter__task *ctx)
 	struct task_struct *task = ctx->task;
 	static const char info[] = "    === END ===";
 	struct pt_regs *regs;
+	char task_str1[10] = "aaaaaaaaaa";
+	char task_str2[10], task_str3[10];
+	char task_str4[20] = "aaaaaaaaaaaaaaaaaaaa";
 	void *ptr;
 	uint32_t user_data = 0;
 	int ret;
@@ -78,8 +90,99 @@  int dump_task_sleepable(struct bpf_iter__task *ctx)
 		BPF_SEQ_PRINTF(seq, "%s\n", info);
 		return 0;
 	}
+
 	++num_success_copy_from_user_task;

+	/* Read an invalid pointer and ensure we get an error */
+	ptr = NULL;
+	ret = bpf_copy_from_user_task_str((char *)task_str1, sizeof(task_str1), ptr, task, 0);
+	if (ret >= 0 || task_str1[9] != 'a') {
+		BPF_SEQ_PRINTF(seq, "%s\n", info);
+		return 0;
+	}
+
+	/* Read an invalid pointer and ensure we get error with pad zeros flag */
+	ptr = NULL;
+	ret = bpf_copy_from_user_task_str((char *)task_str1, sizeof(task_str1),
+					  ptr, task, BPF_F_PAD_ZEROS);
+	if (ret >= 0 || task_str1[9] != '\0') {
+		BPF_SEQ_PRINTF(seq, "%s\n", info);
+		return 0;
+	}
+
+	++num_expected_failure_copy_from_user_task_str;
+
+	/* Same length as the string */
+	ret = bpf_copy_from_user_task_str((char *)task_str2, 10, user_ptr, task, 0);
+	/* only need to do the task pid check once */
+	if (bpf_strncmp(task_str2, 10, "test_data\0") != 0 || ret != 10 || task->tgid != pid) {
+		BPF_SEQ_PRINTF(seq, "%s\n", info);
+		return 0;
+	}
+
+	/* Shorter length than the string */
+	ret = bpf_copy_from_user_task_str((char *)task_str3, 2, user_ptr, task, 0);
+	if (bpf_strncmp(task_str3, 2, "t\0") != 0 || ret != 2) {
+		BPF_SEQ_PRINTF(seq, "%s\n", info);
+		return 0;
+	}
+
+	/* Longer length than the string */
+	ret = bpf_copy_from_user_task_str((char *)task_str4, 20, user_ptr, task, 0);
+	if (bpf_strncmp(task_str4, 10, "test_data\0") != 0 || ret != 10
+	    || task_str4[sizeof(task_str4) - 1] != 'a') {
+		BPF_SEQ_PRINTF(seq, "%s\n", info);
+		return 0;
+	}
+
+	/* Longer length than the string with pad zeros flag */
+	ret = bpf_copy_from_user_task_str((char *)task_str4, 20, user_ptr, task, BPF_F_PAD_ZEROS);
+	if (bpf_strncmp(task_str4, 10, "test_data\0") != 0 || ret != 10
+	    || task_str4[sizeof(task_str4) - 1] != '\0') {
+		BPF_SEQ_PRINTF(seq, "%s\n", info);
+		return 0;
+	}
+
+	/* String that crosses a page boundary */
+	ret = bpf_copy_from_user_task_str(big_str1, 5000, user_ptr_long, task, BPF_F_PAD_ZEROS);
+	if (bpf_strncmp(big_str1, 4, "baba") != 0 || ret != 5000
+	    || bpf_strncmp(big_str1 + 4996, 4, "bab\0") != 0) {
+		BPF_SEQ_PRINTF(seq, "%s\n", info);
+		return 0;
+	}
+
+	for (int i = 0; i < 4999; ++i) {
+		if (i % 2 == 0) {
+			if (big_str1[i] != 'b') {
+				BPF_SEQ_PRINTF(seq, "%s\n", info);
+				return 0;
+			}
+		} else {
+			if (big_str1[i] != 'a') {
+				BPF_SEQ_PRINTF(seq, "%s\n", info);
+				return 0;
+			}
+		}
+	}
+
+	/* Longer length than the string that crosses a page boundary */
+	ret = bpf_copy_from_user_task_str(big_str2, 5005, user_ptr_long, task, BPF_F_PAD_ZEROS);
+	if (bpf_strncmp(big_str2, 4, "baba") != 0 || ret != 5000
+	    || bpf_strncmp(big_str2 + 4996, 5, "bab\0\0") != 0) {
+		BPF_SEQ_PRINTF(seq, "%s\n", info);
+		return 0;
+	}
+
+	/* Shorter length than the string that crosses a page boundary */
+	ret = bpf_copy_from_user_task_str(big_str3, 4996, user_ptr_long, task, 0);
+	if (bpf_strncmp(big_str3, 4, "baba") != 0 || ret != 4996
+	    || bpf_strncmp(big_str3 + 4992, 4, "bab\0") != 0) {
+		BPF_SEQ_PRINTF(seq, "%s\n", info);
+		return 0;
+	}
+
+	++num_success_copy_from_user_task_str;
+
 	if (ctx->meta->seq_num == 0)
 		BPF_SEQ_PRINTF(seq, "    tgid      gid     data\n");

diff --git a/tools/testing/selftests/bpf/progs/read_vsyscall.c b/tools/testing/selftests/bpf/progs/read_vsyscall.c
index 39ebef430059..395591374d4f 100644
--- a/tools/testing/selftests/bpf/progs/read_vsyscall.c
+++ b/tools/testing/selftests/bpf/progs/read_vsyscall.c
@@ -8,14 +8,16 @@ 

 int target_pid = 0;
 void *user_ptr = 0;
-int read_ret[9];
+int read_ret[10];

 char _license[] SEC("license") = "GPL";

 /*
- * This is the only kfunc, the others are helpers
+ * These are the kfuncs, the others are helpers
  */
 int bpf_copy_from_user_str(void *dst, u32, const void *, u64) __weak __ksym;
+int bpf_copy_from_user_task_str(void *dst, u32, const void *,
+				struct task_struct *, u64) __weak __ksym;

 SEC("fentry/" SYS_PREFIX "sys_nanosleep")
 int do_probe_read(void *ctx)
@@ -47,6 +49,11 @@  int do_copy_from_user(void *ctx)
 	read_ret[7] = bpf_copy_from_user_task(buf, sizeof(buf), user_ptr,
 					      bpf_get_current_task_btf(), 0);
 	read_ret[8] = bpf_copy_from_user_str((char *)buf, sizeof(buf), user_ptr, 0);
+	read_ret[9] = bpf_copy_from_user_task_str((char *)buf,
+						  sizeof(buf),
+						  user_ptr,
+						  bpf_get_current_task_btf(),
+						  0);

 	return 0;
 }