diff mbox series

[net-next,v1,1/2] tools: ynl: Use dict of predefined Structs to decode scalar types

Message ID 20230521170733.13151-2-donald.hunter@gmail.com (mailing list archive)
State Superseded
Delegated to: Netdev Maintainers
Headers show
Series tools: ynl: Add byte-order support for struct members | expand

Checks

Context Check Description
netdev/series_format success Posting correctly formatted
netdev/tree_selection success Clearly marked for net-next
netdev/fixes_present success Fixes tag not required for -next series
netdev/header_inline success No static functions without inline keyword in header files
netdev/build_32bit success Errors and warnings before: 8 this patch: 8
netdev/cc_maintainers success CCed 5 of 5 maintainers
netdev/build_clang success Errors and warnings before: 8 this patch: 8
netdev/verify_signedoff success Signed-off-by tag matches author and committer
netdev/deprecated_api success None detected
netdev/check_selftest success No net selftest shell script
netdev/verify_fixes success No Fixes tag
netdev/build_allmodconfig_warn success Errors and warnings before: 8 this patch: 8
netdev/checkpatch success total: 0 errors, 0 warnings, 0 checks, 201 lines checked
netdev/kdoc success Errors and warnings before: 0 this patch: 0
netdev/source_inline success Was 0 now: 0

Commit Message

Donald Hunter May 21, 2023, 5:07 p.m. UTC
Use a dict of predefined Struct() objects to decode scalar types in native,
big or little endian format. This removes the repetitive code for the
scalar variants and ensures all the signed variants are supported.

Signed-off-by: Donald Hunter <donald.hunter@gmail.com>
---
 tools/net/ynl/lib/ynl.py | 107 ++++++++++++++++++---------------------
 1 file changed, 48 insertions(+), 59 deletions(-)

Comments

Jakub Kicinski May 23, 2023, 2:37 a.m. UTC | #1
On Sun, 21 May 2023 18:07:32 +0100 Donald Hunter wrote:
> Use a dict of predefined Struct() objects to decode scalar types in native,
> big or little endian format. This removes the repetitive code for the
> scalar variants and ensures all the signed variants are supported.

> @@ -115,17 +116,17 @@ class NlAttr:
>          return self.raw
>  
>      def as_c_array(self, type):
> -        format, _ = self.type_formats[type]
> -        return list({ x[0] for x in struct.iter_unpack(format, self.raw) })
> +        format = self.get_format(type)
> +        return list({ x[0] for x in format.iter_unpack(self.raw) })

I probably asked about this before, and maybe not the question 
for this series but - why list({ ... }) and not [...]?

>          else:
> -            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
> +            try:
> +                format = NlAttr.get_format(attr['type'], attr.byte_order)
> +                attr_payload = format.pack(int(value))
> +            except:
> +                raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')

Could we do:

	elif attr["type"] in NlAttr.type_formats:

instead? Maybe my C brain treats exceptions as too exceptional..

> +            elif attr_spec["type"]:
> +                try:
> +                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
> +                except:
> +                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')

Same here.

Nice cleanup!
Donald Hunter May 23, 2023, 8:21 a.m. UTC | #2
Jakub Kicinski <kuba@kernel.org> writes:

> On Sun, 21 May 2023 18:07:32 +0100 Donald Hunter wrote:
>> Use a dict of predefined Struct() objects to decode scalar types in native,
>> big or little endian format. This removes the repetitive code for the
>> scalar variants and ensures all the signed variants are supported.
>
>> @@ -115,17 +116,17 @@ class NlAttr:
>>          return self.raw
>>  
>>      def as_c_array(self, type):
>> -        format, _ = self.type_formats[type]
>> -        return list({ x[0] for x in struct.iter_unpack(format, self.raw) })
>> +        format = self.get_format(type)
>> +        return list({ x[0] for x in format.iter_unpack(self.raw) })
>
> I probably asked about this before, and maybe not the question 
> for this series but - why list({ ... }) and not [...]?

It looks like I cargo-culted something there, and it's just plain
wrong. Reading it now, it's clearly a set comprehension coerced into a
list, which could well change ordering.

I'll fix this in the next version.

>>          else:
>> -            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
>> +            try:
>> +                format = NlAttr.get_format(attr['type'], attr.byte_order)
>> +                attr_payload = format.pack(int(value))
>> +            except:
>> +                raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
>
> Could we do:
>
> 	elif attr["type"] in NlAttr.type_formats:
>
> instead? Maybe my C brain treats exceptions as too exceptional..

Good suggestion, that's much cleaner.

>> +            elif attr_spec["type"]:
>> +                try:
>> +                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
>> +                except:
>> +                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
>
> Same here.

Ack.

> Nice cleanup!
diff mbox series

Patch

diff --git a/tools/net/ynl/lib/ynl.py b/tools/net/ynl/lib/ynl.py
index aa77bcae4807..62e31db07c1f 100644
--- a/tools/net/ynl/lib/ynl.py
+++ b/tools/net/ynl/lib/ynl.py
@@ -1,10 +1,12 @@ 
 # SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
 
+from collections import namedtuple
 import functools
 import os
 import random
 import socket
 import struct
+from struct import Struct
 import yaml
 
 from .nlspec import SpecFamily
@@ -76,10 +78,17 @@  class NlError(Exception):
 
 
 class NlAttr:
-    type_formats = { 'u8' : ('B', 1), 's8' : ('b', 1),
-                     'u16': ('H', 2), 's16': ('h', 2),
-                     'u32': ('I', 4), 's32': ('i', 4),
-                     'u64': ('Q', 8), 's64': ('q', 8) }
+    ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
+    type_formats = {
+        'u8' : ScalarFormat(Struct('B'), Struct("B"),  Struct("B")),
+        's8' : ScalarFormat(Struct('b'), Struct("b"),  Struct("b")),
+        'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
+        's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
+        'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
+        's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
+        'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
+        's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
+    }
 
     def __init__(self, raw, offset):
         self._len, self._type = struct.unpack("HH", raw[offset:offset + 4])
@@ -88,25 +97,17 @@  class NlAttr:
         self.full_len = (self.payload_len + 3) & ~3
         self.raw = raw[offset + 4:offset + self.payload_len]
 
-    def format_byte_order(byte_order):
+    @classmethod
+    def get_format(cls, attr_type, byte_order=None):
+        format = cls.type_formats[attr_type]
         if byte_order:
-            return ">" if byte_order == "big-endian" else "<"
-        return ""
+            return format.big if byte_order == "big-endian" \
+                else format.little
+        return format.native
 
-    def as_u8(self):
-        return struct.unpack("B", self.raw)[0]
-
-    def as_u16(self, byte_order=None):
-        endian = NlAttr.format_byte_order(byte_order)
-        return struct.unpack(f"{endian}H", self.raw)[0]
-
-    def as_u32(self, byte_order=None):
-        endian = NlAttr.format_byte_order(byte_order)
-        return struct.unpack(f"{endian}I", self.raw)[0]
-
-    def as_u64(self, byte_order=None):
-        endian = NlAttr.format_byte_order(byte_order)
-        return struct.unpack(f"{endian}Q", self.raw)[0]
+    def as_scalar(self, attr_type, byte_order=None):
+        format = self.get_format(attr_type, byte_order)
+        return format.unpack(self.raw)[0]
 
     def as_strz(self):
         return self.raw.decode('ascii')[:-1]
@@ -115,17 +116,17 @@  class NlAttr:
         return self.raw
 
     def as_c_array(self, type):
-        format, _ = self.type_formats[type]
-        return list({ x[0] for x in struct.iter_unpack(format, self.raw) })
+        format = self.get_format(type)
+        return list({ x[0] for x in format.iter_unpack(self.raw) })
 
     def as_struct(self, members):
         value = dict()
         offset = 0
         for m in members:
             # TODO: handle non-scalar members
-            format, size = self.type_formats[m.type]
-            decoded = struct.unpack_from(format, self.raw, offset)
-            offset += size
+            format = self.get_format(m.type)
+            decoded = format.unpack_from(self.raw, offset)
+            offset += format.size
             value[m.name] = decoded[0]
         return value
 
@@ -184,11 +185,11 @@  class NlMsg:
                 if extack.type == Netlink.NLMSGERR_ATTR_MSG:
                     self.extack['msg'] = extack.as_strz()
                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
-                    self.extack['miss-type'] = extack.as_u32()
+                    self.extack['miss-type'] = extack.as_scalar('u32')
                 elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
-                    self.extack['miss-nest'] = extack.as_u32()
+                    self.extack['miss-nest'] = extack.as_scalar('u32')
                 elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
-                    self.extack['bad-attr-offs'] = extack.as_u32()
+                    self.extack['bad-attr-offs'] = extack.as_scalar('u32')
                 else:
                     if 'unknown' not in self.extack:
                         self.extack['unknown'] = []
@@ -272,11 +273,11 @@  def _genl_load_families():
                 fam = dict()
                 for attr in gm.raw_attrs:
                     if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
-                        fam['id'] = attr.as_u16()
+                        fam['id'] = attr.as_scalar('u16')
                     elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
                         fam['name'] = attr.as_strz()
                     elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
-                        fam['maxattr'] = attr.as_u32()
+                        fam['maxattr'] = attr.as_scalar('u32')
                     elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
                         fam['mcast'] = dict()
                         for entry in NlAttrs(attr.raw):
@@ -286,7 +287,7 @@  def _genl_load_families():
                                 if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
                                     mcast_name = entry_attr.as_strz()
                                 elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
-                                    mcast_id = entry_attr.as_u32()
+                                    mcast_id = entry_attr.as_scalar('u32')
                             if mcast_name and mcast_id is not None:
                                 fam['mcast'][mcast_name] = mcast_id
                 if 'name' in fam and 'id' in fam:
@@ -304,9 +305,9 @@  class GenlMsg:
 
         self.fixed_header_attrs = dict()
         for m in fixed_header_members:
-            format, size = NlAttr.type_formats[m.type]
-            decoded = struct.unpack_from(format, nl_msg.raw, offset)
-            offset += size
+            format = NlAttr.get_format(m.type)
+            decoded = format.unpack_from(nl_msg.raw, offset)
+            offset += format.size
             self.fixed_header_attrs[m.name] = decoded[0]
 
         self.raw = nl_msg.raw[offset:]
@@ -381,23 +382,16 @@  class YnlFamily(SpecFamily):
                 attr_payload += self._add_attr(attr['nested-attributes'], subname, subvalue)
         elif attr["type"] == 'flag':
             attr_payload = b''
-        elif attr["type"] == 'u8':
-            attr_payload = struct.pack("B", int(value))
-        elif attr["type"] == 'u16':
-            endian = NlAttr.format_byte_order(attr.byte_order)
-            attr_payload = struct.pack(f"{endian}H", int(value))
-        elif attr["type"] == 'u32':
-            endian = NlAttr.format_byte_order(attr.byte_order)
-            attr_payload = struct.pack(f"{endian}I", int(value))
-        elif attr["type"] == 'u64':
-            endian = NlAttr.format_byte_order(attr.byte_order)
-            attr_payload = struct.pack(f"{endian}Q", int(value))
         elif attr["type"] == 'string':
             attr_payload = str(value).encode('ascii') + b'\x00'
         elif attr["type"] == 'binary':
             attr_payload = value
         else:
-            raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
+            try:
+                format = NlAttr.get_format(attr['type'], attr.byte_order)
+                attr_payload = format.pack(int(value))
+            except:
+                raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
 
         pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
         return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
@@ -434,22 +428,17 @@  class YnlFamily(SpecFamily):
             if attr_spec["type"] == 'nest':
                 subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'])
                 decoded = subdict
-            elif attr_spec['type'] == 'u8':
-                decoded = attr.as_u8()
-            elif attr_spec['type'] == 'u16':
-                decoded = attr.as_u16(attr_spec.byte_order)
-            elif attr_spec['type'] == 'u32':
-                decoded = attr.as_u32(attr_spec.byte_order)
-            elif attr_spec['type'] == 'u64':
-                decoded = attr.as_u64(attr_spec.byte_order)
             elif attr_spec["type"] == 'string':
                 decoded = attr.as_strz()
             elif attr_spec["type"] == 'binary':
                 decoded = self._decode_binary(attr, attr_spec)
             elif attr_spec["type"] == 'flag':
                 decoded = True
-            else:
-                raise Exception(f'Unknown {attr.type} {attr_spec["name"]} {attr_spec["type"]}')
+            elif attr_spec["type"]:
+                try:
+                    decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
+                except:
+                    raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
 
             if not attr_spec.is_multi:
                 rsp[attr_spec['name']] = decoded
@@ -555,8 +544,8 @@  class YnlFamily(SpecFamily):
             fixed_header_members = self.consts[op.fixed_header].members
             for m in fixed_header_members:
                 value = vals.pop(m.name)
-                format, _ = NlAttr.type_formats[m.type]
-                msg += struct.pack(format, value)
+                format = NlAttr.get_format(m.type)
+                msg += format.pack(value)
         for name, value in vals.items():
             msg += self._add_attr(op.attr_set.name, name, value)
         msg = _genl_msg_finalize(msg)