=== modified file 'mandos' --- mandos 2021-03-21 19:51:15 +0000 +++ mandos 2021-03-21 20:46:40 +0000 @@ -563,9 +563,9 @@ OPENPGP_FMT_RAW = 0 # gnutls/openpgp.h # Types - class session_int(ctypes.Structure): + class _session_int(ctypes.Structure): _fields_ = [] - session_t = ctypes.POINTER(session_int) + session_t = ctypes.POINTER(_session_int) class certificate_credentials_st(ctypes.Structure): _fields_ = [] @@ -577,9 +577,9 @@ _fields_ = [('data', ctypes.POINTER(ctypes.c_ubyte)), ('size', ctypes.c_uint)] - class openpgp_crt_int(ctypes.Structure): + class _openpgp_crt_int(ctypes.Structure): _fields_ = [] - openpgp_crt_t = ctypes.POINTER(openpgp_crt_int) + openpgp_crt_t = ctypes.POINTER(_openpgp_crt_int) openpgp_crt_fmt_t = ctypes.c_int # gnutls/openpgp.h log_func = ctypes.CFUNCTYPE(None, ctypes.c_int, ctypes.c_char_p) credentials_type_t = ctypes.c_int @@ -594,68 +594,90 @@ # gnutls.strerror() self.code = code if message is None and code is not None: - message = gnutls.strerror(code) + message = gnutls.strerror(code).decode( + "utf-8", errors="replace") return super(gnutls.Error, self).__init__( message, *args) class CertificateSecurityError(Error): pass + class PointerTo: + def __init__(self, cls): + self.cls = cls + + def from_param(self, obj): + if not isinstance(obj, self.cls): + raise TypeError("Not of type {}: {!r}" + .format(self.cls.__name__, obj)) + return ctypes.byref(obj.from_param(obj)) + + class CastToVoidPointer: + def __init__(self, cls): + self.cls = cls + + def from_param(self, obj): + if not isinstance(obj, self.cls): + raise TypeError("Not of type {}: {!r}" + .format(self.cls.__name__, obj)) + return ctypes.cast(obj.from_param(obj), ctypes.c_void_p) + + class With_from_param: + @classmethod + def from_param(cls, obj): + return obj._as_parameter_ + # Classes - class Credentials: + class Credentials(With_from_param): def __init__(self): - self._c_object = gnutls.certificate_credentials_t() - gnutls.certificate_allocate_credentials( - ctypes.byref(self._c_object)) + self._as_parameter_ = gnutls.certificate_credentials_t() + gnutls.certificate_allocate_credentials(self) self.type = gnutls.CRD_CERTIFICATE def __del__(self): - gnutls.certificate_free_credentials(self._c_object) + gnutls.certificate_free_credentials(self) - class ClientSession: + class ClientSession(With_from_param): def __init__(self, socket, credentials=None): - self._c_object = gnutls.session_t() + self._as_parameter_ = gnutls.session_t() gnutls_flags = gnutls.CLIENT if gnutls.check_version(b"3.5.6"): gnutls_flags |= gnutls.NO_TICKETS if gnutls.has_rawpk: gnutls_flags |= gnutls.ENABLE_RAWPK - gnutls.init(ctypes.byref(self._c_object), gnutls_flags) + gnutls.init(self, gnutls_flags) del gnutls_flags - gnutls.set_default_priority(self._c_object) - gnutls.transport_set_ptr(self._c_object, socket.fileno()) - gnutls.handshake_set_private_extensions(self._c_object, - True) + gnutls.set_default_priority(self) + gnutls.transport_set_ptr(self, socket.fileno()) + gnutls.handshake_set_private_extensions(self, True) self.socket = socket if credentials is None: credentials = gnutls.Credentials() - gnutls.credentials_set(self._c_object, credentials.type, - ctypes.cast(credentials._c_object, - ctypes.c_void_p)) + gnutls.credentials_set(self, credentials.type, + credentials) self.credentials = credentials def __del__(self): - gnutls.deinit(self._c_object) + gnutls.deinit(self) def handshake(self): - return gnutls.handshake(self._c_object) + return gnutls.handshake(self) def send(self, data): data = bytes(data) data_len = len(data) while data_len > 0: - data_len -= gnutls.record_send(self._c_object, - data[-data_len:], + data_len -= gnutls.record_send(self, data[-data_len:], data_len) def bye(self): - return gnutls.bye(self._c_object, gnutls.SHUT_RDWR) + return gnutls.bye(self, gnutls.SHUT_RDWR) # Error handling functions def _error_code(result): """A function to raise exceptions on errors, suitable for the 'restype' attribute on ctypes functions""" - if result >= 0: + if result >= gnutls.E_SUCCESS: return result if result == gnutls.E_NO_CERTIFICATE_FOUND: raise gnutls.CertificateSecurityError(code=result) @@ -665,7 +687,7 @@ _error_code=_error_code): """A function to retry on some errors, suitable for the 'errcheck' attribute on ctypes functions""" - while result < 0: + while result < gnutls.E_SUCCESS: if result not in (gnutls.E_INTERRUPTED, gnutls.E_AGAIN): return _error_code(result) result = func(*arguments) @@ -676,20 +698,20 @@ # Functions priority_set_direct = _library.gnutls_priority_set_direct - priority_set_direct.argtypes = [session_t, ctypes.c_char_p, + priority_set_direct.argtypes = [ClientSession, ctypes.c_char_p, ctypes.POINTER(ctypes.c_char_p)] priority_set_direct.restype = _error_code init = _library.gnutls_init - init.argtypes = [ctypes.POINTER(session_t), ctypes.c_int] + init.argtypes = [PointerTo(ClientSession), ctypes.c_int] init.restype = _error_code set_default_priority = _library.gnutls_set_default_priority - set_default_priority.argtypes = [session_t] + set_default_priority.argtypes = [ClientSession] set_default_priority.restype = _error_code record_send = _library.gnutls_record_send - record_send.argtypes = [session_t, ctypes.c_void_p, + record_send.argtypes = [ClientSession, ctypes.c_void_p, ctypes.c_size_t] record_send.restype = ctypes.c_ssize_t record_send.errcheck = _retry_on_error @@ -697,24 +719,23 @@ certificate_allocate_credentials = ( _library.gnutls_certificate_allocate_credentials) certificate_allocate_credentials.argtypes = [ - ctypes.POINTER(certificate_credentials_t)] + PointerTo(Credentials)] certificate_allocate_credentials.restype = _error_code certificate_free_credentials = ( _library.gnutls_certificate_free_credentials) - certificate_free_credentials.argtypes = [ - certificate_credentials_t] + certificate_free_credentials.argtypes = [Credentials] certificate_free_credentials.restype = None handshake_set_private_extensions = ( _library.gnutls_handshake_set_private_extensions) - handshake_set_private_extensions.argtypes = [session_t, + handshake_set_private_extensions.argtypes = [ClientSession, ctypes.c_int] handshake_set_private_extensions.restype = None credentials_set = _library.gnutls_credentials_set - credentials_set.argtypes = [session_t, credentials_type_t, - ctypes.c_void_p] + credentials_set.argtypes = [ClientSession, credentials_type_t, + CastToVoidPointer(Credentials)] credentials_set.restype = _error_code strerror = _library.gnutls_strerror @@ -722,11 +743,11 @@ strerror.restype = ctypes.c_char_p certificate_type_get = _library.gnutls_certificate_type_get - certificate_type_get.argtypes = [session_t] + certificate_type_get.argtypes = [ClientSession] certificate_type_get.restype = _error_code certificate_get_peers = _library.gnutls_certificate_get_peers - certificate_get_peers.argtypes = [session_t, + certificate_get_peers.argtypes = [ClientSession, ctypes.POINTER(ctypes.c_uint)] certificate_get_peers.restype = ctypes.POINTER(datum_t) @@ -739,21 +760,21 @@ global_set_log_function.restype = None deinit = _library.gnutls_deinit - deinit.argtypes = [session_t] + deinit.argtypes = [ClientSession] deinit.restype = None handshake = _library.gnutls_handshake - handshake.argtypes = [session_t] - handshake.restype = _error_code + handshake.argtypes = [ClientSession] + handshake.restype = ctypes.c_int handshake.errcheck = _retry_on_error transport_set_ptr = _library.gnutls_transport_set_ptr - transport_set_ptr.argtypes = [session_t, transport_ptr_t] + transport_set_ptr.argtypes = [ClientSession, transport_ptr_t] transport_set_ptr.restype = None bye = _library.gnutls_bye - bye.argtypes = [session_t, close_request_t] - bye.restype = _error_code + bye.argtypes = [ClientSession, close_request_t] + bye.restype = ctypes.c_int bye.errcheck = _retry_on_error check_version = _library.gnutls_check_version @@ -833,7 +854,7 @@ if check_version(b"3.6.4"): certificate_type_get2 = _library.gnutls_certificate_type_get2 - certificate_type_get2.argtypes = [session_t, ctypes.c_int] + certificate_type_get2.argtypes = [ClientSession, ctypes.c_int] certificate_type_get2.restype = _error_code # Remove non-public functions @@ -2299,9 +2320,8 @@ priority = self.server.gnutls_priority if priority is None: priority = "NORMAL" - gnutls.priority_set_direct(session._c_object, - priority.encode("utf-8"), - None) + gnutls.priority_set_direct(session, + priority.encode("utf-8"), None) # Start communication using the Mandos protocol # Get protocol number @@ -2334,7 +2354,9 @@ except (TypeError, gnutls.Error) as error: logger.warning("Bad certificate: %s", error) return - logger.debug("Key ID: %s", key_id) + logger.debug("Key ID: %s", + key_id.decode("utf-8", + errors="replace")) else: key_id = b"" @@ -2432,10 +2454,10 @@ def peer_certificate(session): "Return the peer's certificate as a bytestring" try: - cert_type = gnutls.certificate_type_get2(session._c_object, - gnutls.CTYPE_PEERS) + cert_type = gnutls.certificate_type_get2( + session, gnutls.CTYPE_PEERS) except AttributeError: - cert_type = gnutls.certificate_type_get(session._c_object) + cert_type = gnutls.certificate_type_get(session) if gnutls.has_rawpk: valid_cert_types = frozenset((gnutls.CRT_RAWPK,)) else: @@ -2448,7 +2470,7 @@ return b"" list_size = ctypes.c_uint(1) cert_list = (gnutls.certificate_get_peers - (session._c_object, ctypes.byref(list_size))) + (session, ctypes.byref(list_size))) if not bool(cert_list) and list_size.value != 0: raise gnutls.Error("error getting peer certificate") if list_size.value == 0: @@ -3179,7 +3201,9 @@ @gnutls.log_func def debug_gnutls(level, string): - logger.debug("GnuTLS: %s", string[:-1]) + logger.debug("GnuTLS: %s", + string[:-1].decode("utf-8", + errors="replace")) gnutls.global_set_log_function(debug_gnutls)