diff mbox series

[PATCHv5,03/13] mm: Pass down mm_struct to untagged_addr()

Message ID 20220712231328.5294-4-kirill.shutemov@linux.intel.com (mailing list archive)
State New
Headers show
Series Linear Address Masking enabling | expand

Commit Message

Kirill A . Shutemov July 12, 2022, 11:13 p.m. UTC
Intel Linear Address Masking (LAM) brings per-mm untagging rules. Pass
down mm_struct to the untagging helper. It will help to apply untagging
policy correctly.

In most cases, current->mm is the one to use, but there are some
exceptions, such as get_user_page_remote().

Move dummy implementation of untagged_addr() from <linux/mm.h> to
<linux/uaccess.h>. <asm/uaccess.h> can override the implementation.
Moving the dummy header outside <linux/mm.h> helps to avoid header hell
if you need to defer mm_struct within the helper.

Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
Reviewed-by: Rick Edgecombe <rick.p.edgecombe@intel.com>
---
 arch/arm64/include/asm/memory.h                  |  4 ++--
 arch/arm64/include/asm/signal.h                  |  2 +-
 arch/arm64/include/asm/uaccess.h                 |  4 ++--
 arch/arm64/kernel/hw_breakpoint.c                |  2 +-
 arch/arm64/kernel/traps.c                        |  4 ++--
 arch/arm64/mm/fault.c                            | 10 +++++-----
 arch/sparc/include/asm/pgtable_64.h              |  2 +-
 arch/sparc/include/asm/uaccess_64.h              |  2 ++
 drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c |  2 +-
 drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c          |  2 +-
 drivers/gpu/drm/radeon/radeon_gem.c              |  2 +-
 drivers/infiniband/hw/mlx4/mr.c                  |  2 +-
 drivers/media/common/videobuf2/frame_vector.c    |  2 +-
 drivers/media/v4l2-core/videobuf-dma-contig.c    |  2 +-
 drivers/staging/media/atomisp/pci/hmm/hmm_bo.c   |  2 +-
 drivers/tee/tee_shm.c                            |  2 +-
 drivers/vfio/vfio_iommu_type1.c                  |  2 +-
 fs/proc/task_mmu.c                               |  2 +-
 include/linux/mm.h                               | 11 -----------
 include/linux/uaccess.h                          | 15 +++++++++++++++
 lib/strncpy_from_user.c                          |  2 +-
 lib/strnlen_user.c                               |  2 +-
 mm/gup.c                                         |  6 +++---
 mm/madvise.c                                     |  2 +-
 mm/mempolicy.c                                   |  6 +++---
 mm/migrate.c                                     |  2 +-
 mm/mincore.c                                     |  2 +-
 mm/mlock.c                                       |  4 ++--
 mm/mmap.c                                        |  2 +-
 mm/mprotect.c                                    |  2 +-
 mm/mremap.c                                      |  2 +-
 mm/msync.c                                       |  2 +-
 virt/kvm/kvm_main.c                              |  2 +-
 33 files changed, 59 insertions(+), 53 deletions(-)

Comments

Alexander Potapenko July 21, 2022, 1:12 p.m. UTC | #1
On Wed, Jul 13, 2022 at 1:13 AM Kirill A. Shutemov
<kirill.shutemov@linux.intel.com> wrote:
>
> Intel Linear Address Masking (LAM) brings per-mm untagging rules. Pass
> down mm_struct to the untagging helper. It will help to apply untagging
> policy correctly.
>
> In most cases, current->mm is the one to use, but there are some
> exceptions, such as get_user_page_remote().
>
> Move dummy implementation of untagged_addr() from <linux/mm.h> to
> <linux/uaccess.h>. <asm/uaccess.h> can override the implementation.
> Moving the dummy header outside <linux/mm.h> helps to avoid header hell
> if you need to defer mm_struct within the helper.
>
> Signed-off-by: Kirill A. Shutemov <kirill.shutemov@linux.intel.com>
> Reviewed-by: Rick Edgecombe <rick.p.edgecombe@intel.com>

Reviewed-by: Alexander Potapenko <glider@google.com>
Tested-by: Alexander Potapenko <glider@google.com>


> ---
>  arch/arm64/include/asm/memory.h                  |  4 ++--
>  arch/arm64/include/asm/signal.h                  |  2 +-
>  arch/arm64/include/asm/uaccess.h                 |  4 ++--
>  arch/arm64/kernel/hw_breakpoint.c                |  2 +-
>  arch/arm64/kernel/traps.c                        |  4 ++--
>  arch/arm64/mm/fault.c                            | 10 +++++-----
>  arch/sparc/include/asm/pgtable_64.h              |  2 +-
>  arch/sparc/include/asm/uaccess_64.h              |  2 ++
>  drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c |  2 +-
>  drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c          |  2 +-
>  drivers/gpu/drm/radeon/radeon_gem.c              |  2 +-
>  drivers/infiniband/hw/mlx4/mr.c                  |  2 +-
>  drivers/media/common/videobuf2/frame_vector.c    |  2 +-
>  drivers/media/v4l2-core/videobuf-dma-contig.c    |  2 +-
>  drivers/staging/media/atomisp/pci/hmm/hmm_bo.c   |  2 +-
>  drivers/tee/tee_shm.c                            |  2 +-
>  drivers/vfio/vfio_iommu_type1.c                  |  2 +-
>  fs/proc/task_mmu.c                               |  2 +-
>  include/linux/mm.h                               | 11 -----------
>  include/linux/uaccess.h                          | 15 +++++++++++++++
>  lib/strncpy_from_user.c                          |  2 +-
>  lib/strnlen_user.c                               |  2 +-
>  mm/gup.c                                         |  6 +++---
>  mm/madvise.c                                     |  2 +-
>  mm/mempolicy.c                                   |  6 +++---
>  mm/migrate.c                                     |  2 +-
>  mm/mincore.c                                     |  2 +-
>  mm/mlock.c                                       |  4 ++--
>  mm/mmap.c                                        |  2 +-
>  mm/mprotect.c                                    |  2 +-
>  mm/mremap.c                                      |  2 +-
>  mm/msync.c                                       |  2 +-
>  virt/kvm/kvm_main.c                              |  2 +-
>  33 files changed, 59 insertions(+), 53 deletions(-)
>
> diff --git a/arch/arm64/include/asm/memory.h b/arch/arm64/include/asm/memory.h
> index 0af70d9abede..88bee513b74c 100644
> --- a/arch/arm64/include/asm/memory.h
> +++ b/arch/arm64/include/asm/memory.h
> @@ -215,8 +215,8 @@ static inline unsigned long kaslr_offset(void)
>  #define __untagged_addr(addr)  \
>         ((__force __typeof__(addr))sign_extend64((__force u64)(addr), 55))
>
> -#define untagged_addr(addr)    ({                                      \
> -       u64 __addr = (__force u64)(addr);                                       \
> +#define untagged_addr(mm, addr)        ({                                      \
> +       u64 __addr = (__force u64)(addr);                               \
>         __addr &= __untagged_addr(__addr);                              \
>         (__force __typeof__(addr))__addr;                               \
>  })
> diff --git a/arch/arm64/include/asm/signal.h b/arch/arm64/include/asm/signal.h
> index ef449f5f4ba8..0899c355c398 100644
> --- a/arch/arm64/include/asm/signal.h
> +++ b/arch/arm64/include/asm/signal.h
> @@ -18,7 +18,7 @@ static inline void __user *arch_untagged_si_addr(void __user *addr,
>         if (sig == SIGTRAP && si_code == TRAP_BRKPT)
>                 return addr;
>
> -       return untagged_addr(addr);
> +       return untagged_addr(current->mm, addr);
>  }
>  #define arch_untagged_si_addr arch_untagged_si_addr
>
> diff --git a/arch/arm64/include/asm/uaccess.h b/arch/arm64/include/asm/uaccess.h
> index 63f9c828f1a7..bdcc014bd297 100644
> --- a/arch/arm64/include/asm/uaccess.h
> +++ b/arch/arm64/include/asm/uaccess.h
> @@ -44,7 +44,7 @@ static inline int access_ok(const void __user *addr, unsigned long size)
>          */
>         if (IS_ENABLED(CONFIG_ARM64_TAGGED_ADDR_ABI) &&
>             (current->flags & PF_KTHREAD || test_thread_flag(TIF_TAGGED_ADDR)))
> -               addr = untagged_addr(addr);
> +               addr = untagged_addr(current->mm, addr);
>
>         return likely(__access_ok(addr, size));
>  }
> @@ -217,7 +217,7 @@ static inline void __user *__uaccess_mask_ptr(const void __user *ptr)
>         "       csel    %0, %1, xzr, eq\n"
>         : "=&r" (safe_ptr)
>         : "r" (ptr), "r" (TASK_SIZE_MAX - 1),
> -         "r" (untagged_addr(ptr))
> +         "r" (untagged_addr(current->mm, ptr))
>         : "cc");
>
>         csdb();
> diff --git a/arch/arm64/kernel/hw_breakpoint.c b/arch/arm64/kernel/hw_breakpoint.c
> index b29a311bb055..d637cee7b771 100644
> --- a/arch/arm64/kernel/hw_breakpoint.c
> +++ b/arch/arm64/kernel/hw_breakpoint.c
> @@ -715,7 +715,7 @@ static u64 get_distance_from_watchpoint(unsigned long addr, u64 val,
>         u64 wp_low, wp_high;
>         u32 lens, lene;
>
> -       addr = untagged_addr(addr);
> +       addr = untagged_addr(current->mm, addr);
>
>         lens = __ffs(ctrl->len);
>         lene = __fls(ctrl->len);
> diff --git a/arch/arm64/kernel/traps.c b/arch/arm64/kernel/traps.c
> index 9ac7a81b79be..385612d9890b 100644
> --- a/arch/arm64/kernel/traps.c
> +++ b/arch/arm64/kernel/traps.c
> @@ -476,7 +476,7 @@ void arm64_notify_segfault(unsigned long addr)
>         int code;
>
>         mmap_read_lock(current->mm);
> -       if (find_vma(current->mm, untagged_addr(addr)) == NULL)
> +       if (find_vma(current->mm, untagged_addr(current->mm, addr)) == NULL)
>                 code = SEGV_MAPERR;
>         else
>                 code = SEGV_ACCERR;
> @@ -540,7 +540,7 @@ static void user_cache_maint_handler(unsigned long esr, struct pt_regs *regs)
>         int ret = 0;
>
>         tagged_address = pt_regs_read_reg(regs, rt);
> -       address = untagged_addr(tagged_address);
> +       address = untagged_addr(current->mm, tagged_address);
>
>         switch (crm) {
>         case ESR_ELx_SYS64_ISS_CRM_DC_CVAU:     /* DC CVAU, gets promoted */
> diff --git a/arch/arm64/mm/fault.c b/arch/arm64/mm/fault.c
> index c5e11768e5c1..9577d7e37f36 100644
> --- a/arch/arm64/mm/fault.c
> +++ b/arch/arm64/mm/fault.c
> @@ -454,7 +454,7 @@ static void set_thread_esr(unsigned long address, unsigned long esr)
>  static void do_bad_area(unsigned long far, unsigned long esr,
>                         struct pt_regs *regs)
>  {
> -       unsigned long addr = untagged_addr(far);
> +       unsigned long addr = untagged_addr(current->mm, far);
>
>         /*
>          * If we are in kernel mode at this point, we have no context to
> @@ -524,7 +524,7 @@ static int __kprobes do_page_fault(unsigned long far, unsigned long esr,
>         vm_fault_t fault;
>         unsigned long vm_flags;
>         unsigned int mm_flags = FAULT_FLAG_DEFAULT;
> -       unsigned long addr = untagged_addr(far);
> +       unsigned long addr = untagged_addr(mm, far);
>
>         if (kprobe_page_fault(regs, esr))
>                 return 0;
> @@ -675,7 +675,7 @@ static int __kprobes do_translation_fault(unsigned long far,
>                                           unsigned long esr,
>                                           struct pt_regs *regs)
>  {
> -       unsigned long addr = untagged_addr(far);
> +       unsigned long addr = untagged_addr(current->mm, far);
>
>         if (is_ttbr0_addr(addr))
>                 return do_page_fault(far, esr, regs);
> @@ -719,7 +719,7 @@ static int do_sea(unsigned long far, unsigned long esr, struct pt_regs *regs)
>                  * UNKNOWN for synchronous external aborts. Mask them out now
>                  * so that userspace doesn't see them.
>                  */
> -               siaddr  = untagged_addr(far);
> +               siaddr  = untagged_addr(current->mm, far);
>         }
>         arm64_notify_die(inf->name, regs, inf->sig, inf->code, siaddr, esr);
>
> @@ -809,7 +809,7 @@ static const struct fault_info fault_info[] = {
>  void do_mem_abort(unsigned long far, unsigned long esr, struct pt_regs *regs)
>  {
>         const struct fault_info *inf = esr_to_fault_info(esr);
> -       unsigned long addr = untagged_addr(far);
> +       unsigned long addr = untagged_addr(current->mm, far);
>
>         if (!inf->fn(far, esr, regs))
>                 return;
> diff --git a/arch/sparc/include/asm/pgtable_64.h b/arch/sparc/include/asm/pgtable_64.h
> index 4679e45c8348..1336d7bfaab9 100644
> --- a/arch/sparc/include/asm/pgtable_64.h
> +++ b/arch/sparc/include/asm/pgtable_64.h
> @@ -1071,7 +1071,7 @@ static inline unsigned long __untagged_addr(unsigned long start)
>
>         return start;
>  }
> -#define untagged_addr(addr) \
> +#define untagged_addr(mm, addr) \
>         ((__typeof__(addr))(__untagged_addr((unsigned long)(addr))))
>
>  static inline bool pte_access_permitted(pte_t pte, bool write)
> diff --git a/arch/sparc/include/asm/uaccess_64.h b/arch/sparc/include/asm/uaccess_64.h
> index 94266a5c5b04..b825a5dd0210 100644
> --- a/arch/sparc/include/asm/uaccess_64.h
> +++ b/arch/sparc/include/asm/uaccess_64.h
> @@ -8,8 +8,10 @@
>
>  #include <linux/compiler.h>
>  #include <linux/string.h>
> +#include <linux/mm_types.h>
>  #include <asm/asi.h>
>  #include <asm/spitfire.h>
> +#include <asm/pgtable.h>
>
>  #include <asm/processor.h>
>  #include <asm-generic/access_ok.h>
> diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c
> index 6b6d46e29e6e..b37199b16643 100644
> --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c
> +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c
> @@ -1491,7 +1491,7 @@ int amdgpu_amdkfd_gpuvm_alloc_memory_of_gpu(
>                 if (flags & KFD_IOC_ALLOC_MEM_FLAGS_USERPTR) {
>                         if (!offset || !*offset)
>                                 return -EINVAL;
> -                       user_addr = untagged_addr(*offset);
> +                       user_addr = untagged_addr(current->mm, *offset);
>                 } else if (flags & (KFD_IOC_ALLOC_MEM_FLAGS_DOORBELL |
>                                     KFD_IOC_ALLOC_MEM_FLAGS_MMIO_REMAP)) {
>                         bo_type = ttm_bo_type_sg;
> diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c
> index 8ef31d687ef3..691dfb3f2c0e 100644
> --- a/drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c
> +++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c
> @@ -382,7 +382,7 @@ int amdgpu_gem_userptr_ioctl(struct drm_device *dev, void *data,
>         uint32_t handle;
>         int r;
>
> -       args->addr = untagged_addr(args->addr);
> +       args->addr = untagged_addr(current->mm, args->addr);
>
>         if (offset_in_page(args->addr | args->size))
>                 return -EINVAL;
> diff --git a/drivers/gpu/drm/radeon/radeon_gem.c b/drivers/gpu/drm/radeon/radeon_gem.c
> index 8c01a7f0e027..2c3980677f64 100644
> --- a/drivers/gpu/drm/radeon/radeon_gem.c
> +++ b/drivers/gpu/drm/radeon/radeon_gem.c
> @@ -371,7 +371,7 @@ int radeon_gem_userptr_ioctl(struct drm_device *dev, void *data,
>         uint32_t handle;
>         int r;
>
> -       args->addr = untagged_addr(args->addr);
> +       args->addr = untagged_addr(current->mm, args->addr);
>
>         if (offset_in_page(args->addr | args->size))
>                 return -EINVAL;
> diff --git a/drivers/infiniband/hw/mlx4/mr.c b/drivers/infiniband/hw/mlx4/mr.c
> index 04a67b481608..b2860feeae3c 100644
> --- a/drivers/infiniband/hw/mlx4/mr.c
> +++ b/drivers/infiniband/hw/mlx4/mr.c
> @@ -379,7 +379,7 @@ static struct ib_umem *mlx4_get_umem_mr(struct ib_device *device, u64 start,
>          * again
>          */
>         if (!ib_access_writable(access_flags)) {
> -               unsigned long untagged_start = untagged_addr(start);
> +               unsigned long untagged_start = untagged_addr(current->mm, start);
>                 struct vm_area_struct *vma;
>
>                 mmap_read_lock(current->mm);
> diff --git a/drivers/media/common/videobuf2/frame_vector.c b/drivers/media/common/videobuf2/frame_vector.c
> index 542dde9d2609..7e62f7a2555d 100644
> --- a/drivers/media/common/videobuf2/frame_vector.c
> +++ b/drivers/media/common/videobuf2/frame_vector.c
> @@ -47,7 +47,7 @@ int get_vaddr_frames(unsigned long start, unsigned int nr_frames,
>         if (WARN_ON_ONCE(nr_frames > vec->nr_allocated))
>                 nr_frames = vec->nr_allocated;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(mm, start);
>
>         ret = pin_user_pages_fast(start, nr_frames,
>                                   FOLL_FORCE | FOLL_WRITE | FOLL_LONGTERM,
> diff --git a/drivers/media/v4l2-core/videobuf-dma-contig.c b/drivers/media/v4l2-core/videobuf-dma-contig.c
> index 52312ce2ba05..a1444f8afa05 100644
> --- a/drivers/media/v4l2-core/videobuf-dma-contig.c
> +++ b/drivers/media/v4l2-core/videobuf-dma-contig.c
> @@ -157,8 +157,8 @@ static void videobuf_dma_contig_user_put(struct videobuf_dma_contig_memory *mem)
>  static int videobuf_dma_contig_user_get(struct videobuf_dma_contig_memory *mem,
>                                         struct videobuf_buffer *vb)
>  {
> -       unsigned long untagged_baddr = untagged_addr(vb->baddr);
>         struct mm_struct *mm = current->mm;
> +       unsigned long untagged_baddr = untagged_addr(mm, vb->baddr);
>         struct vm_area_struct *vma;
>         unsigned long prev_pfn, this_pfn;
>         unsigned long pages_done, user_address;
> diff --git a/drivers/staging/media/atomisp/pci/hmm/hmm_bo.c b/drivers/staging/media/atomisp/pci/hmm/hmm_bo.c
> index 0168f9839c90..863d30a7ad23 100644
> --- a/drivers/staging/media/atomisp/pci/hmm/hmm_bo.c
> +++ b/drivers/staging/media/atomisp/pci/hmm/hmm_bo.c
> @@ -913,7 +913,7 @@ static int alloc_user_pages(struct hmm_buffer_object *bo,
>          * and map to user space
>          */
>
> -       userptr = untagged_addr(userptr);
> +       userptr = untagged_addr(current->mm, userptr);
>
>         bo->pages = pages;
>
> diff --git a/drivers/tee/tee_shm.c b/drivers/tee/tee_shm.c
> index f2b1bcefcadd..386be09cb2cd 100644
> --- a/drivers/tee/tee_shm.c
> +++ b/drivers/tee/tee_shm.c
> @@ -261,7 +261,7 @@ register_shm_helper(struct tee_context *ctx, unsigned long addr,
>         shm->flags = flags;
>         shm->ctx = ctx;
>         shm->id = id;
> -       addr = untagged_addr(addr);
> +       addr = untagged_addr(current->mm, addr);
>         start = rounddown(addr, PAGE_SIZE);
>         shm->offset = addr - start;
>         shm->size = length;
> diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
> index c13b9290e357..5ac6c61d7caa 100644
> --- a/drivers/vfio/vfio_iommu_type1.c
> +++ b/drivers/vfio/vfio_iommu_type1.c
> @@ -561,7 +561,7 @@ static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
>                 goto done;
>         }
>
> -       vaddr = untagged_addr(vaddr);
> +       vaddr = untagged_addr(mm, vaddr);
>
>  retry:
>         vma = vma_lookup(mm, vaddr);
> diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c
> index 2d04e3470d4c..c7d262bd6d6b 100644
> --- a/fs/proc/task_mmu.c
> +++ b/fs/proc/task_mmu.c
> @@ -1659,7 +1659,7 @@ static ssize_t pagemap_read(struct file *file, char __user *buf,
>         /* watch out for wraparound */
>         start_vaddr = end_vaddr;
>         if (svpfn <= (ULONG_MAX >> PAGE_SHIFT))
> -               start_vaddr = untagged_addr(svpfn << PAGE_SHIFT);
> +               start_vaddr = untagged_addr(mm, svpfn << PAGE_SHIFT);
>
>         /* Ensure the address is inside the task */
>         if (start_vaddr > mm->task_size)
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index bc8f326be0ce..f0cb92ff1391 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -94,17 +94,6 @@ extern int mmap_rnd_compat_bits __read_mostly;
>  #include <asm/page.h>
>  #include <asm/processor.h>
>
> -/*
> - * Architectures that support memory tagging (assigning tags to memory regions,
> - * embedding these tags into addresses that point to these memory regions, and
> - * checking that the memory and the pointer tags match on memory accesses)
> - * redefine this macro to strip tags from pointers.
> - * It's defined as noop for architectures that don't support memory tagging.
> - */
> -#ifndef untagged_addr
> -#define untagged_addr(addr) (addr)
> -#endif
> -
>  #ifndef __pa_symbol
>  #define __pa_symbol(x)  __pa(RELOC_HIDE((unsigned long)(x), 0))
>  #endif
> diff --git a/include/linux/uaccess.h b/include/linux/uaccess.h
> index 5a328cf02b75..46fd816179d7 100644
> --- a/include/linux/uaccess.h
> +++ b/include/linux/uaccess.h
> @@ -10,6 +10,21 @@
>
>  #include <asm/uaccess.h>
>
> +/*
> + * Architectures that support memory tagging (assigning tags to memory regions,
> + * embedding these tags into addresses that point to these memory regions, and
> + * checking that the memory and the pointer tags match on memory accesses)
> + * redefine this macro to strip tags from pointers.
> + *
> + * Passing down mm_struct allows to define untagging rules on per-process
> + * basis.
> + *
> + * It's defined as noop for architectures that don't support memory tagging.
> + */
> +#ifndef untagged_addr
> +#define untagged_addr(mm, addr) (addr)
> +#endif
> +
>  /*
>   * Architectures should provide two primitives (raw_copy_{to,from}_user())
>   * and get rid of their private instances of copy_{to,from}_user() and
> diff --git a/lib/strncpy_from_user.c b/lib/strncpy_from_user.c
> index 6432b8c3e431..6e1e2aa0c994 100644
> --- a/lib/strncpy_from_user.c
> +++ b/lib/strncpy_from_user.c
> @@ -121,7 +121,7 @@ long strncpy_from_user(char *dst, const char __user *src, long count)
>                 return 0;
>
>         max_addr = TASK_SIZE_MAX;
> -       src_addr = (unsigned long)untagged_addr(src);
> +       src_addr = (unsigned long)untagged_addr(current->mm, src);
>         if (likely(src_addr < max_addr)) {
>                 unsigned long max = max_addr - src_addr;
>                 long retval;
> diff --git a/lib/strnlen_user.c b/lib/strnlen_user.c
> index feeb935a2299..abc096a68f05 100644
> --- a/lib/strnlen_user.c
> +++ b/lib/strnlen_user.c
> @@ -97,7 +97,7 @@ long strnlen_user(const char __user *str, long count)
>                 return 0;
>
>         max_addr = TASK_SIZE_MAX;
> -       src_addr = (unsigned long)untagged_addr(str);
> +       src_addr = (unsigned long)untagged_addr(current->mm, str);
>         if (likely(src_addr < max_addr)) {
>                 unsigned long max = max_addr - src_addr;
>                 long retval;
> diff --git a/mm/gup.c b/mm/gup.c
> index 551264407624..dbe825faf842 100644
> --- a/mm/gup.c
> +++ b/mm/gup.c
> @@ -1104,7 +1104,7 @@ static long __get_user_pages(struct mm_struct *mm,
>         if (!nr_pages)
>                 return 0;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(mm, start);
>
>         VM_BUG_ON(!!pages != !!(gup_flags & (FOLL_GET | FOLL_PIN)));
>
> @@ -1285,7 +1285,7 @@ int fixup_user_fault(struct mm_struct *mm,
>         struct vm_area_struct *vma;
>         vm_fault_t ret;
>
> -       address = untagged_addr(address);
> +       address = untagged_addr(mm, address);
>
>         if (unlocked)
>                 fault_flags |= FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_KILLABLE;
> @@ -2865,7 +2865,7 @@ static int internal_get_user_pages_fast(unsigned long start,
>         if (!(gup_flags & FOLL_FAST_ONLY))
>                 might_lock_read(&current->mm->mmap_lock);
>
> -       start = untagged_addr(start) & PAGE_MASK;
> +       start = untagged_addr(current->mm, start) & PAGE_MASK;
>         len = nr_pages << PAGE_SHIFT;
>         if (check_add_overflow(start, len, &end))
>                 return 0;
> diff --git a/mm/madvise.c b/mm/madvise.c
> index d7b4f2602949..e3c668ddb099 100644
> --- a/mm/madvise.c
> +++ b/mm/madvise.c
> @@ -1373,7 +1373,7 @@ int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, int beh
>         size_t len;
>         struct blk_plug plug;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(mm, start);
>
>         if (!madvise_behavior_valid(behavior))
>                 return -EINVAL;
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index d39b01fd52fe..a03b4d2bc26a 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -1458,7 +1458,7 @@ static long kernel_mbind(unsigned long start, unsigned long len,
>         int lmode = mode;
>         int err;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(current->mm, start);
>         err = sanitize_mpol_flags(&lmode, &mode_flags);
>         if (err)
>                 return err;
> @@ -1481,7 +1481,7 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, le
>         unsigned long end;
>         int err = -ENOENT;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(mm, start);
>         if (start & ~PAGE_MASK)
>                 return -EINVAL;
>         /*
> @@ -1684,7 +1684,7 @@ static int kernel_get_mempolicy(int __user *policy,
>         if (nmask != NULL && maxnode < nr_node_ids)
>                 return -EINVAL;
>
> -       addr = untagged_addr(addr);
> +       addr = untagged_addr(current->mm, addr);
>
>         err = do_get_mempolicy(&pval, &nodes, addr, flags);
>
> diff --git a/mm/migrate.c b/mm/migrate.c
> index e51588e95f57..af05049b055b 100644
> --- a/mm/migrate.c
> +++ b/mm/migrate.c
> @@ -1714,7 +1714,7 @@ static int do_pages_move(struct mm_struct *mm, nodemask_t task_nodes,
>                         goto out_flush;
>                 if (get_user(node, nodes + i))
>                         goto out_flush;
> -               addr = (unsigned long)untagged_addr(p);
> +               addr = (unsigned long)untagged_addr(mm, p);
>
>                 err = -ENODEV;
>                 if (node < 0 || node >= MAX_NUMNODES)
> diff --git a/mm/mincore.c b/mm/mincore.c
> index fa200c14185f..72c55bd9d184 100644
> --- a/mm/mincore.c
> +++ b/mm/mincore.c
> @@ -236,7 +236,7 @@ SYSCALL_DEFINE3(mincore, unsigned long, start, size_t, len,
>         unsigned long pages;
>         unsigned char *tmp;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(current->mm, start);
>
>         /* Check the start address: needs to be page-aligned.. */
>         if (start & ~PAGE_MASK)
> diff --git a/mm/mlock.c b/mm/mlock.c
> index 716caf851043..054168d3e648 100644
> --- a/mm/mlock.c
> +++ b/mm/mlock.c
> @@ -571,7 +571,7 @@ static __must_check int do_mlock(unsigned long start, size_t len, vm_flags_t fla
>         unsigned long lock_limit;
>         int error = -ENOMEM;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(current->mm, start);
>
>         if (!can_do_mlock())
>                 return -EPERM;
> @@ -634,7 +634,7 @@ SYSCALL_DEFINE2(munlock, unsigned long, start, size_t, len)
>  {
>         int ret;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(current->mm, start);
>
>         len = PAGE_ALIGN(len + (offset_in_page(start)));
>         start &= PAGE_MASK;
> diff --git a/mm/mmap.c b/mm/mmap.c
> index 61e6135c54ef..1a7baf6b6b8e 100644
> --- a/mm/mmap.c
> +++ b/mm/mmap.c
> @@ -2926,7 +2926,7 @@ EXPORT_SYMBOL(vm_munmap);
>
>  SYSCALL_DEFINE2(munmap, unsigned long, addr, size_t, len)
>  {
> -       addr = untagged_addr(addr);
> +       addr = untagged_addr(current->mm, addr);
>         return __vm_munmap(addr, len, true);
>  }
>
> diff --git a/mm/mprotect.c b/mm/mprotect.c
> index ba5592655ee3..871e954f6155 100644
> --- a/mm/mprotect.c
> +++ b/mm/mprotect.c
> @@ -622,7 +622,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
>                                 (prot & PROT_READ);
>         struct mmu_gather tlb;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(current->mm, start);
>
>         prot &= ~(PROT_GROWSDOWN|PROT_GROWSUP);
>         if (grows == (PROT_GROWSDOWN|PROT_GROWSUP)) /* can't be both */
> diff --git a/mm/mremap.c b/mm/mremap.c
> index b522cd0259a0..f76648bc4f67 100644
> --- a/mm/mremap.c
> +++ b/mm/mremap.c
> @@ -906,7 +906,7 @@ SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
>          *
>          * See Documentation/arm64/tagged-address-abi.rst for more information.
>          */
> -       addr = untagged_addr(addr);
> +       addr = untagged_addr(mm, addr);
>
>         if (flags & ~(MREMAP_FIXED | MREMAP_MAYMOVE | MREMAP_DONTUNMAP))
>                 return ret;
> diff --git a/mm/msync.c b/mm/msync.c
> index 137d1c104f3e..5fe989bd3c4b 100644
> --- a/mm/msync.c
> +++ b/mm/msync.c
> @@ -37,7 +37,7 @@ SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags)
>         int unmapped_error = 0;
>         int error = -EINVAL;
>
> -       start = untagged_addr(start);
> +       start = untagged_addr(mm, start);
>
>         if (flags & ~(MS_ASYNC | MS_INVALIDATE | MS_SYNC))
>                 goto out;
> diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
> index a49df8988cd6..03f7ad0ebc8a 100644
> --- a/virt/kvm/kvm_main.c
> +++ b/virt/kvm/kvm_main.c
> @@ -1876,7 +1876,7 @@ int __kvm_set_memory_region(struct kvm *kvm,
>                 return -EINVAL;
>         /* We can read the guest memory with __xxx_user() later on. */
>         if ((mem->userspace_addr & (PAGE_SIZE - 1)) ||
> -           (mem->userspace_addr != untagged_addr(mem->userspace_addr)) ||
> +           (mem->userspace_addr != untagged_addr(kvm->mm, mem->userspace_addr)) ||
>              !access_ok((void __user *)(unsigned long)mem->userspace_addr,
>                         mem->memory_size))
>                 return -EINVAL;
> --
> 2.35.1
>
diff mbox series

Patch

diff --git a/arch/arm64/include/asm/memory.h b/arch/arm64/include/asm/memory.h
index 0af70d9abede..88bee513b74c 100644
--- a/arch/arm64/include/asm/memory.h
+++ b/arch/arm64/include/asm/memory.h
@@ -215,8 +215,8 @@  static inline unsigned long kaslr_offset(void)
 #define __untagged_addr(addr)	\
 	((__force __typeof__(addr))sign_extend64((__force u64)(addr), 55))
 
-#define untagged_addr(addr)	({					\
-	u64 __addr = (__force u64)(addr);					\
+#define untagged_addr(mm, addr)	({					\
+	u64 __addr = (__force u64)(addr);				\
 	__addr &= __untagged_addr(__addr);				\
 	(__force __typeof__(addr))__addr;				\
 })
diff --git a/arch/arm64/include/asm/signal.h b/arch/arm64/include/asm/signal.h
index ef449f5f4ba8..0899c355c398 100644
--- a/arch/arm64/include/asm/signal.h
+++ b/arch/arm64/include/asm/signal.h
@@ -18,7 +18,7 @@  static inline void __user *arch_untagged_si_addr(void __user *addr,
 	if (sig == SIGTRAP && si_code == TRAP_BRKPT)
 		return addr;
 
-	return untagged_addr(addr);
+	return untagged_addr(current->mm, addr);
 }
 #define arch_untagged_si_addr arch_untagged_si_addr
 
diff --git a/arch/arm64/include/asm/uaccess.h b/arch/arm64/include/asm/uaccess.h
index 63f9c828f1a7..bdcc014bd297 100644
--- a/arch/arm64/include/asm/uaccess.h
+++ b/arch/arm64/include/asm/uaccess.h
@@ -44,7 +44,7 @@  static inline int access_ok(const void __user *addr, unsigned long size)
 	 */
 	if (IS_ENABLED(CONFIG_ARM64_TAGGED_ADDR_ABI) &&
 	    (current->flags & PF_KTHREAD || test_thread_flag(TIF_TAGGED_ADDR)))
-		addr = untagged_addr(addr);
+		addr = untagged_addr(current->mm, addr);
 
 	return likely(__access_ok(addr, size));
 }
@@ -217,7 +217,7 @@  static inline void __user *__uaccess_mask_ptr(const void __user *ptr)
 	"	csel	%0, %1, xzr, eq\n"
 	: "=&r" (safe_ptr)
 	: "r" (ptr), "r" (TASK_SIZE_MAX - 1),
-	  "r" (untagged_addr(ptr))
+	  "r" (untagged_addr(current->mm, ptr))
 	: "cc");
 
 	csdb();
diff --git a/arch/arm64/kernel/hw_breakpoint.c b/arch/arm64/kernel/hw_breakpoint.c
index b29a311bb055..d637cee7b771 100644
--- a/arch/arm64/kernel/hw_breakpoint.c
+++ b/arch/arm64/kernel/hw_breakpoint.c
@@ -715,7 +715,7 @@  static u64 get_distance_from_watchpoint(unsigned long addr, u64 val,
 	u64 wp_low, wp_high;
 	u32 lens, lene;
 
-	addr = untagged_addr(addr);
+	addr = untagged_addr(current->mm, addr);
 
 	lens = __ffs(ctrl->len);
 	lene = __fls(ctrl->len);
diff --git a/arch/arm64/kernel/traps.c b/arch/arm64/kernel/traps.c
index 9ac7a81b79be..385612d9890b 100644
--- a/arch/arm64/kernel/traps.c
+++ b/arch/arm64/kernel/traps.c
@@ -476,7 +476,7 @@  void arm64_notify_segfault(unsigned long addr)
 	int code;
 
 	mmap_read_lock(current->mm);
-	if (find_vma(current->mm, untagged_addr(addr)) == NULL)
+	if (find_vma(current->mm, untagged_addr(current->mm, addr)) == NULL)
 		code = SEGV_MAPERR;
 	else
 		code = SEGV_ACCERR;
@@ -540,7 +540,7 @@  static void user_cache_maint_handler(unsigned long esr, struct pt_regs *regs)
 	int ret = 0;
 
 	tagged_address = pt_regs_read_reg(regs, rt);
-	address = untagged_addr(tagged_address);
+	address = untagged_addr(current->mm, tagged_address);
 
 	switch (crm) {
 	case ESR_ELx_SYS64_ISS_CRM_DC_CVAU:	/* DC CVAU, gets promoted */
diff --git a/arch/arm64/mm/fault.c b/arch/arm64/mm/fault.c
index c5e11768e5c1..9577d7e37f36 100644
--- a/arch/arm64/mm/fault.c
+++ b/arch/arm64/mm/fault.c
@@ -454,7 +454,7 @@  static void set_thread_esr(unsigned long address, unsigned long esr)
 static void do_bad_area(unsigned long far, unsigned long esr,
 			struct pt_regs *regs)
 {
-	unsigned long addr = untagged_addr(far);
+	unsigned long addr = untagged_addr(current->mm, far);
 
 	/*
 	 * If we are in kernel mode at this point, we have no context to
@@ -524,7 +524,7 @@  static int __kprobes do_page_fault(unsigned long far, unsigned long esr,
 	vm_fault_t fault;
 	unsigned long vm_flags;
 	unsigned int mm_flags = FAULT_FLAG_DEFAULT;
-	unsigned long addr = untagged_addr(far);
+	unsigned long addr = untagged_addr(mm, far);
 
 	if (kprobe_page_fault(regs, esr))
 		return 0;
@@ -675,7 +675,7 @@  static int __kprobes do_translation_fault(unsigned long far,
 					  unsigned long esr,
 					  struct pt_regs *regs)
 {
-	unsigned long addr = untagged_addr(far);
+	unsigned long addr = untagged_addr(current->mm, far);
 
 	if (is_ttbr0_addr(addr))
 		return do_page_fault(far, esr, regs);
@@ -719,7 +719,7 @@  static int do_sea(unsigned long far, unsigned long esr, struct pt_regs *regs)
 		 * UNKNOWN for synchronous external aborts. Mask them out now
 		 * so that userspace doesn't see them.
 		 */
-		siaddr  = untagged_addr(far);
+		siaddr  = untagged_addr(current->mm, far);
 	}
 	arm64_notify_die(inf->name, regs, inf->sig, inf->code, siaddr, esr);
 
@@ -809,7 +809,7 @@  static const struct fault_info fault_info[] = {
 void do_mem_abort(unsigned long far, unsigned long esr, struct pt_regs *regs)
 {
 	const struct fault_info *inf = esr_to_fault_info(esr);
-	unsigned long addr = untagged_addr(far);
+	unsigned long addr = untagged_addr(current->mm, far);
 
 	if (!inf->fn(far, esr, regs))
 		return;
diff --git a/arch/sparc/include/asm/pgtable_64.h b/arch/sparc/include/asm/pgtable_64.h
index 4679e45c8348..1336d7bfaab9 100644
--- a/arch/sparc/include/asm/pgtable_64.h
+++ b/arch/sparc/include/asm/pgtable_64.h
@@ -1071,7 +1071,7 @@  static inline unsigned long __untagged_addr(unsigned long start)
 
 	return start;
 }
-#define untagged_addr(addr) \
+#define untagged_addr(mm, addr) \
 	((__typeof__(addr))(__untagged_addr((unsigned long)(addr))))
 
 static inline bool pte_access_permitted(pte_t pte, bool write)
diff --git a/arch/sparc/include/asm/uaccess_64.h b/arch/sparc/include/asm/uaccess_64.h
index 94266a5c5b04..b825a5dd0210 100644
--- a/arch/sparc/include/asm/uaccess_64.h
+++ b/arch/sparc/include/asm/uaccess_64.h
@@ -8,8 +8,10 @@ 
 
 #include <linux/compiler.h>
 #include <linux/string.h>
+#include <linux/mm_types.h>
 #include <asm/asi.h>
 #include <asm/spitfire.h>
+#include <asm/pgtable.h>
 
 #include <asm/processor.h>
 #include <asm-generic/access_ok.h>
diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c
index 6b6d46e29e6e..b37199b16643 100644
--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c
+++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_amdkfd_gpuvm.c
@@ -1491,7 +1491,7 @@  int amdgpu_amdkfd_gpuvm_alloc_memory_of_gpu(
 		if (flags & KFD_IOC_ALLOC_MEM_FLAGS_USERPTR) {
 			if (!offset || !*offset)
 				return -EINVAL;
-			user_addr = untagged_addr(*offset);
+			user_addr = untagged_addr(current->mm, *offset);
 		} else if (flags & (KFD_IOC_ALLOC_MEM_FLAGS_DOORBELL |
 				    KFD_IOC_ALLOC_MEM_FLAGS_MMIO_REMAP)) {
 			bo_type = ttm_bo_type_sg;
diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c
index 8ef31d687ef3..691dfb3f2c0e 100644
--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c
+++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_gem.c
@@ -382,7 +382,7 @@  int amdgpu_gem_userptr_ioctl(struct drm_device *dev, void *data,
 	uint32_t handle;
 	int r;
 
-	args->addr = untagged_addr(args->addr);
+	args->addr = untagged_addr(current->mm, args->addr);
 
 	if (offset_in_page(args->addr | args->size))
 		return -EINVAL;
diff --git a/drivers/gpu/drm/radeon/radeon_gem.c b/drivers/gpu/drm/radeon/radeon_gem.c
index 8c01a7f0e027..2c3980677f64 100644
--- a/drivers/gpu/drm/radeon/radeon_gem.c
+++ b/drivers/gpu/drm/radeon/radeon_gem.c
@@ -371,7 +371,7 @@  int radeon_gem_userptr_ioctl(struct drm_device *dev, void *data,
 	uint32_t handle;
 	int r;
 
-	args->addr = untagged_addr(args->addr);
+	args->addr = untagged_addr(current->mm, args->addr);
 
 	if (offset_in_page(args->addr | args->size))
 		return -EINVAL;
diff --git a/drivers/infiniband/hw/mlx4/mr.c b/drivers/infiniband/hw/mlx4/mr.c
index 04a67b481608..b2860feeae3c 100644
--- a/drivers/infiniband/hw/mlx4/mr.c
+++ b/drivers/infiniband/hw/mlx4/mr.c
@@ -379,7 +379,7 @@  static struct ib_umem *mlx4_get_umem_mr(struct ib_device *device, u64 start,
 	 * again
 	 */
 	if (!ib_access_writable(access_flags)) {
-		unsigned long untagged_start = untagged_addr(start);
+		unsigned long untagged_start = untagged_addr(current->mm, start);
 		struct vm_area_struct *vma;
 
 		mmap_read_lock(current->mm);
diff --git a/drivers/media/common/videobuf2/frame_vector.c b/drivers/media/common/videobuf2/frame_vector.c
index 542dde9d2609..7e62f7a2555d 100644
--- a/drivers/media/common/videobuf2/frame_vector.c
+++ b/drivers/media/common/videobuf2/frame_vector.c
@@ -47,7 +47,7 @@  int get_vaddr_frames(unsigned long start, unsigned int nr_frames,
 	if (WARN_ON_ONCE(nr_frames > vec->nr_allocated))
 		nr_frames = vec->nr_allocated;
 
-	start = untagged_addr(start);
+	start = untagged_addr(mm, start);
 
 	ret = pin_user_pages_fast(start, nr_frames,
 				  FOLL_FORCE | FOLL_WRITE | FOLL_LONGTERM,
diff --git a/drivers/media/v4l2-core/videobuf-dma-contig.c b/drivers/media/v4l2-core/videobuf-dma-contig.c
index 52312ce2ba05..a1444f8afa05 100644
--- a/drivers/media/v4l2-core/videobuf-dma-contig.c
+++ b/drivers/media/v4l2-core/videobuf-dma-contig.c
@@ -157,8 +157,8 @@  static void videobuf_dma_contig_user_put(struct videobuf_dma_contig_memory *mem)
 static int videobuf_dma_contig_user_get(struct videobuf_dma_contig_memory *mem,
 					struct videobuf_buffer *vb)
 {
-	unsigned long untagged_baddr = untagged_addr(vb->baddr);
 	struct mm_struct *mm = current->mm;
+	unsigned long untagged_baddr = untagged_addr(mm, vb->baddr);
 	struct vm_area_struct *vma;
 	unsigned long prev_pfn, this_pfn;
 	unsigned long pages_done, user_address;
diff --git a/drivers/staging/media/atomisp/pci/hmm/hmm_bo.c b/drivers/staging/media/atomisp/pci/hmm/hmm_bo.c
index 0168f9839c90..863d30a7ad23 100644
--- a/drivers/staging/media/atomisp/pci/hmm/hmm_bo.c
+++ b/drivers/staging/media/atomisp/pci/hmm/hmm_bo.c
@@ -913,7 +913,7 @@  static int alloc_user_pages(struct hmm_buffer_object *bo,
 	 * and map to user space
 	 */
 
-	userptr = untagged_addr(userptr);
+	userptr = untagged_addr(current->mm, userptr);
 
 	bo->pages = pages;
 
diff --git a/drivers/tee/tee_shm.c b/drivers/tee/tee_shm.c
index f2b1bcefcadd..386be09cb2cd 100644
--- a/drivers/tee/tee_shm.c
+++ b/drivers/tee/tee_shm.c
@@ -261,7 +261,7 @@  register_shm_helper(struct tee_context *ctx, unsigned long addr,
 	shm->flags = flags;
 	shm->ctx = ctx;
 	shm->id = id;
-	addr = untagged_addr(addr);
+	addr = untagged_addr(current->mm, addr);
 	start = rounddown(addr, PAGE_SIZE);
 	shm->offset = addr - start;
 	shm->size = length;
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index c13b9290e357..5ac6c61d7caa 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -561,7 +561,7 @@  static int vaddr_get_pfns(struct mm_struct *mm, unsigned long vaddr,
 		goto done;
 	}
 
-	vaddr = untagged_addr(vaddr);
+	vaddr = untagged_addr(mm, vaddr);
 
 retry:
 	vma = vma_lookup(mm, vaddr);
diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c
index 2d04e3470d4c..c7d262bd6d6b 100644
--- a/fs/proc/task_mmu.c
+++ b/fs/proc/task_mmu.c
@@ -1659,7 +1659,7 @@  static ssize_t pagemap_read(struct file *file, char __user *buf,
 	/* watch out for wraparound */
 	start_vaddr = end_vaddr;
 	if (svpfn <= (ULONG_MAX >> PAGE_SHIFT))
-		start_vaddr = untagged_addr(svpfn << PAGE_SHIFT);
+		start_vaddr = untagged_addr(mm, svpfn << PAGE_SHIFT);
 
 	/* Ensure the address is inside the task */
 	if (start_vaddr > mm->task_size)
diff --git a/include/linux/mm.h b/include/linux/mm.h
index bc8f326be0ce..f0cb92ff1391 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -94,17 +94,6 @@  extern int mmap_rnd_compat_bits __read_mostly;
 #include <asm/page.h>
 #include <asm/processor.h>
 
-/*
- * Architectures that support memory tagging (assigning tags to memory regions,
- * embedding these tags into addresses that point to these memory regions, and
- * checking that the memory and the pointer tags match on memory accesses)
- * redefine this macro to strip tags from pointers.
- * It's defined as noop for architectures that don't support memory tagging.
- */
-#ifndef untagged_addr
-#define untagged_addr(addr) (addr)
-#endif
-
 #ifndef __pa_symbol
 #define __pa_symbol(x)  __pa(RELOC_HIDE((unsigned long)(x), 0))
 #endif
diff --git a/include/linux/uaccess.h b/include/linux/uaccess.h
index 5a328cf02b75..46fd816179d7 100644
--- a/include/linux/uaccess.h
+++ b/include/linux/uaccess.h
@@ -10,6 +10,21 @@ 
 
 #include <asm/uaccess.h>
 
+/*
+ * Architectures that support memory tagging (assigning tags to memory regions,
+ * embedding these tags into addresses that point to these memory regions, and
+ * checking that the memory and the pointer tags match on memory accesses)
+ * redefine this macro to strip tags from pointers.
+ *
+ * Passing down mm_struct allows to define untagging rules on per-process
+ * basis.
+ *
+ * It's defined as noop for architectures that don't support memory tagging.
+ */
+#ifndef untagged_addr
+#define untagged_addr(mm, addr) (addr)
+#endif
+
 /*
  * Architectures should provide two primitives (raw_copy_{to,from}_user())
  * and get rid of their private instances of copy_{to,from}_user() and
diff --git a/lib/strncpy_from_user.c b/lib/strncpy_from_user.c
index 6432b8c3e431..6e1e2aa0c994 100644
--- a/lib/strncpy_from_user.c
+++ b/lib/strncpy_from_user.c
@@ -121,7 +121,7 @@  long strncpy_from_user(char *dst, const char __user *src, long count)
 		return 0;
 
 	max_addr = TASK_SIZE_MAX;
-	src_addr = (unsigned long)untagged_addr(src);
+	src_addr = (unsigned long)untagged_addr(current->mm, src);
 	if (likely(src_addr < max_addr)) {
 		unsigned long max = max_addr - src_addr;
 		long retval;
diff --git a/lib/strnlen_user.c b/lib/strnlen_user.c
index feeb935a2299..abc096a68f05 100644
--- a/lib/strnlen_user.c
+++ b/lib/strnlen_user.c
@@ -97,7 +97,7 @@  long strnlen_user(const char __user *str, long count)
 		return 0;
 
 	max_addr = TASK_SIZE_MAX;
-	src_addr = (unsigned long)untagged_addr(str);
+	src_addr = (unsigned long)untagged_addr(current->mm, str);
 	if (likely(src_addr < max_addr)) {
 		unsigned long max = max_addr - src_addr;
 		long retval;
diff --git a/mm/gup.c b/mm/gup.c
index 551264407624..dbe825faf842 100644
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -1104,7 +1104,7 @@  static long __get_user_pages(struct mm_struct *mm,
 	if (!nr_pages)
 		return 0;
 
-	start = untagged_addr(start);
+	start = untagged_addr(mm, start);
 
 	VM_BUG_ON(!!pages != !!(gup_flags & (FOLL_GET | FOLL_PIN)));
 
@@ -1285,7 +1285,7 @@  int fixup_user_fault(struct mm_struct *mm,
 	struct vm_area_struct *vma;
 	vm_fault_t ret;
 
-	address = untagged_addr(address);
+	address = untagged_addr(mm, address);
 
 	if (unlocked)
 		fault_flags |= FAULT_FLAG_ALLOW_RETRY | FAULT_FLAG_KILLABLE;
@@ -2865,7 +2865,7 @@  static int internal_get_user_pages_fast(unsigned long start,
 	if (!(gup_flags & FOLL_FAST_ONLY))
 		might_lock_read(&current->mm->mmap_lock);
 
-	start = untagged_addr(start) & PAGE_MASK;
+	start = untagged_addr(current->mm, start) & PAGE_MASK;
 	len = nr_pages << PAGE_SHIFT;
 	if (check_add_overflow(start, len, &end))
 		return 0;
diff --git a/mm/madvise.c b/mm/madvise.c
index d7b4f2602949..e3c668ddb099 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -1373,7 +1373,7 @@  int do_madvise(struct mm_struct *mm, unsigned long start, size_t len_in, int beh
 	size_t len;
 	struct blk_plug plug;
 
-	start = untagged_addr(start);
+	start = untagged_addr(mm, start);
 
 	if (!madvise_behavior_valid(behavior))
 		return -EINVAL;
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index d39b01fd52fe..a03b4d2bc26a 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -1458,7 +1458,7 @@  static long kernel_mbind(unsigned long start, unsigned long len,
 	int lmode = mode;
 	int err;
 
-	start = untagged_addr(start);
+	start = untagged_addr(current->mm, start);
 	err = sanitize_mpol_flags(&lmode, &mode_flags);
 	if (err)
 		return err;
@@ -1481,7 +1481,7 @@  SYSCALL_DEFINE4(set_mempolicy_home_node, unsigned long, start, unsigned long, le
 	unsigned long end;
 	int err = -ENOENT;
 
-	start = untagged_addr(start);
+	start = untagged_addr(mm, start);
 	if (start & ~PAGE_MASK)
 		return -EINVAL;
 	/*
@@ -1684,7 +1684,7 @@  static int kernel_get_mempolicy(int __user *policy,
 	if (nmask != NULL && maxnode < nr_node_ids)
 		return -EINVAL;
 
-	addr = untagged_addr(addr);
+	addr = untagged_addr(current->mm, addr);
 
 	err = do_get_mempolicy(&pval, &nodes, addr, flags);
 
diff --git a/mm/migrate.c b/mm/migrate.c
index e51588e95f57..af05049b055b 100644
--- a/mm/migrate.c
+++ b/mm/migrate.c
@@ -1714,7 +1714,7 @@  static int do_pages_move(struct mm_struct *mm, nodemask_t task_nodes,
 			goto out_flush;
 		if (get_user(node, nodes + i))
 			goto out_flush;
-		addr = (unsigned long)untagged_addr(p);
+		addr = (unsigned long)untagged_addr(mm, p);
 
 		err = -ENODEV;
 		if (node < 0 || node >= MAX_NUMNODES)
diff --git a/mm/mincore.c b/mm/mincore.c
index fa200c14185f..72c55bd9d184 100644
--- a/mm/mincore.c
+++ b/mm/mincore.c
@@ -236,7 +236,7 @@  SYSCALL_DEFINE3(mincore, unsigned long, start, size_t, len,
 	unsigned long pages;
 	unsigned char *tmp;
 
-	start = untagged_addr(start);
+	start = untagged_addr(current->mm, start);
 
 	/* Check the start address: needs to be page-aligned.. */
 	if (start & ~PAGE_MASK)
diff --git a/mm/mlock.c b/mm/mlock.c
index 716caf851043..054168d3e648 100644
--- a/mm/mlock.c
+++ b/mm/mlock.c
@@ -571,7 +571,7 @@  static __must_check int do_mlock(unsigned long start, size_t len, vm_flags_t fla
 	unsigned long lock_limit;
 	int error = -ENOMEM;
 
-	start = untagged_addr(start);
+	start = untagged_addr(current->mm, start);
 
 	if (!can_do_mlock())
 		return -EPERM;
@@ -634,7 +634,7 @@  SYSCALL_DEFINE2(munlock, unsigned long, start, size_t, len)
 {
 	int ret;
 
-	start = untagged_addr(start);
+	start = untagged_addr(current->mm, start);
 
 	len = PAGE_ALIGN(len + (offset_in_page(start)));
 	start &= PAGE_MASK;
diff --git a/mm/mmap.c b/mm/mmap.c
index 61e6135c54ef..1a7baf6b6b8e 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -2926,7 +2926,7 @@  EXPORT_SYMBOL(vm_munmap);
 
 SYSCALL_DEFINE2(munmap, unsigned long, addr, size_t, len)
 {
-	addr = untagged_addr(addr);
+	addr = untagged_addr(current->mm, addr);
 	return __vm_munmap(addr, len, true);
 }
 
diff --git a/mm/mprotect.c b/mm/mprotect.c
index ba5592655ee3..871e954f6155 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -622,7 +622,7 @@  static int do_mprotect_pkey(unsigned long start, size_t len,
 				(prot & PROT_READ);
 	struct mmu_gather tlb;
 
-	start = untagged_addr(start);
+	start = untagged_addr(current->mm, start);
 
 	prot &= ~(PROT_GROWSDOWN|PROT_GROWSUP);
 	if (grows == (PROT_GROWSDOWN|PROT_GROWSUP)) /* can't be both */
diff --git a/mm/mremap.c b/mm/mremap.c
index b522cd0259a0..f76648bc4f67 100644
--- a/mm/mremap.c
+++ b/mm/mremap.c
@@ -906,7 +906,7 @@  SYSCALL_DEFINE5(mremap, unsigned long, addr, unsigned long, old_len,
 	 *
 	 * See Documentation/arm64/tagged-address-abi.rst for more information.
 	 */
-	addr = untagged_addr(addr);
+	addr = untagged_addr(mm, addr);
 
 	if (flags & ~(MREMAP_FIXED | MREMAP_MAYMOVE | MREMAP_DONTUNMAP))
 		return ret;
diff --git a/mm/msync.c b/mm/msync.c
index 137d1c104f3e..5fe989bd3c4b 100644
--- a/mm/msync.c
+++ b/mm/msync.c
@@ -37,7 +37,7 @@  SYSCALL_DEFINE3(msync, unsigned long, start, size_t, len, int, flags)
 	int unmapped_error = 0;
 	int error = -EINVAL;
 
-	start = untagged_addr(start);
+	start = untagged_addr(mm, start);
 
 	if (flags & ~(MS_ASYNC | MS_INVALIDATE | MS_SYNC))
 		goto out;
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index a49df8988cd6..03f7ad0ebc8a 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -1876,7 +1876,7 @@  int __kvm_set_memory_region(struct kvm *kvm,
 		return -EINVAL;
 	/* We can read the guest memory with __xxx_user() later on. */
 	if ((mem->userspace_addr & (PAGE_SIZE - 1)) ||
-	    (mem->userspace_addr != untagged_addr(mem->userspace_addr)) ||
+	    (mem->userspace_addr != untagged_addr(kvm->mm, mem->userspace_addr)) ||
 	     !access_ok((void __user *)(unsigned long)mem->userspace_addr,
 			mem->memory_size))
 		return -EINVAL;