diff mbox series

[57/62] memcg: Convert object cgroups from struct page to struct slab

Message ID 20211004134650.4031813-58-willy@infradead.org (mailing list archive)
State New
Headers show
Series Separate struct slab from struct page | expand

Commit Message

Matthew Wilcox Oct. 4, 2021, 1:46 p.m. UTC
Now that slab and slub are converted to use struct slab throughout,
convert the memcg infrastructure that they use.

There is a comment in here that I would appreciate being cleared up
before this patch is merged.

Signed-off-by: Matthew Wilcox (Oracle) <willy@infradead.org>
---
 include/linux/memcontrol.h | 34 +++++++++++++--------------
 include/linux/slab_def.h   | 10 ++++----
 include/linux/slub_def.h   | 10 ++++----
 mm/kasan/common.c          |  2 +-
 mm/memcontrol.c            | 33 +++++++++++++-------------
 mm/slab.c                  | 10 ++++----
 mm/slab.h                  | 47 +++++++++++++++++++-------------------
 mm/slub.c                  |  2 +-
 8 files changed, 74 insertions(+), 74 deletions(-)

Comments

Johannes Weiner Oct. 11, 2021, 5:13 p.m. UTC | #1
CC Roman for the slab tracking bits

On Mon, Oct 04, 2021 at 02:46:45PM +0100, Matthew Wilcox (Oracle) wrote:
> @@ -537,41 +537,41 @@ static inline bool PageMemcgKmem(struct page *page)
>  }
>  
>  /*
> - * page_objcgs - get the object cgroups vector associated with a page
> - * @page: a pointer to the page struct
> + * slab_objcgs - get the object cgroups vector associated with a page
> + * @slab: a pointer to the slab struct
>   *
> - * Returns a pointer to the object cgroups vector associated with the page,
> - * or NULL. This function assumes that the page is known to have an
> + * Returns a pointer to the object cgroups vector associated with the slab,
> + * or NULL. This function assumes that the slab is known to have an
>   * associated object cgroups vector. It's not safe to call this function
>   * against pages, which might have an associated memory cgroup: e.g.
>   * kernel stack pages.
>   */
> -static inline struct obj_cgroup **page_objcgs(struct page *page)
> +static inline struct obj_cgroup **slab_objcgs(struct slab *slab)
>  {
> -	unsigned long memcg_data = READ_ONCE(page->memcg_data);
> +	unsigned long memcg_data = READ_ONCE(slab->memcg_data);
>  
> -	VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJCGS), page);
> -	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, page);
> +	VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJCGS), slab_page(slab));
> +	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, slab_page(slab));
>  
>  	return (struct obj_cgroup **)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
>  }

I like this whole patch series, but I think for memcg this is a
particularly nice cleanup.

Because right now we can have user pages pointing to a memcg, random
alloc_page(GFP_ACCOUNT) pages pointing to an objcg, and slab pages
pointing to an array of objcgs - all in the same memcg_data member.

After your patch, slab->memcg_data points to an array of objcgs,
period. The only time it doesn't is when there is a bug. Once the
memcg_data member is no longer physically shared between page and
slab, we can do:

	struct slab {
		struct obj_cgroup **objcgs;
	};

and ditch the accessor function altogether.

> - * page_objcgs_check - get the object cgroups vector associated with a page
> - * @page: a pointer to the page struct
> + * slab_objcgs_check - get the object cgroups vector associated with a page
> + * @slab: a pointer to the slab struct
>   *
> - * Returns a pointer to the object cgroups vector associated with the page,
> - * or NULL. This function is safe to use if the page can be directly associated
> + * Returns a pointer to the object cgroups vector associated with the slab,
> + * or NULL. This function is safe to use if the slab can be directly associated
>   * with a memory cgroup.
>   */
> -static inline struct obj_cgroup **page_objcgs_check(struct page *page)
> +static inline struct obj_cgroup **slab_objcgs_check(struct slab *slab)
>  {
> -	unsigned long memcg_data = READ_ONCE(page->memcg_data);
> +	unsigned long memcg_data = READ_ONCE(slab->memcg_data);
>  
>  	if (!memcg_data || !(memcg_data & MEMCG_DATA_OBJCGS))
>  		return NULL;
>  
> -	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, page);
> +	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, slab_page(slab));
>  
>  	return (struct obj_cgroup **)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);

This is a bit weird.

The function is used in one place, to check whether a random page is a
slab page. It's essentially a generic type check on the page!

After your changes, you pass a struct slab that might well be invalid
if this isn't a slab page, and you rely on the PAGE's memcg_data to
tell you whether this is the case. It works because page->memcg_data
is overlaid with slab->memcg_data, but that won't be the case if we
allocate struct slab separately.

To avoid that trap down the road, I think it would be better to keep
the *page* the ambiguous object for now, and only resolve to struct
slab after the type check. So that every time you see struct slab, you
know it's valid.

In fact, I think it would be best to just inline page_objcgs_check()
into its sole caller. It would clarify the resolution from wildcard
page to valid struct slab quite a bit:

> @@ -2819,38 +2819,39 @@ int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
>   */
>  struct mem_cgroup *mem_cgroup_from_obj(void *p)
>  {
> -	struct page *page;
> +	struct slab *slab;
>  
>  	if (mem_cgroup_disabled())
>  		return NULL;
>  
> -	page = virt_to_head_page(p);
> +	slab = virt_to_slab(p);
>  
>  	/*
>  	 * Slab objects are accounted individually, not per-page.
>  	 * Memcg membership data for each individual object is saved in
> -	 * the page->obj_cgroups.
> +	 * the slab->obj_cgroups.
>  	 */
> -	if (page_objcgs_check(page)) {
> +	if (slab_objcgs_check(slab)) {

I.e. do this instead:

	page = virt_to_head_page(p);

	/* object is backed by slab */
	if (page->memcg_data & MEMCG_DATA_OBJCGS) {
		struct slab *slab = (struct slab *)page;

		objcg = slab_objcgs(...)[]
		return objcg ? obj_cgroup_memcg(objcg): NULL;
	}

	/* object is backed by a regular kernel page */
	return page_memcg_check(page);

>  		struct obj_cgroup *objcg;
>  		unsigned int off;
>  
> -		off = obj_to_index(page->slab_cache, page, p);
> -		objcg = page_objcgs(page)[off];
> +		off = obj_to_index(slab->slab_cache, slab, p);
> +		objcg = slab_objcgs(slab)[off];
>  		if (objcg)
>  			return obj_cgroup_memcg(objcg);
>  
>  		return NULL;
>  	}
>  
> +	/* I am pretty sure this could just be 'return NULL' */

No, we could still be looking at a regular page that is being tracked
by memcg. People do (void *)__get_free_pages(GFP_ACCOUNT). So this
needs to stay 'return page_memcg_check()'.
Matthew Wilcox Oct. 12, 2021, 3:16 a.m. UTC | #2
On Mon, Oct 11, 2021 at 01:13:18PM -0400, Johannes Weiner wrote:
> Because right now we can have user pages pointing to a memcg, random
> alloc_page(GFP_ACCOUNT) pages pointing to an objcg, and slab pages
> pointing to an array of objcgs - all in the same memcg_data member.

Ah!  I was missing the possibility that an alloc_page() could point to
an objcg.  I had thought that only slab pages could point to an objcg
and only anon/file pages could point to a memcg.

> After your patch, slab->memcg_data points to an array of objcgs,
> period. The only time it doesn't is when there is a bug. Once the
> memcg_data member is no longer physically shared between page and
> slab, we can do:
> 
> 	struct slab {
> 		struct obj_cgroup **objcgs;
> 	};
> 
> and ditch the accessor function altogether.

Yes.

> > - * page_objcgs_check - get the object cgroups vector associated with a page
> > - * @page: a pointer to the page struct
> > + * slab_objcgs_check - get the object cgroups vector associated with a page
> > + * @slab: a pointer to the slab struct
> >   *
> > - * Returns a pointer to the object cgroups vector associated with the page,
> > - * or NULL. This function is safe to use if the page can be directly associated
> > + * Returns a pointer to the object cgroups vector associated with the slab,
> > + * or NULL. This function is safe to use if the slab can be directly associated
> >   * with a memory cgroup.
> >   */
> > -static inline struct obj_cgroup **page_objcgs_check(struct page *page)
> > +static inline struct obj_cgroup **slab_objcgs_check(struct slab *slab)
> >  {
> > -	unsigned long memcg_data = READ_ONCE(page->memcg_data);
> > +	unsigned long memcg_data = READ_ONCE(slab->memcg_data);
> >  
> >  	if (!memcg_data || !(memcg_data & MEMCG_DATA_OBJCGS))
> >  		return NULL;
> >  
> > -	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, page);
> > +	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, slab_page(slab));
> >  
> >  	return (struct obj_cgroup **)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
> 
> This is a bit weird.
> 
> The function is used in one place, to check whether a random page is a
> slab page. It's essentially a generic type check on the page!
> 
> After your changes, you pass a struct slab that might well be invalid
> if this isn't a slab page, and you rely on the PAGE's memcg_data to
> tell you whether this is the case. It works because page->memcg_data
> is overlaid with slab->memcg_data, but that won't be the case if we
> allocate struct slab separately.
> 
> To avoid that trap down the road, I think it would be better to keep
> the *page* the ambiguous object for now, and only resolve to struct
> slab after the type check. So that every time you see struct slab, you
> know it's valid.
> 
> In fact, I think it would be best to just inline page_objcgs_check()
> into its sole caller. It would clarify the resolution from wildcard
> page to valid struct slab quite a bit:

Yes.  Every time I read through this, I was wondering if there was
something I was missing.  I mean, there was (the memcg/objcg/objcgs
distinction above), but yes, if we know we have a slab, we don't need
this function.

> > @@ -2819,38 +2819,39 @@ int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
> >   */
> >  struct mem_cgroup *mem_cgroup_from_obj(void *p)
> >  {
> > -	struct page *page;
> > +	struct slab *slab;
> >  
> >  	if (mem_cgroup_disabled())
> >  		return NULL;
> >  
> > -	page = virt_to_head_page(p);
> > +	slab = virt_to_slab(p);
> >  
> >  	/*
> >  	 * Slab objects are accounted individually, not per-page.
> >  	 * Memcg membership data for each individual object is saved in
> > -	 * the page->obj_cgroups.
> > +	 * the slab->obj_cgroups.
> >  	 */
> > -	if (page_objcgs_check(page)) {
> > +	if (slab_objcgs_check(slab)) {
> 
> I.e. do this instead:
> 
> 	page = virt_to_head_page(p);
> 
> 	/* object is backed by slab */
> 	if (page->memcg_data & MEMCG_DATA_OBJCGS) {
> 		struct slab *slab = (struct slab *)page;
> 
> 		objcg = slab_objcgs(...)[]
> 		return objcg ? obj_cgroup_memcg(objcg): NULL;
> 	}
> 
> 	/* object is backed by a regular kernel page */
> 	return page_memcg_check(page);

Maybe I'm missing something else, but why not discriminate based on
PageSlab()?  ie:

	slab = virt_to_slab(p);
	if (slab_test_cache(slab)) {
		...
	}
	return page_memcg_check((struct page *)slab);

... but see the response to your other email for why not exactly this.
diff mbox series

Patch

diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 3096c9a0ee01..3ddc7a980fda 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -537,41 +537,41 @@  static inline bool PageMemcgKmem(struct page *page)
 }
 
 /*
- * page_objcgs - get the object cgroups vector associated with a page
- * @page: a pointer to the page struct
+ * slab_objcgs - get the object cgroups vector associated with a page
+ * @slab: a pointer to the slab struct
  *
- * Returns a pointer to the object cgroups vector associated with the page,
- * or NULL. This function assumes that the page is known to have an
+ * Returns a pointer to the object cgroups vector associated with the slab,
+ * or NULL. This function assumes that the slab is known to have an
  * associated object cgroups vector. It's not safe to call this function
  * against pages, which might have an associated memory cgroup: e.g.
  * kernel stack pages.
  */
-static inline struct obj_cgroup **page_objcgs(struct page *page)
+static inline struct obj_cgroup **slab_objcgs(struct slab *slab)
 {
-	unsigned long memcg_data = READ_ONCE(page->memcg_data);
+	unsigned long memcg_data = READ_ONCE(slab->memcg_data);
 
-	VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJCGS), page);
-	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, page);
+	VM_BUG_ON_PAGE(memcg_data && !(memcg_data & MEMCG_DATA_OBJCGS), slab_page(slab));
+	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, slab_page(slab));
 
 	return (struct obj_cgroup **)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
 }
 
 /*
- * page_objcgs_check - get the object cgroups vector associated with a page
- * @page: a pointer to the page struct
+ * slab_objcgs_check - get the object cgroups vector associated with a page
+ * @slab: a pointer to the slab struct
  *
- * Returns a pointer to the object cgroups vector associated with the page,
- * or NULL. This function is safe to use if the page can be directly associated
+ * Returns a pointer to the object cgroups vector associated with the slab,
+ * or NULL. This function is safe to use if the slab can be directly associated
  * with a memory cgroup.
  */
-static inline struct obj_cgroup **page_objcgs_check(struct page *page)
+static inline struct obj_cgroup **slab_objcgs_check(struct slab *slab)
 {
-	unsigned long memcg_data = READ_ONCE(page->memcg_data);
+	unsigned long memcg_data = READ_ONCE(slab->memcg_data);
 
 	if (!memcg_data || !(memcg_data & MEMCG_DATA_OBJCGS))
 		return NULL;
 
-	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, page);
+	VM_BUG_ON_PAGE(memcg_data & MEMCG_DATA_KMEM, slab_page(slab));
 
 	return (struct obj_cgroup **)(memcg_data & ~MEMCG_DATA_FLAGS_MASK);
 }
@@ -582,12 +582,12 @@  static inline bool PageMemcgKmem(struct page *page)
 	return false;
 }
 
-static inline struct obj_cgroup **page_objcgs(struct page *page)
+static inline struct obj_cgroup **slab_objcgs(struct slab *slab)
 {
 	return NULL;
 }
 
-static inline struct obj_cgroup **page_objcgs_check(struct page *page)
+static inline struct obj_cgroup **slab_objcgs_check(struct slab *slab)
 {
 	return NULL;
 }
diff --git a/include/linux/slab_def.h b/include/linux/slab_def.h
index 3aa5e1e73ab6..f81a41f9d5d1 100644
--- a/include/linux/slab_def.h
+++ b/include/linux/slab_def.h
@@ -106,16 +106,16 @@  static inline void *nearest_obj(struct kmem_cache *cache, struct page *page,
  *   reciprocal_divide(offset, cache->reciprocal_buffer_size)
  */
 static inline unsigned int obj_to_index(const struct kmem_cache *cache,
-					const struct page *page, void *obj)
+					const struct slab *slab, void *obj)
 {
-	u32 offset = (obj - page->s_mem);
+	u32 offset = (obj - slab->s_mem);
 	return reciprocal_divide(offset, cache->reciprocal_buffer_size);
 }
 
-static inline int objs_per_slab_page(const struct kmem_cache *cache,
-				     const struct page *page)
+static inline int objs_per_slab(const struct kmem_cache *cache,
+				     const struct slab *slab)
 {
-	if (is_kfence_address(page_address(page)))
+	if (is_kfence_address(slab_address(slab)))
 		return 1;
 	return cache->num;
 }
diff --git a/include/linux/slub_def.h b/include/linux/slub_def.h
index 63eae033d713..994a60da2f2e 100644
--- a/include/linux/slub_def.h
+++ b/include/linux/slub_def.h
@@ -187,16 +187,16 @@  static inline unsigned int __obj_to_index(const struct kmem_cache *cache,
 }
 
 static inline unsigned int obj_to_index(const struct kmem_cache *cache,
-					const struct page *page, void *obj)
+					const struct slab *slab, void *obj)
 {
 	if (is_kfence_address(obj))
 		return 0;
-	return __obj_to_index(cache, page_address(page), obj);
+	return __obj_to_index(cache, slab_address(slab), obj);
 }
 
-static inline int objs_per_slab_page(const struct kmem_cache *cache,
-				     const struct page *page)
+static inline int objs_per_slab(const struct kmem_cache *cache,
+				     const struct slab *slab)
 {
-	return page->objects;
+	return slab->objects;
 }
 #endif /* _LINUX_SLUB_DEF_H */
diff --git a/mm/kasan/common.c b/mm/kasan/common.c
index 41779ad109cd..f3972af7fa1b 100644
--- a/mm/kasan/common.c
+++ b/mm/kasan/common.c
@@ -298,7 +298,7 @@  static inline u8 assign_tag(struct kmem_cache *cache,
 	/* For caches that either have a constructor or SLAB_TYPESAFE_BY_RCU: */
 #ifdef CONFIG_SLAB
 	/* For SLAB assign tags based on the object index in the freelist. */
-	return (u8)obj_to_index(cache, virt_to_head_page(object), (void *)object);
+	return (u8)obj_to_index(cache, virt_to_slab(object), (void *)object);
 #else
 	/*
 	 * For SLUB assign a random tag during slab creation, otherwise reuse
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 6da5020a8656..fb15325549c1 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2770,16 +2770,16 @@  static struct mem_cgroup *get_mem_cgroup_from_objcg(struct obj_cgroup *objcg)
  */
 #define OBJCGS_CLEAR_MASK	(__GFP_DMA | __GFP_RECLAIMABLE | __GFP_ACCOUNT)
 
-int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
-				 gfp_t gfp, bool new_page)
+int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
+			     gfp_t gfp, bool new_page)
 {
-	unsigned int objects = objs_per_slab_page(s, page);
+	unsigned int objects = objs_per_slab(s, slab);
 	unsigned long memcg_data;
 	void *vec;
 
 	gfp &= ~OBJCGS_CLEAR_MASK;
 	vec = kcalloc_node(objects, sizeof(struct obj_cgroup *), gfp,
-			   page_to_nid(page));
+			   slab_nid(slab));
 	if (!vec)
 		return -ENOMEM;
 
@@ -2790,10 +2790,10 @@  int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
 		 * it's memcg_data, no synchronization is required and
 		 * memcg_data can be simply assigned.
 		 */
-		page->memcg_data = memcg_data;
-	} else if (cmpxchg(&page->memcg_data, 0, memcg_data)) {
+		slab->memcg_data = memcg_data;
+	} else if (cmpxchg(&slab->memcg_data, 0, memcg_data)) {
 		/*
-		 * If the slab page is already in use, somebody can allocate
+		 * If the slab is already in use, somebody can allocate
 		 * and assign obj_cgroups in parallel. In this case the existing
 		 * objcg vector should be reused.
 		 */
@@ -2819,38 +2819,39 @@  int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
  */
 struct mem_cgroup *mem_cgroup_from_obj(void *p)
 {
-	struct page *page;
+	struct slab *slab;
 
 	if (mem_cgroup_disabled())
 		return NULL;
 
-	page = virt_to_head_page(p);
+	slab = virt_to_slab(p);
 
 	/*
 	 * Slab objects are accounted individually, not per-page.
 	 * Memcg membership data for each individual object is saved in
-	 * the page->obj_cgroups.
+	 * the slab->obj_cgroups.
 	 */
-	if (page_objcgs_check(page)) {
+	if (slab_objcgs_check(slab)) {
 		struct obj_cgroup *objcg;
 		unsigned int off;
 
-		off = obj_to_index(page->slab_cache, page, p);
-		objcg = page_objcgs(page)[off];
+		off = obj_to_index(slab->slab_cache, slab, p);
+		objcg = slab_objcgs(slab)[off];
 		if (objcg)
 			return obj_cgroup_memcg(objcg);
 
 		return NULL;
 	}
 
+	/* I am pretty sure this could just be 'return NULL' */
 	/*
-	 * page_memcg_check() is used here, because page_has_obj_cgroups()
+	 * page_memcg_check() is used here, because slab_has_obj_cgroups()
 	 * check above could fail because the object cgroups vector wasn't set
 	 * at that moment, but it can be set concurrently.
-	 * page_memcg_check(page) will guarantee that a proper memory
+	 * page_memcg_check() will guarantee that a proper memory
 	 * cgroup pointer or NULL will be returned.
 	 */
-	return page_memcg_check(page);
+	return page_memcg_check((struct page *)slab);
 }
 
 __always_inline struct obj_cgroup *get_obj_cgroup_from_current(void)
diff --git a/mm/slab.c b/mm/slab.c
index 29dc09e784b8..3e9cd3ecc9ab 100644
--- a/mm/slab.c
+++ b/mm/slab.c
@@ -1555,7 +1555,7 @@  static void check_poison_obj(struct kmem_cache *cachep, void *objp)
 		struct slab *slab = virt_to_slab(objp);
 		unsigned int objnr;
 
-		objnr = obj_to_index(cachep, slab_page(slab), objp);
+		objnr = obj_to_index(cachep, slab, objp);
 		if (objnr) {
 			objp = index_to_obj(cachep, slab, objnr - 1);
 			realobj = (char *)objp + obj_offset(cachep);
@@ -2525,7 +2525,7 @@  static void *slab_get_obj(struct kmem_cache *cachep, struct slab *slab)
 static void slab_put_obj(struct kmem_cache *cachep,
 			struct slab *slab, void *objp)
 {
-	unsigned int objnr = obj_to_index(cachep, slab_page(slab), objp);
+	unsigned int objnr = obj_to_index(cachep, slab, objp);
 #if DEBUG
 	unsigned int i;
 
@@ -2723,7 +2723,7 @@  static void *cache_free_debugcheck(struct kmem_cache *cachep, void *objp,
 	if (cachep->flags & SLAB_STORE_USER)
 		*dbg_userword(cachep, objp) = (void *)caller;
 
-	objnr = obj_to_index(cachep, slab_page(slab), objp);
+	objnr = obj_to_index(cachep, slab, objp);
 
 	BUG_ON(objnr >= cachep->num);
 	BUG_ON(objp != index_to_obj(cachep, slab, objnr));
@@ -3669,7 +3669,7 @@  void kmem_obj_info(struct kmem_obj_info *kpp, void *object, struct slab *slab)
 	objp = object - obj_offset(cachep);
 	kpp->kp_data_offset = obj_offset(cachep);
 	slab = virt_to_slab(objp);
-	objnr = obj_to_index(cachep, slab_page(slab), objp);
+	objnr = obj_to_index(cachep, slab, objp);
 	objp = index_to_obj(cachep, slab, objnr);
 	kpp->kp_objp = objp;
 	if (DEBUG && cachep->flags & SLAB_STORE_USER)
@@ -4191,7 +4191,7 @@  void __check_heap_object(const void *ptr, unsigned long n,
 
 	/* Find and validate object. */
 	cachep = slab->slab_cache;
-	objnr = obj_to_index(cachep, slab_page(slab), (void *)ptr);
+	objnr = obj_to_index(cachep, slab, (void *)ptr);
 	BUG_ON(objnr >= cachep->num);
 
 	/* Find offset within object. */
diff --git a/mm/slab.h b/mm/slab.h
index 5eabc9352bbf..ac9dcdc1bfa9 100644
--- a/mm/slab.h
+++ b/mm/slab.h
@@ -333,15 +333,15 @@  static inline bool kmem_cache_debug_flags(struct kmem_cache *s, slab_flags_t fla
 }
 
 #ifdef CONFIG_MEMCG_KMEM
-int memcg_alloc_page_obj_cgroups(struct page *page, struct kmem_cache *s,
+int memcg_alloc_slab_cgroups(struct slab *slab, struct kmem_cache *s,
 				 gfp_t gfp, bool new_page);
 void mod_objcg_state(struct obj_cgroup *objcg, struct pglist_data *pgdat,
 		     enum node_stat_item idx, int nr);
 
-static inline void memcg_free_page_obj_cgroups(struct page *page)
+static inline void memcg_free_slab_cgroups(struct slab *slab)
 {
-	kfree(page_objcgs(page));
-	page->memcg_data = 0;
+	kfree(slab_objcgs(slab));
+	slab->memcg_data = 0;
 }
 
 static inline size_t obj_full_size(struct kmem_cache *s)
@@ -386,7 +386,7 @@  static inline void memcg_slab_post_alloc_hook(struct kmem_cache *s,
 					      gfp_t flags, size_t size,
 					      void **p)
 {
-	struct page *page;
+	struct slab *slab;
 	unsigned long off;
 	size_t i;
 
@@ -395,19 +395,18 @@  static inline void memcg_slab_post_alloc_hook(struct kmem_cache *s,
 
 	for (i = 0; i < size; i++) {
 		if (likely(p[i])) {
-			page = virt_to_head_page(p[i]);
+			slab = virt_to_slab(p[i]);
 
-			if (!page_objcgs(page) &&
-			    memcg_alloc_page_obj_cgroups(page, s, flags,
-							 false)) {
+			if (!slab_objcgs(slab) &&
+			    memcg_alloc_slab_cgroups(slab, s, flags, false)) {
 				obj_cgroup_uncharge(objcg, obj_full_size(s));
 				continue;
 			}
 
-			off = obj_to_index(s, page, p[i]);
+			off = obj_to_index(s, slab, p[i]);
 			obj_cgroup_get(objcg);
-			page_objcgs(page)[off] = objcg;
-			mod_objcg_state(objcg, page_pgdat(page),
+			slab_objcgs(slab)[off] = objcg;
+			mod_objcg_state(objcg, slab_pgdat(slab),
 					cache_vmstat_idx(s), obj_full_size(s));
 		} else {
 			obj_cgroup_uncharge(objcg, obj_full_size(s));
@@ -422,7 +421,7 @@  static inline void memcg_slab_free_hook(struct kmem_cache *s_orig,
 	struct kmem_cache *s;
 	struct obj_cgroup **objcgs;
 	struct obj_cgroup *objcg;
-	struct page *page;
+	struct slab *slab;
 	unsigned int off;
 	int i;
 
@@ -433,24 +432,24 @@  static inline void memcg_slab_free_hook(struct kmem_cache *s_orig,
 		if (unlikely(!p[i]))
 			continue;
 
-		page = virt_to_head_page(p[i]);
-		objcgs = page_objcgs_check(page);
+		slab = virt_to_slab(p[i]);
+		objcgs = slab_objcgs_check(slab);
 		if (!objcgs)
 			continue;
 
 		if (!s_orig)
-			s = page->slab_cache;
+			s = slab->slab_cache;
 		else
 			s = s_orig;
 
-		off = obj_to_index(s, page, p[i]);
+		off = obj_to_index(s, slab, p[i]);
 		objcg = objcgs[off];
 		if (!objcg)
 			continue;
 
 		objcgs[off] = NULL;
 		obj_cgroup_uncharge(objcg, obj_full_size(s));
-		mod_objcg_state(objcg, page_pgdat(page), cache_vmstat_idx(s),
+		mod_objcg_state(objcg, slab_pgdat(slab), cache_vmstat_idx(s),
 				-obj_full_size(s));
 		obj_cgroup_put(objcg);
 	}
@@ -462,14 +461,14 @@  static inline struct mem_cgroup *memcg_from_slab_obj(void *ptr)
 	return NULL;
 }
 
-static inline int memcg_alloc_page_obj_cgroups(struct page *page,
-					       struct kmem_cache *s, gfp_t gfp,
-					       bool new_page)
+static inline int memcg_alloc_slab_cgroups(struct slab *slab,
+					   struct kmem_cache *s, gfp_t gfp,
+					   bool new_page)
 {
 	return 0;
 }
 
-static inline void memcg_free_page_obj_cgroups(struct page *page)
+static inline void memcg_free_slab_cgroups(struct slab *slab)
 {
 }
 
@@ -509,7 +508,7 @@  static __always_inline void account_slab(struct slab *slab, int order,
 					      gfp_t gfp)
 {
 	if (memcg_kmem_enabled() && (s->flags & SLAB_ACCOUNT))
-		memcg_alloc_page_obj_cgroups(slab_page(slab), s, gfp, true);
+		memcg_alloc_slab_cgroups(slab, s, gfp, true);
 
 	mod_node_page_state(slab_pgdat(slab), cache_vmstat_idx(s),
 			    PAGE_SIZE << order);
@@ -519,7 +518,7 @@  static __always_inline void unaccount_slab(struct slab *slab, int order,
 						struct kmem_cache *s)
 {
 	if (memcg_kmem_enabled())
-		memcg_free_page_obj_cgroups(slab_page(slab));
+		memcg_free_slab_cgroups(slab);
 
 	mod_node_page_state(slab_pgdat(slab), cache_vmstat_idx(s),
 			    -(PAGE_SIZE << order));
diff --git a/mm/slub.c b/mm/slub.c
index 51ead3838fc1..659b30afbb58 100644
--- a/mm/slub.c
+++ b/mm/slub.c
@@ -4294,7 +4294,7 @@  void kmem_obj_info(struct kmem_obj_info *kpp, void *object, struct slab *slab)
 #else
 	objp = objp0;
 #endif
-	objnr = obj_to_index(s, slab_page(slab), objp);
+	objnr = obj_to_index(s, slab, objp);
 	kpp->kp_data_offset = (unsigned long)((char *)objp0 - (char *)objp);
 	objp = base + s->size * objnr;
 	kpp->kp_objp = objp;