diff mbox series

[v2,1/6] iov_iter: Introduce fault_in_iov_iter helper

Message ID 20210718223932.2703330-2-agruenba@redhat.com (mailing list archive)
State New, archived
Headers show
Series gfs2: Fix mmap + page fault deadlocks | expand

Commit Message

Andreas Gruenbacher July 18, 2021, 10:39 p.m. UTC
Introduce a new fault_in_iov_iter helper for manually faulting in an iterator.
Other than fault_in_pages_writeable(), this function is non-destructive.

We'll use fault_in_iov_iter in gfs2 once we've determined that the iterator
passed to .read_iter or .write_iter isn't in memory.

Signed-off-by: Andreas Gruenbacher <agruenba@redhat.com>
---
 include/linux/mm.h  |  3 ++
 include/linux/uio.h |  1 +
 lib/iov_iter.c      | 42 ++++++++++++++++++++++++++++
 mm/gup.c            | 68 +++++++++++++++++++++++++++++++++++++++++++++
 4 files changed, 114 insertions(+)

Comments

Linus Torvalds July 19, 2021, 7:26 p.m. UTC | #1
On Sun, Jul 18, 2021 at 3:39 PM Andreas Gruenbacher <agruenba@redhat.com> wrote:
>
> Introduce a new fault_in_iov_iter helper for manually faulting in an iterator.
> Other than fault_in_pages_writeable(), this function is non-destructive.

You mean "Unlike" rather than "Other than" (also in the comment of the patch).

This is fairly inefficient, but as long as it's the exceptional case,
that's fine. It might be worth making that very explicit, so that
people don't try to use it normally.

                Linus
diff mbox series

Patch

diff --git a/include/linux/mm.h b/include/linux/mm.h
index 8ae31622deef..dc7d64632a60 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -1836,6 +1836,9 @@  int get_user_pages_fast(unsigned long start, int nr_pages,
 int pin_user_pages_fast(unsigned long start, int nr_pages,
 			unsigned int gup_flags, struct page **pages);
 
+unsigned long fault_in_user_pages(unsigned long start, unsigned long len,
+				  bool write);
+
 int account_locked_vm(struct mm_struct *mm, unsigned long pages, bool inc);
 int __account_locked_vm(struct mm_struct *mm, unsigned long pages, bool inc,
 			struct task_struct *task, bool bypass_rlim);
diff --git a/include/linux/uio.h b/include/linux/uio.h
index d3ec87706d75..74f819c41735 100644
--- a/include/linux/uio.h
+++ b/include/linux/uio.h
@@ -124,6 +124,7 @@  size_t iov_iter_copy_from_user_atomic(struct page *page,
 void iov_iter_advance(struct iov_iter *i, size_t bytes);
 void iov_iter_revert(struct iov_iter *i, size_t bytes);
 int iov_iter_fault_in_readable(struct iov_iter *i, size_t bytes);
+size_t fault_in_iov_iter(const struct iov_iter *i);
 size_t iov_iter_single_seg_count(const struct iov_iter *i);
 size_t copy_page_to_iter(struct page *page, size_t offset, size_t bytes,
 			 struct iov_iter *i);
diff --git a/lib/iov_iter.c b/lib/iov_iter.c
index c701b7a187f2..3beecf8f77de 100644
--- a/lib/iov_iter.c
+++ b/lib/iov_iter.c
@@ -487,6 +487,48 @@  int iov_iter_fault_in_readable(struct iov_iter *i, size_t bytes)
 }
 EXPORT_SYMBOL(iov_iter_fault_in_readable);
 
+/**
+ * fault_in_iov_iter - fault in iov iterator for reading / writing
+ * @i: iterator
+ *
+ * Faults in the iterator using get_user_pages, i.e., without triggering
+ * hardware page faults.
+ *
+ * This is primarily useful when we know that some or all of the pages in @i
+ * aren't in memory.  For iterators that are likely to be in memory,
+ * fault_in_pages_readable() may be more appropriate.
+ *
+ * Other than fault_in_pages_writeable(), this function is non-destructive even
+ * when faulting in pages for writing.
+ *
+ * Returns the number of bytes faulted in, or the size of @i if @i doesn't need
+ * faulting in.
+ */
+size_t fault_in_iov_iter(const struct iov_iter *i)
+{
+	size_t count = i->count;
+	const struct iovec *p;
+	size_t ret = 0, skip;
+
+	if (iter_is_iovec(i)) {
+		for (p = i->iov, skip = i->iov_offset; count; p++, skip = 0) {
+			unsigned long len = min(count, p->iov_len - skip);
+			unsigned long start, l;
+
+			if (unlikely(!len))
+				continue;
+			start = (unsigned long)p->iov_base + skip;
+			l = fault_in_user_pages(start, len, iov_iter_rw(i) != WRITE);
+			ret += l;
+			if (unlikely(l != len))
+				break;
+			count -= l;
+		}
+	}
+	return ret;
+}
+EXPORT_SYMBOL(fault_in_iov_iter);
+
 void iov_iter_init(struct iov_iter *i, unsigned int direction,
 			const struct iovec *iov, unsigned long nr_segs,
 			size_t count)
diff --git a/mm/gup.c b/mm/gup.c
index 3ded6a5f26b2..f87fb24d59e1 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -1565,6 +1565,74 @@  static long __get_user_pages_locked(struct mm_struct *mm, unsigned long start,
 }
 #endif /* !CONFIG_MMU */
 
+/**
+ * fault_in_user_pages - fault in an address range for reading / writing
+ * @start: start of address range
+ * @len: length of address range
+ * @write: fault in for writing
+ *
+ * Note that we don't pin or otherwise hold the pages referenced that we fault
+ * in.  There's no guarantee that they'll stay in memory for any duration of
+ * time.
+ *
+ * Returns the number of bytes faulted in from @start.
+ */
+unsigned long fault_in_user_pages(unsigned long start, unsigned long len,
+				  bool write)
+{
+	struct mm_struct *mm = current->mm;
+	struct vm_area_struct *vma = NULL;
+	unsigned long end, nstart, nend;
+	int locked = 0;
+	int gup_flags;
+
+	/*
+	 * FIXME: Make sure this function doesn't succeed for pages that cannot
+	 * be accessed; otherwise we could end up in a loop trying to fault in
+	 * and then access the pages.  (It's okay if a page gets evicted and we
+	 * need more than one retry.)
+	 */
+
+	/*
+	 * FIXME: Are these the right FOLL_* flags?
+	 */
+
+	gup_flags = FOLL_TOUCH | FOLL_POPULATE;
+	if (write)
+		gup_flags |= FOLL_WRITE;
+
+	end = PAGE_ALIGN(start + len);
+	for (nstart = start & PAGE_MASK; nstart < end; nstart = nend) {
+		unsigned long nr_pages;
+		long ret;
+
+		if (!locked) {
+			locked = 1;
+			mmap_read_lock(mm);
+			vma = find_vma(mm, nstart);
+		} else if (nstart >= vma->vm_end)
+			vma = vma->vm_next;
+		if (!vma || vma->vm_start >= end)
+			break;
+		nend = min(end, vma->vm_end);
+		if (vma->vm_flags & (VM_IO | VM_PFNMAP))
+			continue;
+		if (nstart < vma->vm_start)
+			nstart = vma->vm_start;
+		nr_pages = (nend - nstart) / PAGE_SIZE;
+		ret = __get_user_pages_locked(mm, nstart, nr_pages,
+					      NULL, NULL, &locked, gup_flags);
+		if (ret <= 0)
+			break;
+		nend = nstart + ret * PAGE_SIZE;
+	}
+	if (locked)
+		mmap_read_unlock(mm);
+	if (nstart > start)
+		return min(nstart - start, len);
+	return 0;
+}
+
 /**
  * get_dump_page() - pin user page in memory while writing it to core dump
  * @addr: user address