308 lines
9.9 KiB
Python
308 lines
9.9 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
hyper/ssl_compat
|
|
~~~~~~~~~
|
|
|
|
Shoves pyOpenSSL into an API that looks like the standard Python 3.x ssl
|
|
module.
|
|
|
|
Currently exposes exactly those attributes, classes, and methods that we
|
|
actually use in hyper (all method signatures are complete, however). May be
|
|
expanded to something more general-purpose in the future.
|
|
"""
|
|
try:
|
|
import StringIO as BytesIO
|
|
except ImportError:
|
|
from io import BytesIO
|
|
import errno
|
|
import socket
|
|
import time
|
|
|
|
from OpenSSL import SSL as ossl
|
|
from service_identity.pyopenssl import verify_hostname as _verify
|
|
|
|
CERT_NONE = ossl.VERIFY_NONE
|
|
CERT_REQUIRED = ossl.VERIFY_PEER | ossl.VERIFY_FAIL_IF_NO_PEER_CERT
|
|
|
|
_OPENSSL_ATTRS = dict(
|
|
OP_NO_COMPRESSION='OP_NO_COMPRESSION',
|
|
PROTOCOL_TLSv1_2='TLSv1_2_METHOD',
|
|
PROTOCOL_SSLv23='SSLv23_METHOD',
|
|
)
|
|
|
|
for external, internal in _OPENSSL_ATTRS.items():
|
|
value = getattr(ossl, internal, None)
|
|
if value:
|
|
locals()[external] = value
|
|
|
|
OP_ALL = 0
|
|
# TODO: Find out the names of these other flags.
|
|
for bit in [31] + list(range(10)):
|
|
OP_ALL |= 1 << bit
|
|
|
|
HAS_NPN = True
|
|
|
|
|
|
def _proxy(method):
|
|
def inner(self, *args, **kwargs):
|
|
return getattr(self._conn, method)(*args, **kwargs)
|
|
return inner
|
|
|
|
# Referenced in hyper/http20/connection.py. These values come
|
|
# from the python ssl package, and must be defined in this file
|
|
# for hyper to work in python versions <2.7.9
|
|
SSL_ERROR_WANT_READ = 2
|
|
SSL_ERROR_WANT_WRITE = 3
|
|
|
|
|
|
# TODO missing some attributes
|
|
class SSLError(OSError):
|
|
pass
|
|
|
|
|
|
class CertificateError(SSLError):
|
|
pass
|
|
|
|
|
|
def verify_hostname(ssl_sock, server_hostname):
|
|
"""
|
|
A method nearly compatible with the stdlib's match_hostname.
|
|
"""
|
|
if isinstance(server_hostname, bytes):
|
|
server_hostname = server_hostname.decode('ascii')
|
|
return _verify(ssl_sock._conn, server_hostname)
|
|
|
|
|
|
class SSLSocket(object):
|
|
SSL_TIMEOUT = 3
|
|
SSL_RETRY = .01
|
|
|
|
def __init__(self, conn, server_side, do_handshake_on_connect,
|
|
suppress_ragged_eofs, server_hostname, check_hostname):
|
|
self._conn = conn
|
|
self._do_handshake_on_connect = do_handshake_on_connect
|
|
self._suppress_ragged_eofs = suppress_ragged_eofs
|
|
self._check_hostname = check_hostname
|
|
|
|
if server_side:
|
|
self._conn.set_accept_state()
|
|
else:
|
|
if server_hostname:
|
|
self._conn.set_tlsext_host_name(
|
|
server_hostname.encode('utf-8')
|
|
)
|
|
self._server_hostname = server_hostname
|
|
# FIXME does this override do_handshake_on_connect=False?
|
|
self._conn.set_connect_state()
|
|
|
|
if self.connected and self._do_handshake_on_connect:
|
|
self.do_handshake()
|
|
|
|
@property
|
|
def connected(self):
|
|
try:
|
|
self._conn.getpeername()
|
|
except socket.error as e:
|
|
if e.errno != errno.ENOTCONN:
|
|
# It's an exception other than the one we expected if we're not
|
|
# connected.
|
|
raise
|
|
return False
|
|
return True
|
|
|
|
# Lovingly stolen from CherryPy
|
|
# (http://svn.cherrypy.org/tags/cherrypy-3.2.1/cherrypy/wsgiserver/ssl_pyopenssl.py).
|
|
def _safe_ssl_call(self, suppress_ragged_eofs, call, *args, **kwargs):
|
|
"""Wrap the given call with SSL error-trapping."""
|
|
start = time.time()
|
|
while True:
|
|
try:
|
|
return call(*args, **kwargs)
|
|
except (ossl.WantReadError, ossl.WantWriteError):
|
|
# Sleep and try again. This is dangerous, because it means
|
|
# the rest of the stack has no way of differentiating
|
|
# between a "new handshake" error and "client dropped".
|
|
# Note this isn't an endless loop: there's a timeout below.
|
|
time.sleep(self.SSL_RETRY)
|
|
except ossl.Error as e:
|
|
if suppress_ragged_eofs and e.args == (-1, 'Unexpected EOF'):
|
|
return b''
|
|
raise socket.error(e.args[0])
|
|
|
|
if time.time() - start > self.SSL_TIMEOUT:
|
|
raise socket.timeout('timed out')
|
|
|
|
def connect(self, address):
|
|
self._conn.connect(address)
|
|
if self._do_handshake_on_connect:
|
|
self.do_handshake()
|
|
|
|
def do_handshake(self):
|
|
self._safe_ssl_call(False, self._conn.do_handshake)
|
|
if self._check_hostname:
|
|
verify_hostname(self, self._server_hostname)
|
|
|
|
def recv(self, bufsize, flags=None):
|
|
return self._safe_ssl_call(
|
|
self._suppress_ragged_eofs,
|
|
self._conn.recv,
|
|
bufsize,
|
|
flags
|
|
)
|
|
|
|
def recv_into(self, buffer, bufsize=None, flags=None):
|
|
# A temporary recv_into implementation. Should be replaced when
|
|
# PyOpenSSL has merged pyca/pyopenssl#121.
|
|
if bufsize is None:
|
|
bufsize = len(buffer)
|
|
|
|
data = self.recv(bufsize, flags)
|
|
data_len = len(data)
|
|
buffer[0:data_len] = data
|
|
return data_len
|
|
|
|
def send(self, data, flags=None):
|
|
return self._safe_ssl_call(False, self._conn.send, data, flags)
|
|
|
|
def sendall(self, data, flags=None):
|
|
return self._safe_ssl_call(False, self._conn.sendall, data, flags)
|
|
|
|
def selected_npn_protocol(self):
|
|
proto = self._conn.get_next_proto_negotiated()
|
|
if isinstance(proto, bytes):
|
|
proto = proto.decode('ascii')
|
|
|
|
return proto if proto else None
|
|
|
|
def selected_alpn_protocol(self):
|
|
proto = self._conn.get_alpn_proto_negotiated()
|
|
if isinstance(proto, bytes):
|
|
proto = proto.decode('ascii')
|
|
|
|
return proto if proto else None
|
|
|
|
def getpeercert(self):
|
|
def resolve_alias(alias):
|
|
return dict(
|
|
C='countryName',
|
|
ST='stateOrProvinceName',
|
|
L='localityName',
|
|
O='organizationName',
|
|
OU='organizationalUnitName',
|
|
CN='commonName',
|
|
).get(alias, alias)
|
|
|
|
def to_components(name):
|
|
# TODO Verify that these are actually *supposed* to all be
|
|
# single-element tuples, and that's not just a quirk of the
|
|
# examples I've seen.
|
|
return tuple(
|
|
[
|
|
(resolve_alias(k.decode('utf-8'), v.decode('utf-8')),)
|
|
for k, v in name.get_components()
|
|
]
|
|
)
|
|
|
|
# The standard getpeercert() takes the nice X509 object tree returned
|
|
# by OpenSSL and turns it into a dict according to some format it seems
|
|
# to have made up on the spot. Here, we do our best to emulate that.
|
|
cert = self._conn.get_peer_certificate()
|
|
result = dict(
|
|
issuer=to_components(cert.get_issuer()),
|
|
subject=to_components(cert.get_subject()),
|
|
version=cert.get_subject(),
|
|
serialNumber=cert.get_serial_number(),
|
|
notBefore=cert.get_notBefore(),
|
|
notAfter=cert.get_notAfter(),
|
|
)
|
|
# TODO extensions, including subjectAltName
|
|
# (see _decode_certificate in _ssl.c)
|
|
return result
|
|
|
|
# a dash of magic to reduce boilerplate
|
|
methods = ['accept', 'bind', 'close', 'getsockname', 'listen', 'fileno']
|
|
for method in methods:
|
|
locals()[method] = _proxy(method)
|
|
|
|
|
|
class SSLContext(object):
|
|
def __init__(self, protocol):
|
|
self.protocol = protocol
|
|
self._ctx = ossl.Context(protocol)
|
|
self.options = OP_ALL
|
|
self.check_hostname = False
|
|
self.npn_protos = []
|
|
|
|
@property
|
|
def options(self):
|
|
return self._options
|
|
|
|
@options.setter
|
|
def options(self, value):
|
|
self._options = value
|
|
self._ctx.set_options(value)
|
|
|
|
@property
|
|
def verify_mode(self):
|
|
return self._ctx.get_verify_mode()
|
|
|
|
@verify_mode.setter
|
|
def verify_mode(self, value):
|
|
# TODO verify exception is raised on failure
|
|
self._ctx.set_verify(
|
|
value, lambda conn, cert, errnum, errdepth, ok: ok
|
|
)
|
|
|
|
def set_default_verify_paths(self):
|
|
self._ctx.set_default_verify_paths()
|
|
|
|
def load_verify_locations(self, cafile=None, capath=None, cadata=None):
|
|
# TODO factor out common code
|
|
if cafile is not None:
|
|
cafile = cafile.encode('utf-8')
|
|
if capath is not None:
|
|
capath = capath.encode('utf-8')
|
|
self._ctx.load_verify_locations(cafile, capath)
|
|
if cadata is not None:
|
|
self._ctx.load_verify_locations(BytesIO(cadata))
|
|
|
|
def load_cert_chain(self, certfile, keyfile=None, password=None):
|
|
self._ctx.use_certificate_file(certfile)
|
|
if password is not None:
|
|
self._ctx.set_passwd_cb(
|
|
lambda max_length, prompt_twice, userdata: password
|
|
)
|
|
self._ctx.use_privatekey_file(keyfile or certfile)
|
|
|
|
def set_npn_protocols(self, protocols):
|
|
self.protocols = list(map(lambda x: x.encode('ascii'), protocols))
|
|
|
|
def cb(conn, protos):
|
|
# Detect the overlapping set of protocols.
|
|
overlap = set(protos) & set(self.protocols)
|
|
|
|
# Select the option that comes last in the list in the overlap.
|
|
for p in self.protocols:
|
|
if p in overlap:
|
|
return p
|
|
else:
|
|
return b''
|
|
|
|
self._ctx.set_npn_select_callback(cb)
|
|
|
|
def set_alpn_protocols(self, protocols):
|
|
protocols = list(map(lambda x: x.encode('ascii'), protocols))
|
|
self._ctx.set_alpn_protos(protocols)
|
|
|
|
def wrap_socket(self,
|
|
sock,
|
|
server_side=False,
|
|
do_handshake_on_connect=True,
|
|
suppress_ragged_eofs=True,
|
|
server_hostname=None):
|
|
conn = ossl.Connection(self._ctx, sock)
|
|
return SSLSocket(conn, server_side, do_handshake_on_connect,
|
|
suppress_ragged_eofs, server_hostname,
|
|
# TODO what if this is changed after the fact?
|
|
self.check_hostname)
|