diff --git a/ssb/tests/test_packet_stream.py b/ssb/tests/test_packet_stream.py index 04ad937..6863a30 100644 --- a/ssb/tests/test_packet_stream.py +++ b/ssb/tests/test_packet_stream.py @@ -1,8 +1,9 @@ import pytest -from asyncio import gather +from asyncio import ensure_future, gather, Event from asynctest import patch from nacl.signing import SigningKey +from secret_handshake.network import SHSSocket from ssb.packet_stream import PSClient, PSServer, PSMessageType @@ -28,8 +29,9 @@ MSG_BODY_2 = (b'{"previous":"%iQRhPyqmNLpGaO1Tpm1I22jqnUEwRwkCTDbwAGtM+lY=.sha25 b'mAkqqMwFWfP+eBIbc7DZ835er6r6h9CwAg==.sig.ed25519"}') -class MockSHSSocket(object): - def __init__(self, *args, **kwargs): +class MockSHSSocket(SHSSocket): + def __init__(self, *args, loop=None, **kwargs): + super(MockSHSSocket, self).__init__(loop) self.input = [] self.output = [] self.is_connected = False @@ -61,7 +63,7 @@ class MockSHSSocket(object): def _set_connected(self): self.is_connected = True for cb in self._on_connect: - self.event_loop.run_until_complete(cb()) + ensure_future(cb(), loop=self.loop) class MockSHSClient(MockSHSSocket): @@ -74,18 +76,28 @@ class MockSHSServer(MockSHSSocket): @pytest.fixture def ps_client(event_loop): - client = PSClient('fake.local', 1000, SigningKey.generate(), b'\00' * 32, socket_class=MockSHSClient) - client.connection.event_loop = event_loop + client = PSClient('fake.local', 1000, SigningKey.generate(), b'\00' * 32, socket_class=MockSHSClient, + loop=event_loop) client.connect() return client @pytest.fixture def ps_server(event_loop): - server = PSServer('fake.local', 1000, SigningKey.generate(), socket_class=MockSHSServer) - server.connection.event_loop = event_loop - server.listen() - return server + return PSServer('fake.local', 1000, SigningKey.generate(), socket_class=MockSHSServer, loop=event_loop) + + +@pytest.mark.asyncio +async def test_on_connect(ps_server): + called = Event() + + async def _on_connect(): + called.set() + + ps_server.on_connect(_on_connect) + ps_server.listen() + await called.wait() + assert ps_server.is_connected @pytest.mark.asyncio @@ -202,9 +214,9 @@ async def test_message_stream(ps_client, mocker): @pytest.mark.asyncio async def test_message_request(ps_server, mocker): - mocker.patch.object(ps_server, 'register_handler', wraps=ps_server.register_handler) + ps_server.listen() - assert ps_server.is_connected + mocker.patch.object(ps_server, 'register_handler', wraps=ps_server.register_handler) ps_server.send({ 'name': ['whoami'],