new file mode 100755
@@ -0,0 +1,727 @@
+#!/usr/bin/python3
+#
+# oFono - Open Source Telephony
+# Copyright (C) 2023 Cruise, LLC
+#
+# SPDX-License-Identifier: GPL-2.0-only
+import xml.etree.ElementTree as ET
+import sys
+import json
+import bisect
+from argparse import ArgumentParser, FileType
+from pathlib import Path
+import random
+import struct
+import ctypes
+
+class ProviderInfo:
+ sort_order_map = { v : pos for pos, v in
+ enumerate( ['name',
+ 'apn',
+ 'type',
+ 'protocol',
+ 'mmsc',
+ 'mmsproxy',
+ 'authentication',
+ 'username',
+ 'password'] ) }
+
+ @classmethod
+ def rawimport(cls, entry):
+ if 'name' not in entry:
+ raise SystemExit('No name for entry: ' + str(entry))
+
+ info = ProviderInfo(entry['name'])
+
+ for networkid in entry.get('ids', []):
+ if not info.add_id(networkid):
+ raise SystemExit('Invalid network id: ' + str(networkid))
+
+ if 'spn' in entry:
+ if not info.set_spn(entry['spn']):
+ raise SystemExit('Invalid spn: ' + str(spn))
+
+ for apn in entry.get('apns', []):
+ if not info.add_context(apn):
+ raise SystemExit('Invalid apn: ' + str(apn))
+
+ if not info.is_valid():
+ raise SystemExit('Invalid entry: ' + str(entry))
+
+ return info
+
+ def __init__(self, name):
+ self.context_list = []
+ self.mccmnc_list = []
+ self.name = name
+ self.spn = None
+
+ @staticmethod
+ def is_valid_id(id_string, expected_lengths):
+ """
+ Check if the identifier string is valid.
+
+ Parameters:
+ - id_string: The id string to check.
+ - expected_lengths (tuple): A tuple representing the valid range of
+ lengths.
+
+ Returns:
+ - bool: True if the MCC string is valid, False otherwise.
+ """
+ if not id_string.isdigit():
+ return False
+
+ if len(id_string) not in expected_lengths:
+ return False
+
+ if int(id_string) == 0:
+ return False
+
+ return True
+
+ def add_mccmnc(self, mcc, mnc):
+ if not self.is_valid_id(mcc, (3,)) or not self.is_valid_id(mnc, (2, 3)):
+ return False
+
+ bisect.insort(self.mccmnc_list, mcc + mnc)
+ return True
+
+ def add_id(self, mccmnc):
+ if not self.is_valid_id(mccmnc, (5,6)):
+ return False
+
+ if int(mccmnc[:3]) == 0 or int(mccmnc[3:]) == 0:
+ return False
+
+ bisect.insort(self.mccmnc_list, mccmnc)
+ return True
+
+ def set_spn(self, spn):
+ if len(spn) == 0 or len(spn) > 254:
+ return False
+
+ self.spn = spn
+ return True
+
+ def add_context(self, info):
+ info = dict(sorted(info.items(),
+ key = lambda pair: self.sort_order_map[pair[0]]))
+ self.context_list.append(info)
+
+ return True
+
+ def is_valid(self):
+ return len(self.context_list) and len(self.mccmnc_list)
+
+ def __str__(self):
+ s = 'Provider \'' + self.name + '\''
+
+ if (self.spn != None):
+ s += ' [SPN:\'' + self.spn + '\']'
+
+ s+= ' ' + str(self.mccmnc_list) + '\n'
+
+ for context in self.context_list:
+ s += '\t' + str(context) + '\n'
+
+ return s
+
+class MobileBroadbandProviderInfo:
+ usage_to_type = { 'internet' : ['internet'],
+ 'mms' : ['mms'],
+ 'wap' : ['wap'],
+ 'mms-internet-hipri' : ['internet', 'mms'],
+ 'mms-internet-hipri-fota' : ['internet','mms'],
+ }
+ @classmethod
+ def type_from_usage(cls, usage):
+ return cls.usage_to_type[usage]
+
+ def __init__(self, xml_path):
+ self.tree = ET.parse(xml_path)
+
+ def parse(self, xml_path):
+ providers = []
+
+ try:
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+
+ for provider in root.findall('.//provider'):
+ name = provider.find('name')
+ if name is None or not name.text:
+ continue;
+
+ info = ProviderInfo(name.text)
+
+ for networkid in provider.findall('gsm/network-id'):
+ info.add_mccmnc(networkid.get('mcc'), networkid.get('mnc'))
+
+ for apn in provider.findall('gsm/apn'):
+ context = {}
+
+ context['apn'] = apn.get('value')
+ if context['apn'] == None:
+ continue
+
+ # Usage is missing for some APNs, skip such contexts for now
+ usage = apn.find('usage')
+ if usage is None or usage.get('type') is None:
+ continue;
+
+ context['type'] = self.type_from_usage(usage.get('type'))
+ if context['type'] == None:
+ sys.stderr.write("Unable to convert type: %s\n" %
+ usage.get('type'))
+ continue
+
+ if 'mms' in context['type']:
+ mmsc = apn.find('mmsc')
+
+ # Ignore MMS contexts with no MMSC since it is needed
+ # to send messages
+ if mmsc is None or not mmsc.text:
+ continue
+
+ context['mmsc'] = mmsc.text
+
+ mmsproxy = apn.find('mmsproxy')
+ if mmsproxy is not None and mmsproxy.text:
+ context['mmsproxy'] = mmsproxy.text
+
+ username = apn.find('username')
+ if username is not None and username.text:
+ context['username'] = username.text
+
+ password = apn.find('password')
+ if password is not None and password.text:
+ context['password'] = password.text
+
+ authentication = apn.find('authentication')
+ if authentication is not None:
+ context['authentication'] = authentication.get('method')
+
+ context_name = apn.find('name')
+ if context_name != None:
+ context['name'] = context_name.text
+
+ info.add_context(context)
+
+ if info.is_valid():
+ providers.append(info)
+
+ except ET.ParseError as e:
+ print(f"Error parsing XML: {e}")
+
+ return providers
+
+class ProvisionContext(ctypes.LittleEndianStructure):
+ _pack_ = 1
+ _fields_ = [
+ ('type', ctypes.c_uint32),
+ ('protocol', ctypes.c_uint32),
+ ('authentication', ctypes.c_uint32),
+ ('reserved', ctypes.c_uint32),
+ ('name_offset', ctypes.c_uint64),
+ ('apn_offset', ctypes.c_uint64),
+ ('username_offset', ctypes.c_uint64),
+ ('password_offset', ctypes.c_uint64),
+ ('mmsproxy_offset', ctypes.c_uint64),
+ ('mmsc_offset', ctypes.c_uint64)
+ ]
+
+ authentication_dict = { 'chap' : 0, 'pap' : 1, 'none' : 2 }
+ protocol_dict = { 'ipv4' : 0, 'ipv6' : 1, 'ipv4v6' : 2 }
+ attrs = ['name', 'apn', 'username', 'password', 'mmsproxy', 'mmsc']
+
+ @classmethod
+ def type_to_context_type(cls, types):
+ r = 0
+
+ for t in types:
+ if t == 'internet':
+ r |= 0x0001
+ elif t == 'mms':
+ r |= 0x0002
+ elif t == 'wap':
+ r |= 0x0004
+ elif t == 'ims':
+ r |= 0x0008
+ elif t == 'supl':
+ r |= 0x0010
+ elif t == 'ia':
+ r |= 0x0020
+
+ return r
+
+ def __init__(self, apn, strings):
+ self.type = self.type_to_context_type(apn['type'])
+ self.protocol = self.protocol_dict[apn.get('protocol', 'ipv4v6')]
+ self.authentication = self.authentication_dict[apn.get('authentication',
+ 'chap')]
+
+ for s in self.attrs:
+ offset = strings.add_string(apn.get(s, None))
+ setattr(self, s + '_offset', offset)
+
+class ProvisionData(ctypes.LittleEndianStructure):
+ _pack_ = 1
+ _fields_ = [
+ ('spn_offset', ctypes.c_uint64),
+ ('context_offset', ctypes.c_uint64)
+ ]
+
+ def __init__(self, spn, offset, strings):
+ self.spn_offset = strings.add_string(spn)
+ self.context_offset = offset
+
+class ProvisionNode(ctypes.LittleEndianStructure):
+ _pack_ = 1
+ _fields_ = [
+ ('bit_offsets', ctypes.c_uint64 * 2),
+ ('mccmnc', ctypes.c_uint32),
+ ('diff', ctypes.c_int32),
+ ('provision_data_count', ctypes.c_uint64)
+ ]
+
+ style = "bold"
+ fmt_connection = '\t"%s/%d" -> "%s/%d"[color="#%06x"];\n'
+ fmt_declaration = '\t"%s/%d"[style=%s, color="#%06x"];\n'
+ red = 0xff0000
+ green = 0x00ff00
+
+ def __init__(self, key, diff):
+ self.bit = [None, None]
+ self.key = key
+ self.diff = diff
+ self.entries = {}
+ self.node_offset = 0
+
+ def choose(self, key):
+ return (key >> (31 - self.diff)) & 1
+
+ def print_graphviz(self, f):
+ f.write(self.fmt_declaration % (format(self.key, '032b'),
+ self.diff, self.style,
+ random.randint(0, 0x00ffffff)))
+ f.write(self.fmt_connection % (format(self.key, '032b'), self.diff,
+ format(self.bit[0].key, '032b'),
+ self.bit[0].diff, self.red))
+ f.write(self.fmt_connection % (format(self.key, '032b'), self.diff,
+ format(self.bit[1].key, '032b'),
+ self.bit[1].diff, self.green))
+
+ if (self.diff < self.bit[0].diff):
+ self.bit[0].print_graphviz(f)
+
+ if (self.diff < self.bit[1].diff):
+ self.bit[1].print_graphviz(f)
+
+ def __str__(self):
+ s = format(self.key, '032b') + '/' + str(self.diff)
+ return s
+
+class MccMncTree:
+ @staticmethod
+ def clz(v):
+ count = 32
+ while count and v:
+ v = v >> 1
+ count = count - 1
+
+ return count
+
+ @staticmethod
+ def diff(key1, key2):
+ xor = key1 ^ key2;
+ return MccMncTree.clz(xor)
+
+ def __init__(self):
+ self.root = ProvisionNode(key = 0, diff = -1)
+ self.root.bit[0] = self.root
+ self.root.bit[1] = self.root
+ self.n_nodes = 1
+
+ def print_graphviz(self):
+ f = open("step%d.dot" % self.n_nodes, "w")
+ # Use 'dot -Tx11' to visualize
+ f.write('digraph trie {\n')
+ self.root.print_graphviz(f)
+ f.write('}\n')
+ f.close()
+
+ def find_closest(self, key):
+ parent = self.root
+ child = self.root.bit[0]
+
+ while parent.diff < child.diff:
+ parent = child
+ child = child.bit[child.choose(key)]
+
+ return child
+
+ def find(self, key):
+ found = self.find_closest(key)
+ if found.key == key:
+ return found
+
+ return None
+
+ def insert(self, key, attr, value):
+ node = self.find_closest(key);
+ if node.key == key:
+ node.entries[attr] = value
+ return
+
+ bit = self.diff(node.key, key)
+ parent = self.root
+ child = self.root.bit[0]
+
+ while (parent.diff < child.diff) and (child.diff < bit):
+ parent = child
+ child = child.bit[child.choose(key)]
+
+ node = ProvisionNode(key, bit)
+ bit = node.choose(key)
+ node.bit[bit] = node
+ node.bit[not bit] = child
+
+ node.entries[attr] = value
+
+ if parent == self.root:
+ self.root.bit[0] = node
+ else:
+ bit = parent.choose(key)
+ parent.bit[bit] = node
+
+ self.n_nodes += 1
+
+ def traverse_recursive(self, node, bit, visitor):
+ if node == self.root:
+ return
+
+ if node.diff <= bit:
+ visitor.visit(node)
+ return
+
+ self.traverse_recursive(node.bit[0], node.diff, visitor)
+ self.traverse_recursive(node.bit[1], node.diff, visitor)
+
+ def traverse(self, visitor):
+ self.traverse_recursive(self.root.bit[0], -1, visitor)
+
+class StringAccumulator:
+ def __init__(self):
+ self.data = bytearray(b'\x00') # So offsets are never 0 used for NULL
+ self.offsets = {}
+
+ def add_string(self, s):
+ if s is None:
+ return 0
+
+ if s in self.offsets:
+ return self.offsets[s]
+
+ offset = len(self.data)
+ self.data.extend(s.encode('utf-8'))
+ self.data.append(0)
+ self.offsets[s] = offset
+
+ return offset
+
+ def get_bytes(self):
+ return self.data
+
+class ProvisionDatabase(ctypes.LittleEndianStructure):
+ _pack_ = 1
+ _fields_ = [
+ ('version', ctypes.c_uint64),
+ ('file_size', ctypes.c_uint64),
+ ('header_size', ctypes.c_uint64),
+ ('node_struct_size', ctypes.c_uint64),
+ ('provision_data_struct_size', ctypes.c_uint64),
+ ('context_struct_size', ctypes.c_uint64),
+ ('nodes_offset', ctypes.c_uint64),
+ ('nodes_size', ctypes.c_uint64),
+ ('contexts_offset', ctypes.c_uint64),
+ ('contexts_size', ctypes.c_uint64),
+ ('strings_offset', ctypes.c_uint64),
+ ('strings_size', ctypes.c_uint64)
+ ]
+
+ class CalculateNodeOffsetVisitor:
+ def __init__(self):
+ self.current_offset = 0
+ def visit(self, node):
+ node.node_offset = self.current_offset
+
+ # Node data is followed by at least one ProvisionData object, with
+ # the only exception being root, which has no data by definition
+ self.current_offset += ctypes.sizeof(ProvisionNode)
+ self.current_offset += (ctypes.sizeof(ProvisionData) *
+ len(node.entries))
+
+ class SerializeVisitor:
+ def __init__(self, buffer):
+ self.buffer = buffer
+
+ def visit(self, node):
+ # Node doesn't quite fit the C structure definition, so do this
+ # manually by using struct.pack
+ self.buffer.extend(struct.pack('<QQIiQ',
+ node.bit[0].node_offset,
+ node.bit[1].node_offset,
+ node.key, node.diff,
+ len(node.entries)))
+
+ for spn in sorted(node.entries):
+ pd = node.entries[spn]
+ self.buffer.extend(bytes(pd))
+
+ def __init__(self, provider_infos):
+ self.strings = StringAccumulator()
+ self.contexts = bytearray()
+ self.tree = MccMncTree()
+
+ for info in provider_infos:
+ pd = ProvisionData(info.spn, len(self.contexts), self.strings)
+
+ self.contexts.extend(struct.pack('<Q', len(info.context_list)))
+
+ for context in info.context_list:
+ self.contexts.extend(bytes(ProvisionContext(context,
+ self.strings)))
+
+ for mccmnc in info.mccmnc_list:
+ # Sort None spns as '' so they're first in the list when
+ # the SerializeVisitor sorts the entries dict
+ spn = info.spn if info.spn is not None else ''
+
+ # 2 and 3 byte MNCs are treated differently, even if evaluate
+ # to the same integer. For example, 02 and 002 are different
+ # MNCs. In practice this doesn't actually happen except on
+ # test networks, but account for this possibility by using the
+ # upper 10 bits for the MCC, followed by a single bit which
+ # signifies whether a 3 byte MNC is used, followed by 10 bits
+ # of the MNC
+ key = int(mccmnc[:3]) << 11 | int(mccmnc[3:])
+ if len(mccmnc[3:]) == 3:
+ key |= 1 << 10
+
+ self.tree.insert(key, spn, pd)
+
+ visitor = self.CalculateNodeOffsetVisitor()
+ visitor.visit(self.tree.root)
+ self.tree.traverse(visitor)
+
+ self.version = 1
+ self.header_size = ctypes.sizeof(ProvisionDatabase)
+ self.file_size = self.header_size
+ self.node_struct_size = ctypes.sizeof(ProvisionNode)
+ self.provision_data_struct_size = ctypes.sizeof(ProvisionData)
+ self.context_struct_size = ctypes.sizeof(ProvisionContext)
+ self.nodes_offset = self.header_size
+ self.nodes_size = visitor.current_offset
+ self.file_size += self.nodes_size
+ self.contexts_offset = self.nodes_offset + self.nodes_size
+ self.contexts_size = len(self.contexts)
+ self.file_size += self.contexts_size
+ self.strings_offset = self.contexts_offset + self.contexts_size
+ self.strings_size = len(self.strings.get_bytes())
+ self.file_size += self.strings_size
+
+ def serialize(self):
+ buffer = bytearray()
+ buffer.extend(bytes(self))
+
+ visitor = self.SerializeVisitor(buffer)
+ visitor.visit(self.tree.root)
+ self.tree.traverse(visitor)
+
+ buffer.extend(self.contexts)
+ buffer.extend(self.strings.get_bytes())
+
+ return buffer
+
+class ProviderInfoJSONEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, ProviderInfo):
+ asdict = { 'name' : obj.name, 'ids' : obj.mccmnc_list }
+
+ if (obj.spn != None):
+ asdict['spn'] = obj.spn
+
+ asdict.update({'apns' : obj.context_list})
+ return asdict
+
+ return json.JSONEncoder.default(self, obj)
+
+def mbpi_convert(args):
+ xml_path = '/usr/share/mobile-broadband-provider-info/serviceproviders.xml'
+ mbpi = MobileBroadbandProviderInfo(xml_path)
+ provider_infos = mbpi.parse(xml_path)
+
+ try:
+ if args.outfile is None:
+ out = sys.stdout
+ else:
+ out = args.outfile.open('w', encoding='utf-8')
+
+ with out as outfile:
+ json.dump(provider_infos, outfile, ensure_ascii=False, indent=2,
+ cls=ProviderInfoJSONEncoder)
+ except ValueError as e:
+ raise SystemExit(e)
+
+def generate(args):
+ try:
+ json_dict = json.load(args.infile)
+ except ValueError as e:
+ raise SystemExit(e)
+
+ provider_infos = []
+
+ for entry in json_dict:
+ info = ProviderInfo.rawimport(entry)
+ provider_infos.append(info)
+
+ db = ProvisionDatabase(provider_infos)
+ with args.outfile.open('wb') as outfile:
+ outfile.write(db.serialize())
+
+def selftest(args):
+ tree = MccMncTree()
+
+ # Generate random key this many times and insert it into the tree
+ # There will be some collisions, so count the number of items inserted
+ # Then run the lookup and make sure the same number of items can be found
+ times = 100000
+ for i in range(0, times):
+ key = random.randint(10000, 999999)
+ tree.insert(key, None, None)
+
+ n_inserted = tree.n_nodes
+ print("Created a tree with %d nodes (1 root)" % n_inserted)
+
+ n_found = 0
+ expected_keys = []
+
+ for i in range(10000, 1000000):
+ if tree.find(i):
+ n_found += 1
+ expected_keys.append(i)
+
+ expected_keys = sorted(expected_keys)
+
+ print("Found %d nodes (not including root)" % n_found)
+ assert n_found == n_inserted - 1
+
+ class GetKeysVisitor:
+ def __init__(self):
+ self.keys = []
+
+ def visit(self, node):
+ self.keys.append(node.key)
+
+ visitor = GetKeysVisitor()
+ tree.traverse(visitor)
+ assert visitor.keys == expected_keys
+
+ sample_json = """[
+ {
+ "name": "Operator XYZ",
+ "ids": [
+ "99955", "99956", "99901", "99902"
+ ],
+ "apns": [
+ {
+ "name": "Internet",
+ "apn": "internet",
+ "type": [
+ "internet"
+ ],
+ "authentication": "none",
+ "protocol": "ipv4"
+ },
+ {
+ "name": "IMS+MMS",
+ "apn": "imsmms",
+ "type": [
+ "ims", "mms"
+ ],
+ "mmsc": "foobar.mmsc:80",
+ "mmsproxy": "mms.proxy.net",
+ "authentication": "pap",
+ "protocol": "ipv6"
+ }
+ ]
+ },
+ {
+ "name": "Operator ZYX",
+ "ids": [
+ "99998", "99999", "99901", "99902"
+ ],
+ "spn": "ZYX",
+ "apns": [
+ {
+ "name": "ZYX",
+ "apn": "zyx",
+ "type": [
+ "internet"
+ ],
+ "authentication": "none",
+ "protocol": "ipv4"
+ }
+ ]
+ }
+ ]"""
+ try:
+ json_dict = json.loads(sample_json)
+ except ValueError as e:
+ raise SystemExit(e)
+
+ provider_infos = []
+
+ for entry in json_dict:
+ info = ProviderInfo.rawimport(entry)
+ provider_infos.append(info)
+
+ db = ProvisionDatabase(provider_infos)
+ expected_strings = bytearray(b'\x00Internet\x00internet\x00IMS+MMS\x00' +
+ b'imsmms\x00mms.proxy.net\x00foobar.mmsc:80\x00ZYX\x00zyx\x00')
+ assert(expected_strings == db.strings.get_bytes())
+
+ with open('sample.db', 'wb') as outfile:
+ outfile.write(db.serialize())
+
+if __name__ == "__main__":
+ parser = ArgumentParser(description='Parse command line arguments')
+ subparsers = parser.add_subparsers(title='subcommands', dest='command')
+
+ # mbpi-convert command
+ mbpi_convert_parser = subparsers.add_parser('mbpi-convert',
+ help='Convert mobile-broadband-provider-info database')
+ mbpi_convert_parser.add_argument('--outfile', type=Path, default=None,
+ help='Output file path', required=False,)
+ mbpi_convert_parser.set_defaults(func=mbpi_convert)
+
+ # generate command
+ generate_parser = subparsers.add_parser('generate',
+ help='Generate binary provider db')
+ generate_parser.add_argument('--outfile', type=Path, default='provision.db',
+ help='Output file path', required=False)
+ generate_parser.add_argument('--infile', type=FileType(encoding='utf-8'),
+ help='Input JSON db', default=sys.stdin)
+ generate_parser.set_defaults(func=generate)
+
+ # selftest command
+ selftest_parser = subparsers.add_parser('selftest',
+ help='Run self-tests')
+ selftest_parser.set_defaults(func=selftest)
+
+ args = parser.parse_args()
+ if hasattr(args, 'func'):
+ args.func(args)
+ else:
+ parser.print_help()