diff mbox series

[bpf-next,v5,4/8] libbpf: Ensure that BPF syscall fds are never 0, 1, or 2

Message ID 20211028063501.2239335-5-memxor@gmail.com (mailing list archive)
State Accepted
Delegated to: BPF
Headers show
Series Typeless/weak ksym for gen_loader + misc fixups | expand

Checks

Context Check Description
bpf/vmtest-bpf-next-PR fail PR summary
netdev/cover_letter success Series has a cover letter
netdev/fixes_present success Fixes tag not required for -next series
netdev/patch_count success Link
netdev/tree_selection success Clearly marked for bpf-next
netdev/subject_prefix success Link
netdev/cc_maintainers warning 2 maintainers not CCed: john.fastabend@gmail.com kpsingh@kernel.org
netdev/source_inline fail Was 0 now: 1
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/module_param success Was 0 now: 0
netdev/build_32bit success Errors and warnings before: 0 this patch: 0
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/verify_fixes success No Fixes tag
netdev/checkpatch warning WARNING: line length of 93 exceeds 80 columns
netdev/build_allmodconfig_warn success Errors and warnings before: 0 this patch: 0
netdev/header_inline success No static functions without inline keyword in header files
bpf/vmtest-bpf-next fail VM_Test

Commit Message

Kumar Kartikeya Dwivedi Oct. 28, 2021, 6:34 a.m. UTC
Add a simple wrapper for passing an fd and getting a new one >= 3 if it
is one of 0, 1, or 2. There are two primary reasons to make this change:
First, libbpf relies on the assumption a certain BPF fd is never 0 (e.g.
most recently noticed in [0]). Second, Alexei pointed out in [1] that
some environments reset stdin, stdout, and stderr if they notice an
invalid fd at these numbers. To protect against both these cases, switch
all internal BPF syscall wrappers in libbpf to always return an fd >= 3.
We only need to modify the syscall wrappers and not other code that
assumes a valid fd by doing >= 0, to avoid pointless churn, and because
it is still a valid assumption. The cost paid is two additional syscalls
if fd is in range [0, 2].

  [0]: e31eec77e4ab ("bpf: selftests: Fix fd cleanup in get_branch_snapshot")
  [1]: https://lore.kernel.org/bpf/CAADnVQKVKY8o_3aU8Gzke443+uHa-eGoM0h7W4srChMXU1S4Bg@mail.gmail.com

Acked-by: Song Liu <songliubraving@fb.com>
Signed-off-by: Kumar Kartikeya Dwivedi <memxor@gmail.com>
---
 tools/lib/bpf/bpf.c             | 35 +++++++++++++++++++++------------
 tools/lib/bpf/libbpf_internal.h | 24 ++++++++++++++++++++++
 2 files changed, 46 insertions(+), 13 deletions(-)

Comments

Andrii Nakryiko Oct. 28, 2021, 6 p.m. UTC | #1
On Wed, Oct 27, 2021 at 11:35 PM Kumar Kartikeya Dwivedi
<memxor@gmail.com> wrote:
>
> Add a simple wrapper for passing an fd and getting a new one >= 3 if it
> is one of 0, 1, or 2. There are two primary reasons to make this change:
> First, libbpf relies on the assumption a certain BPF fd is never 0 (e.g.
> most recently noticed in [0]). Second, Alexei pointed out in [1] that
> some environments reset stdin, stdout, and stderr if they notice an
> invalid fd at these numbers. To protect against both these cases, switch
> all internal BPF syscall wrappers in libbpf to always return an fd >= 3.
> We only need to modify the syscall wrappers and not other code that
> assumes a valid fd by doing >= 0, to avoid pointless churn, and because
> it is still a valid assumption. The cost paid is two additional syscalls
> if fd is in range [0, 2].
>
>   [0]: e31eec77e4ab ("bpf: selftests: Fix fd cleanup in get_branch_snapshot")
>   [1]: https://lore.kernel.org/bpf/CAADnVQKVKY8o_3aU8Gzke443+uHa-eGoM0h7W4srChMXU1S4Bg@mail.gmail.com
>
> Acked-by: Song Liu <songliubraving@fb.com>
> Signed-off-by: Kumar Kartikeya Dwivedi <memxor@gmail.com>
> ---

LGTM, thanks.

Acked-by: Andrii Nakryiko <andrii@kernel.org>

>  tools/lib/bpf/bpf.c             | 35 +++++++++++++++++++++------------
>  tools/lib/bpf/libbpf_internal.h | 24 ++++++++++++++++++++++
>  2 files changed, 46 insertions(+), 13 deletions(-)
>

[...]
diff mbox series

Patch

diff --git a/tools/lib/bpf/bpf.c b/tools/lib/bpf/bpf.c
index 7d1741ceaa32..53efbdb1b652 100644
--- a/tools/lib/bpf/bpf.c
+++ b/tools/lib/bpf/bpf.c
@@ -65,13 +65,22 @@  static inline int sys_bpf(enum bpf_cmd cmd, union bpf_attr *attr,
 	return syscall(__NR_bpf, cmd, attr, size);
 }
 
+static inline int sys_bpf_fd(enum bpf_cmd cmd, union bpf_attr *attr,
+			     unsigned int size)
+{
+	int fd;
+
+	fd = sys_bpf(cmd, attr, size);
+	return ensure_good_fd(fd);
+}
+
 static inline int sys_bpf_prog_load(union bpf_attr *attr, unsigned int size)
 {
 	int retries = 5;
 	int fd;
 
 	do {
-		fd = sys_bpf(BPF_PROG_LOAD, attr, size);
+		fd = sys_bpf_fd(BPF_PROG_LOAD, attr, size);
 	} while (fd < 0 && errno == EAGAIN && retries-- > 0);
 
 	return fd;
@@ -103,7 +112,7 @@  int bpf_create_map_xattr(const struct bpf_create_map_attr *create_attr)
 	else
 		attr.inner_map_fd = create_attr->inner_map_fd;
 
-	fd = sys_bpf(BPF_MAP_CREATE, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_MAP_CREATE, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -181,7 +190,7 @@  int bpf_create_map_in_map_node(enum bpf_map_type map_type, const char *name,
 		attr.numa_node = node;
 	}
 
-	fd = sys_bpf(BPF_MAP_CREATE, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_MAP_CREATE, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -609,7 +618,7 @@  int bpf_obj_get(const char *pathname)
 	memset(&attr, 0, sizeof(attr));
 	attr.pathname = ptr_to_u64((void *)pathname);
 
-	fd = sys_bpf(BPF_OBJ_GET, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_OBJ_GET, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -720,7 +729,7 @@  int bpf_link_create(int prog_fd, int target_fd,
 		break;
 	}
 proceed:
-	fd = sys_bpf(BPF_LINK_CREATE, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_LINK_CREATE, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -763,7 +772,7 @@  int bpf_iter_create(int link_fd)
 	memset(&attr, 0, sizeof(attr));
 	attr.iter_create.link_fd = link_fd;
 
-	fd = sys_bpf(BPF_ITER_CREATE, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_ITER_CREATE, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -921,7 +930,7 @@  int bpf_prog_get_fd_by_id(__u32 id)
 	memset(&attr, 0, sizeof(attr));
 	attr.prog_id = id;
 
-	fd = sys_bpf(BPF_PROG_GET_FD_BY_ID, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_PROG_GET_FD_BY_ID, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -933,7 +942,7 @@  int bpf_map_get_fd_by_id(__u32 id)
 	memset(&attr, 0, sizeof(attr));
 	attr.map_id = id;
 
-	fd = sys_bpf(BPF_MAP_GET_FD_BY_ID, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_MAP_GET_FD_BY_ID, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -945,7 +954,7 @@  int bpf_btf_get_fd_by_id(__u32 id)
 	memset(&attr, 0, sizeof(attr));
 	attr.btf_id = id;
 
-	fd = sys_bpf(BPF_BTF_GET_FD_BY_ID, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_BTF_GET_FD_BY_ID, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -957,7 +966,7 @@  int bpf_link_get_fd_by_id(__u32 id)
 	memset(&attr, 0, sizeof(attr));
 	attr.link_id = id;
 
-	fd = sys_bpf(BPF_LINK_GET_FD_BY_ID, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_LINK_GET_FD_BY_ID, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -988,7 +997,7 @@  int bpf_raw_tracepoint_open(const char *name, int prog_fd)
 	attr.raw_tracepoint.name = ptr_to_u64(name);
 	attr.raw_tracepoint.prog_fd = prog_fd;
 
-	fd = sys_bpf(BPF_RAW_TRACEPOINT_OPEN, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_RAW_TRACEPOINT_OPEN, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
@@ -1008,7 +1017,7 @@  int bpf_load_btf(const void *btf, __u32 btf_size, char *log_buf, __u32 log_buf_s
 		attr.btf_log_buf = ptr_to_u64(log_buf);
 	}
 
-	fd = sys_bpf(BPF_BTF_LOAD, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_BTF_LOAD, &attr, sizeof(attr));
 
 	if (fd < 0 && !do_log && log_buf && log_buf_size) {
 		do_log = true;
@@ -1050,7 +1059,7 @@  int bpf_enable_stats(enum bpf_stats_type type)
 	memset(&attr, 0, sizeof(attr));
 	attr.enable_stats.type = type;
 
-	fd = sys_bpf(BPF_ENABLE_STATS, &attr, sizeof(attr));
+	fd = sys_bpf_fd(BPF_ENABLE_STATS, &attr, sizeof(attr));
 	return libbpf_err_errno(fd);
 }
 
diff --git a/tools/lib/bpf/libbpf_internal.h b/tools/lib/bpf/libbpf_internal.h
index 13bc7950e304..14f135808f5c 100644
--- a/tools/lib/bpf/libbpf_internal.h
+++ b/tools/lib/bpf/libbpf_internal.h
@@ -13,6 +13,8 @@ 
 #include <limits.h>
 #include <errno.h>
 #include <linux/err.h>
+#include <fcntl.h>
+#include <unistd.h>
 #include "libbpf_legacy.h"
 #include "relo_core.h"
 
@@ -468,4 +470,26 @@  static inline bool is_ldimm64_insn(struct bpf_insn *insn)
 	return insn->code == (BPF_LD | BPF_IMM | BPF_DW);
 }
 
+/* if fd is stdin, stdout, or stderr, dup to a fd greater than 2
+ * Takes ownership of the fd passed in, and closes it if calling
+ * fcntl(fd, F_DUPFD_CLOEXEC, 3).
+ */
+static inline int ensure_good_fd(int fd)
+{
+	int old_fd = fd, saved_errno;
+
+	if (fd < 0)
+		return fd;
+	if (fd < 3) {
+		fd = fcntl(fd, F_DUPFD_CLOEXEC, 3);
+		saved_errno = errno;
+		close(old_fd);
+		if (fd < 0) {
+			pr_warn("failed to dup FD %d to FD > 2: %d\n", old_fd, -saved_errno);
+			errno = saved_errno;
+		}
+	}
+	return fd;
+}
+
 #endif /* __LIBBPF_LIBBPF_INTERNAL_H */