diff mbox

[v1,09/12] IB/cma: validate routing of incoming requests

Message ID 1434976961-27424-10-git-send-email-haggaie@mellanox.com (mailing list archive)
State Superseded
Headers show

Commit Message

Haggai Eran June 22, 2015, 12:42 p.m. UTC
Pass incoming request parameters through the relevant IPv4/IPv6 routing
tables and make sure the network stack is configured to handle such
requests.

Signed-off-by: Haggai Eran <haggaie@mellanox.com>
---
 drivers/infiniband/core/cma.c | 95 +++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 92 insertions(+), 3 deletions(-)
diff mbox

Patch

diff --git a/drivers/infiniband/core/cma.c b/drivers/infiniband/core/cma.c
index 4701167f5117..759697d6e627 100644
--- a/drivers/infiniband/core/cma.c
+++ b/drivers/infiniband/core/cma.c
@@ -46,6 +46,8 @@ 
 
 #include <net/tcp.h>
 #include <net/ipv6.h>
+#include <net/ip_fib.h>
+#include <net/ip6_route.h>
 
 #include <rdma/rdma_cm.h>
 #include <rdma/rdma_cm_ib.h>
@@ -1062,15 +1064,97 @@  static int cma_save_req_info(const struct ib_cm_event *ib_event,
 	return 0;
 }
 
+static bool validate_ipv4_net_dev(struct net_device *net_dev,
+				  const struct sockaddr_in *dst_addr,
+				  const struct sockaddr_in *src_addr)
+{
+	__be32 daddr = dst_addr->sin_addr.s_addr,
+	       saddr = src_addr->sin_addr.s_addr;
+	struct fib_result res;
+	struct flowi4 fl4;
+	int err;
+	bool ret;
+
+	if (ipv4_is_multicast(saddr) || ipv4_is_lbcast(saddr) ||
+	    ipv4_is_lbcast(daddr) || ipv4_is_zeronet(saddr) ||
+	    ipv4_is_zeronet(daddr) || ipv4_is_loopback(daddr) ||
+	    ipv4_is_loopback(saddr))
+		return false;
+
+	memset(&fl4, 0, sizeof(fl4));
+	fl4.flowi4_iif = net_dev->ifindex;
+	fl4.daddr = daddr;
+	fl4.saddr = saddr;
+
+	rcu_read_lock();
+	err = fib_lookup(dev_net(net_dev), &fl4, &res);
+	if (err)
+		return false;
+
+	ret = FIB_RES_DEV(res) == net_dev;
+	rcu_read_unlock();
+
+	return ret;
+}
+
+static bool validate_ipv6_net_dev(struct net_device *net_dev,
+				  const struct sockaddr_in6 *dst_addr,
+				  const struct sockaddr_in6 *src_addr)
+{
+#if IS_ENABLED(CONFIG_IPV6)
+	const int strict = ipv6_addr_type(&dst_addr->sin6_addr) &
+			   IPV6_ADDR_LINKLOCAL;
+	struct rt6_info *rt = rt6_lookup(dev_net(net_dev), &dst_addr->sin6_addr,
+					 &src_addr->sin6_addr, net_dev->ifindex,
+					 strict);
+	bool ret;
+
+	if (!rt)
+		return false;
+
+	ret = rt->rt6i_idev->dev == net_dev;
+	ip6_rt_put(rt);
+
+	return ret;
+#else
+	return false;
+#endif
+}
+
+static bool validate_net_dev(struct net_device *net_dev,
+			     const struct sockaddr *daddr,
+			     const struct sockaddr *saddr)
+{
+	const struct sockaddr_in *daddr4 = (const struct sockaddr_in *)daddr;
+	const struct sockaddr_in *saddr4 = (const struct sockaddr_in *)saddr;
+	const struct sockaddr_in6 *daddr6 = (const struct sockaddr_in6 *)daddr;
+	const struct sockaddr_in6 *saddr6 = (const struct sockaddr_in6 *)saddr;
+
+	switch (daddr->sa_family) {
+	case AF_INET:
+		return saddr->sa_family == AF_INET &&
+		       validate_ipv4_net_dev(net_dev, daddr4, saddr4);
+
+	case AF_INET6:
+		return saddr->sa_family == AF_INET6 &&
+		       validate_ipv6_net_dev(net_dev, daddr6, saddr6);
+
+	default:
+		return false;
+	}
+}
+
 static struct net_device *cma_get_net_dev(struct ib_cm_event *ib_event,
 					  const struct cma_req_info *req)
 {
-	struct sockaddr_storage listen_addr_storage;
-	struct sockaddr *listen_addr = (struct sockaddr *)&listen_addr_storage;
+	struct sockaddr_storage listen_addr_storage, src_addr_storage;
+	struct sockaddr *listen_addr = (struct sockaddr *)&listen_addr_storage,
+			*src_addr = (struct sockaddr *)&src_addr_storage;
 	struct net_device *net_dev;
 	int err;
 
-	err = cma_save_ip_info(listen_addr, NULL, ib_event, req->service_id);
+	err = cma_save_ip_info(listen_addr, src_addr, ib_event,
+			       req->service_id);
 	if (err)
 		return ERR_PTR(err);
 
@@ -1079,6 +1163,11 @@  static struct net_device *cma_get_net_dev(struct ib_cm_event *ib_event,
 	if (!net_dev)
 		return ERR_PTR(-ENODEV);
 
+	if (!validate_net_dev(net_dev, listen_addr, src_addr)) {
+		dev_put(net_dev);
+		return ERR_PTR(-EHOSTUNREACH);
+	}
+
 	return net_dev;
 }