diff mbox

mm: fix race between kmem_cache destroy, create and deactivate

Message ID 20180521174116.171846-1-shakeelb@google.com (mailing list archive)
State New, archived
Headers show

Commit Message

Shakeel Butt May 21, 2018, 5:41 p.m. UTC
The memcg kmem cache creation and deactivation (SLUB only) is
asynchronous. If a root kmem cache is destroyed whose memcg cache is in
the process of creation or deactivation, the kernel may crash.

Example of one such crash:
	general protection fault: 0000 [#1] SMP PTI
	CPU: 1 PID: 1721 Comm: kworker/14:1 Not tainted 4.17.0-smp
	...
	Workqueue: memcg_kmem_cache kmemcg_deactivate_workfn
	RIP: 0010:has_cpu_slab
	...
	Call Trace:
	? on_each_cpu_cond
	__kmem_cache_shrink
	kmemcg_cache_deact_after_rcu
	kmemcg_deactivate_workfn
	process_one_work
	worker_thread
	kthread
	ret_from_fork+0x35/0x40

This issue is due to the lack of reference counting for the root
kmem_caches. There exist a refcount in kmem_cache but it is actually a
count of aliases i.e. number of kmem_caches merged together.

This patch make alias count explicit and adds reference counting to the
root kmem_caches. The reference of a root kmem cache is elevated on
merge and while its memcg kmem_cache is in the process of creation or
deactivation.

Signed-off-by: Shakeel Butt <shakeelb@google.com>
---
 include/linux/slab.h     |  2 +
 include/linux/slab_def.h |  3 +-
 include/linux/slub_def.h |  3 +-
 mm/memcontrol.c          |  7 ++++
 mm/slab.c                |  4 +-
 mm/slab.h                |  5 ++-
 mm/slab_common.c         | 84 ++++++++++++++++++++++++++++++----------
 mm/slub.c                | 14 ++++---
 8 files changed, 90 insertions(+), 32 deletions(-)

Comments

Andrew Morton May 21, 2018, 6:42 p.m. UTC | #1
On Mon, 21 May 2018 10:41:16 -0700 Shakeel Butt <shakeelb@google.com> wrote:

> The memcg kmem cache creation and deactivation (SLUB only) is
> asynchronous. If a root kmem cache is destroyed whose memcg cache is in
> the process of creation or deactivation, the kernel may crash.
> 
> Example of one such crash:
> 	general protection fault: 0000 [#1] SMP PTI
> 	CPU: 1 PID: 1721 Comm: kworker/14:1 Not tainted 4.17.0-smp
> 	...
> 	Workqueue: memcg_kmem_cache kmemcg_deactivate_workfn
> 	RIP: 0010:has_cpu_slab
> 	...
> 	Call Trace:
> 	? on_each_cpu_cond
> 	__kmem_cache_shrink
> 	kmemcg_cache_deact_after_rcu
> 	kmemcg_deactivate_workfn
> 	process_one_work
> 	worker_thread
> 	kthread
> 	ret_from_fork+0x35/0x40
> 
> This issue is due to the lack of reference counting for the root
> kmem_caches. There exist a refcount in kmem_cache but it is actually a
> count of aliases i.e. number of kmem_caches merged together.
> 
> This patch make alias count explicit and adds reference counting to the
> root kmem_caches. The reference of a root kmem cache is elevated on
> merge and while its memcg kmem_cache is in the process of creation or
> deactivation.
> 

The patch seems depressingly complex.

And a bit underdocumented...

> --- a/include/linux/slab.h
> +++ b/include/linux/slab.h
> @@ -674,6 +674,8 @@ struct memcg_cache_params {
>  };
>  
>  int memcg_update_all_caches(int num_memcgs);
> +bool kmem_cache_tryget(struct kmem_cache *s);
> +void kmem_cache_put(struct kmem_cache *s);
>  
>  /**
>   * kmalloc_array - allocate memory for an array.
> diff --git a/include/linux/slab_def.h b/include/linux/slab_def.h
> index d9228e4d0320..4bb22c89a740 100644
> --- a/include/linux/slab_def.h
> +++ b/include/linux/slab_def.h
> @@ -41,7 +41,8 @@ struct kmem_cache {
>  /* 4) cache creation/removal */
>  	const char *name;
>  	struct list_head list;
> -	int refcount;
> +	refcount_t refcount;
> +	int alias_count;

The semantic meaning of these two?  What locking protects alias_count?

>  	int object_size;
>  	int align;
>  
> diff --git a/include/linux/slub_def.h b/include/linux/slub_def.h
> index 3773e26c08c1..532d4b6f83ed 100644
> --- a/include/linux/slub_def.h
> +++ b/include/linux/slub_def.h
> @@ -97,7 +97,8 @@ struct kmem_cache {
>  	struct kmem_cache_order_objects max;
>  	struct kmem_cache_order_objects min;
>  	gfp_t allocflags;	/* gfp flags to use on each alloc */
> -	int refcount;		/* Refcount for slab cache destroy */
> +	refcount_t refcount;	/* Refcount for slab cache destroy */
> +	int alias_count;	/* Number of root kmem caches merged */

"merged" what with what in what manner?

>  	void (*ctor)(void *);
>  	unsigned int inuse;		/* Offset to metadata */
>  	unsigned int align;		/* Alignment */
>
> ...
>
> --- a/mm/slab.h
> +++ b/mm/slab.h
> @@ -25,7 +25,8 @@ struct kmem_cache {
>  	unsigned int useroffset;/* Usercopy region offset */
>  	unsigned int usersize;	/* Usercopy region size */
>  	const char *name;	/* Slab name for sysfs */
> -	int refcount;		/* Use counter */
> +	refcount_t refcount;	/* Use counter */
> +	int alias_count;

Semantic meaning/usage of alias_count?  Locking for it?

>  	void (*ctor)(void *);	/* Called on object slot creation */
>  	struct list_head list;	/* List of all slab caches on the system */
>  };
>
> ...
>
> +bool kmem_cache_tryget(struct kmem_cache *s)
> +{
> +	if (is_root_cache(s))
> +		return refcount_inc_not_zero(&s->refcount);
> +	return false;
> +}
> +
> +void kmem_cache_put(struct kmem_cache *s)
> +{
> +	if (is_root_cache(s) &&
> +	    refcount_dec_and_test(&s->refcount))
> +		__kmem_cache_destroy(s, true);
> +}
> +
> +void kmem_cache_put_locked(struct kmem_cache *s)
> +{
> +	if (is_root_cache(s) &&
> +	    refcount_dec_and_test(&s->refcount))
> +		__kmem_cache_destroy(s, false);
> +}

Some covering documentation for the above would be useful.  Why do they
exist, why do they only operate on the root cache? etc.

>
> ...
>
Shakeel Butt May 21, 2018, 8:12 p.m. UTC | #2
On Mon, May 21, 2018 at 11:42 AM Andrew Morton <akpm@linux-foundation.org>
wrote:

> On Mon, 21 May 2018 10:41:16 -0700 Shakeel Butt <shakeelb@google.com>
wrote:

> > The memcg kmem cache creation and deactivation (SLUB only) is
> > asynchronous. If a root kmem cache is destroyed whose memcg cache is in
> > the process of creation or deactivation, the kernel may crash.
> >
> > Example of one such crash:
> >       general protection fault: 0000 [#1] SMP PTI
> >       CPU: 1 PID: 1721 Comm: kworker/14:1 Not tainted 4.17.0-smp
> >       ...
> >       Workqueue: memcg_kmem_cache kmemcg_deactivate_workfn
> >       RIP: 0010:has_cpu_slab
> >       ...
> >       Call Trace:
> >       ? on_each_cpu_cond
> >       __kmem_cache_shrink
> >       kmemcg_cache_deact_after_rcu
> >       kmemcg_deactivate_workfn
> >       process_one_work
> >       worker_thread
> >       kthread
> >       ret_from_fork+0x35/0x40
> >
> > This issue is due to the lack of reference counting for the root
> > kmem_caches. There exist a refcount in kmem_cache but it is actually a
> > count of aliases i.e. number of kmem_caches merged together.
> >
> > This patch make alias count explicit and adds reference counting to the
> > root kmem_caches. The reference of a root kmem cache is elevated on
> > merge and while its memcg kmem_cache is in the process of creation or
> > deactivation.
> >

> The patch seems depressingly complex.

> And a bit underdocumented...


I will add more documentation to the code.

> > --- a/include/linux/slab.h
> > +++ b/include/linux/slab.h
> > @@ -674,6 +674,8 @@ struct memcg_cache_params {
> >  };
> >
> >  int memcg_update_all_caches(int num_memcgs);
> > +bool kmem_cache_tryget(struct kmem_cache *s);
> > +void kmem_cache_put(struct kmem_cache *s);
> >
> >  /**
> >   * kmalloc_array - allocate memory for an array.
> > diff --git a/include/linux/slab_def.h b/include/linux/slab_def.h
> > index d9228e4d0320..4bb22c89a740 100644
> > --- a/include/linux/slab_def.h
> > +++ b/include/linux/slab_def.h
> > @@ -41,7 +41,8 @@ struct kmem_cache {
> >  /* 4) cache creation/removal */
> >       const char *name;
> >       struct list_head list;
> > -     int refcount;
> > +     refcount_t refcount;
> > +     int alias_count;

> The semantic meaning of these two?  What locking protects alias_count?

SLAB and SLUB allow reusing existing root kmem caches. The alias_count of a
kmem cache tells the number of times this kmem cache is reused (maybe
shared_count or reused_count are better names). Basically if there were 5
root kmem cache creation request and suppose SLAB/SLUB decide to reuse the
first kmem cache created for next 4 requests then this count will be 5 and
all 5 will be pointing to the same kmem_cache object.

Before this patch, alias_count (previously named refcount) was modified
only within slab_mutex but can be read outside. It was conflated into
multiple things like shared count, reference count and unmergeable flag (if
-ve). This patch decouples the reference counting from this field and there
is no need to protect alias_count with locks.

> >       int object_size;
> >       int align;
> >
> > diff --git a/include/linux/slub_def.h b/include/linux/slub_def.h
> > index 3773e26c08c1..532d4b6f83ed 100644
> > --- a/include/linux/slub_def.h
> > +++ b/include/linux/slub_def.h
> > @@ -97,7 +97,8 @@ struct kmem_cache {
> >       struct kmem_cache_order_objects max;
> >       struct kmem_cache_order_objects min;
> >       gfp_t allocflags;       /* gfp flags to use on each alloc */
> > -     int refcount;           /* Refcount for slab cache destroy */
> > +     refcount_t refcount;    /* Refcount for slab cache destroy */
> > +     int alias_count;        /* Number of root kmem caches merged */

> "merged" what with what in what manner?

shared or reused might be better words here.

> >       void (*ctor)(void *);
> >       unsigned int inuse;             /* Offset to metadata */
> >       unsigned int align;             /* Alignment */
> >
> > ...
> >
> > --- a/mm/slab.h
> > +++ b/mm/slab.h
> > @@ -25,7 +25,8 @@ struct kmem_cache {
> >       unsigned int useroffset;/* Usercopy region offset */
> >       unsigned int usersize;  /* Usercopy region size */
> >       const char *name;       /* Slab name for sysfs */
> > -     int refcount;           /* Use counter */
> > +     refcount_t refcount;    /* Use counter */
> > +     int alias_count;

> Semantic meaning/usage of alias_count?  Locking for it?

Will add in the next version.

> >       void (*ctor)(void *);   /* Called on object slot creation */
> >       struct list_head list;  /* List of all slab caches on the system
*/
> >  };
> >
> > ...
> >
> > +bool kmem_cache_tryget(struct kmem_cache *s)
> > +{
> > +     if (is_root_cache(s))
> > +             return refcount_inc_not_zero(&s->refcount);
> > +     return false;
> > +}
> > +
> > +void kmem_cache_put(struct kmem_cache *s)
> > +{
> > +     if (is_root_cache(s) &&
> > +         refcount_dec_and_test(&s->refcount))
> > +             __kmem_cache_destroy(s, true);
> > +}
> > +
> > +void kmem_cache_put_locked(struct kmem_cache *s)
> > +{
> > +     if (is_root_cache(s) &&
> > +         refcount_dec_and_test(&s->refcount))
> > +             __kmem_cache_destroy(s, false);
> > +}

> Some covering documentation for the above would be useful.  Why do they
> exist, why do they only operate on the root cache? etc.

Ack.
Christoph Lameter (Ampere) May 22, 2018, 4:43 p.m. UTC | #3
On Mon, 21 May 2018, Andrew Morton wrote:

> The patch seems depressingly complex.
>
> And a bit underdocumented...

Maybe separate out the bits that rename refcount to alias_count?

> > +	refcount_t refcount;
> > +	int alias_count;
>
> The semantic meaning of these two?  What locking protects alias_count?

slab_mutex

>
> >  	int object_size;
> >  	int align;
> >
> > diff --git a/include/linux/slub_def.h b/include/linux/slub_def.h
> > index 3773e26c08c1..532d4b6f83ed 100644
> > --- a/include/linux/slub_def.h
> > +++ b/include/linux/slub_def.h
> > @@ -97,7 +97,8 @@ struct kmem_cache {
> >  	struct kmem_cache_order_objects max;
> >  	struct kmem_cache_order_objects min;
> >  	gfp_t allocflags;	/* gfp flags to use on each alloc */
> > -	int refcount;		/* Refcount for slab cache destroy */
> > +	refcount_t refcount;	/* Refcount for slab cache destroy */
> > +	int alias_count;	/* Number of root kmem caches merged */
>
> "merged" what with what in what manner?

That is a basic SLUB feature.
kernel test robot May 22, 2018, 9:14 p.m. UTC | #4
Hi Shakeel,

Thank you for the patch! Yet something to improve:

[auto build test ERROR on mmotm/master]
[also build test ERROR on v4.17-rc6 next-20180517]
[if your patch is applied to the wrong git tree, please drop us a note to help improve the system]

url:    https://github.com/0day-ci/linux/commits/Shakeel-Butt/mm-fix-race-between-kmem_cache-destroy-create-and-deactivate/20180523-041715
base:   git://git.cmpxchg.org/linux-mmotm.git master
config: i386-randconfig-x009-201820 (attached as .config)
compiler: gcc-7 (Debian 7.3.0-16) 7.3.0
reproduce:
        # save the attached .config to linux build tree
        make ARCH=i386 

All errors (new ones prefixed by >>):

   mm/slub.c: In function '__kmem_cache_alias':
>> mm/slub.c:4251:4: error: implicit declaration of function 'kmem_cache_put_locked'; did you mean 'kmem_cache_init_late'? [-Werror=implicit-function-declaration]
       kmem_cache_put_locked(s);
       ^~~~~~~~~~~~~~~~~~~~~
       kmem_cache_init_late
   cc1: some warnings being treated as errors

vim +4251 mm/slub.c

  4226	
  4227	struct kmem_cache *
  4228	__kmem_cache_alias(const char *name, unsigned int size, unsigned int align,
  4229			   slab_flags_t flags, void (*ctor)(void *))
  4230	{
  4231		struct kmem_cache *s, *c;
  4232	
  4233		s = find_mergeable(size, align, flags, name, ctor);
  4234		if (s && kmem_cache_tryget(s)) {
  4235			s->alias_count++;
  4236	
  4237			/*
  4238			 * Adjust the object sizes so that we clear
  4239			 * the complete object on kzalloc.
  4240			 */
  4241			s->object_size = max(s->object_size, size);
  4242			s->inuse = max(s->inuse, ALIGN(size, sizeof(void *)));
  4243	
  4244			for_each_memcg_cache(c, s) {
  4245				c->object_size = s->object_size;
  4246				c->inuse = max(c->inuse, ALIGN(size, sizeof(void *)));
  4247			}
  4248	
  4249			if (sysfs_slab_alias(s, name)) {
  4250				s->alias_count--;
> 4251				kmem_cache_put_locked(s);
  4252				s = NULL;
  4253			}
  4254		}
  4255	
  4256		return s;
  4257	}
  4258	

---
0-DAY kernel test infrastructure                Open Source Technology Center
https://lists.01.org/pipermail/kbuild-all                   Intel Corporation
kernel test robot May 22, 2018, 10:07 p.m. UTC | #5
Hi Shakeel,

Thank you for the patch! Yet something to improve:

[auto build test ERROR on mmotm/master]
[also build test ERROR on v4.17-rc6 next-20180517]
[if your patch is applied to the wrong git tree, please drop us a note to help improve the system]

url:    https://github.com/0day-ci/linux/commits/Shakeel-Butt/mm-fix-race-between-kmem_cache-destroy-create-and-deactivate/20180523-041715
base:   git://git.cmpxchg.org/linux-mmotm.git master
config: i386-randconfig-a0-201820 (attached as .config)
compiler: gcc-4.9 (Debian 4.9.4-2) 4.9.4
reproduce:
        # save the attached .config to linux build tree
        make ARCH=i386 

All errors (new ones prefixed by >>):

   mm/slub.c: In function '__kmem_cache_alias':
>> mm/slub.c:4251:4: error: implicit declaration of function 'kmem_cache_put_locked' [-Werror=implicit-function-declaration]
       kmem_cache_put_locked(s);
       ^
   cc1: some warnings being treated as errors

vim +/kmem_cache_put_locked +4251 mm/slub.c

  4226	
  4227	struct kmem_cache *
  4228	__kmem_cache_alias(const char *name, unsigned int size, unsigned int align,
  4229			   slab_flags_t flags, void (*ctor)(void *))
  4230	{
  4231		struct kmem_cache *s, *c;
  4232	
  4233		s = find_mergeable(size, align, flags, name, ctor);
  4234		if (s && kmem_cache_tryget(s)) {
  4235			s->alias_count++;
  4236	
  4237			/*
  4238			 * Adjust the object sizes so that we clear
  4239			 * the complete object on kzalloc.
  4240			 */
  4241			s->object_size = max(s->object_size, size);
  4242			s->inuse = max(s->inuse, ALIGN(size, sizeof(void *)));
  4243	
  4244			for_each_memcg_cache(c, s) {
  4245				c->object_size = s->object_size;
  4246				c->inuse = max(c->inuse, ALIGN(size, sizeof(void *)));
  4247			}
  4248	
  4249			if (sysfs_slab_alias(s, name)) {
  4250				s->alias_count--;
> 4251				kmem_cache_put_locked(s);
  4252				s = NULL;
  4253			}
  4254		}
  4255	
  4256		return s;
  4257	}
  4258	

---
0-DAY kernel test infrastructure                Open Source Technology Center
https://lists.01.org/pipermail/kbuild-all                   Intel Corporation
diff mbox

Patch

diff --git a/include/linux/slab.h b/include/linux/slab.h
index 9ebe659bd4a5..4c28f2483a22 100644
--- a/include/linux/slab.h
+++ b/include/linux/slab.h
@@ -674,6 +674,8 @@  struct memcg_cache_params {
 };
 
 int memcg_update_all_caches(int num_memcgs);
+bool kmem_cache_tryget(struct kmem_cache *s);
+void kmem_cache_put(struct kmem_cache *s);
 
 /**
  * kmalloc_array - allocate memory for an array.
diff --git a/include/linux/slab_def.h b/include/linux/slab_def.h
index d9228e4d0320..4bb22c89a740 100644
--- a/include/linux/slab_def.h
+++ b/include/linux/slab_def.h
@@ -41,7 +41,8 @@  struct kmem_cache {
 /* 4) cache creation/removal */
 	const char *name;
 	struct list_head list;
-	int refcount;
+	refcount_t refcount;
+	int alias_count;
 	int object_size;
 	int align;
 
diff --git a/include/linux/slub_def.h b/include/linux/slub_def.h
index 3773e26c08c1..532d4b6f83ed 100644
--- a/include/linux/slub_def.h
+++ b/include/linux/slub_def.h
@@ -97,7 +97,8 @@  struct kmem_cache {
 	struct kmem_cache_order_objects max;
 	struct kmem_cache_order_objects min;
 	gfp_t allocflags;	/* gfp flags to use on each alloc */
-	int refcount;		/* Refcount for slab cache destroy */
+	refcount_t refcount;	/* Refcount for slab cache destroy */
+	int alias_count;	/* Number of root kmem caches merged */
 	void (*ctor)(void *);
 	unsigned int inuse;		/* Offset to metadata */
 	unsigned int align;		/* Alignment */
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index bdb8028c806c..ab5673dbfc4e 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2185,6 +2185,7 @@  static void memcg_kmem_cache_create_func(struct work_struct *w)
 	memcg_create_kmem_cache(memcg, cachep);
 
 	css_put(&memcg->css);
+	kmem_cache_put(cachep);
 	kfree(cw);
 }
 
@@ -2200,6 +2201,12 @@  static void __memcg_schedule_kmem_cache_create(struct mem_cgroup *memcg,
 	if (!cw)
 		return;
 
+	/* Make sure root kmem cache does not get destroyed in the middle */
+	if (!kmem_cache_tryget(cachep)) {
+		kfree(cw);
+		return;
+	}
+
 	css_get(&memcg->css);
 
 	cw->memcg = memcg;
diff --git a/mm/slab.c b/mm/slab.c
index c1fe8099b3cd..080732f5f20d 100644
--- a/mm/slab.c
+++ b/mm/slab.c
@@ -1883,8 +1883,8 @@  __kmem_cache_alias(const char *name, unsigned int size, unsigned int align,
 	struct kmem_cache *cachep;
 
 	cachep = find_mergeable(size, align, flags, name, ctor);
-	if (cachep) {
-		cachep->refcount++;
+	if (cachep && kmem_cache_tryget(cachep)) {
+		cachep->alias_count++;
 
 		/*
 		 * Adjust the object sizes so that we clear
diff --git a/mm/slab.h b/mm/slab.h
index 68bdf498da3b..25962ab75ec1 100644
--- a/mm/slab.h
+++ b/mm/slab.h
@@ -25,7 +25,8 @@  struct kmem_cache {
 	unsigned int useroffset;/* Usercopy region offset */
 	unsigned int usersize;	/* Usercopy region size */
 	const char *name;	/* Slab name for sysfs */
-	int refcount;		/* Use counter */
+	refcount_t refcount;	/* Use counter */
+	int alias_count;
 	void (*ctor)(void *);	/* Called on object slot creation */
 	struct list_head list;	/* List of all slab caches on the system */
 };
@@ -295,7 +296,7 @@  extern void slab_init_memcg_params(struct kmem_cache *);
 extern void memcg_link_cache(struct kmem_cache *s);
 extern void slab_deactivate_memcg_cache_rcu_sched(struct kmem_cache *s,
 				void (*deact_fn)(struct kmem_cache *));
-
+extern void kmem_cache_put_locked(struct kmem_cache *s);
 #else /* CONFIG_MEMCG && !CONFIG_SLOB */
 
 /* If !memcg, all caches are root. */
diff --git a/mm/slab_common.c b/mm/slab_common.c
index b0dd9db1eb2f..390eb47486fd 100644
--- a/mm/slab_common.c
+++ b/mm/slab_common.c
@@ -306,7 +306,7 @@  int slab_unmergeable(struct kmem_cache *s)
 	/*
 	 * We may have set a slab to be unmergeable during bootstrap.
 	 */
-	if (s->refcount < 0)
+	if (s->alias_count < 0)
 		return 1;
 
 	return 0;
@@ -391,7 +391,8 @@  static struct kmem_cache *create_cache(const char *name,
 	if (err)
 		goto out_free_cache;
 
-	s->refcount = 1;
+	s->alias_count = 1;
+	refcount_set(&s->refcount, 1);
 	list_add(&s->list, &slab_caches);
 	memcg_link_cache(s);
 out:
@@ -611,6 +612,13 @@  void memcg_create_kmem_cache(struct mem_cgroup *memcg,
 	if (memcg->kmem_state != KMEM_ONLINE)
 		goto out_unlock;
 
+	/*
+	 * The root cache has been requested to be destroyed while its memcg
+	 * cache was in creation queue.
+	 */
+	if (!root_cache->alias_count)
+		goto out_unlock;
+
 	idx = memcg_cache_id(memcg);
 	arr = rcu_dereference_protected(root_cache->memcg_params.memcg_caches,
 					lockdep_is_held(&slab_mutex));
@@ -663,6 +671,8 @@  static void kmemcg_deactivate_workfn(struct work_struct *work)
 {
 	struct kmem_cache *s = container_of(work, struct kmem_cache,
 					    memcg_params.deact_work);
+	struct kmem_cache *root = s->memcg_params.root_cache;
+	struct mem_cgroup *memcg = s->memcg_params.memcg;
 
 	get_online_cpus();
 	get_online_mems();
@@ -677,7 +687,8 @@  static void kmemcg_deactivate_workfn(struct work_struct *work)
 	put_online_cpus();
 
 	/* done, put the ref from slab_deactivate_memcg_cache_rcu_sched() */
-	css_put(&s->memcg_params.memcg->css);
+	css_put(&memcg->css);
+	kmem_cache_put(root);
 }
 
 static void kmemcg_deactivate_rcufn(struct rcu_head *head)
@@ -712,6 +723,10 @@  void slab_deactivate_memcg_cache_rcu_sched(struct kmem_cache *s,
 	    WARN_ON_ONCE(s->memcg_params.deact_fn))
 		return;
 
+	/* Make sure root kmem_cache does not get destroyed in the middle */
+	if (!kmem_cache_tryget(s->memcg_params.root_cache))
+		return;
+
 	/* pin memcg so that @s doesn't get destroyed in the middle */
 	css_get(&s->memcg_params.memcg->css);
 
@@ -838,21 +853,17 @@  void slab_kmem_cache_release(struct kmem_cache *s)
 	kmem_cache_free(kmem_cache, s);
 }
 
-void kmem_cache_destroy(struct kmem_cache *s)
+static void __kmem_cache_destroy(struct kmem_cache *s, bool lock)
 {
 	int err;
 
-	if (unlikely(!s))
-		return;
-
-	get_online_cpus();
-	get_online_mems();
+	if (lock) {
+		get_online_cpus();
+		get_online_mems();
+		mutex_lock(&slab_mutex);
+	}
 
-	mutex_lock(&slab_mutex);
-
-	s->refcount--;
-	if (s->refcount)
-		goto out_unlock;
+	VM_BUG_ON(s->alias_count);
 
 	err = shutdown_memcg_caches(s);
 	if (!err)
@@ -863,11 +874,42 @@  void kmem_cache_destroy(struct kmem_cache *s)
 		       s->name);
 		dump_stack();
 	}
-out_unlock:
-	mutex_unlock(&slab_mutex);
 
-	put_online_mems();
-	put_online_cpus();
+	if (lock) {
+		mutex_unlock(&slab_mutex);
+		put_online_mems();
+		put_online_cpus();
+	}
+}
+
+bool kmem_cache_tryget(struct kmem_cache *s)
+{
+	if (is_root_cache(s))
+		return refcount_inc_not_zero(&s->refcount);
+	return false;
+}
+
+void kmem_cache_put(struct kmem_cache *s)
+{
+	if (is_root_cache(s) &&
+	    refcount_dec_and_test(&s->refcount))
+		__kmem_cache_destroy(s, true);
+}
+
+void kmem_cache_put_locked(struct kmem_cache *s)
+{
+	if (is_root_cache(s) &&
+	    refcount_dec_and_test(&s->refcount))
+		__kmem_cache_destroy(s, false);
+}
+
+void kmem_cache_destroy(struct kmem_cache *s)
+{
+	if (unlikely(!s))
+		return;
+
+	s->alias_count--;
+	kmem_cache_put(s);
 }
 EXPORT_SYMBOL(kmem_cache_destroy);
 
@@ -919,7 +961,8 @@  void __init create_boot_cache(struct kmem_cache *s, const char *name,
 		panic("Creation of kmalloc slab %s size=%u failed. Reason %d\n",
 					name, size, err);
 
-	s->refcount = -1;	/* Exempt from merging for now */
+	s->alias_count = -1;	/* Exempt from merging for now */
+	refcount_set(&s->refcount, 1);
 }
 
 struct kmem_cache *__init create_kmalloc_cache(const char *name,
@@ -934,7 +977,8 @@  struct kmem_cache *__init create_kmalloc_cache(const char *name,
 	create_boot_cache(s, name, size, flags, useroffset, usersize);
 	list_add(&s->list, &slab_caches);
 	memcg_link_cache(s);
-	s->refcount = 1;
+	s->alias_count = 1;
+	refcount_set(&s->refcount, 1);
 	return s;
 }
 
diff --git a/mm/slub.c b/mm/slub.c
index 48f75872c356..2e45f7febc6e 100644
--- a/mm/slub.c
+++ b/mm/slub.c
@@ -4270,8 +4270,8 @@  __kmem_cache_alias(const char *name, unsigned int size, unsigned int align,
 	struct kmem_cache *s, *c;
 
 	s = find_mergeable(size, align, flags, name, ctor);
-	if (s) {
-		s->refcount++;
+	if (s && kmem_cache_tryget(s)) {
+		s->alias_count++;
 
 		/*
 		 * Adjust the object sizes so that we clear
@@ -4286,7 +4286,8 @@  __kmem_cache_alias(const char *name, unsigned int size, unsigned int align,
 		}
 
 		if (sysfs_slab_alias(s, name)) {
-			s->refcount--;
+			s->alias_count--;
+			kmem_cache_put_locked(s);
 			s = NULL;
 		}
 	}
@@ -5009,7 +5010,8 @@  SLAB_ATTR_RO(ctor);
 
 static ssize_t aliases_show(struct kmem_cache *s, char *buf)
 {
-	return sprintf(buf, "%d\n", s->refcount < 0 ? 0 : s->refcount - 1);
+	return sprintf(buf, "%d\n",
+		       s->alias_count < 0 ? 0 : s->alias_count - 1);
 }
 SLAB_ATTR_RO(aliases);
 
@@ -5162,7 +5164,7 @@  static ssize_t trace_store(struct kmem_cache *s, const char *buf,
 	 * as well as cause other issues like converting a mergeable
 	 * cache into an umergeable one.
 	 */
-	if (s->refcount > 1)
+	if (s->alias_count > 1)
 		return -EINVAL;
 
 	s->flags &= ~SLAB_TRACE;
@@ -5280,7 +5282,7 @@  static ssize_t failslab_show(struct kmem_cache *s, char *buf)
 static ssize_t failslab_store(struct kmem_cache *s, const char *buf,
 							size_t length)
 {
-	if (s->refcount > 1)
+	if (s->alias_count > 1)
 		return -EINVAL;
 
 	s->flags &= ~SLAB_FAILSLAB;