#!/usr/bin/python
# -*- mode: python; coding: utf-8; after-save-hook: (lambda () (let ((command (if (and (boundp 'tramp-file-name-structure) (string-match (car tramp-file-name-structure) (buffer-file-name))) (tramp-file-name-localname (tramp-dissect-file-name (buffer-file-name))) (buffer-file-name)))) (if (= (shell-command (format "%s --check" (shell-quote-argument command)) "*Test*") 0) (let ((w (get-buffer-window "*Test*"))) (if w (delete-window w)) (kill-buffer "*Test*")) (display-buffer "*Test*")))); -*-
#
# Mandos Monitor - Control and monitor the Mandos server
#
# Copyright © 2008-2019 Teddy Hogeborn
# Copyright © 2008-2019 Björn Påhlsson
#
# This file is part of Mandos.
#
# Mandos is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
#     Mandos is distributed in the hope that it will be useful, but
#     WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Mandos.  If not, see <http://www.gnu.org/licenses/>.
#
# Contact the authors at <mandos@recompile.se>.
#

from __future__ import (division, absolute_import, print_function,
                        unicode_literals)

try:
    from future_builtins import *
except ImportError:
    pass

import sys
import argparse
import locale
import datetime
import re
import os
import collections
import json
import unittest
import logging
import io
import tempfile
import contextlib
import abc

import dbus as dbus_python

# Show warnings by default
if not sys.warnoptions:
    import warnings
    warnings.simplefilter("default")

log = logging.getLogger(sys.argv[0])
logging.basicConfig(level="INFO", # Show info level messages
                    format="%(message)s") # Show basic log messages

logging.captureWarnings(True)   # Show warnings via the logging system

if sys.version_info.major == 2:
    str = unicode
    import StringIO
    io.StringIO = StringIO.StringIO

locale.setlocale(locale.LC_ALL, "")

version = "1.8.3"


def main():
    parser = argparse.ArgumentParser()
    add_command_line_options(parser)

    options = parser.parse_args()
    check_option_syntax(parser, options)

    clientnames = options.client

    if options.debug:
        log.setLevel(logging.DEBUG)

    bus = dbus_python_adapter.CachingBus(dbus_python)

    try:
        all_clients = bus.get_clients_and_properties()
    except dbus.ConnectFailed as e:
        log.critical("Could not connect to Mandos server: %s", e)
        sys.exit(1)
    except dbus.Error as e:
        log.critical(
            "Failed to access Mandos server through D-Bus:\n%s", e)
        sys.exit(1)

    # Compile dict of (clientpath: properties) to process
    if not clientnames:
        clients = all_clients
    else:
        clients = {}
        for name in clientnames:
            for objpath, properties in all_clients.items():
                if properties["Name"] == name:
                    clients[objpath] = properties
                    break
            else:
                log.critical("Client not found on server: %r", name)
                sys.exit(1)

    commands = commands_from_options(options)

    for command in commands:
        command.run(clients, bus)


def add_command_line_options(parser):
    parser.add_argument("--version", action="version",
                        version="%(prog)s {}".format(version),
                        help="show version number and exit")
    parser.add_argument("-a", "--all", action="store_true",
                        help="Select all clients")
    parser.add_argument("-v", "--verbose", action="store_true",
                        help="Print all fields")
    parser.add_argument("-j", "--dump-json", action="store_true",
                        help="Dump client data in JSON format")
    enable_disable = parser.add_mutually_exclusive_group()
    enable_disable.add_argument("-e", "--enable", action="store_true",
                                help="Enable client")
    enable_disable.add_argument("-d", "--disable",
                                action="store_true",
                                help="disable client")
    parser.add_argument("-b", "--bump-timeout", action="store_true",
                        help="Bump timeout for client")
    start_stop_checker = parser.add_mutually_exclusive_group()
    start_stop_checker.add_argument("--start-checker",
                                    action="store_true",
                                    help="Start checker for client")
    start_stop_checker.add_argument("--stop-checker",
                                    action="store_true",
                                    help="Stop checker for client")
    parser.add_argument("-V", "--is-enabled", action="store_true",
                        help="Check if client is enabled")
    parser.add_argument("-r", "--remove", action="store_true",
                        help="Remove client")
    parser.add_argument("-c", "--checker",
                        help="Set checker command for client")
    parser.add_argument("-t", "--timeout", type=string_to_delta,
                        help="Set timeout for client")
    parser.add_argument("--extended-timeout", type=string_to_delta,
                        help="Set extended timeout for client")
    parser.add_argument("-i", "--interval", type=string_to_delta,
                        help="Set checker interval for client")
    approve_deny_default = parser.add_mutually_exclusive_group()
    approve_deny_default.add_argument(
        "--approve-by-default", action="store_true",
        default=None, dest="approved_by_default",
        help="Set client to be approved by default")
    approve_deny_default.add_argument(
        "--deny-by-default", action="store_false",
        dest="approved_by_default",
        help="Set client to be denied by default")
    parser.add_argument("--approval-delay", type=string_to_delta,
                        help="Set delay before client approve/deny")
    parser.add_argument("--approval-duration", type=string_to_delta,
                        help="Set duration of one client approval")
    parser.add_argument("-H", "--host", help="Set host for client")
    parser.add_argument("-s", "--secret",
                        type=argparse.FileType(mode="rb"),
                        help="Set password blob (file) for client")
    approve_deny = parser.add_mutually_exclusive_group()
    approve_deny.add_argument(
        "-A", "--approve", action="store_true",
        help="Approve any current client request")
    approve_deny.add_argument("-D", "--deny", action="store_true",
                              help="Deny any current client request")
    parser.add_argument("--debug", action="store_true",
                        help="Debug mode (show D-Bus commands)")
    parser.add_argument("--check", action="store_true",
                        help="Run self-test")
    parser.add_argument("client", nargs="*", help="Client name")


def string_to_delta(interval):
    """Parse a string and return a datetime.timedelta"""

    try:
        return rfc3339_duration_to_delta(interval)
    except ValueError as e:
        log.warning("%s - Parsing as pre-1.6.1 interval instead",
                    ' '.join(e.args))
    return parse_pre_1_6_1_interval(interval)


def rfc3339_duration_to_delta(duration):
    """Parse an RFC 3339 "duration" and return a datetime.timedelta

    >>> rfc3339_duration_to_delta("P7D")
    datetime.timedelta(7)
    >>> rfc3339_duration_to_delta("PT60S")
    datetime.timedelta(0, 60)
    >>> rfc3339_duration_to_delta("PT60M")
    datetime.timedelta(0, 3600)
    >>> rfc3339_duration_to_delta("P60M")
    datetime.timedelta(1680)
    >>> rfc3339_duration_to_delta("PT24H")
    datetime.timedelta(1)
    >>> rfc3339_duration_to_delta("P1W")
    datetime.timedelta(7)
    >>> rfc3339_duration_to_delta("PT5M30S")
    datetime.timedelta(0, 330)
    >>> rfc3339_duration_to_delta("P1DT3M20S")
    datetime.timedelta(1, 200)
    >>> # Can not be empty:
    >>> rfc3339_duration_to_delta("")
    Traceback (most recent call last):
    ...
    ValueError: Invalid RFC 3339 duration: ""
    >>> # Must start with "P":
    >>> rfc3339_duration_to_delta("1D")
    Traceback (most recent call last):
    ...
    ValueError: Invalid RFC 3339 duration: "1D"
    >>> # Must use correct order
    >>> rfc3339_duration_to_delta("PT1S2M")
    Traceback (most recent call last):
    ...
    ValueError: Invalid RFC 3339 duration: "PT1S2M"
    >>> # Time needs time marker
    >>> rfc3339_duration_to_delta("P1H2S")
    Traceback (most recent call last):
    ...
    ValueError: Invalid RFC 3339 duration: "P1H2S"
    >>> # Weeks can not be combined with anything else
    >>> rfc3339_duration_to_delta("P1D2W")
    Traceback (most recent call last):
    ...
    ValueError: Invalid RFC 3339 duration: "P1D2W"
    >>> rfc3339_duration_to_delta("P2W2H")
    Traceback (most recent call last):
    ...
    ValueError: Invalid RFC 3339 duration: "P2W2H"
    """

    # Parsing an RFC 3339 duration with regular expressions is not
    # possible - there would have to be multiple places for the same
    # values, like seconds.  The current code, while more esoteric, is
    # cleaner without depending on a parsing library.  If Python had a
    # built-in library for parsing we would use it, but we'd like to
    # avoid excessive use of external libraries.

    # New type for defining tokens, syntax, and semantics all-in-one
    Token = collections.namedtuple("Token", (
        "regexp",  # To match token; if "value" is not None, must have
                   # a "group" containing digits
        "value",   # datetime.timedelta or None
        "followers"))           # Tokens valid after this token
    # RFC 3339 "duration" tokens, syntax, and semantics; taken from
    # the "duration" ABNF definition in RFC 3339, Appendix A.
    token_end = Token(re.compile(r"$"), None, frozenset())
    token_second = Token(re.compile(r"(\d+)S"),
                         datetime.timedelta(seconds=1),
                         frozenset((token_end, )))
    token_minute = Token(re.compile(r"(\d+)M"),
                         datetime.timedelta(minutes=1),
                         frozenset((token_second, token_end)))
    token_hour = Token(re.compile(r"(\d+)H"),
                       datetime.timedelta(hours=1),
                       frozenset((token_minute, token_end)))
    token_time = Token(re.compile(r"T"),
                       None,
                       frozenset((token_hour, token_minute,
                                  token_second)))
    token_day = Token(re.compile(r"(\d+)D"),
                      datetime.timedelta(days=1),
                      frozenset((token_time, token_end)))
    token_month = Token(re.compile(r"(\d+)M"),
                        datetime.timedelta(weeks=4),
                        frozenset((token_day, token_end)))
    token_year = Token(re.compile(r"(\d+)Y"),
                       datetime.timedelta(weeks=52),
                       frozenset((token_month, token_end)))
    token_week = Token(re.compile(r"(\d+)W"),
                       datetime.timedelta(weeks=1),
                       frozenset((token_end, )))
    token_duration = Token(re.compile(r"P"), None,
                           frozenset((token_year, token_month,
                                      token_day, token_time,
                                      token_week)))
    # Define starting values:
    # Value so far
    value = datetime.timedelta()
    found_token = None
    # Following valid tokens
    followers = frozenset((token_duration, ))
    # String left to parse
    s = duration
    # Loop until end token is found
    while found_token is not token_end:
        # Search for any currently valid tokens
        for token in followers:
            match = token.regexp.match(s)
            if match is not None:
                # Token found
                if token.value is not None:
                    # Value found, parse digits
                    factor = int(match.group(1), 10)
                    # Add to value so far
                    value += factor * token.value
                # Strip token from string
                s = token.regexp.sub("", s, 1)
                # Go to found token
                found_token = token
                # Set valid next tokens
                followers = found_token.followers
                break
        else:
            # No currently valid tokens were found
            raise ValueError("Invalid RFC 3339 duration: \"{}\""
                             .format(duration))
    # End token found
    return value


def parse_pre_1_6_1_interval(interval):
    """Parse an interval string as documented by Mandos before 1.6.1,
    and return a datetime.timedelta

    >>> parse_pre_1_6_1_interval('7d')
    datetime.timedelta(7)
    >>> parse_pre_1_6_1_interval('60s')
    datetime.timedelta(0, 60)
    >>> parse_pre_1_6_1_interval('60m')
    datetime.timedelta(0, 3600)
    >>> parse_pre_1_6_1_interval('24h')
    datetime.timedelta(1)
    >>> parse_pre_1_6_1_interval('1w')
    datetime.timedelta(7)
    >>> parse_pre_1_6_1_interval('5m 30s')
    datetime.timedelta(0, 330)
    >>> parse_pre_1_6_1_interval('')
    datetime.timedelta(0)
    >>> # Ignore unknown characters, allow any order and repetitions
    >>> parse_pre_1_6_1_interval('2dxy7zz11y3m5m')
    datetime.timedelta(2, 480, 18000)

    """

    value = datetime.timedelta(0)
    regexp = re.compile(r"(\d+)([dsmhw]?)")

    for num, suffix in regexp.findall(interval):
        if suffix == "d":
            value += datetime.timedelta(int(num))
        elif suffix == "s":
            value += datetime.timedelta(0, int(num))
        elif suffix == "m":
            value += datetime.timedelta(0, 0, 0, 0, int(num))
        elif suffix == "h":
            value += datetime.timedelta(0, 0, 0, 0, 0, int(num))
        elif suffix == "w":
            value += datetime.timedelta(0, 0, 0, 0, 0, 0, int(num))
        elif suffix == "":
            value += datetime.timedelta(0, 0, 0, int(num))
    return value


def check_option_syntax(parser, options):
    """Apply additional restrictions on options, not expressible in
argparse"""

    def has_actions(options):
        return any((options.enable,
                    options.disable,
                    options.bump_timeout,
                    options.start_checker,
                    options.stop_checker,
                    options.is_enabled,
                    options.remove,
                    options.checker is not None,
                    options.timeout is not None,
                    options.extended_timeout is not None,
                    options.interval is not None,
                    options.approved_by_default is not None,
                    options.approval_delay is not None,
                    options.approval_duration is not None,
                    options.host is not None,
                    options.secret is not None,
                    options.approve,
                    options.deny))

    if has_actions(options) and not (options.client or options.all):
        parser.error("Options require clients names or --all.")
    if options.verbose and has_actions(options):
        parser.error("--verbose can only be used alone.")
    if options.dump_json and (options.verbose
                              or has_actions(options)):
        parser.error("--dump-json can only be used alone.")
    if options.all and not has_actions(options):
        parser.error("--all requires an action.")
    if options.is_enabled and len(options.client) > 1:
        parser.error("--is-enabled requires exactly one client")
    if options.remove:
        options.remove = False
        if has_actions(options) and not options.deny:
            parser.error("--remove can only be combined with --deny")
        options.remove = True



class dbus(object):

    class SystemBus(object):

        object_manager_iface = "org.freedesktop.DBus.ObjectManager"
        def get_managed_objects(self, busname, objectpath):
            return self.call_method("GetManagedObjects", busname,
                                    objectpath,
                                    self.object_manager_iface)

        properties_iface = "org.freedesktop.DBus.Properties"
        def set_property(self, busname, objectpath, interface, key,
                         value):
            self.call_method("Set", busname, objectpath,
                             self.properties_iface, interface, key,
                             value)


    class MandosBus(SystemBus):
        busname_domain = "se.recompile"
        busname = busname_domain + ".Mandos"
        server_path = "/"
        server_interface = busname_domain + ".Mandos"
        client_interface = busname_domain + ".Mandos.Client"
        del busname_domain

        def get_clients_and_properties(self):
            managed_objects = self.get_managed_objects(
                self.busname, self.server_path)
            return {objpath: properties[self.client_interface]
                    for objpath, properties in managed_objects.items()
                    if self.client_interface in properties}

        def set_client_property(self, objectpath, key, value):
            return self.set_property(self.busname, objectpath,
                                     self.client_interface, key,
                                     value)

        def call_client_method(self, objectpath, method, *args):
            return self.call_method(method, self.busname, objectpath,
                                    self.client_interface, *args)

        def call_server_method(self, method, *args):
            return self.call_method(method, self.busname,
                                    self.server_path,
                                    self.server_interface, *args)

    class Error(Exception):
        pass

    class ConnectFailed(Error):
        pass


class dbus_python_adapter(object):

    class SystemBus(dbus.MandosBus):
        """Use dbus-python"""

        def __init__(self, module=dbus_python):
            self.dbus_python = module
            self.bus = self.dbus_python.SystemBus()

        @contextlib.contextmanager
        def convert_exception(self, exception_class=dbus.Error):
            try:
                yield
            except self.dbus_python.exceptions.DBusException as e:
                # This does what "raise from" would do
                exc = exception_class(*e.args)
                exc.__cause__ = e
                raise exc

        def call_method(self, methodname, busname, objectpath,
                        interface, *args):
            proxy_object = self.get_object(busname, objectpath)
            log.debug("D-Bus: %s:%s:%s.%s(%s)", busname, objectpath,
                      interface, methodname,
                      ", ".join(repr(a) for a in args))
            method = getattr(proxy_object, methodname)
            with self.convert_exception():
                with dbus_python_adapter.SilenceLogger(
                        "dbus.proxies"):
                    value = method(*args, dbus_interface=interface)
            return self.type_filter(value)

        def get_object(self, busname, objectpath):
            log.debug("D-Bus: Connect to: (busname=%r, path=%r)",
                      busname, objectpath)
            with self.convert_exception(dbus.ConnectFailed):
                return self.bus.get_object(busname, objectpath)

        def type_filter(self, value):
            """Convert the most bothersome types to Python types"""
            if isinstance(value, self.dbus_python.Boolean):
                return bool(value)
            if isinstance(value, self.dbus_python.ObjectPath):
                return str(value)
            # Also recurse into dictionaries
            if isinstance(value, self.dbus_python.Dictionary):
                return {self.type_filter(key):
                        self.type_filter(subval)
                        for key, subval in value.items()}
            return value


    class SilenceLogger(object):
        "Simple context manager to silence a particular logger"
        def __init__(self, loggername):
            self.logger = logging.getLogger(loggername)

        def __enter__(self):
            self.logger.addFilter(self.nullfilter)

        class NullFilter(logging.Filter):
            def filter(self, record):
                return False

        nullfilter = NullFilter()

        def __exit__(self, exc_type, exc_val, exc_tb):
            self.logger.removeFilter(self.nullfilter)


    class CachingBus(SystemBus):
        """A caching layer for dbus_python_adapter.SystemBus"""
        def __init__(self, *args, **kwargs):
            self.object_cache = {}
            super(dbus_python_adapter.CachingBus,
                  self).__init__(*args, **kwargs)
        def get_object(self, busname, objectpath):
            try:
                return self.object_cache[(busname, objectpath)]
            except KeyError:
                new_object = super(
                    dbus_python_adapter.CachingBus,
                    self).get_object(busname, objectpath)
                self.object_cache[(busname, objectpath)]  = new_object
                return new_object


def commands_from_options(options):

    commands = []

    if options.is_enabled:
        commands.append(command.IsEnabled())

    if options.approve:
        commands.append(command.Approve())

    if options.deny:
        commands.append(command.Deny())

    if options.remove:
        commands.append(command.Remove())

    if options.dump_json:
        commands.append(command.DumpJSON())

    if options.enable:
        commands.append(command.Enable())

    if options.disable:
        commands.append(command.Disable())

    if options.bump_timeout:
        commands.append(command.BumpTimeout())

    if options.start_checker:
        commands.append(command.StartChecker())

    if options.stop_checker:
        commands.append(command.StopChecker())

    if options.approved_by_default is not None:
        if options.approved_by_default:
            commands.append(command.ApproveByDefault())
        else:
            commands.append(command.DenyByDefault())

    if options.checker is not None:
        commands.append(command.SetChecker(options.checker))

    if options.host is not None:
        commands.append(command.SetHost(options.host))

    if options.secret is not None:
        commands.append(command.SetSecret(options.secret))

    if options.timeout is not None:
        commands.append(command.SetTimeout(options.timeout))

    if options.extended_timeout:
        commands.append(
            command.SetExtendedTimeout(options.extended_timeout))

    if options.interval is not None:
        commands.append(command.SetInterval(options.interval))

    if options.approval_delay is not None:
        commands.append(
            command.SetApprovalDelay(options.approval_delay))

    if options.approval_duration is not None:
        commands.append(
            command.SetApprovalDuration(options.approval_duration))

    # If no command option has been given, show table of clients,
    # optionally verbosely
    if not commands:
        commands.append(command.PrintTable(verbose=options.verbose))

    return commands


class command(object):
    """A namespace for command classes"""

    class Base(object):
        """Abstract base class for commands"""
        def run(self, clients, bus=None):
            """Normal commands should implement run_on_one_client(),
but commands which want to operate on all clients at the same time can
override this run() method instead.
"""
            self.bus = bus
            for client, properties in clients.items():
                self.run_on_one_client(client, properties)


    class IsEnabled(Base):
        def run(self, clients, bus=None):
            properties = next(iter(clients.values()))
            if properties["Enabled"]:
                sys.exit(0)
            sys.exit(1)


    class Approve(Base):
        def run_on_one_client(self, client, properties):
            self.bus.call_client_method(client, "Approve", True)


    class Deny(Base):
        def run_on_one_client(self, client, properties):
            self.bus.call_client_method(client, "Approve", False)


    class Remove(Base):
        def run(self, clients, bus):
            for clientpath in frozenset(clients.keys()):
                bus.call_server_method("RemoveClient", clientpath)


    class Output(Base):
        """Abstract class for commands outputting client details"""
        all_keywords = ("Name", "Enabled", "Timeout", "LastCheckedOK",
                        "Created", "Interval", "Host", "KeyID",
                        "Fingerprint", "CheckerRunning",
                        "LastEnabled", "ApprovalPending",
                        "ApprovedByDefault", "LastApprovalRequest",
                        "ApprovalDelay", "ApprovalDuration",
                        "Checker", "ExtendedTimeout", "Expires",
                        "LastCheckerStatus")


    class DumpJSON(Output):
        def run(self, clients, bus=None):
            data = {properties["Name"]:
                    {key: properties[key]
                     for key in self.all_keywords}
                    for properties in clients.values()}
            print(json.dumps(data, indent=4, separators=(',', ': ')))


    class PrintTable(Output):
        def __init__(self, verbose=False):
            self.verbose = verbose

        def run(self, clients, bus=None):
            default_keywords = ("Name", "Enabled", "Timeout",
                                "LastCheckedOK")
            keywords = default_keywords
            if self.verbose:
                keywords = self.all_keywords
            print(self.TableOfClients(clients.values(), keywords))

        class TableOfClients(object):
            tableheaders = {
                "Name": "Name",
                "Enabled": "Enabled",
                "Timeout": "Timeout",
                "LastCheckedOK": "Last Successful Check",
                "LastApprovalRequest": "Last Approval Request",
                "Created": "Created",
                "Interval": "Interval",
                "Host": "Host",
                "Fingerprint": "Fingerprint",
                "KeyID": "Key ID",
                "CheckerRunning": "Check Is Running",
                "LastEnabled": "Last Enabled",
                "ApprovalPending": "Approval Is Pending",
                "ApprovedByDefault": "Approved By Default",
                "ApprovalDelay": "Approval Delay",
                "ApprovalDuration": "Approval Duration",
                "Checker": "Checker",
                "ExtendedTimeout": "Extended Timeout",
                "Expires": "Expires",
                "LastCheckerStatus": "Last Checker Status",
            }

            def __init__(self, clients, keywords):
                self.clients = clients
                self.keywords = keywords

            def __str__(self):
                return "\n".join(self.rows())

            if sys.version_info.major == 2:
                __unicode__ = __str__
                def __str__(self):
                    return str(self).encode(
                        locale.getpreferredencoding())

            def rows(self):
                format_string = self.row_formatting_string()
                rows = [self.header_line(format_string)]
                rows.extend(self.client_line(client, format_string)
                            for client in self.clients)
                return rows

            def row_formatting_string(self):
                "Format string used to format table rows"
                return " ".join("{{{key}:{width}}}".format(
                    width=max(len(self.tableheaders[key]),
                              *(len(self.string_from_client(client,
                                                            key))
                                for client in self.clients)),
                    key=key)
                                for key in self.keywords)

            def string_from_client(self, client, key):
                return self.valuetostring(client[key], key)

            @classmethod
            def valuetostring(cls, value, keyword):
                if isinstance(value, bool):
                    return "Yes" if value else "No"
                if keyword in ("Timeout", "Interval", "ApprovalDelay",
                               "ApprovalDuration", "ExtendedTimeout"):
                    return cls.milliseconds_to_string(value)
                return str(value)

            def header_line(self, format_string):
                return format_string.format(**self.tableheaders)

            def client_line(self, client, format_string):
                return format_string.format(
                    **{key: self.string_from_client(client, key)
                       for key in self.keywords})

            @staticmethod
            def milliseconds_to_string(ms):
                td = datetime.timedelta(0, 0, 0, ms)
                return ("{days}{hours:02}:{minutes:02}:{seconds:02}"
                        .format(days="{}T".format(td.days)
                                if td.days else "",
                                hours=td.seconds // 3600,
                                minutes=(td.seconds % 3600) // 60,
                                seconds=td.seconds % 60))


    class PropertySetter(Base):
        "Abstract class for Actions for setting one client property"

        def run_on_one_client(self, client, properties=None):
            """Set the Client's D-Bus property"""
            self.bus.set_client_property(client, self.propname,
                                         self.value_to_set)

        @property
        def propname(self):
            raise NotImplementedError()


    class Enable(PropertySetter):
        propname = "Enabled"
        value_to_set = True


    class Disable(PropertySetter):
        propname = "Enabled"
        value_to_set = False


    class BumpTimeout(PropertySetter):
        propname = "LastCheckedOK"
        value_to_set = ""


    class StartChecker(PropertySetter):
        propname = "CheckerRunning"
        value_to_set = True


    class StopChecker(PropertySetter):
        propname = "CheckerRunning"
        value_to_set = False


    class ApproveByDefault(PropertySetter):
        propname = "ApprovedByDefault"
        value_to_set = True


    class DenyByDefault(PropertySetter):
        propname = "ApprovedByDefault"
        value_to_set = False


    class PropertySetterValue(PropertySetter):
        """Abstract class for PropertySetter recieving a value as
constructor argument instead of a class attribute."""
        def __init__(self, value):
            self.value_to_set = value


    class SetChecker(PropertySetterValue):
        propname = "Checker"


    class SetHost(PropertySetterValue):
        propname = "Host"


    class SetSecret(PropertySetterValue):
        propname = "Secret"

        @property
        def value_to_set(self):
            return self._vts

        @value_to_set.setter
        def value_to_set(self, value):
            """When setting, read data from supplied file object"""
            self._vts = value.read()
            value.close()


    class PropertySetterValueMilliseconds(PropertySetterValue):
        """Abstract class for PropertySetterValue taking a value
argument as a datetime.timedelta() but should store it as
milliseconds."""

        @property
        def value_to_set(self):
            return self._vts

        @value_to_set.setter
        def value_to_set(self, value):
            "When setting, convert value from a datetime.timedelta"
            self._vts = int(round(value.total_seconds() * 1000))


    class SetTimeout(PropertySetterValueMilliseconds):
        propname = "Timeout"


    class SetExtendedTimeout(PropertySetterValueMilliseconds):
        propname = "ExtendedTimeout"


    class SetInterval(PropertySetterValueMilliseconds):
        propname = "Interval"


    class SetApprovalDelay(PropertySetterValueMilliseconds):
        propname = "ApprovalDelay"


    class SetApprovalDuration(PropertySetterValueMilliseconds):
        propname = "ApprovalDuration"



class TestCaseWithAssertLogs(unittest.TestCase):
    """unittest.TestCase.assertLogs only exists in Python 3.4"""

    if not hasattr(unittest.TestCase, "assertLogs"):
        @contextlib.contextmanager
        def assertLogs(self, logger, level=logging.INFO):
            capturing_handler = self.CapturingLevelHandler(level)
            old_level = logger.level
            old_propagate = logger.propagate
            logger.addHandler(capturing_handler)
            logger.setLevel(level)
            logger.propagate = False
            try:
                yield capturing_handler.watcher
            finally:
                logger.propagate = old_propagate
                logger.removeHandler(capturing_handler)
                logger.setLevel(old_level)
            self.assertGreater(len(capturing_handler.watcher.records),
                               0)

        class CapturingLevelHandler(logging.Handler):
            def __init__(self, level, *args, **kwargs):
                logging.Handler.__init__(self, *args, **kwargs)
                self.watcher = self.LoggingWatcher([], [])
            def emit(self, record):
                self.watcher.records.append(record)
                self.watcher.output.append(self.format(record))

            LoggingWatcher = collections.namedtuple("LoggingWatcher",
                                                    ("records",
                                                     "output"))


class Unique(object):
    """Class for objects which exist only to be unique objects, since
unittest.mock.sentinel only exists in Python 3.3"""


class Test_string_to_delta(TestCaseWithAssertLogs):
    # Just test basic RFC 3339 functionality here, the doc string for
    # rfc3339_duration_to_delta() already has more comprehensive
    # tests, which are run by doctest.

    def test_rfc3339_zero_seconds(self):
        self.assertEqual(datetime.timedelta(),
                         string_to_delta("PT0S"))

    def test_rfc3339_zero_days(self):
        self.assertEqual(datetime.timedelta(), string_to_delta("P0D"))

    def test_rfc3339_one_second(self):
        self.assertEqual(datetime.timedelta(0, 1),
                         string_to_delta("PT1S"))

    def test_rfc3339_two_hours(self):
        self.assertEqual(datetime.timedelta(0, 7200),
                         string_to_delta("PT2H"))

    def test_falls_back_to_pre_1_6_1_with_warning(self):
        with self.assertLogs(log, logging.WARNING):
            value = string_to_delta("2h")
        self.assertEqual(datetime.timedelta(0, 7200), value)


class Test_check_option_syntax(unittest.TestCase):
    def setUp(self):
        self.parser = argparse.ArgumentParser()
        add_command_line_options(self.parser)

    def test_actions_requires_client_or_all(self):
        for action, value in self.actions.items():
            options = self.parser.parse_args()
            setattr(options, action, value)
            with self.assertParseError():
                self.check_option_syntax(options)

    # This mostly corresponds to the definition from has_actions() in
    # check_option_syntax()
    actions = {
        # The actual values set here are not that important, but we do
        # at least stick to the correct types, even though they are
        # never used
        "enable": True,
        "disable": True,
        "bump_timeout": True,
        "start_checker": True,
        "stop_checker": True,
        "is_enabled": True,
        "remove": True,
        "checker": "x",
        "timeout": datetime.timedelta(),
        "extended_timeout": datetime.timedelta(),
        "interval": datetime.timedelta(),
        "approved_by_default": True,
        "approval_delay": datetime.timedelta(),
        "approval_duration": datetime.timedelta(),
        "host": "x",
        "secret": io.BytesIO(b"x"),
        "approve": True,
        "deny": True,
    }

    @contextlib.contextmanager
    def assertParseError(self):
        with self.assertRaises(SystemExit) as e:
            with self.redirect_stderr_to_devnull():
                yield
        # Exit code from argparse is guaranteed to be "2".  Reference:
        # https://docs.python.org/3/library
        # /argparse.html#exiting-methods
        self.assertEqual(2, e.exception.code)

    @staticmethod
    @contextlib.contextmanager
    def redirect_stderr_to_devnull():
        old_stderr = sys.stderr
        with contextlib.closing(open(os.devnull, "w")) as null:
            sys.stderr = null
            try:
                yield
            finally:
                sys.stderr = old_stderr

    def check_option_syntax(self, options):
        check_option_syntax(self.parser, options)

    def test_actions_all_conflicts_with_verbose(self):
        for action, value in self.actions.items():
            options = self.parser.parse_args()
            setattr(options, action, value)
            options.all = True
            options.verbose = True
            with self.assertParseError():
                self.check_option_syntax(options)

    def test_actions_with_client_conflicts_with_verbose(self):
        for action, value in self.actions.items():
            options = self.parser.parse_args()
            setattr(options, action, value)
            options.verbose = True
            options.client = ["client"]
            with self.assertParseError():
                self.check_option_syntax(options)

    def test_dump_json_conflicts_with_verbose(self):
        options = self.parser.parse_args()
        options.dump_json = True
        options.verbose = True
        with self.assertParseError():
            self.check_option_syntax(options)

    def test_dump_json_conflicts_with_action(self):
        for action, value in self.actions.items():
            options = self.parser.parse_args()
            setattr(options, action, value)
            options.dump_json = True
            with self.assertParseError():
                self.check_option_syntax(options)

    def test_all_can_not_be_alone(self):
        options = self.parser.parse_args()
        options.all = True
        with self.assertParseError():
            self.check_option_syntax(options)

    def test_all_is_ok_with_any_action(self):
        for action, value in self.actions.items():
            options = self.parser.parse_args()
            setattr(options, action, value)
            options.all = True
            self.check_option_syntax(options)

    def test_any_action_is_ok_with_one_client(self):
        for action, value in self.actions.items():
            options = self.parser.parse_args()
            setattr(options, action, value)
            options.client = ["client"]
            self.check_option_syntax(options)

    def test_one_client_with_all_actions_except_is_enabled(self):
        options = self.parser.parse_args()
        for action, value in self.actions.items():
            if action == "is_enabled":
                continue
            setattr(options, action, value)
        options.client = ["client"]
        self.check_option_syntax(options)

    def test_two_clients_with_all_actions_except_is_enabled(self):
        options = self.parser.parse_args()
        for action, value in self.actions.items():
            if action == "is_enabled":
                continue
            setattr(options, action, value)
        options.client = ["client1", "client2"]
        self.check_option_syntax(options)

    def test_two_clients_are_ok_with_actions_except_is_enabled(self):
        for action, value in self.actions.items():
            if action == "is_enabled":
                continue
            options = self.parser.parse_args()
            setattr(options, action, value)
            options.client = ["client1", "client2"]
            self.check_option_syntax(options)

    def test_is_enabled_fails_without_client(self):
        options = self.parser.parse_args()
        options.is_enabled = True
        with self.assertParseError():
            self.check_option_syntax(options)

    def test_is_enabled_fails_with_two_clients(self):
        options = self.parser.parse_args()
        options.is_enabled = True
        options.client = ["client1", "client2"]
        with self.assertParseError():
            self.check_option_syntax(options)

    def test_remove_can_only_be_combined_with_action_deny(self):
        for action, value in self.actions.items():
            if action in {"remove", "deny"}:
                continue
            options = self.parser.parse_args()
            setattr(options, action, value)
            options.all = True
            options.remove = True
            with self.assertParseError():
                self.check_option_syntax(options)


class Test_dbus_exceptions(unittest.TestCase):

    def test_dbus_ConnectFailed_is_Error(self):
        with self.assertRaises(dbus.Error):
            raise dbus.ConnectFailed()


class Test_dbus_MandosBus(unittest.TestCase):

    class MockMandosBus(dbus.MandosBus):
        def __init__(self):
            self._name = "se.recompile.Mandos"
            self._server_path = "/"
            self._server_interface = "se.recompile.Mandos"
            self._client_interface = "se.recompile.Mandos.Client"
            self.calls = []
            self.call_method_return = Unique()

        def call_method(self, methodname, busname, objectpath,
                        interface, *args):
            self.calls.append((methodname, busname, objectpath,
                               interface, args))
            return self.call_method_return

    def setUp(self):
        self.bus = self.MockMandosBus()

    def test_set_client_property(self):
        self.bus.set_client_property("objectpath", "key", "value")
        expected_call = ("Set", self.bus._name, "objectpath",
                         "org.freedesktop.DBus.Properties",
                         (self.bus._client_interface, "key", "value"))
        self.assertIn(expected_call, self.bus.calls)

    def test_call_client_method(self):
        ret = self.bus.call_client_method("objectpath", "methodname")
        self.assertIs(self.bus.call_method_return, ret)
        expected_call = ("methodname", self.bus._name, "objectpath",
                         self.bus._client_interface, ())
        self.assertIn(expected_call, self.bus.calls)

    def test_call_client_method_with_args(self):
        args = (Unique(), Unique())
        ret = self.bus.call_client_method("objectpath", "methodname",
                                          *args)
        self.assertIs(self.bus.call_method_return, ret)
        expected_call = ("methodname", self.bus._name, "objectpath",
                         self.bus._client_interface,
                         (args[0], args[1]))
        self.assertIn(expected_call, self.bus.calls)

    def test_get_clients_and_properties(self):
        managed_objects = {
            "objectpath": {
                self.bus._client_interface: {
                    "key": "value",
                    "bool": True,
                },
                "irrelevant_interface": {
                    "key": "othervalue",
                    "bool": False,
                },
            },
            "other_objectpath": {
                "other_irrelevant_interface": {
                    "key": "value 3",
                    "bool": None,
                },
            },
        }
        expected_clients_and_properties = {
            "objectpath": {
                "key": "value",
                "bool": True,
            }
        }
        self.bus.call_method_return = managed_objects
        ret = self.bus.get_clients_and_properties()
        self.assertDictEqual(expected_clients_and_properties, ret)
        expected_call = ("GetManagedObjects", self.bus._name,
                         self.bus._server_path,
                         "org.freedesktop.DBus.ObjectManager", ())
        self.assertIn(expected_call, self.bus.calls)

    def test_call_server_method(self):
        ret = self.bus.call_server_method("methodname")
        self.assertIs(self.bus.call_method_return, ret)
        expected_call = ("methodname", self.bus._name,
                         self.bus._server_path,
                         self.bus._server_interface, ())
        self.assertIn(expected_call, self.bus.calls)

    def test_call_server_method_with_args(self):
        args = (Unique(), Unique())
        ret = self.bus.call_server_method("methodname", *args)
        self.assertIs(self.bus.call_method_return, ret)
        expected_call = ("methodname", self.bus._name,
                         self.bus._server_path,
                         self.bus._server_interface,
                         (args[0], args[1]))
        self.assertIn(expected_call, self.bus.calls)


class Test_dbus_python_adapter_SystemBus(TestCaseWithAssertLogs):

    def MockDBusPython_func(self, func):
        class mock_dbus_python(object):
            """mock dbus-python module"""
            class exceptions(object):
                """Pseudo-namespace"""
                class DBusException(Exception):
                    pass
            class SystemBus(object):
                @staticmethod
                def get_object(busname, objectpath):
                    DBusObject = collections.namedtuple(
                        "DBusObject", ("methodname",))
                    def method(*args, **kwargs):
                        self.assertEqual({"dbus_interface":
                                          "interface"},
                                         kwargs)
                        return func(*args)
                    return DBusObject(methodname=method)
            class Boolean(object):
                def __init__(self, value):
                    self.value = bool(value)
                def __bool__(self):
                    return self.value
                if sys.version_info.major == 2:
                    __nonzero__ = __bool__
            class ObjectPath(str):
                pass
            class Dictionary(dict):
                pass
        return mock_dbus_python

    def call_method(self, bus, methodname, busname, objectpath,
                    interface, *args):
        with self.assertLogs(log, logging.DEBUG):
            return bus.call_method(methodname, busname, objectpath,
                                   interface, *args)

    def test_call_method_returns(self):
        expected_method_return = Unique()
        method_args = (Unique(), Unique())
        def func(*args):
            self.assertEqual(len(method_args), len(args))
            for marg, arg in zip(method_args, args):
                self.assertIs(marg, arg)
            return expected_method_return
        mock_dbus_python = self.MockDBusPython_func(func)
        bus = dbus_python_adapter.SystemBus(mock_dbus_python)
        ret = self.call_method(bus, "methodname", "busname",
                               "objectpath", "interface",
                               *method_args)
        self.assertIs(ret, expected_method_return)

    def test_call_method_filters_bool_true(self):
        def func():
            return method_return
        mock_dbus_python = self.MockDBusPython_func(func)
        bus = dbus_python_adapter.SystemBus(mock_dbus_python)
        method_return = mock_dbus_python.Boolean(True)
        ret = self.call_method(bus, "methodname", "busname",
                               "objectpath", "interface")
        self.assertTrue(ret)
        self.assertNotIsInstance(ret, mock_dbus_python.Boolean)

    def test_call_method_filters_bool_false(self):
        def func():
            return method_return
        mock_dbus_python = self.MockDBusPython_func(func)
        bus = dbus_python_adapter.SystemBus(mock_dbus_python)
        method_return = mock_dbus_python.Boolean(False)
        ret = self.call_method(bus, "methodname", "busname",
                               "objectpath", "interface")
        self.assertFalse(ret)
        self.assertNotIsInstance(ret, mock_dbus_python.Boolean)

    def test_call_method_filters_objectpath(self):
        def func():
            return method_return
        mock_dbus_python = self.MockDBusPython_func(func)
        bus = dbus_python_adapter.SystemBus(mock_dbus_python)
        method_return = mock_dbus_python.ObjectPath("objectpath")
        ret = self.call_method(bus, "methodname", "busname",
                               "objectpath", "interface")
        self.assertEqual("objectpath", ret)
        self.assertIsNot("objectpath", ret)
        self.assertNotIsInstance(ret, mock_dbus_python.ObjectPath)

    def test_call_method_filters_booleans_in_dict(self):
        def func():
            return method_return
        mock_dbus_python = self.MockDBusPython_func(func)
        bus = dbus_python_adapter.SystemBus(mock_dbus_python)
        method_return = mock_dbus_python.Dictionary(
        {mock_dbus_python.Boolean(True):
         mock_dbus_python.Boolean(False),
         mock_dbus_python.Boolean(False):
         mock_dbus_python.Boolean(True)})
        ret = self.call_method(bus, "methodname", "busname",
                               "objectpath", "interface")
        expected_method_return = {True: False,
                                  False: True}
        self.assertEqual(expected_method_return, ret)
        self.assertNotIsInstance(ret, mock_dbus_python.Dictionary)

    def test_call_method_filters_objectpaths_in_dict(self):
        def func():
            return method_return
        mock_dbus_python = self.MockDBusPython_func(func)
        bus = dbus_python_adapter.SystemBus(mock_dbus_python)
        method_return = mock_dbus_python.Dictionary(
        {mock_dbus_python.ObjectPath("objectpath_key_1"):
         mock_dbus_python.ObjectPath("objectpath_value_1"),
         mock_dbus_python.ObjectPath("objectpath_key_2"):
         mock_dbus_python.ObjectPath("objectpath_value_2")})
        ret = self.call_method(bus, "methodname", "busname",
                               "objectpath", "interface")
        expected_method_return = {str(key): str(value)
                                  for key, value in
                                  method_return.items()}
        self.assertEqual(expected_method_return, ret)
        self.assertIsInstance(ret, dict)
        self.assertNotIsInstance(ret, mock_dbus_python.Dictionary)

    def test_call_method_filters_dict_in_dict(self):
        def func():
            return method_return
        mock_dbus_python = self.MockDBusPython_func(func)
        bus = dbus_python_adapter.SystemBus(mock_dbus_python)
        method_return = mock_dbus_python.Dictionary(
        {"key1": mock_dbus_python.Dictionary({"key11": "value11",
                                              "key12": "value12"}),
         "key2": mock_dbus_python.Dictionary({"key21": "value21",
                                              "key22": "value22"})})
        ret = self.call_method(bus, "methodname", "busname",
                               "objectpath", "interface")
        expected_method_return = {
            "key1": {"key11": "value11",
                     "key12": "value12"},
            "key2": {"key21": "value21",
                     "key22": "value22"},
        }
        self.assertEqual(expected_method_return, ret)
        self.assertIsInstance(ret, dict)
        self.assertNotIsInstance(ret, mock_dbus_python.Dictionary)
        for key, value in ret.items():
            self.assertIsInstance(value, dict)
            self.assertEqual(expected_method_return[key], value)
            self.assertNotIsInstance(value,
                                     mock_dbus_python.Dictionary)

    def test_call_method_filters_dict_three_deep(self):
        def func():
            return method_return
        mock_dbus_python = self.MockDBusPython_func(func)
        bus = dbus_python_adapter.SystemBus(mock_dbus_python)
        method_return = mock_dbus_python.Dictionary(
            {"key1":
             mock_dbus_python.Dictionary(
                 {"key2":
                  mock_dbus_python.Dictionary(
                      {"key3":
                       mock_dbus_python.Boolean(True),
                       }),
                  }),
             })
        ret = self.call_method(bus, "methodname", "busname",
                               "objectpath", "interface")
        expected_method_return = {"key1": {"key2": {"key3": True}}}
        self.assertEqual(expected_method_return, ret)
        self.assertIsInstance(ret, dict)
        self.assertNotIsInstance(ret, mock_dbus_python.Dictionary)
        self.assertIsInstance(ret["key1"], dict)
        self.assertNotIsInstance(ret["key1"],
                                 mock_dbus_python.Dictionary)
        self.assertIsInstance(ret["key1"]["key2"], dict)
        self.assertNotIsInstance(ret["key1"]["key2"],
                                 mock_dbus_python.Dictionary)
        self.assertTrue(ret["key1"]["key2"]["key3"])
        self.assertNotIsInstance(ret["key1"]["key2"]["key3"],
                                 mock_dbus_python.Boolean)

    def test_call_method_handles_exception(self):
        dbus_logger = logging.getLogger("dbus.proxies")

        def func():
            dbus_logger.error("Test")
            raise mock_dbus_python.exceptions.DBusException()

        mock_dbus_python = self.MockDBusPython_func(func)
        bus = dbus_python_adapter.SystemBus(mock_dbus_python)

        class CountingHandler(logging.Handler):
            count = 0
            def emit(self, record):
                self.count += 1

        counting_handler = CountingHandler()

        dbus_logger.addHandler(counting_handler)

        try:
            with self.assertRaises(dbus.Error) as e:
                self.call_method(bus, "methodname", "busname",
                                 "objectpath", "interface")
        finally:
            dbus_logger.removeFilter(counting_handler)

        self.assertNotIsInstance(e, dbus.ConnectFailed)

        # Make sure the dbus logger was suppressed
        self.assertEqual(0, counting_handler.count)

    def test_get_object_converts_to_correct_exception(self):
        bus = dbus_python_adapter.SystemBus(
            self.fake_dbus_python_raises_exception_on_connect)
        with self.assertRaises(dbus.ConnectFailed):
            self.call_method(bus, "methodname", "busname",
                             "objectpath", "interface")

    class fake_dbus_python_raises_exception_on_connect(object):
        """fake dbus-python module"""
        class exceptions(object):
            """Pseudo-namespace"""
            class DBusException(Exception):
                pass

        @classmethod
        def SystemBus(cls):
            def get_object(busname, objectpath):
                raise cls.exceptions.DBusException()
            Bus = collections.namedtuple("Bus", ["get_object"])
            return Bus(get_object=get_object)


class Test_dbus_python_adapter_CachingBus(unittest.TestCase):
    class mock_dbus_python(object):
        """mock dbus-python modules"""
        class SystemBus(object):
            @staticmethod
            def get_object(busname, objectpath):
                return Unique()

    def setUp(self):
        self.bus = dbus_python_adapter.CachingBus(
            self.mock_dbus_python)

    def test_returns_distinct_objectpaths(self):
        obj1 = self.bus.get_object("busname", "objectpath1")
        self.assertIsInstance(obj1, Unique)
        obj2 = self.bus.get_object("busname", "objectpath2")
        self.assertIsInstance(obj2, Unique)
        self.assertIsNot(obj1, obj2)

    def test_returns_distinct_busnames(self):
        obj1 = self.bus.get_object("busname1", "objectpath")
        self.assertIsInstance(obj1, Unique)
        obj2 = self.bus.get_object("busname2", "objectpath")
        self.assertIsInstance(obj2, Unique)
        self.assertIsNot(obj1, obj2)

    def test_returns_distinct_both(self):
        obj1 = self.bus.get_object("busname1", "objectpath")
        self.assertIsInstance(obj1, Unique)
        obj2 = self.bus.get_object("busname2", "objectpath")
        self.assertIsInstance(obj2, Unique)
        self.assertIsNot(obj1, obj2)

    def test_returns_same(self):
        obj1 = self.bus.get_object("busname", "objectpath")
        self.assertIsInstance(obj1, Unique)
        obj2 = self.bus.get_object("busname", "objectpath")
        self.assertIsInstance(obj2, Unique)
        self.assertIs(obj1, obj2)

    def test_returns_same_old(self):
        obj1 = self.bus.get_object("busname1", "objectpath1")
        self.assertIsInstance(obj1, Unique)
        obj2 = self.bus.get_object("busname2", "objectpath2")
        self.assertIsInstance(obj2, Unique)
        obj1b = self.bus.get_object("busname1", "objectpath1")
        self.assertIsInstance(obj1b, Unique)
        self.assertIsNot(obj1, obj2)
        self.assertIsNot(obj2, obj1b)
        self.assertIs(obj1, obj1b)


class Test_commands_from_options(unittest.TestCase):

    def setUp(self):
        self.parser = argparse.ArgumentParser()
        add_command_line_options(self.parser)

    def test_is_enabled(self):
        self.assert_command_from_args(["--is-enabled", "client"],
                                      command.IsEnabled)

    def assert_command_from_args(self, args, command_cls,
                                 **cmd_attrs):
        """Assert that parsing ARGS should result in an instance of
COMMAND_CLS with (optionally) all supplied attributes (CMD_ATTRS)."""
        options = self.parser.parse_args(args)
        check_option_syntax(self.parser, options)
        commands = commands_from_options(options)
        self.assertEqual(1, len(commands))
        command = commands[0]
        self.assertIsInstance(command, command_cls)
        for key, value in cmd_attrs.items():
            self.assertEqual(value, getattr(command, key))

    def test_is_enabled_short(self):
        self.assert_command_from_args(["-V", "client"],
                                      command.IsEnabled)

    def test_approve(self):
        self.assert_command_from_args(["--approve", "client"],
                                      command.Approve)

    def test_approve_short(self):
        self.assert_command_from_args(["-A", "client"],
                                      command.Approve)

    def test_deny(self):
        self.assert_command_from_args(["--deny", "client"],
                                      command.Deny)

    def test_deny_short(self):
        self.assert_command_from_args(["-D", "client"], command.Deny)

    def test_remove(self):
        self.assert_command_from_args(["--remove", "client"],
                                      command.Remove)

    def test_deny_before_remove(self):
        options = self.parser.parse_args(["--deny", "--remove",
                                          "client"])
        check_option_syntax(self.parser, options)
        commands = commands_from_options(options)
        self.assertEqual(2, len(commands))
        self.assertIsInstance(commands[0], command.Deny)
        self.assertIsInstance(commands[1], command.Remove)

    def test_deny_before_remove_reversed(self):
        options = self.parser.parse_args(["--remove", "--deny",
                                          "--all"])
        check_option_syntax(self.parser, options)
        commands = commands_from_options(options)
        self.assertEqual(2, len(commands))
        self.assertIsInstance(commands[0], command.Deny)
        self.assertIsInstance(commands[1], command.Remove)

    def test_remove_short(self):
        self.assert_command_from_args(["-r", "client"],
                                      command.Remove)

    def test_dump_json(self):
        self.assert_command_from_args(["--dump-json"],
                                      command.DumpJSON)

    def test_enable(self):
        self.assert_command_from_args(["--enable", "client"],
                                      command.Enable)

    def test_enable_short(self):
        self.assert_command_from_args(["-e", "client"],
                                      command.Enable)

    def test_disable(self):
        self.assert_command_from_args(["--disable", "client"],
                                      command.Disable)

    def test_disable_short(self):
        self.assert_command_from_args(["-d", "client"],
                                      command.Disable)

    def test_bump_timeout(self):
        self.assert_command_from_args(["--bump-timeout", "client"],
                                      command.BumpTimeout)

    def test_bump_timeout_short(self):
        self.assert_command_from_args(["-b", "client"],
                                      command.BumpTimeout)

    def test_start_checker(self):
        self.assert_command_from_args(["--start-checker", "client"],
                                      command.StartChecker)

    def test_stop_checker(self):
        self.assert_command_from_args(["--stop-checker", "client"],
                                      command.StopChecker)

    def test_approve_by_default(self):
        self.assert_command_from_args(["--approve-by-default",
                                       "client"],
                                      command.ApproveByDefault)

    def test_deny_by_default(self):
        self.assert_command_from_args(["--deny-by-default", "client"],
                                      command.DenyByDefault)

    def test_checker(self):
        self.assert_command_from_args(["--checker", ":", "client"],
                                      command.SetChecker,
                                      value_to_set=":")

    def test_checker_empty(self):
        self.assert_command_from_args(["--checker", "", "client"],
                                      command.SetChecker,
                                      value_to_set="")

    def test_checker_short(self):
        self.assert_command_from_args(["-c", ":", "client"],
                                      command.SetChecker,
                                      value_to_set=":")

    def test_host(self):
        self.assert_command_from_args(
            ["--host", "client.example.org", "client"],
            command.SetHost, value_to_set="client.example.org")

    def test_host_short(self):
        self.assert_command_from_args(
            ["-H", "client.example.org", "client"], command.SetHost,
            value_to_set="client.example.org")

    def test_secret_devnull(self):
        self.assert_command_from_args(["--secret", os.path.devnull,
                                       "client"], command.SetSecret,
                                      value_to_set=b"")

    def test_secret_tempfile(self):
        with tempfile.NamedTemporaryFile(mode="r+b") as f:
            value = b"secret\0xyzzy\nbar"
            f.write(value)
            f.seek(0)
            self.assert_command_from_args(["--secret", f.name,
                                           "client"],
                                          command.SetSecret,
                                          value_to_set=value)

    def test_secret_devnull_short(self):
        self.assert_command_from_args(["-s", os.path.devnull,
                                       "client"], command.SetSecret,
                                      value_to_set=b"")

    def test_secret_tempfile_short(self):
        with tempfile.NamedTemporaryFile(mode="r+b") as f:
            value = b"secret\0xyzzy\nbar"
            f.write(value)
            f.seek(0)
            self.assert_command_from_args(["-s", f.name, "client"],
                                          command.SetSecret,
                                          value_to_set=value)

    def test_timeout(self):
        self.assert_command_from_args(["--timeout", "PT5M", "client"],
                                      command.SetTimeout,
                                      value_to_set=300000)

    def test_timeout_short(self):
        self.assert_command_from_args(["-t", "PT5M", "client"],
                                      command.SetTimeout,
                                      value_to_set=300000)

    def test_extended_timeout(self):
        self.assert_command_from_args(["--extended-timeout", "PT15M",
                                       "client"],
                                      command.SetExtendedTimeout,
                                      value_to_set=900000)

    def test_interval(self):
        self.assert_command_from_args(["--interval", "PT2M",
                                       "client"], command.SetInterval,
                                      value_to_set=120000)

    def test_interval_short(self):
        self.assert_command_from_args(["-i", "PT2M", "client"],
                                      command.SetInterval,
                                      value_to_set=120000)

    def test_approval_delay(self):
        self.assert_command_from_args(["--approval-delay", "PT30S",
                                       "client"],
                                      command.SetApprovalDelay,
                                      value_to_set=30000)

    def test_approval_duration(self):
        self.assert_command_from_args(["--approval-duration", "PT1S",
                                       "client"],
                                      command.SetApprovalDuration,
                                      value_to_set=1000)

    def test_print_table(self):
        self.assert_command_from_args([], command.PrintTable,
                                      verbose=False)

    def test_print_table_verbose(self):
        self.assert_command_from_args(["--verbose"],
                                      command.PrintTable,
                                      verbose=True)

    def test_print_table_verbose_short(self):
        self.assert_command_from_args(["-v"], command.PrintTable,
                                      verbose=True)


class TestCommand(unittest.TestCase):
    """Abstract class for tests of command classes"""

    class FakeMandosBus(dbus.MandosBus):
        def __init__(self, testcase):
            self.client_properties = {
                "Name": "foo",
                "KeyID": ("92ed150794387c03ce684574b1139a65"
                          "94a34f895daaaf09fd8ea90a27cddb12"),
                "Secret": b"secret",
                "Host": "foo.example.org",
                "Enabled": True,
                "Timeout": 300000,
                "LastCheckedOK": "2019-02-03T00:00:00",
                "Created": "2019-01-02T00:00:00",
                "Interval": 120000,
                "Fingerprint": ("778827225BA7DE539C5A"
                                "7CFA59CFF7CDBD9A5920"),
                "CheckerRunning": False,
                "LastEnabled": "2019-01-03T00:00:00",
                "ApprovalPending": False,
                "ApprovedByDefault": True,
                "LastApprovalRequest": "",
                "ApprovalDelay": 0,
                "ApprovalDuration": 1000,
                "Checker": "fping -q -- %(host)s",
                "ExtendedTimeout": 900000,
                "Expires": "2019-02-04T00:00:00",
                "LastCheckerStatus": 0,
            }
            self.other_client_properties = {
                "Name": "barbar",
                "KeyID": ("0558568eedd67d622f5c83b35a115f79"
                          "6ab612cff5ad227247e46c2b020f441c"),
                "Secret": b"secretbar",
                "Host": "192.0.2.3",
                "Enabled": True,
                "Timeout": 300000,
                "LastCheckedOK": "2019-02-04T00:00:00",
                "Created": "2019-01-03T00:00:00",
                "Interval": 120000,
                "Fingerprint": ("3E393AEAEFB84C7E89E2"
                                "F547B3A107558FCA3A27"),
                "CheckerRunning": True,
                "LastEnabled": "2019-01-04T00:00:00",
                "ApprovalPending": False,
                "ApprovedByDefault": False,
                "LastApprovalRequest": "2019-01-03T00:00:00",
                "ApprovalDelay": 30000,
                "ApprovalDuration": 93785000,
                "Checker": ":",
                "ExtendedTimeout": 900000,
                "Expires": "2019-02-05T00:00:00",
                "LastCheckerStatus": -2,
            }
            self.clients =  collections.OrderedDict(
                [
                    ("client_objectpath", self.client_properties),
                    ("other_client_objectpath",
                     self.other_client_properties),
                ])
            self.one_client = {"client_objectpath":
                               self.client_properties}
            self.testcase = testcase
            self.calls = []

        def call_method(self, methodname, busname, objectpath,
                        interface, *args):
            self.testcase.assertEqual("se.recompile.Mandos", busname)
            self.calls.append((methodname, busname, objectpath,
                               interface, args))
            if interface == "org.freedesktop.DBus.Properties":
                if methodname == "Set":
                    self.testcase.assertEqual(3, len(args))
                    interface, key, value = args
                    self.testcase.assertEqual(
                        "se.recompile.Mandos.Client", interface)
                    self.clients[objectpath][key] = value
                    return
            elif interface == "se.recompile.Mandos":
                self.testcase.assertEqual("RemoveClient", methodname)
                self.testcase.assertEqual(1, len(args))
                clientpath = args[0]
                del self.clients[clientpath]
                return
            elif interface == "se.recompile.Mandos.Client":
                if methodname == "Approve":
                    self.testcase.assertEqual(1, len(args))
                    return
            raise ValueError()

    def setUp(self):
        self.bus = self.FakeMandosBus(self)


class TestBaseCommands(TestCommand):

    def test_IsEnabled_exits_successfully(self):
        with self.assertRaises(SystemExit) as e:
            command.IsEnabled().run(self.bus.one_client)
        if e.exception.code is not None:
            self.assertEqual(0, e.exception.code)
        else:
            self.assertIsNone(e.exception.code)

    def test_IsEnabled_exits_with_failure(self):
        self.bus.client_properties["Enabled"] = False
        with self.assertRaises(SystemExit) as e:
            command.IsEnabled().run(self.bus.one_client)
        if isinstance(e.exception.code, int):
            self.assertNotEqual(0, e.exception.code)
        else:
            self.assertIsNotNone(e.exception.code)

    def test_Approve(self):
        busname = "se.recompile.Mandos"
        client_interface = "se.recompile.Mandos.Client"
        command.Approve().run(self.bus.clients, self.bus)
        for clientpath in self.bus.clients:
            self.assertIn(("Approve", busname, clientpath,
                           client_interface, (True,)), self.bus.calls)

    def test_Deny(self):
        busname = "se.recompile.Mandos"
        client_interface = "se.recompile.Mandos.Client"
        command.Deny().run(self.bus.clients, self.bus)
        for clientpath in self.bus.clients:
            self.assertIn(("Approve", busname, clientpath,
                           client_interface, (False,)),
                          self.bus.calls)

    def test_Remove(self):
        command.Remove().run(self.bus.clients, self.bus)
        for clientpath in self.bus.clients:
            self.assertIn(("RemoveClient", dbus_busname,
                           dbus_server_path, dbus_server_interface,
                           (clientpath,)), self.bus.calls)

    expected_json = {
        "foo": {
            "Name": "foo",
            "KeyID": ("92ed150794387c03ce684574b1139a65"
                      "94a34f895daaaf09fd8ea90a27cddb12"),
            "Host": "foo.example.org",
            "Enabled": True,
            "Timeout": 300000,
            "LastCheckedOK": "2019-02-03T00:00:00",
            "Created": "2019-01-02T00:00:00",
            "Interval": 120000,
            "Fingerprint": ("778827225BA7DE539C5A"
                            "7CFA59CFF7CDBD9A5920"),
            "CheckerRunning": False,
            "LastEnabled": "2019-01-03T00:00:00",
            "ApprovalPending": False,
            "ApprovedByDefault": True,
            "LastApprovalRequest": "",
            "ApprovalDelay": 0,
            "ApprovalDuration": 1000,
            "Checker": "fping -q -- %(host)s",
            "ExtendedTimeout": 900000,
            "Expires": "2019-02-04T00:00:00",
            "LastCheckerStatus": 0,
        },
        "barbar": {
            "Name": "barbar",
            "KeyID": ("0558568eedd67d622f5c83b35a115f79"
                      "6ab612cff5ad227247e46c2b020f441c"),
            "Host": "192.0.2.3",
            "Enabled": True,
            "Timeout": 300000,
            "LastCheckedOK": "2019-02-04T00:00:00",
            "Created": "2019-01-03T00:00:00",
            "Interval": 120000,
            "Fingerprint": ("3E393AEAEFB84C7E89E2"
                            "F547B3A107558FCA3A27"),
            "CheckerRunning": True,
            "LastEnabled": "2019-01-04T00:00:00",
            "ApprovalPending": False,
            "ApprovedByDefault": False,
            "LastApprovalRequest": "2019-01-03T00:00:00",
            "ApprovalDelay": 30000,
            "ApprovalDuration": 93785000,
            "Checker": ":",
            "ExtendedTimeout": 900000,
            "Expires": "2019-02-05T00:00:00",
            "LastCheckerStatus": -2,
        },
    }

    def test_DumpJSON_normal(self):
        with self.capture_stdout_to_buffer() as buffer:
            command.DumpJSON().run(self.bus.clients)
        json_data = json.loads(buffer.getvalue())
        self.assertDictEqual(self.expected_json, json_data)

    @staticmethod
    @contextlib.contextmanager
    def capture_stdout_to_buffer():
        capture_buffer = io.StringIO()
        old_stdout = sys.stdout
        sys.stdout = capture_buffer
        try:
            yield capture_buffer
        finally:
            sys.stdout = old_stdout

    def test_DumpJSON_one_client(self):
        with self.capture_stdout_to_buffer() as buffer:
            command.DumpJSON().run(self.bus.one_client)
        json_data = json.loads(buffer.getvalue())
        expected_json = {"foo": self.expected_json["foo"]}
        self.assertDictEqual(expected_json, json_data)

    def test_PrintTable_normal(self):
        with self.capture_stdout_to_buffer() as buffer:
            command.PrintTable().run(self.bus.clients)
        expected_output = "\n".join((
            "Name   Enabled Timeout  Last Successful Check",
            "foo    Yes     00:05:00 2019-02-03T00:00:00  ",
            "barbar Yes     00:05:00 2019-02-04T00:00:00  ",
        )) + "\n"
        self.assertEqual(expected_output, buffer.getvalue())

    def test_PrintTable_verbose(self):
        with self.capture_stdout_to_buffer() as buffer:
            command.PrintTable(verbose=True).run(self.bus.clients)
        columns = (
            (
                "Name   ",
                "foo    ",
                "barbar ",
            ),(
                "Enabled ",
                "Yes     ",
                "Yes     ",
            ),(
                "Timeout  ",
                "00:05:00 ",
                "00:05:00 ",
            ),(
                "Last Successful Check ",
                "2019-02-03T00:00:00   ",
                "2019-02-04T00:00:00   ",
            ),(
                "Created             ",
                "2019-01-02T00:00:00 ",
                "2019-01-03T00:00:00 ",
            ),(
                "Interval ",
                "00:02:00 ",
                "00:02:00 ",
            ),(
                "Host            ",
                "foo.example.org ",
                "192.0.2.3       ",
            ),(
                ("Key ID                                             "
                 "              "),
                ("92ed150794387c03ce684574b1139a6594a34f895daaaf09fd8"
                 "ea90a27cddb12 "),
                ("0558568eedd67d622f5c83b35a115f796ab612cff5ad227247e"
                 "46c2b020f441c "),
            ),(
                "Fingerprint                              ",
                "778827225BA7DE539C5A7CFA59CFF7CDBD9A5920 ",
                "3E393AEAEFB84C7E89E2F547B3A107558FCA3A27 ",
            ),(
                "Check Is Running ",
                "No               ",
                "Yes              ",
            ),(
                "Last Enabled        ",
                "2019-01-03T00:00:00 ",
                "2019-01-04T00:00:00 ",
            ),(
                "Approval Is Pending ",
                "No                  ",
                "No                  ",
            ),(
                "Approved By Default ",
                "Yes                 ",
                "No                  ",
            ),(
                "Last Approval Request ",
                "                      ",
                "2019-01-03T00:00:00   ",
            ),(
                "Approval Delay ",
                "00:00:00       ",
                "00:00:30       ",
            ),(
                "Approval Duration ",
                "00:00:01          ",
                "1T02:03:05        ",
            ),(
                "Checker              ",
                "fping -q -- %(host)s ",
                ":                    ",
            ),(
                "Extended Timeout ",
                "00:15:00         ",
                "00:15:00         ",
            ),(
                "Expires             ",
                "2019-02-04T00:00:00 ",
                "2019-02-05T00:00:00 ",
            ),(
                "Last Checker Status",
                "0                  ",
                "-2                 ",
            )
        )
        num_lines = max(len(rows) for rows in columns)
        expected_output = ("\n".join("".join(rows[line]
                                             for rows in columns)
                                     for line in range(num_lines))
                           + "\n")
        self.assertEqual(expected_output, buffer.getvalue())

    def test_PrintTable_one_client(self):
        with self.capture_stdout_to_buffer() as buffer:
            command.PrintTable().run(self.bus.one_client)
        expected_output = "\n".join((
            "Name Enabled Timeout  Last Successful Check",
            "foo  Yes     00:05:00 2019-02-03T00:00:00  ",
        )) + "\n"
        self.assertEqual(expected_output, buffer.getvalue())


class TestPropertySetterCmd(TestCommand):
    """Abstract class for tests of command.PropertySetter classes"""

    def runTest(self):
        if not hasattr(self, "command"):
            return              # Abstract TestCase class
        values_to_get = getattr(self, "values_to_get",
                                self.values_to_set)
        for value_to_set, value_to_get in zip(self.values_to_set,
                                              values_to_get):
            for clientpath in self.bus.clients:
                self.bus.clients[clientpath][self.propname] = Unique()
            self.run_command(value_to_set, self.bus.clients)
            for clientpath in self.bus.clients:
                value = self.bus.clients[clientpath][self.propname]
                self.assertNotIsInstance(value, Unique)
                self.assertEqual(value_to_get, value)

    def run_command(self, value, clients):
        self.command().run(clients, self.bus)


class TestEnableCmd(TestPropertySetterCmd):
    command = command.Enable
    propname = "Enabled"
    values_to_set = [True]


class TestDisableCmd(TestPropertySetterCmd):
    command = command.Disable
    propname = "Enabled"
    values_to_set = [False]


class TestBumpTimeoutCmd(TestPropertySetterCmd):
    command = command.BumpTimeout
    propname = "LastCheckedOK"
    values_to_set = [""]


class TestStartCheckerCmd(TestPropertySetterCmd):
    command = command.StartChecker
    propname = "CheckerRunning"
    values_to_set = [True]


class TestStopCheckerCmd(TestPropertySetterCmd):
    command = command.StopChecker
    propname = "CheckerRunning"
    values_to_set = [False]


class TestApproveByDefaultCmd(TestPropertySetterCmd):
    command = command.ApproveByDefault
    propname = "ApprovedByDefault"
    values_to_set = [True]


class TestDenyByDefaultCmd(TestPropertySetterCmd):
    command = command.DenyByDefault
    propname = "ApprovedByDefault"
    values_to_set = [False]


class TestPropertySetterValueCmd(TestPropertySetterCmd):
    """Abstract class for tests of PropertySetterValueCmd classes"""

    def run_command(self, value, clients):
        self.command(value).run(clients, self.bus)


class TestSetCheckerCmd(TestPropertySetterValueCmd):
    command = command.SetChecker
    propname = "Checker"
    values_to_set = ["", ":", "fping -q -- %s"]


class TestSetHostCmd(TestPropertySetterValueCmd):
    command = command.SetHost
    propname = "Host"
    values_to_set = ["192.0.2.3", "client.example.org"]


class TestSetSecretCmd(TestPropertySetterValueCmd):
    command = command.SetSecret
    propname = "Secret"
    values_to_set = [io.BytesIO(b""),
                     io.BytesIO(b"secret\0xyzzy\nbar")]
    values_to_get = [f.getvalue() for f in values_to_set]


class TestSetTimeoutCmd(TestPropertySetterValueCmd):
    command = command.SetTimeout
    propname = "Timeout"
    values_to_set = [datetime.timedelta(),
                     datetime.timedelta(minutes=5),
                     datetime.timedelta(seconds=1),
                     datetime.timedelta(weeks=1),
                     datetime.timedelta(weeks=52)]
    values_to_get = [dt.total_seconds()*1000 for dt in values_to_set]


class TestSetExtendedTimeoutCmd(TestPropertySetterValueCmd):
    command = command.SetExtendedTimeout
    propname = "ExtendedTimeout"
    values_to_set = [datetime.timedelta(),
                     datetime.timedelta(minutes=5),
                     datetime.timedelta(seconds=1),
                     datetime.timedelta(weeks=1),
                     datetime.timedelta(weeks=52)]
    values_to_get = [dt.total_seconds()*1000 for dt in values_to_set]


class TestSetIntervalCmd(TestPropertySetterValueCmd):
    command = command.SetInterval
    propname = "Interval"
    values_to_set = [datetime.timedelta(),
                     datetime.timedelta(minutes=5),
                     datetime.timedelta(seconds=1),
                     datetime.timedelta(weeks=1),
                     datetime.timedelta(weeks=52)]
    values_to_get = [dt.total_seconds()*1000 for dt in values_to_set]


class TestSetApprovalDelayCmd(TestPropertySetterValueCmd):
    command = command.SetApprovalDelay
    propname = "ApprovalDelay"
    values_to_set = [datetime.timedelta(),
                     datetime.timedelta(minutes=5),
                     datetime.timedelta(seconds=1),
                     datetime.timedelta(weeks=1),
                     datetime.timedelta(weeks=52)]
    values_to_get = [dt.total_seconds()*1000 for dt in values_to_set]


class TestSetApprovalDurationCmd(TestPropertySetterValueCmd):
    command = command.SetApprovalDuration
    propname = "ApprovalDuration"
    values_to_set = [datetime.timedelta(),
                     datetime.timedelta(minutes=5),
                     datetime.timedelta(seconds=1),
                     datetime.timedelta(weeks=1),
                     datetime.timedelta(weeks=52)]
    values_to_get = [dt.total_seconds()*1000 for dt in values_to_set]



def should_only_run_tests():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("--check", action='store_true')
    args, unknown_args = parser.parse_known_args()
    run_tests = args.check
    if run_tests:
        # Remove --check argument from sys.argv
        sys.argv[1:] = unknown_args
    return run_tests

# Add all tests from doctest strings
def load_tests(loader, tests, none):
    import doctest
    tests.addTests(doctest.DocTestSuite())
    return tests

if __name__ == "__main__":
    try:
        if should_only_run_tests():
            # Call using ./tdd-python-script --check [--verbose]
            unittest.main()
        else:
            main()
    finally:
        logging.shutdown()
