diff mbox series

[v2,2/2] selftests: test_zswap: add test for hierarchical zswap.writeback

Message ID 20240816144344.18135-2-me@yhndnzj.com (mailing list archive)
State New
Headers show
Series [v2,1/2] mm/memcontrol: respect zswap.writeback setting from parent cg too | expand

Commit Message

Mike Yuan Aug. 16, 2024, 2:44 p.m. UTC
Ensure that zswap.writeback check goes up the cgroup tree.

Signed-off-by: Mike Yuan <me@yhndnzj.com>
---
 tools/testing/selftests/cgroup/test_zswap.c | 69 ++++++++++++++-------
 1 file changed, 48 insertions(+), 21 deletions(-)

Comments

Yosry Ahmed Aug. 19, 2024, 7:19 p.m. UTC | #1
On Fri, Aug 16, 2024 at 7:44 AM Mike Yuan <me@yhndnzj.com> wrote:
>
> Ensure that zswap.writeback check goes up the cgroup tree.

Too concise :) Perhaps a little bit of description of what you are
doing would be helpful.

>
> Signed-off-by: Mike Yuan <me@yhndnzj.com>
> ---
>  tools/testing/selftests/cgroup/test_zswap.c | 69 ++++++++++++++-------
>  1 file changed, 48 insertions(+), 21 deletions(-)
>
> diff --git a/tools/testing/selftests/cgroup/test_zswap.c b/tools/testing/selftests/cgroup/test_zswap.c
> index 190096017f80..7da6f9dc1066 100644
> --- a/tools/testing/selftests/cgroup/test_zswap.c
> +++ b/tools/testing/selftests/cgroup/test_zswap.c
> @@ -263,15 +263,13 @@ static int test_zswapin(const char *root)
>  static int attempt_writeback(const char *cgroup, void *arg)
>  {
>         long pagesize = sysconf(_SC_PAGESIZE);
> -       char *test_group = arg;
>         size_t memsize = MB(4);
>         char buf[pagesize];
>         long zswap_usage;
> -       bool wb_enabled;
> +       bool wb_enabled = *(bool *) arg;
>         int ret = -1;
>         char *mem;
>
> -       wb_enabled = cg_read_long(test_group, "memory.zswap.writeback");
>         mem = (char *)malloc(memsize);
>         if (!mem)
>                 return ret;
> @@ -288,12 +286,12 @@ static int attempt_writeback(const char *cgroup, void *arg)
>                 memcpy(&mem[i], buf, pagesize);
>
>         /* Try and reclaim allocated memory */
> -       if (cg_write_numeric(test_group, "memory.reclaim", memsize)) {
> +       if (cg_write_numeric(cgroup, "memory.reclaim", memsize)) {
>                 ksft_print_msg("Failed to reclaim all of the requested memory\n");
>                 goto out;
>         }
>
> -       zswap_usage = cg_read_long(test_group, "memory.zswap.current");
> +       zswap_usage = cg_read_long(cgroup, "memory.zswap.current");
>
>         /* zswpin */
>         for (int i = 0; i < memsize; i += pagesize) {
> @@ -303,7 +301,7 @@ static int attempt_writeback(const char *cgroup, void *arg)
>                 }
>         }
>
> -       if (cg_write_numeric(test_group, "memory.zswap.max", zswap_usage/2))
> +       if (cg_write_numeric(cgroup, "memory.zswap.max", zswap_usage/2))
>                 goto out;
>
>         /*
> @@ -312,7 +310,7 @@ static int attempt_writeback(const char *cgroup, void *arg)
>          * If writeback is disabled, memory reclaim will fail as zswap is limited and
>          * it can't writeback to swap.
>          */
> -       ret = cg_write_numeric(test_group, "memory.reclaim", memsize);
> +       ret = cg_write_numeric(cgroup, "memory.reclaim", memsize);
>         if (!wb_enabled)
>                 ret = (ret == -EAGAIN) ? 0 : -1;
>
> @@ -321,12 +319,38 @@ static int attempt_writeback(const char *cgroup, void *arg)
>         return ret;
>  }
>
> +static int test_zswap_writeback_one(const char *cgroup, bool wb)
> +{
> +       long zswpwb_before, zswpwb_after;
> +
> +       zswpwb_before = get_cg_wb_count(cgroup);
> +       if (zswpwb_before != 0) {
> +               ksft_print_msg("zswpwb_before = %ld instead of 0\n", zswpwb_before);
> +               return -1;
> +       }
> +
> +       if (cg_run(cgroup, attempt_writeback, (void *) &wb))
> +               return -1;
> +
> +       /* Verify that zswap writeback occurred only if writeback was enabled */
> +       zswpwb_after = get_cg_wb_count(cgroup);
> +       if (zswpwb_after < 0)
> +               return -1;
> +
> +       if (wb != !!zswpwb_after) {
> +               ksft_print_msg("zswpwb_after is %ld while wb is %s",
> +                               zswpwb_after, wb ? "enabled" : "disabled");
> +               return -1;
> +       }
> +
> +       return 0;
> +}
> +
>  /* Test to verify the zswap writeback path */
>  static int test_zswap_writeback(const char *root, bool wb)
>  {
> -       long zswpwb_before, zswpwb_after;
>         int ret = KSFT_FAIL;
> -       char *test_group;
> +       char *test_group, *test_group_child = NULL;
>
>         test_group = cg_name(root, "zswap_writeback_test");
>         if (!test_group)
> @@ -336,29 +360,32 @@ static int test_zswap_writeback(const char *root, bool wb)
>         if (cg_write(test_group, "memory.zswap.writeback", wb ? "1" : "0"))
>                 goto out;
>
> -       zswpwb_before = get_cg_wb_count(test_group);
> -       if (zswpwb_before != 0) {
> -               ksft_print_msg("zswpwb_before = %ld instead of 0\n", zswpwb_before);
> +       if (test_zswap_writeback_one(test_group, wb))
>                 goto out;
> -       }
>
> -       if (cg_run(test_group, attempt_writeback, (void *) test_group))
> +       if (cg_write(test_group, "memory.zswap.max", "max"))
> +               goto out;

Why is this needed? Isn't this the default value?

> +       if (cg_write(test_group, "cgroup.subtree_control", "+memory"))
>                 goto out;
>
> -       /* Verify that zswap writeback occurred only if writeback was enabled */
> -       zswpwb_after = get_cg_wb_count(test_group);
> -       if (zswpwb_after < 0)
> +       test_group_child = cg_name(test_group, "zswap_writeback_test_child");
> +       if (!test_group_child)
> +               goto out;
> +       if (cg_create(test_group_child))
> +               goto out;

I'd rather have all the hierarchy setup at the beginning of the test,
before the actual test logic. I don't feel strongly about it though.

> +       if (cg_write(test_group_child, "memory.zswap.writeback", "1"))
>                 goto out;

Is the idea here that we always hardcode the child's zswap.writeback
to 1, and the parent's zswap.writeback changes from 0 to 1, and we
check that the parent's value is what matters?
I think we need a comment here.

TBH, I expected a separate test that checks different combinations of
parent and child values (e.g. also verifies that if the parent is
enabled but child is disabled, writeback is disabled).

>
> -       if (wb != !!zswpwb_after) {
> -               ksft_print_msg("zswpwb_after is %ld while wb is %s",
> -                               zswpwb_after, wb ? "enabled" : "disabled");
> +       if (test_zswap_writeback_one(test_group_child, wb))
>                 goto out;
> -       }
>
>         ret = KSFT_PASS;
>
>  out:
> +       if (test_group_child) {
> +               cg_destroy(test_group_child);
> +               free(test_group_child);
> +       }
>         cg_destroy(test_group);
>         free(test_group);
>         return ret;
> --
> 2.46.0
>
>
Mike Yuan Aug. 20, 2024, 9:43 a.m. UTC | #2
On 2024-08-19 at 12:19 -0700, Yosry Ahmed wrote:
> On Fri, Aug 16, 2024 at 7:44 AM Mike Yuan <me@yhndnzj.com> wrote:
> > 
> > Ensure that zswap.writeback check goes up the cgroup tree.
> 
> Too concise :) Perhaps a little bit of description of what you are
> doing would be helpful.

The patch has been merged into mm-unstable tree. Do I need to
send a v3 to resolve the comments?

> > 
> > Signed-off-by: Mike Yuan <me@yhndnzj.com>
> > ---
> >  tools/testing/selftests/cgroup/test_zswap.c | 69 ++++++++++++++---
> > ----
> >  1 file changed, 48 insertions(+), 21 deletions(-)
> > 
> > diff --git a/tools/testing/selftests/cgroup/test_zswap.c
> > b/tools/testing/selftests/cgroup/test_zswap.c
> > index 190096017f80..7da6f9dc1066 100644
> > --- a/tools/testing/selftests/cgroup/test_zswap.c
> > +++ b/tools/testing/selftests/cgroup/test_zswap.c
> > @@ -263,15 +263,13 @@ static int test_zswapin(const char *root)
> >  static int attempt_writeback(const char *cgroup, void *arg)
> >  {
> >         long pagesize = sysconf(_SC_PAGESIZE);
> > -       char *test_group = arg;
> >         size_t memsize = MB(4);
> >         char buf[pagesize];
> >         long zswap_usage;
> > -       bool wb_enabled;
> > +       bool wb_enabled = *(bool *) arg;
> >         int ret = -1;
> >         char *mem;
> > 
> > -       wb_enabled = cg_read_long(test_group,
> > "memory.zswap.writeback");
> >         mem = (char *)malloc(memsize);
> >         if (!mem)
> >                 return ret;
> > @@ -288,12 +286,12 @@ static int attempt_writeback(const char
> > *cgroup, void *arg)
> >                 memcpy(&mem[i], buf, pagesize);
> > 
> >         /* Try and reclaim allocated memory */
> > -       if (cg_write_numeric(test_group, "memory.reclaim",
> > memsize)) {
> > +       if (cg_write_numeric(cgroup, "memory.reclaim", memsize)) {
> >                 ksft_print_msg("Failed to reclaim all of the
> > requested memory\n");
> >                 goto out;
> >         }
> > 
> > -       zswap_usage = cg_read_long(test_group,
> > "memory.zswap.current");
> > +       zswap_usage = cg_read_long(cgroup, "memory.zswap.current");
> > 
> >         /* zswpin */
> >         for (int i = 0; i < memsize; i += pagesize) {
> > @@ -303,7 +301,7 @@ static int attempt_writeback(const char
> > *cgroup, void *arg)
> >                 }
> >         }
> > 
> > -       if (cg_write_numeric(test_group, "memory.zswap.max",
> > zswap_usage/2))
> > +       if (cg_write_numeric(cgroup, "memory.zswap.max",
> > zswap_usage/2))
> >                 goto out;
> > 
> >         /*
> > @@ -312,7 +310,7 @@ static int attempt_writeback(const char
> > *cgroup, void *arg)
> >          * If writeback is disabled, memory reclaim will fail as
> > zswap is limited and
> >          * it can't writeback to swap.
> >          */
> > -       ret = cg_write_numeric(test_group, "memory.reclaim",
> > memsize);
> > +       ret = cg_write_numeric(cgroup, "memory.reclaim", memsize);
> >         if (!wb_enabled)
> >                 ret = (ret == -EAGAIN) ? 0 : -1;
> > 
> > @@ -321,12 +319,38 @@ static int attempt_writeback(const char
> > *cgroup, void *arg)
> >         return ret;
> >  }
> > 
> > +static int test_zswap_writeback_one(const char *cgroup, bool wb)
> > +{
> > +       long zswpwb_before, zswpwb_after;
> > +
> > +       zswpwb_before = get_cg_wb_count(cgroup);
> > +       if (zswpwb_before != 0) {
> > +               ksft_print_msg("zswpwb_before = %ld instead of
> > 0\n", zswpwb_before);
> > +               return -1;
> > +       }
> > +
> > +       if (cg_run(cgroup, attempt_writeback, (void *) &wb))
> > +               return -1;
> > +
> > +       /* Verify that zswap writeback occurred only if writeback
> > was enabled */
> > +       zswpwb_after = get_cg_wb_count(cgroup);
> > +       if (zswpwb_after < 0)
> > +               return -1;
> > +
> > +       if (wb != !!zswpwb_after) {
> > +               ksft_print_msg("zswpwb_after is %ld while wb is
> > %s",
> > +                               zswpwb_after, wb ? "enabled" :
> > "disabled");
> > +               return -1;
> > +       }
> > +
> > +       return 0;
> > +}
> > +
> >  /* Test to verify the zswap writeback path */
> >  static int test_zswap_writeback(const char *root, bool wb)
> >  {
> > -       long zswpwb_before, zswpwb_after;
> >         int ret = KSFT_FAIL;
> > -       char *test_group;
> > +       char *test_group, *test_group_child = NULL;
> > 
> >         test_group = cg_name(root, "zswap_writeback_test");
> >         if (!test_group)
> > @@ -336,29 +360,32 @@ static int test_zswap_writeback(const char
> > *root, bool wb)
> >         if (cg_write(test_group, "memory.zswap.writeback", wb ? "1"
> > : "0"))
> >                 goto out;
> > 
> > -       zswpwb_before = get_cg_wb_count(test_group);
> > -       if (zswpwb_before != 0) {
> > -               ksft_print_msg("zswpwb_before = %ld instead of
> > 0\n", zswpwb_before);
> > +       if (test_zswap_writeback_one(test_group, wb))
> >                 goto out;
> > -       }
> > 
> > -       if (cg_run(test_group, attempt_writeback, (void *)
> > test_group))
> > +       if (cg_write(test_group, "memory.zswap.max", "max"))
> > +               goto out;
> 
> Why is this needed? Isn't this the default value?

attempt_writeback() would modify it.

> > +       if (cg_write(test_group, "cgroup.subtree_control",
> > "+memory"))
> >                 goto out;
> > 
> > -       /* Verify that zswap writeback occurred only if writeback
> > was enabled */
> > -       zswpwb_after = get_cg_wb_count(test_group);
> > -       if (zswpwb_after < 0)
> > +       test_group_child = cg_name(test_group,
> > "zswap_writeback_test_child");
> > +       if (!test_group_child)
> > +               goto out;
> > +       if (cg_create(test_group_child))
> > +               goto out;
> 
> I'd rather have all the hierarchy setup at the beginning of the test,
> before the actual test logic. I don't feel strongly about it though.
> 
> > +       if (cg_write(test_group_child, "memory.zswap.writeback",
> > "1"))
> >                 goto out;
> 
> Is the idea here that we always hardcode the child's zswap.writeback
> to 1, and the parent's zswap.writeback changes from 0 to 1, and we
> check that the parent's value is what matters?
> I think we need a comment here.

Yes, indeed.

> TBH, I expected a separate test that checks different combinations of
> parent and child values (e.g. also verifies that if the parent is
> enabled but child is disabled, writeback is disabled).

That's (implicitly) covered by the test itself IIUC? The parent cgroup
here is in turn the child of root cgroup.

> > 
> > -       if (wb != !!zswpwb_after) {
> > -               ksft_print_msg("zswpwb_after is %ld while wb is
> > %s",
> > -                               zswpwb_after, wb ? "enabled" :
> > "disabled");
> > +       if (test_zswap_writeback_one(test_group_child, wb))
> >                 goto out;
> > -       }
> > 
> >         ret = KSFT_PASS;
> > 
> >  out:
> > +       if (test_group_child) {
> > +               cg_destroy(test_group_child);
> > +               free(test_group_child);
> > +       }
> >         cg_destroy(test_group);
> >         free(test_group);
> >         return ret;
> > --
> > 2.46.0
> > 
> >
Yosry Ahmed Aug. 22, 2024, 5:47 p.m. UTC | #3
On Tue, Aug 20, 2024 at 2:44 AM Mike Yuan <me@yhndnzj.com> wrote:
>
> On 2024-08-19 at 12:19 -0700, Yosry Ahmed wrote:
> > On Fri, Aug 16, 2024 at 7:44 AM Mike Yuan <me@yhndnzj.com> wrote:
> > >
> > > Ensure that zswap.writeback check goes up the cgroup tree.
> >
> > Too concise :) Perhaps a little bit of description of what you are
> > doing would be helpful.
>
> The patch has been merged into mm-unstable tree. Do I need to
> send a v3 to resolve the comments?

You can send a new version and Andrew usually replaces them. If the
changes are too trivial sometimes Andrew is nice enough to make
amendments directly :)

>
> > >
> > > Signed-off-by: Mike Yuan <me@yhndnzj.com>
> > > ---
> > >  tools/testing/selftests/cgroup/test_zswap.c | 69 ++++++++++++++---
> > > ----
> > >  1 file changed, 48 insertions(+), 21 deletions(-)
> > >
> > > diff --git a/tools/testing/selftests/cgroup/test_zswap.c
> > > b/tools/testing/selftests/cgroup/test_zswap.c
> > > index 190096017f80..7da6f9dc1066 100644
> > > --- a/tools/testing/selftests/cgroup/test_zswap.c
> > > +++ b/tools/testing/selftests/cgroup/test_zswap.c
> > > @@ -263,15 +263,13 @@ static int test_zswapin(const char *root)
> > >  static int attempt_writeback(const char *cgroup, void *arg)
> > >  {
> > >         long pagesize = sysconf(_SC_PAGESIZE);
> > > -       char *test_group = arg;
> > >         size_t memsize = MB(4);
> > >         char buf[pagesize];
> > >         long zswap_usage;
> > > -       bool wb_enabled;
> > > +       bool wb_enabled = *(bool *) arg;
> > >         int ret = -1;
> > >         char *mem;
> > >
> > > -       wb_enabled = cg_read_long(test_group,
> > > "memory.zswap.writeback");
> > >         mem = (char *)malloc(memsize);
> > >         if (!mem)
> > >                 return ret;
> > > @@ -288,12 +286,12 @@ static int attempt_writeback(const char
> > > *cgroup, void *arg)
> > >                 memcpy(&mem[i], buf, pagesize);
> > >
> > >         /* Try and reclaim allocated memory */
> > > -       if (cg_write_numeric(test_group, "memory.reclaim",
> > > memsize)) {
> > > +       if (cg_write_numeric(cgroup, "memory.reclaim", memsize)) {
> > >                 ksft_print_msg("Failed to reclaim all of the
> > > requested memory\n");
> > >                 goto out;
> > >         }
> > >
> > > -       zswap_usage = cg_read_long(test_group,
> > > "memory.zswap.current");
> > > +       zswap_usage = cg_read_long(cgroup, "memory.zswap.current");
> > >
> > >         /* zswpin */
> > >         for (int i = 0; i < memsize; i += pagesize) {
> > > @@ -303,7 +301,7 @@ static int attempt_writeback(const char
> > > *cgroup, void *arg)
> > >                 }
> > >         }
> > >
> > > -       if (cg_write_numeric(test_group, "memory.zswap.max",
> > > zswap_usage/2))
> > > +       if (cg_write_numeric(cgroup, "memory.zswap.max",
> > > zswap_usage/2))
> > >                 goto out;
> > >
> > >         /*
> > > @@ -312,7 +310,7 @@ static int attempt_writeback(const char
> > > *cgroup, void *arg)
> > >          * If writeback is disabled, memory reclaim will fail as
> > > zswap is limited and
> > >          * it can't writeback to swap.
> > >          */
> > > -       ret = cg_write_numeric(test_group, "memory.reclaim",
> > > memsize);
> > > +       ret = cg_write_numeric(cgroup, "memory.reclaim", memsize);
> > >         if (!wb_enabled)
> > >                 ret = (ret == -EAGAIN) ? 0 : -1;
> > >
> > > @@ -321,12 +319,38 @@ static int attempt_writeback(const char
> > > *cgroup, void *arg)
> > >         return ret;
> > >  }
> > >
> > > +static int test_zswap_writeback_one(const char *cgroup, bool wb)
> > > +{
> > > +       long zswpwb_before, zswpwb_after;
> > > +
> > > +       zswpwb_before = get_cg_wb_count(cgroup);
> > > +       if (zswpwb_before != 0) {
> > > +               ksft_print_msg("zswpwb_before = %ld instead of
> > > 0\n", zswpwb_before);
> > > +               return -1;
> > > +       }
> > > +
> > > +       if (cg_run(cgroup, attempt_writeback, (void *) &wb))
> > > +               return -1;
> > > +
> > > +       /* Verify that zswap writeback occurred only if writeback
> > > was enabled */
> > > +       zswpwb_after = get_cg_wb_count(cgroup);
> > > +       if (zswpwb_after < 0)
> > > +               return -1;
> > > +
> > > +       if (wb != !!zswpwb_after) {
> > > +               ksft_print_msg("zswpwb_after is %ld while wb is
> > > %s",
> > > +                               zswpwb_after, wb ? "enabled" :
> > > "disabled");
> > > +               return -1;
> > > +       }
> > > +
> > > +       return 0;
> > > +}
> > > +
> > >  /* Test to verify the zswap writeback path */
> > >  static int test_zswap_writeback(const char *root, bool wb)
> > >  {
> > > -       long zswpwb_before, zswpwb_after;
> > >         int ret = KSFT_FAIL;
> > > -       char *test_group;
> > > +       char *test_group, *test_group_child = NULL;
> > >
> > >         test_group = cg_name(root, "zswap_writeback_test");
> > >         if (!test_group)
> > > @@ -336,29 +360,32 @@ static int test_zswap_writeback(const char
> > > *root, bool wb)
> > >         if (cg_write(test_group, "memory.zswap.writeback", wb ? "1"
> > > : "0"))
> > >                 goto out;
> > >
> > > -       zswpwb_before = get_cg_wb_count(test_group);
> > > -       if (zswpwb_before != 0) {
> > > -               ksft_print_msg("zswpwb_before = %ld instead of
> > > 0\n", zswpwb_before);
> > > +       if (test_zswap_writeback_one(test_group, wb))
> > >                 goto out;
> > > -       }
> > >
> > > -       if (cg_run(test_group, attempt_writeback, (void *)
> > > test_group))
> > > +       if (cg_write(test_group, "memory.zswap.max", "max"))
> > > +               goto out;
> >
> > Why is this needed? Isn't this the default value?
>
> attempt_writeback() would modify it.

Oh yes, missed that.

>
> > > +       if (cg_write(test_group, "cgroup.subtree_control",
> > > "+memory"))
> > >                 goto out;
> > >
> > > -       /* Verify that zswap writeback occurred only if writeback
> > > was enabled */
> > > -       zswpwb_after = get_cg_wb_count(test_group);
> > > -       if (zswpwb_after < 0)
> > > +       test_group_child = cg_name(test_group,
> > > "zswap_writeback_test_child");
> > > +       if (!test_group_child)
> > > +               goto out;
> > > +       if (cg_create(test_group_child))
> > > +               goto out;
> >
> > I'd rather have all the hierarchy setup at the beginning of the test,
> > before the actual test logic. I don't feel strongly about it though.
> >
> > > +       if (cg_write(test_group_child, "memory.zswap.writeback",
> > > "1"))
> > >                 goto out;
> >
> > Is the idea here that we always hardcode the child's zswap.writeback
> > to 1, and the parent's zswap.writeback changes from 0 to 1, and we
> > check that the parent's value is what matters?
> > I think we need a comment here.
>
> Yes, indeed.
>
> > TBH, I expected a separate test that checks different combinations of
> > parent and child values (e.g. also verifies that if the parent is
> > enabled but child is disabled, writeback is disabled).
>
> That's (implicitly) covered by the test itself IIUC? The parent cgroup
> here is in turn the child of root cgroup.

This assumes that the root has zswap writeback enabled, but that's a
fair assumption as otherwise all the writeback tests will fail.

TBH I'd prefer a standalone test rather than these implicitly tested scenarios.
diff mbox series

Patch

diff --git a/tools/testing/selftests/cgroup/test_zswap.c b/tools/testing/selftests/cgroup/test_zswap.c
index 190096017f80..7da6f9dc1066 100644
--- a/tools/testing/selftests/cgroup/test_zswap.c
+++ b/tools/testing/selftests/cgroup/test_zswap.c
@@ -263,15 +263,13 @@  static int test_zswapin(const char *root)
 static int attempt_writeback(const char *cgroup, void *arg)
 {
 	long pagesize = sysconf(_SC_PAGESIZE);
-	char *test_group = arg;
 	size_t memsize = MB(4);
 	char buf[pagesize];
 	long zswap_usage;
-	bool wb_enabled;
+	bool wb_enabled = *(bool *) arg;
 	int ret = -1;
 	char *mem;
 
-	wb_enabled = cg_read_long(test_group, "memory.zswap.writeback");
 	mem = (char *)malloc(memsize);
 	if (!mem)
 		return ret;
@@ -288,12 +286,12 @@  static int attempt_writeback(const char *cgroup, void *arg)
 		memcpy(&mem[i], buf, pagesize);
 
 	/* Try and reclaim allocated memory */
-	if (cg_write_numeric(test_group, "memory.reclaim", memsize)) {
+	if (cg_write_numeric(cgroup, "memory.reclaim", memsize)) {
 		ksft_print_msg("Failed to reclaim all of the requested memory\n");
 		goto out;
 	}
 
-	zswap_usage = cg_read_long(test_group, "memory.zswap.current");
+	zswap_usage = cg_read_long(cgroup, "memory.zswap.current");
 
 	/* zswpin */
 	for (int i = 0; i < memsize; i += pagesize) {
@@ -303,7 +301,7 @@  static int attempt_writeback(const char *cgroup, void *arg)
 		}
 	}
 
-	if (cg_write_numeric(test_group, "memory.zswap.max", zswap_usage/2))
+	if (cg_write_numeric(cgroup, "memory.zswap.max", zswap_usage/2))
 		goto out;
 
 	/*
@@ -312,7 +310,7 @@  static int attempt_writeback(const char *cgroup, void *arg)
 	 * If writeback is disabled, memory reclaim will fail as zswap is limited and
 	 * it can't writeback to swap.
 	 */
-	ret = cg_write_numeric(test_group, "memory.reclaim", memsize);
+	ret = cg_write_numeric(cgroup, "memory.reclaim", memsize);
 	if (!wb_enabled)
 		ret = (ret == -EAGAIN) ? 0 : -1;
 
@@ -321,12 +319,38 @@  static int attempt_writeback(const char *cgroup, void *arg)
 	return ret;
 }
 
+static int test_zswap_writeback_one(const char *cgroup, bool wb)
+{
+	long zswpwb_before, zswpwb_after;
+
+	zswpwb_before = get_cg_wb_count(cgroup);
+	if (zswpwb_before != 0) {
+		ksft_print_msg("zswpwb_before = %ld instead of 0\n", zswpwb_before);
+		return -1;
+	}
+
+	if (cg_run(cgroup, attempt_writeback, (void *) &wb))
+		return -1;
+
+	/* Verify that zswap writeback occurred only if writeback was enabled */
+	zswpwb_after = get_cg_wb_count(cgroup);
+	if (zswpwb_after < 0)
+		return -1;
+
+	if (wb != !!zswpwb_after) {
+		ksft_print_msg("zswpwb_after is %ld while wb is %s",
+				zswpwb_after, wb ? "enabled" : "disabled");
+		return -1;
+	}
+
+	return 0;
+}
+
 /* Test to verify the zswap writeback path */
 static int test_zswap_writeback(const char *root, bool wb)
 {
-	long zswpwb_before, zswpwb_after;
 	int ret = KSFT_FAIL;
-	char *test_group;
+	char *test_group, *test_group_child = NULL;
 
 	test_group = cg_name(root, "zswap_writeback_test");
 	if (!test_group)
@@ -336,29 +360,32 @@  static int test_zswap_writeback(const char *root, bool wb)
 	if (cg_write(test_group, "memory.zswap.writeback", wb ? "1" : "0"))
 		goto out;
 
-	zswpwb_before = get_cg_wb_count(test_group);
-	if (zswpwb_before != 0) {
-		ksft_print_msg("zswpwb_before = %ld instead of 0\n", zswpwb_before);
+	if (test_zswap_writeback_one(test_group, wb))
 		goto out;
-	}
 
-	if (cg_run(test_group, attempt_writeback, (void *) test_group))
+	if (cg_write(test_group, "memory.zswap.max", "max"))
+		goto out;
+	if (cg_write(test_group, "cgroup.subtree_control", "+memory"))
 		goto out;
 
-	/* Verify that zswap writeback occurred only if writeback was enabled */
-	zswpwb_after = get_cg_wb_count(test_group);
-	if (zswpwb_after < 0)
+	test_group_child = cg_name(test_group, "zswap_writeback_test_child");
+	if (!test_group_child)
+		goto out;
+	if (cg_create(test_group_child))
+		goto out;
+	if (cg_write(test_group_child, "memory.zswap.writeback", "1"))
 		goto out;
 
-	if (wb != !!zswpwb_after) {
-		ksft_print_msg("zswpwb_after is %ld while wb is %s",
-				zswpwb_after, wb ? "enabled" : "disabled");
+	if (test_zswap_writeback_one(test_group_child, wb))
 		goto out;
-	}
 
 	ret = KSFT_PASS;
 
 out:
+	if (test_group_child) {
+		cg_destroy(test_group_child);
+		free(test_group_child);
+	}
 	cg_destroy(test_group);
 	free(test_group);
 	return ret;