=== modified file 'mandos' --- mandos 2016-02-28 03:01:43 +0000 +++ mandos 2016-02-28 10:59:18 +0000 @@ -450,6 +450,8 @@ # Constants E_SUCCESS = 0 + E_INTERRUPTED = -52 + E_AGAIN = -28 CRT_OPENPGP = 2 CLIENT = 2 SHUT_RDWR = 0 @@ -483,11 +485,12 @@ # We need to use the class name "GnuTLS" here, since this # exception might be raised from within GnuTLS.__init__, # which is called before the assignment to the "gnutls" - # global variable happens. + # global variable has happened. def __init__(self, message = None, code = None, args=()): # Default usage is by a message string, but if a return # code is passed, convert it to a string with # gnutls.strerror() + self.code = code if message is None and code is not None: message = GnuTLS.strerror(code) return super(GnuTLS.Error, self).__init__( @@ -531,14 +534,16 @@ def send(self, data): data = bytes(data) - if not data: - return 0 - return gnutls.record_send(self._c_object, data, len(data)) + data_len = len(data) + while data_len > 0: + data_len -= gnutls.record_send(self._c_object, + data[-data_len:], + data_len) def bye(self): return gnutls.bye(self._c_object, gnutls.SHUT_RDWR) - # Error handling function + # Error handling functions def _error_code(result): """A function to raise exceptions on errors, suitable for the 'restype' attribute on ctypes functions""" @@ -548,6 +553,15 @@ raise gnutls.CertificateSecurityError(code = result) raise gnutls.Error(code = result) + def _retry_on_error(result, func, arguments): + """A function to retry on some errors, suitable + for the 'errcheck' attribute on ctypes functions""" + while result < 0: + if result not in (gnutls.E_INTERRUPTED, gnutls.E_AGAIN): + return _error_code(result) + result = func(*arguments) + return result + # Unless otherwise indicated, the function declarations below are # all from the gnutls/gnutls.h C header file. @@ -569,6 +583,7 @@ record_send.argtypes = [session_t, ctypes.c_void_p, ctypes.c_size_t] record_send.restype = ctypes.c_ssize_t + record_send.errcheck = _retry_on_error certificate_allocate_credentials = ( _library.gnutls_certificate_allocate_credentials) @@ -620,6 +635,7 @@ handshake = _library.gnutls_handshake handshake.argtypes = [session_t] handshake.restype = _error_code + handshake.errcheck = _retry_on_error transport_set_ptr = _library.gnutls_transport_set_ptr transport_set_ptr.argtypes = [session_t, transport_ptr_t] @@ -628,6 +644,7 @@ bye = _library.gnutls_bye bye.argtypes = [session_t, close_request_t] bye.restype = _error_code + bye.errcheck = _retry_on_error check_version = _library.gnutls_check_version check_version.argtypes = [ctypes.c_char_p] @@ -662,8 +679,8 @@ ctypes.c_size_t)] openpgp_crt_get_fingerprint.restype = _error_code - # Remove non-public function - del _error_code + # Remove non-public functions + del _error_code, _retry_on_error # Create the global "gnutls" object, simulating a module gnutls = GnuTLS() @@ -2215,18 +2232,12 @@ else: delay -= time2 - time - sent_size = 0 - while sent_size < len(client.secret): - try: - sent = session.send(client.secret[sent_size:]) - except gnutls.Error as error: - logger.warning("gnutls send failed", - exc_info=error) - return - logger.debug("Sent: %d, remaining: %d", sent, - len(client.secret) - (sent_size - + sent)) - sent_size += sent + try: + session.send(client.secret) + except gnutls.Error as error: + logger.warning("gnutls send failed", + exc_info = error) + return logger.info("Sending secret to %s", client.name) # bump the timeout using extended_timeout