/mandos/trunk

To get this branch, use:
bzr branch http://bzr.recompile.se/loggerhead/mandos/trunk

« back to all changes in this revision

Viewing changes to mandos-ctl

  • Committer: Teddy Hogeborn
  • Date: 2019-03-16 17:11:12 UTC
  • Revision ID: teddy@recompile.se-20190316171112-0tpfk9ved7a2a5x2
mandos-ctl: Refactor and add a few more tests

* mandos-ctl (main): Remove comment and add empty line.
  (rfc3339_duration_to_delta): Change ValueError exception message to
                               use \"{}\"" instead of "{!r}" so Python
                               2 and Python 3 output is the same.
  (Test_check_option_syntax.temporarily_suppress_stderr): Rename to
                                         "redirect_stderr_to_devnull".

  (Test_check_option_syntax
  .test_actions_except_is_enabled_are_ok_with_two_clients):  Rename to
             "test_two_clients_are_ok_with_actions_except_is_enabled".
  (Test_check_option_syntax
   .test_one_client_with_all_actions_except_is_enabled): New.
  (Test_check_option_syntax
   .test_two_clients_with_all_actions_except_is_enabled): - '' -

Show diffs side-by-side

added added

removed removed

Lines of Context:
93
93
    if options.debug:
94
94
        log.setLevel(logging.DEBUG)
95
95
 
96
 
    try:
97
 
        bus = dbus.SystemBus()
98
 
        log.debug("D-Bus: Connect to: (busname=%r, path=%r)",
99
 
                  dbus_busname, server_dbus_path)
100
 
        mandos_dbus_objc = bus.get_object(dbus_busname,
101
 
                                          server_dbus_path)
102
 
    except dbus.exceptions.DBusException:
103
 
        log.critical("Could not connect to Mandos server")
104
 
        sys.exit(1)
105
 
 
106
 
    mandos_serv = dbus.Interface(mandos_dbus_objc,
107
 
                                 dbus_interface=server_dbus_interface)
 
96
    bus = dbus.SystemBus()
 
97
 
 
98
    mandos_dbus_object = get_mandos_dbus_object(bus)
 
99
 
 
100
    mandos_serv = dbus.Interface(
 
101
        mandos_dbus_object, dbus_interface=server_dbus_interface)
108
102
    mandos_serv_object_manager = dbus.Interface(
109
 
        mandos_dbus_objc, dbus_interface=dbus.OBJECT_MANAGER_IFACE)
110
 
 
111
 
    # Filter out log message from dbus module
112
 
    dbus_logger = logging.getLogger("dbus.proxies")
113
 
    class NullFilter(logging.Filter):
114
 
        def filter(self, record):
115
 
            return False
116
 
    dbus_filter = NullFilter()
117
 
    try:
118
 
        dbus_logger.addFilter(dbus_filter)
119
 
        log.debug("D-Bus: %s:%s:%s.GetManagedObjects()", dbus_busname,
120
 
                  server_dbus_path, dbus.OBJECT_MANAGER_IFACE)
121
 
        mandos_clients = {path: ifs_and_props[client_dbus_interface]
122
 
                          for path, ifs_and_props in
123
 
                          mandos_serv_object_manager
124
 
                          .GetManagedObjects().items()
125
 
                          if client_dbus_interface in ifs_and_props}
126
 
    except dbus.exceptions.DBusException as e:
127
 
        log.critical("Failed to access Mandos server through D-Bus:"
128
 
                     "\n%s", e)
129
 
        sys.exit(1)
130
 
    finally:
131
 
        # restore dbus logger
132
 
        dbus_logger.removeFilter(dbus_filter)
133
 
 
134
 
    # Compile dict of (clients: properties) to process
135
 
    clients = {}
136
 
 
 
103
        mandos_dbus_object, dbus_interface=dbus.OBJECT_MANAGER_IFACE)
 
104
 
 
105
    managed_objects = get_managed_objects(mandos_serv_object_manager)
 
106
 
 
107
    all_clients = {}
 
108
    for path, ifs_and_props in managed_objects.items():
 
109
        try:
 
110
            all_clients[path] = ifs_and_props[client_dbus_interface]
 
111
        except KeyError:
 
112
            pass
 
113
 
 
114
    # Compile dict of (clientpath: properties) to process
137
115
    if not clientnames:
138
 
        clients = {objpath: properties
139
 
                   for objpath, properties in mandos_clients.items()}
 
116
        clients = all_clients
140
117
    else:
 
118
        clients = {}
141
119
        for name in clientnames:
142
 
            for objpath, properties in mandos_clients.items():
 
120
            for objpath, properties in all_clients.items():
143
121
                if properties["Name"] == name:
144
122
                    clients[objpath] = properties
145
123
                    break
147
125
                log.critical("Client not found on server: %r", name)
148
126
                sys.exit(1)
149
127
 
150
 
    # Run all commands on clients
151
128
    commands = commands_from_options(options)
 
129
 
152
130
    for command in commands:
153
131
        command.run(clients, bus, mandos_serv)
154
132
 
254
232
    >>> rfc3339_duration_to_delta("")
255
233
    Traceback (most recent call last):
256
234
    ...
257
 
    ValueError: Invalid RFC 3339 duration: u''
 
235
    ValueError: Invalid RFC 3339 duration: ""
258
236
    >>> # Must start with "P":
259
237
    >>> rfc3339_duration_to_delta("1D")
260
238
    Traceback (most recent call last):
261
239
    ...
262
 
    ValueError: Invalid RFC 3339 duration: u'1D'
 
240
    ValueError: Invalid RFC 3339 duration: "1D"
263
241
    >>> # Must use correct order
264
242
    >>> rfc3339_duration_to_delta("PT1S2M")
265
243
    Traceback (most recent call last):
266
244
    ...
267
 
    ValueError: Invalid RFC 3339 duration: u'PT1S2M'
 
245
    ValueError: Invalid RFC 3339 duration: "PT1S2M"
268
246
    >>> # Time needs time marker
269
247
    >>> rfc3339_duration_to_delta("P1H2S")
270
248
    Traceback (most recent call last):
271
249
    ...
272
 
    ValueError: Invalid RFC 3339 duration: u'P1H2S'
 
250
    ValueError: Invalid RFC 3339 duration: "P1H2S"
273
251
    >>> # Weeks can not be combined with anything else
274
252
    >>> rfc3339_duration_to_delta("P1D2W")
275
253
    Traceback (most recent call last):
276
254
    ...
277
 
    ValueError: Invalid RFC 3339 duration: u'P1D2W'
 
255
    ValueError: Invalid RFC 3339 duration: "P1D2W"
278
256
    >>> rfc3339_duration_to_delta("P2W2H")
279
257
    Traceback (most recent call last):
280
258
    ...
281
 
    ValueError: Invalid RFC 3339 duration: u'P2W2H'
 
259
    ValueError: Invalid RFC 3339 duration: "P2W2H"
282
260
    """
283
261
 
284
262
    # Parsing an RFC 3339 duration with regular expressions is not
355
333
                break
356
334
        else:
357
335
            # No currently valid tokens were found
358
 
            raise ValueError("Invalid RFC 3339 duration: {!r}"
 
336
            raise ValueError("Invalid RFC 3339 duration: \"{}\""
359
337
                             .format(duration))
360
338
    # End token found
361
339
    return value
446
424
        options.remove = True
447
425
 
448
426
 
 
427
def get_mandos_dbus_object(bus):
 
428
    log.debug("D-Bus: Connect to: (busname=%r, path=%r)",
 
429
              dbus_busname, server_dbus_path)
 
430
    with if_dbus_exception_log_with_exception_and_exit(
 
431
            "Could not connect to Mandos server: %s"):
 
432
        mandos_dbus_object = bus.get_object(dbus_busname,
 
433
                                            server_dbus_path)
 
434
    return mandos_dbus_object
 
435
 
 
436
 
 
437
@contextlib.contextmanager
 
438
def if_dbus_exception_log_with_exception_and_exit(*args, **kwargs):
 
439
    try:
 
440
        yield
 
441
    except dbus.exceptions.DBusException as e:
 
442
        log.critical(*(args + (e,)), **kwargs)
 
443
        sys.exit(1)
 
444
 
 
445
 
 
446
def get_managed_objects(object_manager):
 
447
    log.debug("D-Bus: %s:%s:%s.GetManagedObjects()", dbus_busname,
 
448
              server_dbus_path, dbus.OBJECT_MANAGER_IFACE)
 
449
    with if_dbus_exception_log_with_exception_and_exit(
 
450
            "Failed to access Mandos server through D-Bus:\n%s"):
 
451
        with SilenceLogger("dbus.proxies"):
 
452
            managed_objects = object_manager.GetManagedObjects()
 
453
    return managed_objects
 
454
 
 
455
 
 
456
class SilenceLogger(object):
 
457
    "Simple context manager to silence a particular logger"
 
458
    def __init__(self, loggername):
 
459
        self.logger = logging.getLogger(loggername)
 
460
 
 
461
    def __enter__(self):
 
462
        self.logger.addFilter(self.nullfilter)
 
463
        return self
 
464
 
 
465
    class NullFilter(logging.Filter):
 
466
        def filter(self, record):
 
467
            return False
 
468
 
 
469
    nullfilter = NullFilter()
 
470
 
 
471
    def __exit__(self, exc_type, exc_val, exc_tb):
 
472
        self.logger.removeFilter(self.nullfilter)
 
473
 
 
474
 
449
475
def commands_from_options(options):
450
476
 
451
477
    commands = []
568
594
        self.mandos.RemoveClient(client.__dbus_object_path__)
569
595
 
570
596
 
571
 
class PrintCmd(Command):
572
 
    """Abstract class for commands printing client details"""
 
597
class OutputCmd(Command):
 
598
    """Abstract class for commands outputting client details"""
573
599
    all_keywords = ("Name", "Enabled", "Timeout", "LastCheckedOK",
574
600
                    "Created", "Interval", "Host", "KeyID",
575
601
                    "Fingerprint", "CheckerRunning", "LastEnabled",
577
603
                    "LastApprovalRequest", "ApprovalDelay",
578
604
                    "ApprovalDuration", "Checker", "ExtendedTimeout",
579
605
                    "Expires", "LastCheckerStatus")
 
606
 
580
607
    def run(self, clients, bus=None, mandos=None):
581
608
        print(self.output(clients.values()))
 
609
 
582
610
    def output(self, clients):
583
611
        raise NotImplementedError()
584
612
 
585
613
 
586
 
class DumpJSONCmd(PrintCmd):
 
614
class DumpJSONCmd(OutputCmd):
587
615
    def output(self, clients):
588
616
        data = {client["Name"]:
589
617
                {key: self.dbus_boolean_to_bool(client[key])
590
618
                 for key in self.all_keywords}
591
 
                for client in clients.values()}
 
619
                for client in clients}
592
620
        return json.dumps(data, indent=4, separators=(',', ': '))
 
621
 
593
622
    @staticmethod
594
623
    def dbus_boolean_to_bool(value):
595
624
        if isinstance(value, dbus.Boolean):
597
626
        return value
598
627
 
599
628
 
600
 
class PrintTableCmd(PrintCmd):
 
629
class PrintTableCmd(OutputCmd):
601
630
    def __init__(self, verbose=False):
602
631
        self.verbose = verbose
603
632
 
633
662
            "LastCheckerStatus": "Last Checker Status",
634
663
        }
635
664
 
636
 
        def __init__(self, clients, keywords, tableheaders=None):
 
665
        def __init__(self, clients, keywords):
637
666
            self.clients = clients
638
667
            self.keywords = keywords
639
 
            if tableheaders is not None:
640
 
                self.tableheaders = tableheaders
641
668
 
642
669
        def __str__(self):
643
670
            return "\n".join(self.rows())
696
723
 
697
724
class PropertyCmd(Command):
698
725
    """Abstract class for Actions for setting one client property"""
 
726
 
699
727
    def run_on_one_client(self, client, properties):
700
728
        """Set the Client's D-Bus property"""
701
729
        log.debug("D-Bus: %s:%s:%s.Set(%r, %r, %r)", dbus_busname,
707
735
        client.Set(client_dbus_interface, self.propname,
708
736
                   self.value_to_set,
709
737
                   dbus_interface=dbus.PROPERTIES_IFACE)
 
738
 
710
739
    @property
711
740
    def propname(self):
712
741
        raise NotImplementedError()
763
792
 
764
793
class SetSecretCmd(PropertyValueCmd):
765
794
    propname = "Secret"
 
795
 
766
796
    @property
767
797
    def value_to_set(self):
768
798
        return self._vts
 
799
 
769
800
    @value_to_set.setter
770
801
    def value_to_set(self, value):
771
802
        """When setting, read data from supplied file object"""
776
807
class MillisecondsPropertyValueArgumentCmd(PropertyValueCmd):
777
808
    """Abstract class for PropertyValueCmd taking a value argument as
778
809
a datetime.timedelta() but should store it as milliseconds."""
 
810
 
779
811
    @property
780
812
    def value_to_set(self):
781
813
        return self._vts
 
814
 
782
815
    @value_to_set.setter
783
816
    def value_to_set(self, value):
784
817
        """When setting, convert value from a datetime.timedelta"""
806
839
 
807
840
 
808
841
 
809
 
class Test_string_to_delta(unittest.TestCase):
 
842
class TestCaseWithAssertLogs(unittest.TestCase):
 
843
    """unittest.TestCase.assertLogs only exists in Python 3.4"""
 
844
 
 
845
    if not hasattr(unittest.TestCase, "assertLogs"):
 
846
        @contextlib.contextmanager
 
847
        def assertLogs(self, logger, level=logging.INFO):
 
848
            capturing_handler = self.CapturingLevelHandler(level)
 
849
            old_level = logger.level
 
850
            old_propagate = logger.propagate
 
851
            logger.addHandler(capturing_handler)
 
852
            logger.setLevel(level)
 
853
            logger.propagate = False
 
854
            try:
 
855
                yield capturing_handler.watcher
 
856
            finally:
 
857
                logger.propagate = old_propagate
 
858
                logger.removeHandler(capturing_handler)
 
859
                logger.setLevel(old_level)
 
860
            self.assertGreater(len(capturing_handler.watcher.records),
 
861
                               0)
 
862
 
 
863
        class CapturingLevelHandler(logging.Handler):
 
864
            def __init__(self, level, *args, **kwargs):
 
865
                logging.Handler.__init__(self, *args, **kwargs)
 
866
                self.watcher = self.LoggingWatcher([], [])
 
867
            def emit(self, record):
 
868
                self.watcher.records.append(record)
 
869
                self.watcher.output.append(self.format(record))
 
870
 
 
871
            LoggingWatcher = collections.namedtuple("LoggingWatcher",
 
872
                                                    ("records",
 
873
                                                     "output"))
 
874
 
 
875
 
 
876
class Test_string_to_delta(TestCaseWithAssertLogs):
810
877
    def test_handles_basic_rfc3339(self):
811
878
        self.assertEqual(string_to_delta("PT0S"),
812
879
                         datetime.timedelta())
816
883
                         datetime.timedelta(0, 1))
817
884
        self.assertEqual(string_to_delta("PT2H"),
818
885
                         datetime.timedelta(0, 7200))
 
886
 
819
887
    def test_falls_back_to_pre_1_6_1_with_warning(self):
820
 
        # assertLogs only exists in Python 3.4
821
 
        if hasattr(self, "assertLogs"):
822
 
            with self.assertLogs(log, logging.WARNING):
823
 
                value = string_to_delta("2h")
824
 
        else:
825
 
            class WarningFilter(logging.Filter):
826
 
                """Don't show, but record the presence of, warnings"""
827
 
                def filter(self, record):
828
 
                    is_warning = record.levelno >= logging.WARNING
829
 
                    self.found = is_warning or getattr(self, "found",
830
 
                                                       False)
831
 
                    return not is_warning
832
 
            warning_filter = WarningFilter()
833
 
            log.addFilter(warning_filter)
834
 
            try:
835
 
                value = string_to_delta("2h")
836
 
            finally:
837
 
                log.removeFilter(warning_filter)
838
 
            self.assertTrue(getattr(warning_filter, "found", False))
 
888
        with self.assertLogs(log, logging.WARNING):
 
889
            value = string_to_delta("2h")
839
890
        self.assertEqual(value, datetime.timedelta(0, 7200))
840
891
 
841
892
 
842
893
class Test_check_option_syntax(unittest.TestCase):
 
894
    def setUp(self):
 
895
        self.parser = argparse.ArgumentParser()
 
896
        add_command_line_options(self.parser)
 
897
 
 
898
    def test_actions_requires_client_or_all(self):
 
899
        for action, value in self.actions.items():
 
900
            options = self.parser.parse_args()
 
901
            setattr(options, action, value)
 
902
            with self.assertParseError():
 
903
                self.check_option_syntax(options)
 
904
 
843
905
    # This mostly corresponds to the definition from has_actions() in
844
906
    # check_option_syntax()
845
907
    actions = {
866
928
        "deny": True,
867
929
    }
868
930
 
869
 
    def setUp(self):
870
 
        self.parser = argparse.ArgumentParser()
871
 
        add_command_line_options(self.parser)
872
 
 
873
931
    @contextlib.contextmanager
874
932
    def assertParseError(self):
875
933
        with self.assertRaises(SystemExit) as e:
876
 
            with self.temporarily_suppress_stderr():
 
934
            with self.redirect_stderr_to_devnull():
877
935
                yield
878
936
        # Exit code from argparse is guaranteed to be "2".  Reference:
879
937
        # https://docs.python.org/3/library
882
940
 
883
941
    @staticmethod
884
942
    @contextlib.contextmanager
885
 
    def temporarily_suppress_stderr():
 
943
    def redirect_stderr_to_devnull():
886
944
        null = os.open(os.path.devnull, os.O_RDWR)
887
945
        stderrcopy = os.dup(sys.stderr.fileno())
888
946
        os.dup2(null, sys.stderr.fileno())
897
955
    def check_option_syntax(self, options):
898
956
        check_option_syntax(self.parser, options)
899
957
 
900
 
    def test_actions_requires_client_or_all(self):
 
958
    def test_actions_all_conflicts_with_verbose(self):
901
959
        for action, value in self.actions.items():
902
960
            options = self.parser.parse_args()
903
961
            setattr(options, action, value)
 
962
            options.all = True
 
963
            options.verbose = True
904
964
            with self.assertParseError():
905
965
                self.check_option_syntax(options)
906
966
 
907
 
    def test_actions_conflicts_with_verbose(self):
 
967
    def test_actions_with_client_conflicts_with_verbose(self):
908
968
        for action, value in self.actions.items():
909
969
            options = self.parser.parse_args()
910
970
            setattr(options, action, value)
911
971
            options.verbose = True
 
972
            options.client = ["foo"]
912
973
            with self.assertParseError():
913
974
                self.check_option_syntax(options)
914
975
 
940
1001
            options.all = True
941
1002
            self.check_option_syntax(options)
942
1003
 
 
1004
    def test_any_action_is_ok_with_one_client(self):
 
1005
        for action, value in self.actions.items():
 
1006
            options = self.parser.parse_args()
 
1007
            setattr(options, action, value)
 
1008
            options.client = ["foo"]
 
1009
            self.check_option_syntax(options)
 
1010
 
 
1011
    def test_one_client_with_all_actions_except_is_enabled(self):
 
1012
        options = self.parser.parse_args()
 
1013
        for action, value in self.actions.items():
 
1014
            if action == "is_enabled":
 
1015
                continue
 
1016
            setattr(options, action, value)
 
1017
        options.client = ["foo"]
 
1018
        self.check_option_syntax(options)
 
1019
 
 
1020
    def test_two_clients_with_all_actions_except_is_enabled(self):
 
1021
        options = self.parser.parse_args()
 
1022
        for action, value in self.actions.items():
 
1023
            if action == "is_enabled":
 
1024
                continue
 
1025
            setattr(options, action, value)
 
1026
        options.client = ["foo", "barbar"]
 
1027
        self.check_option_syntax(options)
 
1028
 
 
1029
    def test_two_clients_are_ok_with_actions_except_is_enabled(self):
 
1030
        for action, value in self.actions.items():
 
1031
            if action == "is_enabled":
 
1032
                continue
 
1033
            options = self.parser.parse_args()
 
1034
            setattr(options, action, value)
 
1035
            options.client = ["foo", "barbar"]
 
1036
            self.check_option_syntax(options)
 
1037
 
943
1038
    def test_is_enabled_fails_without_client(self):
944
1039
        options = self.parser.parse_args()
945
1040
        options.is_enabled = True
946
1041
        with self.assertParseError():
947
1042
            self.check_option_syntax(options)
948
1043
 
949
 
    def test_is_enabled_works_with_one_client(self):
950
 
        options = self.parser.parse_args()
951
 
        options.is_enabled = True
952
 
        options.client = ["foo"]
953
 
        self.check_option_syntax(options)
954
 
 
955
1044
    def test_is_enabled_fails_with_two_clients(self):
956
1045
        options = self.parser.parse_args()
957
1046
        options.is_enabled = True
971
1060
                self.check_option_syntax(options)
972
1061
 
973
1062
 
974
 
class Test_command_from_options(unittest.TestCase):
 
1063
class Test_get_mandos_dbus_object(TestCaseWithAssertLogs):
 
1064
    def test_calls_and_returns_get_object_on_bus(self):
 
1065
        class MockBus(object):
 
1066
            called = False
 
1067
            def get_object(mockbus_self, busname, dbus_path):
 
1068
                # Note that "self" is still the testcase instance,
 
1069
                # this MockBus instance is in "mockbus_self".
 
1070
                self.assertEqual(busname, dbus_busname)
 
1071
                self.assertEqual(dbus_path, server_dbus_path)
 
1072
                mockbus_self.called = True
 
1073
                return mockbus_self
 
1074
 
 
1075
        mockbus = get_mandos_dbus_object(bus=MockBus())
 
1076
        self.assertIsInstance(mockbus, MockBus)
 
1077
        self.assertTrue(mockbus.called)
 
1078
 
 
1079
    def test_logs_and_exits_on_dbus_error(self):
 
1080
        class MockBusFailing(object):
 
1081
            def get_object(self, busname, dbus_path):
 
1082
                raise dbus.exceptions.DBusException("Test")
 
1083
 
 
1084
        with self.assertLogs(log, logging.CRITICAL):
 
1085
            with self.assertRaises(SystemExit) as e:
 
1086
                bus = get_mandos_dbus_object(bus=MockBusFailing())
 
1087
 
 
1088
        if isinstance(e.exception.code, int):
 
1089
            self.assertNotEqual(e.exception.code, 0)
 
1090
        else:
 
1091
            self.assertIsNotNone(e.exception.code)
 
1092
 
 
1093
 
 
1094
class Test_get_managed_objects(TestCaseWithAssertLogs):
 
1095
    def test_calls_and_returns_GetManagedObjects(self):
 
1096
        managed_objects = {"/clients/foo": { "Name": "foo"}}
 
1097
        class MockObjectManager(object):
 
1098
            def GetManagedObjects(self):
 
1099
                return managed_objects
 
1100
        retval = get_managed_objects(MockObjectManager())
 
1101
        self.assertDictEqual(managed_objects, retval)
 
1102
 
 
1103
    def test_logs_and_exits_on_dbus_error(self):
 
1104
        dbus_logger = logging.getLogger("dbus.proxies")
 
1105
 
 
1106
        class MockObjectManagerFailing(object):
 
1107
            def GetManagedObjects(self):
 
1108
                dbus_logger.error("Test")
 
1109
                raise dbus.exceptions.DBusException("Test")
 
1110
 
 
1111
        class CountingHandler(logging.Handler):
 
1112
            count = 0
 
1113
            def emit(self, record):
 
1114
                self.count += 1
 
1115
 
 
1116
        counting_handler = CountingHandler()
 
1117
 
 
1118
        dbus_logger.addHandler(counting_handler)
 
1119
 
 
1120
        try:
 
1121
            with self.assertLogs(log, logging.CRITICAL) as watcher:
 
1122
                with self.assertRaises(SystemExit) as e:
 
1123
                    get_managed_objects(MockObjectManagerFailing())
 
1124
        finally:
 
1125
            dbus_logger.removeFilter(counting_handler)
 
1126
 
 
1127
        # Make sure the dbus logger was suppressed
 
1128
        self.assertEqual(counting_handler.count, 0)
 
1129
 
 
1130
        # Test that the dbus_logger still works
 
1131
        with self.assertLogs(dbus_logger, logging.ERROR):
 
1132
            dbus_logger.error("Test")
 
1133
 
 
1134
        if isinstance(e.exception.code, int):
 
1135
            self.assertNotEqual(e.exception.code, 0)
 
1136
        else:
 
1137
            self.assertIsNotNone(e.exception.code)
 
1138
 
 
1139
 
 
1140
class Test_commands_from_options(unittest.TestCase):
975
1141
    def setUp(self):
976
1142
        self.parser = argparse.ArgumentParser()
977
1143
        add_command_line_options(self.parser)
 
1144
 
 
1145
    def test_is_enabled(self):
 
1146
        self.assert_command_from_args(["--is-enabled", "foo"],
 
1147
                                      IsEnabledCmd)
 
1148
 
978
1149
    def assert_command_from_args(self, args, command_cls,
979
1150
                                 **cmd_attrs):
980
1151
        """Assert that parsing ARGS should result in an instance of
987
1158
        self.assertIsInstance(command, command_cls)
988
1159
        for key, value in cmd_attrs.items():
989
1160
            self.assertEqual(getattr(command, key), value)
990
 
    def test_print_table(self):
991
 
        self.assert_command_from_args([], PrintTableCmd,
992
 
                                      verbose=False)
993
 
 
994
 
    def test_print_table_verbose(self):
995
 
        self.assert_command_from_args(["--verbose"], PrintTableCmd,
996
 
                                      verbose=True)
997
 
 
998
 
    def test_print_table_verbose_short(self):
999
 
        self.assert_command_from_args(["-v"], PrintTableCmd,
1000
 
                                      verbose=True)
 
1161
 
 
1162
    def test_is_enabled_short(self):
 
1163
        self.assert_command_from_args(["-V", "foo"], IsEnabledCmd)
 
1164
 
 
1165
    def test_approve(self):
 
1166
        self.assert_command_from_args(["--approve", "foo"],
 
1167
                                      ApproveCmd)
 
1168
 
 
1169
    def test_approve_short(self):
 
1170
        self.assert_command_from_args(["-A", "foo"], ApproveCmd)
 
1171
 
 
1172
    def test_deny(self):
 
1173
        self.assert_command_from_args(["--deny", "foo"], DenyCmd)
 
1174
 
 
1175
    def test_deny_short(self):
 
1176
        self.assert_command_from_args(["-D", "foo"], DenyCmd)
 
1177
 
 
1178
    def test_remove(self):
 
1179
        self.assert_command_from_args(["--remove", "foo"],
 
1180
                                      RemoveCmd)
 
1181
 
 
1182
    def test_deny_before_remove(self):
 
1183
        options = self.parser.parse_args(["--deny", "--remove",
 
1184
                                          "foo"])
 
1185
        check_option_syntax(self.parser, options)
 
1186
        commands = commands_from_options(options)
 
1187
        self.assertEqual(len(commands), 2)
 
1188
        self.assertIsInstance(commands[0], DenyCmd)
 
1189
        self.assertIsInstance(commands[1], RemoveCmd)
 
1190
 
 
1191
    def test_deny_before_remove_reversed(self):
 
1192
        options = self.parser.parse_args(["--remove", "--deny",
 
1193
                                          "--all"])
 
1194
        check_option_syntax(self.parser, options)
 
1195
        commands = commands_from_options(options)
 
1196
        self.assertEqual(len(commands), 2)
 
1197
        self.assertIsInstance(commands[0], DenyCmd)
 
1198
        self.assertIsInstance(commands[1], RemoveCmd)
 
1199
 
 
1200
    def test_remove_short(self):
 
1201
        self.assert_command_from_args(["-r", "foo"], RemoveCmd)
 
1202
 
 
1203
    def test_dump_json(self):
 
1204
        self.assert_command_from_args(["--dump-json"], DumpJSONCmd)
1001
1205
 
1002
1206
    def test_enable(self):
1003
1207
        self.assert_command_from_args(["--enable", "foo"], EnableCmd)
1027
1231
        self.assert_command_from_args(["--stop-checker", "foo"],
1028
1232
                                      StopCheckerCmd)
1029
1233
 
1030
 
    def test_remove(self):
1031
 
        self.assert_command_from_args(["--remove", "foo"],
1032
 
                                      RemoveCmd)
 
1234
    def test_approve_by_default(self):
 
1235
        self.assert_command_from_args(["--approve-by-default", "foo"],
 
1236
                                      ApproveByDefaultCmd)
1033
1237
 
1034
 
    def test_remove_short(self):
1035
 
        self.assert_command_from_args(["-r", "foo"], RemoveCmd)
 
1238
    def test_deny_by_default(self):
 
1239
        self.assert_command_from_args(["--deny-by-default", "foo"],
 
1240
                                      DenyByDefaultCmd)
1036
1241
 
1037
1242
    def test_checker(self):
1038
1243
        self.assert_command_from_args(["--checker", ":", "foo"],
1046
1251
        self.assert_command_from_args(["-c", ":", "foo"],
1047
1252
                                      SetCheckerCmd, value_to_set=":")
1048
1253
 
 
1254
    def test_host(self):
 
1255
        self.assert_command_from_args(["--host", "foo.example.org",
 
1256
                                       "foo"], SetHostCmd,
 
1257
                                      value_to_set="foo.example.org")
 
1258
 
 
1259
    def test_host_short(self):
 
1260
        self.assert_command_from_args(["-H", "foo.example.org",
 
1261
                                       "foo"], SetHostCmd,
 
1262
                                      value_to_set="foo.example.org")
 
1263
 
 
1264
    def test_secret_devnull(self):
 
1265
        self.assert_command_from_args(["--secret", os.path.devnull,
 
1266
                                       "foo"], SetSecretCmd,
 
1267
                                      value_to_set=b"")
 
1268
 
 
1269
    def test_secret_tempfile(self):
 
1270
        with tempfile.NamedTemporaryFile(mode="r+b") as f:
 
1271
            value = b"secret\0xyzzy\nbar"
 
1272
            f.write(value)
 
1273
            f.seek(0)
 
1274
            self.assert_command_from_args(["--secret", f.name,
 
1275
                                           "foo"], SetSecretCmd,
 
1276
                                          value_to_set=value)
 
1277
 
 
1278
    def test_secret_devnull_short(self):
 
1279
        self.assert_command_from_args(["-s", os.path.devnull, "foo"],
 
1280
                                      SetSecretCmd, value_to_set=b"")
 
1281
 
 
1282
    def test_secret_tempfile_short(self):
 
1283
        with tempfile.NamedTemporaryFile(mode="r+b") as f:
 
1284
            value = b"secret\0xyzzy\nbar"
 
1285
            f.write(value)
 
1286
            f.seek(0)
 
1287
            self.assert_command_from_args(["-s", f.name, "foo"],
 
1288
                                          SetSecretCmd,
 
1289
                                          value_to_set=value)
 
1290
 
1049
1291
    def test_timeout(self):
1050
1292
        self.assert_command_from_args(["--timeout", "PT5M", "foo"],
1051
1293
                                      SetTimeoutCmd,
1072
1314
                                      SetIntervalCmd,
1073
1315
                                      value_to_set=120000)
1074
1316
 
1075
 
    def test_approve_by_default(self):
1076
 
        self.assert_command_from_args(["--approve-by-default", "foo"],
1077
 
                                      ApproveByDefaultCmd)
1078
 
 
1079
 
    def test_deny_by_default(self):
1080
 
        self.assert_command_from_args(["--deny-by-default", "foo"],
1081
 
                                      DenyByDefaultCmd)
1082
 
 
1083
1317
    def test_approval_delay(self):
1084
1318
        self.assert_command_from_args(["--approval-delay", "PT30S",
1085
1319
                                       "foo"], SetApprovalDelayCmd,
1090
1324
                                       "foo"], SetApprovalDurationCmd,
1091
1325
                                      value_to_set=1000)
1092
1326
 
1093
 
    def test_host(self):
1094
 
        self.assert_command_from_args(["--host", "foo.example.org",
1095
 
                                       "foo"], SetHostCmd,
1096
 
                                      value_to_set="foo.example.org")
1097
 
 
1098
 
    def test_host_short(self):
1099
 
        self.assert_command_from_args(["-H", "foo.example.org",
1100
 
                                       "foo"], SetHostCmd,
1101
 
                                      value_to_set="foo.example.org")
1102
 
 
1103
 
    def test_secret_devnull(self):
1104
 
        self.assert_command_from_args(["--secret", os.path.devnull,
1105
 
                                       "foo"], SetSecretCmd,
1106
 
                                      value_to_set=b"")
1107
 
 
1108
 
    def test_secret_tempfile(self):
1109
 
        with tempfile.NamedTemporaryFile(mode="r+b") as f:
1110
 
            value = b"secret\0xyzzy\nbar"
1111
 
            f.write(value)
1112
 
            f.seek(0)
1113
 
            self.assert_command_from_args(["--secret", f.name,
1114
 
                                           "foo"], SetSecretCmd,
1115
 
                                          value_to_set=value)
1116
 
 
1117
 
    def test_secret_devnull_short(self):
1118
 
        self.assert_command_from_args(["-s", os.path.devnull, "foo"],
1119
 
                                      SetSecretCmd, value_to_set=b"")
1120
 
 
1121
 
    def test_secret_tempfile_short(self):
1122
 
        with tempfile.NamedTemporaryFile(mode="r+b") as f:
1123
 
            value = b"secret\0xyzzy\nbar"
1124
 
            f.write(value)
1125
 
            f.seek(0)
1126
 
            self.assert_command_from_args(["-s", f.name, "foo"],
1127
 
                                          SetSecretCmd,
1128
 
                                          value_to_set=value)
1129
 
 
1130
 
    def test_approve(self):
1131
 
        self.assert_command_from_args(["--approve", "foo"],
1132
 
                                      ApproveCmd)
1133
 
 
1134
 
    def test_approve_short(self):
1135
 
        self.assert_command_from_args(["-A", "foo"], ApproveCmd)
1136
 
 
1137
 
    def test_deny(self):
1138
 
        self.assert_command_from_args(["--deny", "foo"], DenyCmd)
1139
 
 
1140
 
    def test_deny_short(self):
1141
 
        self.assert_command_from_args(["-D", "foo"], DenyCmd)
1142
 
 
1143
 
    def test_dump_json(self):
1144
 
        self.assert_command_from_args(["--dump-json"], DumpJSONCmd)
1145
 
 
1146
 
    def test_is_enabled(self):
1147
 
        self.assert_command_from_args(["--is-enabled", "foo"],
1148
 
                                      IsEnabledCmd)
1149
 
 
1150
 
    def test_is_enabled_short(self):
1151
 
        self.assert_command_from_args(["-V", "foo"], IsEnabledCmd)
1152
 
 
1153
 
    def test_deny_before_remove(self):
1154
 
        options = self.parser.parse_args(["--deny", "--remove",
1155
 
                                          "foo"])
1156
 
        check_option_syntax(self.parser, options)
1157
 
        commands = commands_from_options(options)
1158
 
        self.assertEqual(len(commands), 2)
1159
 
        self.assertIsInstance(commands[0], DenyCmd)
1160
 
        self.assertIsInstance(commands[1], RemoveCmd)
1161
 
 
1162
 
    def test_deny_before_remove_reversed(self):
1163
 
        options = self.parser.parse_args(["--remove", "--deny",
1164
 
                                          "--all"])
1165
 
        check_option_syntax(self.parser, options)
1166
 
        commands = commands_from_options(options)
1167
 
        self.assertEqual(len(commands), 2)
1168
 
        self.assertIsInstance(commands[0], DenyCmd)
1169
 
        self.assertIsInstance(commands[1], RemoveCmd)
 
1327
    def test_print_table(self):
 
1328
        self.assert_command_from_args([], PrintTableCmd,
 
1329
                                      verbose=False)
 
1330
 
 
1331
    def test_print_table_verbose(self):
 
1332
        self.assert_command_from_args(["--verbose"], PrintTableCmd,
 
1333
                                      verbose=True)
 
1334
 
 
1335
    def test_print_table_verbose_short(self):
 
1336
        self.assert_command_from_args(["-v"], PrintTableCmd,
 
1337
                                      verbose=True)
1170
1338
 
1171
1339
 
1172
1340
class TestCmd(unittest.TestCase):
1173
1341
    """Abstract class for tests of command classes"""
 
1342
 
1174
1343
    def setUp(self):
1175
1344
        testcase = self
1176
1345
        class MockClient(object):
1248
1417
                ("/clients/barbar", self.other_client.attributes),
1249
1418
            ])
1250
1419
        self.one_client = {"/clients/foo": self.client.attributes}
 
1420
 
1251
1421
    @property
1252
1422
    def bus(self):
1253
1423
        class Bus(object):
1255
1425
            def get_object(client_bus_name, path):
1256
1426
                self.assertEqual(client_bus_name, dbus_busname)
1257
1427
                return {
 
1428
                    # Note: "self" here is the TestCmd instance, not
 
1429
                    # the Bus instance, since this is a static method!
1258
1430
                    "/clients/foo": self.client,
1259
1431
                    "/clients/barbar": self.other_client,
1260
1432
                }[path]
1267
1439
                                                      properties)
1268
1440
                            for client, properties
1269
1441
                            in self.clients.items()))
 
1442
 
1270
1443
    def test_is_enabled_run_exits_successfully(self):
1271
1444
        with self.assertRaises(SystemExit) as e:
1272
1445
            IsEnabledCmd().run(self.one_client)
1274
1447
            self.assertEqual(e.exception.code, 0)
1275
1448
        else:
1276
1449
            self.assertIsNone(e.exception.code)
 
1450
 
1277
1451
    def test_is_enabled_run_exits_with_failure(self):
1278
1452
        self.client.attributes["Enabled"] = dbus.Boolean(False)
1279
1453
        with self.assertRaises(SystemExit) as e:
1301
1475
            self.assertIn(("Approve", (False, client_dbus_interface)),
1302
1476
                          client.calls)
1303
1477
 
 
1478
 
1304
1479
class TestRemoveCmd(TestCmd):
1305
1480
    def test_remove(self):
1306
1481
        class MockMandos(object):
1370
1545
            },
1371
1546
        }
1372
1547
        return super(TestDumpJSONCmd, self).setUp()
 
1548
 
1373
1549
    def test_normal(self):
1374
 
        json_data = json.loads(DumpJSONCmd().output(self.clients))
 
1550
        output = DumpJSONCmd().output(self.clients.values())
 
1551
        json_data = json.loads(output)
1375
1552
        self.assertDictEqual(json_data, self.expected_json)
 
1553
 
1376
1554
    def test_one_client(self):
1377
 
        clients = self.one_client
1378
 
        json_data = json.loads(DumpJSONCmd().output(clients))
 
1555
        output = DumpJSONCmd().output(self.one_client.values())
 
1556
        json_data = json.loads(output)
1379
1557
        expected_json = {"foo": self.expected_json["foo"]}
1380
1558
        self.assertDictEqual(json_data, expected_json)
1381
1559
 
1389
1567
            "barbar Yes     00:05:00 2019-02-04T00:00:00  ",
1390
1568
        ))
1391
1569
        self.assertEqual(output, expected_output)
 
1570
 
1392
1571
    def test_verbose(self):
1393
1572
        output = PrintTableCmd(verbose=True).output(
1394
1573
            self.clients.values())
1483
1662
                                            for rows in columns)
1484
1663
                                    for line in range(num_lines))
1485
1664
        self.assertEqual(output, expected_output)
 
1665
 
1486
1666
    def test_one_client(self):
1487
1667
        output = PrintTableCmd().output(self.one_client.values())
1488
1668
        expected_output = "\n".join((
1492
1672
        self.assertEqual(output, expected_output)
1493
1673
 
1494
1674
 
1495
 
class Unique(object):
1496
 
    """Class for objects which exist only to be unique objects, since
1497
 
unittest.mock.sentinel only exists in Python 3.3"""
1498
 
 
1499
 
 
1500
1675
class TestPropertyCmd(TestCmd):
1501
1676
    """Abstract class for tests of PropertyCmd classes"""
1502
1677
    def runTest(self):
1509
1684
            for clientpath in self.clients:
1510
1685
                client = self.bus.get_object(dbus_busname, clientpath)
1511
1686
                old_value = client.attributes[self.propname]
1512
 
                self.assertNotIsInstance(old_value, Unique)
1513
 
                client.attributes[self.propname] = Unique()
 
1687
                self.assertNotIsInstance(old_value, self.Unique)
 
1688
                client.attributes[self.propname] = self.Unique()
1514
1689
            self.run_command(value_to_set, self.clients)
1515
1690
            for clientpath in self.clients:
1516
1691
                client = self.bus.get_object(dbus_busname, clientpath)
1517
1692
                value = client.attributes[self.propname]
1518
 
                self.assertNotIsInstance(value, Unique)
 
1693
                self.assertNotIsInstance(value, self.Unique)
1519
1694
                self.assertEqual(value, value_to_get)
 
1695
 
 
1696
    class Unique(object):
 
1697
        """Class for objects which exist only to be unique objects,
 
1698
since unittest.mock.sentinel only exists in Python 3.3"""
 
1699
 
1520
1700
    def run_command(self, value, clients):
1521
1701
        self.command().run(clients, self.bus)
1522
1702
 
1523
1703
 
1524
 
class TestEnableCmd(TestCmd):
1525
 
    def test_enable(self):
1526
 
        for clientpath in self.clients:
1527
 
            client = self.bus.get_object(dbus_busname, clientpath)
1528
 
            client.attributes["Enabled"] = False
1529
 
 
1530
 
        EnableCmd().run(self.clients, self.bus)
1531
 
 
1532
 
        for clientpath in self.clients:
1533
 
            client = self.bus.get_object(dbus_busname, clientpath)
1534
 
            self.assertTrue(client.attributes["Enabled"])
1535
 
 
1536
 
 
1537
 
class TestDisableCmd(TestCmd):
1538
 
    def test_disable(self):
1539
 
        DisableCmd().run(self.clients, self.bus)
1540
 
        for clientpath in self.clients:
1541
 
            client = self.bus.get_object(dbus_busname, clientpath)
1542
 
            self.assertFalse(client.attributes["Enabled"])
 
1704
class TestEnableCmd(TestPropertyCmd):
 
1705
    command = EnableCmd
 
1706
    propname = "Enabled"
 
1707
    values_to_set = [dbus.Boolean(True)]
 
1708
 
 
1709
 
 
1710
class TestDisableCmd(TestPropertyCmd):
 
1711
    command = DisableCmd
 
1712
    propname = "Enabled"
 
1713
    values_to_set = [dbus.Boolean(False)]
1543
1714
 
1544
1715
 
1545
1716
class TestBumpTimeoutCmd(TestPropertyCmd):
1574
1745
 
1575
1746
class TestPropertyValueCmd(TestPropertyCmd):
1576
1747
    """Abstract class for tests of PropertyValueCmd classes"""
 
1748
 
1577
1749
    def runTest(self):
1578
1750
        if type(self) is TestPropertyValueCmd:
1579
1751
            return
1580
1752
        return super(TestPropertyValueCmd, self).runTest()
 
1753
 
1581
1754
    def run_command(self, value, clients):
1582
1755
        self.command(value).run(clients, self.bus)
1583
1756
 
1675
1848
    return tests
1676
1849
 
1677
1850
if __name__ == "__main__":
1678
 
    if should_only_run_tests():
1679
 
        # Call using ./tdd-python-script --check [--verbose]
1680
 
        unittest.main()
1681
 
    else:
1682
 
        main()
 
1851
    try:
 
1852
        if should_only_run_tests():
 
1853
            # Call using ./tdd-python-script --check [--verbose]
 
1854
            unittest.main()
 
1855
        else:
 
1856
            main()
 
1857
    finally:
 
1858
        logging.shutdown()