Allow low level transport class to be overridden

This commit is contained in:
Pedro Ferreira 2017-07-31 23:15:52 +02:00
parent 7b2ca4b606
commit f2f4a829d2

View File

@ -127,7 +127,7 @@ class PSConnection(object):
return PSMessage.from_header_body(flags, req, body)
except StopAsyncIteration:
logger.debug('DISCONNECT')
await self.connection.disconnect()
self.connection.disconnect()
return None
async def read(self):
@ -199,10 +199,11 @@ class PSConnection(object):
class PSClient(PSConnection):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None,
socket_class=SHSClient):
super(PSClient, self).__init__()
self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key,
application_key=application_key, loop=loop)
self.connection = socket_class(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key,
application_key=application_key, loop=loop)
self.connection.on_connect(self._on_connect)
self.loop = loop
@ -211,9 +212,9 @@ class PSClient(PSConnection):
class PSServer(PSConnection):
def __init__(self, host, port, client_kp, application_key=None, loop=None):
super(PSClient, self).__init__()
self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop)
def __init__(self, host, port, client_kp, application_key=None, loop=None, socket_class=SHSServer):
super(PSServer, self).__init__()
self.connection = socket_class(host, port, client_kp, application_key=application_key, loop=loop)
self.connection.on_connect(self._on_connect)
self.loop = loop