@@ -142,7 +142,7 @@ struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr,
mmgrab(mm);
if (access & IB_ACCESS_ON_DEMAND) {
- ret = ib_umem_odp_get(to_ib_umem_odp(umem), access);
+ ret = ib_umem_odp_get(to_ib_umem_odp(umem), access, owner);
if (ret)
goto umem_kfree;
return umem;
@@ -200,7 +200,7 @@ struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr,
mm, cur_base,
min_t(unsigned long, npages,
PAGE_SIZE / sizeof(struct page *)),
- gup_flags, page_list, vma_list, NULL);
+ gup_flags, page_list, vma_list);
if (ret < 0) {
up_read(&mm->mmap_sem);
goto umem_release;
@@ -227,7 +227,8 @@ static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
}
static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
- struct mm_struct *mm)
+ struct mm_struct *mm,
+ struct pid *owner)
{
struct ib_ucontext_per_mm *per_mm;
int ret;
@@ -241,12 +242,8 @@ static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
per_mm->umem_tree = RB_ROOT_CACHED;
init_rwsem(&per_mm->umem_rwsem);
per_mm->active = ctx->invalidate_range;
-
- rcu_read_lock();
- per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
- rcu_read_unlock();
-
- WARN_ON(mm != current->mm);
+ per_mm->tgid = owner;
+ mmgrab(per_mm->mm);
per_mm->mn.ops = &ib_umem_notifiers;
ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
@@ -265,7 +262,7 @@ static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
return ERR_PTR(ret);
}
-static int get_per_mm(struct ib_umem_odp *umem_odp)
+static int get_per_mm(struct ib_umem_odp *umem_odp, struct pid *owner)
{
struct ib_ucontext *ctx = umem_odp->umem.context;
struct ib_ucontext_per_mm *per_mm;
@@ -280,7 +277,7 @@ static int get_per_mm(struct ib_umem_odp *umem_odp)
goto found;
}
- per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
+ per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm, owner);
if (IS_ERR(per_mm)) {
mutex_unlock(&ctx->per_mm_list_lock);
return PTR_ERR(per_mm);
@@ -333,7 +330,8 @@ void put_per_mm(struct ib_umem_odp *umem_odp)
}
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
- unsigned long addr, size_t size)
+ unsigned long addr, size_t size,
+ struct mm_struct *owner_mm)
{
struct ib_ucontext *ctx = per_mm->context;
struct ib_umem_odp *odp_data;
@@ -345,12 +343,14 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
if (!odp_data)
return ERR_PTR(-ENOMEM);
umem = &odp_data->umem;
+
umem->context = ctx;
umem->length = size;
umem->address = addr;
umem->page_shift = PAGE_SHIFT;
umem->writable = 1;
umem->is_odp = 1;
+ umem->owning_mm = owner_mm;
odp_data->per_mm = per_mm;
mutex_init(&odp_data->umem_mutex);
@@ -389,13 +389,9 @@ struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
}
EXPORT_SYMBOL(ib_alloc_odp_umem);
-int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
+int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access, struct pid *owner)
{
struct ib_umem *umem = &umem_odp->umem;
- /*
- * NOTE: This must called in a process context where umem->owning_mm
- * == current->mm
- */
struct mm_struct *mm = umem->owning_mm;
int ret_val;
@@ -437,7 +433,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
}
}
- ret_val = get_per_mm(umem_odp);
+ ret_val = get_per_mm(umem_odp, owner);
if (ret_val)
goto out_dma_list;
add_umem_to_per_mm(umem_odp);
@@ -574,8 +570,8 @@ static int ib_umem_odp_map_dma_single_page(
* the return value.
* @access_mask: bit mask of the requested access permissions for the given
* range.
- * @current_seq: the MMU notifiers sequance value for synchronization with
- * invalidations. the sequance number is read from
+ * @current_seq: the MMU notifiers sequence value for synchronization with
+ * invalidations. the sequence number is read from
* umem_odp->notifiers_seq before calling this function
*/
int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
@@ -584,7 +580,7 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
{
struct ib_umem *umem = &umem_odp->umem;
struct task_struct *owning_process = NULL;
- struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
+ struct mm_struct *owning_mm;
struct page **local_page_list = NULL;
u64 page_mask, off;
int j, k, ret = 0, start_idx, npages = 0, page_shift;
@@ -609,12 +605,13 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
bcnt += off; /* Charge for the first page offset as well. */
/*
- * owning_process is allowed to be NULL, this means somehow the mm is
- * existing beyond the lifetime of the originating process.. Presumably
+ * owning_process may be NULL, because the mm can
+ * exist independently of the originating process.
* mmget_not_zero will fail in this case.
*/
owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
- if (WARN_ON(!mmget_not_zero(umem_odp->umem.owning_mm))) {
+ owning_mm = umem_odp->per_mm->mm;
+ if (WARN_ON(!mmget_not_zero(owning_mm))) {
ret = -EINVAL;
goto out_put_task;
}
@@ -632,15 +629,16 @@ int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
down_read(&owning_mm->mmap_sem);
/*
- * Note: this might result in redundent page getting. We can
+ * Note: this might result in redundant page getting. We can
* avoid this by checking dma_list to be 0 before calling
- * get_user_pages. However, this make the code much more
+ * get_user_pages. However, this makes the code much more
* complex (and doesn't gain us much performance in most use
* cases).
*/
- npages = get_user_pages_remote(owning_process, owning_mm,
+ npages = get_user_pages_remote_longterm(owning_process,
+ owning_mm,
user_virt, gup_num_pages,
- flags, local_page_list, NULL, NULL);
+ flags, local_page_list, NULL);
up_read(&owning_mm->mmap_sem);
if (npages < 0) {
@@ -439,8 +439,12 @@ static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
if (nentries)
nentries++;
} else {
+ struct mm_struct *owner_mm = current->mm;
+
+ if (mr->umem->owning_mm)
+ owner_mm = mr->umem->owning_mm;
odp = ib_alloc_odp_umem(odp_mr->per_mm, addr,
- MLX5_IMR_MTT_SIZE);
+ MLX5_IMR_MTT_SIZE, owner_mm);
if (IS_ERR(odp)) {
mutex_unlock(&odp_mr->umem_mutex);
return ERR_CAST(odp);
@@ -102,9 +102,11 @@ struct ib_ucontext_per_mm {
struct rcu_head rcu;
};
-int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access);
+int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access,
+ struct pid *owner);
struct ib_umem_odp *ib_alloc_odp_umem(struct ib_ucontext_per_mm *per_mm,
- unsigned long addr, size_t size);
+ unsigned long addr, size_t size,
+ struct mm_struct *owner_mm);
void ib_umem_odp_release(struct ib_umem_odp *umem_odp);
/*
Propagate the change of adding the owner parameter to several internal core functions, as well as the ib_umem_odp_get() kernel interface function. The mm of the address space that owns the memory region is saved in the per_mm struct, which is then used by ib_umem_odp_map_dma_pages() when resolving a page fault from ODP. Signed-off-by: Joel Nider <joeln@il.ibm.com> --- drivers/infiniband/core/umem.c | 4 +-- drivers/infiniband/core/umem_odp.c | 50 ++++++++++++++++++-------------------- drivers/infiniband/hw/mlx5/odp.c | 6 ++++- include/rdma/ib_umem_odp.h | 6 +++-- 4 files changed, 35 insertions(+), 31 deletions(-)