Improvements in on_connect()

This commit is contained in:
Pedro Ferreira 2017-06-05 20:41:44 +02:00
parent 37c7d34fee
commit 0a9ebc09bf

View File

@ -19,7 +19,7 @@
# SOFTWARE. # SOFTWARE.
from asyncio import open_connection, start_server from asyncio import open_connection, start_server, ensure_future
from .boxstream import get_stream_pair from .boxstream import get_stream_pair
from .crypto import SHSClientCrypto, SHSServerCrypto from .crypto import SHSClientCrypto, SHSServerCrypto
@ -32,6 +32,7 @@ class SHSClientException(Exception):
class SHSSocket(object): class SHSSocket(object):
def __init__(self, loop): def __init__(self, loop):
self.loop = loop self.loop = loop
self._on_connect = None
def write(self, data): def write(self, data):
self.write_stream.write(data) self.write_stream.write(data)
@ -46,6 +47,9 @@ class SHSSocket(object):
async for msg in self.read_stream: async for msg in self.read_stream:
yield msg yield msg
def on_connect(self, cb):
self._on_connect = cb
class SHSServer(SHSSocket): class SHSServer(SHSSocket):
def __init__(self, host, port, server_kp, application_key=None, loop=None): def __init__(self, host, port, server_kp, application_key=None, loop=None):
@ -53,7 +57,6 @@ class SHSServer(SHSSocket):
self.host = host self.host = host
self.port = port self.port = port
self.crypto = SHSServerCrypto(server_kp, application_key=application_key) self.crypto = SHSServerCrypto(server_kp, application_key=application_key)
self._on_connect = None
async def _handshake(self, reader, writer): async def _handshake(self, reader, writer):
data = await reader.read(64) data = await reader.read(64)
@ -79,14 +82,11 @@ class SHSServer(SHSSocket):
self.writer = writer self.writer = writer
if self._on_connect: if self._on_connect:
await self._on_connect() ensure_future(self._on_connect(), loop=self.loop)
def listen(self): def listen(self):
self.loop.run_until_complete(start_server(self.handle_connection, self.host, self.port, loop=self.loop)) self.loop.run_until_complete(start_server(self.handle_connection, self.host, self.port, loop=self.loop))
def on_connect(self, cb):
self._on_connect = cb
class SHSClient(SHSSocket): class SHSClient(SHSSocket):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None):
@ -118,3 +118,5 @@ class SHSClient(SHSSocket):
self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys) self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys)
self.writer = writer self.writer = writer
if self._on_connect:
ensure_future(self._on_connect(), loop=self.loop)