diff mbox series

[2/3] python/sepolicy: Stop rejecting aliases in sepolicy commands

Message ID 20181016082541.21615-3-vmojzis@redhat.com (mailing list archive)
State Not Applicable
Headers show
Series [1/3] python/sepolicy: Fix "info" to search aliases as well | expand

Commit Message

Vit Mojzis Oct. 16, 2018, 8:25 a.m. UTC
Fix CheckDomain and CheckPortType classes to properly deal with aliases.

Resolves: rhbz#1600009
---
 python/sepolicy/sepolicy.py          |  8 +++-----
 python/sepolicy/sepolicy/__init__.py | 10 +++++++++-
 2 files changed, 12 insertions(+), 6 deletions(-)
diff mbox series

Patch

diff --git a/python/sepolicy/sepolicy.py b/python/sepolicy/sepolicy.py
index a000c1ad..01380fbe 100755
--- a/python/sepolicy/sepolicy.py
+++ b/python/sepolicy/sepolicy.py
@@ -60,8 +60,6 @@  class CheckPath(argparse.Action):
 class CheckType(argparse.Action):
 
     def __call__(self, parser, namespace, values, option_string=None):
-        domains = sepolicy.get_all_domains()
-
         if isinstance(values, str):
             setattr(namespace, self.dest, values)
         else:
@@ -103,7 +101,7 @@  class CheckDomain(argparse.Action):
         domains = sepolicy.get_all_domains()
 
         if isinstance(values, str):
-            if values not in domains:
+            if sepolicy.get_real_type_name(values) not in domains:
                 raise ValueError("%s must be an SELinux process domain:\nValid domains: %s" % (values, ", ".join(domains)))
             setattr(namespace, self.dest, values)
         else:
@@ -112,7 +110,7 @@  class CheckDomain(argparse.Action):
                 newval = []
 
             for v in values:
-                if v not in domains:
+                if sepolicy.get_real_type_name(v) not in domains:
                     raise ValueError("%s must be an SELinux process domain:\nValid domains: %s" % (v, ", ".join(domains)))
                 newval.append(v)
             setattr(namespace, self.dest, newval)
@@ -167,7 +165,7 @@  class CheckPortType(argparse.Action):
         if not newval:
             newval = []
         for v in values:
-            if v not in port_types:
+            if sepolicy.get_real_type_name(v) not in port_types:
                 raise ValueError("%s must be an SELinux port type:\nValid port types: %s" % (v, ", ".join(port_types)))
             newval.append(v)
         setattr(namespace, self.dest, values)
diff --git a/python/sepolicy/sepolicy/__init__.py b/python/sepolicy/sepolicy/__init__.py
index 8484b28c..0da3917b 100644
--- a/python/sepolicy/sepolicy/__init__.py
+++ b/python/sepolicy/sepolicy/__init__.py
@@ -447,6 +447,14 @@  def get_file_types(setype):
     return mpaths
 
 
+# determine if entered type is an alias
+# and return corresponding type name
+def get_real_type_name(name):
+    try:
+        return next(info(TYPE, name))["name"]
+    except (RuntimeError, StopIteration):
+        return None
+
 def get_writable_files(setype):
     file_types = get_all_file_types()
     all_writes = []
@@ -1061,7 +1069,7 @@  def gen_short_name(setype):
         domainname = setype[:-2]
     else:
         domainname = setype
-    if domainname + "_t" not in all_domains:
+    if get_real_type_name(domainname + "_t") not in all_domains:
         raise ValueError("domain %s_t does not exist" % domainname)
     if domainname[-1] == 'd':
         short_name = domainname[:-1] + "_"