/mandos/trunk

To get this branch, use:
bzr branch http://bzr.recompile.se/loggerhead/mandos/trunk

« back to all changes in this revision

Viewing changes to mandos

  • Committer: Teddy Hogeborn
  • Date: 2021-02-03 23:10:42 UTC
  • Revision ID: teddy@recompile.se-20210203231042-2z3egrvpo1zt7nej
mandos-ctl: Fix bad test for command.Remove and related minor issues

The test for command.Remove removes all clients from the spy server,
and then loops over all clients, looking for the corresponding Remove
command as recorded by the spy server.  But since since there aren't
any clients left after they were removed, no assertions are made, and
the test therefore does nothing.  Fix this.

In tests for command.Approve and command.Deny, add checks that clients
were not somehow removed by the command (in which case, likewise, no
assertions are made).

Add related checks to TestPropertySetterCmd.runTest; i.e. test that a
sequence is not empty before looping over it and making assertions.

* mandos-ctl (TestBaseCommands.test_Remove): Save a copy of the
  original "clients" dict, and loop over those instead.  Add assertion
  that all clients were indeed removed.  Also fix the code which looks
  for the Remove command, which now needs to actually work.
  (TestBaseCommands.test_Approve, TestBaseCommands.test_Deny): Add
  assertion that there are still clients before looping over them.
  (TestPropertySetterCmd.runTest): Add assertion that the list of
  values to get is not empty before looping over them.  Also add check
  that there are still clients before looping over clients.

Show diffs side-by-side

added added

removed removed

Lines of Context:
563
563
    OPENPGP_FMT_RAW = 0         # gnutls/openpgp.h
564
564
 
565
565
    # Types
566
 
    class _session_int(ctypes.Structure):
 
566
    class session_int(ctypes.Structure):
567
567
        _fields_ = []
568
 
    session_t = ctypes.POINTER(_session_int)
 
568
    session_t = ctypes.POINTER(session_int)
569
569
 
570
570
    class certificate_credentials_st(ctypes.Structure):
571
571
        _fields_ = []
577
577
        _fields_ = [('data', ctypes.POINTER(ctypes.c_ubyte)),
578
578
                    ('size', ctypes.c_uint)]
579
579
 
580
 
    class _openpgp_crt_int(ctypes.Structure):
 
580
    class openpgp_crt_int(ctypes.Structure):
581
581
        _fields_ = []
582
 
    openpgp_crt_t = ctypes.POINTER(_openpgp_crt_int)
 
582
    openpgp_crt_t = ctypes.POINTER(openpgp_crt_int)
583
583
    openpgp_crt_fmt_t = ctypes.c_int  # gnutls/openpgp.h
584
584
    log_func = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p)
585
585
    credentials_type_t = ctypes.c_int
594
594
            # gnutls.strerror()
595
595
            self.code = code
596
596
            if message is None and code is not None:
597
 
                message = gnutls.strerror(code).decode(
598
 
                    "utf-8", errors="replace")
 
597
                message = gnutls.strerror(code)
599
598
            return super(gnutls.Error, self).__init__(
600
599
                message, *args)
601
600
 
602
601
    class CertificateSecurityError(Error):
603
602
        pass
604
603
 
605
 
    class PointerTo:
606
 
        def __init__(self, cls):
607
 
            self.cls = cls
608
 
 
609
 
        def from_param(self, obj):
610
 
            if not isinstance(obj, self.cls):
611
 
                raise TypeError("Not of type {}: {!r}"
612
 
                                .format(self.cls.__name__, obj))
613
 
            return ctypes.byref(obj.from_param(obj))
614
 
 
615
 
    class CastToVoidPointer:
616
 
        def __init__(self, cls):
617
 
            self.cls = cls
618
 
 
619
 
        def from_param(self, obj):
620
 
            if not isinstance(obj, self.cls):
621
 
                raise TypeError("Not of type {}: {!r}"
622
 
                                .format(self.cls.__name__, obj))
623
 
            return ctypes.cast(obj.from_param(obj), ctypes.c_void_p)
624
 
 
625
 
    class With_from_param:
626
 
        @classmethod
627
 
        def from_param(cls, obj):
628
 
            return obj._as_parameter_
629
 
 
630
604
    # Classes
631
 
    class Credentials(With_from_param):
 
605
    class Credentials:
632
606
        def __init__(self):
633
 
            self._as_parameter_ = gnutls.certificate_credentials_t()
634
 
            gnutls.certificate_allocate_credentials(self)
 
607
            self._c_object = gnutls.certificate_credentials_t()
 
608
            gnutls.certificate_allocate_credentials(
 
609
                ctypes.byref(self._c_object))
635
610
            self.type = gnutls.CRD_CERTIFICATE
636
611
 
637
612
        def __del__(self):
638
 
            gnutls.certificate_free_credentials(self)
 
613
            gnutls.certificate_free_credentials(self._c_object)
639
614
 
640
 
    class ClientSession(With_from_param):
 
615
    class ClientSession:
641
616
        def __init__(self, socket, credentials=None):
642
 
            self._as_parameter_ = gnutls.session_t()
 
617
            self._c_object = gnutls.session_t()
643
618
            gnutls_flags = gnutls.CLIENT
644
619
            if gnutls.check_version(b"3.5.6"):
645
620
                gnutls_flags |= gnutls.NO_TICKETS
646
621
            if gnutls.has_rawpk:
647
622
                gnutls_flags |= gnutls.ENABLE_RAWPK
648
 
            gnutls.init(self, gnutls_flags)
 
623
            gnutls.init(ctypes.byref(self._c_object), gnutls_flags)
649
624
            del gnutls_flags
650
 
            gnutls.set_default_priority(self)
651
 
            gnutls.transport_set_ptr(self, socket.fileno())
652
 
            gnutls.handshake_set_private_extensions(self, True)
 
625
            gnutls.set_default_priority(self._c_object)
 
626
            gnutls.transport_set_ptr(self._c_object, socket.fileno())
 
627
            gnutls.handshake_set_private_extensions(self._c_object,
 
628
                                                    True)
653
629
            self.socket = socket
654
630
            if credentials is None:
655
631
                credentials = gnutls.Credentials()
656
 
            gnutls.credentials_set(self, credentials.type,
657
 
                                   credentials)
 
632
            gnutls.credentials_set(self._c_object, credentials.type,
 
633
                                   ctypes.cast(credentials._c_object,
 
634
                                               ctypes.c_void_p))
658
635
            self.credentials = credentials
659
636
 
660
637
        def __del__(self):
661
 
            gnutls.deinit(self)
 
638
            gnutls.deinit(self._c_object)
662
639
 
663
640
        def handshake(self):
664
 
            return gnutls.handshake(self)
 
641
            return gnutls.handshake(self._c_object)
665
642
 
666
643
        def send(self, data):
667
644
            data = bytes(data)
668
645
            data_len = len(data)
669
646
            while data_len > 0:
670
 
                data_len -= gnutls.record_send(self, data[-data_len:],
 
647
                data_len -= gnutls.record_send(self._c_object,
 
648
                                               data[-data_len:],
671
649
                                               data_len)
672
650
 
673
651
        def bye(self):
674
 
            return gnutls.bye(self, gnutls.SHUT_RDWR)
 
652
            return gnutls.bye(self._c_object, gnutls.SHUT_RDWR)
675
653
 
676
654
    # Error handling functions
677
655
    def _error_code(result):
678
656
        """A function to raise exceptions on errors, suitable
679
657
        for the 'restype' attribute on ctypes functions"""
680
 
        if result >= gnutls.E_SUCCESS:
 
658
        if result >= 0:
681
659
            return result
682
660
        if result == gnutls.E_NO_CERTIFICATE_FOUND:
683
661
            raise gnutls.CertificateSecurityError(code=result)
684
662
        raise gnutls.Error(code=result)
685
663
 
686
 
    def _retry_on_error(result, func, arguments,
687
 
                        _error_code=_error_code):
 
664
    def _retry_on_error(result, func, arguments):
688
665
        """A function to retry on some errors, suitable
689
666
        for the 'errcheck' attribute on ctypes functions"""
690
 
        while result < gnutls.E_SUCCESS:
 
667
        while result < 0:
691
668
            if result not in (gnutls.E_INTERRUPTED, gnutls.E_AGAIN):
692
669
                return _error_code(result)
693
670
            result = func(*arguments)
698
675
 
699
676
    # Functions
700
677
    priority_set_direct = _library.gnutls_priority_set_direct
701
 
    priority_set_direct.argtypes = [ClientSession, ctypes.c_char_p,
 
678
    priority_set_direct.argtypes = [session_t, ctypes.c_char_p,
702
679
                                    ctypes.POINTER(ctypes.c_char_p)]
703
680
    priority_set_direct.restype = _error_code
704
681
 
705
682
    init = _library.gnutls_init
706
 
    init.argtypes = [PointerTo(ClientSession), ctypes.c_int]
 
683
    init.argtypes = [ctypes.POINTER(session_t), ctypes.c_int]
707
684
    init.restype = _error_code
708
685
 
709
686
    set_default_priority = _library.gnutls_set_default_priority
710
 
    set_default_priority.argtypes = [ClientSession]
 
687
    set_default_priority.argtypes = [session_t]
711
688
    set_default_priority.restype = _error_code
712
689
 
713
690
    record_send = _library.gnutls_record_send
714
 
    record_send.argtypes = [ClientSession, ctypes.c_void_p,
 
691
    record_send.argtypes = [session_t, ctypes.c_void_p,
715
692
                            ctypes.c_size_t]
716
693
    record_send.restype = ctypes.c_ssize_t
717
694
    record_send.errcheck = _retry_on_error
719
696
    certificate_allocate_credentials = (
720
697
        _library.gnutls_certificate_allocate_credentials)
721
698
    certificate_allocate_credentials.argtypes = [
722
 
        PointerTo(Credentials)]
 
699
        ctypes.POINTER(certificate_credentials_t)]
723
700
    certificate_allocate_credentials.restype = _error_code
724
701
 
725
702
    certificate_free_credentials = (
726
703
        _library.gnutls_certificate_free_credentials)
727
 
    certificate_free_credentials.argtypes = [Credentials]
 
704
    certificate_free_credentials.argtypes = [
 
705
        certificate_credentials_t]
728
706
    certificate_free_credentials.restype = None
729
707
 
730
708
    handshake_set_private_extensions = (
731
709
        _library.gnutls_handshake_set_private_extensions)
732
 
    handshake_set_private_extensions.argtypes = [ClientSession,
 
710
    handshake_set_private_extensions.argtypes = [session_t,
733
711
                                                 ctypes.c_int]
734
712
    handshake_set_private_extensions.restype = None
735
713
 
736
714
    credentials_set = _library.gnutls_credentials_set
737
 
    credentials_set.argtypes = [ClientSession, credentials_type_t,
738
 
                                CastToVoidPointer(Credentials)]
 
715
    credentials_set.argtypes = [session_t, credentials_type_t,
 
716
                                ctypes.c_void_p]
739
717
    credentials_set.restype = _error_code
740
718
 
741
719
    strerror = _library.gnutls_strerror
743
721
    strerror.restype = ctypes.c_char_p
744
722
 
745
723
    certificate_type_get = _library.gnutls_certificate_type_get
746
 
    certificate_type_get.argtypes = [ClientSession]
 
724
    certificate_type_get.argtypes = [session_t]
747
725
    certificate_type_get.restype = _error_code
748
726
 
749
727
    certificate_get_peers = _library.gnutls_certificate_get_peers
750
 
    certificate_get_peers.argtypes = [ClientSession,
 
728
    certificate_get_peers.argtypes = [session_t,
751
729
                                      ctypes.POINTER(ctypes.c_uint)]
752
730
    certificate_get_peers.restype = ctypes.POINTER(datum_t)
753
731
 
760
738
    global_set_log_function.restype = None
761
739
 
762
740
    deinit = _library.gnutls_deinit
763
 
    deinit.argtypes = [ClientSession]
 
741
    deinit.argtypes = [session_t]
764
742
    deinit.restype = None
765
743
 
766
744
    handshake = _library.gnutls_handshake
767
 
    handshake.argtypes = [ClientSession]
768
 
    handshake.restype = ctypes.c_int
 
745
    handshake.argtypes = [session_t]
 
746
    handshake.restype = _error_code
769
747
    handshake.errcheck = _retry_on_error
770
748
 
771
749
    transport_set_ptr = _library.gnutls_transport_set_ptr
772
 
    transport_set_ptr.argtypes = [ClientSession, transport_ptr_t]
 
750
    transport_set_ptr.argtypes = [session_t, transport_ptr_t]
773
751
    transport_set_ptr.restype = None
774
752
 
775
753
    bye = _library.gnutls_bye
776
 
    bye.argtypes = [ClientSession, close_request_t]
777
 
    bye.restype = ctypes.c_int
 
754
    bye.argtypes = [session_t, close_request_t]
 
755
    bye.restype = _error_code
778
756
    bye.errcheck = _retry_on_error
779
757
 
780
758
    check_version = _library.gnutls_check_version
854
832
 
855
833
    if check_version(b"3.6.4"):
856
834
        certificate_type_get2 = _library.gnutls_certificate_type_get2
857
 
        certificate_type_get2.argtypes = [ClientSession, ctypes.c_int]
 
835
        certificate_type_get2.argtypes = [session_t, ctypes.c_int]
858
836
        certificate_type_get2.restype = _error_code
859
837
 
860
838
    # Remove non-public functions
2320
2298
            priority = self.server.gnutls_priority
2321
2299
            if priority is None:
2322
2300
                priority = "NORMAL"
2323
 
            gnutls.priority_set_direct(session,
2324
 
                                       priority.encode("utf-8"), None)
 
2301
            gnutls.priority_set_direct(session._c_object,
 
2302
                                       priority.encode("utf-8"),
 
2303
                                       None)
2325
2304
 
2326
2305
            # Start communication using the Mandos protocol
2327
2306
            # Get protocol number
2354
2333
                    except (TypeError, gnutls.Error) as error:
2355
2334
                        logger.warning("Bad certificate: %s", error)
2356
2335
                        return
2357
 
                    logger.debug("Key ID: %s",
2358
 
                                 key_id.decode("utf-8",
2359
 
                                               errors="replace"))
 
2336
                    logger.debug("Key ID: %s", key_id)
2360
2337
 
2361
2338
                else:
2362
2339
                    key_id = b""
2454
2431
    def peer_certificate(session):
2455
2432
        "Return the peer's certificate as a bytestring"
2456
2433
        try:
2457
 
            cert_type = gnutls.certificate_type_get2(
2458
 
                session, gnutls.CTYPE_PEERS)
 
2434
            cert_type = gnutls.certificate_type_get2(session._c_object,
 
2435
                                                     gnutls.CTYPE_PEERS)
2459
2436
        except AttributeError:
2460
 
            cert_type = gnutls.certificate_type_get(session)
 
2437
            cert_type = gnutls.certificate_type_get(session._c_object)
2461
2438
        if gnutls.has_rawpk:
2462
2439
            valid_cert_types = frozenset((gnutls.CRT_RAWPK,))
2463
2440
        else:
2470
2447
            return b""
2471
2448
        list_size = ctypes.c_uint(1)
2472
2449
        cert_list = (gnutls.certificate_get_peers
2473
 
                     (session, ctypes.byref(list_size)))
 
2450
                     (session._c_object, ctypes.byref(list_size)))
2474
2451
        if not bool(cert_list) and list_size.value != 0:
2475
2452
            raise gnutls.Error("error getting peer certificate")
2476
2453
        if list_size.value == 0:
3201
3178
 
3202
3179
        @gnutls.log_func
3203
3180
        def debug_gnutls(level, string):
3204
 
            logger.debug("GnuTLS: %s",
3205
 
                         string[:-1].decode("utf-8",
3206
 
                                            errors="replace"))
 
3181
            logger.debug("GnuTLS: %s", string[:-1])
3207
3182
 
3208
3183
        gnutls.global_set_log_function(debug_gnutls)
3209
3184