diff mbox series

[STABLE,4.4,5/8] mm: prevent get_user_pages() from overflowing page refcount

Message ID 20191108093814.16032-6-vbabka@suse.cz (mailing list archive)
State New, archived
Headers show
Series page refcount overflow backports | expand

Commit Message

Vlastimil Babka Nov. 8, 2019, 9:38 a.m. UTC
From: Linus Torvalds <torvalds@linux-foundation.org>

commit 8fde12ca79aff9b5ba951fce1a2641901b8d8e64 upstream.

[ 4.4 backport: there's get_page_foll(), so add try_get_page()-like checks
                in there, enabled by a new parameter, which is false where
                upstream patch doesn't replace get_page() with try_get_page()
                (the THP and hugetlb callers).
		In gup_pte_range(), we don't expect tail pages, so just check
                page ref count instead of try_get_compound_head()
		Also patch arch-specific variants of gup.c for x86 and s390,
		leaving mips, sh, sparc alone				      ]

If the page refcount wraps around past zero, it will be freed while
there are still four billion references to it.  One of the possible
avenues for an attacker to try to make this happen is by doing direct IO
on a page multiple times.  This patch makes get_user_pages() refuse to
take a new page reference if there are already more than two billion
references to the page.

Reported-by: Jann Horn <jannh@google.com>
Acked-by: Matthew Wilcox <willy@infradead.org>
Cc: stable@kernel.org
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
Signed-off-by: Vlastimil Babka <vbabka@suse.cz>
---
 arch/s390/mm/gup.c |  6 ++++--
 arch/x86/mm/gup.c  |  9 ++++++++-
 mm/gup.c           | 39 +++++++++++++++++++++++++++++++--------
 mm/huge_memory.c   |  2 +-
 mm/hugetlb.c       | 18 ++++++++++++++++--
 mm/internal.h      | 12 +++++++++---
 6 files changed, 69 insertions(+), 17 deletions(-)

Comments

Ajay Kaher Dec. 3, 2019, 12:25 p.m. UTC | #1
On 08/11/19, 3:08 PM, "Vlastimil Babka" <vbabka@suse.cz> wrote:

> From: Linus Torvalds <torvalds@linux-foundation.org>
> 
> commit 8fde12ca79aff9b5ba951fce1a2641901b8d8e64 upstream.
>    
> [ 4.4 backport: there's get_page_foll(), so add try_get_page()-like checks
>                 in there, enabled by a new parameter, which is false where
>                 upstream patch doesn't replace get_page() with try_get_page()
>                 (the THP and hugetlb callers).

Could we have try_get_page_foll(), as in:
https://lore.kernel.org/stable/1570581863-12090-3-git-send-email-akaher@vmware.com/

+ Code will be in sync as we have try_get_page()
+ No need to add extra argument to try_get_page()
+ No need to modify the callers of try_get_page()

>		In gup_pte_range(), we don't expect tail pages, so just check
>                 page ref count instead of try_get_compound_head()

Technically it's fine. If you want to keep the code of stable versions in sync
with latest versions then this could be done in following ways (without any
modification in upstream patch for gup_pte_range()):

Apply 7aef4172c7957d7e65fc172be4c99becaef855d4 before applying
8fde12ca79aff9b5ba951fce1a2641901b8d8e64, as done here:
https://lore.kernel.org/stable/1570581863-12090-4-git-send-email-akaher@vmware.com/

> 		Also patch arch-specific variants of gup.c for x86 and s390,
> 		leaving mips, sh, sparc alone				      ]
> 
    
> ---
>  arch/s390/mm/gup.c |  6 ++++--
>  arch/x86/mm/gup.c  |  9 ++++++++-
>  mm/gup.c           | 39 +++++++++++++++++++++++++++++++--------
>  mm/huge_memory.c   |  2 +-
>  mm/hugetlb.c       | 18 ++++++++++++++++--
>  mm/internal.h      | 12 +++++++++---
>  6 files changed, 69 insertions(+), 17 deletions(-)
>    
> #ifdef __HAVE_ARCH_PTE_SPECIAL
>  static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
>  			 int write, struct page **pages, int *nr)
> @@ -1083,6 +1103,9 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
>  		VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
> 		page = pte_page(pte);
>  
> +		if (WARN_ON_ONCE(page_ref_count(page) < 0))
> +			goto pte_unmap;
> +
>  		if (!page_cache_get_speculative(page))
>  			goto pte_unmap;


 
> diff --git a/mm/internal.h b/mm/internal.h
> index a6639c72780a..b52041969d06 100644
> --- a/mm/internal.h
> +++ b/mm/internal.h
> @@ -93,23 +93,29 @@ static inline void __get_page_tail_foll(struct page *page,
>   * follow_page() and it must be called while holding the proper PT
>   * lock while the pte (or pmd_trans_huge) is still mapping the page.
>   */
> -static inline void get_page_foll(struct page *page)
> +static inline bool get_page_foll(struct page *page, bool check)
>  {
> -	if (unlikely(PageTail(page)))
> +	if (unlikely(PageTail(page))) {
>  		/*
>  		 * This is safe only because
>  		 * __split_huge_page_refcount() can't run under
>  		 * get_page_foll() because we hold the proper PT lock.
>  		 */
> +		if (check && WARN_ON_ONCE(
> +				page_ref_count(compound_head(page)) <= 0))
> +			return false;
>  		__get_page_tail_foll(page, true);
> -	else {
> +	} else {
>  		/*
>  		 * Getting a normal page or the head of a compound page
>  		 * requires to already have an elevated page->_count.
>  		 */
>  		VM_BUG_ON_PAGE(page_ref_zero_or_close_to_overflow(page), page);
> +		if (check && WARN_ON_ONCE(page_ref_count(page) <= 0))
> +			return false;
>  		atomic_inc(&page->_count);
>  	}
> +	return true;
>  }
Vlastimil Babka Dec. 3, 2019, 12:57 p.m. UTC | #2
On 12/3/19 1:25 PM, Ajay Kaher wrote:
> 
> 
> On 08/11/19, 3:08 PM, "Vlastimil Babka" <vbabka@suse.cz> wrote:
> 
>> From: Linus Torvalds <torvalds@linux-foundation.org>
>>
>> commit 8fde12ca79aff9b5ba951fce1a2641901b8d8e64 upstream.
>>    
>> [ 4.4 backport: there's get_page_foll(), so add try_get_page()-like checks
>>                 in there, enabled by a new parameter, which is false where
>>                 upstream patch doesn't replace get_page() with try_get_page()
>>                 (the THP and hugetlb callers).
> 
> Could we have try_get_page_foll(), as in:
> https://lore.kernel.org/stable/1570581863-12090-3-git-send-email-akaher@vmware.com/
> 
> + Code will be in sync as we have try_get_page()
> + No need to add extra argument to try_get_page()
> + No need to modify the callers of try_get_page()
> 
>> 		In gup_pte_range(), we don't expect tail pages, so just check
>>                 page ref count instead of try_get_compound_head()
> 
> Technically it's fine. If you want to keep the code of stable versions in sync
> with latest versions then this could be done in following ways (without any
> modification in upstream patch for gup_pte_range()):
> 
> Apply 7aef4172c7957d7e65fc172be4c99becaef855d4 before applying
> 8fde12ca79aff9b5ba951fce1a2641901b8d8e64, as done here:
> https://lore.kernel.org/stable/1570581863-12090-4-git-send-email-akaher@vmware.com/

Yup, I have considered that, and deliberately didn't add that commit
7aef4172c795 ("mm: handle PTE-mapped tail pages in gerneric fast gup
implementaiton") as it's part of a large THP refcount rework. In 4.4 we
don't expect to GUP tail pages so I wanted to keep it that way -
minimally, the compound_head() operation is a unnecessary added cost,
although it would also work.
Ajay Kaher Dec. 6, 2019, 4:15 a.m. UTC | #3
On 03/12/19, 6:28 PM, "Vlastimil Babka" <vbabka@suse.cz> wrote:
>>>    
>>> [ 4.4 backport: there's get_page_foll(), so add try_get_page()-like checks
>>>                 in there, enabled by a new parameter, which is false where
>>>                 upstream patch doesn't replace get_page() with try_get_page()
>>>                 (the THP and hugetlb callers).
>> 
>> Could we have try_get_page_foll(), as in:
>> https://nam04.safelinks.protection.outlook.com/?url=https%3A%2F%2Flore.kernel.org%2Fstable%2F1570581863-12090-3-git-send-email-akaher%40vmware.com%2F&amp;data=02%7C01%7Cakaher%40vmware.com%7Cb6592f0fbec040aa045f08d777f06a9f%7Cb39138ca3cee4b4aa4d6cd83d9dd62f0%7C0%7C0%7C637109746821395444&amp;sdata=cYBj3SvEikPbiHsVZj3zCys8t9ISLiHKzAlsSqiZRW8%3D&amp;reserved=0
>> 
>> + Code will be in sync as we have try_get_page()
>> + No need to add extra argument to try_get_page()
>> + No need to modify the callers of try_get_page()

Any reason for not using try_get_page_foll().

>>> 		In gup_pte_range(), we don't expect tail pages, so just check
>>>                 page ref count instead of try_get_compound_head()
>> 
>> Technically it's fine. If you want to keep the code of stable versions in sync
>> with latest versions then this could be done in following ways (without any
>> modification in upstream patch for gup_pte_range()):
>> 
>> Apply 7aef4172c7957d7e65fc172be4c99becaef855d4 before applying
>> 8fde12ca79aff9b5ba951fce1a2641901b8d8e64, as done here:
>> https://nam04.safelinks.protection.outlook.com/?url=https%3A%2F%2Flore.kernel.org%2Fstable%2F1570581863-12090-4-git-send-email-akaher%40vmware.com%2F&amp;data=02%7C01%7Cakaher%40vmware.com%7Cb6592f0fbec040aa045f08d777f06a9f%7Cb39138ca3cee4b4aa4d6cd83d9dd62f0%7C0%7C0%7C637109746821395444&amp;sdata=gTJMJ3Yx6G0ng46TQsBzCS2DowwP7YtIjluKJuqvN6o%3D&amp;reserved=0
    
> Yup, I have considered that, and deliberately didn't add that commit
> 7aef4172c795 ("mm: handle PTE-mapped tail pages in gerneric fast gup
> implementaiton") as it's part of a large THP refcount rework. In 4.4 we
> don't expect to GUP tail pages so I wanted to keep it that way -
> minimally, the compound_head() operation is a unnecessary added cost,
> although it would also work.
Vlastimil Babka Dec. 6, 2019, 2:32 p.m. UTC | #4
On 12/6/19 5:15 AM, Ajay Kaher wrote:
> 
> 
> On 03/12/19, 6:28 PM, "Vlastimil Babka" <vbabka@suse.cz> wrote:
>>>>    
>>>> [ 4.4 backport: there's get_page_foll(), so add try_get_page()-like checks
>>>>                 in there, enabled by a new parameter, which is false where
>>>>                 upstream patch doesn't replace get_page() with try_get_page()
>>>>                 (the THP and hugetlb callers).
>>>
>>> Could we have try_get_page_foll(), as in:
>>> https://nam04.safelinks.protection.outlook.com/?url=https%3A%2F%2Flore.kernel.org%2Fstable%2F1570581863-12090-3-git-send-email-akaher%40vmware.com%2F&amp;data=02%7C01%7Cakaher%40vmware.com%7Cb6592f0fbec040aa045f08d777f06a9f%7Cb39138ca3cee4b4aa4d6cd83d9dd62f0%7C0%7C0%7C637109746821395444&amp;sdata=cYBj3SvEikPbiHsVZj3zCys8t9ISLiHKzAlsSqiZRW8%3D&amp;reserved=0
>>>
>>> + Code will be in sync as we have try_get_page()
>>> + No need to add extra argument to try_get_page()
>>> + No need to modify the callers of try_get_page()
> 
> Any reason for not using try_get_page_foll().

Ah, sorry, I missed that previously. It's certainly possible to do it
that way, I just didn't care so strongly to rewrite the existing SLES
patch. It's a stable backport for a rather old LTS, not a codebase for
further development.

>>>> 		In gup_pte_range(), we don't expect tail pages, so just check
>>>>                 page ref count instead of try_get_compound_head()
>>>
>>> Technically it's fine. If you want to keep the code of stable versions in sync
>>> with latest versions then this could be done in following ways (without any
>>> modification in upstream patch for gup_pte_range()):
>>>
>>> Apply 7aef4172c7957d7e65fc172be4c99becaef855d4 before applying
>>> 8fde12ca79aff9b5ba951fce1a2641901b8d8e64, as done here:
>>> https://nam04.safelinks.protection.outlook.com/?url=https%3A%2F%2Flore.kernel.org%2Fstable%2F1570581863-12090-4-git-send-email-akaher%40vmware.com%2F&amp;data=02%7C01%7Cakaher%40vmware.com%7Cb6592f0fbec040aa045f08d777f06a9f%7Cb39138ca3cee4b4aa4d6cd83d9dd62f0%7C0%7C0%7C637109746821395444&amp;sdata=gTJMJ3Yx6G0ng46TQsBzCS2DowwP7YtIjluKJuqvN6o%3D&amp;reserved=0
>     
>> Yup, I have considered that, and deliberately didn't add that commit
>> 7aef4172c795 ("mm: handle PTE-mapped tail pages in gerneric fast gup
>> implementaiton") as it's part of a large THP refcount rework. In 4.4 we
>> don't expect to GUP tail pages so I wanted to keep it that way -
>> minimally, the compound_head() operation is a unnecessary added cost,
>> although it would also work.
>     
>
Ajay Kaher Dec. 9, 2019, 8:54 a.m. UTC | #5
On 06/12/19, 8:02 PM, "Vlastimil Babka" <vbabka@suse.cz> wrote:

> On 12/6/19 5:15 AM, Ajay Kaher wrote:
>> 
>> 
>> On 03/12/19, 6:28 PM, "Vlastimil Babka" <vbabka@suse.cz> wrote:
>>>>>    
>>>>> [ 4.4 backport: there's get_page_foll(), so add try_get_page()-like checks
>>>>>                 in there, enabled by a new parameter, which is false where
>>>>>                 upstream patch doesn't replace get_page() with try_get_page()
>>>>>                 (the THP and hugetlb callers).
>>>>
>>>> Could we have try_get_page_foll(), as in:
>>>> https://nam04.safelinks.protection.outlook.com/?url=https%3A%2F%2Flore.kernel.org%2Fstable%2F1570581863-12090-3-git-send-email-akaher%40vmware.com%2F&amp;data=02%7C01%7Cakaher%40vmware.com%7Cb65cf5622ca8401fd2ba08d77a5914e8%7Cb39138ca3cee4b4aa4d6cd83d9dd62f0%7C0%7C0%7C637112395344338606&amp;sdata=sLbw%2BQWu0%2BB0y2OpfaQS%2FxXX6Z9jNB3wPeTcPsawNJA%3D&amp;reserved=0
>>>>
>>>> + Code will be in sync as we have try_get_page()
>>>> + No need to add extra argument to try_get_page()
>>>> + No need to modify the callers of try_get_page()
>> 
>> Any reason for not using try_get_page_foll().
>    
> Ah, sorry, I missed that previously. It's certainly possible to do it
> that way, I just didn't care so strongly to rewrite the existing SLES
> patch. It's a stable backport for a rather old LTS, not a codebase for
> further development.
 
Thanks for your response.

I would appreciate if you would like to include try_get_page_foll(),
and resend this patch series again.

Greg may require Acked-by from my side also, so if it's fine with you,
you can add or I will add once you will post this patch series again.

Let me know if anything else I can do here.

>>>>> 		In gup_pte_range(), we don't expect tail pages, so just check
>>>>>                 page ref count instead of try_get_compound_head()
>>>>
>>>> Technically it's fine. If you want to keep the code of stable versions in sync
>>>> with latest versions then this could be done in following ways (without any
>>>> modification in upstream patch for gup_pte_range()):
>>>>
>>>> Apply 7aef4172c7957d7e65fc172be4c99becaef855d4 before applying
>>>> 8fde12ca79aff9b5ba951fce1a2641901b8d8e64, as done here:
>>>> https://nam04.safelinks.protection.outlook.com/?url=https%3A%2F%2Flore.kernel.org%2Fstable%2F1570581863-12090-4-git-send-email-akaher%40vmware.com%2F&amp;data=02%7C01%7Cakaher%40vmware.com%7Cb65cf5622ca8401fd2ba08d77a5914e8%7Cb39138ca3cee4b4aa4d6cd83d9dd62f0%7C0%7C0%7C637112395344348599&amp;sdata=MYA%2Fx7oVu8x1c7%2FGkEw%2B69FX7WN1O34Oq8lkMiFs1Wk%3D&amp;reserved=0
>>     
>>> Yup, I have considered that, and deliberately didn't add that commit
>>> 7aef4172c795 ("mm: handle PTE-mapped tail pages in gerneric fast gup
>>> implementaiton") as it's part of a large THP refcount rework. In 4.4 we
>>> don't expect to GUP tail pages so I wanted to keep it that way -
>>> minimally, the compound_head() operation is a unnecessary added cost,
>>> although it would also work.
>>     

Thanks for above explanation.
Vlastimil Babka Dec. 9, 2019, 9:10 a.m. UTC | #6
On 12/9/19 9:54 AM, Ajay Kaher wrote:
> 
> 
> On 06/12/19, 8:02 PM, "Vlastimil Babka" <vbabka@suse.cz> wrote:
> 
>> On 12/6/19 5:15 AM, Ajay Kaher wrote:
>>>
>>>
>>> On 03/12/19, 6:28 PM, "Vlastimil Babka" <vbabka@suse.cz> wrote:
>>>>>>    
>>>>>> [ 4.4 backport: there's get_page_foll(), so add try_get_page()-like checks
>>>>>>                 in there, enabled by a new parameter, which is false where
>>>>>>                 upstream patch doesn't replace get_page() with try_get_page()
>>>>>>                 (the THP and hugetlb callers).
>>>>>
>>>>> Could we have try_get_page_foll(), as in:
>>>>> https://nam04.safelinks.protection.outlook.com/?url=https%3A%2F%2Flore.kernel.org%2Fstable%2F1570581863-12090-3-git-send-email-akaher%40vmware.com%2F&amp;data=02%7C01%7Cakaher%40vmware.com%7Cb65cf5622ca8401fd2ba08d77a5914e8%7Cb39138ca3cee4b4aa4d6cd83d9dd62f0%7C0%7C0%7C637112395344338606&amp;sdata=sLbw%2BQWu0%2BB0y2OpfaQS%2FxXX6Z9jNB3wPeTcPsawNJA%3D&amp;reserved=0
>>>>>
>>>>> + Code will be in sync as we have try_get_page()
>>>>> + No need to add extra argument to try_get_page()
>>>>> + No need to modify the callers of try_get_page()
>>>
>>> Any reason for not using try_get_page_foll().
>>    
>> Ah, sorry, I missed that previously. It's certainly possible to do it
>> that way, I just didn't care so strongly to rewrite the existing SLES
>> patch. It's a stable backport for a rather old LTS, not a codebase for
>> further development.
>  
> Thanks for your response.
> 
> I would appreciate if you would like to include try_get_page_foll(),
> and resend this patch series again.

I won't have time for that now, but I don't mind if you do that, or
resend your version with the missing x86 and s390 gup.c parts and
preferably without 7aef4172c795.
diff mbox series

Patch

diff --git a/arch/s390/mm/gup.c b/arch/s390/mm/gup.c
index 7ad41be8b373..bdaa5f7b652c 100644
--- a/arch/s390/mm/gup.c
+++ b/arch/s390/mm/gup.c
@@ -37,7 +37,8 @@  static inline int gup_pte_range(pmd_t *pmdp, pmd_t pmd, unsigned long addr,
 			return 0;
 		VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
 		page = pte_page(pte);
-		if (!page_cache_get_speculative(page))
+		if (unlikely(WARN_ON_ONCE(page_ref_count(page) < 0)
+		    || !page_cache_get_speculative(page)))
 			return 0;
 		if (unlikely(pte_val(pte) != pte_val(*ptep))) {
 			put_page(page);
@@ -76,7 +77,8 @@  static inline int gup_huge_pmd(pmd_t *pmdp, pmd_t pmd, unsigned long addr,
 		refs++;
 	} while (addr += PAGE_SIZE, addr != end);
 
-	if (!page_cache_add_speculative(head, refs)) {
+	if (unlikely(WARN_ON_ONCE(page_ref_count(head) < 0)
+	    || !page_cache_add_speculative(head, refs))) {
 		*nr -= refs;
 		return 0;
 	}
diff --git a/arch/x86/mm/gup.c b/arch/x86/mm/gup.c
index 7d2542ad346a..6612d532e42e 100644
--- a/arch/x86/mm/gup.c
+++ b/arch/x86/mm/gup.c
@@ -95,7 +95,10 @@  static noinline int gup_pte_range(pmd_t pmd, unsigned long addr,
 		}
 		VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
 		page = pte_page(pte);
-		get_page(page);
+		if (unlikely(!try_get_page(page))) {
+			pte_unmap(ptep);
+			return 0;
+		}
 		SetPageReferenced(page);
 		pages[*nr] = page;
 		(*nr)++;
@@ -132,6 +135,8 @@  static noinline int gup_huge_pmd(pmd_t pmd, unsigned long addr,
 
 	refs = 0;
 	head = pmd_page(pmd);
+	if (WARN_ON_ONCE(page_ref_count(head) <= 0))
+		return 0;
 	page = head + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
 	do {
 		VM_BUG_ON_PAGE(compound_head(page) != head, page);
@@ -208,6 +213,8 @@  static noinline int gup_huge_pud(pud_t pud, unsigned long addr,
 
 	refs = 0;
 	head = pud_page(pud);
+	if (WARN_ON_ONCE(page_ref_count(head) <= 0))
+		return 0;
 	page = head + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
 	do {
 		VM_BUG_ON_PAGE(compound_head(page) != head, page);
diff --git a/mm/gup.c b/mm/gup.c
index 71e9d0093a35..fc8e2dca99fc 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -127,7 +127,10 @@  static struct page *follow_page_pte(struct vm_area_struct *vma,
 	}
 
 	if (flags & FOLL_GET)
-		get_page_foll(page);
+		if (!get_page_foll(page, true)) {
+			page = ERR_PTR(-ENOMEM);
+			goto out;
+		}
 	if (flags & FOLL_TOUCH) {
 		if ((flags & FOLL_WRITE) &&
 		    !pte_dirty(pte) && !PageDirty(page))
@@ -289,7 +292,10 @@  static int get_gate_page(struct mm_struct *mm, unsigned long address,
 			goto unmap;
 		*page = pte_page(*pte);
 	}
-	get_page(*page);
+	if (unlikely(!try_get_page(*page))) {
+		ret = -ENOMEM;
+		goto unmap;
+	}
 out:
 	ret = 0;
 unmap:
@@ -1053,6 +1059,20 @@  struct page *get_dump_page(unsigned long addr)
  */
 #ifdef CONFIG_HAVE_GENERIC_RCU_GUP
 
+/*
+ * Return the compund head page with ref appropriately incremented,
+ * or NULL if that failed.
+ */
+static inline struct page *try_get_compound_head(struct page *page, int refs)
+{
+	struct page *head = compound_head(page);
+	if (WARN_ON_ONCE(page_ref_count(head) < 0))
+		return NULL;
+	if (unlikely(!page_cache_add_speculative(head, refs)))
+		return NULL;
+	return head;
+}
+
 #ifdef __HAVE_ARCH_PTE_SPECIAL
 static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
 			 int write, struct page **pages, int *nr)
@@ -1083,6 +1103,9 @@  static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
 		VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
 		page = pte_page(pte);
 
+		if (WARN_ON_ONCE(page_ref_count(page) < 0))
+			goto pte_unmap;
+
 		if (!page_cache_get_speculative(page))
 			goto pte_unmap;
 
@@ -1139,8 +1162,8 @@  static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
 		refs++;
 	} while (addr += PAGE_SIZE, addr != end);
 
-	head = compound_head(pmd_page(orig));
-	if (!page_cache_add_speculative(head, refs)) {
+	head = try_get_compound_head(pmd_page(orig), refs);
+	if (!head) {
 		*nr -= refs;
 		return 0;
 	}
@@ -1185,8 +1208,8 @@  static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
 		refs++;
 	} while (addr += PAGE_SIZE, addr != end);
 
-	head = compound_head(pud_page(orig));
-	if (!page_cache_add_speculative(head, refs)) {
+	head = try_get_compound_head(pud_page(orig), refs);
+	if (!head) {
 		*nr -= refs;
 		return 0;
 	}
@@ -1227,8 +1250,8 @@  static int gup_huge_pgd(pgd_t orig, pgd_t *pgdp, unsigned long addr,
 		refs++;
 	} while (addr += PAGE_SIZE, addr != end);
 
-	head = compound_head(pgd_page(orig));
-	if (!page_cache_add_speculative(head, refs)) {
+	head = try_get_compound_head(pgd_page(orig), refs);
+	if (!head) {
 		*nr -= refs;
 		return 0;
 	}
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 465786cd6490..6087277981a6 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1322,7 +1322,7 @@  struct page *follow_trans_huge_pmd(struct vm_area_struct *vma,
 	page += (addr & ~HPAGE_PMD_MASK) >> PAGE_SHIFT;
 	VM_BUG_ON_PAGE(!PageCompound(page), page);
 	if (flags & FOLL_GET)
-		get_page_foll(page);
+		get_page_foll(page, false);
 
 out:
 	return page;
diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index fd932e7a25dd..b4a8a18fa3a5 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -3886,6 +3886,7 @@  long follow_hugetlb_page(struct mm_struct *mm, struct vm_area_struct *vma,
 	unsigned long vaddr = *position;
 	unsigned long remainder = *nr_pages;
 	struct hstate *h = hstate_vma(vma);
+	int err = -EFAULT;
 
 	while (vaddr < vma->vm_end && remainder) {
 		pte_t *pte;
@@ -3957,10 +3958,23 @@  long follow_hugetlb_page(struct mm_struct *mm, struct vm_area_struct *vma,
 
 		pfn_offset = (vaddr & ~huge_page_mask(h)) >> PAGE_SHIFT;
 		page = pte_page(huge_ptep_get(pte));
+
+		/*
+		 * Instead of doing 'try_get_page()' below in the same_page
+		 * loop, just check the count once here.
+		 */
+		if (unlikely(page_count(page) <= 0)) {
+			if (pages) {
+				spin_unlock(ptl);
+				remainder = 0;
+				err = -ENOMEM;
+				break;
+			}
+		}
 same_page:
 		if (pages) {
 			pages[i] = mem_map_offset(page, pfn_offset);
-			get_page_foll(pages[i]);
+			get_page_foll(pages[i], false);
 		}
 
 		if (vmas)
@@ -3983,7 +3997,7 @@  long follow_hugetlb_page(struct mm_struct *mm, struct vm_area_struct *vma,
 	*nr_pages = remainder;
 	*position = vaddr;
 
-	return i ? i : -EFAULT;
+	return i ? i : err;
 }
 
 unsigned long hugetlb_change_protection(struct vm_area_struct *vma,
diff --git a/mm/internal.h b/mm/internal.h
index a6639c72780a..b52041969d06 100644
--- a/mm/internal.h
+++ b/mm/internal.h
@@ -93,23 +93,29 @@  static inline void __get_page_tail_foll(struct page *page,
  * follow_page() and it must be called while holding the proper PT
  * lock while the pte (or pmd_trans_huge) is still mapping the page.
  */
-static inline void get_page_foll(struct page *page)
+static inline bool get_page_foll(struct page *page, bool check)
 {
-	if (unlikely(PageTail(page)))
+	if (unlikely(PageTail(page))) {
 		/*
 		 * This is safe only because
 		 * __split_huge_page_refcount() can't run under
 		 * get_page_foll() because we hold the proper PT lock.
 		 */
+		if (check && WARN_ON_ONCE(
+				page_ref_count(compound_head(page)) <= 0))
+			return false;
 		__get_page_tail_foll(page, true);
-	else {
+	} else {
 		/*
 		 * Getting a normal page or the head of a compound page
 		 * requires to already have an elevated page->_count.
 		 */
 		VM_BUG_ON_PAGE(page_ref_zero_or_close_to_overflow(page), page);
+		if (check && WARN_ON_ONCE(page_ref_count(page) <= 0))
+			return false;
 		atomic_inc(&page->_count);
 	}
+	return true;
 }
 
 extern unsigned long highest_memmap_pfn;