=== modified file 'mandos-ctl' --- mandos-ctl 2019-03-16 00:23:20 +0000 +++ mandos-ctl 2019-03-16 04:32:51 +0000 @@ -839,7 +839,40 @@ -class Test_string_to_delta(unittest.TestCase): +class TestCaseWithAssertLogs(unittest.TestCase): + """unittest.TestCase.assertLogs only exists in Python 3.4""" + + if not hasattr(unittest.TestCase, "assertLogs"): + @contextlib.contextmanager + def assertLogs(self, logger, level=logging.INFO): + capturing_handler = self.CapturingLevelHandler(level) + old_level = logger.level + old_propagate = logger.propagate + logger.addHandler(capturing_handler) + logger.setLevel(level) + logger.propagate = False + try: + yield capturing_handler.watcher + finally: + logger.propagate = old_propagate + logger.removeHandler(capturing_handler) + logger.setLevel(old_level) + self.assertGreater(len(capturing_handler.watcher.records), + 0) + + class CapturingLevelHandler(logging.Handler): + def __init__(self, level, *args, **kwargs): + logging.Handler.__init__(self, *args, **kwargs) + self.watcher = self.LoggingWatcher([], []) + def emit(self, record): + self.watcher.records.append(record) + self.watcher.output.append(self.format(record)) + + LoggingWatcher = collections.namedtuple("LoggingWatcher", + ("records", + "output")) + +class Test_string_to_delta(TestCaseWithAssertLogs): def test_handles_basic_rfc3339(self): self.assertEqual(string_to_delta("PT0S"), datetime.timedelta()) @@ -851,25 +884,8 @@ datetime.timedelta(0, 7200)) def test_falls_back_to_pre_1_6_1_with_warning(self): - # assertLogs only exists in Python 3.4 - if hasattr(self, "assertLogs"): - with self.assertLogs(log, logging.WARNING): - value = string_to_delta("2h") - else: - class WarningFilter(logging.Filter): - """Don't show, but record the presence of, warnings""" - def filter(self, record): - is_warning = record.levelno >= logging.WARNING - self.found = is_warning or getattr(self, "found", - False) - return not is_warning - warning_filter = WarningFilter() - log.addFilter(warning_filter) - try: - value = string_to_delta("2h") - finally: - log.removeFilter(warning_filter) - self.assertTrue(getattr(warning_filter, "found", False)) + with self.assertLogs(log, logging.WARNING): + value = string_to_delta("2h") self.assertEqual(value, datetime.timedelta(0, 7200)) @@ -1005,7 +1021,7 @@ self.check_option_syntax(options) -class Test_get_mandos_dbus_object(unittest.TestCase): +class Test_get_mandos_dbus_object(TestCaseWithAssertLogs): def test_calls_and_returns_get_object_on_bus(self): class MockBus(object): called = False @@ -1026,35 +1042,17 @@ def get_object(self, busname, dbus_path): raise dbus.exceptions.DBusException("Test") - # assertLogs only exists in Python 3.4 - if hasattr(self, "assertLogs"): - with self.assertLogs(log, logging.CRITICAL): - with self.assertRaises(SystemExit) as e: - bus = get_mandos_dbus_object(bus=MockBus()) - else: - critical_filter = self.CriticalFilter() - log.addFilter(critical_filter) - try: - with self.assertRaises(SystemExit) as e: - get_mandos_dbus_object(bus=MockBusFailing()) - finally: - log.removeFilter(critical_filter) - self.assertTrue(critical_filter.found) + with self.assertLogs(log, logging.CRITICAL): + with self.assertRaises(SystemExit) as e: + bus = get_mandos_dbus_object(bus=MockBusFailing()) + if isinstance(e.exception.code, int): self.assertNotEqual(e.exception.code, 0) else: self.assertIsNotNone(e.exception.code) - class CriticalFilter(logging.Filter): - """Don't show, but register, critical messages""" - found = False - def filter(self, record): - is_critical = record.levelno >= logging.CRITICAL - self.found = is_critical or self.found - return not is_critical - - -class Test_get_managed_objects(unittest.TestCase): + +class Test_get_managed_objects(TestCaseWithAssertLogs): def test_calls_and_returns_GetManagedObjects(self): managed_objects = {"/clients/foo": { "Name": "foo"}} class MockObjectManager(object): @@ -1065,65 +1063,40 @@ self.assertDictEqual(managed_objects, retval) def test_logs_and_exits_on_dbus_error(self): + dbus_logger = logging.getLogger("dbus.proxies") + class MockObjectManagerFailing(object): @staticmethod def GetManagedObjects(): + dbus_logger.error("Test") raise dbus.exceptions.DBusException("Test") - if hasattr(self, "assertLogs"): - with self.assertLogs(log, logging.CRITICAL): - with self.assertRaises(SystemExit): - get_managed_objects(MockObjectManagerFailing()) - else: - critical_filter = self.CriticalFilter() - log.addFilter(critical_filter) - try: + class CountingHandler(logging.Handler): + count = 0 + def emit(self, record): + self.count += 1 + + counting_handler = CountingHandler() + + dbus_logger.addHandler(counting_handler) + + try: + with self.assertLogs(log, logging.CRITICAL) as watcher: with self.assertRaises(SystemExit) as e: get_managed_objects(MockObjectManagerFailing()) - finally: - log.removeFilter(critical_filter) - self.assertTrue(critical_filter.found) + finally: + dbus_logger.removeFilter(counting_handler) + self.assertEqual(counting_handler.count, 0) + + # Test that the dbus_logger still works + with self.assertLogs(dbus_logger, logging.ERROR): + dbus_logger.error("Test") + if isinstance(e.exception.code, int): self.assertNotEqual(e.exception.code, 0) else: self.assertIsNotNone(e.exception.code) - class CriticalFilter(logging.Filter): - """Don't show, but register, critical messages""" - found = False - def filter(self, record): - is_critical = record.levelno >= logging.CRITICAL - self.found = is_critical or self.found - return not is_critical - - -class Test_SilenceLogger(unittest.TestCase): - loggername = "mandos-ctl.Test_SilenceLogger" - log = logging.getLogger(loggername) - log.propagate = False - log.addHandler(logging.NullHandler()) - - def setUp(self): - self.counting_filter = self.CountingFilter() - - class CountingFilter(logging.Filter): - "Count number of records" - count = 0 - def filter(self, record): - self.count += 1 - return True - - def test_should_filter_records_only_when_active(self): - try: - with SilenceLogger(self.loggername): - self.log.addFilter(self.counting_filter) - self.log.info("Filtered log message 1") - self.log.info("Non-filtered message 2") - self.log.info("Non-filtered message 3") - finally: - self.log.removeFilter(self.counting_filter) - self.assertEqual(self.counting_filter.count, 2) - class Test_commands_from_options(unittest.TestCase): def setUp(self):