diff mbox series

[v7,17/24] mm/gup: track FOLL_PIN pages

Message ID 20191121071354.456618-18-jhubbard@nvidia.com (mailing list archive)
State Superseded
Headers show
Series mm/gup: track dma-pinned pages: FOLL_PIN | expand

Commit Message

John Hubbard Nov. 21, 2019, 7:13 a.m. UTC
Add tracking of pages that were pinned via FOLL_PIN.

As mentioned in the FOLL_PIN documentation, callers who effectively set
FOLL_PIN are required to ultimately free such pages via put_user_page().
The effect is similar to FOLL_GET, and may be thought of as "FOLL_GET
for DIO and/or RDMA use".

Pages that have been pinned via FOLL_PIN are identifiable via a
new function call:

   bool page_dma_pinned(struct page *page);

What to do in response to encountering such a page, is left to later
patchsets. There is discussion about this in [1], [2], and [3].

This also changes a BUG_ON(), to a WARN_ON(), in follow_page_mask().

[1] Some slow progress on get_user_pages() (Apr 2, 2019):
    https://lwn.net/Articles/784574/
[2] DMA and get_user_pages() (LPC: Dec 12, 2018):
    https://lwn.net/Articles/774411/
[3] The trouble with get_user_pages() (Apr 30, 2018):
    https://lwn.net/Articles/753027/

Suggested-by: Jan Kara <jack@suse.cz>
Suggested-by: Jérôme Glisse <jglisse@redhat.com>
Signed-off-by: John Hubbard <jhubbard@nvidia.com>
---
 Documentation/core-api/pin_user_pages.rst |   2 +-
 include/linux/mm.h                        | 113 +++++++-
 include/linux/mmzone.h                    |   2 +
 include/linux/page_ref.h                  |  10 +
 mm/gup.c                                  | 323 ++++++++++++++++------
 mm/huge_memory.c                          |  44 ++-
 mm/hugetlb.c                              |  36 ++-
 mm/vmstat.c                               |   2 +
 8 files changed, 421 insertions(+), 111 deletions(-)

Comments

Jan Kara Nov. 21, 2019, 9:39 a.m. UTC | #1
On Wed 20-11-19 23:13:47, John Hubbard wrote:
> Add tracking of pages that were pinned via FOLL_PIN.
> 
> As mentioned in the FOLL_PIN documentation, callers who effectively set
> FOLL_PIN are required to ultimately free such pages via put_user_page().
> The effect is similar to FOLL_GET, and may be thought of as "FOLL_GET
> for DIO and/or RDMA use".
> 
> Pages that have been pinned via FOLL_PIN are identifiable via a
> new function call:
> 
>    bool page_dma_pinned(struct page *page);
> 
> What to do in response to encountering such a page, is left to later
> patchsets. There is discussion about this in [1], [2], and [3].
> 
> This also changes a BUG_ON(), to a WARN_ON(), in follow_page_mask().
> 
> [1] Some slow progress on get_user_pages() (Apr 2, 2019):
>     https://lwn.net/Articles/784574/
> [2] DMA and get_user_pages() (LPC: Dec 12, 2018):
>     https://lwn.net/Articles/774411/
> [3] The trouble with get_user_pages() (Apr 30, 2018):
>     https://lwn.net/Articles/753027/
> 
> Suggested-by: Jan Kara <jack@suse.cz>
> Suggested-by: Jérôme Glisse <jglisse@redhat.com>
> Signed-off-by: John Hubbard <jhubbard@nvidia.com>

Thanks for the patch! We are mostly getting there. Some smaller comments
below.

> +/**
> + * try_pin_compound_head() - mark a compound page as being used by
> + * pin_user_pages*().
> + *
> + * This is the FOLL_PIN counterpart to try_get_compound_head().
> + *
> + * @page:	pointer to page to be marked
> + * @Return:	true for success, false for failure
> + */
> +__must_check bool try_pin_compound_head(struct page *page, int refs)
> +{
> +	page = try_get_compound_head(page, GUP_PIN_COUNTING_BIAS * refs);
> +	if (!page)
> +		return false;
> +
> +	__update_proc_vmstat(page, NR_FOLL_PIN_REQUESTED, refs);
> +	return true;
> +}
> +
> +#ifdef CONFIG_DEV_PAGEMAP_OPS
> +static bool __put_devmap_managed_user_page(struct page *page)

Probably call this __unpin_devmap_managed_user_page()? To match the later
conversion of put_user_page() to unpin_user_page()?

> +{
> +	bool is_devmap = page_is_devmap_managed(page);
> +
> +	if (is_devmap) {
> +		int count = page_ref_sub_return(page, GUP_PIN_COUNTING_BIAS);
> +
> +		__update_proc_vmstat(page, NR_FOLL_PIN_RETURNED, 1);
> +		/*
> +		 * devmap page refcounts are 1-based, rather than 0-based: if
> +		 * refcount is 1, then the page is free and the refcount is
> +		 * stable because nobody holds a reference on the page.
> +		 */
> +		if (count == 1)
> +			free_devmap_managed_page(page);
> +		else if (!count)
> +			__put_page(page);
> +	}
> +
> +	return is_devmap;
> +}
> +#else
> +static bool __put_devmap_managed_user_page(struct page *page)
> +{
> +	return false;
> +}
> +#endif /* CONFIG_DEV_PAGEMAP_OPS */
> +
> +/**
> + * put_user_page() - release a dma-pinned page
> + * @page:            pointer to page to be released
> + *
> + * Pages that were pinned via pin_user_pages*() must be released via either
> + * put_user_page(), or one of the put_user_pages*() routines. This is so that
> + * such pages can be separately tracked and uniquely handled. In particular,
> + * interactions with RDMA and filesystems need special handling.
> + */
> +void put_user_page(struct page *page)
> +{
> +	page = compound_head(page);
> +
> +	/*
> +	 * For devmap managed pages we need to catch refcount transition from
> +	 * GUP_PIN_COUNTING_BIAS to 1, when refcount reach one it means the
> +	 * page is free and we need to inform the device driver through
> +	 * callback. See include/linux/memremap.h and HMM for details.
> +	 */
> +	if (__put_devmap_managed_user_page(page))
> +		return;
> +
> +	if (page_ref_sub_and_test(page, GUP_PIN_COUNTING_BIAS))
> +		__put_page(page);
> +
> +	__update_proc_vmstat(page, NR_FOLL_PIN_RETURNED, 1);
> +}
> +EXPORT_SYMBOL(put_user_page);
> +
>  /**
>   * put_user_pages_dirty_lock() - release and optionally dirty gup-pinned pages
>   * @pages:  array of pages to be maybe marked dirty, and definitely released.
> @@ -237,10 +327,11 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
>  	}
>  
>  	page = vm_normal_page(vma, address, pte);
> -	if (!page && pte_devmap(pte) && (flags & FOLL_GET)) {
> +	if (!page && pte_devmap(pte) && (flags & (FOLL_GET | FOLL_PIN))) {
>  		/*
> -		 * Only return device mapping pages in the FOLL_GET case since
> -		 * they are only valid while holding the pgmap reference.
> +		 * Only return device mapping pages in the FOLL_GET or FOLL_PIN
> +		 * case since they are only valid while holding the pgmap
> +		 * reference.
>  		 */
>  		*pgmap = get_dev_pagemap(pte_pfn(pte), *pgmap);
>  		if (*pgmap)
> @@ -283,6 +374,11 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
>  			page = ERR_PTR(-ENOMEM);
>  			goto out;
>  		}
> +	} else if (flags & FOLL_PIN) {
> +		if (unlikely(!try_pin_page(page))) {
> +			page = ERR_PTR(-ENOMEM);
> +			goto out;
> +		}

Use grab_page() here?

>  	}
>  	if (flags & FOLL_TOUCH) {
>  		if ((flags & FOLL_WRITE) &&
> @@ -1890,9 +2000,15 @@ static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
>  		VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
>  		page = pte_page(pte);
>  
> -		head = try_get_compound_head(page, 1);
> -		if (!head)
> -			goto pte_unmap;
> +		if (flags & FOLL_PIN) {
> +			head = page;
> +			if (unlikely(!try_pin_page(head)))
> +				goto pte_unmap;
> +		} else {
> +			head = try_get_compound_head(page, 1);
> +			if (!head)
> +				goto pte_unmap;
> +		}

Why don't you use grab_page() here? Also you seem to loose the head =
compound_head(page) indirection here for the FOLL_PIN case?

>  
>  		if (unlikely(pte_val(pte) != pte_val(*ptep))) {
>  			put_page(head);
> @@ -1946,12 +2062,20 @@ static int __gup_device_huge(unsigned long pfn, unsigned long addr,
>  
>  		pgmap = get_dev_pagemap(pfn, pgmap);
>  		if (unlikely(!pgmap)) {
> -			undo_dev_pagemap(nr, nr_start, pages);
> +			undo_dev_pagemap(nr, nr_start, flags, pages);
>  			return 0;
>  		}
>  		SetPageReferenced(page);
>  		pages[*nr] = page;
> -		get_page(page);
> +
> +		if (flags & FOLL_PIN) {
> +			if (unlikely(!try_pin_page(page))) {
> +				undo_dev_pagemap(nr, nr_start, flags, pages);
> +				return 0;
> +			}
> +		} else
> +			get_page(page);
> +

Use grab_page() here?

>  		(*nr)++;
>  		pfn++;
>  	} while (addr += PAGE_SIZE, addr != end);
...
> @@ -2025,12 +2149,31 @@ static int __record_subpages(struct page *page, unsigned long addr,
>  	return nr;
>  }
>  
> -static void put_compound_head(struct page *page, int refs)
> +static bool grab_compound_head(struct page *head, int refs, unsigned int flags)
>  {
> +	if (flags & FOLL_PIN) {
> +		if (unlikely(!try_pin_compound_head(head, refs)))
> +			return false;
> +	} else {
> +		head = try_get_compound_head(head, refs);
> +		if (!head)
> +			return false;
> +	}
> +
> +	return true;
> +}
> +
> +static void put_compound_head(struct page *page, int refs, unsigned int flags)
> +{
> +	struct page *head = compound_head(page);
> +
> +	if (flags & FOLL_PIN)
> +		refs *= GUP_PIN_COUNTING_BIAS;
> +
>  	/* Do a get_page() first, in case refs == page->_refcount */
> -	get_page(page);
> -	page_ref_sub(page, refs);
> -	put_page(page);
> +	get_page(head);
> +	page_ref_sub(head, refs);
> +	put_page(head);
>  }
>  
>  #ifdef CONFIG_ARCH_HAS_HUGEPD
> @@ -2064,14 +2207,13 @@ static int gup_hugepte(pte_t *ptep, unsigned long sz, unsigned long addr,
>  
>  	head = pte_page(pte);
>  	page = head + ((addr & (sz-1)) >> PAGE_SHIFT);
> -	refs = __record_subpages(page, addr, end, pages + *nr);
> +	refs = record_subpages(page, addr, end, pages + *nr);
>  
> -	head = try_get_compound_head(head, refs);
> -	if (!head)
> +	if (!grab_compound_head(head, refs, flags))

Are you sure this is correct? Historically we seem to have always had logic
like:

	head = compound_head(pte_page / pmd_page / ... (orig))

in this code. And you removed this now. Looking at the code I'm not sure
whether the compound_head() indirection is really needed or not. We seem to
have already huge page head in the page table but maybe there's some subtle
case I'm missing. So I'd be calmer if we left the head=compound_head(...)
in the code but if you really want to remove it, I'd like to see Ack from
someone actually familiar with huge pages - e.g. Kirill Shutemov...

And even if we find out that compound_head() indirection isn't really
needed, that is big enough change in the logic that it would deserve to be
done in a separate patch (if only for debugging by bisection purposes).

> diff --git a/mm/huge_memory.c b/mm/huge_memory.c
> index 13cc93785006..981a9ea0b83f 100644
> --- a/mm/huge_memory.c
> +++ b/mm/huge_memory.c
...
> @@ -5034,8 +5052,20 @@ follow_huge_pmd(struct mm_struct *mm, unsigned long address,
>  	pte = huge_ptep_get((pte_t *)pmd);
>  	if (pte_present(pte)) {
>  		page = pmd_page(*pmd) + ((address & ~PMD_MASK) >> PAGE_SHIFT);
> +
>  		if (flags & FOLL_GET)
>  			get_page(page);
> +		else if (flags & FOLL_PIN) {
> +			/*
> +			 * try_pin_page() is not actually expected to fail
> +			 * here because we hold the ptl.
> +			 */
> +			if (unlikely(!try_pin_page(page))) {
> +				WARN_ON_ONCE(1);
> +				page = NULL;
> +				goto out;
> +			}
> +		}

Use grab_page() here?

								Honza
kernel test robot Nov. 30, 2019, 6:59 p.m. UTC | #2
Hi John,

Thank you for the patch! Yet something to improve:

[auto build test ERROR on linus/master]
[also build test ERROR on v5.4]
[cannot apply to mmotm/master rdma/for-next linuxtv-media/master next-20191129]
[if your patch is applied to the wrong git tree, please drop us a note to help
improve the system. BTW, we also suggest to use '--base' option to specify the
base tree in git format-patch, please see https://stackoverflow.com/a/37406982]

url:    https://github.com/0day-ci/linux/commits/John-Hubbard/mm-gup-track-dma-pinned-pages-FOLL_PIN/20191122-092349
base:   https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git c74386d50fbaf4a54fd3fe560f1abc709c0cff4b
config: sh-allmodconfig (attached as .config)
compiler: sh4-linux-gcc (GCC) 7.4.0
reproduce:
        wget https://raw.githubusercontent.com/intel/lkp-tests/master/sbin/make.cross -O ~/bin/make.cross
        chmod +x ~/bin/make.cross
        # save the attached .config to linux build tree
        GCC_VERSION=7.4.0 make.cross ARCH=sh 

If you fix the issue, kindly add following tag
Reported-by: kbuild test robot <lkp@intel.com>

All errors (new ones prefixed by >>):

   mm/gup.c: In function 'pin_user_pages_remote':
>> mm/gup.c:2684:9: error: implicit declaration of function '__get_user_pages_remote'; did you mean 'get_user_pages_remote'? [-Werror=implicit-function-declaration]
     return __get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags,
            ^~~~~~~~~~~~~~~~~~~~~~~
            get_user_pages_remote
   cc1: some warnings being treated as errors

vim +2684 mm/gup.c

  2660	
  2661	/**
  2662	 * pin_user_pages_remote() - pin pages of a remote process (task != current)
  2663	 *
  2664	 * Nearly the same as get_user_pages_remote(), except that FOLL_PIN is set. See
  2665	 * get_user_pages_remote() for documentation on the function arguments, because
  2666	 * the arguments here are identical.
  2667	 *
  2668	 * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
  2669	 * see Documentation/vm/pin_user_pages.rst for details.
  2670	 *
  2671	 * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  2672	 * is NOT intended for Case 2 (RDMA: long-term pins).
  2673	 */
  2674	long pin_user_pages_remote(struct task_struct *tsk, struct mm_struct *mm,
  2675				   unsigned long start, unsigned long nr_pages,
  2676				   unsigned int gup_flags, struct page **pages,
  2677				   struct vm_area_struct **vmas, int *locked)
  2678	{
  2679		/* FOLL_GET and FOLL_PIN are mutually exclusive. */
  2680		if (WARN_ON_ONCE(gup_flags & FOLL_GET))
  2681			return -EINVAL;
  2682	
  2683		gup_flags |= FOLL_PIN;
> 2684		return __get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags,
  2685					       pages, vmas, locked);
  2686	}
  2687	EXPORT_SYMBOL(pin_user_pages_remote);
  2688	

---
0-DAY kernel test infrastructure                 Open Source Technology Center
https://lists.01.org/hyperkitty/list/kbuild-all@lists.01.org Intel Corporation
diff mbox series

Patch

diff --git a/Documentation/core-api/pin_user_pages.rst b/Documentation/core-api/pin_user_pages.rst
index 4f26637a5005..baa288a44a77 100644
--- a/Documentation/core-api/pin_user_pages.rst
+++ b/Documentation/core-api/pin_user_pages.rst
@@ -53,7 +53,7 @@  Which flags are set by each wrapper
 For these pin_user_pages*() functions, FOLL_PIN is OR'd in with whatever gup
 flags the caller provides. The caller is required to pass in a non-null struct
 pages* array, and the function then pin pages by incrementing each by a special
-value. For now, that value is +1, just like get_user_pages*().::
+value: GUP_PIN_COUNTING_BIAS.::
 
  Function
  --------
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 568cbb895f03..253ec2d15f36 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1054,6 +1054,21 @@  static inline __must_check bool try_get_page(struct page *page)
 	return true;
 }
 
+__must_check bool try_pin_compound_head(struct page *page, int refs);
+
+/**
+ * try_pin_page() - mark a page as being used by pin_user_pages*().
+ *
+ * This is the FOLL_PIN counterpart to try_get_page().
+ *
+ * @page:	pointer to page to be marked
+ * @Return:	true for success, false for failure
+ */
+static inline __must_check bool try_pin_page(struct page *page)
+{
+	return try_pin_compound_head(page, 1);
+}
+
 static inline void put_page(struct page *page)
 {
 	page = compound_head(page);
@@ -1071,29 +1086,70 @@  static inline void put_page(struct page *page)
 		__put_page(page);
 }
 
-/**
- * put_user_page() - release a gup-pinned page
- * @page:            pointer to page to be released
+/*
+ * GUP_PIN_COUNTING_BIAS, and the associated functions that use it, overload
+ * the page's refcount so that two separate items are tracked: the original page
+ * reference count, and also a new count of how many pin_user_pages() calls were
+ * made against the page. ("gup-pinned" is another term for the latter).
  *
- * Pages that were pinned via pin_user_pages*() must be released via either
- * put_user_page(), or one of the put_user_pages*() routines. This is so that
- * eventually such pages can be separately tracked and uniquely handled. In
- * particular, interactions with RDMA and filesystems need special handling.
+ * With this scheme, pin_user_pages() becomes special: such pages are marked as
+ * distinct from normal pages. As such, the put_user_page() call (and its
+ * variants) must be used in order to release gup-pinned pages.
  *
- * put_user_page() and put_page() are not interchangeable, despite this early
- * implementation that makes them look the same. put_user_page() calls must
- * be perfectly matched up with pin*() calls.
+ * Choice of value:
+ *
+ * By making GUP_PIN_COUNTING_BIAS a power of two, debugging of page reference
+ * counts with respect to pin_user_pages() and put_user_page() becomes simpler,
+ * due to the fact that adding an even power of two to the page refcount has the
+ * effect of using only the upper N bits, for the code that counts up using the
+ * bias value. This means that the lower bits are left for the exclusive use of
+ * the original code that increments and decrements by one (or at least, by much
+ * smaller values than the bias value).
+ *
+ * Of course, once the lower bits overflow into the upper bits (and this is
+ * OK, because subtraction recovers the original values), then visual inspection
+ * no longer suffices to directly view the separate counts. However, for normal
+ * applications that don't have huge page reference counts, this won't be an
+ * issue.
+ *
+ * Locking: the lockless algorithm described in page_cache_get_speculative()
+ * and page_cache_gup_pin_speculative() provides safe operation for
+ * get_user_pages and page_mkclean and other calls that race to set up page
+ * table entries.
  */
-static inline void put_user_page(struct page *page)
-{
-	put_page(page);
-}
+#define GUP_PIN_COUNTING_BIAS (1UL << 10)
 
+void put_user_page(struct page *page);
 void put_user_pages_dirty_lock(struct page **pages, unsigned long npages,
 			       bool make_dirty);
-
 void put_user_pages(struct page **pages, unsigned long npages);
 
+/**
+ * page_dma_pinned() - report if a page is pinned for DMA.
+ *
+ * This function checks if a page has been pinned via a call to
+ * pin_user_pages*().
+ *
+ * The return value is partially fuzzy: false is not fuzzy, because it means
+ * "definitely not pinned for DMA", but true means "probably pinned for DMA, but
+ * possibly a false positive due to having at least GUP_PIN_COUNTING_BIAS worth
+ * of normal page references".
+ *
+ * False positives are OK, because: a) it's unlikely for a page to get that many
+ * refcounts, and b) all the callers of this routine are expected to be able to
+ * deal gracefully with a false positive.
+ *
+ * For more information, please see Documentation/vm/pin_user_pages.rst.
+ *
+ * @page:	pointer to page to be queried.
+ * @Return:	True, if it is likely that the page has been "dma-pinned".
+ *		False, if the page is definitely not dma-pinned.
+ */
+static inline bool page_dma_pinned(struct page *page)
+{
+	return (page_ref_count(compound_head(page))) >= GUP_PIN_COUNTING_BIAS;
+}
+
 #if defined(CONFIG_SPARSEMEM) && !defined(CONFIG_SPARSEMEM_VMEMMAP)
 #define SECTION_IN_PAGE_FLAGS
 #endif
@@ -2675,6 +2731,33 @@  struct page *follow_page(struct vm_area_struct *vma, unsigned long address,
  * Please see Documentation/vm/pin_user_pages.rst for more information.
  */
 
+/**
+ * grab_page() - elevate a page's refcount by a flag-dependent amount
+ *
+ * @page:	pointer to page to be grabbed
+ * @flags:	gup flags: these are the FOLL_* flag values.
+ *
+ * Either FOLL_PIN or FOLL_GET (or neither) may be set, but not both at the same
+ * time. (That's true throughout the get_user_pages*() and pin_user_pages*()
+ * APIs.) Cases:
+ *
+ *	FOLL_GET: page's refcount will be incremented by 1.
+ *	FOLL_PIN: page's refcount will be incremented by GUP_PIN_COUNTING_BIAS.
+ *
+ * @Return: true for for success, false for failure. If neither FOLL_GET nor
+ * FOLL_PIN is set, that's also considered "success", even though nothing is
+ * done: no page refcount changes are made in that case.
+ */
+static inline bool grab_page(struct page *page, unsigned int flags)
+{
+	if (flags & FOLL_GET)
+		get_page(page);
+	else if (flags & FOLL_PIN)
+		return try_pin_page(page);
+
+	return true;
+}
+
 static inline int vm_fault_to_errno(vm_fault_t vm_fault, int foll_flags)
 {
 	if (vm_fault & VM_FAULT_OOM)
diff --git a/include/linux/mmzone.h b/include/linux/mmzone.h
index bda20282746b..0485cba38d23 100644
--- a/include/linux/mmzone.h
+++ b/include/linux/mmzone.h
@@ -244,6 +244,8 @@  enum node_stat_item {
 	NR_DIRTIED,		/* page dirtyings since bootup */
 	NR_WRITTEN,		/* page writings since bootup */
 	NR_KERNEL_MISC_RECLAIMABLE,	/* reclaimable non-slab kernel pages */
+	NR_FOLL_PIN_REQUESTED,	/* via: pin_user_page(), gup flag: FOLL_PIN */
+	NR_FOLL_PIN_RETURNED,	/* pages returned via put_user_page() */
 	NR_VM_NODE_STAT_ITEMS
 };
 
diff --git a/include/linux/page_ref.h b/include/linux/page_ref.h
index 14d14beb1f7f..b9cbe553d1e7 100644
--- a/include/linux/page_ref.h
+++ b/include/linux/page_ref.h
@@ -102,6 +102,16 @@  static inline void page_ref_sub(struct page *page, int nr)
 		__page_ref_mod(page, -nr);
 }
 
+static inline int page_ref_sub_return(struct page *page, int nr)
+{
+	int ret = atomic_sub_return(nr, &page->_refcount);
+
+	if (page_ref_tracepoint_active(__tracepoint_page_ref_mod))
+		__page_ref_mod(page, -nr);
+
+	return ret;
+}
+
 static inline void page_ref_inc(struct page *page)
 {
 	atomic_inc(&page->_refcount);
diff --git a/mm/gup.c b/mm/gup.c
index f72d7a1635b4..b06434185e33 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -51,6 +51,96 @@  static inline struct page *try_get_compound_head(struct page *page, int refs)
 	return head;
 }
 
+#ifdef CONFIG_DEBUG_VM
+static inline void __update_proc_vmstat(struct page *page,
+					enum node_stat_item item, int count)
+{
+	mod_node_page_state(page_pgdat(page), item, count);
+}
+#else
+static inline void __update_proc_vmstat(struct page *page,
+					enum node_stat_item item, int count)
+{
+}
+#endif
+
+/**
+ * try_pin_compound_head() - mark a compound page as being used by
+ * pin_user_pages*().
+ *
+ * This is the FOLL_PIN counterpart to try_get_compound_head().
+ *
+ * @page:	pointer to page to be marked
+ * @Return:	true for success, false for failure
+ */
+__must_check bool try_pin_compound_head(struct page *page, int refs)
+{
+	page = try_get_compound_head(page, GUP_PIN_COUNTING_BIAS * refs);
+	if (!page)
+		return false;
+
+	__update_proc_vmstat(page, NR_FOLL_PIN_REQUESTED, refs);
+	return true;
+}
+
+#ifdef CONFIG_DEV_PAGEMAP_OPS
+static bool __put_devmap_managed_user_page(struct page *page)
+{
+	bool is_devmap = page_is_devmap_managed(page);
+
+	if (is_devmap) {
+		int count = page_ref_sub_return(page, GUP_PIN_COUNTING_BIAS);
+
+		__update_proc_vmstat(page, NR_FOLL_PIN_RETURNED, 1);
+		/*
+		 * devmap page refcounts are 1-based, rather than 0-based: if
+		 * refcount is 1, then the page is free and the refcount is
+		 * stable because nobody holds a reference on the page.
+		 */
+		if (count == 1)
+			free_devmap_managed_page(page);
+		else if (!count)
+			__put_page(page);
+	}
+
+	return is_devmap;
+}
+#else
+static bool __put_devmap_managed_user_page(struct page *page)
+{
+	return false;
+}
+#endif /* CONFIG_DEV_PAGEMAP_OPS */
+
+/**
+ * put_user_page() - release a dma-pinned page
+ * @page:            pointer to page to be released
+ *
+ * Pages that were pinned via pin_user_pages*() must be released via either
+ * put_user_page(), or one of the put_user_pages*() routines. This is so that
+ * such pages can be separately tracked and uniquely handled. In particular,
+ * interactions with RDMA and filesystems need special handling.
+ */
+void put_user_page(struct page *page)
+{
+	page = compound_head(page);
+
+	/*
+	 * For devmap managed pages we need to catch refcount transition from
+	 * GUP_PIN_COUNTING_BIAS to 1, when refcount reach one it means the
+	 * page is free and we need to inform the device driver through
+	 * callback. See include/linux/memremap.h and HMM for details.
+	 */
+	if (__put_devmap_managed_user_page(page))
+		return;
+
+	if (page_ref_sub_and_test(page, GUP_PIN_COUNTING_BIAS))
+		__put_page(page);
+
+	__update_proc_vmstat(page, NR_FOLL_PIN_RETURNED, 1);
+}
+EXPORT_SYMBOL(put_user_page);
+
 /**
  * put_user_pages_dirty_lock() - release and optionally dirty gup-pinned pages
  * @pages:  array of pages to be maybe marked dirty, and definitely released.
@@ -237,10 +327,11 @@  static struct page *follow_page_pte(struct vm_area_struct *vma,
 	}
 
 	page = vm_normal_page(vma, address, pte);
-	if (!page && pte_devmap(pte) && (flags & FOLL_GET)) {
+	if (!page && pte_devmap(pte) && (flags & (FOLL_GET | FOLL_PIN))) {
 		/*
-		 * Only return device mapping pages in the FOLL_GET case since
-		 * they are only valid while holding the pgmap reference.
+		 * Only return device mapping pages in the FOLL_GET or FOLL_PIN
+		 * case since they are only valid while holding the pgmap
+		 * reference.
 		 */
 		*pgmap = get_dev_pagemap(pte_pfn(pte), *pgmap);
 		if (*pgmap)
@@ -283,6 +374,11 @@  static struct page *follow_page_pte(struct vm_area_struct *vma,
 			page = ERR_PTR(-ENOMEM);
 			goto out;
 		}
+	} else if (flags & FOLL_PIN) {
+		if (unlikely(!try_pin_page(page))) {
+			page = ERR_PTR(-ENOMEM);
+			goto out;
+		}
 	}
 	if (flags & FOLL_TOUCH) {
 		if ((flags & FOLL_WRITE) &&
@@ -544,8 +640,8 @@  static struct page *follow_page_mask(struct vm_area_struct *vma,
 	/* make this handle hugepd */
 	page = follow_huge_addr(mm, address, flags & FOLL_WRITE);
 	if (!IS_ERR(page)) {
-		BUG_ON(flags & FOLL_GET);
-		return page;
+		WARN_ON_ONCE(flags & (FOLL_GET | FOLL_PIN));
+		return NULL;
 	}
 
 	pgd = pgd_offset(mm, address);
@@ -1125,6 +1221,36 @@  static __always_inline long __get_user_pages_locked(struct task_struct *tsk,
 	return pages_done;
 }
 
+static long __get_user_pages_remote(struct task_struct *tsk,
+				    struct mm_struct *mm,
+				    unsigned long start, unsigned long nr_pages,
+				    unsigned int gup_flags, struct page **pages,
+				    struct vm_area_struct **vmas, int *locked)
+{
+	/*
+	 * Parts of FOLL_LONGTERM behavior are incompatible with
+	 * FAULT_FLAG_ALLOW_RETRY because of the FS DAX check requirement on
+	 * vmas. However, this only comes up if locked is set, and there are
+	 * callers that do request FOLL_LONGTERM, but do not set locked. So,
+	 * allow what we can.
+	 */
+	if (gup_flags & FOLL_LONGTERM) {
+		if (WARN_ON_ONCE(locked))
+			return -EINVAL;
+		/*
+		 * This will check the vmas (even if our vmas arg is NULL)
+		 * and return -ENOTSUPP if DAX isn't allowed in this case:
+		 */
+		return __gup_longterm_locked(tsk, mm, start, nr_pages, pages,
+					     vmas, gup_flags | FOLL_TOUCH |
+					     FOLL_REMOTE);
+	}
+
+	return __get_user_pages_locked(tsk, mm, start, nr_pages, pages, vmas,
+				       locked,
+				       gup_flags | FOLL_TOUCH | FOLL_REMOTE);
+}
+
 /*
  * get_user_pages_remote() - pin user pages in memory
  * @tsk:	the task_struct to use for page fault accounting, or
@@ -1193,28 +1319,8 @@  long get_user_pages_remote(struct task_struct *tsk, struct mm_struct *mm,
 	if (WARN_ON_ONCE(gup_flags & FOLL_PIN))
 		return -EINVAL;
 
-	/*
-	 * Parts of FOLL_LONGTERM behavior are incompatible with
-	 * FAULT_FLAG_ALLOW_RETRY because of the FS DAX check requirement on
-	 * vmas. However, this only comes up if locked is set, and there are
-	 * callers that do request FOLL_LONGTERM, but do not set locked. So,
-	 * allow what we can.
-	 */
-	if (gup_flags & FOLL_LONGTERM) {
-		if (WARN_ON_ONCE(locked))
-			return -EINVAL;
-		/*
-		 * This will check the vmas (even if our vmas arg is NULL)
-		 * and return -ENOTSUPP if DAX isn't allowed in this case:
-		 */
-		return __gup_longterm_locked(tsk, mm, start, nr_pages, pages,
-					     vmas, gup_flags | FOLL_TOUCH |
-					     FOLL_REMOTE);
-	}
-
-	return __get_user_pages_locked(tsk, mm, start, nr_pages, pages, vmas,
-				       locked,
-				       gup_flags | FOLL_TOUCH | FOLL_REMOTE);
+	return __get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags,
+				       pages, vmas, locked);
 }
 EXPORT_SYMBOL(get_user_pages_remote);
 
@@ -1842,13 +1948,17 @@  static inline pte_t gup_get_pte(pte_t *ptep)
 #endif /* CONFIG_GUP_GET_PTE_LOW_HIGH */
 
 static void __maybe_unused undo_dev_pagemap(int *nr, int nr_start,
+					    unsigned int flags,
 					    struct page **pages)
 {
 	while ((*nr) - nr_start) {
 		struct page *page = pages[--(*nr)];
 
 		ClearPageReferenced(page);
-		put_page(page);
+		if (flags & FOLL_PIN)
+			put_user_page(page);
+		else
+			put_page(page);
 	}
 }
 
@@ -1881,7 +1991,7 @@  static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
 
 			pgmap = get_dev_pagemap(pte_pfn(pte), pgmap);
 			if (unlikely(!pgmap)) {
-				undo_dev_pagemap(nr, nr_start, pages);
+				undo_dev_pagemap(nr, nr_start, flags, pages);
 				goto pte_unmap;
 			}
 		} else if (pte_special(pte))
@@ -1890,9 +2000,15 @@  static int gup_pte_range(pmd_t pmd, unsigned long addr, unsigned long end,
 		VM_BUG_ON(!pfn_valid(pte_pfn(pte)));
 		page = pte_page(pte);
 
-		head = try_get_compound_head(page, 1);
-		if (!head)
-			goto pte_unmap;
+		if (flags & FOLL_PIN) {
+			head = page;
+			if (unlikely(!try_pin_page(head)))
+				goto pte_unmap;
+		} else {
+			head = try_get_compound_head(page, 1);
+			if (!head)
+				goto pte_unmap;
+		}
 
 		if (unlikely(pte_val(pte) != pte_val(*ptep))) {
 			put_page(head);
@@ -1946,12 +2062,20 @@  static int __gup_device_huge(unsigned long pfn, unsigned long addr,
 
 		pgmap = get_dev_pagemap(pfn, pgmap);
 		if (unlikely(!pgmap)) {
-			undo_dev_pagemap(nr, nr_start, pages);
+			undo_dev_pagemap(nr, nr_start, flags, pages);
 			return 0;
 		}
 		SetPageReferenced(page);
 		pages[*nr] = page;
-		get_page(page);
+
+		if (flags & FOLL_PIN) {
+			if (unlikely(!try_pin_page(page))) {
+				undo_dev_pagemap(nr, nr_start, flags, pages);
+				return 0;
+			}
+		} else
+			get_page(page);
+
 		(*nr)++;
 		pfn++;
 	} while (addr += PAGE_SIZE, addr != end);
@@ -1973,7 +2097,7 @@  static int __gup_device_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
 		return 0;
 
 	if (unlikely(pmd_val(orig) != pmd_val(*pmdp))) {
-		undo_dev_pagemap(nr, nr_start, pages);
+		undo_dev_pagemap(nr, nr_start, flags, pages);
 		return 0;
 	}
 	return 1;
@@ -1991,7 +2115,7 @@  static int __gup_device_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
 		return 0;
 
 	if (unlikely(pud_val(orig) != pud_val(*pudp))) {
-		undo_dev_pagemap(nr, nr_start, pages);
+		undo_dev_pagemap(nr, nr_start, flags, pages);
 		return 0;
 	}
 	return 1;
@@ -2014,8 +2138,8 @@  static int __gup_device_huge_pud(pud_t pud, pud_t *pudp, unsigned long addr,
 }
 #endif
 
-static int __record_subpages(struct page *page, unsigned long addr,
-			     unsigned long end, struct page **pages)
+static int record_subpages(struct page *page, unsigned long addr,
+			   unsigned long end, struct page **pages)
 {
 	int nr;
 
@@ -2025,12 +2149,31 @@  static int __record_subpages(struct page *page, unsigned long addr,
 	return nr;
 }
 
-static void put_compound_head(struct page *page, int refs)
+static bool grab_compound_head(struct page *head, int refs, unsigned int flags)
 {
+	if (flags & FOLL_PIN) {
+		if (unlikely(!try_pin_compound_head(head, refs)))
+			return false;
+	} else {
+		head = try_get_compound_head(head, refs);
+		if (!head)
+			return false;
+	}
+
+	return true;
+}
+
+static void put_compound_head(struct page *page, int refs, unsigned int flags)
+{
+	struct page *head = compound_head(page);
+
+	if (flags & FOLL_PIN)
+		refs *= GUP_PIN_COUNTING_BIAS;
+
 	/* Do a get_page() first, in case refs == page->_refcount */
-	get_page(page);
-	page_ref_sub(page, refs);
-	put_page(page);
+	get_page(head);
+	page_ref_sub(head, refs);
+	put_page(head);
 }
 
 #ifdef CONFIG_ARCH_HAS_HUGEPD
@@ -2064,14 +2207,13 @@  static int gup_hugepte(pte_t *ptep, unsigned long sz, unsigned long addr,
 
 	head = pte_page(pte);
 	page = head + ((addr & (sz-1)) >> PAGE_SHIFT);
-	refs = __record_subpages(page, addr, end, pages + *nr);
+	refs = record_subpages(page, addr, end, pages + *nr);
 
-	head = try_get_compound_head(head, refs);
-	if (!head)
+	if (!grab_compound_head(head, refs, flags))
 		return 0;
 
 	if (unlikely(pte_val(pte) != pte_val(*ptep))) {
-		put_compound_head(head, refs);
+		put_compound_head(head, refs, flags);
 		return 0;
 	}
 
@@ -2124,14 +2266,14 @@  static int gup_huge_pmd(pmd_t orig, pmd_t *pmdp, unsigned long addr,
 	}
 
 	page = pmd_page(orig) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
-	refs = __record_subpages(page, addr, end, pages + *nr);
+	refs = record_subpages(page, addr, end, pages + *nr);
 
-	head = try_get_compound_head(pmd_page(orig), refs);
-	if (!head)
+	head = pmd_page(orig);
+	if (!grab_compound_head(head, refs, flags))
 		return 0;
 
 	if (unlikely(pmd_val(orig) != pmd_val(*pmdp))) {
-		put_compound_head(head, refs);
+		put_compound_head(head, refs, flags);
 		return 0;
 	}
 
@@ -2158,14 +2300,14 @@  static int gup_huge_pud(pud_t orig, pud_t *pudp, unsigned long addr,
 	}
 
 	page = pud_page(orig) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
-	refs = __record_subpages(page, addr, end, pages + *nr);
+	refs = record_subpages(page, addr, end, pages + *nr);
 
-	head = try_get_compound_head(pud_page(orig), refs);
-	if (!head)
+	head = pud_page(orig);
+	if (!grab_compound_head(head, refs, flags))
 		return 0;
 
 	if (unlikely(pud_val(orig) != pud_val(*pudp))) {
-		put_compound_head(head, refs);
+		put_compound_head(head, refs, flags);
 		return 0;
 	}
 
@@ -2187,14 +2329,14 @@  static int gup_huge_pgd(pgd_t orig, pgd_t *pgdp, unsigned long addr,
 	BUILD_BUG_ON(pgd_devmap(orig));
 
 	page = pgd_page(orig) + ((addr & ~PGDIR_MASK) >> PAGE_SHIFT);
-	refs = __record_subpages(page, addr, end, pages + *nr);
+	refs = record_subpages(page, addr, end, pages + *nr);
 
-	head = try_get_compound_head(pgd_page(orig), refs);
-	if (!head)
+	head = pgd_page(orig);
+	if (!grab_compound_head(head, refs, flags))
 		return 0;
 
 	if (unlikely(pgd_val(orig) != pgd_val(*pgdp))) {
-		put_compound_head(head, refs);
+		put_compound_head(head, refs, flags);
 		return 0;
 	}
 
@@ -2494,9 +2636,12 @@  EXPORT_SYMBOL_GPL(get_user_pages_fast);
 /**
  * pin_user_pages_fast() - pin user pages in memory without taking locks
  *
- * For now, this is a placeholder function, until various call sites are
- * converted to use the correct get_user_pages*() or pin_user_pages*() API. So,
- * this is identical to get_user_pages_fast().
+ * Nearly the same as get_user_pages_fast(), except that FOLL_PIN is set. See
+ * get_user_pages_fast() for documentation on the function arguments, because
+ * the arguments here are identical.
+ *
+ * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
+ * see Documentation/vm/pin_user_pages.rst for further details.
  *
  * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  * is NOT intended for Case 2 (RDMA: long-term pins).
@@ -2504,21 +2649,24 @@  EXPORT_SYMBOL_GPL(get_user_pages_fast);
 int pin_user_pages_fast(unsigned long start, int nr_pages,
 			unsigned int gup_flags, struct page **pages)
 {
-	/*
-	 * This is a placeholder, until the pin functionality is activated.
-	 * Until then, just behave like the corresponding get_user_pages*()
-	 * routine.
-	 */
-	return get_user_pages_fast(start, nr_pages, gup_flags, pages);
+	/* FOLL_GET and FOLL_PIN are mutually exclusive. */
+	if (WARN_ON_ONCE(gup_flags & FOLL_GET))
+		return -EINVAL;
+
+	gup_flags |= FOLL_PIN;
+	return internal_get_user_pages_fast(start, nr_pages, gup_flags, pages);
 }
 EXPORT_SYMBOL_GPL(pin_user_pages_fast);
 
 /**
  * pin_user_pages_remote() - pin pages of a remote process (task != current)
  *
- * For now, this is a placeholder function, until various call sites are
- * converted to use the correct get_user_pages*() or pin_user_pages*() API. So,
- * this is identical to get_user_pages_remote().
+ * Nearly the same as get_user_pages_remote(), except that FOLL_PIN is set. See
+ * get_user_pages_remote() for documentation on the function arguments, because
+ * the arguments here are identical.
+ *
+ * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
+ * see Documentation/vm/pin_user_pages.rst for details.
  *
  * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  * is NOT intended for Case 2 (RDMA: long-term pins).
@@ -2528,22 +2676,24 @@  long pin_user_pages_remote(struct task_struct *tsk, struct mm_struct *mm,
 			   unsigned int gup_flags, struct page **pages,
 			   struct vm_area_struct **vmas, int *locked)
 {
-	/*
-	 * This is a placeholder, until the pin functionality is activated.
-	 * Until then, just behave like the corresponding get_user_pages*()
-	 * routine.
-	 */
-	return get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags, pages,
-				     vmas, locked);
+	/* FOLL_GET and FOLL_PIN are mutually exclusive. */
+	if (WARN_ON_ONCE(gup_flags & FOLL_GET))
+		return -EINVAL;
+
+	gup_flags |= FOLL_PIN;
+	return __get_user_pages_remote(tsk, mm, start, nr_pages, gup_flags,
+				       pages, vmas, locked);
 }
 EXPORT_SYMBOL(pin_user_pages_remote);
 
 /**
  * pin_user_pages() - pin user pages in memory for use by other devices
  *
- * For now, this is a placeholder function, until various call sites are
- * converted to use the correct get_user_pages*() or pin_user_pages*() API. So,
- * this is identical to get_user_pages().
+ * Nearly the same as get_user_pages(), except that FOLL_TOUCH is not set, and
+ * FOLL_PIN is set.
+ *
+ * FOLL_PIN means that the pages must be released via unpin_user_page(). Please
+ * see Documentation/vm/pin_user_pages.rst for details.
  *
  * This is intended for Case 1 (DIO) in Documentation/vm/pin_user_pages.rst. It
  * is NOT intended for Case 2 (RDMA: long-term pins).
@@ -2552,11 +2702,12 @@  long pin_user_pages(unsigned long start, unsigned long nr_pages,
 		    unsigned int gup_flags, struct page **pages,
 		    struct vm_area_struct **vmas)
 {
-	/*
-	 * This is a placeholder, until the pin functionality is activated.
-	 * Until then, just behave like the corresponding get_user_pages*()
-	 * routine.
-	 */
-	return get_user_pages(start, nr_pages, gup_flags, pages, vmas);
+	/* FOLL_GET and FOLL_PIN are mutually exclusive. */
+	if (WARN_ON_ONCE(gup_flags & FOLL_GET))
+		return -EINVAL;
+
+	gup_flags |= FOLL_PIN;
+	return __gup_longterm_locked(current, current->mm, start, nr_pages,
+				     pages, vmas, gup_flags);
 }
 EXPORT_SYMBOL(pin_user_pages);
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 13cc93785006..981a9ea0b83f 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -945,6 +945,11 @@  struct page *follow_devmap_pmd(struct vm_area_struct *vma, unsigned long addr,
 	 */
 	WARN_ONCE(flags & FOLL_COW, "mm: In follow_devmap_pmd with FOLL_COW set");
 
+	/* FOLL_GET and FOLL_PIN are mutually exclusive. */
+	if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
+			 (FOLL_PIN | FOLL_GET)))
+		return NULL;
+
 	if (flags & FOLL_WRITE && !pmd_write(*pmd))
 		return NULL;
 
@@ -960,7 +965,7 @@  struct page *follow_devmap_pmd(struct vm_area_struct *vma, unsigned long addr,
 	 * device mapped pages can only be returned if the
 	 * caller will manage the page reference count.
 	 */
-	if (!(flags & FOLL_GET))
+	if (!(flags & (FOLL_GET | FOLL_PIN)))
 		return ERR_PTR(-EEXIST);
 
 	pfn += (addr & ~PMD_MASK) >> PAGE_SHIFT;
@@ -968,7 +973,14 @@  struct page *follow_devmap_pmd(struct vm_area_struct *vma, unsigned long addr,
 	if (!*pgmap)
 		return ERR_PTR(-EFAULT);
 	page = pfn_to_page(pfn);
-	get_page(page);
+
+	/*
+	 * grab_page() is not actually expected to fail here because we hold the
+	 * pmd lock, so no one can unmap the pmd and free the page that it
+	 * points to.
+	 */
+	if (!grab_page(page, flags))
+		page = ERR_PTR(-EFAULT);
 
 	return page;
 }
@@ -1088,6 +1100,11 @@  struct page *follow_devmap_pud(struct vm_area_struct *vma, unsigned long addr,
 	if (flags & FOLL_WRITE && !pud_write(*pud))
 		return NULL;
 
+	/* FOLL_GET and FOLL_PIN are mutually exclusive. */
+	if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
+			 (FOLL_PIN | FOLL_GET)))
+		return NULL;
+
 	if (pud_present(*pud) && pud_devmap(*pud))
 		/* pass */;
 	else
@@ -1099,8 +1116,10 @@  struct page *follow_devmap_pud(struct vm_area_struct *vma, unsigned long addr,
 	/*
 	 * device mapped pages can only be returned if the
 	 * caller will manage the page reference count.
+	 *
+	 * At least one of FOLL_GET | FOLL_PIN must be set, so assert that here:
 	 */
-	if (!(flags & FOLL_GET))
+	if (!(flags & (FOLL_GET | FOLL_PIN)))
 		return ERR_PTR(-EEXIST);
 
 	pfn += (addr & ~PUD_MASK) >> PAGE_SHIFT;
@@ -1108,7 +1127,14 @@  struct page *follow_devmap_pud(struct vm_area_struct *vma, unsigned long addr,
 	if (!*pgmap)
 		return ERR_PTR(-EFAULT);
 	page = pfn_to_page(pfn);
-	get_page(page);
+
+	/*
+	 * grab_page() is not actually expected to fail here because we hold the
+	 * pmd lock, so no one can unmap the pmd and free the page that it
+	 * points to.
+	 */
+	if (!grab_page(page, flags))
+		page = ERR_PTR(-EFAULT);
 
 	return page;
 }
@@ -1522,8 +1548,14 @@  struct page *follow_trans_huge_pmd(struct vm_area_struct *vma,
 skip_mlock:
 	page += (addr & ~HPAGE_PMD_MASK) >> PAGE_SHIFT;
 	VM_BUG_ON_PAGE(!PageCompound(page) && !is_zone_device_page(page), page);
-	if (flags & FOLL_GET)
-		get_page(page);
+
+	/*
+	 * grab_page() is not actually expected to fail here because we hold the
+	 * pmd lock, so no one can unmap the pmd and free the page that it
+	 * points to.
+	 */
+	if (!grab_page(page, flags))
+		page = ERR_PTR(-EFAULT);
 
 out:
 	return page;
diff --git a/mm/hugetlb.c b/mm/hugetlb.c
index b45a95363a84..eac3310d62f5 100644
--- a/mm/hugetlb.c
+++ b/mm/hugetlb.c
@@ -4462,7 +4462,19 @@  long follow_hugetlb_page(struct mm_struct *mm, struct vm_area_struct *vma,
 same_page:
 		if (pages) {
 			pages[i] = mem_map_offset(page, pfn_offset);
-			get_page(pages[i]);
+
+			/*
+			 * grab_page() is not actually expected to fail here
+			 * because we hold the pmd lock, so no one can unmap the
+			 * pmd and free the page that it points to.
+			 */
+			if (!grab_page(pages[i], flags)) {
+				spin_unlock(ptl);
+				remainder = 0;
+				err = -ENOMEM;
+				WARN_ON_ONCE(1);
+				break;
+			}
 		}
 
 		if (vmas)
@@ -5022,6 +5034,12 @@  follow_huge_pmd(struct mm_struct *mm, unsigned long address,
 	struct page *page = NULL;
 	spinlock_t *ptl;
 	pte_t pte;
+
+	/* FOLL_GET and FOLL_PIN are mutually exclusive. */
+	if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
+			 (FOLL_PIN | FOLL_GET)))
+		return NULL;
+
 retry:
 	ptl = pmd_lockptr(mm, pmd);
 	spin_lock(ptl);
@@ -5034,8 +5052,20 @@  follow_huge_pmd(struct mm_struct *mm, unsigned long address,
 	pte = huge_ptep_get((pte_t *)pmd);
 	if (pte_present(pte)) {
 		page = pmd_page(*pmd) + ((address & ~PMD_MASK) >> PAGE_SHIFT);
+
 		if (flags & FOLL_GET)
 			get_page(page);
+		else if (flags & FOLL_PIN) {
+			/*
+			 * try_pin_page() is not actually expected to fail
+			 * here because we hold the ptl.
+			 */
+			if (unlikely(!try_pin_page(page))) {
+				WARN_ON_ONCE(1);
+				page = NULL;
+				goto out;
+			}
+		}
 	} else {
 		if (is_hugetlb_entry_migration(pte)) {
 			spin_unlock(ptl);
@@ -5056,7 +5086,7 @@  struct page * __weak
 follow_huge_pud(struct mm_struct *mm, unsigned long address,
 		pud_t *pud, int flags)
 {
-	if (flags & FOLL_GET)
+	if (flags & (FOLL_GET | FOLL_PIN))
 		return NULL;
 
 	return pte_page(*(pte_t *)pud) + ((address & ~PUD_MASK) >> PAGE_SHIFT);
@@ -5065,7 +5095,7 @@  follow_huge_pud(struct mm_struct *mm, unsigned long address,
 struct page * __weak
 follow_huge_pgd(struct mm_struct *mm, unsigned long address, pgd_t *pgd, int flags)
 {
-	if (flags & FOLL_GET)
+	if (flags & (FOLL_GET | FOLL_PIN))
 		return NULL;
 
 	return pte_page(*(pte_t *)pgd) + ((address & ~PGDIR_MASK) >> PAGE_SHIFT);
diff --git a/mm/vmstat.c b/mm/vmstat.c
index a8222041bd44..fdad40ccde7b 100644
--- a/mm/vmstat.c
+++ b/mm/vmstat.c
@@ -1167,6 +1167,8 @@  const char * const vmstat_text[] = {
 	"nr_dirtied",
 	"nr_written",
 	"nr_kernel_misc_reclaimable",
+	"nr_foll_pin_requested",
+	"nr_foll_pin_returned",
 
 	/* enum writeback_stat_item counters */
 	"nr_dirty_threshold",