diff mbox series

[v2,2/3] HID: bpf: actually free hdev memory after attaching a HID-BPF program

Message ID 20240124-b4-hid-bpf-fixes-v2-2-052520b1e5e6@kernel.org (mailing list archive)
State New
Delegated to: Jiri Kosina
Headers show
Series HID: bpf: couple of upstream fixes | expand

Commit Message

Benjamin Tissoires Jan. 24, 2024, 11:26 a.m. UTC
Turns out that I got my reference counts wrong and each successful
bus_find_device() actually calls get_device(), and we need to manually
call put_device().

Ensure each bus_find_device() gets a matching put_device() when releasing
the bpf programs and fix all the error paths.

Cc: stable@vger.kernel.org
Fixes: f5c27da4e3c8 ("HID: initial BPF implementation")
Signed-off-by: Benjamin Tissoires <bentiss@kernel.org>

---

new in v2
---
 drivers/hid/bpf/hid_bpf_dispatch.c  | 29 +++++++++++++++++++++++------
 drivers/hid/bpf/hid_bpf_jmp_table.c | 19 ++++++++++++++++---
 2 files changed, 39 insertions(+), 9 deletions(-)

Comments

Benjamin Tissoires Jan. 26, 2024, 11:20 a.m. UTC | #1
On Wed, Jan 24, 2024 at 12:27 PM Benjamin Tissoires <bentiss@kernel.org> wrote:
>
> Turns out that I got my reference counts wrong and each successful
> bus_find_device() actually calls get_device(), and we need to manually
> call put_device().
>
> Ensure each bus_find_device() gets a matching put_device() when releasing
> the bpf programs and fix all the error paths.
>
> Cc: stable@vger.kernel.org
> Fixes: f5c27da4e3c8 ("HID: initial BPF implementation")
> Signed-off-by: Benjamin Tissoires <bentiss@kernel.org>
>
> ---
>
> new in v2
> ---
>  drivers/hid/bpf/hid_bpf_dispatch.c  | 29 +++++++++++++++++++++++------
>  drivers/hid/bpf/hid_bpf_jmp_table.c | 19 ++++++++++++++++---
>  2 files changed, 39 insertions(+), 9 deletions(-)
>
> diff --git a/drivers/hid/bpf/hid_bpf_dispatch.c b/drivers/hid/bpf/hid_bpf_dispatch.c
> index 5111d1fef0d3..7903c8638e81 100644
> --- a/drivers/hid/bpf/hid_bpf_dispatch.c
> +++ b/drivers/hid/bpf/hid_bpf_dispatch.c
> @@ -292,7 +292,7 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
>         struct hid_device *hdev;
>         struct bpf_prog *prog;
>         struct device *dev;
> -       int fd;
> +       int err, fd;
>
>         if (!hid_bpf_ops)
>                 return -EINVAL;
> @@ -311,14 +311,24 @@ hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
>          * on errors or when it'll be detached
>          */
>         prog = bpf_prog_get(prog_fd);
> -       if (IS_ERR(prog))
> -               return PTR_ERR(prog);
> +       if (IS_ERR(prog)) {
> +               err = PTR_ERR(prog);
> +               goto out_dev_put;
> +       }
>
>         fd = do_hid_bpf_attach_prog(hdev, prog_fd, prog, flags);
> -       if (fd < 0)
> -               bpf_prog_put(prog);
> +       if (fd < 0) {
> +               err = fd;
> +               goto out_prog_put;
> +       }
>
>         return fd;
> +
> + out_prog_put:
> +       bpf_prog_put(prog);
> + out_dev_put:
> +       put_device(dev);
> +       return err;
>  }
>
>  /**
> @@ -345,8 +355,10 @@ hid_bpf_allocate_context(unsigned int hid_id)
>         hdev = to_hid_device(dev);
>
>         ctx_kern = kzalloc(sizeof(*ctx_kern), GFP_KERNEL);
> -       if (!ctx_kern)
> +       if (!ctx_kern) {
> +               put_device(dev);
>                 return NULL;
> +       }
>
>         ctx_kern->ctx.hid = hdev;
>
> @@ -363,10 +375,15 @@ noinline void
>  hid_bpf_release_context(struct hid_bpf_ctx *ctx)
>  {
>         struct hid_bpf_ctx_kern *ctx_kern;
> +       struct hid_device *hid;
>
>         ctx_kern = container_of(ctx, struct hid_bpf_ctx_kern, ctx);
> +       hid = (struct hid_device *)ctx_kern->ctx.hid; /* ignore const */
>
>         kfree(ctx_kern);
> +
> +       /* get_device() is called by bus_find_device() */
> +       put_device(&hid->dev);
>  }
>
>  /**
> diff --git a/drivers/hid/bpf/hid_bpf_jmp_table.c b/drivers/hid/bpf/hid_bpf_jmp_table.c
> index 12f7cebddd73..85a24bc0ea25 100644
> --- a/drivers/hid/bpf/hid_bpf_jmp_table.c
> +++ b/drivers/hid/bpf/hid_bpf_jmp_table.c
> @@ -196,6 +196,7 @@ static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
>  static void hid_bpf_release_progs(struct work_struct *work)
>  {
>         int i, j, n, map_fd = -1;
> +       bool hdev_destroyed;
>
>         if (!jmp_table.map)
>                 return;
> @@ -220,6 +221,12 @@ static void hid_bpf_release_progs(struct work_struct *work)
>                 if (entry->hdev) {
>                         hdev = entry->hdev;
>                         type = entry->type;
> +                       /*
> +                        * hdev is still valid, even if we are called after hid_destroy_device():
> +                        * when hid_bpf_attach() gets called, it takes a ref on the dev through
> +                        * bus_find_device()
> +                        */
> +                       hdev_destroyed = hdev->bpf.destroyed;
>
>                         hid_bpf_populate_hdev(hdev, type);
>
> @@ -232,12 +239,18 @@ static void hid_bpf_release_progs(struct work_struct *work)
>                                 if (test_bit(next->idx, jmp_table.enabled))
>                                         continue;
>
> -                               if (next->hdev == hdev && next->type == type)
> +                               if (next->hdev == hdev && next->type == type) {
> +                                       /*
> +                                        * clear the hdev reference and decrement the device ref
> +                                        * that was taken during bus_find_device() while calling
> +                                        * hid_bpf_attach()
> +                                        */
>                                         next->hdev = NULL;
> +                                       put_device(&hdev->dev);

sigh... I can't make a correct patch these days... Missing a '}' here
to match the open bracket added above :(

I had some debug information put there to check if the device was
actually freed, and the closing bracket got lost while cleaning this
up.

Cheers,
Benjamin

>                         }
>
> -                       /* if type was rdesc fixup, reconnect device */
> -                       if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP)
> +                       /* if type was rdesc fixup and the device is not gone, reconnect device */
> +                       if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP && !hdev_destroyed)
>                                 hid_bpf_reconnect(hdev);
>                 }
>         }
>
> --
> 2.43.0
>
diff mbox series

Patch

diff --git a/drivers/hid/bpf/hid_bpf_dispatch.c b/drivers/hid/bpf/hid_bpf_dispatch.c
index 5111d1fef0d3..7903c8638e81 100644
--- a/drivers/hid/bpf/hid_bpf_dispatch.c
+++ b/drivers/hid/bpf/hid_bpf_dispatch.c
@@ -292,7 +292,7 @@  hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
 	struct hid_device *hdev;
 	struct bpf_prog *prog;
 	struct device *dev;
-	int fd;
+	int err, fd;
 
 	if (!hid_bpf_ops)
 		return -EINVAL;
@@ -311,14 +311,24 @@  hid_bpf_attach_prog(unsigned int hid_id, int prog_fd, __u32 flags)
 	 * on errors or when it'll be detached
 	 */
 	prog = bpf_prog_get(prog_fd);
-	if (IS_ERR(prog))
-		return PTR_ERR(prog);
+	if (IS_ERR(prog)) {
+		err = PTR_ERR(prog);
+		goto out_dev_put;
+	}
 
 	fd = do_hid_bpf_attach_prog(hdev, prog_fd, prog, flags);
-	if (fd < 0)
-		bpf_prog_put(prog);
+	if (fd < 0) {
+		err = fd;
+		goto out_prog_put;
+	}
 
 	return fd;
+
+ out_prog_put:
+	bpf_prog_put(prog);
+ out_dev_put:
+	put_device(dev);
+	return err;
 }
 
 /**
@@ -345,8 +355,10 @@  hid_bpf_allocate_context(unsigned int hid_id)
 	hdev = to_hid_device(dev);
 
 	ctx_kern = kzalloc(sizeof(*ctx_kern), GFP_KERNEL);
-	if (!ctx_kern)
+	if (!ctx_kern) {
+		put_device(dev);
 		return NULL;
+	}
 
 	ctx_kern->ctx.hid = hdev;
 
@@ -363,10 +375,15 @@  noinline void
 hid_bpf_release_context(struct hid_bpf_ctx *ctx)
 {
 	struct hid_bpf_ctx_kern *ctx_kern;
+	struct hid_device *hid;
 
 	ctx_kern = container_of(ctx, struct hid_bpf_ctx_kern, ctx);
+	hid = (struct hid_device *)ctx_kern->ctx.hid; /* ignore const */
 
 	kfree(ctx_kern);
+
+	/* get_device() is called by bus_find_device() */
+	put_device(&hid->dev);
 }
 
 /**
diff --git a/drivers/hid/bpf/hid_bpf_jmp_table.c b/drivers/hid/bpf/hid_bpf_jmp_table.c
index 12f7cebddd73..85a24bc0ea25 100644
--- a/drivers/hid/bpf/hid_bpf_jmp_table.c
+++ b/drivers/hid/bpf/hid_bpf_jmp_table.c
@@ -196,6 +196,7 @@  static void __hid_bpf_do_release_prog(int map_fd, unsigned int idx)
 static void hid_bpf_release_progs(struct work_struct *work)
 {
 	int i, j, n, map_fd = -1;
+	bool hdev_destroyed;
 
 	if (!jmp_table.map)
 		return;
@@ -220,6 +221,12 @@  static void hid_bpf_release_progs(struct work_struct *work)
 		if (entry->hdev) {
 			hdev = entry->hdev;
 			type = entry->type;
+			/*
+			 * hdev is still valid, even if we are called after hid_destroy_device():
+			 * when hid_bpf_attach() gets called, it takes a ref on the dev through
+			 * bus_find_device()
+			 */
+			hdev_destroyed = hdev->bpf.destroyed;
 
 			hid_bpf_populate_hdev(hdev, type);
 
@@ -232,12 +239,18 @@  static void hid_bpf_release_progs(struct work_struct *work)
 				if (test_bit(next->idx, jmp_table.enabled))
 					continue;
 
-				if (next->hdev == hdev && next->type == type)
+				if (next->hdev == hdev && next->type == type) {
+					/*
+					 * clear the hdev reference and decrement the device ref
+					 * that was taken during bus_find_device() while calling
+					 * hid_bpf_attach()
+					 */
 					next->hdev = NULL;
+					put_device(&hdev->dev);
 			}
 
-			/* if type was rdesc fixup, reconnect device */
-			if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP)
+			/* if type was rdesc fixup and the device is not gone, reconnect device */
+			if (type == HID_BPF_PROG_TYPE_RDESC_FIXUP && !hdev_destroyed)
 				hid_bpf_reconnect(hdev);
 		}
 	}