From dc1389d634298137ad444bb351d604edd78918e1 Mon Sep 17 00:00:00 2001 From: Pedro Ferreira Date: Sun, 4 Feb 2018 22:18:36 +0100 Subject: [PATCH] Adapt code to match secret_handshake changes --- examples/test_client.py | 18 +++---- examples/test_server.py | 60 ++++++++++++--------- setup.py | 8 +-- ssb/feed/__init__.py | 3 ++ ssb/packet_stream.py | 78 ++++++++------------------- ssb/tests/test_packet_stream.py | 94 ++++++++++++++++++--------------- 6 files changed, 124 insertions(+), 137 deletions(-) create mode 100644 ssb/feed/__init__.py diff --git a/examples/test_client.py b/examples/test_client.py index 488dc28..58cbcd2 100644 --- a/examples/test_client.py +++ b/examples/test_client.py @@ -5,8 +5,9 @@ from asyncio import get_event_loop, gather, ensure_future from colorlog import ColoredFormatter +from secret_handshake.network import SHSClient from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException -from ssb.packet_stream import PSClient, PSMessageType +from ssb.packet_stream import PacketStream, PSMessageType from ssb.util import load_ssb_secret @@ -52,13 +53,15 @@ async def test_client(): f.write(data.data) -async def _main(packet_stream): - await packet_stream.connect() +async def main(): + client = SHSClient('127.0.0.1', 8008, keypair, bytes(keypair.verify_key)) + packet_stream = PacketStream(client) + await client.open() api.add_connection(packet_stream) await gather(ensure_future(api), test_client()) -def main(): +if __name__ == '__main__': # create console handler and set level to debug ch = logging.StreamHandler() ch.setLevel(logging.INFO) @@ -78,10 +81,5 @@ def main(): keypair = load_ssb_secret()['keypair'] loop = get_event_loop() - packet_stream = PSClient('127.0.0.1', 8008, keypair, bytes(keypair.verify_key), loop=loop) - loop.run_until_complete(_main(packet_stream)) + loop.run_until_complete(main()) loop.close() - - -if __name__ == '__main__': - main() diff --git a/examples/test_server.py b/examples/test_server.py index 3218b34..a4aac05 100644 --- a/examples/test_server.py +++ b/examples/test_server.py @@ -1,37 +1,47 @@ import logging -from asyncio import get_event_loop, ensure_future +from asyncio import gather, get_event_loop, ensure_future from colorlog import ColoredFormatter -from ssb.packet_stream import PSServer +from secret_handshake import SHSServer +from ssb.packet_stream import PacketStream +from ssb.protocol.streams import stream_api from ssb.util import load_ssb_secret -async def on_connect(): - ensure_future(packet_stream, loop=loop) +async def on_connect(conn): + packet_stream = PacketStream(conn) + stream_api.add_connection(packet_stream) -# create console handler and set level to debug -ch = logging.StreamHandler() -ch.setLevel(logging.DEBUG) - -# create formatter -formatter = ColoredFormatter('%(log_color)s%(levelname)s%(reset)s:%(bold_white)s%(name)s%(reset)s - ' - '%(cyan)s%(message)s%(reset)s') - -# add formatter to ch -ch.setFormatter(formatter) - -# add ch to logger -logger = logging.getLogger('packet_stream') -logger.setLevel(logging.DEBUG) -logger.addHandler(ch) + print('connect', conn) + async for msg in packet_stream: + print(msg) -loop = get_event_loop() +async def main(): + server = SHSServer('127.0.0.1', 8008, load_ssb_secret()['keypair']) + server.on_connect(on_connect) + await server.listen() -packet_stream = PSServer('127.0.0.1', 8008, load_ssb_secret()['keypair'], loop=loop) -packet_stream.on_connect(on_connect) -packet_stream.listen() -loop.run_forever() -loop.close() +if __name__ == '__main__': + # create console handler and set level to debug + ch = logging.StreamHandler() + ch.setLevel(logging.DEBUG) + + # create formatter + formatter = ColoredFormatter('%(log_color)s%(levelname)s%(reset)s:%(bold_white)s%(name)s%(reset)s - ' + '%(cyan)s%(message)s%(reset)s') + + # add formatter to ch + ch.setFormatter(formatter) + + # add ch to logger + logger = logging.getLogger('packet_stream') + logger.setLevel(logging.DEBUG) + logger.addHandler(ch) + + loop = get_event_loop() + loop.run_until_complete(main()) + loop.run_forever() + loop.close() diff --git a/setup.py b/setup.py index a096d5e..395018b 100644 --- a/setup.py +++ b/setup.py @@ -35,8 +35,7 @@ tests_require = [ 'pytest>=3.1.1', 'pytest-asyncio==0.6.0', 'asynctest==0.10.0', - 'pytest-mock==1.6.2', - 'async-generator==1.8' + 'pytest-mock==1.6.2' ] extras_require = { @@ -48,10 +47,11 @@ extras_require = { extras_require['all'] = sum((lst for lst in extras_require.values()), []) install_requires = [ + 'async-generator==1.8', 'pynacl==1.1.2', - 'simplejson==3.10.0', 'PyYAML==3.12', - 'secret-handshake' + 'secret-handshake', + 'simplejson==3.10.0' ] setup_requires = [ diff --git a/ssb/feed/__init__.py b/ssb/feed/__init__.py new file mode 100644 index 0000000..c3080c0 --- /dev/null +++ b/ssb/feed/__init__.py @@ -0,0 +1,3 @@ +from .models import Feed, LocalFeed, Message, LocalMessage, NoPrivateKeyException + +__all__ = ('Feed', 'LocalFeed', 'Message', 'LocalMessage', 'NoPrivateKeyException') diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index 1829616..3849ef8 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -99,18 +99,34 @@ class PSMessage(object): '~' if self.stream else '', '!' if self.end_err else '') -class PSConnection(object): - def __init__(self): - self._event_map = {} +class PacketStream(object): + def __init__(self, connection): + self.connection = connection self.req_counter = 1 - self._connected = False + self._event_map = {} - async def _on_connect(self): - self._connected = True + def register_handler(self, handler): + self._event_map[handler.req] = (time(), handler) @property def is_connected(self): - return self._connected + return self.connection.is_connected + + @async_generator + async def __aiter__(self): + while True: + msg = await self.read() + if not msg: + return + # filter out replies + if msg.req >= 0: + await yield_(msg) + + async def __await__(self): + async for data in self: + logger.info('RECV: %r', data) + if data is None: + return async def _read(self): try: @@ -147,25 +163,6 @@ class PSConnection(object): logger.info('RESPONSE [%d]: EOS', -msg.req) return msg - async def __await__(self): - async for data in self: - logger.info('RECV: %r', data) - if data is None: - return - - def register_handler(self, handler): - self._event_map[handler.req] = (time(), handler) - - @async_generator - async def __aiter__(self): - while True: - msg = await self.read() - if not msg: - return - # filter out replies - if msg.req >= 0: - await yield_(msg) - def _write(self, msg): logger.info('SEND [%d]: %r', msg.req, msg) header = struct.pack('>BIi', (int(msg.stream) << 3) | (int(msg.end_err) << 2) | msg.type.value, len(msg.data), @@ -199,32 +196,3 @@ class PSConnection(object): def disconnect(self): self._connected = False self.connection.disconnect() - - -class PSClient(PSConnection): - def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None, - socket_class=SHSClient): - super(PSClient, self).__init__() - self.connection = socket_class(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key, - application_key=application_key, loop=loop) - self.connection.on_connect(self._on_connect) - self.loop = loop - - async def connect(self): - await self.connection.connect() - - -class PSServer(PSConnection): - def __init__(self, host, port, client_kp, application_key=None, loop=None, socket_class=SHSServer): - super(PSServer, self).__init__() - self.connection = socket_class(host, port, client_kp, application_key=application_key, loop=loop) - self.connection.on_connect(self._on_connect) - self.loop = loop - - def on_connect(self, cb): - async def _on_connect(): - await cb() - self.connection.on_connect(_on_connect) - - def listen(self): - self.connection.listen() diff --git a/ssb/tests/test_packet_stream.py b/ssb/tests/test_packet_stream.py index fe5ae54..02ccd0e 100644 --- a/ssb/tests/test_packet_stream.py +++ b/ssb/tests/test_packet_stream.py @@ -5,8 +5,8 @@ import pytest from asynctest import patch from nacl.signing import SigningKey -from secret_handshake.network import SHSSocket -from ssb.packet_stream import PSClient, PSServer, PSMessageType +from secret_handshake.network import SHSDuplexStream +from ssb.packet_stream import PacketStream, PSMessageType async def _collect_messages(generator): @@ -31,9 +31,9 @@ MSG_BODY_2 = (b'{"previous":"%iQRhPyqmNLpGaO1Tpm1I22jqnUEwRwkCTDbwAGtM+lY=.sha25 b'mAkqqMwFWfP+eBIbc7DZ835er6r6h9CwAg==.sig.ed25519"}') -class MockSHSSocket(SHSSocket): - def __init__(self, *args, loop=None, **kwargs): - super(MockSHSSocket, self).__init__(loop) +class MockSHSSocket(SHSDuplexStream): + def __init__(self, *args, **kwargs): + super(MockSHSSocket, self).__init__() self.input = [] self.output = [] self.is_connected = False @@ -74,19 +74,17 @@ class MockSHSServer(MockSHSSocket): def listen(self): self.is_connected = True for cb in self._on_connect: - ensure_future(cb(), loop=self.loop) + ensure_future(cb()) @pytest.fixture def ps_client(event_loop): - client = PSClient('fake.local', 1000, SigningKey.generate(), b'\00' * 32, socket_class=MockSHSClient, - loop=event_loop) - return client + return MockSHSClient() @pytest.fixture def ps_server(event_loop): - return PSServer('fake.local', 1000, SigningKey.generate(), socket_class=MockSHSServer, loop=event_loop) + return MockSHSServer() @pytest.mark.asyncio @@ -105,15 +103,18 @@ async def test_on_connect(ps_server): @pytest.mark.asyncio async def test_message_decoding(ps_client): await ps_client.connect() - assert ps_client.is_connected - ps_client.connection.feed([ + ps = PacketStream(ps_client) + + assert ps.is_connected + + ps_client.feed([ b'\n\x00\x00\x00\x9a\x00\x00\x04\xfb', b'{"name":["createHistoryStream"],"args":[{"id":"@omgyp7Pnrw+Qm0I6T6Fh5VvnKmodMXwnxTIesW2DgMg=.ed25519",' b'"seq":10,"live":true,"keys":false}],"type":"source"}' ]) - messages = (await _collect_messages(ps_client)) + messages = (await _collect_messages(ps)) assert len(messages) == 1 assert messages[0].type == PSMessageType.JSON assert messages[0].body == { @@ -133,9 +134,12 @@ async def test_message_decoding(ps_client): @pytest.mark.asyncio async def test_message_encoding(ps_client): await ps_client.connect() - assert ps_client.is_connected - ps_client.send({ + ps = PacketStream(ps_client) + + assert ps.is_connected + + ps.send({ 'name': ['createHistoryStream'], 'args': [{ 'id': "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519", @@ -146,7 +150,7 @@ async def test_message_encoding(ps_client): 'type': 'source' }, stream=True) - header, body = list(ps_client.connection.get_output()) + header, body = list(ps_client.get_output()) assert header == b'\x0a\x00\x00\x00\xa6\x00\x00\x00\x01' assert json.loads(body.decode('utf-8')) == { @@ -160,12 +164,14 @@ async def test_message_encoding(ps_client): @pytest.mark.asyncio async def test_message_stream(ps_client, mocker): - mocker.patch.object(ps_client, 'register_handler', wraps=ps_client.register_handler) - await ps_client.connect() - assert ps_client.is_connected - ps_client.send({ + ps = PacketStream(ps_client) + mocker.patch.object(ps, 'register_handler', wraps=ps.register_handler) + + assert ps.is_connected + + ps.send({ 'name': ['createHistoryStream'], 'args': [{ 'id': "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519", @@ -176,22 +182,22 @@ async def test_message_stream(ps_client, mocker): 'type': 'source' }, stream=True) - assert ps_client.req_counter == 2 - assert ps_client.register_handler.call_count == 1 - handler = list(ps_client._event_map.values())[0][1] + assert ps.req_counter == 2 + assert ps.register_handler.call_count == 1 + handler = list(ps._event_map.values())[0][1] with patch.object(handler, 'process') as mock_process: - ps_client.connection.feed([b'\n\x00\x00\x02\xc5\xff\xff\xff\xff', MSG_BODY_1]) - msg = await ps_client.read() + ps_client.feed([b'\n\x00\x00\x02\xc5\xff\xff\xff\xff', MSG_BODY_1]) + msg = await ps.read() assert mock_process.call_count == 1 # responses have negative req assert msg.req == -1 assert msg.body['previous'] == '%KTGP6W8vF80McRAZHYDWuKOD0KlNyKSq6Gb42iuV7Iw=.sha256' - assert ps_client.req_counter == 2 + assert ps.req_counter == 2 - stream_handler = ps_client.send({ + stream_handler = ps.send({ 'name': ['createHistoryStream'], 'args': [{ 'id': "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519", @@ -202,16 +208,16 @@ async def test_message_stream(ps_client, mocker): 'type': 'source' }, stream=True) - assert ps_client.req_counter == 3 - assert ps_client.register_handler.call_count == 2 - handler = list(ps_client._event_map.values())[1][1] + assert ps.req_counter == 3 + assert ps.register_handler.call_count == 2 + handler = list(ps._event_map.values())[1][1] with patch.object(handler, 'process', wraps=handler.process) as mock_process: - ps_client.connection.feed([b'\n\x00\x00\x02\xc5\xff\xff\xff\xfe', MSG_BODY_1, - b'\x0e\x00\x00\x023\xff\xff\xff\xfe', MSG_BODY_2]) + ps_client.feed([b'\n\x00\x00\x02\xc5\xff\xff\xff\xfe', MSG_BODY_1, + b'\x0e\x00\x00\x023\xff\xff\xff\xfe', MSG_BODY_2]) # execute both message polling and response handling loops - collected, handled = await gather(_collect_messages(ps_client), _collect_messages(stream_handler)) + collected, handled = await gather(_collect_messages(ps), _collect_messages(stream_handler)) # No messages collected, since they're all responses assert collected == [] @@ -227,28 +233,30 @@ async def test_message_stream(ps_client, mocker): async def test_message_request(ps_server, mocker): ps_server.listen() - mocker.patch.object(ps_server, 'register_handler', wraps=ps_server.register_handler) + ps = PacketStream(ps_server) - ps_server.send({ + mocker.patch.object(ps, 'register_handler', wraps=ps.register_handler) + + ps.send({ 'name': ['whoami'], 'args': [] }) - header, body = list(ps_server.connection.get_output()) + header, body = list(ps_server.get_output()) assert header == b'\x02\x00\x00\x00 \x00\x00\x00\x01' assert json.loads(body.decode('utf-8')) == {"name": ["whoami"], "args": []} - assert ps_server.req_counter == 2 - assert ps_server.register_handler.call_count == 1 - handler = list(ps_server._event_map.values())[0][1] + assert ps.req_counter == 2 + assert ps.register_handler.call_count == 1 + handler = list(ps._event_map.values())[0][1] with patch.object(handler, 'process') as mock_process: - ps_server.connection.feed([b'\x02\x00\x00\x00>\xff\xff\xff\xff', - b'{"id":"@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519"}']) - msg = await ps_server.read() + ps_server.feed([b'\x02\x00\x00\x00>\xff\xff\xff\xff', + b'{"id":"@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519"}']) + msg = await ps.read() assert mock_process.call_count == 1 # responses have negative req assert msg.req == -1 assert msg.body['id'] == '@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519' - assert ps_server.req_counter == 2 + assert ps.req_counter == 2