diff mbox series

[4/6] tools: Add provision.py

Message ID 20231213215132.287577-4-denkenz@gmail.com (mailing list archive)
State Superseded
Headers show
Series [1/6] build: Fix typo that breaks --fsanitize=leak check | expand

Commit Message

Denis Kenzior Dec. 13, 2023, 9:51 p.m. UTC
Introduce a new tool that will convert the intermediate JSON format
(documented in doc/provision.rst) to the binary provisioning format,
which will be used by ofonod's default provisioning plugin for
automatically setting up carrier specific settings.

The tool also supports import of and conversion of
'mobile-broadband-provider-info' XML files to the new intermediate JSON
file format.  This is accomplished using:
  % tools/provision.py mbpi-convert --outfile=provision.json

Conversion of JSON intermediate format to binary format is accomplished
using:
  % tools/provision.py generate --infile=provision.json

By default, the output will be placed in the same directory in the file
'provision.db'.  Alternatively, the output file can be specified using
the --outfile option.

Finally, the tool supports a simple selftest method, which can be
invoked as follows:
  % tools/provision.py selftest
---
 tools/provision.py | 727 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 727 insertions(+)
 create mode 100755 tools/provision.py
diff mbox series

Patch

diff --git a/tools/provision.py b/tools/provision.py
new file mode 100755
index 00000000..79273277
--- /dev/null
+++ b/tools/provision.py
@@ -0,0 +1,727 @@ 
+#!/usr/bin/python3
+#
+# oFono - Open Source Telephony
+# Copyright (C) 2023  Cruise, LLC
+#
+# SPDX-License-Identifier: GPL-2.0-only
+import xml.etree.ElementTree as ET
+import sys
+import json
+import bisect
+from argparse import ArgumentParser, FileType
+from pathlib import Path
+import random
+import struct
+import ctypes
+
+class ProviderInfo:
+    sort_order_map = { v : pos for pos, v in
+                      enumerate( ['name',
+                                  'apn',
+                                  'type',
+                                  'protocol',
+                                  'mmsc',
+                                  'mmsproxy',
+                                  'authentication',
+                                  'username',
+                                  'password'] ) }
+
+    @classmethod
+    def rawimport(cls, entry):
+        if 'name' not in entry:
+            raise SystemExit('No name for entry: ' + str(entry))
+
+        info = ProviderInfo(entry['name'])
+
+        for networkid in entry.get('ids', []):
+            if not info.add_id(networkid):
+                raise SystemExit('Invalid network id: ' + str(networkid))
+
+        if 'spn' in entry:
+            if not info.set_spn(entry['spn']):
+                raise SystemExit('Invalid spn: ' + str(spn))
+
+        for apn in entry.get('apns', []):
+            if not info.add_context(apn):
+                raise SystemExit('Invalid apn: ' + str(apn))
+
+        if not info.is_valid():
+            raise SystemExit('Invalid entry: ' + str(entry))
+
+        return info
+
+    def __init__(self, name):
+        self.context_list = []
+        self.mccmnc_list = []
+        self.name = name
+        self.spn = None
+
+    @staticmethod
+    def is_valid_id(id_string, expected_lengths):
+        """
+        Check if the identifier string is valid.
+
+        Parameters:
+        - id_string: The id string to check.
+        - expected_lengths (tuple): A tuple representing the valid range of
+          lengths.
+
+        Returns:
+        - bool: True if the MCC string is valid, False otherwise.
+        """
+        if not id_string.isdigit():
+            return False
+
+        if len(id_string) not in expected_lengths:
+            return False
+
+        if int(id_string) == 0:
+            return False
+
+        return True
+
+    def add_mccmnc(self, mcc, mnc):
+        if not self.is_valid_id(mcc, (3,)) or not self.is_valid_id(mnc, (2, 3)):
+            return False
+
+        bisect.insort(self.mccmnc_list, mcc + mnc)
+        return True
+
+    def add_id(self, mccmnc):
+        if not self.is_valid_id(mccmnc, (5,6)):
+            return False
+
+        if int(mccmnc[:3]) == 0 or int(mccmnc[3:]) == 0:
+            return False
+
+        bisect.insort(self.mccmnc_list, mccmnc)
+        return True
+
+    def set_spn(self, spn):
+        if len(spn) == 0 or len(spn) > 254:
+            return False
+
+        self.spn = spn
+        return True
+
+    def add_context(self, info):
+        info = dict(sorted(info.items(),
+                      key = lambda pair: self.sort_order_map[pair[0]]))
+        self.context_list.append(info)
+
+        return True
+
+    def is_valid(self):
+        return len(self.context_list) and len(self.mccmnc_list)
+
+    def __str__(self):
+        s = 'Provider \'' + self.name + '\''
+
+        if (self.spn != None):
+            s += ' [SPN:\'' + self.spn + '\']'
+
+        s+= ' ' + str(self.mccmnc_list) + '\n'
+
+        for context in self.context_list:
+            s += '\t' + str(context) + '\n'
+
+        return s
+
+class MobileBroadbandProviderInfo:
+    usage_to_type = { 'internet' : ['internet'],
+                        'mms' : ['mms'],
+                        'wap' : ['wap'],
+                        'mms-internet-hipri' : ['internet', 'mms'],
+                        'mms-internet-hipri-fota' : ['internet','mms'],
+                     }
+    @classmethod
+    def type_from_usage(cls, usage):
+        return cls.usage_to_type[usage]
+
+    def __init__(self, xml_path):
+        self.tree = ET.parse(xml_path)
+
+    def parse(self, xml_path):
+        providers = []
+
+        try:
+            tree = ET.parse(xml_path)
+            root = tree.getroot()
+
+            for provider in root.findall('.//provider'):
+                name = provider.find('name')
+                if name is None or not name.text:
+                    continue;
+
+                info = ProviderInfo(name.text)
+
+                for networkid in provider.findall('gsm/network-id'):
+                    info.add_mccmnc(networkid.get('mcc'), networkid.get('mnc'))
+
+                for apn in provider.findall('gsm/apn'):
+                    context = {}
+
+                    context['apn'] = apn.get('value')
+                    if context['apn'] == None:
+                        continue
+
+                    # Usage is missing for some APNs, skip such contexts for now
+                    usage = apn.find('usage')
+                    if usage is None or usage.get('type') is None:
+                        continue;
+
+                    context['type'] = self.type_from_usage(usage.get('type'))
+                    if context['type'] == None:
+                        sys.stderr.write("Unable to convert type: %s\n" %
+                                            usage.get('type'))
+                        continue
+
+                    if 'mms' in context['type']:
+                        mmsc = apn.find('mmsc')
+
+                        # Ignore MMS contexts with no MMSC since it is needed
+                        # to send messages
+                        if mmsc is None or not mmsc.text:
+                            continue
+
+                        context['mmsc'] = mmsc.text
+
+                        mmsproxy = apn.find('mmsproxy')
+                        if mmsproxy is not None and mmsproxy.text:
+                            context['mmsproxy'] = mmsproxy.text
+
+                    username = apn.find('username')
+                    if username is not None and username.text:
+                        context['username'] = username.text
+
+                    password = apn.find('password')
+                    if password is not None and password.text:
+                        context['password'] = password.text
+
+                    authentication = apn.find('authentication')
+                    if authentication is not None:
+                        context['authentication'] = authentication.get('method')
+
+                    context_name = apn.find('name')
+                    if context_name != None:
+                        context['name'] = context_name.text
+
+                    info.add_context(context)
+
+                if info.is_valid():
+                    providers.append(info)
+
+        except ET.ParseError as e:
+            print(f"Error parsing XML: {e}")
+
+        return providers
+
+class ProvisionContext(ctypes.LittleEndianStructure):
+    _pack_ = 1
+    _fields_ = [
+        ('type', ctypes.c_uint32),
+        ('protocol', ctypes.c_uint32),
+        ('authentication', ctypes.c_uint32),
+        ('reserved', ctypes.c_uint32),
+        ('name_offset', ctypes.c_uint64),
+        ('apn_offset', ctypes.c_uint64),
+        ('username_offset', ctypes.c_uint64),
+        ('password_offset', ctypes.c_uint64),
+        ('mmsproxy_offset', ctypes.c_uint64),
+        ('mmsc_offset', ctypes.c_uint64)
+    ]
+
+    authentication_dict = { 'chap' : 0, 'pap' : 1, 'none' : 2 }
+    protocol_dict = { 'ipv4' : 0, 'ipv6' : 1, 'ipv4v6' : 2 }
+    attrs = ['name', 'apn', 'username', 'password', 'mmsproxy', 'mmsc']
+
+    @classmethod
+    def type_to_context_type(cls, types):
+        r = 0
+
+        for t in types:
+            if t == 'internet':
+                r |= 0x0001
+            elif t == 'mms':
+                r |= 0x0002
+            elif t == 'wap':
+                r |= 0x0004
+            elif t == 'ims':
+                r |= 0x0008
+            elif t == 'supl':
+                r |= 0x0010
+            elif t == 'ia':
+                r |= 0x0020
+
+        return r
+
+    def __init__(self, apn, strings):
+        self.type = self.type_to_context_type(apn['type'])
+        self.protocol = self.protocol_dict[apn.get('protocol', 'ipv4v6')]
+        self.authentication = self.authentication_dict[apn.get('authentication',
+                                                               'chap')]
+
+        for s in self.attrs:
+            offset = strings.add_string(apn.get(s, None))
+            setattr(self, s + '_offset', offset)
+
+class ProvisionData(ctypes.LittleEndianStructure):
+    _pack_ = 1
+    _fields_ = [
+        ('spn_offset', ctypes.c_uint64),
+        ('context_offset', ctypes.c_uint64)
+    ]
+
+    def __init__(self, spn, offset, strings):
+        self.spn_offset = strings.add_string(spn)
+        self.context_offset = offset
+
+class ProvisionNode(ctypes.LittleEndianStructure):
+    _pack_ = 1
+    _fields_ = [
+        ('bit_offsets', ctypes.c_uint64 * 2),
+        ('mccmnc', ctypes.c_uint32),
+        ('diff', ctypes.c_int32),
+        ('provision_data_count', ctypes.c_uint64)
+    ]
+
+    style = "bold"
+    fmt_connection = '\t"%s/%d" -> "%s/%d"[color="#%06x"];\n'
+    fmt_declaration = '\t"%s/%d"[style=%s, color="#%06x"];\n'
+    red = 0xff0000
+    green = 0x00ff00
+
+    def __init__(self, key, diff):
+        self.bit = [None, None]
+        self.key = key
+        self.diff = diff
+        self.entries = {}
+        self.node_offset = 0
+
+    def choose(self, key):
+        return (key >> (31 - self.diff)) & 1
+
+    def print_graphviz(self, f):
+        f.write(self.fmt_declaration % (format(self.key, '032b'),
+                                        self.diff, self.style,
+                                        random.randint(0, 0x00ffffff)))
+        f.write(self.fmt_connection % (format(self.key, '032b'), self.diff,
+                                       format(self.bit[0].key, '032b'),
+                                       self.bit[0].diff, self.red))
+        f.write(self.fmt_connection % (format(self.key, '032b'), self.diff,
+                                       format(self.bit[1].key, '032b'),
+                                       self.bit[1].diff, self.green))
+
+        if (self.diff < self.bit[0].diff):
+            self.bit[0].print_graphviz(f)
+
+        if (self.diff < self.bit[1].diff):
+            self.bit[1].print_graphviz(f)
+
+    def __str__(self):
+        s = format(self.key, '032b') + '/' + str(self.diff)
+        return s
+
+class MccMncTree:
+    @staticmethod
+    def clz(v):
+        count = 32
+        while count and v:
+            v = v >> 1
+            count = count - 1
+
+        return count
+
+    @staticmethod
+    def diff(key1, key2):
+        xor = key1 ^ key2;
+        return MccMncTree.clz(xor)
+
+    def __init__(self):
+        self.root = ProvisionNode(key = 0, diff = -1)
+        self.root.bit[0] = self.root
+        self.root.bit[1] = self.root
+        self.n_nodes = 1
+
+    def print_graphviz(self):
+        f = open("step%d.dot" % self.n_nodes, "w")
+        # Use 'dot -Tx11' to visualize
+        f.write('digraph trie {\n')
+        self.root.print_graphviz(f)
+        f.write('}\n')
+        f.close()
+
+    def find_closest(self, key):
+        parent = self.root
+        child = self.root.bit[0]
+
+        while parent.diff < child.diff:
+            parent = child
+            child = child.bit[child.choose(key)]
+
+        return child
+
+    def find(self, key):
+        found = self.find_closest(key)
+        if found.key == key:
+            return found
+
+        return None
+
+    def insert(self, key, attr, value):
+        node = self.find_closest(key);
+        if node.key == key:
+            node.entries[attr] = value
+            return
+
+        bit = self.diff(node.key, key)
+        parent = self.root
+        child = self.root.bit[0]
+
+        while (parent.diff < child.diff) and (child.diff < bit):
+            parent = child
+            child = child.bit[child.choose(key)]
+
+        node = ProvisionNode(key, bit)
+        bit = node.choose(key)
+        node.bit[bit] = node
+        node.bit[not bit] = child
+
+        node.entries[attr] = value
+
+        if parent == self.root:
+            self.root.bit[0] = node
+        else:
+            bit = parent.choose(key)
+            parent.bit[bit] = node
+
+        self.n_nodes += 1
+
+    def traverse_recursive(self, node, bit, visitor):
+        if node == self.root:
+            return
+
+        if node.diff <= bit:
+            visitor.visit(node)
+            return
+
+        self.traverse_recursive(node.bit[0], node.diff, visitor)
+        self.traverse_recursive(node.bit[1], node.diff, visitor)
+
+    def traverse(self, visitor):
+        self.traverse_recursive(self.root.bit[0], -1, visitor)
+
+class StringAccumulator:
+    def __init__(self):
+        self.data = bytearray(b'\x00') # So offsets are never 0 used for NULL
+        self.offsets = {}
+
+    def add_string(self, s):
+        if s is None:
+            return 0
+
+        if s in self.offsets:
+            return self.offsets[s]
+
+        offset = len(self.data)
+        self.data.extend(s.encode('utf-8'))
+        self.data.append(0)
+        self.offsets[s] = offset
+
+        return offset
+
+    def get_bytes(self):
+        return self.data
+
+class ProvisionDatabase(ctypes.LittleEndianStructure):
+    _pack_ = 1
+    _fields_ = [
+        ('version', ctypes.c_uint64),
+        ('file_size', ctypes.c_uint64),
+        ('header_size', ctypes.c_uint64),
+        ('node_struct_size', ctypes.c_uint64),
+        ('provision_data_struct_size', ctypes.c_uint64),
+        ('context_struct_size', ctypes.c_uint64),
+        ('nodes_offset', ctypes.c_uint64),
+        ('nodes_size', ctypes.c_uint64),
+        ('contexts_offset', ctypes.c_uint64),
+        ('contexts_size', ctypes.c_uint64),
+        ('strings_offset', ctypes.c_uint64),
+        ('strings_size', ctypes.c_uint64)
+    ]
+
+    class CalculateNodeOffsetVisitor:
+        def __init__(self):
+            self.current_offset = 0
+        def visit(self, node):
+            node.node_offset = self.current_offset
+
+            # Node data is followed by at least one ProvisionData object, with
+            # the only exception being root, which has no data by definition
+            self.current_offset += ctypes.sizeof(ProvisionNode)
+            self.current_offset += (ctypes.sizeof(ProvisionData) *
+                                    len(node.entries))
+
+    class SerializeVisitor:
+        def __init__(self, buffer):
+            self.buffer = buffer
+
+        def visit(self, node):
+            # Node doesn't quite fit the C structure definition, so do this
+            # manually by using struct.pack
+            self.buffer.extend(struct.pack('<QQIiQ',
+                                           node.bit[0].node_offset,
+                                           node.bit[1].node_offset,
+                                           node.key, node.diff,
+                                           len(node.entries)))
+
+            for spn in sorted(node.entries):
+                pd = node.entries[spn]
+                self.buffer.extend(bytes(pd))
+
+    def __init__(self, provider_infos):
+        self.strings = StringAccumulator()
+        self.contexts = bytearray()
+        self.tree = MccMncTree()
+
+        for info in provider_infos:
+            pd = ProvisionData(info.spn, len(self.contexts), self.strings)
+
+            self.contexts.extend(struct.pack('<Q', len(info.context_list)))
+
+            for context in info.context_list:
+                self.contexts.extend(bytes(ProvisionContext(context,
+                                                            self.strings)))
+
+            for mccmnc in info.mccmnc_list:
+                # Sort None spns as '' so they're first in the list when
+                # the SerializeVisitor sorts the entries dict
+                spn = info.spn if info.spn is not None else ''
+
+                # 2 and 3 byte MNCs are treated differently, even if evaluate
+                # to the same integer.  For example, 02 and 002 are different
+                # MNCs.  In practice this doesn't actually happen except on
+                # test networks, but account for this possibility by using the
+                # upper 10 bits for the MCC, followed by a single bit which
+                # signifies whether a 3 byte MNC is used, followed by 10 bits
+                # of the MNC
+                key = int(mccmnc[:3]) << 11 | int(mccmnc[3:])
+                if len(mccmnc[3:]) == 3:
+                    key |= 1 << 10
+
+                self.tree.insert(key, spn, pd)
+
+        visitor = self.CalculateNodeOffsetVisitor()
+        visitor.visit(self.tree.root)
+        self.tree.traverse(visitor)
+
+        self.version = 1
+        self.header_size = ctypes.sizeof(ProvisionDatabase)
+        self.file_size = self.header_size
+        self.node_struct_size = ctypes.sizeof(ProvisionNode)
+        self.provision_data_struct_size = ctypes.sizeof(ProvisionData)
+        self.context_struct_size = ctypes.sizeof(ProvisionContext)
+        self.nodes_offset = self.header_size
+        self.nodes_size = visitor.current_offset
+        self.file_size += self.nodes_size
+        self.contexts_offset = self.nodes_offset + self.nodes_size
+        self.contexts_size = len(self.contexts)
+        self.file_size += self.contexts_size
+        self.strings_offset = self.contexts_offset + self.contexts_size
+        self.strings_size = len(self.strings.get_bytes())
+        self.file_size += self.strings_size
+
+    def serialize(self):
+        buffer = bytearray()
+        buffer.extend(bytes(self))
+
+        visitor = self.SerializeVisitor(buffer)
+        visitor.visit(self.tree.root)
+        self.tree.traverse(visitor)
+
+        buffer.extend(self.contexts)
+        buffer.extend(self.strings.get_bytes())
+
+        return buffer
+
+class ProviderInfoJSONEncoder(json.JSONEncoder):
+    def default(self, obj):
+        if isinstance(obj, ProviderInfo):
+            asdict = { 'name' : obj.name, 'ids' : obj.mccmnc_list }
+
+            if (obj.spn != None):
+                asdict['spn'] = obj.spn
+
+            asdict.update({'apns' : obj.context_list})
+            return asdict
+
+        return json.JSONEncoder.default(self, obj)
+
+def mbpi_convert(args):
+    xml_path = '/usr/share/mobile-broadband-provider-info/serviceproviders.xml'
+    mbpi = MobileBroadbandProviderInfo(xml_path)
+    provider_infos = mbpi.parse(xml_path)
+
+    try:
+        if args.outfile is None:
+            out = sys.stdout
+        else:
+            out = args.outfile.open('w', encoding='utf-8')
+
+        with out as outfile:
+            json.dump(provider_infos, outfile, ensure_ascii=False, indent=2,
+                      cls=ProviderInfoJSONEncoder)
+    except ValueError as e:
+        raise SystemExit(e)
+
+def generate(args):
+    try:
+        json_dict = json.load(args.infile)
+    except ValueError as e:
+        raise SystemExit(e)
+
+    provider_infos = []
+
+    for entry in json_dict:
+        info = ProviderInfo.rawimport(entry)
+        provider_infos.append(info)
+
+    db = ProvisionDatabase(provider_infos)
+    with args.outfile.open('wb') as outfile:
+        outfile.write(db.serialize())
+
+def selftest(args):
+    tree = MccMncTree()
+
+    # Generate random key this many times and insert it into the tree
+    # There will be some collisions, so count the number of items inserted
+    # Then run the lookup and make sure the same number of items can be found
+    times = 100000
+    for i in range(0, times):
+        key = random.randint(10000, 999999)
+        tree.insert(key, None, None)
+
+    n_inserted = tree.n_nodes
+    print("Created a tree with %d nodes (1 root)" % n_inserted)
+
+    n_found = 0
+    expected_keys = []
+
+    for i in range(10000, 1000000):
+        if tree.find(i):
+            n_found += 1
+            expected_keys.append(i)
+
+    expected_keys = sorted(expected_keys)
+
+    print("Found %d nodes (not including root)" % n_found)
+    assert n_found == n_inserted - 1
+
+    class GetKeysVisitor:
+        def __init__(self):
+            self.keys = []
+
+        def visit(self, node):
+            self.keys.append(node.key)
+
+    visitor = GetKeysVisitor()
+    tree.traverse(visitor)
+    assert visitor.keys == expected_keys
+
+    sample_json = """[
+    {
+      "name": "Operator XYZ",
+      "ids": [
+        "99955", "99956", "99901", "99902"
+      ],
+      "apns": [
+        {
+          "name": "Internet",
+          "apn": "internet",
+          "type": [
+            "internet"
+          ],
+          "authentication": "none",
+          "protocol": "ipv4"
+        },
+        {
+          "name": "IMS+MMS",
+          "apn": "imsmms",
+          "type": [
+            "ims", "mms"
+          ],
+          "mmsc": "foobar.mmsc:80",
+          "mmsproxy": "mms.proxy.net",
+          "authentication": "pap",
+          "protocol": "ipv6"
+        }
+      ]
+    },
+    {
+      "name": "Operator ZYX",
+      "ids": [
+        "99998", "99999", "99901", "99902"
+      ],
+      "spn": "ZYX",
+      "apns": [
+        {
+          "name": "ZYX",
+          "apn": "zyx",
+          "type": [
+            "internet"
+          ],
+          "authentication": "none",
+          "protocol": "ipv4"
+        }
+      ]
+    }
+    ]"""
+    try:
+        json_dict = json.loads(sample_json)
+    except ValueError as e:
+        raise SystemExit(e)
+
+    provider_infos = []
+
+    for entry in json_dict:
+        info = ProviderInfo.rawimport(entry)
+        provider_infos.append(info)
+
+    db = ProvisionDatabase(provider_infos)
+    expected_strings = bytearray(b'\x00Internet\x00internet\x00IMS+MMS\x00' +
+                b'imsmms\x00mms.proxy.net\x00foobar.mmsc:80\x00ZYX\x00zyx\x00')
+    assert(expected_strings == db.strings.get_bytes())
+
+    with open('sample.db', 'wb') as outfile:
+        outfile.write(db.serialize())
+
+if __name__ == "__main__":
+    parser = ArgumentParser(description='Parse command line arguments')
+    subparsers = parser.add_subparsers(title='subcommands', dest='command')
+
+    # mbpi-convert command
+    mbpi_convert_parser = subparsers.add_parser('mbpi-convert',
+                        help='Convert mobile-broadband-provider-info database')
+    mbpi_convert_parser.add_argument('--outfile', type=Path, default=None,
+                        help='Output file path', required=False,)
+    mbpi_convert_parser.set_defaults(func=mbpi_convert)
+
+    # generate command
+    generate_parser = subparsers.add_parser('generate',
+                                            help='Generate binary provider db')
+    generate_parser.add_argument('--outfile', type=Path, default='provision.db',
+                                 help='Output file path', required=False)
+    generate_parser.add_argument('--infile', type=FileType(encoding='utf-8'),
+                                 help='Input JSON db', default=sys.stdin)
+    generate_parser.set_defaults(func=generate)
+
+    # selftest command
+    selftest_parser = subparsers.add_parser('selftest',
+                        help='Run self-tests')
+    selftest_parser.set_defaults(func=selftest)
+
+    args = parser.parse_args()
+    if hasattr(args, 'func'):
+        args.func(args)
+    else:
+        parser.print_help()