# -*- coding: utf-8 -*- """ hyper/ssl_compat ~~~~~~~~~ Shoves pyOpenSSL into an API that looks like the standard Python 3.x ssl module. Currently exposes exactly those attributes, classes, and methods that we actually use in hyper (all method signatures are complete, however). May be expanded to something more general-purpose in the future. """ try: import StringIO as BytesIO except ImportError: from io import BytesIO import errno import socket import time from OpenSSL import SSL as ossl from service_identity.pyopenssl import verify_hostname as _verify CERT_NONE = ossl.VERIFY_NONE CERT_REQUIRED = ossl.VERIFY_PEER | ossl.VERIFY_FAIL_IF_NO_PEER_CERT _OPENSSL_ATTRS = dict( OP_NO_COMPRESSION='OP_NO_COMPRESSION', PROTOCOL_TLSv1_2='TLSv1_2_METHOD', PROTOCOL_SSLv23='SSLv23_METHOD', ) for external, internal in _OPENSSL_ATTRS.items(): value = getattr(ossl, internal, None) if value: locals()[external] = value OP_ALL = 0 # TODO: Find out the names of these other flags. for bit in [31] + list(range(10)): OP_ALL |= 1 << bit HAS_NPN = True def _proxy(method): def inner(self, *args, **kwargs): return getattr(self._conn, method)(*args, **kwargs) return inner # Referenced in hyper/http20/connection.py. These values come # from the python ssl package, and must be defined in this file # for hyper to work in python versions <2.7.9 SSL_ERROR_WANT_READ = 2 SSL_ERROR_WANT_WRITE = 3 # TODO missing some attributes class SSLError(OSError): pass class CertificateError(SSLError): pass def verify_hostname(ssl_sock, server_hostname): """ A method nearly compatible with the stdlib's match_hostname. """ if isinstance(server_hostname, bytes): server_hostname = server_hostname.decode('ascii') return _verify(ssl_sock._conn, server_hostname) class SSLSocket(object): SSL_TIMEOUT = 3 SSL_RETRY = .01 def __init__(self, conn, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, check_hostname): self._conn = conn self._do_handshake_on_connect = do_handshake_on_connect self._suppress_ragged_eofs = suppress_ragged_eofs self._check_hostname = check_hostname if server_side: self._conn.set_accept_state() else: if server_hostname: self._conn.set_tlsext_host_name( server_hostname.encode('utf-8') ) self._server_hostname = server_hostname # FIXME does this override do_handshake_on_connect=False? self._conn.set_connect_state() if self.connected and self._do_handshake_on_connect: self.do_handshake() @property def connected(self): try: self._conn.getpeername() except socket.error as e: if e.errno != errno.ENOTCONN: # It's an exception other than the one we expected if we're not # connected. raise return False return True # Lovingly stolen from CherryPy # (http://svn.cherrypy.org/tags/cherrypy-3.2.1/cherrypy/wsgiserver/ssl_pyopenssl.py). def _safe_ssl_call(self, suppress_ragged_eofs, call, *args, **kwargs): """Wrap the given call with SSL error-trapping.""" start = time.time() while True: try: return call(*args, **kwargs) except (ossl.WantReadError, ossl.WantWriteError): # Sleep and try again. This is dangerous, because it means # the rest of the stack has no way of differentiating # between a "new handshake" error and "client dropped". # Note this isn't an endless loop: there's a timeout below. time.sleep(self.SSL_RETRY) except ossl.Error as e: if suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'): return b'' raise socket.error(e.args[0]) if time.time() - start > self.SSL_TIMEOUT: raise socket.timeout('timed out') def connect(self, address): self._conn.connect(address) if self._do_handshake_on_connect: self.do_handshake() def do_handshake(self): self._safe_ssl_call(False, self._conn.do_handshake) if self._check_hostname: verify_hostname(self, self._server_hostname) def recv(self, bufsize, flags=None): return self._safe_ssl_call( self._suppress_ragged_eofs, self._conn.recv, bufsize, flags ) def recv_into(self, buffer, bufsize=None, flags=None): # A temporary recv_into implementation. Should be replaced when # PyOpenSSL has merged pyca/pyopenssl#121. if bufsize is None: bufsize = len(buffer) data = self.recv(bufsize, flags) data_len = len(data) buffer[0:data_len] = data return data_len def send(self, data, flags=None): return self._safe_ssl_call(False, self._conn.send, data, flags) def sendall(self, data, flags=None): return self._safe_ssl_call(False, self._conn.sendall, data, flags) def selected_npn_protocol(self): proto = self._conn.get_next_proto_negotiated() if isinstance(proto, bytes): proto = proto.decode('ascii') return proto if proto else None def selected_alpn_protocol(self): proto = self._conn.get_alpn_proto_negotiated() if isinstance(proto, bytes): proto = proto.decode('ascii') return proto if proto else None def getpeercert(self): def resolve_alias(alias): return dict( C='countryName', ST='stateOrProvinceName', L='localityName', O='organizationName', OU='organizationalUnitName', CN='commonName', ).get(alias, alias) def to_components(name): # TODO Verify that these are actually *supposed* to all be # single-element tuples, and that's not just a quirk of the # examples I've seen. return tuple( [ (resolve_alias(k.decode('utf-8'), v.decode('utf-8')),) for k, v in name.get_components() ] ) # The standard getpeercert() takes the nice X509 object tree returned # by OpenSSL and turns it into a dict according to some format it seems # to have made up on the spot. Here, we do our best to emulate that. cert = self._conn.get_peer_certificate() result = dict( issuer=to_components(cert.get_issuer()), subject=to_components(cert.get_subject()), version=cert.get_subject(), serialNumber=cert.get_serial_number(), notBefore=cert.get_notBefore(), notAfter=cert.get_notAfter(), ) # TODO extensions, including subjectAltName # (see _decode_certificate in _ssl.c) return result # a dash of magic to reduce boilerplate methods = ['accept', 'bind', 'close', 'getsockname', 'listen', 'fileno'] for method in methods: locals()[method] = _proxy(method) class SSLContext(object): def __init__(self, protocol): self.protocol = protocol self._ctx = ossl.Context(protocol) self.options = OP_ALL self.check_hostname = False self.npn_protos = [] @property def options(self): return self._options @options.setter def options(self, value): self._options = value self._ctx.set_options(value) @property def verify_mode(self): return self._ctx.get_verify_mode() @verify_mode.setter def verify_mode(self, value): # TODO verify exception is raised on failure self._ctx.set_verify( value, lambda conn, cert, errnum, errdepth, ok: ok ) def set_default_verify_paths(self): self._ctx.set_default_verify_paths() def load_verify_locations(self, cafile=None, capath=None, cadata=None): # TODO factor out common code if cafile is not None: cafile = cafile.encode('utf-8') if capath is not None: capath = capath.encode('utf-8') self._ctx.load_verify_locations(cafile, capath) if cadata is not None: self._ctx.load_verify_locations(BytesIO(cadata)) def load_cert_chain(self, certfile, keyfile=None, password=None): self._ctx.use_certificate_file(certfile) if password is not None: self._ctx.set_passwd_cb( lambda max_length, prompt_twice, userdata: password ) self._ctx.use_privatekey_file(keyfile or certfile) def set_npn_protocols(self, protocols): self.protocols = list(map(lambda x: x.encode('ascii'), protocols)) def cb(conn, protos): # Detect the overlapping set of protocols. overlap = set(protos) & set(self.protocols) # Select the option that comes last in the list in the overlap. for p in self.protocols: if p in overlap: return p else: return b'' self._ctx.set_npn_select_callback(cb) def set_alpn_protocols(self, protocols): protocols = list(map(lambda x: x.encode('ascii'), protocols)) self._ctx.set_alpn_protos(protocols) def wrap_socket(self, sock, server_side=False, do_handshake_on_connect=True, suppress_ragged_eofs=True, server_hostname=None): conn = ossl.Connection(self._ctx, sock) return SSLSocket(conn, server_side, do_handshake_on_connect, suppress_ragged_eofs, server_hostname, # TODO what if this is changed after the fact? self.check_hostname)