Refactor server code (use connections)
Past version was unusable with >1 client.
This commit is contained in:
		| @@ -14,15 +14,14 @@ with open(os.path.expanduser('~/.ssb/secret')) as f: | ||||
|  | ||||
|  | ||||
| async def main(): | ||||
|     server_pub_key = b64decode(config['public'][:-8]) | ||||
|     client = SHSClient('localhost', 8008, SigningKey.generate(), server_pub_key) | ||||
|     await client.open() | ||||
|  | ||||
|     async for msg in client: | ||||
|         print(msg) | ||||
|  | ||||
|  | ||||
| loop = get_event_loop() | ||||
|  | ||||
| server_pub_key = b64decode(config['public'][:-8]) | ||||
| client = SHSClient('localhost', 8008, SigningKey.generate(), server_pub_key, loop=loop) | ||||
| client.connect() | ||||
| loop.run_until_complete(main()) | ||||
|  | ||||
| loop.close() | ||||
|   | ||||
| @@ -13,17 +13,18 @@ with open(os.path.expanduser('~/.ssb/secret')) as f: | ||||
|     config = yaml.load(f) | ||||
|  | ||||
|  | ||||
| async def main(): | ||||
|     async for msg in server: | ||||
| async def _on_connect(conn): | ||||
|     async for msg in conn: | ||||
|         print(msg) | ||||
|  | ||||
|  | ||||
| loop = get_event_loop() | ||||
|  | ||||
| async def main(): | ||||
|     server_keypair = SigningKey(b64decode(config['private'][:-8])[:32]) | ||||
| server = SHSServer('localhost', 8008, server_keypair, loop=loop) | ||||
| server.on_connect(main) | ||||
| server.listen() | ||||
|     server = SHSServer('localhost', 8008, server_keypair) | ||||
|     server.on_connect(_on_connect) | ||||
|     await server.listen() | ||||
|  | ||||
| loop = get_event_loop() | ||||
| loop.run_until_complete(main()) | ||||
| loop.run_forever() | ||||
| loop.close() | ||||
|   | ||||
| @@ -19,7 +19,7 @@ | ||||
| # SOFTWARE. | ||||
|  | ||||
|  | ||||
| from asyncio import open_connection, start_server, ensure_future | ||||
| import asyncio | ||||
|  | ||||
| from async_generator import async_generator, yield_ | ||||
|  | ||||
| @@ -31,10 +31,10 @@ class SHSClientException(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| class SHSSocket(object): | ||||
|     def __init__(self, loop): | ||||
|         self.loop = loop | ||||
|         self._on_connect = None | ||||
| class SHSDuplexStream(object): | ||||
|     def __init__(self): | ||||
|         self.write_stream = None | ||||
|         self.read_stream = None | ||||
|  | ||||
|     def write(self, data): | ||||
|         self.write_stream.write(data) | ||||
| @@ -42,24 +42,35 @@ class SHSSocket(object): | ||||
|     async def read(self): | ||||
|         return await self.read_stream.read() | ||||
|  | ||||
|     def disconnect(self): | ||||
|         self.writer.close() | ||||
|     def close(self): | ||||
|         self.write_stream.close() | ||||
|         self.read_stream.close() | ||||
|  | ||||
|     @async_generator | ||||
|     async def __aiter__(self): | ||||
|         async for msg in self.read_stream: | ||||
|             await yield_(msg) | ||||
|  | ||||
|  | ||||
| class SHSEndpoint(object): | ||||
|     def __init__(self): | ||||
|         self._on_connect = None | ||||
|         self.crypto = None | ||||
|  | ||||
|     def on_connect(self, cb): | ||||
|         self._on_connect = cb | ||||
|  | ||||
|     def disconnect(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
| class SHSServer(SHSSocket): | ||||
|     def __init__(self, host, port, server_kp, application_key=None, loop=None): | ||||
|         super(SHSServer, self).__init__(loop) | ||||
|  | ||||
| class SHSServer(SHSEndpoint): | ||||
|     def __init__(self, host, port, server_kp, application_key=None): | ||||
|         super(SHSServer, self).__init__() | ||||
|         self.host = host | ||||
|         self.port = port | ||||
|         self.crypto = SHSServerCrypto(server_kp, application_key=application_key) | ||||
|         self.connections = [] | ||||
|  | ||||
|     async def _handshake(self, reader, writer): | ||||
|         data = await reader.readexactly(64) | ||||
| @@ -77,23 +88,39 @@ class SHSServer(SHSSocket): | ||||
|     async def handle_connection(self, reader, writer): | ||||
|         self.crypto.clean() | ||||
|         await self._handshake(reader, writer) | ||||
|  | ||||
|         keys = self.crypto.get_box_keys() | ||||
|         self.crypto.clean() | ||||
|  | ||||
|         self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys) | ||||
|         self.writer = writer | ||||
|         conn = SHSServerConnection.from_byte_streams(reader, writer, **keys) | ||||
|         self.connections.append(conn) | ||||
|  | ||||
|         if self._on_connect: | ||||
|             ensure_future(self._on_connect(), loop=self.loop) | ||||
|             asyncio.ensure_future(self._on_connect(conn)) | ||||
|  | ||||
|     def listen(self): | ||||
|         self.loop.run_until_complete(start_server(self.handle_connection, self.host, self.port, loop=self.loop)) | ||||
|     async def listen(self): | ||||
|         await asyncio.start_server(self.handle_connection, self.host, self.port) | ||||
|  | ||||
|     def disconnect(self): | ||||
|         for connection in self.connections: | ||||
|             connection.close() | ||||
|  | ||||
|  | ||||
| class SHSClient(SHSSocket): | ||||
|     def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): | ||||
|         super(SHSClient, self).__init__(loop) | ||||
| class SHSServerConnection(SHSDuplexStream): | ||||
|     def __init__(self, read_stream, write_stream): | ||||
|         super(SHSServerConnection, self).__init__() | ||||
|         self.read_stream = read_stream | ||||
|         self.write_stream = write_stream | ||||
|  | ||||
|     @classmethod | ||||
|     def from_byte_streams(cls, reader, writer, **keys): | ||||
|         reader, writer = get_stream_pair(reader, writer, **keys) | ||||
|         return cls(reader, writer) | ||||
|  | ||||
|  | ||||
| class SHSClient(SHSDuplexStream, SHSEndpoint): | ||||
|     def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None): | ||||
|         SHSDuplexStream.__init__(self) | ||||
|         SHSEndpoint.__init__(self) | ||||
|         self.host = host | ||||
|         self.port = port | ||||
|         self.crypto = SHSClientCrypto(client_kp, server_pub_key, ephemeral_key=ephemeral_key, | ||||
| @@ -112,8 +139,8 @@ class SHSClient(SHSSocket): | ||||
|         if not self.crypto.verify_server_accept(data): | ||||
|             raise SHSClientException('Server accept is not valid') | ||||
|  | ||||
|     async def connect(self): | ||||
|         reader, writer = await open_connection(self.host, self.port, loop=self.loop) | ||||
|     async def open(self): | ||||
|         reader, writer = await asyncio.open_connection(self.host, self.port) | ||||
|         await self._handshake(reader, writer) | ||||
|  | ||||
|         keys = self.crypto.get_box_keys() | ||||
| @@ -123,3 +150,6 @@ class SHSClient(SHSSocket): | ||||
|         self.writer = writer | ||||
|         if self._on_connect: | ||||
|             await self._on_connect() | ||||
|  | ||||
|     def disconnect(self): | ||||
|         self.close() | ||||
|   | ||||
| @@ -20,11 +20,10 @@ | ||||
|  | ||||
|  | ||||
| import pytest | ||||
| from io import BytesIO | ||||
|  | ||||
| from .test_crypto import (CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE) | ||||
| from secret_handshake.boxstream import BoxStream, UnboxStream, HEADER_LENGTH | ||||
| from secret_handshake.util import async_comprehend | ||||
| from secret_handshake.util import AsyncBuffer, async_comprehend | ||||
|  | ||||
| MESSAGE_1 = (b'\xcev\xedE\x06l\x02\x13\xc8\x17V\xfa\x8bZ?\x88B%O\xb0L\x9f\x8e\x8c0y\x1dv\xc0\xc9\xf6\x9d\xc2\xdf\xdb' | ||||
|              b'\xee\x9d') | ||||
| @@ -34,13 +33,6 @@ MESSAGE_3 = (b'\xcbYY\xf1\x0f\xa5O\x13r\xa6"\x15\xc5\x9d\r.*\x0b\x92\x10m\xa6(\x | ||||
| MESSAGE_CLOSED = b'\xb1\x14hU\'\xb5M\xa6"\x03\x9duy\xa1\xd4evW,\xdcE\x18\xe4+ C4\xe8h\x96\xed\xc5\x94\x80' | ||||
|  | ||||
|  | ||||
| class AsyncBuffer(BytesIO): | ||||
|     """Just a BytesIO with an async read method.""" | ||||
|     async def read(self, n=None): | ||||
|         return super(AsyncBuffer, self).read(n) | ||||
|     readexactly = read | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_boxstream(): | ||||
|     buffer = AsyncBuffer() | ||||
|   | ||||
							
								
								
									
										129
									
								
								secret_handshake/test_network.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										129
									
								
								secret_handshake/test_network.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,129 @@ | ||||
| # Copyright (c) 2017 PySecretHandshake contributors (see AUTHORS for more details) | ||||
| # | ||||
| # Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| # of this software and associated documentation files (the "Software"), to deal | ||||
| # in the Software without restriction, including without limitation the rights | ||||
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| # copies of the Software, and to permit persons to whom the Software is | ||||
| # furnished to do so, subject to the following conditions: | ||||
| # | ||||
| # The above copyright notice and this permission notice shall be included in all | ||||
| # copies or substantial portions of the Software. | ||||
| # | ||||
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| # SOFTWARE. | ||||
|  | ||||
| import os | ||||
| from asyncio import Event, wait_for | ||||
|  | ||||
| import pytest | ||||
| from nacl.signing import SigningKey | ||||
|  | ||||
| from secret_handshake.util import AsyncBuffer | ||||
|  | ||||
|  | ||||
| class DummyCrypto(object): | ||||
|     """Dummy crypto module, pretends everything is fine.""" | ||||
|     def verify_server_challenge(self, data): | ||||
|         return True | ||||
|  | ||||
|     def verify_challenge(self, data): | ||||
|         return True | ||||
|  | ||||
|     def verify_server_accept(self, data): | ||||
|         return True | ||||
|  | ||||
|     def generate_challenge(self): | ||||
|         return b'CHALLENGE' | ||||
|  | ||||
|     def generate_client_auth(self): | ||||
|         return b'AUTH' | ||||
|  | ||||
|     def verify_client_auth(self, data): | ||||
|         return True | ||||
|  | ||||
|     def generate_accept(self): | ||||
|         return b'ACCEPT' | ||||
|  | ||||
|     def get_box_keys(self): | ||||
|         return { | ||||
|             'encrypt_key': b'x' * 32, | ||||
|             'encrypt_nonce': b'x' * 32, | ||||
|             'decrypt_key': b'x' * 32, | ||||
|             'decrypt_nonce': b'x' * 32 | ||||
|         } | ||||
|  | ||||
|     def clean(self): | ||||
|         return | ||||
|  | ||||
|  | ||||
| def _dummy_boxstream(stream, **kwargs): | ||||
|     """Identity boxstream, no tansformation.""" | ||||
|     return stream | ||||
|  | ||||
|  | ||||
| def _client_stream_mocker(): | ||||
|     reader = AsyncBuffer(b'xxx') | ||||
|     writer = AsyncBuffer(b'xxx') | ||||
|  | ||||
|     async def _create_mock_streams(host, port): | ||||
|         return reader, writer | ||||
|  | ||||
|     return reader, writer, _create_mock_streams | ||||
|  | ||||
|  | ||||
| def _server_stream_mocker(): | ||||
|     reader = AsyncBuffer(b'xxx') | ||||
|     writer = AsyncBuffer(b'xxx') | ||||
|  | ||||
|     async def _create_mock_server(cb, host, port): | ||||
|         await cb(reader, writer) | ||||
|  | ||||
|     return reader, writer, _create_mock_server | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_client(mocker): | ||||
|     reader, writer, _create_mock_streams = _client_stream_mocker() | ||||
|     mocker.patch('asyncio.open_connection', new=_create_mock_streams) | ||||
|     mocker.patch('secret_handshake.boxstream.BoxStream', new=_dummy_boxstream) | ||||
|     mocker.patch('secret_handshake.boxstream.UnboxStream', new=_dummy_boxstream) | ||||
|  | ||||
|     from secret_handshake import SHSClient | ||||
|  | ||||
|     client = SHSClient('shop.local', 1111, SigningKey.generate(), os.urandom(32)) | ||||
|     client.crypto = DummyCrypto() | ||||
|  | ||||
|     await client.open() | ||||
|     reader.append(b'TEST') | ||||
|     assert (await client.read()) == b'TEST' | ||||
|     client.disconnect() | ||||
|  | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| async def test_server(mocker): | ||||
|     from secret_handshake import SHSServer | ||||
|  | ||||
|     resolve = Event() | ||||
|  | ||||
|     async def _on_connect(conn): | ||||
|         server.disconnect() | ||||
|         resolve.set() | ||||
|  | ||||
|     reader, writer, _create_mock_server = _server_stream_mocker() | ||||
|     mocker.patch('asyncio.start_server', new=_create_mock_server) | ||||
|     mocker.patch('secret_handshake.boxstream.BoxStream', new=_dummy_boxstream) | ||||
|     mocker.patch('secret_handshake.boxstream.UnboxStream', new=_dummy_boxstream) | ||||
|  | ||||
|     server = SHSServer('shop.local', 1111, SigningKey.generate(), os.urandom(32)) | ||||
|     server.crypto = DummyCrypto() | ||||
|  | ||||
|     server.on_connect(_on_connect) | ||||
|  | ||||
|     await server.listen() | ||||
|     await wait_for(resolve.wait(), 5) | ||||
| @@ -1,9 +1,46 @@ | ||||
| # Copyright (c) 2017 PySecretHandshake contributors (see AUTHORS for more details) | ||||
| # | ||||
| # Permission is hereby granted, free of charge, to any person obtaining a copy | ||||
| # of this software and associated documentation files (the "Software"), to deal | ||||
| # in the Software without restriction, including without limitation the rights | ||||
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||||
| # copies of the Software, and to permit persons to whom the Software is | ||||
| # furnished to do so, subject to the following conditions: | ||||
| # | ||||
| # The above copyright notice and this permission notice shall be included in all | ||||
| # copies or substantial portions of the Software. | ||||
| # | ||||
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||||
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||||
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||
| # SOFTWARE. | ||||
|  | ||||
|  | ||||
| import struct | ||||
| from io import BytesIO | ||||
|  | ||||
|  | ||||
| NONCE_SIZE = 24 | ||||
| MAX_NONCE = (8 * NONCE_SIZE) | ||||
|  | ||||
|  | ||||
| class AsyncBuffer(BytesIO): | ||||
|     """Just a BytesIO with an async read method.""" | ||||
|     async def read(self, n=None): | ||||
|         v = super(AsyncBuffer, self).read(n) | ||||
|         return v | ||||
|     readexactly = read | ||||
|  | ||||
|     def append(self, data): | ||||
|         """Append data to the buffer without changing the current position.""" | ||||
|         pos = self.tell() | ||||
|         self.write(data) | ||||
|         self.seek(pos) | ||||
|  | ||||
|  | ||||
| async def async_comprehend(generator): | ||||
|     """Emulate ``[elem async for elem in generator]``.""" | ||||
|     results = [] | ||||
|   | ||||
		Reference in New Issue
	
	Block a user