diff --git a/ssb/tests/test_packet_stream.py b/ssb/tests/test_packet_stream.py index 6863a30..c6c6fb0 100644 --- a/ssb/tests/test_packet_stream.py +++ b/ssb/tests/test_packet_stream.py @@ -60,25 +60,25 @@ class MockSHSSocket(SHSSocket): def disconnect(self): self.is_connected = False - def _set_connected(self): - self.is_connected = True - for cb in self._on_connect: - ensure_future(cb(), loop=self.loop) - class MockSHSClient(MockSHSSocket): - connect = MockSHSSocket._set_connected + async def connect(self): + self.is_connected = True + for cb in self._on_connect: + await cb() class MockSHSServer(MockSHSSocket): - listen = MockSHSSocket._set_connected + def listen(self): + self.is_connected = True + for cb in self._on_connect: + ensure_future(cb(), loop=self.loop) @pytest.fixture def ps_client(event_loop): client = PSClient('fake.local', 1000, SigningKey.generate(), b'\00' * 32, socket_class=MockSHSClient, loop=event_loop) - client.connect() return client @@ -102,6 +102,7 @@ async def test_on_connect(ps_server): @pytest.mark.asyncio async def test_message_decoding(ps_client): + await ps_client.connect() assert ps_client.is_connected ps_client.connection.feed([ @@ -129,6 +130,7 @@ async def test_message_decoding(ps_client): @pytest.mark.asyncio async def test_message_encoding(ps_client): + await ps_client.connect() assert ps_client.is_connected ps_client.send({ @@ -142,16 +144,23 @@ async def test_message_encoding(ps_client): 'type': 'source' }, stream=True) - body = (b'{"name": ["createHistoryStream"], "args": [{"id": "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519"' - b', "seq": 1, "live": false, "keys": false}], "type": "source"}') + header, body = list(ps_client.connection.get_output()) - assert list(ps_client.connection.get_output()) == [b'\x0a\x00\x00\x00\xa6\x00\x00\x00\x01', body] + assert header == b'\x0a\x00\x00\x00\xa6\x00\x00\x00\x01' + assert json.loads(body.decode('utf-8')) == { + "name": ["createHistoryStream"], + "args": [ + {"id": "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519", "seq": 1, "live": False, "keys": False} + ], + "type": "source" + } @pytest.mark.asyncio async def test_message_stream(ps_client, mocker): mocker.patch.object(ps_client, 'register_handler', wraps=ps_client.register_handler) + await ps_client.connect() assert ps_client.is_connected ps_client.send({ @@ -223,8 +232,9 @@ async def test_message_request(ps_server, mocker): 'args': [] }) - assert (list(ps_server.connection.get_output()) == - [b'\x02\x00\x00\x00 \x00\x00\x00\x01', b'{"name": ["whoami"], "args": []}']) + header, body = list(ps_server.connection.get_output()) + assert header == b'\x02\x00\x00\x00 \x00\x00\x00\x01' + assert json.loads(body.decode('utf-8')) == {"name": ["whoami"], "args": []} assert ps_server.req_counter == 2 assert ps_server.register_handler.call_count == 1