/mandos/trunk

To get this branch, use:
bzr branch http://bzr.recompile.se/loggerhead/mandos/trunk

« back to all changes in this revision

Viewing changes to mandos

  • Committer: Teddy Hogeborn
  • Date: 2019-08-24 14:43:51 UTC
  • Revision ID: teddy@recompile.se-20190824144351-2y0l31jpj496vrtu
Server: Add scaffolding for tests

* mandos: Add code to run tests via the unittest module, similar to
          the code in mandos-ctl.  Also shut down logging on exit.

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
1
#!/usr/bin/python
2
 
# -*- mode: python; coding: utf-8 -*-
 
2
# -*- mode: python; after-save-hook: (lambda () (let ((command (if (fboundp 'file-local-name) (file-local-name (buffer-file-name)) (or (file-remote-p (buffer-file-name) 'localname) (buffer-file-name))))) (if (= (progn (if (get-buffer "*Test*") (kill-buffer "*Test*")) (process-file-shell-command (format "%s --check" (shell-quote-argument command)) nil "*Test*")) 0) (let ((w (get-buffer-window "*Test*"))) (if w (delete-window w))) (progn (with-current-buffer "*Test*" (compilation-mode)) (display-buffer "*Test*" '(display-buffer-in-side-window)))))); coding: utf-8 -*-
3
3
#
4
4
# Mandos server - give out binary blobs to connecting clients.
5
5
#
77
77
import itertools
78
78
import collections
79
79
import codecs
 
80
import unittest
80
81
 
81
82
import dbus
82
83
import dbus.service
 
84
import gi
83
85
from gi.repository import GLib
84
86
from dbus.mainloop.glib import DBusGMainLoop
85
87
import ctypes
87
89
import xml.dom.minidom
88
90
import inspect
89
91
 
 
92
if sys.version_info.major == 2:
 
93
    __metaclass__ = type
 
94
 
90
95
# Try to find the value of SO_BINDTODEVICE:
91
96
try:
92
97
    # This is where SO_BINDTODEVICE is in Python 3.3 (or 3.4?) and
115
120
if sys.version_info.major == 2:
116
121
    str = unicode
117
122
 
118
 
version = "1.7.20"
 
123
if sys.version_info < (3, 2):
 
124
    configparser.Configparser = configparser.SafeConfigParser
 
125
 
 
126
version = "1.8.8"
119
127
stored_state_file = "clients.pickle"
120
128
 
121
129
logger = logging.getLogger()
179
187
    pass
180
188
 
181
189
 
182
 
class PGPEngine(object):
 
190
class PGPEngine:
183
191
    """A simple class for OpenPGP symmetric encryption & decryption"""
184
192
 
185
193
    def __init__(self):
275
283
 
276
284
 
277
285
# Pretend that we have an Avahi module
278
 
class Avahi(object):
279
 
    """This isn't so much a class as it is a module-like namespace.
280
 
    It is instantiated once, and simulates having an Avahi module."""
 
286
class avahi:
 
287
    """This isn't so much a class as it is a module-like namespace."""
281
288
    IF_UNSPEC = -1               # avahi-common/address.h
282
289
    PROTO_UNSPEC = -1            # avahi-common/address.h
283
290
    PROTO_INET = 0               # avahi-common/address.h
287
294
    DBUS_INTERFACE_SERVER = DBUS_NAME + ".Server"
288
295
    DBUS_PATH_SERVER = "/"
289
296
 
290
 
    def string_array_to_txt_array(self, t):
 
297
    @staticmethod
 
298
    def string_array_to_txt_array(t):
291
299
        return dbus.Array((dbus.ByteArray(s.encode("utf-8"))
292
300
                           for s in t), signature="ay")
293
301
    ENTRY_GROUP_ESTABLISHED = 2  # avahi-common/defs.h
298
306
    SERVER_RUNNING = 2           # avahi-common/defs.h
299
307
    SERVER_COLLISION = 3         # avahi-common/defs.h
300
308
    SERVER_FAILURE = 4           # avahi-common/defs.h
301
 
avahi = Avahi()
302
309
 
303
310
 
304
311
class AvahiError(Exception):
316
323
    pass
317
324
 
318
325
 
319
 
class AvahiService(object):
 
326
class AvahiService:
320
327
    """An Avahi (Zeroconf) service.
321
328
 
322
329
    Attributes:
504
511
 
505
512
 
506
513
# Pretend that we have a GnuTLS module
507
 
class GnuTLS(object):
508
 
    """This isn't so much a class as it is a module-like namespace.
509
 
    It is instantiated once, and simulates having a GnuTLS module."""
 
514
class gnutls:
 
515
    """This isn't so much a class as it is a module-like namespace."""
510
516
 
511
517
    library = ctypes.util.find_library("gnutls")
512
518
    if library is None:
513
519
        library = ctypes.util.find_library("gnutls-deb0")
514
520
    _library = ctypes.cdll.LoadLibrary(library)
515
521
    del library
516
 
    _need_version = b"3.3.0"
517
 
    _tls_rawpk_version = b"3.6.6"
518
 
 
519
 
    def __init__(self):
520
 
        # Need to use "self" here, since this method is called before
521
 
        # the assignment to the "gnutls" global variable happens.
522
 
        if self.check_version(self._need_version) is None:
523
 
            raise self.Error("Needs GnuTLS {} or later"
524
 
                             .format(self._need_version))
525
522
 
526
523
    # Unless otherwise indicated, the constants and types below are
527
524
    # all from the gnutls/gnutls.h C header file.
569
566
 
570
567
    # Exceptions
571
568
    class Error(Exception):
572
 
        # We need to use the class name "GnuTLS" here, since this
573
 
        # exception might be raised from within GnuTLS.__init__,
574
 
        # which is called before the assignment to the "gnutls"
575
 
        # global variable has happened.
576
569
        def __init__(self, message=None, code=None, args=()):
577
570
            # Default usage is by a message string, but if a return
578
571
            # code is passed, convert it to a string with
579
572
            # gnutls.strerror()
580
573
            self.code = code
581
574
            if message is None and code is not None:
582
 
                message = GnuTLS.strerror(code)
583
 
            return super(GnuTLS.Error, self).__init__(
 
575
                message = gnutls.strerror(code)
 
576
            return super(gnutls.Error, self).__init__(
584
577
                message, *args)
585
578
 
586
579
    class CertificateSecurityError(Error):
587
580
        pass
588
581
 
589
582
    # Classes
590
 
    class Credentials(object):
 
583
    class Credentials:
591
584
        def __init__(self):
592
585
            self._c_object = gnutls.certificate_credentials_t()
593
586
            gnutls.certificate_allocate_credentials(
597
590
        def __del__(self):
598
591
            gnutls.certificate_free_credentials(self._c_object)
599
592
 
600
 
    class ClientSession(object):
 
593
    class ClientSession:
601
594
        def __init__(self, socket, credentials=None):
602
595
            self._c_object = gnutls.session_t()
603
596
            gnutls_flags = gnutls.CLIENT
604
 
            if gnutls.check_version("3.5.6"):
 
597
            if gnutls.check_version(b"3.5.6"):
605
598
                gnutls_flags |= gnutls.NO_TICKETS
606
599
            if gnutls.has_rawpk:
607
600
                gnutls_flags |= gnutls.ENABLE_RAWPK
744
737
    check_version.argtypes = [ctypes.c_char_p]
745
738
    check_version.restype = ctypes.c_char_p
746
739
 
 
740
    _need_version = b"3.3.0"
 
741
    if check_version(_need_version) is None:
 
742
        raise self.Error("Needs GnuTLS {} or later"
 
743
                         .format(_need_version))
 
744
 
 
745
    _tls_rawpk_version = b"3.6.6"
747
746
    has_rawpk = bool(check_version(_tls_rawpk_version))
748
747
 
749
748
    if has_rawpk:
803
802
                                                    ctypes.c_size_t)]
804
803
        openpgp_crt_get_fingerprint.restype = _error_code
805
804
 
806
 
    if check_version("3.6.4"):
 
805
    if check_version(b"3.6.4"):
807
806
        certificate_type_get2 = _library.gnutls_certificate_type_get2
808
807
        certificate_type_get2.argtypes = [session_t, ctypes.c_int]
809
808
        certificate_type_get2.restype = _error_code
810
809
 
811
810
    # Remove non-public functions
812
811
    del _error_code, _retry_on_error
813
 
# Create the global "gnutls" object, simulating a module
814
 
gnutls = GnuTLS()
815
812
 
816
813
 
817
814
def call_pipe(connection,       # : multiprocessing.Connection
825
822
    connection.close()
826
823
 
827
824
 
828
 
class Client(object):
 
825
class Client:
829
826
    """A representation of a client host served by this server.
830
827
 
831
828
    Attributes:
832
829
    approved:   bool(); 'None' if not yet approved/disapproved
833
830
    approval_delay: datetime.timedelta(); Time to wait for approval
834
831
    approval_duration: datetime.timedelta(); Duration of one approval
835
 
    checker:    subprocess.Popen(); a running checker process used
836
 
                                    to see if the client lives.
837
 
                                    'None' if no process is running.
 
832
    checker: multiprocessing.Process(); a running checker process used
 
833
             to see if the client lives. 'None' if no process is
 
834
             running.
838
835
    checker_callback_tag: a GLib event source tag, or None
839
836
    checker_command: string; External command which is run to check
840
837
                     if client lives.  %() expansions are done at
1047
1044
    def checker_callback(self, source, condition, connection,
1048
1045
                         command):
1049
1046
        """The checker has completed, so take appropriate actions."""
1050
 
        self.checker_callback_tag = None
1051
 
        self.checker = None
1052
1047
        # Read return code from connection (see call_pipe)
1053
1048
        returncode = connection.recv()
1054
1049
        connection.close()
 
1050
        self.checker.join()
 
1051
        self.checker_callback_tag = None
 
1052
        self.checker = None
1055
1053
 
1056
1054
        if returncode >= 0:
1057
1055
            self.last_checker_status = returncode
2219
2217
    del _interface
2220
2218
 
2221
2219
 
2222
 
class ProxyClient(object):
 
2220
class ProxyClient:
2223
2221
    def __init__(self, child_pipe, key_id, fpr, address):
2224
2222
        self._pipe = child_pipe
2225
2223
        self._pipe.send(('init', key_id, fpr, address))
2298
2296
            approval_required = False
2299
2297
            try:
2300
2298
                if gnutls.has_rawpk:
2301
 
                    fpr = ""
 
2299
                    fpr = b""
2302
2300
                    try:
2303
2301
                        key_id = self.key_id(
2304
2302
                            self.peer_certificate(session))
2308
2306
                    logger.debug("Key ID: %s", key_id)
2309
2307
 
2310
2308
                else:
2311
 
                    key_id = ""
 
2309
                    key_id = b""
2312
2310
                    try:
2313
2311
                        fpr = self.fingerprint(
2314
2312
                            self.peer_certificate(session))
2498
2496
        return hex_fpr
2499
2497
 
2500
2498
 
2501
 
class MultiprocessingMixIn(object):
 
2499
class MultiprocessingMixIn:
2502
2500
    """Like socketserver.ThreadingMixIn, but with multiprocessing"""
2503
2501
 
2504
2502
    def sub_process_main(self, request, address):
2516
2514
        return proc
2517
2515
 
2518
2516
 
2519
 
class MultiprocessingMixInWithPipe(MultiprocessingMixIn, object):
 
2517
class MultiprocessingMixInWithPipe(MultiprocessingMixIn):
2520
2518
    """ adds a pipe to the MixIn """
2521
2519
 
2522
2520
    def process_request(self, request, client_address):
2537
2535
 
2538
2536
 
2539
2537
class IPv6_TCPServer(MultiprocessingMixInWithPipe,
2540
 
                     socketserver.TCPServer, object):
 
2538
                     socketserver.TCPServer):
2541
2539
    """IPv6-capable TCP server.  Accepts 'None' as address and/or port
2542
2540
 
2543
2541
    Attributes:
2616
2614
                    raise
2617
2615
        # Only bind(2) the socket if we really need to.
2618
2616
        if self.server_address[0] or self.server_address[1]:
 
2617
            if self.server_address[1]:
 
2618
                self.allow_reuse_address = True
2619
2619
            if not self.server_address[0]:
2620
2620
                if self.address_family == socket.AF_INET6:
2621
2621
                    any_address = "::"  # in6addr_any
2700
2700
            address = request[3]
2701
2701
 
2702
2702
            for c in self.clients.values():
 
2703
                if key_id == "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855":
 
2704
                    continue
2703
2705
                if key_id and c.key_id == key_id:
2704
2706
                    client = c
2705
2707
                    break
2974
2976
 
2975
2977
    options = parser.parse_args()
2976
2978
 
2977
 
    if options.check:
2978
 
        import doctest
2979
 
        fail_count, test_count = doctest.testmod()
2980
 
        sys.exit(os.EX_OK if fail_count == 0 else 1)
2981
 
 
2982
2979
    # Default values for config file for server-global settings
2983
2980
    if gnutls.has_rawpk:
2984
2981
        priority = ("SECURE128:!CTYPE-X.509:+CTYPE-RAWPK:!RSA"
3004
3001
    del priority
3005
3002
 
3006
3003
    # Parse config file for server-global settings
3007
 
    server_config = configparser.SafeConfigParser(server_defaults)
 
3004
    server_config = configparser.ConfigParser(server_defaults)
3008
3005
    del server_defaults
3009
3006
    server_config.read(os.path.join(options.configdir, "mandos.conf"))
3010
 
    # Convert the SafeConfigParser object to a dict
 
3007
    # Convert the ConfigParser object to a dict
3011
3008
    server_settings = server_config.defaults()
3012
3009
    # Use the appropriate methods on the non-string config options
3013
3010
    for option in ("debug", "use_dbus", "use_ipv6", "restore",
3085
3082
                                  server_settings["servicename"])))
3086
3083
 
3087
3084
    # Parse config file with clients
3088
 
    client_config = configparser.SafeConfigParser(Client
3089
 
                                                  .client_defaults)
 
3085
    client_config = configparser.ConfigParser(Client.client_defaults)
3090
3086
    client_config.read(os.path.join(server_settings["configdir"],
3091
3087
                                    "clients.conf"))
3092
3088
 
3163
3159
        # Close all input and output, do double fork, etc.
3164
3160
        daemon()
3165
3161
 
3166
 
    # multiprocessing will use threads, so before we use GLib we need
3167
 
    # to inform GLib that threads will be used.
3168
 
    GLib.threads_init()
 
3162
    if gi.version_info < (3, 10, 2):
 
3163
        # multiprocessing will use threads, so before we use GLib we
 
3164
        # need to inform GLib that threads will be used.
 
3165
        GLib.threads_init()
3169
3166
 
3170
3167
    global main_loop
3171
3168
    # From the Avahi example code
3251
3248
                        for k in ("name", "host"):
3252
3249
                            if isinstance(value[k], bytes):
3253
3250
                                value[k] = value[k].decode("utf-8")
3254
 
                        if not value.has_key("key_id"):
 
3251
                        if "key_id" not in value:
3255
3252
                            value["key_id"] = ""
3256
 
                        elif not value.has_key("fingerprint"):
 
3253
                        elif "fingerprint" not in value:
3257
3254
                            value["fingerprint"] = ""
3258
3255
                    #  old_client_settings
3259
3256
                    # .keys()
3618
3615
    # Must run before the D-Bus bus name gets deregistered
3619
3616
    cleanup()
3620
3617
 
 
3618
 
 
3619
def should_only_run_tests():
 
3620
    parser = argparse.ArgumentParser(add_help=False)
 
3621
    parser.add_argument("--check", action='store_true')
 
3622
    args, unknown_args = parser.parse_known_args()
 
3623
    run_tests = args.check
 
3624
    if run_tests:
 
3625
        # Remove --check argument from sys.argv
 
3626
        sys.argv[1:] = unknown_args
 
3627
    return run_tests
 
3628
 
 
3629
# Add all tests from doctest strings
 
3630
def load_tests(loader, tests, none):
 
3631
    import doctest
 
3632
    tests.addTests(doctest.DocTestSuite())
 
3633
    return tests
3621
3634
 
3622
3635
if __name__ == '__main__':
3623
 
    main()
 
3636
    try:
 
3637
        if should_only_run_tests():
 
3638
            # Call using ./mandos --check [--verbose]
 
3639
            unittest.main()
 
3640
        else:
 
3641
            main()
 
3642
    finally:
 
3643
        logging.shutdown()