diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index 5715f8e..2fcaf96 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -127,7 +127,7 @@ class PSConnection(object): return PSMessage.from_header_body(flags, req, body) except StopAsyncIteration: logger.debug('DISCONNECT') - await self.connection.disconnect() + self.connection.disconnect() return None async def read(self): @@ -199,10 +199,11 @@ class PSConnection(object): class PSClient(PSConnection): - def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): + def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None, + socket_class=SHSClient): super(PSClient, self).__init__() - self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key, - application_key=application_key, loop=loop) + self.connection = socket_class(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key, + application_key=application_key, loop=loop) self.connection.on_connect(self._on_connect) self.loop = loop @@ -211,9 +212,9 @@ class PSClient(PSConnection): class PSServer(PSConnection): - def __init__(self, host, port, client_kp, application_key=None, loop=None): - super(PSClient, self).__init__() - self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop) + def __init__(self, host, port, client_kp, application_key=None, loop=None, socket_class=SHSServer): + super(PSServer, self).__init__() + self.connection = socket_class(host, port, client_kp, application_key=application_key, loop=loop) self.connection.on_connect(self._on_connect) self.loop = loop