@@ -8,6 +8,7 @@
#include <linux/ipv6.h>
#include <linux/mpls.h>
#include <linux/netconf.h>
+#include <linux/nospec.h>
#include <linux/vmalloc.h>
#include <linux/percpu.h>
#include <net/ip.h>
@@ -77,12 +78,13 @@ static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
{
struct mpls_route *rt = NULL;
+ struct mpls_route __rcu **platform_label =
+ rcu_dereference(net->mpls.platform_label);
+ struct mpls_route __rcu **rtp;
- if (index < net->mpls.platform_labels) {
- struct mpls_route __rcu **platform_label =
- rcu_dereference(net->mpls.platform_label);
- rt = rcu_dereference(platform_label[index]);
- }
+ rtp = array_ptr(platform_label, index, net->mpls.platform_labels);
+ if (rtp)
+ rt = rcu_dereference(*rtp);
return rt;
}