/mandos/trunk

To get this branch, use:
bzr branch http://bzr.recompile.se/loggerhead/mandos/trunk

« back to all changes in this revision

Viewing changes to mandos-ctl

  • Committer: Teddy Hogeborn
  • Date: 2019-03-15 23:55:53 UTC
  • Revision ID: teddy@recompile.se-20190315235553-znbpfn7d8o84tyt6
mandos-ctl: Refactor

* mandos-ctl (main): Do some minor refactoring.

Show diffs side-by-side

added added

removed removed

Lines of Context:
102
102
    mandos_serv_object_manager = dbus.Interface(
103
103
        mandos_dbus_object, dbus_interface=dbus.OBJECT_MANAGER_IFACE)
104
104
 
105
 
    managed_objects = get_managed_objects(mandos_serv_object_manager)
 
105
    log.debug("D-Bus: %s:%s:%s.GetManagedObjects()", dbus_busname,
 
106
              server_dbus_path, dbus.OBJECT_MANAGER_IFACE)
 
107
    try:
 
108
        with SilenceLogger("dbus.proxies"):
 
109
            managed_objects = (mandos_serv_object_manager
 
110
                               .GetManagedObjects())
 
111
    except dbus.exceptions.DBusException as e:
 
112
        log.critical("Failed to access Mandos server through D-Bus:"
 
113
                     "\n%s", e)
 
114
        sys.exit(1)
106
115
 
107
116
    all_clients = {}
108
117
    for path, ifs_and_props in managed_objects.items():
125
134
                log.critical("Client not found on server: %r", name)
126
135
                sys.exit(1)
127
136
 
 
137
    # Run all commands on clients
128
138
    commands = commands_from_options(options)
129
 
 
130
139
    for command in commands:
131
140
        command.run(clients, bus, mandos_serv)
132
141
 
232
241
    >>> rfc3339_duration_to_delta("")
233
242
    Traceback (most recent call last):
234
243
    ...
235
 
    ValueError: Invalid RFC 3339 duration: ""
 
244
    ValueError: Invalid RFC 3339 duration: u''
236
245
    >>> # Must start with "P":
237
246
    >>> rfc3339_duration_to_delta("1D")
238
247
    Traceback (most recent call last):
239
248
    ...
240
 
    ValueError: Invalid RFC 3339 duration: "1D"
 
249
    ValueError: Invalid RFC 3339 duration: u'1D'
241
250
    >>> # Must use correct order
242
251
    >>> rfc3339_duration_to_delta("PT1S2M")
243
252
    Traceback (most recent call last):
244
253
    ...
245
 
    ValueError: Invalid RFC 3339 duration: "PT1S2M"
 
254
    ValueError: Invalid RFC 3339 duration: u'PT1S2M'
246
255
    >>> # Time needs time marker
247
256
    >>> rfc3339_duration_to_delta("P1H2S")
248
257
    Traceback (most recent call last):
249
258
    ...
250
 
    ValueError: Invalid RFC 3339 duration: "P1H2S"
 
259
    ValueError: Invalid RFC 3339 duration: u'P1H2S'
251
260
    >>> # Weeks can not be combined with anything else
252
261
    >>> rfc3339_duration_to_delta("P1D2W")
253
262
    Traceback (most recent call last):
254
263
    ...
255
 
    ValueError: Invalid RFC 3339 duration: "P1D2W"
 
264
    ValueError: Invalid RFC 3339 duration: u'P1D2W'
256
265
    >>> rfc3339_duration_to_delta("P2W2H")
257
266
    Traceback (most recent call last):
258
267
    ...
259
 
    ValueError: Invalid RFC 3339 duration: "P2W2H"
 
268
    ValueError: Invalid RFC 3339 duration: u'P2W2H'
260
269
    """
261
270
 
262
271
    # Parsing an RFC 3339 duration with regular expressions is not
333
342
                break
334
343
        else:
335
344
            # No currently valid tokens were found
336
 
            raise ValueError("Invalid RFC 3339 duration: \"{}\""
 
345
            raise ValueError("Invalid RFC 3339 duration: {!r}"
337
346
                             .format(duration))
338
347
    # End token found
339
348
    return value
425
434
 
426
435
 
427
436
def get_mandos_dbus_object(bus):
428
 
    log.debug("D-Bus: Connect to: (busname=%r, path=%r)",
429
 
              dbus_busname, server_dbus_path)
430
 
    with if_dbus_exception_log_with_exception_and_exit(
431
 
            "Could not connect to Mandos server: %s"):
 
437
    try:
 
438
        log.debug("D-Bus: Connect to: (busname=%r, path=%r)",
 
439
                  dbus_busname, server_dbus_path)
432
440
        mandos_dbus_object = bus.get_object(dbus_busname,
433
441
                                            server_dbus_path)
 
442
    except dbus.exceptions.DBusException:
 
443
        log.critical("Could not connect to Mandos server")
 
444
        sys.exit(1)
 
445
 
434
446
    return mandos_dbus_object
435
447
 
436
448
 
437
 
@contextlib.contextmanager
438
 
def if_dbus_exception_log_with_exception_and_exit(*args, **kwargs):
439
 
    try:
440
 
        yield
441
 
    except dbus.exceptions.DBusException as e:
442
 
        log.critical(*(args + (e,)), **kwargs)
443
 
        sys.exit(1)
444
 
 
445
 
 
446
 
def get_managed_objects(object_manager):
447
 
    log.debug("D-Bus: %s:%s:%s.GetManagedObjects()", dbus_busname,
448
 
              server_dbus_path, dbus.OBJECT_MANAGER_IFACE)
449
 
    with if_dbus_exception_log_with_exception_and_exit(
450
 
            "Failed to access Mandos server through D-Bus:\n%s"):
451
 
        with SilenceLogger("dbus.proxies"):
452
 
            managed_objects = object_manager.GetManagedObjects()
453
 
    return managed_objects
454
 
 
455
 
 
456
449
class SilenceLogger(object):
457
450
    "Simple context manager to silence a particular logger"
458
451
    def __init__(self, loggername):
839
832
 
840
833
 
841
834
 
842
 
class TestCaseWithAssertLogs(unittest.TestCase):
843
 
    """unittest.TestCase.assertLogs only exists in Python 3.4"""
844
 
 
845
 
    if not hasattr(unittest.TestCase, "assertLogs"):
846
 
        @contextlib.contextmanager
847
 
        def assertLogs(self, logger, level=logging.INFO):
848
 
            capturing_handler = self.CapturingLevelHandler(level)
849
 
            old_level = logger.level
850
 
            old_propagate = logger.propagate
851
 
            logger.addHandler(capturing_handler)
852
 
            logger.setLevel(level)
853
 
            logger.propagate = False
854
 
            try:
855
 
                yield capturing_handler.watcher
856
 
            finally:
857
 
                logger.propagate = old_propagate
858
 
                logger.removeHandler(capturing_handler)
859
 
                logger.setLevel(old_level)
860
 
            self.assertGreater(len(capturing_handler.watcher.records),
861
 
                               0)
862
 
 
863
 
        class CapturingLevelHandler(logging.Handler):
864
 
            def __init__(self, level, *args, **kwargs):
865
 
                logging.Handler.__init__(self, *args, **kwargs)
866
 
                self.watcher = self.LoggingWatcher([], [])
867
 
            def emit(self, record):
868
 
                self.watcher.records.append(record)
869
 
                self.watcher.output.append(self.format(record))
870
 
 
871
 
            LoggingWatcher = collections.namedtuple("LoggingWatcher",
872
 
                                                    ("records",
873
 
                                                     "output"))
874
 
 
875
 
 
876
 
class Test_string_to_delta(TestCaseWithAssertLogs):
 
835
class Test_string_to_delta(unittest.TestCase):
877
836
    def test_handles_basic_rfc3339(self):
878
837
        self.assertEqual(string_to_delta("PT0S"),
879
838
                         datetime.timedelta())
885
844
                         datetime.timedelta(0, 7200))
886
845
 
887
846
    def test_falls_back_to_pre_1_6_1_with_warning(self):
888
 
        with self.assertLogs(log, logging.WARNING):
889
 
            value = string_to_delta("2h")
 
847
        # assertLogs only exists in Python 3.4
 
848
        if hasattr(self, "assertLogs"):
 
849
            with self.assertLogs(log, logging.WARNING):
 
850
                value = string_to_delta("2h")
 
851
        else:
 
852
            class WarningFilter(logging.Filter):
 
853
                """Don't show, but record the presence of, warnings"""
 
854
                def filter(self, record):
 
855
                    is_warning = record.levelno >= logging.WARNING
 
856
                    self.found = is_warning or getattr(self, "found",
 
857
                                                       False)
 
858
                    return not is_warning
 
859
            warning_filter = WarningFilter()
 
860
            log.addFilter(warning_filter)
 
861
            try:
 
862
                value = string_to_delta("2h")
 
863
            finally:
 
864
                log.removeFilter(warning_filter)
 
865
            self.assertTrue(getattr(warning_filter, "found", False))
890
866
        self.assertEqual(value, datetime.timedelta(0, 7200))
891
867
 
892
868
 
931
907
    @contextlib.contextmanager
932
908
    def assertParseError(self):
933
909
        with self.assertRaises(SystemExit) as e:
934
 
            with self.redirect_stderr_to_devnull():
 
910
            with self.temporarily_suppress_stderr():
935
911
                yield
936
912
        # Exit code from argparse is guaranteed to be "2".  Reference:
937
913
        # https://docs.python.org/3/library
940
916
 
941
917
    @staticmethod
942
918
    @contextlib.contextmanager
943
 
    def redirect_stderr_to_devnull():
 
919
    def temporarily_suppress_stderr():
944
920
        null = os.open(os.path.devnull, os.O_RDWR)
945
921
        stderrcopy = os.dup(sys.stderr.fileno())
946
922
        os.dup2(null, sys.stderr.fileno())
955
931
    def check_option_syntax(self, options):
956
932
        check_option_syntax(self.parser, options)
957
933
 
958
 
    def test_actions_all_conflicts_with_verbose(self):
959
 
        for action, value in self.actions.items():
960
 
            options = self.parser.parse_args()
961
 
            setattr(options, action, value)
962
 
            options.all = True
963
 
            options.verbose = True
964
 
            with self.assertParseError():
965
 
                self.check_option_syntax(options)
966
 
 
967
 
    def test_actions_with_client_conflicts_with_verbose(self):
968
 
        for action, value in self.actions.items():
969
 
            options = self.parser.parse_args()
970
 
            setattr(options, action, value)
971
 
            options.verbose = True
972
 
            options.client = ["foo"]
 
934
    def test_actions_conflicts_with_verbose(self):
 
935
        for action, value in self.actions.items():
 
936
            options = self.parser.parse_args()
 
937
            setattr(options, action, value)
 
938
            options.verbose = True
973
939
            with self.assertParseError():
974
940
                self.check_option_syntax(options)
975
941
 
1001
967
            options.all = True
1002
968
            self.check_option_syntax(options)
1003
969
 
1004
 
    def test_any_action_is_ok_with_one_client(self):
1005
 
        for action, value in self.actions.items():
1006
 
            options = self.parser.parse_args()
1007
 
            setattr(options, action, value)
1008
 
            options.client = ["foo"]
 
970
    def test_is_enabled_fails_without_client(self):
 
971
        options = self.parser.parse_args()
 
972
        options.is_enabled = True
 
973
        with self.assertParseError():
1009
974
            self.check_option_syntax(options)
1010
975
 
1011
 
    def test_one_client_with_all_actions_except_is_enabled(self):
 
976
    def test_is_enabled_works_with_one_client(self):
1012
977
        options = self.parser.parse_args()
1013
 
        for action, value in self.actions.items():
1014
 
            if action == "is_enabled":
1015
 
                continue
1016
 
            setattr(options, action, value)
 
978
        options.is_enabled = True
1017
979
        options.client = ["foo"]
1018
980
        self.check_option_syntax(options)
1019
981
 
1020
 
    def test_two_clients_with_all_actions_except_is_enabled(self):
1021
 
        options = self.parser.parse_args()
1022
 
        for action, value in self.actions.items():
1023
 
            if action == "is_enabled":
1024
 
                continue
1025
 
            setattr(options, action, value)
1026
 
        options.client = ["foo", "barbar"]
1027
 
        self.check_option_syntax(options)
1028
 
 
1029
 
    def test_two_clients_are_ok_with_actions_except_is_enabled(self):
1030
 
        for action, value in self.actions.items():
1031
 
            if action == "is_enabled":
1032
 
                continue
1033
 
            options = self.parser.parse_args()
1034
 
            setattr(options, action, value)
1035
 
            options.client = ["foo", "barbar"]
1036
 
            self.check_option_syntax(options)
1037
 
 
1038
 
    def test_is_enabled_fails_without_client(self):
1039
 
        options = self.parser.parse_args()
1040
 
        options.is_enabled = True
1041
 
        with self.assertParseError():
1042
 
            self.check_option_syntax(options)
1043
 
 
1044
982
    def test_is_enabled_fails_with_two_clients(self):
1045
983
        options = self.parser.parse_args()
1046
984
        options.is_enabled = True
1060
998
                self.check_option_syntax(options)
1061
999
 
1062
1000
 
1063
 
class Test_get_mandos_dbus_object(TestCaseWithAssertLogs):
 
1001
class Test_get_mandos_dbus_object(unittest.TestCase):
1064
1002
    def test_calls_and_returns_get_object_on_bus(self):
1065
1003
        class MockBus(object):
1066
1004
            called = False
1081
1019
            def get_object(self, busname, dbus_path):
1082
1020
                raise dbus.exceptions.DBusException("Test")
1083
1021
 
1084
 
        with self.assertLogs(log, logging.CRITICAL):
1085
 
            with self.assertRaises(SystemExit) as e:
1086
 
                bus = get_mandos_dbus_object(bus=MockBusFailing())
1087
 
 
 
1022
        # assertLogs only exists in Python 3.4
 
1023
        if hasattr(self, "assertLogs"):
 
1024
            with self.assertLogs(log, logging.CRITICAL):
 
1025
                with self.assertRaises(SystemExit) as e:
 
1026
                    bus = get_mandos_dbus_object(bus=MockBus())
 
1027
        else:
 
1028
            critical_filter = self.CriticalFilter()
 
1029
            log.addFilter(critical_filter)
 
1030
            try:
 
1031
                with self.assertRaises(SystemExit) as e:
 
1032
                    get_mandos_dbus_object(bus=MockBusFailing())
 
1033
            finally:
 
1034
                log.removeFilter(critical_filter)
 
1035
            self.assertTrue(critical_filter.found)
1088
1036
        if isinstance(e.exception.code, int):
1089
1037
            self.assertNotEqual(e.exception.code, 0)
1090
1038
        else:
1091
1039
            self.assertIsNotNone(e.exception.code)
1092
1040
 
1093
 
 
1094
 
class Test_get_managed_objects(TestCaseWithAssertLogs):
1095
 
    def test_calls_and_returns_GetManagedObjects(self):
1096
 
        managed_objects = {"/clients/foo": { "Name": "foo"}}
1097
 
        class MockObjectManager(object):
1098
 
            def GetManagedObjects(self):
1099
 
                return managed_objects
1100
 
        retval = get_managed_objects(MockObjectManager())
1101
 
        self.assertDictEqual(managed_objects, retval)
1102
 
 
1103
 
    def test_logs_and_exits_on_dbus_error(self):
1104
 
        dbus_logger = logging.getLogger("dbus.proxies")
1105
 
 
1106
 
        class MockObjectManagerFailing(object):
1107
 
            def GetManagedObjects(self):
1108
 
                dbus_logger.error("Test")
1109
 
                raise dbus.exceptions.DBusException("Test")
1110
 
 
1111
 
        class CountingHandler(logging.Handler):
1112
 
            count = 0
1113
 
            def emit(self, record):
1114
 
                self.count += 1
1115
 
 
1116
 
        counting_handler = CountingHandler()
1117
 
 
1118
 
        dbus_logger.addHandler(counting_handler)
1119
 
 
 
1041
    class CriticalFilter(logging.Filter):
 
1042
        """Don't show, but register, critical messages"""
 
1043
        found = False
 
1044
        def filter(self, record):
 
1045
            is_critical = record.levelno >= logging.CRITICAL
 
1046
            self.found = is_critical or self.found
 
1047
            return not is_critical
 
1048
 
 
1049
 
 
1050
class Test_SilenceLogger(unittest.TestCase):
 
1051
    loggername = "mandos-ctl.Test_SilenceLogger"
 
1052
    log = logging.getLogger(loggername)
 
1053
    log.propagate = False
 
1054
    log.addHandler(logging.NullHandler())
 
1055
 
 
1056
    def setUp(self):
 
1057
        self.counting_filter = self.CountingFilter()
 
1058
 
 
1059
    class CountingFilter(logging.Filter):
 
1060
        "Count number of records"
 
1061
        count = 0
 
1062
        def filter(self, record):
 
1063
            self.count += 1
 
1064
            return True
 
1065
 
 
1066
    def test_should_filter_records_only_when_active(self):
1120
1067
        try:
1121
 
            with self.assertLogs(log, logging.CRITICAL) as watcher:
1122
 
                with self.assertRaises(SystemExit) as e:
1123
 
                    get_managed_objects(MockObjectManagerFailing())
 
1068
            with SilenceLogger(self.loggername):
 
1069
                self.log.addFilter(self.counting_filter)
 
1070
                self.log.info("Filtered log message 1")
 
1071
            self.log.info("Non-filtered message 2")
 
1072
            self.log.info("Non-filtered message 3")
1124
1073
        finally:
1125
 
            dbus_logger.removeFilter(counting_handler)
1126
 
 
1127
 
        # Make sure the dbus logger was suppressed
1128
 
        self.assertEqual(counting_handler.count, 0)
1129
 
 
1130
 
        # Test that the dbus_logger still works
1131
 
        with self.assertLogs(dbus_logger, logging.ERROR):
1132
 
            dbus_logger.error("Test")
1133
 
 
1134
 
        if isinstance(e.exception.code, int):
1135
 
            self.assertNotEqual(e.exception.code, 0)
1136
 
        else:
1137
 
            self.assertIsNotNone(e.exception.code)
 
1074
            self.log.removeFilter(self.counting_filter)
 
1075
        self.assertEqual(self.counting_filter.count, 2)
1138
1076
 
1139
1077
 
1140
1078
class Test_commands_from_options(unittest.TestCase):