diff mbox series

[04/10] RDMA/device: Add ib_device_set_netdev() as an alternative to get_netdev

Message ID 20190213041256.22437-5-jgg@ziepe.ca (mailing list archive)
State Accepted
Delegated to: Jason Gunthorpe
Headers show
Series Revise device handling in rxe | expand

Commit Message

Jason Gunthorpe Feb. 13, 2019, 4:12 a.m. UTC
From: Jason Gunthorpe <jgg@mellanox.com>

The associated netdev should not actually be very dynamic, so for most
drivers there is no reason for a callback like this. Provide an API to
inform the core code about the net dev affiliation and use a core
maintained data structure instead.

This allows the core code to be more aware of the ndev relationship which
will allow some new APIs based around this.

This also uses locking that makes some kind of sense, many drivers had a
confusing RCU lock, or missing locking which isn't right.

Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
---
 drivers/infiniband/core/cache.c     |  24 ++--
 drivers/infiniband/core/core_priv.h |   3 +
 drivers/infiniband/core/device.c    | 166 +++++++++++++++++++++++++---
 drivers/infiniband/core/nldev.c     |   4 +-
 drivers/infiniband/core/verbs.c     |   5 +-
 include/rdma/ib_verbs.h             |   7 ++
 6 files changed, 171 insertions(+), 38 deletions(-)

Comments

Ira Weiny Feb. 15, 2019, 7:03 p.m. UTC | #1
On Tue, Feb 12, 2019 at 09:12:50PM -0700, Jason Gunthorpe wrote:
> From: Jason Gunthorpe <jgg@mellanox.com>
> 
> The associated netdev should not actually be very dynamic, so for most
> drivers there is no reason for a callback like this. Provide an API to
> inform the core code about the net dev affiliation and use a core
> maintained data structure instead.
> 
> This allows the core code to be more aware of the ndev relationship which
> will allow some new APIs based around this.
> 
> This also uses locking that makes some kind of sense, many drivers had a
> confusing RCU lock, or missing locking which isn't right.
> 
> Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
> ---
>  drivers/infiniband/core/cache.c     |  24 ++--
>  drivers/infiniband/core/core_priv.h |   3 +
>  drivers/infiniband/core/device.c    | 166 +++++++++++++++++++++++++---
>  drivers/infiniband/core/nldev.c     |   4 +-
>  drivers/infiniband/core/verbs.c     |   5 +-
>  include/rdma/ib_verbs.h             |   7 ++
>  6 files changed, 171 insertions(+), 38 deletions(-)
> 
> diff --git a/drivers/infiniband/core/cache.c b/drivers/infiniband/core/cache.c
> index a28dc1901c8000..e191f3d86b41d5 100644
> --- a/drivers/infiniband/core/cache.c
> +++ b/drivers/infiniband/core/cache.c
> @@ -547,21 +547,19 @@ int ib_cache_gid_add(struct ib_device *ib_dev, u8 port,
>  	unsigned long mask;
>  	int ret;
>  
> -	if (ib_dev->ops.get_netdev) {
> -		idev = ib_dev->ops.get_netdev(ib_dev, port);
> -		if (idev && attr->ndev != idev) {
> -			union ib_gid default_gid;
> -
> -			/* Adding default GIDs in not permitted */
> -			make_default_gid(idev, &default_gid);
> -			if (!memcmp(gid, &default_gid, sizeof(*gid))) {
> -				dev_put(idev);
> -				return -EPERM;
> -			}
> -		}
> -		if (idev)
> +	idev = ib_device_get_netdev(ib_dev, port);
> +	if (idev && attr->ndev != idev) {
> +		union ib_gid default_gid;
> +
> +		/* Adding default GIDs in not permitted */

NIT: "is not"?

> +		make_default_gid(idev, &default_gid);
> +		if (!memcmp(gid, &default_gid, sizeof(*gid))) {
>  			dev_put(idev);
> +			return -EPERM;
> +		}
>  	}
> +	if (idev)
> +		dev_put(idev);
>  
>  	mask = GID_ATTR_FIND_MASK_GID |
>  	       GID_ATTR_FIND_MASK_GID_TYPE |
> diff --git a/drivers/infiniband/core/core_priv.h b/drivers/infiniband/core/core_priv.h
> index a1826f4c2e23ee..8aa4872e07a0da 100644
> --- a/drivers/infiniband/core/core_priv.h
> +++ b/drivers/infiniband/core/core_priv.h
> @@ -64,6 +64,9 @@ typedef void (*roce_netdev_callback)(struct ib_device *device, u8 port,
>  typedef bool (*roce_netdev_filter)(struct ib_device *device, u8 port,
>  				   struct net_device *idev, void *cookie);
>  
> +struct net_device *ib_device_get_netdev(struct ib_device *ib_dev,
> +					unsigned int port);
> +
>  void ib_enum_roce_netdev(struct ib_device *ib_dev,
>  			 roce_netdev_filter filter,
>  			 void *filter_cookie,
> diff --git a/drivers/infiniband/core/device.c b/drivers/infiniband/core/device.c
> index 58591408bb1b35..14c91f9af6ccc9 100644
> --- a/drivers/infiniband/core/device.c
> +++ b/drivers/infiniband/core/device.c
> @@ -133,6 +133,7 @@ static void *xan_find_marked(struct xarray *xa, unsigned long *indexp,
>  	     !xa_is_err(entry);                                                \
>  	     (index)++, entry = xan_find_marked(xa, &(index), filter))
>  
> +static void free_netdevs(struct ib_device *ib_dev);
>  static int ib_security_change(struct notifier_block *nb, unsigned long event,
>  			      void *lsm_data);
>  static void ib_policy_change_task(struct work_struct *work);
> @@ -289,6 +290,7 @@ static void ib_device_release(struct device *device)
>  {
>  	struct ib_device *dev = container_of(device, struct ib_device, dev);
>  
> +	free_netdevs(dev);
>  	WARN_ON(refcount_read(&dev->refcount));
>  	ib_cache_release_one(dev);
>  	ib_security_release_port_pkey_list(dev);
> @@ -365,6 +367,9 @@ EXPORT_SYMBOL(_ib_alloc_device);
>   */
>  void ib_dealloc_device(struct ib_device *device)
>  {
> +	/* Expedite releasing netdev references */
> +	free_netdevs(device);
> +
>  	WARN_ON(!xa_empty(&device->client_data));
>  	WARN_ON(refcount_read(&device->refcount));
>  	rdma_restrack_clean(device);
> @@ -454,16 +459,16 @@ static void remove_client_context(struct ib_device *device,
>  	up_read(&device->client_data_rwsem);
>  }
>  
> -static int verify_immutable(const struct ib_device *dev, u8 port)
> -{
> -	return WARN_ON(!rdma_cap_ib_mad(dev, port) &&
> -			    rdma_max_mad_size(dev, port) != 0);
> -}
> -
> -static int setup_port_data(struct ib_device *device)
> +static int alloc_port_data(struct ib_device *device)
>  {
>  	unsigned int port;
> -	int ret;
> +
> +	if (device->port_data)
> +		return 0;
> +
> +	/* This can only be called once the physical port range is defined */
> +	if (WARN_ON(!device->phys_port_cnt))
> +		return -EINVAL;

Was this port stuff supposed to be part of the previous patch?

>  
>  	/*
>  	 * device->port_data is indexed directly by the port number to make
> @@ -482,6 +487,28 @@ static int setup_port_data(struct ib_device *device)
>  
>  		spin_lock_init(&pdata->pkey_list_lock);
>  		INIT_LIST_HEAD(&pdata->pkey_list);
> +		spin_lock_init(&pdata->netdev_lock);
> +	}
> +	return 0;
> +}
> +
> +static int verify_immutable(const struct ib_device *dev, u8 port)
> +{
> +	return WARN_ON(!rdma_cap_ib_mad(dev, port) &&
> +			    rdma_max_mad_size(dev, port) != 0);
> +}
> +
> +static int setup_port_data(struct ib_device *device)
> +{
> +	unsigned int port;
> +	int ret;
> +
> +	ret = alloc_port_data(device);
> +	if (ret)
> +		return ret;
> +
> +	rdma_for_each_port (device, port) {
> +		struct ib_port_data *pdata = &device->port_data[port];

Same question as above?

>  
>  		ret = device->ops.get_port_immutable(device, port,
>  						     &pdata->immutable);
> @@ -675,6 +702,9 @@ static void disable_device(struct ib_device *device)
>  	/* Pairs with refcount_set in enable_device */
>  	ib_device_put(device);
>  	wait_for_completion(&device->unreg_completion);
> +
> +	/* Expedite removing unregistered pointers from the hash table */
> +	free_netdevs(device);
>  }
>  
>  /*
> @@ -998,6 +1028,114 @@ int ib_query_port(struct ib_device *device,
>  }
>  EXPORT_SYMBOL(ib_query_port);
>  
> +/**
> + * ib_device_set_netdev - Associate the ib_dev with an underlying net_device
> + * @ib_dev: Device to modify
> + * @ndev: net_device to affiliate, may be NULL
> + * @port: IB port the net_device is connected to
> + *
> + * Drivers should use this to link the ib_device to a netdev so the netdev
> + * shows up in interfaces like ib_enum_roce_netdev. Only one netdev may be
> + * affiliated with any port.
> + *
> + * The caller must ensure that the given ndev is not unregistered or
> + * unregistering, and that either the ib_device is unregistered or
> + * ib_device_set_netdev() is called with NULL when the ndev sends a
> + * NETDEV_UNREGISTER event.
> + */
> +int ib_device_set_netdev(struct ib_device *ib_dev, struct net_device *ndev,
> +			 unsigned int port)
> +{
> +	struct net_device *old_ndev;
> +	struct ib_port_data *pdata;
> +	unsigned long flags;
> +	int ret;
> +
> +	/*
> +	 * Drivers wish to call this before ib_register_driver, so we have to
> +	 * setup the port data early.
> +	 */
> +	ret = alloc_port_data(ib_dev);
> +	if (ret)
> +		return ret;
> +
> +	if (!rdma_is_port_valid(ib_dev, port))
> +		return -EINVAL;
> +
> +	pdata = &ib_dev->port_data[port];
> +	spin_lock_irqsave(&pdata->netdev_lock, flags);
> +	if (pdata->netdev == ndev) {
> +		spin_unlock_irqrestore(&pdata->netdev_lock, flags);
> +		return 0;
> +	}
> +	old_ndev = pdata->netdev;
> +
> +	if (ndev)
> +		dev_hold(ndev);
> +	pdata->netdev = ndev;
> +	spin_unlock_irqrestore(&pdata->netdev_lock, flags);
> +
> +	if (old_ndev)
> +		dev_put(old_ndev);
> +
> +	return 0;
> +}
> +EXPORT_SYMBOL(ib_device_set_netdev);
> +
> +static void free_netdevs(struct ib_device *ib_dev)
> +{
> +	unsigned long flags;
> +	unsigned int port;
> +
> +	rdma_for_each_port (ib_dev, port) {
> +		struct ib_port_data *pdata = &ib_dev->port_data[port];
> +
> +		spin_lock_irqsave(&pdata->netdev_lock, flags);
> +		if (pdata->netdev) {
> +			dev_put(pdata->netdev);
> +			pdata->netdev = NULL;
> +		}
> +		spin_unlock_irqrestore(&pdata->netdev_lock, flags);
> +	}
> +}
> +
> +struct net_device *ib_device_get_netdev(struct ib_device *ib_dev,
> +					unsigned int port)
> +{
> +	struct ib_port_data *pdata;
> +	struct net_device *res;
> +
> +	if (!rdma_is_port_valid(ib_dev, port))
> +		return NULL;
> +
> +	pdata = &ib_dev->port_data[port];
> +
> +	/*
> +	 * New drivers should use ib_device_set_netdev() not the legacy
> +	 * get_netdev().
> +	 */
> +	if (ib_dev->ops.get_netdev)
> +		res = ib_dev->ops.get_netdev(ib_dev, port);
> +	else {
> +		spin_lock(&pdata->netdev_lock);
> +		res = pdata->netdev;
> +		if (res)
> +			dev_hold(res);
> +		spin_unlock(&pdata->netdev_lock);
> +	}
> +
> +	/*
> +	 * If we are starting to unregister expedite things by preventing
> +	 * propagation of an unregistering netdev.
> +	 */
> +	if (res && res->reg_state != NETREG_REGISTERED) {
> +		dev_put(res);
> +		return NULL;
> +	}
> +
> +	return res;
> +}
> +
>  /**
>   * ib_enum_roce_netdev - enumerate all RoCE ports
>   * @ib_dev : IB device we want to query
> @@ -1020,16 +1158,8 @@ void ib_enum_roce_netdev(struct ib_device *ib_dev,
>  
>  	rdma_for_each_port (ib_dev, port)
>  		if (rdma_protocol_roce(ib_dev, port)) {
> -			struct net_device *idev = NULL;
> -
> -			if (ib_dev->ops.get_netdev)
> -				idev = ib_dev->ops.get_netdev(ib_dev, port);
> -
> -			if (idev &&
> -			    idev->reg_state >= NETREG_UNREGISTERED) {
> -				dev_put(idev);
> -				idev = NULL;
> -			}
> +			struct net_device *idev =
> +				ib_device_get_netdev(ib_dev, port);
>  
>  			if (filter(ib_dev, port, idev, filter_cookie))
>  				cb(ib_dev, port, idev, cookie);
> diff --git a/drivers/infiniband/core/nldev.c b/drivers/infiniband/core/nldev.c
> index 6708277aad7e4c..a0e635c5715905 100644
> --- a/drivers/infiniband/core/nldev.c
> +++ b/drivers/infiniband/core/nldev.c
> @@ -262,9 +262,7 @@ static int fill_port_info(struct sk_buff *msg,
>  	if (nla_put_u8(msg, RDMA_NLDEV_ATTR_PORT_PHYS_STATE, attr.phys_state))
>  		return -EMSGSIZE;
>  
> -	if (device->ops.get_netdev)
> -		netdev = device->ops.get_netdev(device, port);
> -
> +	netdev = ib_device_get_netdev(device, port);
>  	if (netdev && net_eq(dev_net(netdev), net)) {
>  		ret = nla_put_u32(msg,
>  				  RDMA_NLDEV_ATTR_NDEV_INDEX, netdev->ifindex);
> diff --git a/drivers/infiniband/core/verbs.c b/drivers/infiniband/core/verbs.c
> index de5d895a50544c..5a5e83f5f0fc4c 100644
> --- a/drivers/infiniband/core/verbs.c
> +++ b/drivers/infiniband/core/verbs.c
> @@ -1723,10 +1723,7 @@ int ib_get_eth_speed(struct ib_device *dev, u8 port_num, u8 *speed, u8 *width)
>  	if (rdma_port_get_link_layer(dev, port_num) != IB_LINK_LAYER_ETHERNET)
>  		return -EINVAL;
>  
> -	if (!dev->ops.get_netdev)
> -		return -EOPNOTSUPP;
> -
> -	netdev = dev->ops.get_netdev(dev, port_num);
> +	netdev = ib_device_get_netdev(dev, port_num);
>  	if (!netdev)
>  		return -ENODEV;
>  
> diff --git a/include/rdma/ib_verbs.h b/include/rdma/ib_verbs.h
> index e5eea5a50e27f2..358bda4a76ff42 100644
> --- a/include/rdma/ib_verbs.h
> +++ b/include/rdma/ib_verbs.h
> @@ -2204,6 +2204,9 @@ struct ib_port_data {
>  	struct list_head pkey_list;
>  
>  	struct ib_port_cache cache;
> +
> +	spinlock_t netdev_lock;
> +	struct net_device *netdev;
>  };
>  
>  /* rdma netdev type - specifies protocol type */
> @@ -3997,6 +4000,10 @@ void ib_device_put(struct ib_device *device);
>  struct net_device *ib_get_net_dev_by_params(struct ib_device *dev, u8 port,
>  					    u16 pkey, const union ib_gid *gid,
>  					    const struct sockaddr *addr);
> +int ib_device_set_netdev(struct ib_device *ib_dev, struct net_device *ndev,
> +			 unsigned int port);
> +struct net_device *ib_device_netdev(struct ib_device *dev, u8 port);
> +
>  struct ib_wq *ib_create_wq(struct ib_pd *pd,
>  			   struct ib_wq_init_attr *init_attr);
>  int ib_destroy_wq(struct ib_wq *wq);
> -- 
> 2.20.1
>
Jason Gunthorpe Feb. 15, 2019, 9:59 p.m. UTC | #2
On Fri, Feb 15, 2019 at 11:03:16AM -0800, Ira Weiny wrote:
> > diff --git a/drivers/infiniband/core/cache.c b/drivers/infiniband/core/cache.c
> > index a28dc1901c8000..e191f3d86b41d5 100644
> > +++ b/drivers/infiniband/core/cache.c
> > @@ -547,21 +547,19 @@ int ib_cache_gid_add(struct ib_device *ib_dev, u8 port,
> >  	unsigned long mask;
> >  	int ret;
> >  
> > -	if (ib_dev->ops.get_netdev) {
> > -		idev = ib_dev->ops.get_netdev(ib_dev, port);
> > -		if (idev && attr->ndev != idev) {
> > -			union ib_gid default_gid;
> > -
> > -			/* Adding default GIDs in not permitted */
> > -			make_default_gid(idev, &default_gid);
> > -			if (!memcmp(gid, &default_gid, sizeof(*gid))) {
> > -				dev_put(idev);
> > -				return -EPERM;
> > -			}
> > -		}
> > -		if (idev)
> > +	idev = ib_device_get_netdev(ib_dev, port);
> > +	if (idev && attr->ndev != idev) {
> > +		union ib_gid default_gid;
> > +
> > +		/* Adding default GIDs in not permitted */
> 
> NIT: "is not"?

Okay, I can fold that into this patch.

> > -static int setup_port_data(struct ib_device *device)
> > +static int alloc_port_data(struct ib_device *device)
> >  {
> >  	unsigned int port;
> > -	int ret;
> > +
> > +	if (device->port_data)
> > +		return 0;
> > +
> > +	/* This can only be called once the physical port range is defined */
> > +	if (WARN_ON(!device->phys_port_cnt))
> > +		return -EINVAL;
> 
> Was this port stuff supposed to be part of the previous patch?

No, this is splitting one function into alloc & setup, the new alloc
function has this extra protection as it would be easy to call it out
of order in the driver conversions.

Jason
diff mbox series

Patch

diff --git a/drivers/infiniband/core/cache.c b/drivers/infiniband/core/cache.c
index a28dc1901c8000..e191f3d86b41d5 100644
--- a/drivers/infiniband/core/cache.c
+++ b/drivers/infiniband/core/cache.c
@@ -547,21 +547,19 @@  int ib_cache_gid_add(struct ib_device *ib_dev, u8 port,
 	unsigned long mask;
 	int ret;
 
-	if (ib_dev->ops.get_netdev) {
-		idev = ib_dev->ops.get_netdev(ib_dev, port);
-		if (idev && attr->ndev != idev) {
-			union ib_gid default_gid;
-
-			/* Adding default GIDs in not permitted */
-			make_default_gid(idev, &default_gid);
-			if (!memcmp(gid, &default_gid, sizeof(*gid))) {
-				dev_put(idev);
-				return -EPERM;
-			}
-		}
-		if (idev)
+	idev = ib_device_get_netdev(ib_dev, port);
+	if (idev && attr->ndev != idev) {
+		union ib_gid default_gid;
+
+		/* Adding default GIDs in not permitted */
+		make_default_gid(idev, &default_gid);
+		if (!memcmp(gid, &default_gid, sizeof(*gid))) {
 			dev_put(idev);
+			return -EPERM;
+		}
 	}
+	if (idev)
+		dev_put(idev);
 
 	mask = GID_ATTR_FIND_MASK_GID |
 	       GID_ATTR_FIND_MASK_GID_TYPE |
diff --git a/drivers/infiniband/core/core_priv.h b/drivers/infiniband/core/core_priv.h
index a1826f4c2e23ee..8aa4872e07a0da 100644
--- a/drivers/infiniband/core/core_priv.h
+++ b/drivers/infiniband/core/core_priv.h
@@ -64,6 +64,9 @@  typedef void (*roce_netdev_callback)(struct ib_device *device, u8 port,
 typedef bool (*roce_netdev_filter)(struct ib_device *device, u8 port,
 				   struct net_device *idev, void *cookie);
 
+struct net_device *ib_device_get_netdev(struct ib_device *ib_dev,
+					unsigned int port);
+
 void ib_enum_roce_netdev(struct ib_device *ib_dev,
 			 roce_netdev_filter filter,
 			 void *filter_cookie,
diff --git a/drivers/infiniband/core/device.c b/drivers/infiniband/core/device.c
index 58591408bb1b35..14c91f9af6ccc9 100644
--- a/drivers/infiniband/core/device.c
+++ b/drivers/infiniband/core/device.c
@@ -133,6 +133,7 @@  static void *xan_find_marked(struct xarray *xa, unsigned long *indexp,
 	     !xa_is_err(entry);                                                \
 	     (index)++, entry = xan_find_marked(xa, &(index), filter))
 
+static void free_netdevs(struct ib_device *ib_dev);
 static int ib_security_change(struct notifier_block *nb, unsigned long event,
 			      void *lsm_data);
 static void ib_policy_change_task(struct work_struct *work);
@@ -289,6 +290,7 @@  static void ib_device_release(struct device *device)
 {
 	struct ib_device *dev = container_of(device, struct ib_device, dev);
 
+	free_netdevs(dev);
 	WARN_ON(refcount_read(&dev->refcount));
 	ib_cache_release_one(dev);
 	ib_security_release_port_pkey_list(dev);
@@ -365,6 +367,9 @@  EXPORT_SYMBOL(_ib_alloc_device);
  */
 void ib_dealloc_device(struct ib_device *device)
 {
+	/* Expedite releasing netdev references */
+	free_netdevs(device);
+
 	WARN_ON(!xa_empty(&device->client_data));
 	WARN_ON(refcount_read(&device->refcount));
 	rdma_restrack_clean(device);
@@ -454,16 +459,16 @@  static void remove_client_context(struct ib_device *device,
 	up_read(&device->client_data_rwsem);
 }
 
-static int verify_immutable(const struct ib_device *dev, u8 port)
-{
-	return WARN_ON(!rdma_cap_ib_mad(dev, port) &&
-			    rdma_max_mad_size(dev, port) != 0);
-}
-
-static int setup_port_data(struct ib_device *device)
+static int alloc_port_data(struct ib_device *device)
 {
 	unsigned int port;
-	int ret;
+
+	if (device->port_data)
+		return 0;
+
+	/* This can only be called once the physical port range is defined */
+	if (WARN_ON(!device->phys_port_cnt))
+		return -EINVAL;
 
 	/*
 	 * device->port_data is indexed directly by the port number to make
@@ -482,6 +487,28 @@  static int setup_port_data(struct ib_device *device)
 
 		spin_lock_init(&pdata->pkey_list_lock);
 		INIT_LIST_HEAD(&pdata->pkey_list);
+		spin_lock_init(&pdata->netdev_lock);
+	}
+	return 0;
+}
+
+static int verify_immutable(const struct ib_device *dev, u8 port)
+{
+	return WARN_ON(!rdma_cap_ib_mad(dev, port) &&
+			    rdma_max_mad_size(dev, port) != 0);
+}
+
+static int setup_port_data(struct ib_device *device)
+{
+	unsigned int port;
+	int ret;
+
+	ret = alloc_port_data(device);
+	if (ret)
+		return ret;
+
+	rdma_for_each_port (device, port) {
+		struct ib_port_data *pdata = &device->port_data[port];
 
 		ret = device->ops.get_port_immutable(device, port,
 						     &pdata->immutable);
@@ -675,6 +702,9 @@  static void disable_device(struct ib_device *device)
 	/* Pairs with refcount_set in enable_device */
 	ib_device_put(device);
 	wait_for_completion(&device->unreg_completion);
+
+	/* Expedite removing unregistered pointers from the hash table */
+	free_netdevs(device);
 }
 
 /*
@@ -998,6 +1028,114 @@  int ib_query_port(struct ib_device *device,
 }
 EXPORT_SYMBOL(ib_query_port);
 
+/**
+ * ib_device_set_netdev - Associate the ib_dev with an underlying net_device
+ * @ib_dev: Device to modify
+ * @ndev: net_device to affiliate, may be NULL
+ * @port: IB port the net_device is connected to
+ *
+ * Drivers should use this to link the ib_device to a netdev so the netdev
+ * shows up in interfaces like ib_enum_roce_netdev. Only one netdev may be
+ * affiliated with any port.
+ *
+ * The caller must ensure that the given ndev is not unregistered or
+ * unregistering, and that either the ib_device is unregistered or
+ * ib_device_set_netdev() is called with NULL when the ndev sends a
+ * NETDEV_UNREGISTER event.
+ */
+int ib_device_set_netdev(struct ib_device *ib_dev, struct net_device *ndev,
+			 unsigned int port)
+{
+	struct net_device *old_ndev;
+	struct ib_port_data *pdata;
+	unsigned long flags;
+	int ret;
+
+	/*
+	 * Drivers wish to call this before ib_register_driver, so we have to
+	 * setup the port data early.
+	 */
+	ret = alloc_port_data(ib_dev);
+	if (ret)
+		return ret;
+
+	if (!rdma_is_port_valid(ib_dev, port))
+		return -EINVAL;
+
+	pdata = &ib_dev->port_data[port];
+	spin_lock_irqsave(&pdata->netdev_lock, flags);
+	if (pdata->netdev == ndev) {
+		spin_unlock_irqrestore(&pdata->netdev_lock, flags);
+		return 0;
+	}
+	old_ndev = pdata->netdev;
+
+	if (ndev)
+		dev_hold(ndev);
+	pdata->netdev = ndev;
+	spin_unlock_irqrestore(&pdata->netdev_lock, flags);
+
+	if (old_ndev)
+		dev_put(old_ndev);
+
+	return 0;
+}
+EXPORT_SYMBOL(ib_device_set_netdev);
+
+static void free_netdevs(struct ib_device *ib_dev)
+{
+	unsigned long flags;
+	unsigned int port;
+
+	rdma_for_each_port (ib_dev, port) {
+		struct ib_port_data *pdata = &ib_dev->port_data[port];
+
+		spin_lock_irqsave(&pdata->netdev_lock, flags);
+		if (pdata->netdev) {
+			dev_put(pdata->netdev);
+			pdata->netdev = NULL;
+		}
+		spin_unlock_irqrestore(&pdata->netdev_lock, flags);
+	}
+}
+
+struct net_device *ib_device_get_netdev(struct ib_device *ib_dev,
+					unsigned int port)
+{
+	struct ib_port_data *pdata;
+	struct net_device *res;
+
+	if (!rdma_is_port_valid(ib_dev, port))
+		return NULL;
+
+	pdata = &ib_dev->port_data[port];
+
+	/*
+	 * New drivers should use ib_device_set_netdev() not the legacy
+	 * get_netdev().
+	 */
+	if (ib_dev->ops.get_netdev)
+		res = ib_dev->ops.get_netdev(ib_dev, port);
+	else {
+		spin_lock(&pdata->netdev_lock);
+		res = pdata->netdev;
+		if (res)
+			dev_hold(res);
+		spin_unlock(&pdata->netdev_lock);
+	}
+
+	/*
+	 * If we are starting to unregister expedite things by preventing
+	 * propagation of an unregistering netdev.
+	 */
+	if (res && res->reg_state != NETREG_REGISTERED) {
+		dev_put(res);
+		return NULL;
+	}
+
+	return res;
+}
+
 /**
  * ib_enum_roce_netdev - enumerate all RoCE ports
  * @ib_dev : IB device we want to query
@@ -1020,16 +1158,8 @@  void ib_enum_roce_netdev(struct ib_device *ib_dev,
 
 	rdma_for_each_port (ib_dev, port)
 		if (rdma_protocol_roce(ib_dev, port)) {
-			struct net_device *idev = NULL;
-
-			if (ib_dev->ops.get_netdev)
-				idev = ib_dev->ops.get_netdev(ib_dev, port);
-
-			if (idev &&
-			    idev->reg_state >= NETREG_UNREGISTERED) {
-				dev_put(idev);
-				idev = NULL;
-			}
+			struct net_device *idev =
+				ib_device_get_netdev(ib_dev, port);
 
 			if (filter(ib_dev, port, idev, filter_cookie))
 				cb(ib_dev, port, idev, cookie);
diff --git a/drivers/infiniband/core/nldev.c b/drivers/infiniband/core/nldev.c
index 6708277aad7e4c..a0e635c5715905 100644
--- a/drivers/infiniband/core/nldev.c
+++ b/drivers/infiniband/core/nldev.c
@@ -262,9 +262,7 @@  static int fill_port_info(struct sk_buff *msg,
 	if (nla_put_u8(msg, RDMA_NLDEV_ATTR_PORT_PHYS_STATE, attr.phys_state))
 		return -EMSGSIZE;
 
-	if (device->ops.get_netdev)
-		netdev = device->ops.get_netdev(device, port);
-
+	netdev = ib_device_get_netdev(device, port);
 	if (netdev && net_eq(dev_net(netdev), net)) {
 		ret = nla_put_u32(msg,
 				  RDMA_NLDEV_ATTR_NDEV_INDEX, netdev->ifindex);
diff --git a/drivers/infiniband/core/verbs.c b/drivers/infiniband/core/verbs.c
index de5d895a50544c..5a5e83f5f0fc4c 100644
--- a/drivers/infiniband/core/verbs.c
+++ b/drivers/infiniband/core/verbs.c
@@ -1723,10 +1723,7 @@  int ib_get_eth_speed(struct ib_device *dev, u8 port_num, u8 *speed, u8 *width)
 	if (rdma_port_get_link_layer(dev, port_num) != IB_LINK_LAYER_ETHERNET)
 		return -EINVAL;
 
-	if (!dev->ops.get_netdev)
-		return -EOPNOTSUPP;
-
-	netdev = dev->ops.get_netdev(dev, port_num);
+	netdev = ib_device_get_netdev(dev, port_num);
 	if (!netdev)
 		return -ENODEV;
 
diff --git a/include/rdma/ib_verbs.h b/include/rdma/ib_verbs.h
index e5eea5a50e27f2..358bda4a76ff42 100644
--- a/include/rdma/ib_verbs.h
+++ b/include/rdma/ib_verbs.h
@@ -2204,6 +2204,9 @@  struct ib_port_data {
 	struct list_head pkey_list;
 
 	struct ib_port_cache cache;
+
+	spinlock_t netdev_lock;
+	struct net_device *netdev;
 };
 
 /* rdma netdev type - specifies protocol type */
@@ -3997,6 +4000,10 @@  void ib_device_put(struct ib_device *device);
 struct net_device *ib_get_net_dev_by_params(struct ib_device *dev, u8 port,
 					    u16 pkey, const union ib_gid *gid,
 					    const struct sockaddr *addr);
+int ib_device_set_netdev(struct ib_device *ib_dev, struct net_device *ndev,
+			 unsigned int port);
+struct net_device *ib_device_netdev(struct ib_device *dev, u8 port);
+
 struct ib_wq *ib_create_wq(struct ib_pd *pd,
 			   struct ib_wq_init_attr *init_attr);
 int ib_destroy_wq(struct ib_wq *wq);