/mandos/trunk

To get this branch, use:
bzr branch http://bzr.recompile.se/loggerhead/mandos/trunk

« back to all changes in this revision

Viewing changes to mandos

  • Committer: Teddy Hogeborn
  • Date: 2012-05-26 22:56:38 UTC
  • mfrom: (589.1.1 socket-option)
  • Revision ID: teddy@recompile.se-20120526225638-4hvqyrvmj0036lfn
Merge "--socket" option for server.

This is suggested by the GNU Coding Standards' Table of Long Options,
and will probably also allow socket activation (e.g. by systemd(8)).

Show diffs side-by-side

added added

removed removed

Lines of Context:
34
34
from __future__ import (division, absolute_import, print_function,
35
35
                        unicode_literals)
36
36
 
 
37
from future_builtins import *
 
38
 
37
39
import SocketServer as socketserver
38
40
import socket
39
41
import argparse
86
88
    except ImportError:
87
89
        SO_BINDTODEVICE = None
88
90
 
89
 
version = "1.5.3"
 
91
version = "1.5.4"
90
92
stored_state_file = "clients.pickle"
91
93
 
92
94
logger = logging.getLogger()
149
151
    def __enter__(self):
150
152
        return self
151
153
    
152
 
    def __exit__ (self, exc_type, exc_value, traceback):
 
154
    def __exit__(self, exc_type, exc_value, traceback):
153
155
        self._cleanup()
154
156
        return False
155
157
    
377
379
                                 self.server_state_changed)
378
380
        self.server_state_changed(self.server.GetState())
379
381
 
 
382
 
380
383
class AvahiServiceToSyslog(AvahiService):
381
384
    def rename(self):
382
385
        """Add the new name to the syslog messages"""
387
390
                                .format(self.name)))
388
391
        return ret
389
392
 
 
393
 
390
394
def timedelta_to_milliseconds(td):
391
395
    "Convert a datetime.timedelta() to milliseconds"
392
396
    return ((td.days * 24 * 60 * 60 * 1000)
393
397
            + (td.seconds * 1000)
394
398
            + (td.microseconds // 1000))
395
399
 
 
400
 
396
401
class Client(object):
397
402
    """A representation of a client host served by this server.
398
403
    
437
442
    """
438
443
    
439
444
    runtime_expansions = ("approval_delay", "approval_duration",
440
 
                          "created", "enabled", "fingerprint",
441
 
                          "host", "interval", "last_checked_ok",
 
445
                          "created", "enabled", "expires",
 
446
                          "fingerprint", "host", "interval",
 
447
                          "last_approval_request", "last_checked_ok",
442
448
                          "last_enabled", "name", "timeout")
443
449
    client_defaults = { "timeout": "5m",
444
450
                        "extended_timeout": "15m",
570
576
        if getattr(self, "enabled", False):
571
577
            # Already enabled
572
578
            return
573
 
        self.send_changedstate()
574
579
        self.expires = datetime.datetime.utcnow() + self.timeout
575
580
        self.enabled = True
576
581
        self.last_enabled = datetime.datetime.utcnow()
577
582
        self.init_checker()
 
583
        self.send_changedstate()
578
584
    
579
585
    def disable(self, quiet=True):
580
586
        """Disable this client."""
581
587
        if not getattr(self, "enabled", False):
582
588
            return False
583
589
        if not quiet:
584
 
            self.send_changedstate()
585
 
        if not quiet:
586
590
            logger.info("Disabling client %s", self.name)
587
 
        if getattr(self, "disable_initiator_tag", False):
 
591
        if getattr(self, "disable_initiator_tag", None) is not None:
588
592
            gobject.source_remove(self.disable_initiator_tag)
589
593
            self.disable_initiator_tag = None
590
594
        self.expires = None
591
 
        if getattr(self, "checker_initiator_tag", False):
 
595
        if getattr(self, "checker_initiator_tag", None) is not None:
592
596
            gobject.source_remove(self.checker_initiator_tag)
593
597
            self.checker_initiator_tag = None
594
598
        self.stop_checker()
595
599
        self.enabled = False
 
600
        if not quiet:
 
601
            self.send_changedstate()
596
602
        # Do not run this again if called by a gobject.timeout_add
597
603
        return False
598
604
    
602
608
    def init_checker(self):
603
609
        # Schedule a new checker to be started an 'interval' from now,
604
610
        # and every interval from then on.
 
611
        if self.checker_initiator_tag is not None:
 
612
            gobject.source_remove(self.checker_initiator_tag)
605
613
        self.checker_initiator_tag = (gobject.timeout_add
606
614
                                      (self.interval_milliseconds(),
607
615
                                       self.start_checker))
608
616
        # Schedule a disable() when 'timeout' has passed
 
617
        if self.disable_initiator_tag is not None:
 
618
            gobject.source_remove(self.disable_initiator_tag)
609
619
        self.disable_initiator_tag = (gobject.timeout_add
610
620
                                   (self.timeout_milliseconds(),
611
621
                                    self.disable))
642
652
            timeout = self.timeout
643
653
        if self.disable_initiator_tag is not None:
644
654
            gobject.source_remove(self.disable_initiator_tag)
 
655
            self.disable_initiator_tag = None
645
656
        if getattr(self, "enabled", False):
646
657
            self.disable_initiator_tag = (gobject.timeout_add
647
658
                                          (timedelta_to_milliseconds
680
691
                                      self.current_checker_command)
681
692
        # Start a new checker if needed
682
693
        if self.checker is None:
 
694
            # Escape attributes for the shell
 
695
            escaped_attrs = dict(
 
696
                (attr, re.escape(unicode(getattr(self, attr))))
 
697
                for attr in
 
698
                self.runtime_expansions)
683
699
            try:
684
 
                # In case checker_command has exactly one % operator
685
 
                command = self.checker_command % self.host
686
 
            except TypeError:
687
 
                # Escape attributes for the shell
688
 
                escaped_attrs = dict(
689
 
                    (attr,
690
 
                     re.escape(unicode(str(getattr(self, attr, "")),
691
 
                                       errors=
692
 
                                       'replace')))
693
 
                    for attr in
694
 
                    self.runtime_expansions)
695
 
                
696
 
                try:
697
 
                    command = self.checker_command % escaped_attrs
698
 
                except TypeError as error:
699
 
                    logger.error('Could not format string "%s"',
700
 
                                 self.checker_command, exc_info=error)
701
 
                    return True # Try again later
 
700
                command = self.checker_command % escaped_attrs
 
701
            except TypeError as error:
 
702
                logger.error('Could not format string "%s"',
 
703
                             self.checker_command, exc_info=error)
 
704
                return True # Try again later
702
705
            self.current_checker_command = command
703
706
            try:
704
707
                logger.info("Starting checker %r for %s",
710
713
                self.checker = subprocess.Popen(command,
711
714
                                                close_fds=True,
712
715
                                                shell=True, cwd="/")
713
 
                self.checker_callback_tag = (gobject.child_watch_add
714
 
                                             (self.checker.pid,
715
 
                                              self.checker_callback,
716
 
                                              data=command))
717
 
                # The checker may have completed before the gobject
718
 
                # watch was added.  Check for this.
719
 
                pid, status = os.waitpid(self.checker.pid, os.WNOHANG)
720
 
                if pid:
721
 
                    gobject.source_remove(self.checker_callback_tag)
722
 
                    self.checker_callback(pid, status, command)
723
716
            except OSError as error:
724
717
                logger.error("Failed to start subprocess",
725
718
                             exc_info=error)
 
719
            self.checker_callback_tag = (gobject.child_watch_add
 
720
                                         (self.checker.pid,
 
721
                                          self.checker_callback,
 
722
                                          data=command))
 
723
            # The checker may have completed before the gobject
 
724
            # watch was added.  Check for this.
 
725
            pid, status = os.waitpid(self.checker.pid, os.WNOHANG)
 
726
            if pid:
 
727
                gobject.source_remove(self.checker_callback_tag)
 
728
                self.checker_callback(pid, status, command)
726
729
        # Re-run this periodically if run by gobject.timeout_add
727
730
        return True
728
731
    
1013
1016
        return xmlstring
1014
1017
 
1015
1018
 
1016
 
def datetime_to_dbus (dt, variant_level=0):
 
1019
def datetime_to_dbus(dt, variant_level=0):
1017
1020
    """Convert a UTC datetime.datetime() to a D-Bus type."""
1018
1021
    if dt is None:
1019
1022
        return dbus.String("", variant_level = variant_level)
1027
1030
    interface names according to the "alt_interface_names" mapping.
1028
1031
    Usage:
1029
1032
    
1030
 
    @alternate_dbus_names({"org.example.Interface":
1031
 
                               "net.example.AlternateInterface"})
 
1033
    @alternate_dbus_interfaces({"org.example.Interface":
 
1034
                                    "net.example.AlternateInterface"})
1032
1035
    class SampleDBusObject(dbus.service.Object):
1033
1036
        @dbus.service.method("org.example.Interface")
1034
1037
        def SampleDBusMethod():
1333
1336
        return False
1334
1337
    
1335
1338
    def approve(self, value=True):
1336
 
        self.send_changedstate()
1337
1339
        self.approved = value
1338
1340
        gobject.timeout_add(timedelta_to_milliseconds
1339
1341
                            (self.approval_duration),
1340
1342
                            self._reset_approved)
 
1343
        self.send_changedstate()
1341
1344
    
1342
1345
    ## D-Bus methods, signals & properties
1343
1346
    _interface = "se.recompile.Mandos.Client"
1527
1530
    def Timeout_dbus_property(self, value=None):
1528
1531
        if value is None:       # get
1529
1532
            return dbus.UInt64(self.timeout_milliseconds())
 
1533
        old_timeout = self.timeout
1530
1534
        self.timeout = datetime.timedelta(0, 0, 0, value)
1531
 
        # Reschedule timeout
 
1535
        # Reschedule disabling
1532
1536
        if self.enabled:
1533
1537
            now = datetime.datetime.utcnow()
1534
 
            time_to_die = timedelta_to_milliseconds(
1535
 
                (self.last_checked_ok + self.timeout) - now)
1536
 
            if time_to_die <= 0:
 
1538
            self.expires += self.timeout - old_timeout
 
1539
            if self.expires <= now:
1537
1540
                # The timeout has passed
1538
1541
                self.disable()
1539
1542
            else:
1540
 
                self.expires = (now +
1541
 
                                datetime.timedelta(milliseconds =
1542
 
                                                   time_to_die))
1543
1543
                if (getattr(self, "disable_initiator_tag", None)
1544
1544
                    is None):
1545
1545
                    return
1546
1546
                gobject.source_remove(self.disable_initiator_tag)
1547
 
                self.disable_initiator_tag = (gobject.timeout_add
1548
 
                                              (time_to_die,
1549
 
                                               self.disable))
 
1547
                self.disable_initiator_tag = (
 
1548
                    gobject.timeout_add(
 
1549
                        timedelta_to_milliseconds(self.expires - now),
 
1550
                        self.disable))
1550
1551
    
1551
1552
    # ExtendedTimeout - property
1552
1553
    @dbus_service_property(_interface, signature="t",
1740
1741
                    #wait until timeout or approved
1741
1742
                    time = datetime.datetime.now()
1742
1743
                    client.changedstate.acquire()
1743
 
                    (client.changedstate.wait
1744
 
                     (float(client.timedelta_to_milliseconds(delay)
1745
 
                            / 1000)))
 
1744
                    client.changedstate.wait(
 
1745
                        float(timedelta_to_milliseconds(delay)
 
1746
                              / 1000))
1746
1747
                    client.changedstate.release()
1747
1748
                    time2 = datetime.datetime.now()
1748
1749
                    if (time2 - time) >= delay:
1864
1865
    def process_request(self, request, address):
1865
1866
        """Start a new process to process the request."""
1866
1867
        proc = multiprocessing.Process(target = self.sub_process_main,
1867
 
                                       args = (request,
1868
 
                                               address))
 
1868
                                       args = (request, address))
1869
1869
        proc.start()
1870
1870
        return proc
1871
1871
 
1899
1899
        use_ipv6:       Boolean; to use IPv6 or not
1900
1900
    """
1901
1901
    def __init__(self, server_address, RequestHandlerClass,
1902
 
                 interface=None, use_ipv6=True):
 
1902
                 interface=None, use_ipv6=True, socketfd=None):
 
1903
        """If socketfd is set, use that file descriptor instead of
 
1904
        creating a new one with socket.socket().
 
1905
        """
1903
1906
        self.interface = interface
1904
1907
        if use_ipv6:
1905
1908
            self.address_family = socket.AF_INET6
 
1909
        if socketfd is not None:
 
1910
            # Save the file descriptor
 
1911
            self.socketfd = socketfd
 
1912
            # Save the original socket.socket() function
 
1913
            self.socket_socket = socket.socket
 
1914
            # To implement --socket, we monkey patch socket.socket.
 
1915
            # 
 
1916
            # (When socketserver.TCPServer is a new-style class, we
 
1917
            # could make self.socket into a property instead of monkey
 
1918
            # patching socket.socket.)
 
1919
            # 
 
1920
            # Create a one-time-only replacement for socket.socket()
 
1921
            @functools.wraps(socket.socket)
 
1922
            def socket_wrapper(*args, **kwargs):
 
1923
                # Restore original function so subsequent calls are
 
1924
                # not affected.
 
1925
                socket.socket = self.socket_socket
 
1926
                del self.socket_socket
 
1927
                # This time only, return a new socket object from the
 
1928
                # saved file descriptor.
 
1929
                return socket.fromfd(self.socketfd, *args, **kwargs)
 
1930
            # Replace socket.socket() function with wrapper
 
1931
            socket.socket = socket_wrapper
 
1932
        # The socketserver.TCPServer.__init__ will call
 
1933
        # socket.socket(), which might be our replacement,
 
1934
        # socket_wrapper(), if socketfd was set.
1906
1935
        socketserver.TCPServer.__init__(self, server_address,
1907
1936
                                        RequestHandlerClass)
 
1937
    
1908
1938
    def server_bind(self):
1909
1939
        """This overrides the normal server_bind() function
1910
1940
        to bind to an interface if one was specified, and also NOT to
1921
1951
                                           str(self.interface
1922
1952
                                               + '\0'))
1923
1953
                except socket.error as error:
1924
 
                    if error[0] == errno.EPERM:
 
1954
                    if error.errno == errno.EPERM:
1925
1955
                        logger.error("No permission to"
1926
1956
                                     " bind to interface %s",
1927
1957
                                     self.interface)
1928
 
                    elif error[0] == errno.ENOPROTOOPT:
 
1958
                    elif error.errno == errno.ENOPROTOOPT:
1929
1959
                        logger.error("SO_BINDTODEVICE not available;"
1930
1960
                                     " cannot bind to interface %s",
1931
1961
                                     self.interface)
 
1962
                    elif error.errno == errno.ENODEV:
 
1963
                        logger.error("Interface %s does not"
 
1964
                                     " exist, cannot bind",
 
1965
                                     self.interface)
1932
1966
                    else:
1933
1967
                        raise
1934
1968
        # Only bind(2) the socket if we really need to.
1964
1998
    """
1965
1999
    def __init__(self, server_address, RequestHandlerClass,
1966
2000
                 interface=None, use_ipv6=True, clients=None,
1967
 
                 gnutls_priority=None, use_dbus=True):
 
2001
                 gnutls_priority=None, use_dbus=True, socketfd=None):
1968
2002
        self.enabled = False
1969
2003
        self.clients = clients
1970
2004
        if self.clients is None:
1974
2008
        IPv6_TCPServer.__init__(self, server_address,
1975
2009
                                RequestHandlerClass,
1976
2010
                                interface = interface,
1977
 
                                use_ipv6 = use_ipv6)
 
2011
                                use_ipv6 = use_ipv6,
 
2012
                                socketfd = socketfd)
1978
2013
    def server_activate(self):
1979
2014
        if self.enabled:
1980
2015
            return socketserver.TCPServer.server_activate(self)
2161
2196
    parser.add_argument("--no-restore", action="store_false",
2162
2197
                        dest="restore", help="Do not restore stored"
2163
2198
                        " state")
 
2199
    parser.add_argument("--socket", type=int,
 
2200
                        help="Specify a file descriptor to a network"
 
2201
                        " socket to use instead of creating one")
2164
2202
    parser.add_argument("--statedir", metavar="DIR",
2165
2203
                        help="Directory to save/restore state in")
2166
2204
    
2183
2221
                        "use_ipv6": "True",
2184
2222
                        "debuglevel": "",
2185
2223
                        "restore": "True",
 
2224
                        "socket": "",
2186
2225
                        "statedir": "/var/lib/mandos"
2187
2226
                        }
2188
2227
    
2200
2239
    if server_settings["port"]:
2201
2240
        server_settings["port"] = server_config.getint("DEFAULT",
2202
2241
                                                       "port")
 
2242
    if server_settings["socket"]:
 
2243
        server_settings["socket"] = server_config.getint("DEFAULT",
 
2244
                                                         "socket")
 
2245
        # Later, stdin will, and stdout and stderr might, be dup'ed
 
2246
        # over with an opened os.devnull.  But we don't want this to
 
2247
        # happen with a supplied network socket.
 
2248
        if 0 <= server_settings["socket"] <= 2:
 
2249
            server_settings["socket"] = os.dup(server_settings
 
2250
                                               ["socket"])
2203
2251
    del server_config
2204
2252
    
2205
2253
    # Override the settings from the config file with command line
2207
2255
    for option in ("interface", "address", "port", "debug",
2208
2256
                   "priority", "servicename", "configdir",
2209
2257
                   "use_dbus", "use_ipv6", "debuglevel", "restore",
2210
 
                   "statedir"):
 
2258
                   "statedir", "socket"):
2211
2259
        value = getattr(options, option)
2212
2260
        if value is not None:
2213
2261
            server_settings[option] = value
2261
2309
                              use_ipv6=use_ipv6,
2262
2310
                              gnutls_priority=
2263
2311
                              server_settings["priority"],
2264
 
                              use_dbus=use_dbus)
 
2312
                              use_dbus=use_dbus,
 
2313
                              socketfd=(server_settings["socket"]
 
2314
                                        or None))
2265
2315
    if not debug:
2266
2316
        pidfilename = "/var/run/mandos.pid"
2267
2317
        try:
2284
2334
        os.setgid(gid)
2285
2335
        os.setuid(uid)
2286
2336
    except OSError as error:
2287
 
        if error[0] != errno.EPERM:
 
2337
        if error.errno != errno.EPERM:
2288
2338
            raise error
2289
2339
    
2290
2340
    if debug:
2312
2362
        # Close all input and output, do double fork, etc.
2313
2363
        daemon()
2314
2364
    
 
2365
    # multiprocessing will use threads, so before we use gobject we
 
2366
    # need to inform gobject that threads will be used.
2315
2367
    gobject.threads_init()
2316
2368
    
2317
2369
    global main_loop
2458
2510
            # "pidfile" was never created
2459
2511
            pass
2460
2512
        del pidfilename
2461
 
        signal.signal(signal.SIGINT, signal.SIG_IGN)
2462
2513
    
2463
2514
    signal.signal(signal.SIGHUP, lambda signum, frame: sys.exit())
2464
2515
    signal.signal(signal.SIGTERM, lambda signum, frame: sys.exit())
2564
2615
                del client_settings[client.name]["secret"]
2565
2616
        
2566
2617
        try:
2567
 
            tempfd, tempname = tempfile.mkstemp(suffix=".pickle",
2568
 
                                                prefix="clients-",
2569
 
                                                dir=os.path.dirname
2570
 
                                                (stored_state_path))
2571
 
            with os.fdopen(tempfd, "wb") as stored_state:
 
2618
            with (tempfile.NamedTemporaryFile
 
2619
                  (mode='wb', suffix=".pickle", prefix='clients-',
 
2620
                   dir=os.path.dirname(stored_state_path),
 
2621
                   delete=False)) as stored_state:
2572
2622
                pickle.dump((clients, client_settings), stored_state)
 
2623
                tempname=stored_state.name
2573
2624
            os.rename(tempname, stored_state_path)
2574
2625
        except (IOError, OSError) as e:
2575
2626
            if not debug: