diff --git a/secret_handshake/network.py b/secret_handshake/network.py index 5580b4b..3a06711 100644 --- a/secret_handshake/network.py +++ b/secret_handshake/network.py @@ -19,7 +19,7 @@ # SOFTWARE. -from asyncio import open_connection, start_server +from asyncio import open_connection, start_server, ensure_future from .boxstream import get_stream_pair from .crypto import SHSClientCrypto, SHSServerCrypto @@ -32,6 +32,7 @@ class SHSClientException(Exception): class SHSSocket(object): def __init__(self, loop): self.loop = loop + self._on_connect = None def write(self, data): self.write_stream.write(data) @@ -46,6 +47,9 @@ class SHSSocket(object): async for msg in self.read_stream: yield msg + def on_connect(self, cb): + self._on_connect = cb + class SHSServer(SHSSocket): def __init__(self, host, port, server_kp, application_key=None, loop=None): @@ -53,7 +57,6 @@ class SHSServer(SHSSocket): self.host = host self.port = port self.crypto = SHSServerCrypto(server_kp, application_key=application_key) - self._on_connect = None async def _handshake(self, reader, writer): data = await reader.read(64) @@ -79,14 +82,11 @@ class SHSServer(SHSSocket): self.writer = writer if self._on_connect: - await self._on_connect() + ensure_future(self._on_connect(), loop=self.loop) def listen(self): self.loop.run_until_complete(start_server(self.handle_connection, self.host, self.port, loop=self.loop)) - def on_connect(self, cb): - self._on_connect = cb - class SHSClient(SHSSocket): def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): @@ -118,3 +118,5 @@ class SHSClient(SHSSocket): self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys) self.writer = writer + if self._on_connect: + ensure_future(self._on_connect(), loop=self.loop)