diff mbox

[2/2] dax: fix bdev NULL pointer dereferences

Message ID 20160201134410.GD2948@linux.intel.com
State New, archived
Headers show

Commit Message

Matthew Wilcox Feb. 1, 2016, 1:44 p.m. UTC
On Sun, Jan 31, 2016 at 01:32:47PM +1100, Matthew Wilcox wrote:
> On Fri, Jan 29, 2016 at 10:01:13PM -0800, Dan Williams wrote:
> > On Fri, Jan 29, 2016 at 9:28 PM, Matthew Wilcox <willy@linux.intel.com> wrote:
> > > If we store the PFN of the underlying page instead, we don't have this
> > > problem.  Instead, we have a different problem; of the device going
> > > away under us.  I'm trying to find the code which tears down PTEs when
> > > the device goes away, and I'm not seeing it.  What do we do about user
> > > mappings of the device?
> > 
> > I deferred the dax tear down code until next cycle as Al rightly
> > pointed out some needed re-works:
> > 
> > https://lists.01.org/pipermail/linux-nvdimm/2016-January/003995.html
> 
> Thanks; I eventually found it in my email somewhere over the Pacific.
> 
> I did probably 70% of the work needed to switch the radix tree over to
> storing PFNs instead of sectors.  It seems viable, though it's a big
> change from where we are today:

70%?!  Hah.  I'd done maybe 50%.  This isn't everything needed; I still
need to write radix_tree_replace().  But it's enough to get a flavour for
where this line of thinking takes us.  I think it ends up being cleaner
code, and possibly better performing.  I also think it points us back
in the direction of wanting an address_space operation to return a PFN
for the radix tree instead of handling buffer_heads directly in dax.c.

Ah well.  Time to sleep ...

From 0321c30eeb189ad2da8dcc25623419e2ba9c6cee Mon Sep 17 00:00:00 2001
From: Matthew Wilcox <matthew.r.wilcox@intel.com>
Date: Sun, 31 Jan 2016 13:38:21 +1100
Subject: [PATCH] Giant non-compiling mess

Note that clear_pmem needs to be updated to set the needs_wmb() flag.

Signed-off-by: Matthew Wilcox <matthew.r.wilcox@intel.com>
---
 fs/dax.c                   | 1127 +++++++++++++++++++++-----------------------
 include/linux/dax.h        |    3 +-
 include/linux/pfn_t.h      |   41 +-
 include/linux/radix-tree.h |    9 -
 include/linux/sched.h      |    1 +
 5 files changed, 565 insertions(+), 616 deletions(-)
diff mbox

Patch

diff --git a/fs/dax.c b/fs/dax.c
index e9701d6..38b92b5 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -25,12 +25,132 @@ 
 #include <linux/mm.h>
 #include <linux/mutex.h>
 #include <linux/pagevec.h>
+#include <linux/pfn_t.h>
 #include <linux/pmem.h>
+#include <linux/preempt.h>
 #include <linux/sched.h>
+#include <linux/sizes.h>
 #include <linux/uio.h>
 #include <linux/vmstat.h>
-#include <linux/pfn_t.h>
-#include <linux/sizes.h>
+
+/*
+ * 32-bit architectures want to override this to actually map/unmap
+ * their persistent memory.  ARM, SPARC & MIPS also want to override it
+ * to map the PFN at an address that uses the same cachelines as the
+ * userspace mapping (that's what 'index' is for)
+ */
+static void *dax_map_pfn(pfn_t pfn, unsigned long index)
+{
+	if (is_bad_pfn_t(pfn))
+		return NULL;
+	preempt_disable();
+	pagefault_disable();
+	return pfn_to_kaddr(pfn_t_to_pfn(pfn));
+}
+
+static void dax_unmap_pfn(void *addr)
+{
+	pagefault_enable();
+	preempt_enable();
+}
+
+/*
+ * DAX uses the 'exceptional' entries to store PFNs in the radix tree.
+ * Bit 0 is clear (the radix tree uses this for its own purposes).  Bit
+ * 1 is set (to indicate an exceptional entry).  Bits 2 & 3 are PFN_DEV
+ * and PFN_MAP.  The top two bits denote the size of the entry (PTE, PMD,
+ * PUD, one reserved).  That leaves us 26 bits on 32-bit systems and 58
+ * bits on 64-bit systems, able to address 256GB and 1024EB respectively.
+ */
+#define RADIX_DAX_SIZE_MASK	(0x3UL << (BITS_PER_LONG - 2))
+#define RADIX_TREE_MASK		(RADIX_TREE_INDIRECT_PTR | \
+				 RADIX_TREE_EXCEPTIONAL_ENTRY)
+#define RADIX_DAX_PFN_MASK	(~(RADIX_DAX_SIZE_MASK | RADIX_TREE_MASK))
+#define RADIX_DAX_SHIFT		4
+#define RADIX_DAX_PTE		(0x0UL << (BITS_PER_LONG - 2))
+#define RADIX_DAX_PMD		(0x1UL << (BITS_PER_LONG - 2))
+#define RADIX_DAX_PUD		(0x2UL << (BITS_PER_LONG - 2))
+#define RADIX_DAX_SIZE(entry)	((unsigned long)entry & RADIX_DAX_SIZE_MASK)
+#define RADIX_DAX_ENTRY(pfn, size) \
+	((void *)((pfn_t_to_pfn(pfn) << RADIX_DAX_SHIFT) | size))
+
+/* The 'colour' (ie low bits) within a PMD/PUD of a page offset. */
+#define PG_PMD_COLOUR	((PMD_SIZE >> PAGE_CACHE_SHIFT) - 1)
+#define PG_PUD_COLOUR	((PUD_SIZE >> PAGE_CACHE_SHIFT) - 1)
+
+static pfn_t radix_to_pfn_t(void *entry, pgoff_t index)
+{
+	pfn_t pfn = { .val = (unsigned long)entry & RADIX_DAX_PFN_MASK };
+	unsigned offset = 0;
+
+	if (RADIX_DAX_SIZE(entry) == RADIX_DAX_PMD)
+		offset = index & PG_PMD_COLOUR;
+	else if (RADIX_DAX_SIZE(entry) == RADIX_DAX_PUD)
+		offset = index & PG_PUD_COLOUR;
+
+	return pfn_t_add(pfn, offset);
+}
+
+static void *pfn_to_radix(pfn_t pfn, unsigned long size)
+{
+	unsigned long value = pfn.val;
+	BUG_ON(value & RADIX_DAX_PFN_MASK);
+	return (void *)(value | size);
+}
+
+static unsigned size_to_order(unsigned long size)
+{
+	switch (size) {
+	case RADIX_DAX_PTE: return 0;
+	case RADIX_DAX_PMD: return PMD_SHIFT - PAGE_SHIFT;
+	case RADIX_DAX_PUD: return PUD_SHIFT - PAGE_SHIFT;
+	}
+	BUG();
+}
+
+static unsigned size_to_bytes(unsigned long size)
+{
+	switch (size) {
+	case RADIX_DAX_PTE: return PAGE_CACHE_SIZE;
+	case RADIX_DAX_PMD: return PMD_SIZE;
+	case RADIX_DAX_PUD: return PUD_SIZE;
+	}
+	BUG();
+}
+
+static int dax_add_radix_entry(struct address_space *mapping, pgoff_t index,
+				pfn_t pfn, unsigned long size, bool dirty)
+{
+	struct radix_tree_root *page_tree = &mapping->page_tree;
+	int count = 0;
+	void *entry;
+	unsigned order = size_to_order(size);
+
+	if (dirty)
+		__mark_inode_dirty(mapping->host, I_DIRTY_PAGES);
+
+	spin_lock_irq(&mapping->tree_lock);
+	entry = radix_tree_lookup(page_tree, index);
+	if (!radix_tree_exceptional_entry(entry)) {
+		count = -EEXIST;
+		goto unlock;
+	} else if (entry) {
+		if (size <= RADIX_DAX_SIZE(entry))
+			goto dirty;
+	}
+	count = radix_tree_replace(page_tree, index, order,
+					pfn_to_radix(pfn, size));
+	if (count < 0)
+		goto unlock;
+
+	mapping->nrexceptional -= (count - 1);
+ dirty:
+	if (dirty)
+		radix_tree_tag_set(page_tree, index, PAGECACHE_TAG_DIRTY);
+ unlock:
+	spin_unlock_irq(&mapping->tree_lock);
+	return count;
+}
 
 static long dax_map_atomic(struct block_device *bdev, struct blk_dax_ctl *dax)
 {
@@ -58,17 +178,235 @@  static void dax_unmap_atomic(struct block_device *bdev,
 	blk_queue_exit(bdev->bd_queue);
 }
 
+static sector_t to_sector(const struct buffer_head *bh,
+		const struct inode *inode)
+{
+	sector_t sector = bh->b_blocknr << (inode->i_blkbits - 9);
+
+	return sector;
+}
+
+static bool buffer_written(struct buffer_head *bh)
+{
+	return buffer_mapped(bh) && !buffer_unwritten(bh);
+}
+
+static int dax_replace_hole(struct address_space *mapping, pgoff_t index,
+				unsigned long size, pfn_t pfn)
+{
+	unsigned order = size_to_order(size);
+	int i, error;
+
+	for (i = 0; i < (1 << order); i++) {
+		struct page *page;
+ repeat:
+		page = find_get_entry(mapping, index + i);
+		if (!page || radix_tree_exceptional_entry(page))
+			continue;
+
+		lock_page(page);
+		if (unlikely(page->mapping != mapping)) {
+			unlock_page(page);
+			page_cache_release(page);
+			goto repeat;
+		}
+
+		delete_from_page_cache(page);
+		unlock_page(page);
+		page_cache_release(page);
+	}
+
+	/*
+	 * Somebody else could look in the radix tree and find nothing.
+	 * It's harmless though; they'll find the correct pfn by calling
+	 * the filesystem.
+	 */
+	error = dax_add_radix_entry(mapping, index, pfn, size, true);
+
+	unmap_mapping_range(mapping, index << PAGE_CACHE_SHIFT,
+				PAGE_CACHE_SIZE, 0);
+
+	return error;
+}
+
+static int dax_add_pfn_sized(struct address_space *mapping, pgoff_t index,
+				size_t size, bool write, pfn_t pfn,
+				unsigned long radix_size, unsigned entry_size)
+{
+	int error;
+	bool report = true;
+
+	while (size >= entry_size) {
+		error = dax_add_radix_entry(mapping, index, pfn,
+						radix_size, write);
+		if (error == -EEXIST)
+			error = dax_replace_hole(mapping, index, radix_size,
+						pfn);
+		if (error)
+			break;
+		report = false;
+
+		size -= entry_size;
+		pfn = pfn_t_add(pfn, entry_size / PAGE_CACHE_SIZE);
+		index += entry_size / PAGE_CACHE_SIZE;
+	}
+
+	return report ? error : 0;
+}
+
+static int dax_add_pfn_entries(struct address_space *mapping, pgoff_t index,
+				size_t size, bool write, pfn_t pfn)
+{
+	int error = 0;
+	int called = 0;
+	size_t max;
+
+	max = (PG_PMD_COLOUR + 1 - index) << PAGE_CACHE_SHIFT;
+	if (index & PG_PMD_COLOUR) {
+		error = dax_add_pfn_sized(mapping, index, min(size, max),
+				write, pfn, RADIX_DAX_PTE, PAGE_CACHE_SIZE);
+		called++;
+	}
+	size -= min(size, max);
+	if (error || !size)
+		goto out;
+	index += max >> PAGE_CACHE_SHIFT; 
+
+	max = (PG_PUD_COLOUR + 1 - index) << PAGE_CACHE_SHIFT;
+	if (index & PG_PUD_COLOUR) {
+		error = dax_add_pfn_sized(mapping, index, min(size, max),
+				write, pfn, RADIX_DAX_PMD, PMD_SIZE);
+		called++;
+	}
+	size -= min(size, max);
+	if (error || !size)
+		goto out;
+	index += max >> PMD_SHIFT; 
+
+	error = dax_add_pfn_sized(mapping, index, size,
+				write, pfn, RADIX_DAX_PUD, PUD_SIZE);
+	called++;
+	index += size >> PUD_SHIFT;
+	size = size & ~PMD_MASK;
+	if (error || !size)
+		goto out;
+
+	error = dax_add_pfn_sized(mapping, index, size,
+				write, pfn, RADIX_DAX_PMD, PMD_SIZE);
+	index += size >> PMD_SHIFT;
+	size = size & ~PAGE_CACHE_MASK;
+	if (error || !size)
+		return 0;
+
+	error = dax_add_pfn_sized(mapping, index, size,
+				write, pfn, RADIX_DAX_PTE, PAGE_CACHE_SIZE);
+ out:
+	if (called > 1)
+		error = 0;
+	return error;
+}
+
+/*
+ * Populate the page cache with as many pfns as the filesystem is willing
+ * to tell us about from a single call to get_block, starting at @index and
+ * continuing up to @max bytes.
+ */
+static int dax_create_pfns(struct address_space *mapping, pgoff_t index,
+				unsigned max, bool write, pfn_t *pfn,
+				get_block_t get_block, struct buffer_head *bh)
+{
+	struct inode *inode = mapping->host;
+	unsigned blkbits = inode->i_blkbits;
+	sector_t block = index << (PAGE_CACHE_SHIFT - blkbits);
+	struct blk_dax_ctl dax;
+	int error, result = 0;
+
+	bh->b_size = max;
+	bh->b_state = 0;
+	error = get_block(inode, block, bh, write);
+	if (error)
+		goto error;
+
+	if (!buffer_written(bh))
+		goto hole;
+
+	dax.sector = to_sector(bh, inode);
+	dax.size = bh->b_size;
+	error = dax_map_atomic(bh->b_bdev, &dax);
+	if (error < 0)
+		goto error;
+
+	/*
+	 * We may be about to write data to it, but now it's allocated,
+	 * and another thread will be able to find it in the page cache,
+	 * so we have to zero it otherwise there's a write vs fault race
+	 * that could expose stale data to an application.
+	 */
+	if (buffer_unwritten(bh) || buffer_new(bh)) {
+		clear_pmem(dax.addr, bh->b_size);
+		result = 1;
+	}
+
+	dax_unmap_atomic(bh->b_bdev, &dax);
+
+	error = dax_add_pfn_entries(mapping, index, bh->b_size,
+					write, dax.pfn);
+
+	/*
+	 * Even if we had an error adding the PFN to the radix tree,
+	 * the PFN is still good, so return it.
+	 */
+	*pfn = dax.pfn;
+	return error ? error : result;
+
+ hole:
+ error:
+	*pfn = bad_pfn_t;
+	return error;
+}
+
+/*
+ * Returns either a negative errno, 0 if no allocation had to be performed,
+ * or 1 if the filesystem allocated a block.
+ */
+static int dax_get_pfn(struct address_space *mapping, pgoff_t index,
+				size_t len, bool write, pfn_t *pfn,
+				get_block_t get_block, struct buffer_head *bh)
+{
+	void *entry;
+
+	rcu_read_lock();
+	entry = radix_tree_lookup(&mapping->page_tree, index);
+	rcu_read_unlock();
+
+	if (radix_tree_exceptional_entry(entry)) {
+		*pfn = radix_to_pfn_t(entry, index);
+		return 0;
+	}
+
+	if (entry) {
+		if (write)
+			return dax_create_pfns(mapping, index, len, true, pfn,
+						get_block, bh);
+	} else {
+		return dax_create_pfns(mapping, index, len, write, pfn,
+						get_block, bh);
+	}
+
+	*pfn = bad_pfn_t;
+	return 0;
+}
+
 /*
  * dax_clear_blocks() is called from within transaction context from XFS,
  * and hence this means the stack from this point must follow GFP_NOFS
  * semantics for all operations.
  */
-int dax_clear_blocks(struct inode *inode, sector_t block, long _size)
+int dax_clear_blocks(struct block_device *bdev, sector_t sector, long size)
 {
-	struct block_device *bdev = inode->i_sb->s_bdev;
 	struct blk_dax_ctl dax = {
-		.sector = block << (inode->i_blkbits - 9),
-		.size = _size,
+		.sector = sector,
+		.size = size,
 	};
 
 	might_sleep();
@@ -91,133 +429,52 @@  int dax_clear_blocks(struct inode *inode, sector_t block, long _size)
 }
 EXPORT_SYMBOL_GPL(dax_clear_blocks);
 
-/* the clear_pmem() calls are ordered by a wmb_pmem() in the caller */
-static void dax_new_buf(void __pmem *addr, unsigned size, unsigned first,
-		loff_t pos, loff_t end)
-{
-	loff_t final = end - pos + first; /* The final byte of the buffer */
-
-	if (first > 0)
-		clear_pmem(addr, first);
-	if (final < size)
-		clear_pmem(addr + final, size - final);
-}
-
-static bool buffer_written(struct buffer_head *bh)
-{
-	return buffer_mapped(bh) && !buffer_unwritten(bh);
-}
-
-/*
- * When ext4 encounters a hole, it returns without modifying the buffer_head
- * which means that we can't trust b_size.  To cope with this, we set b_state
- * to 0 before calling get_block and, if any bit is set, we know we can trust
- * b_size.  Unfortunate, really, since ext4 knows precisely how long a hole is
- * and would save us time calling get_block repeatedly.
- */
-static bool buffer_size_valid(struct buffer_head *bh)
-{
-	return bh->b_state != 0;
-}
-
-
-static sector_t to_sector(const struct buffer_head *bh,
-		const struct inode *inode)
-{
-	sector_t sector = bh->b_blocknr << (inode->i_blkbits - 9);
-
-	return sector;
-}
-
 static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
-		      loff_t start, loff_t end, get_block_t get_block,
-		      struct buffer_head *bh)
+				loff_t start, loff_t end,
+				get_block_t get_block, struct buffer_head *bh)
 {
-	loff_t pos = start, max = start, bh_max = start;
-	bool hole = false, need_wmb = false;
-	struct block_device *bdev = NULL;
-	int rw = iov_iter_rw(iter), rc;
-	long map_len = 0;
-	struct blk_dax_ctl dax = {
-		.addr = (void __pmem *) ERR_PTR(-EIO),
-	};
+	loff_t pos = start;
+	int error = 0;
+	const int rw = iov_iter_rw(iter);
 
 	if (rw == READ)
 		end = min(end, i_size_read(inode));
 
-	while (pos < end) {
-		size_t len;
-		if (pos == max) {
-			unsigned blkbits = inode->i_blkbits;
-			long page = pos >> PAGE_SHIFT;
-			sector_t block = page << (PAGE_SHIFT - blkbits);
-			unsigned first = pos - (block << blkbits);
-			long size;
-
-			if (pos == bh_max) {
-				bh->b_size = PAGE_ALIGN(end - pos);
-				bh->b_state = 0;
-				rc = get_block(inode, block, bh, rw == WRITE);
-				if (rc)
-					break;
-				if (!buffer_size_valid(bh))
-					bh->b_size = 1 << blkbits;
-				bh_max = pos - first + bh->b_size;
-				bdev = bh->b_bdev;
-			} else {
-				unsigned done = bh->b_size -
-						(bh_max - (pos - first));
-				bh->b_blocknr += done >> blkbits;
-				bh->b_size -= done;
-			}
+	while (!error && pos < end) {
+		pgoff_t pgoff = pos >> PAGE_CACHE_SHIFT;
+		unsigned off = pos & ~PAGE_CACHE_MASK;
+		size_t len = end - pos;
+		pfn_t pfn;
+		void __pmem *addr;
 
-			hole = rw == READ && !buffer_written(bh);
-			if (hole) {
-				size = bh->b_size - first;
-			} else {
-				dax_unmap_atomic(bdev, &dax);
-				dax.sector = to_sector(bh, inode);
-				dax.size = bh->b_size;
-				map_len = dax_map_atomic(bdev, &dax);
-				if (map_len < 0) {
-					rc = map_len;
-					break;
-				}
-				if (buffer_unwritten(bh) || buffer_new(bh)) {
-					dax_new_buf(dax.addr, map_len, first,
-							pos, end);
-					need_wmb = true;
-				}
-				dax.addr += first;
-				size = map_len - first;
-			}
-			max = min(pos + size, end);
-		}
+		error = dax_get_pfn(inode->i_mapping, pgoff, len, rw == WRITE,
+						&pfn, get_block, bh);
+		if (error < 0)
+			break;
+		addr = dax_map_pfn(pfn, pgoff) + off;
 
-		if (iov_iter_rw(iter) == WRITE) {
-			len = copy_from_iter_pmem(dax.addr, max - pos, iter);
-			need_wmb = true;
-		} else if (!hole)
-			len = copy_to_iter((void __force *) dax.addr, max - pos,
-					iter);
+		if (len > PAGE_CACHE_SIZE)
+			len = PAGE_CACHE_SIZE;
+
+		if (rw == WRITE) {
+			len = copy_from_iter_pmem(addr, len, iter);
+			current->needs_wmb = true;
+		} else if (addr)
+			len = copy_to_iter((void __force *) addr, len, iter);
 		else
-			len = iov_iter_zero(max - pos, iter);
+			len = iov_iter_zero(len, iter);
+		dax_unmap_pfn(addr - off);
 
-		if (!len) {
-			rc = -EFAULT;
-			break;
-		}
+		if (!len)
+			error = -EFAULT;
 
 		pos += len;
-		if (!IS_ERR(dax.addr))
-			dax.addr += len;
 	}
 
-	if (need_wmb)
+	if (current->needs_wmb)
 		wmb_pmem();
-	dax_unmap_atomic(bdev, &dax);
 
-	return (pos == start) ? rc : pos - start;
+	return (pos == start) ? error : pos - start;
 }
 
 /**
@@ -238,15 +495,14 @@  static ssize_t dax_io(struct inode *inode, struct iov_iter *iter,
  * is in progress.
  */
 ssize_t dax_do_io(struct kiocb *iocb, struct inode *inode,
-		  struct iov_iter *iter, loff_t pos, get_block_t get_block,
-		  dio_iodone_t end_io, int flags)
+			struct iov_iter *iter, loff_t pos,
+			get_block_t get_block, dio_iodone_t end_io, int flags)
 {
 	struct buffer_head bh;
 	ssize_t retval = -EINVAL;
 	loff_t end = pos + iov_iter_count(iter);
 
 	memset(&bh, 0, sizeof(bh));
-	bh.b_bdev = inode->i_sb->s_bdev;
 
 	if ((flags & DIO_LOCKING) && iov_iter_rw(iter) == READ) {
 		struct address_space *mapping = inode->i_mapping;
@@ -277,124 +533,26 @@  ssize_t dax_do_io(struct kiocb *iocb, struct inode *inode,
 }
 EXPORT_SYMBOL_GPL(dax_do_io);
 
-/*
- * The user has performed a load from a hole in the file.  Allocating
- * a new page in the file would cause excessive storage usage for
- * workloads with sparse files.  We allocate a page cache page instead.
- * We'll kick it out of the page cache if it's ever written to,
- * otherwise it will simply fall out of the page cache under memory
- * pressure without ever having been dirtied.
- */
-static int dax_load_hole(struct address_space *mapping, struct page *page,
-							struct vm_fault *vmf)
+static int copy_user_pfn(struct vm_fault *vmf, pfn_t pfn)
 {
-	if (!page)
-		page = find_or_create_page(mapping, vmf->pgoff,
-						vmf->gfp_mask | __GFP_ZERO);
-	if (!page)
-		return VM_FAULT_OOM;
-	vmf->page = page;
-	return VM_FAULT_LOCKED;
-}
+	void *vto, *vfrom;
 
-static int copy_user_bh(struct page *to, struct inode *inode,
-		struct buffer_head *bh, unsigned long vaddr)
-{
-	struct blk_dax_ctl dax = {
-		.sector = to_sector(bh, inode),
-		.size = bh->b_size,
-	};
-	struct block_device *bdev = bh->b_bdev;
-	void *vto;
-
-	if (dax_map_atomic(bdev, &dax) < 0)
-		return PTR_ERR(dax.addr);
-	vto = kmap_atomic(to);
-	copy_user_page(vto, (void __force *)dax.addr, vaddr, to);
+	vfrom = dax_map_pfn(pfn, vmf->pgoff);
+	vto = kmap_atomic(vmf->cow_page);
+	copy_user_page(vto, vfrom, (unsigned long)vmf->virtual_address,
+			vmf->cow_page);
 	kunmap_atomic(vto);
-	dax_unmap_atomic(bdev, &dax);
+	dax_unmap_pfn(vfrom);
 	return 0;
 }
 
-#define NO_SECTOR -1
-#define DAX_PMD_INDEX(page_index) (page_index & (PMD_MASK >> PAGE_CACHE_SHIFT))
-
-static int dax_radix_entry(struct address_space *mapping, pgoff_t index,
-		sector_t sector, bool pmd_entry, bool dirty)
-{
-	struct radix_tree_root *page_tree = &mapping->page_tree;
-	pgoff_t pmd_index = DAX_PMD_INDEX(index);
-	int type, error = 0;
-	void *entry;
-
-	WARN_ON_ONCE(pmd_entry && !dirty);
-	__mark_inode_dirty(mapping->host, I_DIRTY_PAGES);
-
-	spin_lock_irq(&mapping->tree_lock);
-
-	entry = radix_tree_lookup(page_tree, pmd_index);
-	if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD) {
-		index = pmd_index;
-		goto dirty;
-	}
-
-	entry = radix_tree_lookup(page_tree, index);
-	if (entry) {
-		type = RADIX_DAX_TYPE(entry);
-		if (WARN_ON_ONCE(type != RADIX_DAX_PTE &&
-					type != RADIX_DAX_PMD)) {
-			error = -EIO;
-			goto unlock;
-		}
-
-		if (!pmd_entry || type == RADIX_DAX_PMD)
-			goto dirty;
-
-		/*
-		 * We only insert dirty PMD entries into the radix tree.  This
-		 * means we don't need to worry about removing a dirty PTE
-		 * entry and inserting a clean PMD entry, thus reducing the
-		 * range we would flush with a follow-up fsync/msync call.
-		 */
-		radix_tree_delete(&mapping->page_tree, index);
-		mapping->nrexceptional--;
-	}
-
-	if (sector == NO_SECTOR) {
-		/*
-		 * This can happen during correct operation if our pfn_mkwrite
-		 * fault raced against a hole punch operation.  If this
-		 * happens the pte that was hole punched will have been
-		 * unmapped and the radix tree entry will have been removed by
-		 * the time we are called, but the call will still happen.  We
-		 * will return all the way up to wp_pfn_shared(), where the
-		 * pte_same() check will fail, eventually causing page fault
-		 * to be retried by the CPU.
-		 */
-		goto unlock;
-	}
-
-	error = radix_tree_insert(page_tree, index,
-			RADIX_DAX_ENTRY(sector, pmd_entry));
-	if (error)
-		goto unlock;
-
-	mapping->nrexceptional++;
- dirty:
-	if (dirty)
-		radix_tree_tag_set(page_tree, index, PAGECACHE_TAG_DIRTY);
- unlock:
-	spin_unlock_irq(&mapping->tree_lock);
-	return error;
-}
-
 static int dax_writeback_one(struct block_device *bdev,
 		struct address_space *mapping, pgoff_t index, void *entry)
 {
 	struct radix_tree_root *page_tree = &mapping->page_tree;
-	int type = RADIX_DAX_TYPE(entry);
+	unsigned size = RADIX_DAX_SIZE(entry);
 	struct radix_tree_node *node;
-	struct blk_dax_ctl dax;
+	void __pmem *addr;
 	void **slot;
 	int ret = 0;
 
@@ -412,38 +570,14 @@  static int dax_writeback_one(struct block_device *bdev,
 	/* another fsync thread may have already written back this entry */
 	if (!radix_tree_tag_get(page_tree, index, PAGECACHE_TAG_TOWRITE))
 		goto unlock;
-
-	if (WARN_ON_ONCE(type != RADIX_DAX_PTE && type != RADIX_DAX_PMD)) {
-		ret = -EIO;
-		goto unlock;
-	}
-
-	dax.sector = RADIX_DAX_SECTOR(entry);
-	dax.size = (type == RADIX_DAX_PMD ? PMD_SIZE : PAGE_SIZE);
 	spin_unlock_irq(&mapping->tree_lock);
 
-	/*
-	 * We cannot hold tree_lock while calling dax_map_atomic() because it
-	 * eventually calls cond_resched().
-	 */
-	ret = dax_map_atomic(bdev, &dax);
-	if (ret < 0)
-		return ret;
-
-	if (WARN_ON_ONCE(ret < dax.size)) {
-		ret = -EIO;
-		goto unmap;
-	}
-
-	wb_cache_pmem(dax.addr, dax.size);
+	addr = dax_map_pfn(radix_to_pfn_t(entry, index), index);
+	wb_cache_pmem(addr, size_to_bytes(size));
+	dax_unmap_pfn(addr);
 
 	spin_lock_irq(&mapping->tree_lock);
 	radix_tree_tag_clear(page_tree, index, PAGECACHE_TAG_TOWRITE);
-	spin_unlock_irq(&mapping->tree_lock);
- unmap:
-	dax_unmap_atomic(bdev, &dax);
-	return ret;
-
  unlock:
 	spin_unlock_irq(&mapping->tree_lock);
 	return ret;
@@ -459,27 +593,17 @@  int dax_writeback_mapping_range(struct address_space *mapping, loff_t start,
 {
 	struct inode *inode = mapping->host;
 	struct block_device *bdev = inode->i_sb->s_bdev;
-	pgoff_t start_index, end_index, pmd_index;
+	pgoff_t start_index, end_index;
 	pgoff_t indices[PAGEVEC_SIZE];
 	struct pagevec pvec;
 	bool done = false;
 	int i, ret = 0;
-	void *entry;
 
 	if (WARN_ON_ONCE(inode->i_blkbits != PAGE_SHIFT))
 		return -EIO;
 
 	start_index = start >> PAGE_CACHE_SHIFT;
 	end_index = end >> PAGE_CACHE_SHIFT;
-	pmd_index = DAX_PMD_INDEX(start_index);
-
-	rcu_read_lock();
-	entry = radix_tree_lookup(&mapping->page_tree, pmd_index);
-	rcu_read_unlock();
-
-	/* see if the start of our range is covered by a PMD entry */
-	if (entry && RADIX_DAX_TYPE(entry) == RADIX_DAX_PMD)
-		start_index = pmd_index;
 
 	tag_pages_for_writeback(mapping, start_index, end_index);
 
@@ -509,107 +633,80 @@  int dax_writeback_mapping_range(struct address_space *mapping, loff_t start,
 }
 EXPORT_SYMBOL_GPL(dax_writeback_mapping_range);
 
-static int dax_insert_mapping(struct inode *inode, struct buffer_head *bh,
-			struct vm_area_struct *vma, struct vm_fault *vmf)
-{
-	unsigned long vaddr = (unsigned long)vmf->virtual_address;
-	struct address_space *mapping = inode->i_mapping;
-	struct block_device *bdev = bh->b_bdev;
-	struct blk_dax_ctl dax = {
-		.sector = to_sector(bh, inode),
-		.size = bh->b_size,
-	};
-	int error;
-
-	if (dax_map_atomic(bdev, &dax) < 0) {
-		error = PTR_ERR(dax.addr);
-		goto out;
-	}
-
-	if (buffer_unwritten(bh) || buffer_new(bh)) {
-		clear_pmem(dax.addr, PAGE_SIZE);
-		wmb_pmem();
-	}
-	dax_unmap_atomic(bdev, &dax);
-
-	error = dax_radix_entry(mapping, vmf->pgoff, dax.sector, false,
-			vmf->flags & FAULT_FLAG_WRITE);
-	if (error)
-		goto out;
-
-	error = vm_insert_mixed(vma, vaddr, dax.pfn);
-
- out:
-	return error;
-}
-
 static int dax_pte_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 			get_block_t get_block, dax_iodone_t complete_unwritten)
 {
-	struct file *file = vma->vm_file;
-	struct address_space *mapping = file->f_mapping;
+	struct address_space *mapping = vma->vm_file->f_mapping;
 	struct inode *inode = mapping->host;
 	struct page *page;
+	pfn_t pfn;
 	struct buffer_head bh;
 	unsigned long vaddr = (unsigned long)vmf->virtual_address;
-	unsigned blkbits = inode->i_blkbits;
-	sector_t block;
 	pgoff_t size;
 	int error;
 	int major = 0;
+	bool write = vmf->flags & FAULT_FLAG_WRITE;
 
 	size = (i_size_read(inode) + PAGE_CACHE_SIZE - 1) >> PAGE_CACHE_SHIFT;
 	if (vmf->pgoff >= size)
 		return VM_FAULT_SIGBUS;
 
 	memset(&bh, 0, sizeof(bh));
-	block = (sector_t)vmf->pgoff << (PAGE_CACHE_SHIFT - blkbits);
-	bh.b_bdev = inode->i_sb->s_bdev;
-	bh.b_size = PAGE_CACHE_SIZE;
 
  repeat:
-	page = find_get_page(mapping, vmf->pgoff);
-	if (page) {
+	page = find_get_entry(mapping, vmf->pgoff);
+	if (radix_tree_exceptional_entry(page)) {
+		pfn = radix_to_pfn_t(page, vmf->pgoff);
+		page = NULL;
+	} else if (!page) {
+		error = dax_create_pfns(mapping, vmf->pgoff, PAGE_CACHE_SIZE,
+				write && !vmf->cow_page, &pfn, get_block, &bh);
+		if (error < 0)
+			goto out;
+
+		if (error) {
+			count_vm_event(PGMAJFAULT);
+			mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
+			major = VM_FAULT_MAJOR;
+			error = 0;
+		}
+	} else {
 		if (!lock_page_or_retry(page, vma->vm_mm, vmf->flags)) {
 			page_cache_release(page);
 			return VM_FAULT_RETRY;
-		}
-		if (unlikely(page->mapping != mapping)) {
+		} else if (unlikely(page->mapping != mapping)) {
 			unlock_page(page);
 			page_cache_release(page);
 			goto repeat;
 		}
 	}
 
-	error = get_block(inode, block, &bh, 0);
-	if (!error && (bh.b_size < PAGE_CACHE_SIZE))
-		error = -EIO;		/* fs corruption? */
-	if (error)
-		goto unlock_page;
+	if (is_bad_pfn_t(pfn) && !vmf->cow_page) {
+		/*
+		 * Allocating a new page in the file would cause excessive
+		 * storage usage for workloads with sparse files.  We allocate
+		 * a page cache page instead.  We'll kick it out of the page
+		 * cache if it's ever written to, otherwise it will simply
+		 * fall out of the page cache under memory pressure without
+		 * ever having been dirtied.
+		 */
+		if (!page)
+			page = find_or_create_page(mapping, vmf->pgoff,
+						vmf->gfp_mask | __GFP_ZERO);
+		if (!page)
+			return VM_FAULT_OOM;
+		vmf->page = page;
+		return VM_FAULT_LOCKED;
+	}
 
-	if (!buffer_mapped(&bh) && !buffer_unwritten(&bh) && !vmf->cow_page) {
-		if (vmf->flags & FAULT_FLAG_WRITE) {
-			error = get_block(inode, block, &bh, 1);
-			count_vm_event(PGMAJFAULT);
-			mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
-			major = VM_FAULT_MAJOR;
-			if (!error && (bh.b_size < PAGE_CACHE_SIZE))
-				error = -EIO;
+	if (vmf->cow_page) {
+		if (is_bad_pfn_t(pfn)) {
+			clear_user_highpage(vmf->cow_page, vaddr);
+		} else {
+			error = copy_user_pfn(vmf, pfn);
 			if (error)
 				goto unlock_page;
-		} else {
-			return dax_load_hole(mapping, page, vmf);
 		}
-	}
-
-	if (vmf->cow_page) {
-		struct page *new_page = vmf->cow_page;
-		if (buffer_written(&bh))
-			error = copy_user_bh(new_page, inode, &bh, vaddr);
-		else
-			clear_user_highpage(new_page, vaddr);
-		if (error)
-			goto unlock_page;
 		vmf->page = page;
 
 		/*
@@ -625,18 +722,10 @@  static int dax_pte_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 		return VM_FAULT_LOCKED;
 	}
 
-	/* Check we didn't race with a read fault installing a new page */
-	if (!page && major)
-		page = find_lock_page(mapping, vmf->pgoff);
+	if (current->needs_wmb)
+		wmb_pmem();
 
-	if (page) {
-		unmap_mapping_range(mapping, vmf->pgoff << PAGE_CACHE_SHIFT,
-							PAGE_CACHE_SIZE, 0);
-		delete_from_page_cache(page);
-		unlock_page(page);
-		page_cache_release(page);
-		page = NULL;
-	}
+	error = vm_insert_mixed(vma, vaddr, pfn);
 
 	/*
 	 * If we successfully insert the new mapping over an unwritten extent,
@@ -648,12 +737,11 @@  static int dax_pte_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 	 * indicate what the callback should do via the uptodate variable, same
 	 * as for normal BH based IO completions.
 	 */
-	error = dax_insert_mapping(inode, &bh, vma, vmf);
 	if (buffer_unwritten(&bh)) {
 		if (complete_unwritten)
 			complete_unwritten(&bh, !error);
 		else
-			WARN_ON_ONCE(!(vmf->flags & FAULT_FLAG_WRITE));
+			WARN_ON_ONCE(!write);
 	}
 
  out:
@@ -673,12 +761,6 @@  static int dax_pte_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 }
 
 #ifdef CONFIG_TRANSPARENT_HUGEPAGE
-/*
- * The 'colour' (ie low bits) within a PMD of a page offset.  This comes up
- * more often than one might expect in the below function.
- */
-#define PG_PMD_COLOUR	((PMD_SIZE >> PAGE_CACHE_SHIFT) - 1)
-
 static void __dax_dbg(struct buffer_head *bh, unsigned long address,
 		const char *reason, const char *fn)
 {
@@ -697,98 +779,19 @@  static void __dax_dbg(struct buffer_head *bh, unsigned long address,
 
 #define dax_pmd_dbg(bh, address, reason)	__dax_dbg(bh, address, reason, "dax_pmd")
 
-static int dax_insert_pmd_mapping(struct inode *inode, struct buffer_head *bh,
-			struct vm_area_struct *vma, struct vm_fault *vmf)
-{
-	int major = 0;
-	struct blk_dax_ctl dax = {
-		.sector = to_sector(bh, inode),
-		.size = PMD_SIZE,
-	};
-	struct block_device *bdev = bh->b_bdev;
-	bool write = vmf->flags & FAULT_FLAG_WRITE;
-	unsigned long address = (unsigned long)vmf->virtual_address;
-	long length = dax_map_atomic(bdev, &dax);
-
-	if (length < 0)
-		return VM_FAULT_SIGBUS;
-	if (length < PMD_SIZE) {
-		dax_pmd_dbg(bh, address, "dax-length too small");
-		goto unmap;
-	}
-
-	if (pfn_t_to_pfn(dax.pfn) & PG_PMD_COLOUR) {
-		dax_pmd_dbg(bh, address, "pfn unaligned");
-		goto unmap;
-	}
-
-	if (!pfn_t_devmap(dax.pfn)) {
-		dax_pmd_dbg(bh, address, "pfn not in memmap");
-		goto unmap;
-	}
-
-	if (buffer_unwritten(bh) || buffer_new(bh)) {
-		clear_pmem(dax.addr, PMD_SIZE);
-		wmb_pmem();
-		count_vm_event(PGMAJFAULT);
-		mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
-		major = VM_FAULT_MAJOR;
-	}
-	dax_unmap_atomic(bdev, &dax);
-
-	/*
-	 * For PTE faults we insert a radix tree entry for reads, and leave
-	 * it clean.  Then on the first write we dirty the radix tree entry
-	 * via the dax_pfn_mkwrite() path.  This sequence allows the
-	 * dax_pfn_mkwrite() call to be simpler and avoid a call into
-	 * get_block() to translate the pgoff to a sector in order to be able
-	 * to create a new radix tree entry.
-	 *
-	 * The PMD path doesn't have an equivalent to dax_pfn_mkwrite(),
-	 * though, so for a read followed by a write we traverse all the way
-	 * through dax_pmd_fault() twice.  This means we can just skip
-	 * inserting a radix tree entry completely on the initial read and
-	 * just wait until the write to insert a dirty entry.
-	 */
-	if (write) {
-		int error = dax_radix_entry(vma->vm_file->f_mapping, vmf->pgoff,
-						dax.sector, true, true);
-		if (error) {
-			dax_pmd_dbg(bh, address,
-					"PMD radix insertion failed");
-			goto fallback;
-		}
-	}
-
-	dev_dbg(part_to_dev(bdev->bd_part),
-			"%s: %s addr: %lx pfn: %lx sect: %llx\n",
-			__func__, current->comm, address,
-			pfn_t_to_pfn(dax.pfn),
-			(unsigned long long) dax.sector);
-	return major | vmf_insert_pfn_pmd(vma, address, vmf->pmd,
-						dax.pfn, write);
- unmap:
-	dax_unmap_atomic(bdev, &dax);
- fallback:
-	count_vm_event(THP_FAULT_FALLBACK);
-	return VM_FAULT_FALLBACK;
-}
-
 static int dax_pmd_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 		get_block_t get_block, dax_iodone_t complete_unwritten)
 {
-	struct file *file = vma->vm_file;
-	struct address_space *mapping = file->f_mapping;
+	struct address_space *mapping = vma->vm_file->f_mapping;
 	struct inode *inode = mapping->host;
+	void *entry;
+	pfn_t pfn;
 	struct buffer_head bh;
-	unsigned blkbits = inode->i_blkbits;
-	unsigned long address = (unsigned long)vmf->virtual_address;
-	unsigned long pmd_addr = address & PMD_MASK;
-	bool write = vmf->flags & FAULT_FLAG_WRITE;
+	unsigned long vaddr = (unsigned long)vmf->virtual_address;
+	unsigned long pmd_addr = vaddr & PMD_MASK;
 	pgoff_t size;
-	sector_t block;
-	int result;
-	bool alloc = false;
+	int result = 0;
+	bool write = vmf->flags & FAULT_FLAG_WRITE;
 
 	/* dax pmd mappings require pfn_t_devmap() */
 	if (!IS_ENABLED(CONFIG_FS_DAX_PMD))
@@ -796,17 +799,17 @@  static int dax_pmd_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 
 	/* Fall back to PTEs if we're going to COW */
 	if (write && !(vma->vm_flags & VM_SHARED)) {
-		split_huge_pmd(vma, vmf->pmd, address);
-		dax_pmd_dbg(NULL, address, "cow write");
+		split_huge_pmd(vma, vmf->pmd, vaddr);
+		dax_pmd_dbg(NULL, vaddr, "cow write");
 		return VM_FAULT_FALLBACK;
 	}
 	/* If the PMD would extend outside the VMA */
 	if (pmd_addr < vma->vm_start) {
-		dax_pmd_dbg(NULL, address, "vma start unaligned");
+		dax_pmd_dbg(NULL, vaddr, "vma start unaligned");
 		return VM_FAULT_FALLBACK;
 	}
 	if ((pmd_addr + PMD_SIZE) > vma->vm_end) {
-		dax_pmd_dbg(NULL, address, "vma end unaligned");
+		dax_pmd_dbg(NULL, vaddr, "vma end unaligned");
 		return VM_FAULT_FALLBACK;
 	}
 
@@ -815,76 +818,69 @@  static int dax_pmd_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 		return VM_FAULT_SIGBUS;
 	/* If the PMD would cover blocks out of the file */
 	if ((vmf->pgoff | PG_PMD_COLOUR) >= size) {
-		dax_pmd_dbg(NULL, address,
+		dax_pmd_dbg(NULL, vaddr,
 				"offset + huge page size > file size");
 		return VM_FAULT_FALLBACK;
 	}
 
 	memset(&bh, 0, sizeof(bh));
-	bh.b_bdev = inode->i_sb->s_bdev;
-	block = (sector_t)vmf->pgoff << (PAGE_CACHE_SHIFT - blkbits);
-
-	bh.b_size = PMD_SIZE;
-
-	if (get_block(inode, block, &bh, 0) != 0)
-		return VM_FAULT_SIGBUS;
-
-	if (!buffer_mapped(&bh) && write) {
-		if (get_block(inode, block, &bh, 1) != 0)
-			return VM_FAULT_SIGBUS;
-		alloc = true;
-	}
 
-	/*
-	 * If the filesystem isn't willing to tell us the length of a hole,
-	 * just fall back to PTEs.  Calling get_block 512 times in a loop
-	 * would be silly.
-	 */
-	if (!buffer_size_valid(&bh) || bh.b_size < PMD_SIZE) {
-		dax_pmd_dbg(&bh, address, "allocated block too small");
-		return VM_FAULT_FALLBACK;
-	}
+	entry = find_get_entry(mapping, vmf->pgoff);
+	if (radix_tree_exceptional_entry(entry) &&
+				RADIX_DAX_SIZE(entry) >= RADIX_DAX_PMD) {
+		pfn = radix_to_pfn_t(entry, vmf->pgoff);
+	} else {
+		int error = dax_create_pfns(mapping, vmf->pgoff, PMD_SIZE,
+					write, &pfn, get_block, &bh);
+		if (error < 0)
+			goto fallback;
 
-	/*
-	 * If we allocated new storage, make sure no process has any
-	 * zero pages covering this hole
-	 */
-	if (alloc) {
-		loff_t lstart = vmf->pgoff << PAGE_CACHE_SHIFT;
-		loff_t lend = lstart + PMD_SIZE - 1; /* inclusive */
+		if (error) {
+			count_vm_event(PGMAJFAULT);
+			mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
+			result = VM_FAULT_MAJOR;
+			error = 0;
+		}
 
-		truncate_pagecache_range(inode, lstart, lend);
+		/*
+		 * We don't know if dax_create_pfns() was able to allocate
+		 * a contiguous aligned chunk, or whether it was only able
+		 * to do a partial allocation.
+		 */
+		entry = find_get_entry(mapping, vmf->pgoff);
+		if (!radix_tree_exceptional_entry(entry) ||
+				RADIX_DAX_SIZE(entry) < RADIX_DAX_PMD)
+			goto fallback;
 	}
 
-	if (!write && !buffer_mapped(&bh) && buffer_uptodate(&bh)) {
+	if (is_bad_pfn_t(pfn)) {
 		spinlock_t *ptl;
 		pmd_t entry, *pmd = vmf->pmd;
 		struct page *zero_page = get_huge_zero_page();
 
 		if (unlikely(!zero_page)) {
-			dax_pmd_dbg(&bh, address, "no zero page");
+			dax_pmd_dbg(&bh, vaddr, "no zero page");
 			goto fallback;
 		}
 
 		ptl = pmd_lock(vma->vm_mm, pmd);
 		if (!pmd_none(*pmd)) {
 			spin_unlock(ptl);
-			dax_pmd_dbg(&bh, address, "pmd already present");
+			dax_pmd_dbg(&bh, vaddr, "pmd already present");
 			goto fallback;
 		}
 
-		dev_dbg(part_to_dev(bh.b_bdev->bd_part),
-				"%s: %s addr: %lx pfn: <zero> sect: %llx\n",
-				__func__, current->comm, address,
-				(unsigned long long) to_sector(&bh, inode));
-
 		entry = mk_pmd(zero_page, vma->vm_page_prot);
 		entry = pmd_mkhuge(entry);
 		set_pmd_at(vma->vm_mm, pmd_addr, pmd, entry);
 		result = VM_FAULT_NOPAGE;
 		spin_unlock(ptl);
 	} else {
-		result = dax_insert_pmd_mapping(inode, &bh, vma, vmf);
+		if (current->needs_wmb)
+			wmb_pmem();
+
+		result |= vmf_insert_pfn_pmd(vma, vaddr, vmf->pmd, pfn,
+						write);
 	}
 
  out:
@@ -907,80 +903,21 @@  static int dax_pmd_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 #endif /* !CONFIG_TRANSPARENT_HUGEPAGE */
 
 #ifdef CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD
-/*
- * The 'colour' (ie low bits) within a PUD of a page offset.  This comes up
- * more often than one might expect in the below function.
- */
-#define PG_PUD_COLOUR	((PUD_SIZE >> PAGE_CACHE_SHIFT) - 1)
-
 #define dax_pud_dbg(bh, address, reason)	__dax_dbg(bh, address, reason, "dax_pud")
 
-static int dax_insert_pud_mapping(struct inode *inode, struct buffer_head *bh,
-			struct vm_area_struct *vma, struct vm_fault *vmf)
-{
-	int major = 0;
-	struct blk_dax_ctl dax = {
-		.sector = to_sector(bh, inode),
-		.size = PUD_SIZE,
-	};
-	struct block_device *bdev = bh->b_bdev;
-	bool write = vmf->flags & FAULT_FLAG_WRITE;
-	unsigned long address = (unsigned long)vmf->virtual_address;
-	long length = dax_map_atomic(bdev, &dax);
-
-	if (length < 0)
-		return VM_FAULT_SIGBUS;
-	if (length < PUD_SIZE) {
-		dax_pud_dbg(bh, address, "dax-length too small");
-		goto unmap;
-	}
-	if (pfn_t_to_pfn(dax.pfn) & PG_PUD_COLOUR) {
-		dax_pud_dbg(bh, address, "pfn unaligned");
-		goto unmap;
-	}
-
-	if (!pfn_t_devmap(dax.pfn)) {
-		dax_pud_dbg(bh, address, "pfn not in memmap");
-		goto unmap;
-	}
-
-	if (buffer_unwritten(bh) || buffer_new(bh)) {
-		clear_pmem(dax.addr, PUD_SIZE);
-		wmb_pmem();
-		count_vm_event(PGMAJFAULT);
-		mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
-		major = VM_FAULT_MAJOR;
-	}
-	dax_unmap_atomic(bdev, &dax);
-
-	dev_dbg(part_to_dev(bdev->bd_part),
-			"%s: %s addr: %lx pfn: %lx sect: %llx\n",
-			__func__, current->comm, address,
-			pfn_t_to_pfn(dax.pfn),
-			(unsigned long long) dax.sector);
-	return major | vmf_insert_pfn_pud(vma, address, vmf->pud,
-						dax.pfn, write);
- unmap:
-	dax_unmap_atomic(bdev, &dax);
-	count_vm_event(THP_FAULT_FALLBACK);
-	return VM_FAULT_FALLBACK;
-}
-
 static int dax_pud_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 		get_block_t get_block, dax_iodone_t complete_unwritten)
 {
-	struct file *file = vma->vm_file;
-	struct address_space *mapping = file->f_mapping;
+	struct address_space *mapping = vma->vm_file->f_mapping;
 	struct inode *inode = mapping->host;
+	void *entry;
+	pfn_t pfn;
 	struct buffer_head bh;
-	unsigned blkbits = inode->i_blkbits;
-	unsigned long address = (unsigned long)vmf->virtual_address;
-	unsigned long pud_addr = address & PUD_MASK;
-	bool write = vmf->flags & FAULT_FLAG_WRITE;
+	unsigned long vaddr = (unsigned long)vmf->virtual_address;
+	unsigned long pud_addr = vaddr & PUD_MASK;
 	pgoff_t size;
-	sector_t block;
-	int result;
-	bool alloc = false;
+	int result = 0;
+	bool write = vmf->flags & FAULT_FLAG_WRITE;
 
 	/* dax pud mappings require pfn_t_devmap() */
 	if (!IS_ENABLED(CONFIG_FS_DAX_PMD))
@@ -988,17 +925,17 @@  static int dax_pud_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 
 	/* Fall back to PTEs if we're going to COW */
 	if (write && !(vma->vm_flags & VM_SHARED)) {
-		split_huge_pud(vma, vmf->pud, address);
-		dax_pud_dbg(NULL, address, "cow write");
+		split_huge_pud(vma, vmf->pud, vaddr);
+		dax_pud_dbg(NULL, vaddr, "cow write");
 		return VM_FAULT_FALLBACK;
 	}
 	/* If the PUD would extend outside the VMA */
 	if (pud_addr < vma->vm_start) {
-		dax_pud_dbg(NULL, address, "vma start unaligned");
+		dax_pud_dbg(NULL, vaddr, "vma start unaligned");
 		return VM_FAULT_FALLBACK;
 	}
 	if ((pud_addr + PUD_SIZE) > vma->vm_end) {
-		dax_pud_dbg(NULL, address, "vma end unaligned");
+		dax_pud_dbg(NULL, vaddr, "vma end unaligned");
 		return VM_FAULT_FALLBACK;
 	}
 
@@ -1007,52 +944,50 @@  static int dax_pud_fault(struct vm_area_struct *vma, struct vm_fault *vmf,
 		return VM_FAULT_SIGBUS;
 	/* If the PUD would cover blocks out of the file */
 	if ((vmf->pgoff | PG_PUD_COLOUR) >= size) {
-		dax_pud_dbg(NULL, address,
+		dax_pud_dbg(NULL, vaddr,
 				"offset + huge page size > file size");
 		return VM_FAULT_FALLBACK;
 	}
 
 	memset(&bh, 0, sizeof(bh));
-	bh.b_bdev = inode->i_sb->s_bdev;
-	block = (sector_t)vmf->pgoff << (PAGE_CACHE_SHIFT - blkbits);
 
-	bh.b_size = PUD_SIZE;
-
-	if (get_block(inode, block, &bh, 0) != 0)
-		return VM_FAULT_SIGBUS;
-
-	if (!buffer_mapped(&bh) && write) {
-		if (get_block(inode, block, &bh, 1) != 0)
-			return VM_FAULT_SIGBUS;
-		alloc = true;
-	}
-
-	/*
-	 * If the filesystem isn't willing to tell us the length of a hole,
-	 * just fall back to PMDs.  Calling get_block 512 times in a loop
-	 * would be silly.
-	 */
-	if (!buffer_size_valid(&bh) || bh.b_size < PUD_SIZE) {
-		dax_pud_dbg(&bh, address, "allocated block too small");
-		return VM_FAULT_FALLBACK;
-	}
+	entry = find_get_entry(mapping, vmf->pgoff);
+	if (radix_tree_exceptional_entry(entry) &&
+				RADIX_DAX_SIZE(entry) >= RADIX_DAX_PUD) {
+		pfn = radix_to_pfn_t(entry, vmf->pgoff);
+	} else {
+		int error = dax_create_pfns(mapping, vmf->pgoff, PUD_SIZE,
+					write, &pfn, get_block, &bh);
+		if (error < 0)
+			goto fallback;
 
-	/*
-	 * If we allocated new storage, make sure no process has any
-	 * zero pages covering this hole
-	 */
-	if (alloc) {
-		loff_t lstart = vmf->pgoff << PAGE_CACHE_SHIFT;
-		loff_t lend = lstart + PUD_SIZE - 1; /* inclusive */
+		if (error) {
+			count_vm_event(PGMAJFAULT);
+			mem_cgroup_count_vm_event(vma->vm_mm, PGMAJFAULT);
+			result = VM_FAULT_MAJOR;
+			error = 0;
+		}
 
-		truncate_pagecache_range(inode, lstart, lend);
+		/*
+		 * We don't know if dax_create_pfns() was able to allocate
+		 * a contiguous aligned chunk, or whether it was only able
+		 * to do a partial allocation.
+		 */
+		entry = find_get_entry(mapping, vmf->pgoff);
+		if (!radix_tree_exceptional_entry(entry) ||
+				RADIX_DAX_SIZE(entry) < RADIX_DAX_PUD)
+			goto fallback;
 	}
 
-	if (!write && !buffer_mapped(&bh) && buffer_uptodate(&bh)) {
-		dax_pud_dbg(&bh, address, "no zero page");
+	if (is_bad_pfn_t(pfn)) {
+		dax_pud_dbg(&bh, vaddr, "no zero page");
 		goto fallback;
 	} else {
-		result = dax_insert_pud_mapping(inode, &bh, vma, vmf);
+		if (current->needs_wmb)
+			wmb_pmem();
+
+		result |= vmf_insert_pfn_pud(vma, vaddr, vmf->pud, pfn,
+						write);
 	}
 
  out:
@@ -1113,17 +1048,13 @@  EXPORT_SYMBOL_GPL(dax_fault);
  */
 int dax_pfn_mkwrite(struct vm_area_struct *vma, struct vm_fault *vmf)
 {
-	struct file *file = vma->vm_file;
+	struct address_space *mapping = vma->vm_file->f_mapping;
+
+	spin_lock_irq(&mapping->tree_lock);
+	radix_tree_tag_set(&mapping->page_tree, vmf->pgoff,
+							PAGECACHE_TAG_DIRTY);
+	spin_unlock_irq(&mapping->tree_lock);
 
-	/*
-	 * We pass NO_SECTOR to dax_radix_entry() because we expect that a
-	 * RADIX_DAX_PTE entry already exists in the radix tree from a
-	 * previous call to dax_fault().  We just want to look up that PTE
-	 * entry using vmf->pgoff and make sure the dirty tag is set.  This
-	 * saves us from having to make a call to get_block() here to look
-	 * up the sector.
-	 */
-	dax_radix_entry(file->f_mapping, vmf->pgoff, NO_SECTOR, false, true);
 	return VM_FAULT_NOPAGE;
 }
 EXPORT_SYMBOL_GPL(dax_pfn_mkwrite);
diff --git a/include/linux/dax.h b/include/linux/dax.h
index 8e58c36..0a6505d 100644
--- a/include/linux/dax.h
+++ b/include/linux/dax.h
@@ -5,9 +5,10 @@ 
 #include <linux/mm.h>
 #include <asm/pgtable.h>
 
+int dax_clear_blocks(struct block_device *, sector_t sector, long size);
+
 ssize_t dax_do_io(struct kiocb *, struct inode *, struct iov_iter *, loff_t,
 		  get_block_t, dio_iodone_t, int flags);
-int dax_clear_blocks(struct inode *, sector_t block, long size);
 int dax_zero_page_range(struct inode *, loff_t from, unsigned len, get_block_t);
 int dax_truncate_page(struct inode *, loff_t from, get_block_t);
 int dax_fault(struct vm_area_struct *, struct vm_fault *, get_block_t,
diff --git a/include/linux/pfn_t.h b/include/linux/pfn_t.h
index 07d18a8..95a7b50 100644
--- a/include/linux/pfn_t.h
+++ b/include/linux/pfn_t.h
@@ -8,21 +8,46 @@ 
  * PFN_SG_LAST - pfn references a page and is the last scatterlist entry
  * PFN_DEV - pfn is not covered by system memmap by default
  * PFN_MAP - pfn has a dynamic page mapping established by a device driver
+ *
+ * Note that DAX uses the same format for its radix tree entries.  The
+ * bottom two bits are used by the radix tree.
  */
-#define PFN_FLAGS_MASK (((unsigned long) ~PAGE_MASK) \
-		<< (BITS_PER_LONG - PAGE_SHIFT))
-#define PFN_SG_CHAIN (1UL << (BITS_PER_LONG - 1))
-#define PFN_SG_LAST (1UL << (BITS_PER_LONG - 2))
-#define PFN_DEV (1UL << (BITS_PER_LONG - 3))
-#define PFN_MAP (1UL << (BITS_PER_LONG - 4))
+#define PFN_FLAG_BITS	4
+#define PFN_FLAGS_MASK	((1 << PFN_FLAG_BITS) - 1)
+#define PFN_SG_CHAIN	0x1UL
+#define PFN_SG_LAST	0x2UL
+#define PFN_DEV		0x4UL
+#define PFN_MAP		0x8UL
 
 static inline pfn_t __pfn_to_pfn_t(unsigned long pfn, unsigned long flags)
 {
-	pfn_t pfn_t = { .val = pfn | (flags & PFN_FLAGS_MASK), };
+	pfn_t pfn_t = { .val = (pfn << PFN_FLAG_BITS) |
+					(flags & PFN_FLAGS_MASK), };
 
 	return pfn_t;
 }
 
+static inline __must_check pfn_t pfn_t_add(const pfn_t pfn, int val)
+{
+	pfn_t tmp = pfn;
+	tmp.val += val << PFN_FLAG_BITS;
+	return tmp;
+}
+	
+/*
+ * It makes no sense to have both SG_CHAIN and SG_LAST set, so we could
+ * encode an errno in here if we need to.  Note that you can't put a
+ * bad_pfn_t in the radix tree because the radix tree uses the bottom bit
+ * for its own purposes.
+ */
+#define bad_pfn_t	((pfn_t) { .val = -1 })
+
+static inline bool is_bad_pfn_t(pfn_t pfn)
+{
+	return ((pfn.val & (PFN_SG_CHAIN | PFN_SG_LAST)) ==
+			  (PFN_SG_CHAIN | PFN_SG_LAST));
+}
+
 /* a default pfn to pfn_t conversion assumes that @pfn is pfn_valid() */
 static inline pfn_t pfn_to_pfn_t(unsigned long pfn)
 {
@@ -38,7 +63,7 @@  static inline bool pfn_t_has_page(pfn_t pfn)
 
 static inline unsigned long pfn_t_to_pfn(pfn_t pfn)
 {
-	return pfn.val & ~PFN_FLAGS_MASK;
+	return pfn.val >> PFN_FLAG_BITS;
 }
 
 static inline struct page *pfn_t_to_page(pfn_t pfn)
diff --git a/include/linux/radix-tree.h b/include/linux/radix-tree.h
index 7c88ad1..57e7d87 100644
--- a/include/linux/radix-tree.h
+++ b/include/linux/radix-tree.h
@@ -51,15 +51,6 @@ 
 #define RADIX_TREE_EXCEPTIONAL_ENTRY	2
 #define RADIX_TREE_EXCEPTIONAL_SHIFT	2
 
-#define RADIX_DAX_MASK	0xf
-#define RADIX_DAX_SHIFT	4
-#define RADIX_DAX_PTE  (0x4 | RADIX_TREE_EXCEPTIONAL_ENTRY)
-#define RADIX_DAX_PMD  (0x8 | RADIX_TREE_EXCEPTIONAL_ENTRY)
-#define RADIX_DAX_TYPE(entry) ((unsigned long)entry & RADIX_DAX_MASK)
-#define RADIX_DAX_SECTOR(entry) (((unsigned long)entry >> RADIX_DAX_SHIFT))
-#define RADIX_DAX_ENTRY(sector, pmd) ((void *)((unsigned long)sector << \
-		RADIX_DAX_SHIFT | (pmd ? RADIX_DAX_PMD : RADIX_DAX_PTE)))
-
 static inline int radix_tree_is_indirect_ptr(void *ptr)
 {
 	return (int)((unsigned long)ptr & RADIX_TREE_INDIRECT_PTR);
diff --git a/include/linux/sched.h b/include/linux/sched.h
index 6e95d8a..2cdfe76 100644
--- a/include/linux/sched.h
+++ b/include/linux/sched.h
@@ -1476,6 +1476,7 @@  struct task_struct {
 	/* unserialized, strictly 'current' */
 	unsigned in_execve:1; /* bit to tell LSMs we're in execve */
 	unsigned in_iowait:1;
+	unsigned needs_wmb:1;
 #ifdef CONFIG_MEMCG
 	unsigned memcg_may_oom:1;
 #ifndef CONFIG_SLOB