Message ID | 20211201142918.921493-44-Liam.Howlett@oracle.com (mailing list archive) |
---|---|
State | New |
Headers | show |
Series | Introducing the Maple Tree | expand |
On 12/1/21 15:30, Liam Howlett wrote: > From: "Liam R. Howlett" <Liam.Howlett@Oracle.com> > > Don't use the mm_struct linked list or the vma->vm_next in prep for removal > > Signed-off-by: Liam R. Howlett <Liam.Howlett@Oracle.com> > --- > fs/userfaultfd.c | 49 ++++++++++++++++++++++------------- > include/linux/userfaultfd_k.h | 7 +++-- > mm/mmap.c | 12 ++++----- > 3 files changed, 40 insertions(+), 28 deletions(-) > > diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c > index 22bf14ab2d16..2880025598c7 100644 > --- a/fs/userfaultfd.c > +++ b/fs/userfaultfd.c > @@ -606,14 +606,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, > if (release_new_ctx) { > struct vm_area_struct *vma; > struct mm_struct *mm = release_new_ctx->mm; > + VMA_ITERATOR(vmi, mm, 0); > > /* the various vma->vm_userfaultfd_ctx still points to it */ > mmap_write_lock(mm); > - for (vma = mm->mmap; vma; vma = vma->vm_next) > + for_each_vma(vmi, vma) { > if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) { > vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; > vma->vm_flags &= ~__VM_UFFD_FLAGS; > } > + } > mmap_write_unlock(mm); > > userfaultfd_ctx_put(release_new_ctx); > @@ -794,11 +796,13 @@ static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps, > return false; > } > > -int userfaultfd_unmap_prep(struct vm_area_struct *vma, > - unsigned long start, unsigned long end, > - struct list_head *unmaps) > +int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, > + unsigned long end, struct list_head *unmaps) > { > - for ( ; vma && vma->vm_start < end; vma = vma->vm_next) { > + VMA_ITERATOR(vmi, mm, start); > + struct vm_area_struct *vma; > + > + for_each_vma_range(vmi, vma, end) { > struct userfaultfd_unmap_ctx *unmap_ctx; > struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx; > > @@ -848,6 +852,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) > /* len == 0 means wake all */ > struct userfaultfd_wake_range range = { .len = 0, }; > unsigned long new_flags; > + MA_STATE(mas, &mm->mm_mt, 0, 0); Again, it looks like this could also be VMA_ITERATOR, consistent with the one above? > > WRITE_ONCE(ctx->released, true); > > @@ -864,7 +869,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) > */ > mmap_write_lock(mm); > prev = NULL; > - for (vma = mm->mmap; vma; vma = vma->vm_next) { > + mas_for_each(&mas, vma, ULONG_MAX) { > cond_resched(); > BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^ > !!(vma->vm_flags & __VM_UFFD_FLAGS)); > @@ -1281,6 +1286,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > bool found; > bool basic_ioctls; > unsigned long start, end, vma_end; > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > user_uffdio_register = (struct uffdio_register __user *) arg; > > @@ -1323,7 +1329,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > goto out; > > mmap_write_lock(mm); > - vma = find_vma_prev(mm, start, &prev); > + mas_set(&mas, start); > + vma = mas_find(&mas, ULONG_MAX); > if (!vma) > goto out_unlock; > > @@ -1348,7 +1355,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > */ > found = false; > basic_ioctls = false; > - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { > + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { > cond_resched(); > > BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ > @@ -1408,8 +1415,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > } > BUG_ON(!found); > > - if (vma->vm_start < start) > - prev = vma; > + mas_set(&mas, start); > + prev = mas_prev(&mas, 0); > + if (prev != vma) > + mas_next(&mas, ULONG_MAX); Hmm non-commented tricky stuff... > > ret = 0; > do { > @@ -1466,8 +1475,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > skip: > prev = vma; > start = vma->vm_end; > - vma = vma->vm_next; > - } while (vma && vma->vm_start < end); > + vma = mas_next(&mas, end - 1); > + } while (vma); > out_unlock: > mmap_write_unlock(mm); > mmput(mm); > @@ -1511,6 +1520,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > bool found; > unsigned long start, end, vma_end; > const void __user *buf = (void __user *)arg; > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > ret = -EFAULT; > if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister))) > @@ -1529,7 +1539,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > goto out; > > mmap_write_lock(mm); > - vma = find_vma_prev(mm, start, &prev); > + mas_set(&mas, start); > + vma = mas_find(&mas, ULONG_MAX); > if (!vma) > goto out_unlock; > > @@ -1554,7 +1565,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > */ > found = false; > ret = -EINVAL; > - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { > + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { > cond_resched(); > > BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ > @@ -1574,8 +1585,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > } > BUG_ON(!found); > > - if (vma->vm_start < start) > - prev = vma; > + mas_set(&mas, start); > + prev = mas_prev(&mas, 0); > + if (prev != vma) > + mas_next(&mas, ULONG_MAX); Same here. > > ret = 0; > do { > @@ -1640,8 +1653,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > skip: > prev = vma; > start = vma->vm_end; > - vma = vma->vm_next; > - } while (vma && vma->vm_start < end); > + vma = mas_next(&mas, end - 1); > + } while (vma); > out_unlock: > mmap_write_unlock(mm); > mmput(mm); > diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h > index 33cea484d1ad..e0b2ec2c20f2 100644 > --- a/include/linux/userfaultfd_k.h > +++ b/include/linux/userfaultfd_k.h > @@ -139,9 +139,8 @@ extern bool userfaultfd_remove(struct vm_area_struct *vma, > unsigned long start, > unsigned long end); > > -extern int userfaultfd_unmap_prep(struct vm_area_struct *vma, > - unsigned long start, unsigned long end, > - struct list_head *uf); > +extern int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, > + unsigned long end, struct list_head *uf); > extern void userfaultfd_unmap_complete(struct mm_struct *mm, > struct list_head *uf); > > @@ -222,7 +221,7 @@ static inline bool userfaultfd_remove(struct vm_area_struct *vma, > return true; > } > > -static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma, > +static inline int userfaultfd_unmap_prep(struct mm_struct *mm, > unsigned long start, unsigned long end, > struct list_head *uf) > { > diff --git a/mm/mmap.c b/mm/mmap.c > index 79b8494d83c6..dde74e0b195d 100644 > --- a/mm/mmap.c > +++ b/mm/mmap.c > @@ -2449,7 +2449,7 @@ do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma, > * split, despite we could. This is unlikely enough > * failure that it's not worth optimizing it for. > */ > - int error = userfaultfd_unmap_prep(vma, start, end, uf); > + int error = userfaultfd_unmap_prep(mm, start, end, uf); > > if (error) > return error; > @@ -2938,10 +2938,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > goto munmap_full_vma; > } > > - vma_init(&unmap, mm); > - unmap.vm_start = newbrk; > - unmap.vm_end = oldbrk; > - ret = userfaultfd_unmap_prep(&unmap, newbrk, oldbrk, uf); > + ret = userfaultfd_unmap_prep(mm, newbrk, oldbrk, uf); > if (ret) > return ret; > ret = 1; > @@ -2954,6 +2951,9 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > } > > vma->vm_end = newbrk; > + vma_init(&unmap, mm); > + unmap.vm_start = newbrk; > + unmap.vm_end = oldbrk; > if (vma_mas_remove(&unmap, mas)) > goto mas_store_fail; > > @@ -2963,7 +2963,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > } > > unmap_pages = vma_pages(&unmap); > - if (unmap.vm_flags & VM_LOCKED) { > + if (vma->vm_flags & VM_LOCKED) { Hmm is this an unrelated bug fix? As unmap didn't have any vm_flags set even before this patch, right? > mm->locked_vm -= unmap_pages; > munlock_vma_pages_range(&unmap, newbrk, oldbrk); > }
* Vlastimil Babka <vbabka@suse.cz> [220119 11:26]: > On 12/1/21 15:30, Liam Howlett wrote: > > From: "Liam R. Howlett" <Liam.Howlett@Oracle.com> > > > > Don't use the mm_struct linked list or the vma->vm_next in prep for removal > > > > Signed-off-by: Liam R. Howlett <Liam.Howlett@Oracle.com> > > --- > > fs/userfaultfd.c | 49 ++++++++++++++++++++++------------- > > include/linux/userfaultfd_k.h | 7 +++-- > > mm/mmap.c | 12 ++++----- > > 3 files changed, 40 insertions(+), 28 deletions(-) > > > > diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c > > index 22bf14ab2d16..2880025598c7 100644 > > --- a/fs/userfaultfd.c > > +++ b/fs/userfaultfd.c > > @@ -606,14 +606,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, > > if (release_new_ctx) { > > struct vm_area_struct *vma; > > struct mm_struct *mm = release_new_ctx->mm; > > + VMA_ITERATOR(vmi, mm, 0); > > > > /* the various vma->vm_userfaultfd_ctx still points to it */ > > mmap_write_lock(mm); > > - for (vma = mm->mmap; vma; vma = vma->vm_next) > > + for_each_vma(vmi, vma) { > > if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) { > > vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; > > vma->vm_flags &= ~__VM_UFFD_FLAGS; > > } > > + } > > mmap_write_unlock(mm); > > > > userfaultfd_ctx_put(release_new_ctx); > > @@ -794,11 +796,13 @@ static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps, > > return false; > > } > > > > -int userfaultfd_unmap_prep(struct vm_area_struct *vma, > > - unsigned long start, unsigned long end, > > - struct list_head *unmaps) > > +int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, > > + unsigned long end, struct list_head *unmaps) > > { > > - for ( ; vma && vma->vm_start < end; vma = vma->vm_next) { > > + VMA_ITERATOR(vmi, mm, start); > > + struct vm_area_struct *vma; > > + > > + for_each_vma_range(vmi, vma, end) { > > struct userfaultfd_unmap_ctx *unmap_ctx; > > struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx; > > > > @@ -848,6 +852,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) > > /* len == 0 means wake all */ > > struct userfaultfd_wake_range range = { .len = 0, }; > > unsigned long new_flags; > > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > Again, it looks like this could also be VMA_ITERATOR, consistent with the > one above? VMA_ITERATOR is for simple cases, this is not a simple case, but in this change it does appear so. I missed the mas_pause() when the state is invalidated by vma_merge() success in the mas_for_each() loop below. I will fix this. > > > > > WRITE_ONCE(ctx->released, true); > > > > @@ -864,7 +869,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) > > */ > > mmap_write_lock(mm); > > prev = NULL; > > - for (vma = mm->mmap; vma; vma = vma->vm_next) { > > + mas_for_each(&mas, vma, ULONG_MAX) { > > cond_resched(); > > BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^ > > !!(vma->vm_flags & __VM_UFFD_FLAGS)); > > @@ -1281,6 +1286,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > bool found; > > bool basic_ioctls; > > unsigned long start, end, vma_end; > > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > > > user_uffdio_register = (struct uffdio_register __user *) arg; > > > > @@ -1323,7 +1329,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > goto out; > > > > mmap_write_lock(mm); > > - vma = find_vma_prev(mm, start, &prev); > > + mas_set(&mas, start); > > + vma = mas_find(&mas, ULONG_MAX); > > if (!vma) > > goto out_unlock; > > > > @@ -1348,7 +1355,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > */ > > found = false; > > basic_ioctls = false; > > - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { > > + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { > > cond_resched(); > > > > BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ > > @@ -1408,8 +1415,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > } > > BUG_ON(!found); > > > > - if (vma->vm_start < start) > > - prev = vma; > > + mas_set(&mas, start); > > + prev = mas_prev(&mas, 0); > > + if (prev != vma) > > + mas_next(&mas, ULONG_MAX); > > Hmm non-commented tricky stuff... Oh, I did not see this as tricky. I will add a comment. Basically, I am setting the maple state to search for start, then a mas_prev() means it will get the vma before start. > > > > > ret = 0; > > do { > > @@ -1466,8 +1475,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, > > skip: > > prev = vma; > > start = vma->vm_end; > > - vma = vma->vm_next; > > - } while (vma && vma->vm_start < end); > > + vma = mas_next(&mas, end - 1); > > + } while (vma); > > out_unlock: > > mmap_write_unlock(mm); > > mmput(mm); > > @@ -1511,6 +1520,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > bool found; > > unsigned long start, end, vma_end; > > const void __user *buf = (void __user *)arg; > > + MA_STATE(mas, &mm->mm_mt, 0, 0); > > > > ret = -EFAULT; > > if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister))) > > @@ -1529,7 +1539,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > goto out; > > > > mmap_write_lock(mm); > > - vma = find_vma_prev(mm, start, &prev); > > + mas_set(&mas, start); > > + vma = mas_find(&mas, ULONG_MAX); > > if (!vma) > > goto out_unlock; > > > > @@ -1554,7 +1565,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > */ > > found = false; > > ret = -EINVAL; > > - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { > > + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { > > cond_resched(); > > > > BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ > > @@ -1574,8 +1585,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > } > > BUG_ON(!found); > > > > - if (vma->vm_start < start) > > - prev = vma; > > + mas_set(&mas, start); > > + prev = mas_prev(&mas, 0); > > + if (prev != vma) > > + mas_next(&mas, ULONG_MAX); > > Same here. I'll add the comment here too. > > > > > ret = 0; > > do { > > @@ -1640,8 +1653,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, > > skip: > > prev = vma; > > start = vma->vm_end; > > - vma = vma->vm_next; > > - } while (vma && vma->vm_start < end); > > + vma = mas_next(&mas, end - 1); > > + } while (vma); > > out_unlock: > > mmap_write_unlock(mm); > > mmput(mm); > > diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h > > index 33cea484d1ad..e0b2ec2c20f2 100644 > > --- a/include/linux/userfaultfd_k.h > > +++ b/include/linux/userfaultfd_k.h > > @@ -139,9 +139,8 @@ extern bool userfaultfd_remove(struct vm_area_struct *vma, > > unsigned long start, > > unsigned long end); > > > > -extern int userfaultfd_unmap_prep(struct vm_area_struct *vma, > > - unsigned long start, unsigned long end, > > - struct list_head *uf); > > +extern int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, > > + unsigned long end, struct list_head *uf); > > extern void userfaultfd_unmap_complete(struct mm_struct *mm, > > struct list_head *uf); > > > > @@ -222,7 +221,7 @@ static inline bool userfaultfd_remove(struct vm_area_struct *vma, > > return true; > > } > > > > -static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma, > > +static inline int userfaultfd_unmap_prep(struct mm_struct *mm, > > unsigned long start, unsigned long end, > > struct list_head *uf) > > { > > diff --git a/mm/mmap.c b/mm/mmap.c > > index 79b8494d83c6..dde74e0b195d 100644 > > --- a/mm/mmap.c > > +++ b/mm/mmap.c > > @@ -2449,7 +2449,7 @@ do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma, > > * split, despite we could. This is unlikely enough > > * failure that it's not worth optimizing it for. > > */ > > - int error = userfaultfd_unmap_prep(vma, start, end, uf); > > + int error = userfaultfd_unmap_prep(mm, start, end, uf); > > > > if (error) > > return error; > > @@ -2938,10 +2938,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > > goto munmap_full_vma; > > } > > > > - vma_init(&unmap, mm); > > - unmap.vm_start = newbrk; > > - unmap.vm_end = oldbrk; > > - ret = userfaultfd_unmap_prep(&unmap, newbrk, oldbrk, uf); > > + ret = userfaultfd_unmap_prep(mm, newbrk, oldbrk, uf); > > if (ret) > > return ret; > > ret = 1; > > @@ -2954,6 +2951,9 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > > } > > > > vma->vm_end = newbrk; > > + vma_init(&unmap, mm); > > + unmap.vm_start = newbrk; > > + unmap.vm_end = oldbrk; > > if (vma_mas_remove(&unmap, mas)) > > goto mas_store_fail; > > > > @@ -2963,7 +2963,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, > > } > > > > unmap_pages = vma_pages(&unmap); > > - if (unmap.vm_flags & VM_LOCKED) { > > + if (vma->vm_flags & VM_LOCKED) { > > Hmm is this an unrelated bug fix? As unmap didn't have any vm_flags set even > before this patch, right? Yes. Thanks, I must have merged it into the wrong commit. > > > mm->locked_vm -= unmap_pages; > > munlock_vma_pages_range(&unmap, newbrk, oldbrk); > > } >
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c index 22bf14ab2d16..2880025598c7 100644 --- a/fs/userfaultfd.c +++ b/fs/userfaultfd.c @@ -606,14 +606,16 @@ static void userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx, if (release_new_ctx) { struct vm_area_struct *vma; struct mm_struct *mm = release_new_ctx->mm; + VMA_ITERATOR(vmi, mm, 0); /* the various vma->vm_userfaultfd_ctx still points to it */ mmap_write_lock(mm); - for (vma = mm->mmap; vma; vma = vma->vm_next) + for_each_vma(vmi, vma) { if (vma->vm_userfaultfd_ctx.ctx == release_new_ctx) { vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX; vma->vm_flags &= ~__VM_UFFD_FLAGS; } + } mmap_write_unlock(mm); userfaultfd_ctx_put(release_new_ctx); @@ -794,11 +796,13 @@ static bool has_unmap_ctx(struct userfaultfd_ctx *ctx, struct list_head *unmaps, return false; } -int userfaultfd_unmap_prep(struct vm_area_struct *vma, - unsigned long start, unsigned long end, - struct list_head *unmaps) +int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, + unsigned long end, struct list_head *unmaps) { - for ( ; vma && vma->vm_start < end; vma = vma->vm_next) { + VMA_ITERATOR(vmi, mm, start); + struct vm_area_struct *vma; + + for_each_vma_range(vmi, vma, end) { struct userfaultfd_unmap_ctx *unmap_ctx; struct userfaultfd_ctx *ctx = vma->vm_userfaultfd_ctx.ctx; @@ -848,6 +852,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) /* len == 0 means wake all */ struct userfaultfd_wake_range range = { .len = 0, }; unsigned long new_flags; + MA_STATE(mas, &mm->mm_mt, 0, 0); WRITE_ONCE(ctx->released, true); @@ -864,7 +869,7 @@ static int userfaultfd_release(struct inode *inode, struct file *file) */ mmap_write_lock(mm); prev = NULL; - for (vma = mm->mmap; vma; vma = vma->vm_next) { + mas_for_each(&mas, vma, ULONG_MAX) { cond_resched(); BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^ !!(vma->vm_flags & __VM_UFFD_FLAGS)); @@ -1281,6 +1286,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, bool found; bool basic_ioctls; unsigned long start, end, vma_end; + MA_STATE(mas, &mm->mm_mt, 0, 0); user_uffdio_register = (struct uffdio_register __user *) arg; @@ -1323,7 +1329,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, goto out; mmap_write_lock(mm); - vma = find_vma_prev(mm, start, &prev); + mas_set(&mas, start); + vma = mas_find(&mas, ULONG_MAX); if (!vma) goto out_unlock; @@ -1348,7 +1355,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, */ found = false; basic_ioctls = false; - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { cond_resched(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ @@ -1408,8 +1415,10 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, } BUG_ON(!found); - if (vma->vm_start < start) - prev = vma; + mas_set(&mas, start); + prev = mas_prev(&mas, 0); + if (prev != vma) + mas_next(&mas, ULONG_MAX); ret = 0; do { @@ -1466,8 +1475,8 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx, skip: prev = vma; start = vma->vm_end; - vma = vma->vm_next; - } while (vma && vma->vm_start < end); + vma = mas_next(&mas, end - 1); + } while (vma); out_unlock: mmap_write_unlock(mm); mmput(mm); @@ -1511,6 +1520,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, bool found; unsigned long start, end, vma_end; const void __user *buf = (void __user *)arg; + MA_STATE(mas, &mm->mm_mt, 0, 0); ret = -EFAULT; if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister))) @@ -1529,7 +1539,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, goto out; mmap_write_lock(mm); - vma = find_vma_prev(mm, start, &prev); + mas_set(&mas, start); + vma = mas_find(&mas, ULONG_MAX); if (!vma) goto out_unlock; @@ -1554,7 +1565,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, */ found = false; ret = -EINVAL; - for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) { + for (cur = vma; cur; cur = mas_next(&mas, end - 1)) { cond_resched(); BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^ @@ -1574,8 +1585,10 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, } BUG_ON(!found); - if (vma->vm_start < start) - prev = vma; + mas_set(&mas, start); + prev = mas_prev(&mas, 0); + if (prev != vma) + mas_next(&mas, ULONG_MAX); ret = 0; do { @@ -1640,8 +1653,8 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx, skip: prev = vma; start = vma->vm_end; - vma = vma->vm_next; - } while (vma && vma->vm_start < end); + vma = mas_next(&mas, end - 1); + } while (vma); out_unlock: mmap_write_unlock(mm); mmput(mm); diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h index 33cea484d1ad..e0b2ec2c20f2 100644 --- a/include/linux/userfaultfd_k.h +++ b/include/linux/userfaultfd_k.h @@ -139,9 +139,8 @@ extern bool userfaultfd_remove(struct vm_area_struct *vma, unsigned long start, unsigned long end); -extern int userfaultfd_unmap_prep(struct vm_area_struct *vma, - unsigned long start, unsigned long end, - struct list_head *uf); +extern int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, + unsigned long end, struct list_head *uf); extern void userfaultfd_unmap_complete(struct mm_struct *mm, struct list_head *uf); @@ -222,7 +221,7 @@ static inline bool userfaultfd_remove(struct vm_area_struct *vma, return true; } -static inline int userfaultfd_unmap_prep(struct vm_area_struct *vma, +static inline int userfaultfd_unmap_prep(struct mm_struct *mm, unsigned long start, unsigned long end, struct list_head *uf) { diff --git a/mm/mmap.c b/mm/mmap.c index 79b8494d83c6..dde74e0b195d 100644 --- a/mm/mmap.c +++ b/mm/mmap.c @@ -2449,7 +2449,7 @@ do_mas_align_munmap(struct ma_state *mas, struct vm_area_struct *vma, * split, despite we could. This is unlikely enough * failure that it's not worth optimizing it for. */ - int error = userfaultfd_unmap_prep(vma, start, end, uf); + int error = userfaultfd_unmap_prep(mm, start, end, uf); if (error) return error; @@ -2938,10 +2938,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, goto munmap_full_vma; } - vma_init(&unmap, mm); - unmap.vm_start = newbrk; - unmap.vm_end = oldbrk; - ret = userfaultfd_unmap_prep(&unmap, newbrk, oldbrk, uf); + ret = userfaultfd_unmap_prep(mm, newbrk, oldbrk, uf); if (ret) return ret; ret = 1; @@ -2954,6 +2951,9 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, } vma->vm_end = newbrk; + vma_init(&unmap, mm); + unmap.vm_start = newbrk; + unmap.vm_end = oldbrk; if (vma_mas_remove(&unmap, mas)) goto mas_store_fail; @@ -2963,7 +2963,7 @@ static int do_brk_munmap(struct ma_state *mas, struct vm_area_struct *vma, } unmap_pages = vma_pages(&unmap); - if (unmap.vm_flags & VM_LOCKED) { + if (vma->vm_flags & VM_LOCKED) { mm->locked_vm -= unmap_pages; munlock_vma_pages_range(&unmap, newbrk, oldbrk); }