Allow low level transport class to be overridden

This commit is contained in:
Pedro Ferreira 2017-07-31 23:15:52 +02:00
parent 7b2ca4b606
commit f2f4a829d2

View File

@ -127,7 +127,7 @@ class PSConnection(object):
return PSMessage.from_header_body(flags, req, body) return PSMessage.from_header_body(flags, req, body)
except StopAsyncIteration: except StopAsyncIteration:
logger.debug('DISCONNECT') logger.debug('DISCONNECT')
await self.connection.disconnect() self.connection.disconnect()
return None return None
async def read(self): async def read(self):
@ -199,9 +199,10 @@ class PSConnection(object):
class PSClient(PSConnection): class PSClient(PSConnection):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None,
socket_class=SHSClient):
super(PSClient, self).__init__() super(PSClient, self).__init__()
self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key, self.connection = socket_class(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key,
application_key=application_key, loop=loop) application_key=application_key, loop=loop)
self.connection.on_connect(self._on_connect) self.connection.on_connect(self._on_connect)
self.loop = loop self.loop = loop
@ -211,9 +212,9 @@ class PSClient(PSConnection):
class PSServer(PSConnection): class PSServer(PSConnection):
def __init__(self, host, port, client_kp, application_key=None, loop=None): def __init__(self, host, port, client_kp, application_key=None, loop=None, socket_class=SHSServer):
super(PSClient, self).__init__() super(PSServer, self).__init__()
self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop) self.connection = socket_class(host, port, client_kp, application_key=application_key, loop=loop)
self.connection.on_connect(self._on_connect) self.connection.on_connect(self._on_connect)
self.loop = loop self.loop = loop