/mandos/trunk

To get this branch, use:
bzr branch http://bzr.recompile.se/loggerhead/mandos/trunk

« back to all changes in this revision

Viewing changes to mandos-ctl

  • Committer: Teddy Hogeborn
  • Date: 2019-03-17 21:29:32 UTC
  • Revision ID: teddy@recompile.se-20190317212932-r3libgz33mkb85rw
mandos-ctl: Refactor

* mandos-ctl: For Python 2, use StringIO.StringIO as a replacement for
              io.StringIO, since Python 2's io.StringIO won't work
              with print redirection.
  (Output.run, Output.output): Remove.
  (DumpJSON.output): Rename to "run" and change signature to match.
                     Also change code to print instead of returning
                     string.
  (PrintTable.output): - '' -

Show diffs side-by-side

added added

removed removed

Lines of Context:
61
61
 
62
62
if sys.version_info.major == 2:
63
63
    str = unicode
 
64
    import StringIO
 
65
    io.StringIO = StringIO.StringIO
64
66
 
65
67
locale.setlocale(locale.LC_ALL, "")
66
68
 
613
615
                        "Checker", "ExtendedTimeout", "Expires",
614
616
                        "LastCheckerStatus")
615
617
 
616
 
        def run(self, clients, bus=None, mandos=None):
617
 
            print(self.output(clients.values()))
618
 
 
619
 
        def output(self, clients):
620
 
            raise NotImplementedError()
621
 
 
622
618
 
623
619
    class DumpJSON(Output):
624
 
        def output(self, clients):
 
620
        def run(self, clients, bus=None, mandos=None):
625
621
            data = {client["Name"]:
626
622
                    {key: self.dbus_boolean_to_bool(client[key])
627
623
                     for key in self.all_keywords}
628
 
                    for client in clients}
629
 
            return json.dumps(data, indent=4, separators=(',', ': '))
 
624
                    for client in clients.values()}
 
625
            print(json.dumps(data, indent=4, separators=(',', ': ')))
630
626
 
631
627
        @staticmethod
632
628
        def dbus_boolean_to_bool(value):
639
635
        def __init__(self, verbose=False):
640
636
            self.verbose = verbose
641
637
 
642
 
        def output(self, clients):
 
638
        def run(self, clients, bus=None, mandos=None):
643
639
            default_keywords = ("Name", "Enabled", "Timeout",
644
640
                                "LastCheckedOK")
645
641
            keywords = default_keywords
646
642
            if self.verbose:
647
643
                keywords = self.all_keywords
648
 
            return str(self.TableOfClients(clients, keywords))
 
644
            print(self.TableOfClients(clients.values(), keywords))
649
645
 
650
646
        class TableOfClients(object):
651
647
            tableheaders = {
886
882
 
887
883
 
888
884
class Test_string_to_delta(TestCaseWithAssertLogs):
889
 
    def test_handles_basic_rfc3339(self):
890
 
        self.assertEqual(string_to_delta("PT0S"),
891
 
                         datetime.timedelta())
892
 
        self.assertEqual(string_to_delta("P0D"),
893
 
                         datetime.timedelta())
894
 
        self.assertEqual(string_to_delta("PT1S"),
895
 
                         datetime.timedelta(0, 1))
896
 
        self.assertEqual(string_to_delta("PT2H"),
897
 
                         datetime.timedelta(0, 7200))
 
885
    # Just test basic RFC 3339 functionality here, the doc string for
 
886
    # rfc3339_duration_to_delta() already has more comprehensive
 
887
    # tests, which is run by doctest.
 
888
 
 
889
    def test_rfc3339_zero_seconds(self):
 
890
        self.assertEqual(datetime.timedelta(),
 
891
                         string_to_delta("PT0S"))
 
892
 
 
893
    def test_rfc3339_zero_days(self):
 
894
        self.assertEqual(datetime.timedelta(), string_to_delta("P0D"))
 
895
 
 
896
    def test_rfc3339_one_second(self):
 
897
        self.assertEqual(datetime.timedelta(0, 1),
 
898
                         string_to_delta("PT1S"))
 
899
 
 
900
    def test_rfc3339_two_hours(self):
 
901
        self.assertEqual(datetime.timedelta(0, 7200),
 
902
                         string_to_delta("PT2H"))
898
903
 
899
904
    def test_falls_back_to_pre_1_6_1_with_warning(self):
900
905
        with self.assertLogs(log, logging.WARNING):
901
906
            value = string_to_delta("2h")
902
 
        self.assertEqual(value, datetime.timedelta(0, 7200))
 
907
        self.assertEqual(datetime.timedelta(0, 7200), value)
903
908
 
904
909
 
905
910
class Test_check_option_syntax(unittest.TestCase):
948
953
        # Exit code from argparse is guaranteed to be "2".  Reference:
949
954
        # https://docs.python.org/3/library
950
955
        # /argparse.html#exiting-methods
951
 
        self.assertEqual(e.exception.code, 2)
 
956
        self.assertEqual(2, e.exception.code)
952
957
 
953
958
    @staticmethod
954
959
    @contextlib.contextmanager
955
960
    def redirect_stderr_to_devnull():
956
 
        null = os.open(os.path.devnull, os.O_RDWR)
957
 
        stderrcopy = os.dup(sys.stderr.fileno())
958
 
        os.dup2(null, sys.stderr.fileno())
959
 
        os.close(null)
960
 
        try:
961
 
            yield
962
 
        finally:
963
 
            # restore stderr
964
 
            os.dup2(stderrcopy, sys.stderr.fileno())
965
 
            os.close(stderrcopy)
 
961
        old_stderr = sys.stderr
 
962
        with contextlib.closing(open(os.devnull, "w")) as null:
 
963
            sys.stderr = null
 
964
            try:
 
965
                yield
 
966
            finally:
 
967
                sys.stderr = old_stderr
966
968
 
967
969
    def check_option_syntax(self, options):
968
970
        check_option_syntax(self.parser, options)
1079
1081
            def get_object(mockbus_self, busname, dbus_path):
1080
1082
                # Note that "self" is still the testcase instance,
1081
1083
                # this MockBus instance is in "mockbus_self".
1082
 
                self.assertEqual(busname, dbus_busname)
1083
 
                self.assertEqual(dbus_path, server_dbus_path)
 
1084
                self.assertEqual(dbus_busname, busname)
 
1085
                self.assertEqual(server_dbus_path, dbus_path)
1084
1086
                mockbus_self.called = True
1085
1087
                return mockbus_self
1086
1088
 
1098
1100
                bus = get_mandos_dbus_object(bus=MockBusFailing())
1099
1101
 
1100
1102
        if isinstance(e.exception.code, int):
1101
 
            self.assertNotEqual(e.exception.code, 0)
 
1103
            self.assertNotEqual(0, e.exception.code)
1102
1104
        else:
1103
1105
            self.assertIsNotNone(e.exception.code)
1104
1106
 
1137
1139
            dbus_logger.removeFilter(counting_handler)
1138
1140
 
1139
1141
        # Make sure the dbus logger was suppressed
1140
 
        self.assertEqual(counting_handler.count, 0)
 
1142
        self.assertEqual(0, counting_handler.count)
1141
1143
 
1142
1144
        # Test that the dbus_logger still works
1143
1145
        with self.assertLogs(dbus_logger, logging.ERROR):
1144
1146
            dbus_logger.error("Test")
1145
1147
 
1146
1148
        if isinstance(e.exception.code, int):
1147
 
            self.assertNotEqual(e.exception.code, 0)
 
1149
            self.assertNotEqual(0, e.exception.code)
1148
1150
        else:
1149
1151
            self.assertIsNotNone(e.exception.code)
1150
1152
 
1165
1167
        options = self.parser.parse_args(args)
1166
1168
        check_option_syntax(self.parser, options)
1167
1169
        commands = commands_from_options(options)
1168
 
        self.assertEqual(len(commands), 1)
 
1170
        self.assertEqual(1, len(commands))
1169
1171
        command = commands[0]
1170
1172
        self.assertIsInstance(command, command_cls)
1171
1173
        for key, value in cmd_attrs.items():
1172
 
            self.assertEqual(getattr(command, key), value)
 
1174
            self.assertEqual(value, getattr(command, key))
1173
1175
 
1174
1176
    def test_is_enabled_short(self):
1175
1177
        self.assert_command_from_args(["-V", "foo"],
1197
1199
                                          "foo"])
1198
1200
        check_option_syntax(self.parser, options)
1199
1201
        commands = commands_from_options(options)
1200
 
        self.assertEqual(len(commands), 2)
 
1202
        self.assertEqual(2, len(commands))
1201
1203
        self.assertIsInstance(commands[0], command.Deny)
1202
1204
        self.assertIsInstance(commands[1], command.Remove)
1203
1205
 
1206
1208
                                          "--all"])
1207
1209
        check_option_syntax(self.parser, options)
1208
1210
        commands = commands_from_options(options)
1209
 
        self.assertEqual(len(commands), 2)
 
1211
        self.assertEqual(2, len(commands))
1210
1212
        self.assertIsInstance(commands[0], command.Deny)
1211
1213
        self.assertIsInstance(commands[1], command.Remove)
1212
1214
 
1372
1374
                self.attributes["Name"] = name
1373
1375
                self.calls = []
1374
1376
            def Set(self, interface, propname, value, dbus_interface):
1375
 
                testcase.assertEqual(interface, client_dbus_interface)
1376
 
                testcase.assertEqual(dbus_interface,
1377
 
                                     dbus.PROPERTIES_IFACE)
 
1377
                testcase.assertEqual(client_dbus_interface, interface)
 
1378
                testcase.assertEqual(dbus.PROPERTIES_IFACE,
 
1379
                                     dbus_interface)
1378
1380
                self.attributes[propname] = value
1379
 
            def Get(self, interface, propname, dbus_interface):
1380
 
                testcase.assertEqual(interface, client_dbus_interface)
1381
 
                testcase.assertEqual(dbus_interface,
1382
 
                                     dbus.PROPERTIES_IFACE)
1383
 
                return self.attributes[propname]
1384
1381
            def Approve(self, approve, dbus_interface):
1385
 
                testcase.assertEqual(dbus_interface,
1386
 
                                     client_dbus_interface)
 
1382
                testcase.assertEqual(client_dbus_interface,
 
1383
                                     dbus_interface)
1387
1384
                self.calls.append(("Approve", (approve,
1388
1385
                                               dbus_interface)))
1389
1386
        self.client = MockClient(
1446
1443
        class Bus(object):
1447
1444
            @staticmethod
1448
1445
            def get_object(client_bus_name, path):
1449
 
                self.assertEqual(client_bus_name, dbus_busname)
 
1446
                self.assertEqual(dbus_busname, client_bus_name)
1450
1447
                return {
1451
1448
                    # Note: "self" here is the TestCmd instance, not
1452
1449
                    # the Bus instance, since this is a static method!
1458
1455
 
1459
1456
class TestBaseCommands(TestCommand):
1460
1457
 
1461
 
    def test_is_enabled(self):
1462
 
        self.assertTrue(all(command.IsEnabled().is_enabled(client,
1463
 
                                                      properties)
1464
 
                            for client, properties
1465
 
                            in self.clients.items()))
1466
 
 
1467
 
    def test_is_enabled_run_exits_successfully(self):
 
1458
    def test_IsEnabled_exits_successfully(self):
1468
1459
        with self.assertRaises(SystemExit) as e:
1469
1460
            command.IsEnabled().run(self.one_client)
1470
1461
        if e.exception.code is not None:
1471
 
            self.assertEqual(e.exception.code, 0)
 
1462
            self.assertEqual(0, e.exception.code)
1472
1463
        else:
1473
1464
            self.assertIsNone(e.exception.code)
1474
1465
 
1475
 
    def test_is_enabled_run_exits_with_failure(self):
 
1466
    def test_IsEnabled_exits_with_failure(self):
1476
1467
        self.client.attributes["Enabled"] = dbus.Boolean(False)
1477
1468
        with self.assertRaises(SystemExit) as e:
1478
1469
            command.IsEnabled().run(self.one_client)
1479
1470
        if isinstance(e.exception.code, int):
1480
 
            self.assertNotEqual(e.exception.code, 0)
 
1471
            self.assertNotEqual(0, e.exception.code)
1481
1472
        else:
1482
1473
            self.assertIsNotNone(e.exception.code)
1483
1474
 
1484
 
    def test_approve(self):
 
1475
    def test_Approve(self):
1485
1476
        command.Approve().run(self.clients, self.bus)
1486
1477
        for clientpath in self.clients:
1487
1478
            client = self.bus.get_object(dbus_busname, clientpath)
1488
1479
            self.assertIn(("Approve", (True, client_dbus_interface)),
1489
1480
                          client.calls)
1490
1481
 
1491
 
    def test_deny(self):
 
1482
    def test_Deny(self):
1492
1483
        command.Deny().run(self.clients, self.bus)
1493
1484
        for clientpath in self.clients:
1494
1485
            client = self.bus.get_object(dbus_busname, clientpath)
1495
1486
            self.assertIn(("Approve", (False, client_dbus_interface)),
1496
1487
                          client.calls)
1497
1488
 
1498
 
    def test_remove(self):
 
1489
    def test_Remove(self):
1499
1490
        class MockMandos(object):
1500
1491
            def __init__(self):
1501
1492
                self.calls = []
1503
1494
                self.calls.append(("RemoveClient", (dbus_path,)))
1504
1495
        mandos = MockMandos()
1505
1496
        command.Remove().run(self.clients, self.bus, mandos)
1506
 
        self.assertEqual(len(mandos.calls), 2)
1507
1497
        for clientpath in self.clients:
1508
1498
            self.assertIn(("RemoveClient", (clientpath,)),
1509
1499
                          mandos.calls)
1560
1550
    }
1561
1551
 
1562
1552
    def test_DumpJSON_normal(self):
1563
 
        output = command.DumpJSON().output(self.clients.values())
1564
 
        json_data = json.loads(output)
1565
 
        self.assertDictEqual(json_data, self.expected_json)
 
1553
        with self.capture_stdout_to_buffer() as buffer:
 
1554
            command.DumpJSON().run(self.clients)
 
1555
        json_data = json.loads(buffer.getvalue())
 
1556
        self.assertDictEqual(self.expected_json, json_data)
 
1557
 
 
1558
    @staticmethod
 
1559
    @contextlib.contextmanager
 
1560
    def capture_stdout_to_buffer():
 
1561
        capture_buffer = io.StringIO()
 
1562
        old_stdout = sys.stdout
 
1563
        sys.stdout = capture_buffer
 
1564
        try:
 
1565
            yield capture_buffer
 
1566
        finally:
 
1567
            sys.stdout = old_stdout
1566
1568
 
1567
1569
    def test_DumpJSON_one_client(self):
1568
 
        output = command.DumpJSON().output(self.one_client.values())
1569
 
        json_data = json.loads(output)
 
1570
        with self.capture_stdout_to_buffer() as buffer:
 
1571
            command.DumpJSON().run(self.one_client)
 
1572
        json_data = json.loads(buffer.getvalue())
1570
1573
        expected_json = {"foo": self.expected_json["foo"]}
1571
 
        self.assertDictEqual(json_data, expected_json)
 
1574
        self.assertDictEqual(expected_json, json_data)
1572
1575
 
1573
1576
    def test_PrintTable_normal(self):
1574
 
        output = command.PrintTable().output(self.clients.values())
 
1577
        with self.capture_stdout_to_buffer() as buffer:
 
1578
            command.PrintTable().run(self.clients)
1575
1579
        expected_output = "\n".join((
1576
1580
            "Name   Enabled Timeout  Last Successful Check",
1577
1581
            "foo    Yes     00:05:00 2019-02-03T00:00:00  ",
1578
1582
            "barbar Yes     00:05:00 2019-02-04T00:00:00  ",
1579
 
        ))
1580
 
        self.assertEqual(output, expected_output)
 
1583
        )) + "\n"
 
1584
        self.assertEqual(expected_output, buffer.getvalue())
1581
1585
 
1582
1586
    def test_PrintTable_verbose(self):
1583
 
        output = command.PrintTable(verbose=True).output(
1584
 
            self.clients.values())
 
1587
        with self.capture_stdout_to_buffer() as buffer:
 
1588
            command.PrintTable(verbose=True).run(self.clients)
1585
1589
        columns = (
1586
1590
            (
1587
1591
                "Name   ",
1669
1673
            )
1670
1674
        )
1671
1675
        num_lines = max(len(rows) for rows in columns)
1672
 
        expected_output = "\n".join("".join(rows[line]
1673
 
                                            for rows in columns)
1674
 
                                    for line in range(num_lines))
1675
 
        self.assertEqual(output, expected_output)
 
1676
        expected_output = ("\n".join("".join(rows[line]
 
1677
                                             for rows in columns)
 
1678
                                     for line in range(num_lines))
 
1679
                           + "\n")
 
1680
        self.assertEqual(expected_output, buffer.getvalue())
1676
1681
 
1677
1682
    def test_PrintTable_one_client(self):
1678
 
        output = command.PrintTable().output(self.one_client.values())
 
1683
        with self.capture_stdout_to_buffer() as buffer:
 
1684
            command.PrintTable().run(self.one_client)
1679
1685
        expected_output = "\n".join((
1680
1686
            "Name Enabled Timeout  Last Successful Check",
1681
1687
            "foo  Yes     00:05:00 2019-02-03T00:00:00  ",
1682
 
        ))
1683
 
        self.assertEqual(output, expected_output)
 
1688
        )) + "\n"
 
1689
        self.assertEqual(expected_output, buffer.getvalue())
1684
1690
 
1685
1691
 
1686
1692
class TestPropertyCmd(TestCommand):
1695
1701
            for clientpath in self.clients:
1696
1702
                client = self.bus.get_object(dbus_busname, clientpath)
1697
1703
                old_value = client.attributes[self.propname]
1698
 
                self.assertNotIsInstance(old_value, self.Unique)
1699
1704
                client.attributes[self.propname] = self.Unique()
1700
1705
            self.run_command(value_to_set, self.clients)
1701
1706
            for clientpath in self.clients:
1702
1707
                client = self.bus.get_object(dbus_busname, clientpath)
1703
1708
                value = client.attributes[self.propname]
1704
1709
                self.assertNotIsInstance(value, self.Unique)
1705
 
                self.assertEqual(value, value_to_get)
 
1710
                self.assertEqual(value_to_get, value)
1706
1711
 
1707
1712
    class Unique(object):
1708
1713
        """Class for objects which exist only to be unique objects,