/mandos/trunk

To get this branch, use:
bzr branch http://bzr.recompile.se/loggerhead/mandos/trunk

« back to all changes in this revision

Viewing changes to mandos-ctl

  • Committer: Teddy Hogeborn
  • Date: 2019-03-13 21:38:35 UTC
  • Revision ID: teddy@recompile.se-20190313213835-vq9cv3lih2jp3au7
mandos-ctl: Refactor

* mandos-ctl (TestCmd.bus.Bus.get_object): Add explanatory comment
                                           about "self".
  (Unique): Move to "TestPropertyCmd.Unique".  All callers changed.

Show diffs side-by-side

added added

removed removed

Lines of Context:
93
93
    if options.debug:
94
94
        log.setLevel(logging.DEBUG)
95
95
 
96
 
    bus = dbus.SystemBus()
97
 
 
98
 
    mandos_dbus_object = get_mandos_dbus_object(bus)
99
 
 
100
 
    mandos_serv = dbus.Interface(
101
 
        mandos_dbus_object, dbus_interface=server_dbus_interface)
 
96
    try:
 
97
        bus = dbus.SystemBus()
 
98
        log.debug("D-Bus: Connect to: (busname=%r, path=%r)",
 
99
                  dbus_busname, server_dbus_path)
 
100
        mandos_dbus_objc = bus.get_object(dbus_busname,
 
101
                                          server_dbus_path)
 
102
    except dbus.exceptions.DBusException:
 
103
        log.critical("Could not connect to Mandos server")
 
104
        sys.exit(1)
 
105
 
 
106
    mandos_serv = dbus.Interface(mandos_dbus_objc,
 
107
                                 dbus_interface=server_dbus_interface)
102
108
    mandos_serv_object_manager = dbus.Interface(
103
 
        mandos_dbus_object, dbus_interface=dbus.OBJECT_MANAGER_IFACE)
104
 
 
105
 
    managed_objects = get_managed_objects(mandos_serv_object_manager)
106
 
 
107
 
    all_clients = {}
108
 
    for path, ifs_and_props in managed_objects.items():
109
 
        try:
110
 
            all_clients[path] = ifs_and_props[client_dbus_interface]
111
 
        except KeyError:
112
 
            pass
113
 
 
114
 
    # Compile dict of (clientpath: properties) to process
 
109
        mandos_dbus_objc, dbus_interface=dbus.OBJECT_MANAGER_IFACE)
 
110
 
 
111
    # Filter out log message from dbus module
 
112
    dbus_logger = logging.getLogger("dbus.proxies")
 
113
    class NullFilter(logging.Filter):
 
114
        def filter(self, record):
 
115
            return False
 
116
    dbus_filter = NullFilter()
 
117
    try:
 
118
        dbus_logger.addFilter(dbus_filter)
 
119
        log.debug("D-Bus: %s:%s:%s.GetManagedObjects()", dbus_busname,
 
120
                  server_dbus_path, dbus.OBJECT_MANAGER_IFACE)
 
121
        mandos_clients = {path: ifs_and_props[client_dbus_interface]
 
122
                          for path, ifs_and_props in
 
123
                          mandos_serv_object_manager
 
124
                          .GetManagedObjects().items()
 
125
                          if client_dbus_interface in ifs_and_props}
 
126
    except dbus.exceptions.DBusException as e:
 
127
        log.critical("Failed to access Mandos server through D-Bus:"
 
128
                     "\n%s", e)
 
129
        sys.exit(1)
 
130
    finally:
 
131
        # restore dbus logger
 
132
        dbus_logger.removeFilter(dbus_filter)
 
133
 
 
134
    # Compile dict of (clients: properties) to process
 
135
    clients = {}
 
136
 
115
137
    if not clientnames:
116
 
        clients = all_clients
 
138
        clients = {objpath: properties
 
139
                   for objpath, properties in mandos_clients.items()}
117
140
    else:
118
 
        clients = {}
119
141
        for name in clientnames:
120
 
            for objpath, properties in all_clients.items():
 
142
            for objpath, properties in mandos_clients.items():
121
143
                if properties["Name"] == name:
122
144
                    clients[objpath] = properties
123
145
                    break
424
446
        options.remove = True
425
447
 
426
448
 
427
 
def get_mandos_dbus_object(bus):
428
 
    log.debug("D-Bus: Connect to: (busname=%r, path=%r)",
429
 
              dbus_busname, server_dbus_path)
430
 
    with if_dbus_exception_log_with_exception_and_exit(
431
 
            "Could not connect to Mandos server: %s"):
432
 
        mandos_dbus_object = bus.get_object(dbus_busname,
433
 
                                            server_dbus_path)
434
 
    return mandos_dbus_object
435
 
 
436
 
 
437
 
@contextlib.contextmanager
438
 
def if_dbus_exception_log_with_exception_and_exit(*args, **kwargs):
439
 
    try:
440
 
        yield
441
 
    except dbus.exceptions.DBusException as e:
442
 
        log.critical(*(args + (e,)), **kwargs)
443
 
        sys.exit(1)
444
 
 
445
 
 
446
 
def get_managed_objects(object_manager):
447
 
    log.debug("D-Bus: %s:%s:%s.GetManagedObjects()", dbus_busname,
448
 
              server_dbus_path, dbus.OBJECT_MANAGER_IFACE)
449
 
    with if_dbus_exception_log_with_exception_and_exit(
450
 
            "Failed to access Mandos server through D-Bus:\n%s"):
451
 
        with SilenceLogger("dbus.proxies"):
452
 
            managed_objects = object_manager.GetManagedObjects()
453
 
    return managed_objects
454
 
 
455
 
 
456
 
class SilenceLogger(object):
457
 
    "Simple context manager to silence a particular logger"
458
 
    def __init__(self, loggername):
459
 
        self.logger = logging.getLogger(loggername)
460
 
 
461
 
    def __enter__(self):
462
 
        self.logger.addFilter(self.nullfilter)
463
 
        return self
464
 
 
465
 
    class NullFilter(logging.Filter):
466
 
        def filter(self, record):
467
 
            return False
468
 
 
469
 
    nullfilter = NullFilter()
470
 
 
471
 
    def __exit__(self, exc_type, exc_val, exc_tb):
472
 
        self.logger.removeFilter(self.nullfilter)
473
 
 
474
 
 
475
449
def commands_from_options(options):
476
450
 
477
451
    commands = []
616
590
        data = {client["Name"]:
617
591
                {key: self.dbus_boolean_to_bool(client[key])
618
592
                 for key in self.all_keywords}
619
 
                for client in clients}
 
593
                for client in clients.values()}
620
594
        return json.dumps(data, indent=4, separators=(',', ': '))
621
595
 
622
596
    @staticmethod
662
636
            "LastCheckerStatus": "Last Checker Status",
663
637
        }
664
638
 
665
 
        def __init__(self, clients, keywords):
 
639
        def __init__(self, clients, keywords, tableheaders=None):
666
640
            self.clients = clients
667
641
            self.keywords = keywords
 
642
            if tableheaders is not None:
 
643
                self.tableheaders = tableheaders
668
644
 
669
645
        def __str__(self):
670
646
            return "\n".join(self.rows())
839
815
 
840
816
 
841
817
 
842
 
class TestCaseWithAssertLogs(unittest.TestCase):
843
 
    """unittest.TestCase.assertLogs only exists in Python 3.4"""
844
 
 
845
 
    if not hasattr(unittest.TestCase, "assertLogs"):
846
 
        @contextlib.contextmanager
847
 
        def assertLogs(self, logger, level=logging.INFO):
848
 
            capturing_handler = self.CapturingLevelHandler(level)
849
 
            old_level = logger.level
850
 
            old_propagate = logger.propagate
851
 
            logger.addHandler(capturing_handler)
852
 
            logger.setLevel(level)
853
 
            logger.propagate = False
854
 
            try:
855
 
                yield capturing_handler.watcher
856
 
            finally:
857
 
                logger.propagate = old_propagate
858
 
                logger.removeHandler(capturing_handler)
859
 
                logger.setLevel(old_level)
860
 
            self.assertGreater(len(capturing_handler.watcher.records),
861
 
                               0)
862
 
 
863
 
        class CapturingLevelHandler(logging.Handler):
864
 
            def __init__(self, level, *args, **kwargs):
865
 
                logging.Handler.__init__(self, *args, **kwargs)
866
 
                self.watcher = self.LoggingWatcher([], [])
867
 
            def emit(self, record):
868
 
                self.watcher.records.append(record)
869
 
                self.watcher.output.append(self.format(record))
870
 
 
871
 
            LoggingWatcher = collections.namedtuple("LoggingWatcher",
872
 
                                                    ("records",
873
 
                                                     "output"))
874
 
 
875
 
class Test_string_to_delta(TestCaseWithAssertLogs):
 
818
class Test_string_to_delta(unittest.TestCase):
876
819
    def test_handles_basic_rfc3339(self):
877
820
        self.assertEqual(string_to_delta("PT0S"),
878
821
                         datetime.timedelta())
884
827
                         datetime.timedelta(0, 7200))
885
828
 
886
829
    def test_falls_back_to_pre_1_6_1_with_warning(self):
887
 
        with self.assertLogs(log, logging.WARNING):
888
 
            value = string_to_delta("2h")
 
830
        # assertLogs only exists in Python 3.4
 
831
        if hasattr(self, "assertLogs"):
 
832
            with self.assertLogs(log, logging.WARNING):
 
833
                value = string_to_delta("2h")
 
834
        else:
 
835
            class WarningFilter(logging.Filter):
 
836
                """Don't show, but record the presence of, warnings"""
 
837
                def filter(self, record):
 
838
                    is_warning = record.levelno >= logging.WARNING
 
839
                    self.found = is_warning or getattr(self, "found",
 
840
                                                       False)
 
841
                    return not is_warning
 
842
            warning_filter = WarningFilter()
 
843
            log.addFilter(warning_filter)
 
844
            try:
 
845
                value = string_to_delta("2h")
 
846
            finally:
 
847
                log.removeFilter(warning_filter)
 
848
            self.assertTrue(getattr(warning_filter, "found", False))
889
849
        self.assertEqual(value, datetime.timedelta(0, 7200))
890
850
 
891
851
 
1021
981
                self.check_option_syntax(options)
1022
982
 
1023
983
 
1024
 
class Test_get_mandos_dbus_object(TestCaseWithAssertLogs):
1025
 
    def test_calls_and_returns_get_object_on_bus(self):
1026
 
        class MockBus(object):
1027
 
            called = False
1028
 
            def get_object(mockbus_self, busname, dbus_path):
1029
 
                # Note that "self" is still the testcase instance,
1030
 
                # this MockBus instance is in "mockbus_self".
1031
 
                self.assertEqual(busname, dbus_busname)
1032
 
                self.assertEqual(dbus_path, server_dbus_path)
1033
 
                mockbus_self.called = True
1034
 
                return mockbus_self
1035
 
 
1036
 
        mockbus = get_mandos_dbus_object(bus=MockBus())
1037
 
        self.assertIsInstance(mockbus, MockBus)
1038
 
        self.assertTrue(mockbus.called)
1039
 
 
1040
 
    def test_logs_and_exits_on_dbus_error(self):
1041
 
        class MockBusFailing(object):
1042
 
            def get_object(self, busname, dbus_path):
1043
 
                raise dbus.exceptions.DBusException("Test")
1044
 
 
1045
 
        with self.assertLogs(log, logging.CRITICAL):
1046
 
            with self.assertRaises(SystemExit) as e:
1047
 
                bus = get_mandos_dbus_object(bus=MockBusFailing())
1048
 
 
1049
 
        if isinstance(e.exception.code, int):
1050
 
            self.assertNotEqual(e.exception.code, 0)
1051
 
        else:
1052
 
            self.assertIsNotNone(e.exception.code)
1053
 
 
1054
 
 
1055
 
class Test_get_managed_objects(TestCaseWithAssertLogs):
1056
 
    def test_calls_and_returns_GetManagedObjects(self):
1057
 
        managed_objects = {"/clients/foo": { "Name": "foo"}}
1058
 
        class MockObjectManager(object):
1059
 
            def GetManagedObjects(self):
1060
 
                return managed_objects
1061
 
        retval = get_managed_objects(MockObjectManager())
1062
 
        self.assertDictEqual(managed_objects, retval)
1063
 
 
1064
 
    def test_logs_and_exits_on_dbus_error(self):
1065
 
        dbus_logger = logging.getLogger("dbus.proxies")
1066
 
 
1067
 
        class MockObjectManagerFailing(object):
1068
 
            def GetManagedObjects(self):
1069
 
                dbus_logger.error("Test")
1070
 
                raise dbus.exceptions.DBusException("Test")
1071
 
 
1072
 
        class CountingHandler(logging.Handler):
1073
 
            count = 0
1074
 
            def emit(self, record):
1075
 
                self.count += 1
1076
 
 
1077
 
        counting_handler = CountingHandler()
1078
 
 
1079
 
        dbus_logger.addHandler(counting_handler)
1080
 
 
1081
 
        try:
1082
 
            with self.assertLogs(log, logging.CRITICAL) as watcher:
1083
 
                with self.assertRaises(SystemExit) as e:
1084
 
                    get_managed_objects(MockObjectManagerFailing())
1085
 
        finally:
1086
 
            dbus_logger.removeFilter(counting_handler)
1087
 
        self.assertEqual(counting_handler.count, 0)
1088
 
 
1089
 
        # Test that the dbus_logger still works
1090
 
        with self.assertLogs(dbus_logger, logging.ERROR):
1091
 
            dbus_logger.error("Test")
1092
 
 
1093
 
        if isinstance(e.exception.code, int):
1094
 
            self.assertNotEqual(e.exception.code, 0)
1095
 
        else:
1096
 
            self.assertIsNotNone(e.exception.code)
1097
 
 
1098
 
 
1099
984
class Test_commands_from_options(unittest.TestCase):
1100
985
    def setUp(self):
1101
986
        self.parser = argparse.ArgumentParser()
1506
1391
        return super(TestDumpJSONCmd, self).setUp()
1507
1392
 
1508
1393
    def test_normal(self):
1509
 
        output = DumpJSONCmd().output(self.clients.values())
1510
 
        json_data = json.loads(output)
 
1394
        json_data = json.loads(DumpJSONCmd().output(self.clients))
1511
1395
        self.assertDictEqual(json_data, self.expected_json)
1512
1396
 
1513
1397
    def test_one_client(self):
1514
 
        output = DumpJSONCmd().output(self.one_client.values())
1515
 
        json_data = json.loads(output)
 
1398
        clients = self.one_client
 
1399
        json_data = json.loads(DumpJSONCmd().output(clients))
1516
1400
        expected_json = {"foo": self.expected_json["foo"]}
1517
1401
        self.assertDictEqual(json_data, expected_json)
1518
1402
 
1807
1691
    return tests
1808
1692
 
1809
1693
if __name__ == "__main__":
1810
 
    try:
1811
 
        if should_only_run_tests():
1812
 
            # Call using ./tdd-python-script --check [--verbose]
1813
 
            unittest.main()
1814
 
        else:
1815
 
            main()
1816
 
    finally:
1817
 
        logging.shutdown()
 
1694
    if should_only_run_tests():
 
1695
        # Call using ./tdd-python-script --check [--verbose]
 
1696
        unittest.main()
 
1697
    else:
 
1698
        main()