@@ -41,6 +41,7 @@
#include "core_priv.h"
#include "cma_priv.h"
#include "restrack.h"
+#include "uverbs.h"
/*
* Sort array elements by the netlink attribute name
@@ -141,6 +142,8 @@ static const struct nla_policy nldev_policy[RDMA_NLDEV_ATTR_MAX] = {
[RDMA_NLDEV_ATTR_UVERBS_DRIVER_ID] = { .type = NLA_U32 },
[RDMA_NLDEV_NET_NS_FD] = { .type = NLA_U32 },
[RDMA_NLDEV_SYS_ATTR_NETNS_MODE] = { .type = NLA_U8 },
+ [RDMA_NLDEV_ATTR_RES_CTX] = { .type = NLA_NESTED },
+ [RDMA_NLDEV_ATTR_RES_CTX_ENTRY] = { .type = NLA_NESTED },
};
static int put_driver_name_print_type(struct sk_buff *msg, const char *name,
@@ -611,11 +614,84 @@ static int fill_res_mr_entry(struct sk_buff *msg, bool has_cap_net_admin,
err: return -EMSGSIZE;
}
+struct context_id {
+ struct list_head list;
+ u32 id;
+};
+
+static void pd_context(struct ib_pd *pd, struct list_head *list, int *count)
+{
+ struct ib_device *device = pd->device;
+ struct rdma_restrack_entry *res;
+ struct rdma_restrack_root *rt;
+ struct ib_uverbs_file *ufile;
+ struct ib_ucontext *ucontext;
+ struct ib_uobject *uobj;
+ unsigned long flags;
+ unsigned long id;
+ bool found;
+
+ rt = &device->res[RDMA_RESTRACK_CTX];
+
+ xa_lock(&rt->xa);
+
+ xa_for_each(&rt->xa, id, res) {
+ if (!rdma_is_visible_in_pid_ns(res))
+ continue;
+
+ if (!rdma_restrack_get(res))
+ continue;
+
+ xa_unlock(&rt->xa);
+
+ ucontext = container_of(res, struct ib_ucontext, res);
+ ufile = ucontext->ufile;
+ found = false;
+
+ /* See locking requirements in struct ib_uverbs_file */
+ down_read(&ufile->hw_destroy_rwsem);
+ spin_lock_irqsave(&ufile->uobjects_lock, flags);
+
+ list_for_each_entry(uobj, &ufile->uobjects, list) {
+ if (uobj->object == pd) {
+ found = true;
+ goto found;
+ }
+ }
+
+found: spin_unlock_irqrestore(&ufile->uobjects_lock, flags);
+ up_read(&ufile->hw_destroy_rwsem);
+
+ if (found) {
+ struct context_id *ctx_id =
+ kmalloc(sizeof(*ctx_id), GFP_KERNEL);
+
+ if (WARN_ON_ONCE(!ctx_id))
+ goto next;
+
+ ctx_id->id = ucontext->res.id;
+ list_add(&ctx_id->list, list);
+ (*count)++;
+ }
+
+next: rdma_restrack_put(res);
+ xa_lock(&rt->xa);
+ }
+
+ xa_unlock(&rt->xa);
+}
+
static int fill_res_pd_entry(struct sk_buff *msg, bool has_cap_net_admin,
struct rdma_restrack_entry *res, uint32_t port)
{
struct ib_pd *pd = container_of(res, struct ib_pd, res);
struct ib_device *dev = pd->device;
+ struct nlattr *table_attr = NULL;
+ struct nlattr *entry_attr = NULL;
+ struct context_id *ctx_id;
+ struct context_id *tmp;
+ LIST_HEAD(pd_context_ids);
+ int ctx_count = 0;
if (has_cap_net_admin) {
if (nla_put_u32(msg, RDMA_NLDEV_ATTR_RES_LOCAL_DMA_LKEY,
@@ -633,10 +709,38 @@ static int fill_res_pd_entry(struct sk_buff *msg, bool has_cap_net_admin,
if (nla_put_u32(msg, RDMA_NLDEV_ATTR_RES_PDN, res->id))
goto err;
- if (!rdma_is_kernel_res(res) &&
- nla_put_u32(msg, RDMA_NLDEV_ATTR_RES_CTXN,
- pd->uobject->context->res.id))
- goto err;
+ if (!rdma_is_kernel_res(res)) {
+ pd_context(pd, &pd_context_ids, &ctx_count);
+ if (ctx_count == 1) {
+ /* user pd, not shared */
+ ctx_id = list_first_entry(&pd_context_ids,
+ struct context_id, list);
+ if (nla_put_u32(msg, RDMA_NLDEV_ATTR_RES_CTXN,
+ ctx_id->id))
+ goto err;
+ } else if (ctx_count > 1) {
+ /* user pd, shared */
+ table_attr = nla_nest_start(msg,
+ RDMA_NLDEV_ATTR_RES_CTX);
+ if (!table_attr)
+ goto err;
+
+ list_for_each_entry(ctx_id, &pd_context_ids, list) {
+ entry_attr = nla_nest_start(msg,
+ RDMA_NLDEV_ATTR_RES_CTX_ENTRY);
+ if (!entry_attr)
+ goto err;
+ if (nla_put_u32(msg, RDMA_NLDEV_ATTR_RES_CTXN,
+ ctx_id->id))
+ goto err;
+ nla_nest_end(msg, entry_attr);
+ entry_attr = NULL;
+ }
+
+ nla_nest_end(msg, table_attr);
+ table_attr = NULL;
+ }
+ }
if (fill_res_name_pid(msg, res))
goto err;
@@ -644,9 +748,22 @@ static int fill_res_pd_entry(struct sk_buff *msg, bool has_cap_net_admin,
if (fill_res_entry(dev, msg, res))
goto err;
+ list_for_each_entry_safe(ctx_id, tmp, &pd_context_ids, list)
+ kfree(ctx_id);
+
return 0;
-err: return -EMSGSIZE;
+err:
+ if (entry_attr)
+ nla_nest_end(msg, entry_attr);
+
+ if (table_attr)
+ nla_nest_end(msg, table_attr);
+
+ list_for_each_entry_safe(ctx_id, tmp, &pd_context_ids, list)
+ kfree(ctx_id);
+
+ return -EMSGSIZE;
}
static int fill_stat_counter_mode(struct sk_buff *msg,
@@ -525,6 +525,9 @@ enum rdma_nldev_attr {
*/
RDMA_NLDEV_ATTR_DEV_DIM, /* u8 */
+ RDMA_NLDEV_ATTR_RES_CTX, /* nested table */
+ RDMA_NLDEV_ATTR_RES_CTX_ENTRY, /* nested table */
+
/*
* Always the end
*/