diff mbox series

[RFC,for-next,2/2] IB/{core,hw}: ib_umem_get extract ib_ucontext from ib_udata

Message ID 20190107143930.4922-3-shamir.rabinovitch@oracle.com (mailing list archive)
State Superseded
Headers show
Series ib_umem_get extract ib_ucontext from ib_udata | expand

Commit Message

Shamir Rabinovitch Jan. 7, 2019, 2:39 p.m. UTC
From: shamir rabinovitch <shamir.rabinovitch@oracle.com>

ib_umem_get can only be called in a method callback, which always has a
udata parameter. This allows ib_umem_get() to derive the ucontext pointer
directly from the udata without requiring the drivers to find it in some
way or another.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Signed-off-by: shamir rabinovitch <shamir.rabinovitch@oracle.com>
---
 drivers/infiniband/core/umem.c                | 17 ++++++++++--
 drivers/infiniband/core/uverbs_main.c         | 27 +++++++++++++++++++
 drivers/infiniband/hw/bnxt_re/ib_verbs.c      | 10 +++----
 drivers/infiniband/hw/cxgb3/iwch_provider.c   |  2 +-
 drivers/infiniband/hw/cxgb4/mem.c             |  2 +-
 drivers/infiniband/hw/hns/hns_roce_cq.c       |  6 ++---
 drivers/infiniband/hw/hns/hns_roce_db.c       |  5 ++--
 drivers/infiniband/hw/hns/hns_roce_mr.c       |  4 +--
 drivers/infiniband/hw/hns/hns_roce_qp.c       |  2 +-
 drivers/infiniband/hw/hns/hns_roce_srq.c      |  4 +--
 drivers/infiniband/hw/i40iw/i40iw_verbs.c     |  2 +-
 drivers/infiniband/hw/mlx4/cq.c               | 12 ++++-----
 drivers/infiniband/hw/mlx4/doorbell.c         |  5 ++--
 drivers/infiniband/hw/mlx4/mlx4_ib.h          |  3 ++-
 drivers/infiniband/hw/mlx4/mr.c               | 11 ++++----
 drivers/infiniband/hw/mlx4/qp.c               |  5 ++--
 drivers/infiniband/hw/mlx4/srq.c              |  4 +--
 drivers/infiniband/hw/mlx5/cq.c               |  7 +++--
 drivers/infiniband/hw/mlx5/devx.c             |  2 +-
 drivers/infiniband/hw/mlx5/doorbell.c         |  5 ++--
 drivers/infiniband/hw/mlx5/mlx5_ib.h          |  4 ++-
 drivers/infiniband/hw/mlx5/mr.c               | 22 +++++++--------
 drivers/infiniband/hw/mlx5/odp.c              |  4 +--
 drivers/infiniband/hw/mlx5/qp.c               | 24 ++++++++---------
 drivers/infiniband/hw/mlx5/srq.c              |  4 +--
 drivers/infiniband/hw/mthca/mthca_provider.c  |  2 +-
 drivers/infiniband/hw/nes/nes_verbs.c         |  2 +-
 drivers/infiniband/hw/ocrdma/ocrdma_verbs.c   |  2 +-
 drivers/infiniband/hw/qedr/verbs.c            | 24 ++++++++---------
 drivers/infiniband/hw/vmw_pvrdma/pvrdma_cq.c  |  2 +-
 drivers/infiniband/hw/vmw_pvrdma/pvrdma_mr.c  |  2 +-
 drivers/infiniband/hw/vmw_pvrdma/pvrdma_qp.c  |  4 +--
 drivers/infiniband/hw/vmw_pvrdma/pvrdma_srq.c |  2 +-
 drivers/infiniband/sw/rdmavt/mr.c             |  2 +-
 drivers/infiniband/sw/rxe/rxe_mr.c            |  2 +-
 include/rdma/ib_umem.h                        |  5 ++--
 include/rdma/ib_verbs.h                       |  2 ++
 37 files changed, 147 insertions(+), 97 deletions(-)
diff mbox series

Patch

diff --git a/drivers/infiniband/core/umem.c b/drivers/infiniband/core/umem.c
index c6144df47ea4..d4ad9d86259e 100644
--- a/drivers/infiniband/core/umem.c
+++ b/drivers/infiniband/core/umem.c
@@ -72,15 +72,16 @@  static void __ib_umem_release(struct ib_device *dev, struct ib_umem *umem, int d
  * If access flags indicate ODP memory, avoid pinning. Instead, stores
  * the mm for future page fault handling in conjunction with MMU notifiers.
  *
- * @context: userspace context to pin memory for
+ * @udata: userspace context to pin memory for
  * @addr: userspace virtual address to start at
  * @size: length of region to pin
  * @access: IB_ACCESS_xxx flags for memory being pinned
  * @dmasync: flush in-flight DMA when the memory region is written
  */
-struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr,
+struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
 			    size_t size, int access, int dmasync)
 {
+	struct ib_ucontext *context = NULL;
 	struct ib_umem *umem;
 	struct page **page_list;
 	struct vm_area_struct **vma_list;
@@ -94,6 +95,18 @@  struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr,
 	unsigned long dma_attrs = 0;
 	struct scatterlist *sg, *sg_list_start;
 	unsigned int gup_flags = FOLL_WRITE;
+	struct ib_uverbs *uverbs;
+
+	rcu_read_lock();
+	uverbs = ib_get_uverbs();
+	if (uverbs)
+		context = uverbs->rdma_get_ucontext(udata);
+	rcu_read_unlock();
+
+	if (!context)
+		return ERR_PTR(-EINVAL);
+	if (IS_ERR(context))
+		return ERR_CAST(context);
 
 	if (dmasync)
 		dma_attrs |= DMA_ATTR_WRITE_BARRIER;
diff --git a/drivers/infiniband/core/uverbs_main.c b/drivers/infiniband/core/uverbs_main.c
index e0574cd5e850..652d019bcf44 100644
--- a/drivers/infiniband/core/uverbs_main.c
+++ b/drivers/infiniband/core/uverbs_main.c
@@ -101,6 +101,30 @@  struct ib_ucontext *ib_uverbs_get_ucontext_file(struct ib_uverbs_file *ufile)
 }
 EXPORT_SYMBOL(ib_uverbs_get_ucontext_file);
 
+/* rdma_get_ucontext - Return the ucontext from a udata
+ * @udata: The udata to get the context from
+ *
+ * This can only be called from within a uapi method that was passed ib_udata
+ * as a parameter. It returns the ucontext associated with the udata, or ERR_PTR
+ * if the udata is NULL or the ucontext has been disassociated.
+ */
+struct ib_ucontext *rdma_get_ucontext(struct ib_udata *udata)
+{
+	if (!udata)
+		return ERR_PTR(-EIO);
+
+	/*
+	 * FIXME: Really all cases that get here with a udata will have
+	 * already called ib_uverbs_get_ucontext_file, or located a uobject
+	 * that points to a ucontext. We could store that result in the udata
+	 * so this function can't fail.
+	 */
+	return ib_uverbs_get_ucontext_file(
+		container_of(udata, struct uverbs_attr_bundle, driver_udata)
+			->ufile);
+}
+EXPORT_SYMBOL(rdma_get_ucontext);
+
 int uverbs_dealloc_mw(struct ib_mw *mw)
 {
 	struct ib_pd *pd = mw->pd;
@@ -1410,6 +1434,9 @@  static int __init ib_uverbs_init(void)
 		goto out_class;
 	}
 
+	/* uverbs callbacks used by ib_core */
+	uverbs.rdma_get_ucontext = rdma_get_ucontext;
+
 	ret = ib_register_uverbs(&uverbs);
 	if (ret) {
 		pr_err("user_verbs: couldn't register uverbs\n");
diff --git a/drivers/infiniband/hw/bnxt_re/ib_verbs.c b/drivers/infiniband/hw/bnxt_re/ib_verbs.c
index 611bacd00b80..91b606aa8f96 100644
--- a/drivers/infiniband/hw/bnxt_re/ib_verbs.c
+++ b/drivers/infiniband/hw/bnxt_re/ib_verbs.c
@@ -892,7 +892,7 @@  static int bnxt_re_init_user_qp(struct bnxt_re_dev *rdev, struct bnxt_re_pd *pd,
 	if (qplib_qp->type == CMDQ_CREATE_QP_TYPE_RC)
 		bytes += (qplib_qp->sq.max_wqe * sizeof(struct sq_psn_search));
 	bytes = PAGE_ALIGN(bytes);
-	umem = ib_umem_get(context, ureq.qpsva, bytes,
+	umem = ib_umem_get(udata, ureq.qpsva, bytes,
 			   IB_ACCESS_LOCAL_WRITE, 1);
 	if (IS_ERR(umem))
 		return PTR_ERR(umem);
@@ -905,7 +905,7 @@  static int bnxt_re_init_user_qp(struct bnxt_re_dev *rdev, struct bnxt_re_pd *pd,
 	if (!qp->qplib_qp.srq) {
 		bytes = (qplib_qp->rq.max_wqe * BNXT_QPLIB_MAX_RQE_ENTRY_SIZE);
 		bytes = PAGE_ALIGN(bytes);
-		umem = ib_umem_get(context, ureq.qprva, bytes,
+		umem = ib_umem_get(udata, ureq.qprva, bytes,
 				   IB_ACCESS_LOCAL_WRITE, 1);
 		if (IS_ERR(umem))
 			goto rqfail;
@@ -1367,7 +1367,7 @@  static int bnxt_re_init_user_srq(struct bnxt_re_dev *rdev,
 
 	bytes = (qplib_srq->max_wqe * BNXT_QPLIB_MAX_RQE_ENTRY_SIZE);
 	bytes = PAGE_ALIGN(bytes);
-	umem = ib_umem_get(context, ureq.srqva, bytes,
+	umem = ib_umem_get(udata, ureq.srqva, bytes,
 			   IB_ACCESS_LOCAL_WRITE, 1);
 	if (IS_ERR(umem))
 		return PTR_ERR(umem);
@@ -2619,7 +2619,7 @@  struct ib_cq *bnxt_re_create_cq(struct ib_device *ibdev,
 			goto fail;
 		}
 
-		cq->umem = ib_umem_get(context, req.cq_va,
+		cq->umem = ib_umem_get(udata, req.cq_va,
 				       entries * sizeof(struct cq_base),
 				       IB_ACCESS_LOCAL_WRITE, 1);
 		if (IS_ERR(cq->umem)) {
@@ -3586,7 +3586,7 @@  struct ib_mr *bnxt_re_reg_user_mr(struct ib_pd *ib_pd, u64 start, u64 length,
 	/* The fixed portion of the rkey is the same as the lkey */
 	mr->ib_mr.rkey = mr->qplib_mr.rkey;
 
-	umem = ib_umem_get(ib_pd->uobject->context, start, length,
+	umem = ib_umem_get(udata, start, length,
 			   mr_access_flags, 0);
 	if (IS_ERR(umem)) {
 		dev_err(rdev_to_dev(rdev), "Failed to get umem");
diff --git a/drivers/infiniband/hw/cxgb3/iwch_provider.c b/drivers/infiniband/hw/cxgb3/iwch_provider.c
index b34b1a1bd94b..92ee6761a3bd 100644
--- a/drivers/infiniband/hw/cxgb3/iwch_provider.c
+++ b/drivers/infiniband/hw/cxgb3/iwch_provider.c
@@ -540,7 +540,7 @@  static struct ib_mr *iwch_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 
 	mhp->rhp = rhp;
 
-	mhp->umem = ib_umem_get(pd->uobject->context, start, length, acc, 0);
+	mhp->umem = ib_umem_get(udata, start, length, acc, 0);
 	if (IS_ERR(mhp->umem)) {
 		err = PTR_ERR(mhp->umem);
 		kfree(mhp);
diff --git a/drivers/infiniband/hw/cxgb4/mem.c b/drivers/infiniband/hw/cxgb4/mem.c
index 7b76e6f81aeb..96760a36b9fc 100644
--- a/drivers/infiniband/hw/cxgb4/mem.c
+++ b/drivers/infiniband/hw/cxgb4/mem.c
@@ -537,7 +537,7 @@  struct ib_mr *c4iw_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 
 	mhp->rhp = rhp;
 
-	mhp->umem = ib_umem_get(pd->uobject->context, start, length, acc, 0);
+	mhp->umem = ib_umem_get(udata, start, length, acc, 0);
 	if (IS_ERR(mhp->umem))
 		goto err_free_skb;
 
diff --git a/drivers/infiniband/hw/hns/hns_roce_cq.c b/drivers/infiniband/hw/hns/hns_roce_cq.c
index 3a485f50fede..0f86196fd619 100644
--- a/drivers/infiniband/hw/hns/hns_roce_cq.c
+++ b/drivers/infiniband/hw/hns/hns_roce_cq.c
@@ -215,7 +215,7 @@  void hns_roce_free_cq(struct hns_roce_dev *hr_dev, struct hns_roce_cq *hr_cq)
 EXPORT_SYMBOL_GPL(hns_roce_free_cq);
 
 static int hns_roce_ib_get_cq_umem(struct hns_roce_dev *hr_dev,
-				   struct ib_ucontext *context,
+				   struct ib_udata *udata,
 				   struct hns_roce_cq_buf *buf,
 				   struct ib_umem **umem, u64 buf_addr, int cqe)
 {
@@ -223,7 +223,7 @@  static int hns_roce_ib_get_cq_umem(struct hns_roce_dev *hr_dev,
 	u32 page_shift;
 	u32 npages;
 
-	*umem = ib_umem_get(context, buf_addr, cqe * hr_dev->caps.cq_entry_sz,
+	*umem = ib_umem_get(udata, buf_addr, cqe * hr_dev->caps.cq_entry_sz,
 			    IB_ACCESS_LOCAL_WRITE, 1);
 	if (IS_ERR(*umem))
 		return PTR_ERR(*umem);
@@ -347,7 +347,7 @@  struct ib_cq *hns_roce_ib_create_cq(struct ib_device *ib_dev,
 		}
 
 		/* Get user space address, write it into mtt table */
-		ret = hns_roce_ib_get_cq_umem(hr_dev, context, &hr_cq->hr_buf,
+		ret = hns_roce_ib_get_cq_umem(hr_dev, udata, &hr_cq->hr_buf,
 					      &hr_cq->umem, ucmd.buf_addr,
 					      cq_entries);
 		if (ret) {
diff --git a/drivers/infiniband/hw/hns/hns_roce_db.c b/drivers/infiniband/hw/hns/hns_roce_db.c
index e2f93c1ce86a..78704925f3f4 100644
--- a/drivers/infiniband/hw/hns/hns_roce_db.c
+++ b/drivers/infiniband/hw/hns/hns_roce_db.c
@@ -8,7 +8,8 @@ 
 #include <rdma/ib_umem.h>
 #include "hns_roce_device.h"
 
-int hns_roce_db_map_user(struct hns_roce_ucontext *context, unsigned long virt,
+int hns_roce_db_map_user(struct hns_roce_ucontext *context,
+			 struct ib_udata *udata, unsigned long virt,
 			 struct hns_roce_db *db)
 {
 	struct hns_roce_user_db_page *page;
@@ -28,7 +29,7 @@  int hns_roce_db_map_user(struct hns_roce_ucontext *context, unsigned long virt,
 
 	refcount_set(&page->refcount, 1);
 	page->user_virt = (virt & PAGE_MASK);
-	page->umem = ib_umem_get(&context->ibucontext, virt & PAGE_MASK,
+	page->umem = ib_umem_get(udata, virt & PAGE_MASK,
 				 PAGE_SIZE, 0, 0);
 	if (IS_ERR(page->umem)) {
 		ret = PTR_ERR(page->umem);
diff --git a/drivers/infiniband/hw/hns/hns_roce_mr.c b/drivers/infiniband/hw/hns/hns_roce_mr.c
index ee5991bd4171..c7e3d819c5e9 100644
--- a/drivers/infiniband/hw/hns/hns_roce_mr.c
+++ b/drivers/infiniband/hw/hns/hns_roce_mr.c
@@ -1110,7 +1110,7 @@  struct ib_mr *hns_roce_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 	if (!mr)
 		return ERR_PTR(-ENOMEM);
 
-	mr->umem = ib_umem_get(pd->uobject->context, start, length,
+	mr->umem = ib_umem_get(udata, start, length,
 			       access_flags, 0);
 	if (IS_ERR(mr->umem)) {
 		ret = PTR_ERR(mr->umem);
@@ -1220,7 +1220,7 @@  int hns_roce_rereg_user_mr(struct ib_mr *ibmr, int flags, u64 start, u64 length,
 		}
 		ib_umem_release(mr->umem);
 
-		mr->umem = ib_umem_get(ibmr->uobject->context, start, length,
+		mr->umem = ib_umem_get(udata, start, length,
 				       mr_access_flags, 0);
 		if (IS_ERR(mr->umem)) {
 			ret = PTR_ERR(mr->umem);
diff --git a/drivers/infiniband/hw/hns/hns_roce_qp.c b/drivers/infiniband/hw/hns/hns_roce_qp.c
index 54031c5b53fa..492af0ce7d86 100644
--- a/drivers/infiniband/hw/hns/hns_roce_qp.c
+++ b/drivers/infiniband/hw/hns/hns_roce_qp.c
@@ -612,7 +612,7 @@  static int hns_roce_create_qp_common(struct hns_roce_dev *hr_dev,
 			goto err_rq_sge_list;
 		}
 
-		hr_qp->umem = ib_umem_get(ib_pd->uobject->context,
+		hr_qp->umem = ib_umem_get(udata,
 					  ucmd.buf_addr, hr_qp->buff_size, 0,
 					  0);
 		if (IS_ERR(hr_qp->umem)) {
diff --git a/drivers/infiniband/hw/hns/hns_roce_srq.c b/drivers/infiniband/hw/hns/hns_roce_srq.c
index 960b1946c365..28a1940f8978 100644
--- a/drivers/infiniband/hw/hns/hns_roce_srq.c
+++ b/drivers/infiniband/hw/hns/hns_roce_srq.c
@@ -252,7 +252,7 @@  struct ib_srq *hns_roce_create_srq(struct ib_pd *pd,
 			goto err_srq;
 		}
 
-		srq->umem = ib_umem_get(pd->uobject->context, ucmd.buf_addr,
+		srq->umem = ib_umem_get(udata, ucmd.buf_addr,
 					srq_buf_size, 0, 0);
 		if (IS_ERR(srq->umem)) {
 			ret = PTR_ERR(srq->umem);
@@ -280,7 +280,7 @@  struct ib_srq *hns_roce_create_srq(struct ib_pd *pd,
 			goto err_srq_mtt;
 
 		/* config index queue BA */
-		srq->idx_que.umem = ib_umem_get(pd->uobject->context,
+		srq->idx_que.umem = ib_umem_get(udata,
 						ucmd.que_addr,
 						srq->idx_que.buf_size, 0, 0);
 		if (IS_ERR(srq->idx_que.umem)) {
diff --git a/drivers/infiniband/hw/i40iw/i40iw_verbs.c b/drivers/infiniband/hw/i40iw/i40iw_verbs.c
index 0b675b0742c2..80b66df95362 100644
--- a/drivers/infiniband/hw/i40iw/i40iw_verbs.c
+++ b/drivers/infiniband/hw/i40iw/i40iw_verbs.c
@@ -1852,7 +1852,7 @@  static struct ib_mr *i40iw_reg_user_mr(struct ib_pd *pd,
 
 	if (length > I40IW_MAX_MR_SIZE)
 		return ERR_PTR(-EINVAL);
-	region = ib_umem_get(pd->uobject->context, start, length, acc, 0);
+	region = ib_umem_get(udata, start, length, acc, 0);
 	if (IS_ERR(region))
 		return (struct ib_mr *)region;
 
diff --git a/drivers/infiniband/hw/mlx4/cq.c b/drivers/infiniband/hw/mlx4/cq.c
index 82adc0d1d30e..4bfe1d779836 100644
--- a/drivers/infiniband/hw/mlx4/cq.c
+++ b/drivers/infiniband/hw/mlx4/cq.c
@@ -134,7 +134,7 @@  static void mlx4_ib_free_cq_buf(struct mlx4_ib_dev *dev, struct mlx4_ib_cq_buf *
 	mlx4_buf_free(dev->dev, (cqe + 1) * buf->entry_size, &buf->buf);
 }
 
-static int mlx4_ib_get_cq_umem(struct mlx4_ib_dev *dev, struct ib_ucontext *context,
+static int mlx4_ib_get_cq_umem(struct mlx4_ib_dev *dev, struct ib_udata *udata,
 			       struct mlx4_ib_cq_buf *buf, struct ib_umem **umem,
 			       u64 buf_addr, int cqe)
 {
@@ -143,7 +143,7 @@  static int mlx4_ib_get_cq_umem(struct mlx4_ib_dev *dev, struct ib_ucontext *cont
 	int shift;
 	int n;
 
-	*umem = ib_umem_get(context, buf_addr, cqe * cqe_size,
+	*umem = ib_umem_get(udata, buf_addr, cqe * cqe_size,
 			    IB_ACCESS_LOCAL_WRITE, 1);
 	if (IS_ERR(*umem))
 		return PTR_ERR(*umem);
@@ -211,13 +211,13 @@  struct ib_cq *mlx4_ib_create_cq(struct ib_device *ibdev,
 			goto err_cq;
 		}
 
-		err = mlx4_ib_get_cq_umem(dev, context, &cq->buf, &cq->umem,
+		err = mlx4_ib_get_cq_umem(dev, udata, &cq->buf, &cq->umem,
 					  ucmd.buf_addr, entries);
 		if (err)
 			goto err_cq;
 
-		err = mlx4_ib_db_map_user(to_mucontext(context), ucmd.db_addr,
-					  &cq->db);
+		err = mlx4_ib_db_map_user(to_mucontext(context), udata,
+					  ucmd.db_addr, &cq->db);
 		if (err)
 			goto err_mtt;
 
@@ -329,7 +329,7 @@  static int mlx4_alloc_resize_umem(struct mlx4_ib_dev *dev, struct mlx4_ib_cq *cq
 	if (!cq->resize_buf)
 		return -ENOMEM;
 
-	err = mlx4_ib_get_cq_umem(dev, cq->umem->context, &cq->resize_buf->buf,
+	err = mlx4_ib_get_cq_umem(dev, udata, &cq->resize_buf->buf,
 				  &cq->resize_umem, ucmd.buf_addr, entries);
 	if (err) {
 		kfree(cq->resize_buf);
diff --git a/drivers/infiniband/hw/mlx4/doorbell.c b/drivers/infiniband/hw/mlx4/doorbell.c
index c51740986367..1243bc4dc546 100644
--- a/drivers/infiniband/hw/mlx4/doorbell.c
+++ b/drivers/infiniband/hw/mlx4/doorbell.c
@@ -41,7 +41,8 @@  struct mlx4_ib_user_db_page {
 	int			refcnt;
 };
 
-int mlx4_ib_db_map_user(struct mlx4_ib_ucontext *context, unsigned long virt,
+int mlx4_ib_db_map_user(struct mlx4_ib_ucontext *context,
+			struct ib_udata *udata, unsigned long virt,
 			struct mlx4_db *db)
 {
 	struct mlx4_ib_user_db_page *page;
@@ -61,7 +62,7 @@  int mlx4_ib_db_map_user(struct mlx4_ib_ucontext *context, unsigned long virt,
 
 	page->user_virt = (virt & PAGE_MASK);
 	page->refcnt    = 0;
-	page->umem      = ib_umem_get(&context->ibucontext, virt & PAGE_MASK,
+	page->umem      = ib_umem_get(udata, virt & PAGE_MASK,
 				      PAGE_SIZE, 0, 0);
 	if (IS_ERR(page->umem)) {
 		err = PTR_ERR(page->umem);
diff --git a/drivers/infiniband/hw/mlx4/mlx4_ib.h b/drivers/infiniband/hw/mlx4/mlx4_ib.h
index 5cb52424912e..8f4fb78c5a26 100644
--- a/drivers/infiniband/hw/mlx4/mlx4_ib.h
+++ b/drivers/infiniband/hw/mlx4/mlx4_ib.h
@@ -722,7 +722,8 @@  static inline u8 mlx4_ib_bond_next_port(struct mlx4_ib_dev *dev)
 int mlx4_ib_init_sriov(struct mlx4_ib_dev *dev);
 void mlx4_ib_close_sriov(struct mlx4_ib_dev *dev);
 
-int mlx4_ib_db_map_user(struct mlx4_ib_ucontext *context, unsigned long virt,
+int mlx4_ib_db_map_user(struct mlx4_ib_ucontext *context,
+			struct ib_udata *udata, unsigned long virt,
 			struct mlx4_db *db);
 void mlx4_ib_db_unmap_user(struct mlx4_ib_ucontext *context, struct mlx4_db *db);
 
diff --git a/drivers/infiniband/hw/mlx4/mr.c b/drivers/infiniband/hw/mlx4/mr.c
index c7c85c22e4e3..56639ecd53ad 100644
--- a/drivers/infiniband/hw/mlx4/mr.c
+++ b/drivers/infiniband/hw/mlx4/mr.c
@@ -367,7 +367,8 @@  int mlx4_ib_umem_calc_optimal_mtt_size(struct ib_umem *umem, u64 start_va,
 	return block_shift;
 }
 
-static struct ib_umem *mlx4_get_umem_mr(struct ib_ucontext *context, u64 start,
+static struct ib_umem *mlx4_get_umem_mr(struct ib_ucontext *context,
+					struct ib_udata *udata, u64 start,
 					u64 length, u64 virt_addr,
 					int access_flags)
 {
@@ -398,7 +399,7 @@  static struct ib_umem *mlx4_get_umem_mr(struct ib_ucontext *context, u64 start,
 		up_read(&current->mm->mmap_sem);
 	}
 
-	return ib_umem_get(context, start, length, access_flags, 0);
+	return ib_umem_get(udata, start, length, access_flags, 0);
 }
 
 struct ib_mr *mlx4_ib_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
@@ -415,7 +416,7 @@  struct ib_mr *mlx4_ib_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 	if (!mr)
 		return ERR_PTR(-ENOMEM);
 
-	mr->umem = mlx4_get_umem_mr(pd->uobject->context, start, length,
+	mr->umem = mlx4_get_umem_mr(pd->uobject->context, udata, start, length,
 				    virt_addr, access_flags);
 	if (IS_ERR(mr->umem)) {
 		err = PTR_ERR(mr->umem);
@@ -506,8 +507,8 @@  int mlx4_ib_rereg_user_mr(struct ib_mr *mr, int flags,
 		mlx4_mr_rereg_mem_cleanup(dev->dev, &mmr->mmr);
 		ib_umem_release(mmr->umem);
 		mmr->umem =
-			mlx4_get_umem_mr(mr->uobject->context, start, length,
-					 virt_addr, mr_access_flags);
+			mlx4_get_umem_mr(mr->uobject->context, udata, start,
+					 length, virt_addr, mr_access_flags);
 		if (IS_ERR(mmr->umem)) {
 			err = PTR_ERR(mmr->umem);
 			/* Prevent mlx4_ib_dereg_mr from free'ing invalid pointer */
diff --git a/drivers/infiniband/hw/mlx4/qp.c b/drivers/infiniband/hw/mlx4/qp.c
index 24ee30f1cb45..027a31ed8051 100644
--- a/drivers/infiniband/hw/mlx4/qp.c
+++ b/drivers/infiniband/hw/mlx4/qp.c
@@ -1015,7 +1015,7 @@  static int create_qp_common(struct mlx4_ib_dev *dev, struct ib_pd *pd,
 				(qp->sq.wqe_cnt << qp->sq.wqe_shift);
 		}
 
-		qp->umem = ib_umem_get(pd->uobject->context,
+		qp->umem = ib_umem_get(udata,
 				(src == MLX4_IB_QP_SRC) ? ucmd.qp.buf_addr :
 				ucmd.wq.buf_addr, qp->buf_size, 0, 0);
 		if (IS_ERR(qp->umem)) {
@@ -1035,7 +1035,8 @@  static int create_qp_common(struct mlx4_ib_dev *dev, struct ib_pd *pd,
 			goto err_mtt;
 
 		if (qp_has_rq(init_attr)) {
-			err = mlx4_ib_db_map_user(to_mucontext(pd->uobject->context),
+			err = mlx4_ib_db_map_user(
+				to_mucontext(pd->uobject->context), udata,
 				(src == MLX4_IB_QP_SRC) ? ucmd.qp.db_addr :
 				ucmd.wq.db_addr, &qp->db);
 			if (err)
diff --git a/drivers/infiniband/hw/mlx4/srq.c b/drivers/infiniband/hw/mlx4/srq.c
index 4456f1b8921d..3751847dcc22 100644
--- a/drivers/infiniband/hw/mlx4/srq.c
+++ b/drivers/infiniband/hw/mlx4/srq.c
@@ -113,7 +113,7 @@  struct ib_srq *mlx4_ib_create_srq(struct ib_pd *pd,
 			goto err_srq;
 		}
 
-		srq->umem = ib_umem_get(pd->uobject->context, ucmd.buf_addr,
+		srq->umem = ib_umem_get(udata, ucmd.buf_addr,
 					buf_size, 0, 0);
 		if (IS_ERR(srq->umem)) {
 			err = PTR_ERR(srq->umem);
@@ -130,7 +130,7 @@  struct ib_srq *mlx4_ib_create_srq(struct ib_pd *pd,
 			goto err_mtt;
 
 		err = mlx4_ib_db_map_user(to_mucontext(pd->uobject->context),
-					  ucmd.db_addr, &srq->db);
+					  udata, ucmd.db_addr, &srq->db);
 		if (err)
 			goto err_mtt;
 	} else {
diff --git a/drivers/infiniband/hw/mlx5/cq.c b/drivers/infiniband/hw/mlx5/cq.c
index 95a29e85522e..2a142beb65f1 100644
--- a/drivers/infiniband/hw/mlx5/cq.c
+++ b/drivers/infiniband/hw/mlx5/cq.c
@@ -707,7 +707,7 @@  static int create_cq_user(struct mlx5_ib_dev *dev, struct ib_udata *udata,
 
 	*cqe_size = ucmd.cqe_size;
 
-	cq->buf.umem = ib_umem_get(context, ucmd.buf_addr,
+	cq->buf.umem = ib_umem_get(udata, ucmd.buf_addr,
 				   entries * ucmd.cqe_size,
 				   IB_ACCESS_LOCAL_WRITE, 1);
 	if (IS_ERR(cq->buf.umem)) {
@@ -715,7 +715,7 @@  static int create_cq_user(struct mlx5_ib_dev *dev, struct ib_udata *udata,
 		return err;
 	}
 
-	err = mlx5_ib_db_map_user(to_mucontext(context), ucmd.db_addr,
+	err = mlx5_ib_db_map_user(to_mucontext(context), udata, ucmd.db_addr,
 				  &cq->db);
 	if (err)
 		goto err_umem;
@@ -1111,7 +1111,6 @@  static int resize_user(struct mlx5_ib_dev *dev, struct mlx5_ib_cq *cq,
 	struct ib_umem *umem;
 	int err;
 	int npages;
-	struct ib_ucontext *context = cq->buf.umem->context;
 
 	err = ib_copy_from_udata(&ucmd, udata, sizeof(ucmd));
 	if (err)
@@ -1124,7 +1123,7 @@  static int resize_user(struct mlx5_ib_dev *dev, struct mlx5_ib_cq *cq,
 	if (ucmd.cqe_size && SIZE_MAX / ucmd.cqe_size <= entries - 1)
 		return -EINVAL;
 
-	umem = ib_umem_get(context, ucmd.buf_addr,
+	umem = ib_umem_get(udata, ucmd.buf_addr,
 			   (size_t)ucmd.cqe_size * entries,
 			   IB_ACCESS_LOCAL_WRITE, 1);
 	if (IS_ERR(umem)) {
diff --git a/drivers/infiniband/hw/mlx5/devx.c b/drivers/infiniband/hw/mlx5/devx.c
index 5271469aad10..093a6cebec8a 100644
--- a/drivers/infiniband/hw/mlx5/devx.c
+++ b/drivers/infiniband/hw/mlx5/devx.c
@@ -1195,7 +1195,7 @@  static int devx_umem_get(struct mlx5_ib_dev *dev, struct ib_ucontext *ucontext,
 	if (err)
 		return err;
 
-	obj->umem = ib_umem_get(ucontext, addr, size, access, 0);
+	obj->umem = ib_umem_get(&attrs->driver_udata, addr, size, access, 0);
 	if (IS_ERR(obj->umem))
 		return PTR_ERR(obj->umem);
 
diff --git a/drivers/infiniband/hw/mlx5/doorbell.c b/drivers/infiniband/hw/mlx5/doorbell.c
index a0e4e6ddb71a..d31cab8ea9ad 100644
--- a/drivers/infiniband/hw/mlx5/doorbell.c
+++ b/drivers/infiniband/hw/mlx5/doorbell.c
@@ -43,7 +43,8 @@  struct mlx5_ib_user_db_page {
 	int			refcnt;
 };
 
-int mlx5_ib_db_map_user(struct mlx5_ib_ucontext *context, unsigned long virt,
+int mlx5_ib_db_map_user(struct mlx5_ib_ucontext *context,
+			struct ib_udata *udata, unsigned long virt,
 			struct mlx5_db *db)
 {
 	struct mlx5_ib_user_db_page *page;
@@ -63,7 +64,7 @@  int mlx5_ib_db_map_user(struct mlx5_ib_ucontext *context, unsigned long virt,
 
 	page->user_virt = (virt & PAGE_MASK);
 	page->refcnt    = 0;
-	page->umem      = ib_umem_get(&context->ibucontext, virt & PAGE_MASK,
+	page->umem      = ib_umem_get(udata, virt & PAGE_MASK,
 				      PAGE_SIZE, 0, 0);
 	if (IS_ERR(page->umem)) {
 		err = PTR_ERR(page->umem);
diff --git a/drivers/infiniband/hw/mlx5/mlx5_ib.h b/drivers/infiniband/hw/mlx5/mlx5_ib.h
index 9b4e2554889a..99c40782d806 100644
--- a/drivers/infiniband/hw/mlx5/mlx5_ib.h
+++ b/drivers/infiniband/hw/mlx5/mlx5_ib.h
@@ -1032,7 +1032,8 @@  to_mflow_act(struct ib_flow_action *ibact)
 	return container_of(ibact, struct mlx5_ib_flow_action, ib_action);
 }
 
-int mlx5_ib_db_map_user(struct mlx5_ib_ucontext *context, unsigned long virt,
+int mlx5_ib_db_map_user(struct mlx5_ib_ucontext *context,
+			struct ib_udata *udata, unsigned long virt,
 			struct mlx5_db *db);
 void mlx5_ib_db_unmap_user(struct mlx5_ib_ucontext *context, struct mlx5_db *db);
 void __mlx5_ib_cq_clean(struct mlx5_ib_cq *cq, u32 qpn, struct mlx5_ib_srq *srq);
@@ -1098,6 +1099,7 @@  int mlx5_ib_dealloc_mw(struct ib_mw *mw);
 int mlx5_ib_update_xlt(struct mlx5_ib_mr *mr, u64 idx, int npages,
 		       int page_shift, int flags);
 struct mlx5_ib_mr *mlx5_ib_alloc_implicit_mr(struct mlx5_ib_pd *pd,
+					     struct ib_udata *udata,
 					     int access_flags);
 void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *mr);
 int mlx5_ib_rereg_user_mr(struct ib_mr *ib_mr, int flags, u64 start,
diff --git a/drivers/infiniband/hw/mlx5/mr.c b/drivers/infiniband/hw/mlx5/mr.c
index fd6ea1f75085..9e6baa1a569b 100644
--- a/drivers/infiniband/hw/mlx5/mr.c
+++ b/drivers/infiniband/hw/mlx5/mr.c
@@ -847,18 +847,17 @@  static int mr_cache_max_order(struct mlx5_ib_dev *dev)
 	return MLX5_MAX_UMR_SHIFT;
 }
 
-static int mr_umem_get(struct ib_pd *pd, u64 start, u64 length,
-		       int access_flags, struct ib_umem **umem,
-		       int *npages, int *page_shift, int *ncont,
-		       int *order)
+static int mr_umem_get(struct mlx5_ib_dev *dev, struct ib_udata *udata,
+		       u64 start, u64 length, int access_flags,
+		       struct ib_umem **umem, int *npages, int *page_shift,
+		       int *ncont, int *order)
 {
-	struct mlx5_ib_dev *dev = to_mdev(pd->device);
 	struct ib_umem *u;
 	int err;
 
 	*umem = NULL;
 
-	u = ib_umem_get(pd->uobject->context, start, length, access_flags, 0);
+	u = ib_umem_get(udata, start, length, access_flags, 0);
 	err = PTR_ERR_OR_ZERO(u);
 	if (err) {
 		mlx5_ib_dbg(dev, "umem get failed (%d)\n", err);
@@ -1337,15 +1336,15 @@  struct ib_mr *mlx5_ib_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 		    !(dev->odp_caps.general_caps & IB_ODP_SUPPORT_IMPLICIT))
 			return ERR_PTR(-EINVAL);
 
-		mr = mlx5_ib_alloc_implicit_mr(to_mpd(pd), access_flags);
+		mr = mlx5_ib_alloc_implicit_mr(to_mpd(pd), udata, access_flags);
 		if (IS_ERR(mr))
 			return ERR_CAST(mr);
 		return &mr->ibmr;
 	}
 #endif
 
-	err = mr_umem_get(pd, start, length, access_flags, &umem, &npages,
-			   &page_shift, &ncont, &order);
+	err = mr_umem_get(dev, udata, start, length, access_flags, &umem,
+			  &npages, &page_shift, &ncont, &order);
 
 	if (err < 0)
 		return ERR_PTR(err);
@@ -1495,8 +1494,9 @@  int mlx5_ib_rereg_user_mr(struct ib_mr *ib_mr, int flags, u64 start,
 		flags |= IB_MR_REREG_TRANS;
 		ib_umem_release(mr->umem);
 		mr->umem = NULL;
-		err = mr_umem_get(pd, addr, len, access_flags, &mr->umem,
-				  &npages, &page_shift, &ncont, &order);
+		err = mr_umem_get(dev, udata, addr, len, access_flags,
+				  &mr->umem, &npages, &page_shift, &ncont,
+				  &order);
 		if (err)
 			goto err;
 	}
diff --git a/drivers/infiniband/hw/mlx5/odp.c b/drivers/infiniband/hw/mlx5/odp.c
index 80fa2438db8f..23255ee542af 100644
--- a/drivers/infiniband/hw/mlx5/odp.c
+++ b/drivers/infiniband/hw/mlx5/odp.c
@@ -492,13 +492,13 @@  static struct ib_umem_odp *implicit_mr_get_data(struct mlx5_ib_mr *mr,
 }
 
 struct mlx5_ib_mr *mlx5_ib_alloc_implicit_mr(struct mlx5_ib_pd *pd,
+					     struct ib_udata *udata,
 					     int access_flags)
 {
-	struct ib_ucontext *ctx = pd->ibpd.uobject->context;
 	struct mlx5_ib_mr *imr;
 	struct ib_umem *umem;
 
-	umem = ib_umem_get(ctx, 0, 0, IB_ACCESS_ON_DEMAND, 0);
+	umem = ib_umem_get(udata, 0, 0, IB_ACCESS_ON_DEMAND, 0);
 	if (IS_ERR(umem))
 		return ERR_CAST(umem);
 
diff --git a/drivers/infiniband/hw/mlx5/qp.c b/drivers/infiniband/hw/mlx5/qp.c
index b26ddb147643..2463d309c5a8 100644
--- a/drivers/infiniband/hw/mlx5/qp.c
+++ b/drivers/infiniband/hw/mlx5/qp.c
@@ -646,7 +646,7 @@  int bfregn_to_uar_index(struct mlx5_ib_dev *dev,
 }
 
 static int mlx5_ib_umem_get(struct mlx5_ib_dev *dev,
-			    struct ib_pd *pd,
+			    struct ib_udata *udata,
 			    unsigned long addr, size_t size,
 			    struct ib_umem **umem,
 			    int *npages, int *page_shift, int *ncont,
@@ -654,7 +654,7 @@  static int mlx5_ib_umem_get(struct mlx5_ib_dev *dev,
 {
 	int err;
 
-	*umem = ib_umem_get(pd->uobject->context, addr, size, 0, 0);
+	*umem = ib_umem_get(udata, addr, size, 0, 0);
 	if (IS_ERR(*umem)) {
 		mlx5_ib_dbg(dev, "umem_get failed\n");
 		return PTR_ERR(*umem);
@@ -695,10 +695,9 @@  static void destroy_user_rq(struct mlx5_ib_dev *dev, struct ib_pd *pd,
 }
 
 static int create_user_rq(struct mlx5_ib_dev *dev, struct ib_pd *pd,
-			  struct mlx5_ib_rwq *rwq,
+			  struct ib_udata *udata, struct mlx5_ib_rwq *rwq,
 			  struct mlx5_ib_create_wq *ucmd)
 {
-	struct mlx5_ib_ucontext *context;
 	int page_shift = 0;
 	int npages;
 	u32 offset = 0;
@@ -708,8 +707,7 @@  static int create_user_rq(struct mlx5_ib_dev *dev, struct ib_pd *pd,
 	if (!ucmd->buf_addr)
 		return -EINVAL;
 
-	context = to_mucontext(pd->uobject->context);
-	rwq->umem = ib_umem_get(pd->uobject->context, ucmd->buf_addr,
+	rwq->umem = ib_umem_get(udata, ucmd->buf_addr,
 			       rwq->buf_size, 0, 0);
 	if (IS_ERR(rwq->umem)) {
 		mlx5_ib_dbg(dev, "umem_get failed\n");
@@ -735,7 +733,8 @@  static int create_user_rq(struct mlx5_ib_dev *dev, struct ib_pd *pd,
 		    (unsigned long long)ucmd->buf_addr, rwq->buf_size,
 		    npages, page_shift, ncont, offset);
 
-	err = mlx5_ib_db_map_user(context, ucmd->db_addr, &rwq->db);
+	err = mlx5_ib_db_map_user(to_mucontext(pd->uobject->context), udata,
+				  ucmd->db_addr, &rwq->db);
 	if (err) {
 		mlx5_ib_dbg(dev, "map failed\n");
 		goto err_umem;
@@ -819,7 +818,7 @@  static int create_user_qp(struct mlx5_ib_dev *dev, struct ib_pd *pd,
 
 	if (ucmd.buf_addr && ubuffer->buf_size) {
 		ubuffer->buf_addr = ucmd.buf_addr;
-		err = mlx5_ib_umem_get(dev, pd, ubuffer->buf_addr,
+		err = mlx5_ib_umem_get(dev, udata, ubuffer->buf_addr,
 				       ubuffer->buf_size,
 				       &ubuffer->umem, &npages, &page_shift,
 				       &ncont, &offset);
@@ -855,7 +854,7 @@  static int create_user_qp(struct mlx5_ib_dev *dev, struct ib_pd *pd,
 		resp->bfreg_index = MLX5_IB_INVALID_BFREG;
 	qp->bfregn = bfregn;
 
-	err = mlx5_ib_db_map_user(context, ucmd.db_addr, &qp->db);
+	err = mlx5_ib_db_map_user(context, udata, ucmd.db_addr, &qp->db);
 	if (err) {
 		mlx5_ib_dbg(dev, "map failed\n");
 		goto err_free;
@@ -1118,6 +1117,7 @@  static void destroy_flow_rule_vport_sq(struct mlx5_ib_dev *dev,
 }
 
 static int create_raw_packet_qp_sq(struct mlx5_ib_dev *dev,
+				   struct ib_udata *udata,
 				   struct mlx5_ib_sq *sq, void *qpin,
 				   struct ib_pd *pd)
 {
@@ -1134,7 +1134,7 @@  static int create_raw_packet_qp_sq(struct mlx5_ib_dev *dev,
 	int ncont = 0;
 	u32 offset = 0;
 
-	err = mlx5_ib_umem_get(dev, pd, ubuffer->buf_addr, ubuffer->buf_size,
+	err = mlx5_ib_umem_get(dev, udata, ubuffer->buf_addr, ubuffer->buf_size,
 			       &sq->ubuffer.umem, &npages, &page_shift,
 			       &ncont, &offset);
 	if (err)
@@ -1373,7 +1373,7 @@  static int create_raw_packet_qp(struct mlx5_ib_dev *dev, struct mlx5_ib_qp *qp,
 		if (err)
 			return err;
 
-		err = create_raw_packet_qp_sq(dev, sq, in, pd);
+		err = create_raw_packet_qp_sq(dev, udata, sq, in, pd);
 		if (err)
 			goto err_destroy_tis;
 
@@ -5792,7 +5792,7 @@  static int prepare_user_rq(struct ib_pd *pd,
 		return err;
 	}
 
-	err = create_user_rq(dev, pd, rwq, &ucmd);
+	err = create_user_rq(dev, pd, udata, rwq, &ucmd);
 	if (err) {
 		mlx5_ib_dbg(dev, "err %d\n", err);
 		return err;
diff --git a/drivers/infiniband/hw/mlx5/srq.c b/drivers/infiniband/hw/mlx5/srq.c
index 4e8d18009f58..d02383d35750 100644
--- a/drivers/infiniband/hw/mlx5/srq.c
+++ b/drivers/infiniband/hw/mlx5/srq.c
@@ -79,7 +79,7 @@  static int create_srq_user(struct ib_pd *pd, struct mlx5_ib_srq *srq,
 
 	srq->wq_sig = !!(ucmd.flags & MLX5_SRQ_FLAG_SIGNATURE);
 
-	srq->umem = ib_umem_get(pd->uobject->context, ucmd.buf_addr, buf_size,
+	srq->umem = ib_umem_get(udata, ucmd.buf_addr, buf_size,
 				0, 0);
 	if (IS_ERR(srq->umem)) {
 		mlx5_ib_dbg(dev, "failed umem get, size %d\n", buf_size);
@@ -104,7 +104,7 @@  static int create_srq_user(struct ib_pd *pd, struct mlx5_ib_srq *srq,
 
 	mlx5_ib_populate_pas(dev, srq->umem, page_shift, in->pas, 0);
 
-	err = mlx5_ib_db_map_user(to_mucontext(pd->uobject->context),
+	err = mlx5_ib_db_map_user(to_mucontext(pd->uobject->context), udata,
 				  ucmd.db_addr, &srq->db);
 	if (err) {
 		mlx5_ib_dbg(dev, "map doorbell failed\n");
diff --git a/drivers/infiniband/hw/mthca/mthca_provider.c b/drivers/infiniband/hw/mthca/mthca_provider.c
index 443521cf8107..2347f3ebb4dd 100644
--- a/drivers/infiniband/hw/mthca/mthca_provider.c
+++ b/drivers/infiniband/hw/mthca/mthca_provider.c
@@ -930,7 +930,7 @@  static struct ib_mr *mthca_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 	if (!mr)
 		return ERR_PTR(-ENOMEM);
 
-	mr->umem = ib_umem_get(pd->uobject->context, start, length, acc,
+	mr->umem = ib_umem_get(udata, start, length, acc,
 			       ucmd.mr_attrs & MTHCA_MR_DMASYNC);
 
 	if (IS_ERR(mr->umem)) {
diff --git a/drivers/infiniband/hw/nes/nes_verbs.c b/drivers/infiniband/hw/nes/nes_verbs.c
index 4e7f08ee1907..feb4d259aab9 100644
--- a/drivers/infiniband/hw/nes/nes_verbs.c
+++ b/drivers/infiniband/hw/nes/nes_verbs.c
@@ -2134,7 +2134,7 @@  static struct ib_mr *nes_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 	u8 stag_key;
 	int first_page = 1;
 
-	region = ib_umem_get(pd->uobject->context, start, length, acc, 0);
+	region = ib_umem_get(udata, start, length, acc, 0);
 	if (IS_ERR(region)) {
 		return (struct ib_mr *)region;
 	}
diff --git a/drivers/infiniband/hw/ocrdma/ocrdma_verbs.c b/drivers/infiniband/hw/ocrdma/ocrdma_verbs.c
index c46bed0c5513..a8b40cb7679e 100644
--- a/drivers/infiniband/hw/ocrdma/ocrdma_verbs.c
+++ b/drivers/infiniband/hw/ocrdma/ocrdma_verbs.c
@@ -916,7 +916,7 @@  struct ib_mr *ocrdma_reg_user_mr(struct ib_pd *ibpd, u64 start, u64 len,
 	mr = kzalloc(sizeof(*mr), GFP_KERNEL);
 	if (!mr)
 		return ERR_PTR(status);
-	mr->umem = ib_umem_get(ibpd->uobject->context, start, len, acc, 0);
+	mr->umem = ib_umem_get(udata, start, len, acc, 0);
 	if (IS_ERR(mr->umem)) {
 		status = -EFAULT;
 		goto umem_err;
diff --git a/drivers/infiniband/hw/qedr/verbs.c b/drivers/infiniband/hw/qedr/verbs.c
index 8056121e9f69..70908516de9e 100644
--- a/drivers/infiniband/hw/qedr/verbs.c
+++ b/drivers/infiniband/hw/qedr/verbs.c
@@ -736,7 +736,7 @@  static inline int qedr_align_cq_entries(int entries)
 	return aligned_size / QEDR_CQE_SIZE;
 }
 
-static inline int qedr_init_user_queue(struct ib_ucontext *ib_ctx,
+static inline int qedr_init_user_queue(struct ib_udata *udata,
 				       struct qedr_dev *dev,
 				       struct qedr_userq *q,
 				       u64 buf_addr, size_t buf_len,
@@ -748,7 +748,7 @@  static inline int qedr_init_user_queue(struct ib_ucontext *ib_ctx,
 
 	q->buf_addr = buf_addr;
 	q->buf_len = buf_len;
-	q->umem = ib_umem_get(ib_ctx, q->buf_addr, q->buf_len, access, dmasync);
+	q->umem = ib_umem_get(udata, q->buf_addr, q->buf_len, access, dmasync);
 	if (IS_ERR(q->umem)) {
 		DP_ERR(dev, "create user queue: failed ib_umem_get, got %ld\n",
 		       PTR_ERR(q->umem));
@@ -905,9 +905,9 @@  struct ib_cq *qedr_create_cq(struct ib_device *ibdev,
 
 		cq->cq_type = QEDR_CQ_TYPE_USER;
 
-		rc = qedr_init_user_queue(ib_ctx, dev, &cq->q, ureq.addr,
-					  ureq.len, IB_ACCESS_LOCAL_WRITE,
-					  1, 1);
+		rc = qedr_init_user_queue(udata, dev, &cq->q, ureq.addr,
+					  ureq.len, IB_ACCESS_LOCAL_WRITE, 1,
+					  1);
 		if (rc)
 			goto err0;
 
@@ -1344,7 +1344,7 @@  static void qedr_free_srq_kernel_params(struct qedr_srq *srq)
 			  hw_srq->phy_prod_pair_addr);
 }
 
-static int qedr_init_srq_user_params(struct ib_ucontext *ib_ctx,
+static int qedr_init_srq_user_params(struct ib_udata *udata,
 				     struct qedr_srq *srq,
 				     struct qedr_create_srq_ureq *ureq,
 				     int access, int dmasync)
@@ -1352,12 +1352,12 @@  static int qedr_init_srq_user_params(struct ib_ucontext *ib_ctx,
 	struct scatterlist *sg;
 	int rc;
 
-	rc = qedr_init_user_queue(ib_ctx, srq->dev, &srq->usrq, ureq->srq_addr,
+	rc = qedr_init_user_queue(udata, srq->dev, &srq->usrq, ureq->srq_addr,
 				  ureq->srq_len, access, dmasync, 1);
 	if (rc)
 		return rc;
 
-	srq->prod_umem = ib_umem_get(ib_ctx, ureq->prod_pair_addr,
+	srq->prod_umem = ib_umem_get(udata, ureq->prod_pair_addr,
 				     sizeof(struct rdma_srq_producers),
 				     access, dmasync);
 	if (IS_ERR(srq->prod_umem)) {
@@ -1468,7 +1468,7 @@  struct ib_srq *qedr_create_srq(struct ib_pd *ibpd,
 			goto err0;
 		}
 
-		rc = qedr_init_srq_user_params(ib_ctx, srq, &ureq, 0, 0);
+		rc = qedr_init_srq_user_params(udata, srq, &ureq, 0, 0);
 		if (rc)
 			goto err0;
 
@@ -1714,14 +1714,14 @@  static int qedr_create_user_qp(struct qedr_dev *dev,
 	}
 
 	/* SQ - read access only (0), dma sync not required (0) */
-	rc = qedr_init_user_queue(ib_ctx, dev, &qp->usq, ureq.sq_addr,
+	rc = qedr_init_user_queue(udata, dev, &qp->usq, ureq.sq_addr,
 				  ureq.sq_len, 0, 0, alloc_and_init);
 	if (rc)
 		return rc;
 
 	if (!qp->srq) {
 		/* RQ - read access only (0), dma sync not required (0) */
-		rc = qedr_init_user_queue(ib_ctx, dev, &qp->urq, ureq.rq_addr,
+		rc = qedr_init_user_queue(udata, dev, &qp->urq, ureq.rq_addr,
 					  ureq.rq_len, 0, 0, alloc_and_init);
 		if (rc)
 			return rc;
@@ -2719,7 +2719,7 @@  struct ib_mr *qedr_reg_user_mr(struct ib_pd *ibpd, u64 start, u64 len,
 
 	mr->type = QEDR_MR_USER;
 
-	mr->umem = ib_umem_get(ibpd->uobject->context, start, len, acc, 0);
+	mr->umem = ib_umem_get(udata, start, len, acc, 0);
 	if (IS_ERR(mr->umem)) {
 		rc = -EFAULT;
 		goto err0;
diff --git a/drivers/infiniband/hw/vmw_pvrdma/pvrdma_cq.c b/drivers/infiniband/hw/vmw_pvrdma/pvrdma_cq.c
index 0f004c737620..104c7db4704f 100644
--- a/drivers/infiniband/hw/vmw_pvrdma/pvrdma_cq.c
+++ b/drivers/infiniband/hw/vmw_pvrdma/pvrdma_cq.c
@@ -141,7 +141,7 @@  struct ib_cq *pvrdma_create_cq(struct ib_device *ibdev,
 			goto err_cq;
 		}
 
-		cq->umem = ib_umem_get(context, ucmd.buf_addr, ucmd.buf_size,
+		cq->umem = ib_umem_get(udata, ucmd.buf_addr, ucmd.buf_size,
 				       IB_ACCESS_LOCAL_WRITE, 1);
 		if (IS_ERR(cq->umem)) {
 			ret = PTR_ERR(cq->umem);
diff --git a/drivers/infiniband/hw/vmw_pvrdma/pvrdma_mr.c b/drivers/infiniband/hw/vmw_pvrdma/pvrdma_mr.c
index fa96fa4fb829..1721e1f254db 100644
--- a/drivers/infiniband/hw/vmw_pvrdma/pvrdma_mr.c
+++ b/drivers/infiniband/hw/vmw_pvrdma/pvrdma_mr.c
@@ -126,7 +126,7 @@  struct ib_mr *pvrdma_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 		return ERR_PTR(-EINVAL);
 	}
 
-	umem = ib_umem_get(pd->uobject->context, start,
+	umem = ib_umem_get(udata, start,
 			   length, access_flags, 0);
 	if (IS_ERR(umem)) {
 		dev_warn(&dev->pdev->dev,
diff --git a/drivers/infiniband/hw/vmw_pvrdma/pvrdma_qp.c b/drivers/infiniband/hw/vmw_pvrdma/pvrdma_qp.c
index 3acf74cbe266..80ae52d34332 100644
--- a/drivers/infiniband/hw/vmw_pvrdma/pvrdma_qp.c
+++ b/drivers/infiniband/hw/vmw_pvrdma/pvrdma_qp.c
@@ -262,7 +262,7 @@  struct ib_qp *pvrdma_create_qp(struct ib_pd *pd,
 
 			if (!is_srq) {
 				/* set qp->sq.wqe_cnt, shift, buf_size.. */
-				qp->rumem = ib_umem_get(pd->uobject->context,
+				qp->rumem = ib_umem_get(udata,
 							ucmd.rbuf_addr,
 							ucmd.rbuf_size, 0, 0);
 				if (IS_ERR(qp->rumem)) {
@@ -275,7 +275,7 @@  struct ib_qp *pvrdma_create_qp(struct ib_pd *pd,
 				qp->srq = to_vsrq(init_attr->srq);
 			}
 
-			qp->sumem = ib_umem_get(pd->uobject->context,
+			qp->sumem = ib_umem_get(udata,
 						ucmd.sbuf_addr,
 						ucmd.sbuf_size, 0, 0);
 			if (IS_ERR(qp->sumem)) {
diff --git a/drivers/infiniband/hw/vmw_pvrdma/pvrdma_srq.c b/drivers/infiniband/hw/vmw_pvrdma/pvrdma_srq.c
index 06ba7c7a2235..a42f0ca295e6 100644
--- a/drivers/infiniband/hw/vmw_pvrdma/pvrdma_srq.c
+++ b/drivers/infiniband/hw/vmw_pvrdma/pvrdma_srq.c
@@ -153,7 +153,7 @@  struct ib_srq *pvrdma_create_srq(struct ib_pd *pd,
 		goto err_srq;
 	}
 
-	srq->umem = ib_umem_get(pd->uobject->context,
+	srq->umem = ib_umem_get(udata,
 				ucmd.buf_addr,
 				ucmd.buf_size, 0, 0);
 	if (IS_ERR(srq->umem)) {
diff --git a/drivers/infiniband/sw/rdmavt/mr.c b/drivers/infiniband/sw/rdmavt/mr.c
index 49c9541050d4..6d7082135a56 100644
--- a/drivers/infiniband/sw/rdmavt/mr.c
+++ b/drivers/infiniband/sw/rdmavt/mr.c
@@ -388,7 +388,7 @@  struct ib_mr *rvt_reg_user_mr(struct ib_pd *pd, u64 start, u64 length,
 	if (length == 0)
 		return ERR_PTR(-EINVAL);
 
-	umem = ib_umem_get(pd->uobject->context, start, length,
+	umem = ib_umem_get(udata, start, length,
 			   mr_access_flags, 0);
 	if (IS_ERR(umem))
 		return (void *)umem;
diff --git a/drivers/infiniband/sw/rxe/rxe_mr.c b/drivers/infiniband/sw/rxe/rxe_mr.c
index 9d3916b93f23..2438093776a0 100644
--- a/drivers/infiniband/sw/rxe/rxe_mr.c
+++ b/drivers/infiniband/sw/rxe/rxe_mr.c
@@ -171,7 +171,7 @@  int rxe_mem_init_user(struct rxe_pd *pd, u64 start,
 	void			*vaddr;
 	int err;
 
-	umem = ib_umem_get(pd->ibpd.uobject->context, start, length, access, 0);
+	umem = ib_umem_get(udata, start, length, access, 0);
 	if (IS_ERR(umem)) {
 		pr_warn("err %d from rxe_umem_get\n",
 			(int)PTR_ERR(umem));
diff --git a/include/rdma/ib_umem.h b/include/rdma/ib_umem.h
index 5d3755ec5afa..18b5f1c4e9d0 100644
--- a/include/rdma/ib_umem.h
+++ b/include/rdma/ib_umem.h
@@ -36,6 +36,7 @@ 
 #include <linux/list.h>
 #include <linux/scatterlist.h>
 #include <linux/workqueue.h>
+#include <rdma/ib_verbs.h>
 
 struct ib_ucontext;
 struct ib_umem_odp;
@@ -80,7 +81,7 @@  static inline size_t ib_umem_num_pages(struct ib_umem *umem)
 
 #ifdef CONFIG_INFINIBAND_USER_MEM
 
-struct ib_umem *ib_umem_get(struct ib_ucontext *context, unsigned long addr,
+struct ib_umem *ib_umem_get(struct ib_udata *udata, unsigned long addr,
 			    size_t size, int access, int dmasync);
 void ib_umem_release(struct ib_umem *umem);
 int ib_umem_page_count(struct ib_umem *umem);
@@ -91,7 +92,7 @@  int ib_umem_copy_from(void *dst, struct ib_umem *umem, size_t offset,
 
 #include <linux/err.h>
 
-static inline struct ib_umem *ib_umem_get(struct ib_ucontext *context,
+static inline struct ib_umem *ib_umem_get(struct ib_udata *udata,
 					  unsigned long addr, size_t size,
 					  int access, int dmasync) {
 	return ERR_PTR(-EINVAL);
diff --git a/include/rdma/ib_verbs.h b/include/rdma/ib_verbs.h
index cd8c5886a1e6..fa663587b9ff 100644
--- a/include/rdma/ib_verbs.h
+++ b/include/rdma/ib_verbs.h
@@ -2619,6 +2619,7 @@  struct ib_client {
 
 struct ib_uverbs {
 	/* uverbs callbacks used by ib_core */
+	struct ib_ucontext *(*rdma_get_ucontext)(struct ib_udata *udata);
 };
 
 struct ib_device *ib_alloc_device(size_t size);
@@ -4201,6 +4202,7 @@  void rdma_roce_rescan_device(struct ib_device *ibdev);
 
 struct ib_ucontext *ib_uverbs_get_ucontext_file(struct ib_uverbs_file *ufile);
 
+struct ib_ucontext *rdma_get_ucontext(struct ib_udata *udata);
 
 int uverbs_destroy_def_handler(struct uverbs_attr_bundle *attrs);