diff --git a/examples/test_client.py b/examples/test_client.py index 962ea69..852541d 100644 --- a/examples/test_client.py +++ b/examples/test_client.py @@ -14,15 +14,14 @@ with open(os.path.expanduser('~/.ssb/secret')) as f: async def main(): + server_pub_key = b64decode(config['public'][:-8]) + client = SHSClient('localhost', 8008, SigningKey.generate(), server_pub_key) + await client.open() + async for msg in client: print(msg) loop = get_event_loop() - -server_pub_key = b64decode(config['public'][:-8]) -client = SHSClient('localhost', 8008, SigningKey.generate(), server_pub_key, loop=loop) -client.connect() loop.run_until_complete(main()) - loop.close() diff --git a/examples/test_server.py b/examples/test_server.py index 4e6721c..1d0399e 100644 --- a/examples/test_server.py +++ b/examples/test_server.py @@ -13,17 +13,18 @@ with open(os.path.expanduser('~/.ssb/secret')) as f: config = yaml.load(f) -async def main(): - async for msg in server: +async def _on_connect(conn): + async for msg in conn: print(msg) +async def main(): + server_keypair = SigningKey(b64decode(config['private'][:-8])[:32]) + server = SHSServer('localhost', 8008, server_keypair) + server.on_connect(_on_connect) + await server.listen() + loop = get_event_loop() - -server_keypair = SigningKey(b64decode(config['private'][:-8])[:32]) -server = SHSServer('localhost', 8008, server_keypair, loop=loop) -server.on_connect(main) -server.listen() - +loop.run_until_complete(main()) loop.run_forever() loop.close() diff --git a/secret_handshake/network.py b/secret_handshake/network.py index 6875bf0..b510263 100644 --- a/secret_handshake/network.py +++ b/secret_handshake/network.py @@ -19,7 +19,7 @@ # SOFTWARE. -from asyncio import open_connection, start_server, ensure_future +import asyncio from async_generator import async_generator, yield_ @@ -31,10 +31,10 @@ class SHSClientException(Exception): pass -class SHSSocket(object): - def __init__(self, loop): - self.loop = loop - self._on_connect = None +class SHSDuplexStream(object): + def __init__(self): + self.write_stream = None + self.read_stream = None def write(self, data): self.write_stream.write(data) @@ -42,24 +42,35 @@ class SHSSocket(object): async def read(self): return await self.read_stream.read() - def disconnect(self): - self.writer.close() + def close(self): + self.write_stream.close() + self.read_stream.close() @async_generator async def __aiter__(self): async for msg in self.read_stream: await yield_(msg) + +class SHSEndpoint(object): + def __init__(self): + self._on_connect = None + self.crypto = None + def on_connect(self, cb): self._on_connect = cb + def disconnect(self): + raise NotImplementedError -class SHSServer(SHSSocket): - def __init__(self, host, port, server_kp, application_key=None, loop=None): - super(SHSServer, self).__init__(loop) + +class SHSServer(SHSEndpoint): + def __init__(self, host, port, server_kp, application_key=None): + super(SHSServer, self).__init__() self.host = host self.port = port self.crypto = SHSServerCrypto(server_kp, application_key=application_key) + self.connections = [] async def _handshake(self, reader, writer): data = await reader.readexactly(64) @@ -77,23 +88,39 @@ class SHSServer(SHSSocket): async def handle_connection(self, reader, writer): self.crypto.clean() await self._handshake(reader, writer) - keys = self.crypto.get_box_keys() self.crypto.clean() - self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys) - self.writer = writer + conn = SHSServerConnection.from_byte_streams(reader, writer, **keys) + self.connections.append(conn) if self._on_connect: - ensure_future(self._on_connect(), loop=self.loop) + asyncio.ensure_future(self._on_connect(conn)) - def listen(self): - self.loop.run_until_complete(start_server(self.handle_connection, self.host, self.port, loop=self.loop)) + async def listen(self): + await asyncio.start_server(self.handle_connection, self.host, self.port) + + def disconnect(self): + for connection in self.connections: + connection.close() -class SHSClient(SHSSocket): - def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): - super(SHSClient, self).__init__(loop) +class SHSServerConnection(SHSDuplexStream): + def __init__(self, read_stream, write_stream): + super(SHSServerConnection, self).__init__() + self.read_stream = read_stream + self.write_stream = write_stream + + @classmethod + def from_byte_streams(cls, reader, writer, **keys): + reader, writer = get_stream_pair(reader, writer, **keys) + return cls(reader, writer) + + +class SHSClient(SHSDuplexStream, SHSEndpoint): + def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None): + SHSDuplexStream.__init__(self) + SHSEndpoint.__init__(self) self.host = host self.port = port self.crypto = SHSClientCrypto(client_kp, server_pub_key, ephemeral_key=ephemeral_key, @@ -112,8 +139,8 @@ class SHSClient(SHSSocket): if not self.crypto.verify_server_accept(data): raise SHSClientException('Server accept is not valid') - async def connect(self): - reader, writer = await open_connection(self.host, self.port, loop=self.loop) + async def open(self): + reader, writer = await asyncio.open_connection(self.host, self.port) await self._handshake(reader, writer) keys = self.crypto.get_box_keys() @@ -123,3 +150,6 @@ class SHSClient(SHSSocket): self.writer = writer if self._on_connect: await self._on_connect() + + def disconnect(self): + self.close() diff --git a/secret_handshake/test_boxstream.py b/secret_handshake/test_boxstream.py index 1e7b477..fc0ad98 100644 --- a/secret_handshake/test_boxstream.py +++ b/secret_handshake/test_boxstream.py @@ -20,11 +20,10 @@ import pytest -from io import BytesIO from .test_crypto import (CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE) from secret_handshake.boxstream import BoxStream, UnboxStream, HEADER_LENGTH -from secret_handshake.util import async_comprehend +from secret_handshake.util import AsyncBuffer, async_comprehend MESSAGE_1 = (b'\xcev\xedE\x06l\x02\x13\xc8\x17V\xfa\x8bZ?\x88B%O\xb0L\x9f\x8e\x8c0y\x1dv\xc0\xc9\xf6\x9d\xc2\xdf\xdb' b'\xee\x9d') @@ -34,13 +33,6 @@ MESSAGE_3 = (b'\xcbYY\xf1\x0f\xa5O\x13r\xa6"\x15\xc5\x9d\r.*\x0b\x92\x10m\xa6(\x MESSAGE_CLOSED = b'\xb1\x14hU\'\xb5M\xa6"\x03\x9duy\xa1\xd4evW,\xdcE\x18\xe4+ C4\xe8h\x96\xed\xc5\x94\x80' -class AsyncBuffer(BytesIO): - """Just a BytesIO with an async read method.""" - async def read(self, n=None): - return super(AsyncBuffer, self).read(n) - readexactly = read - - @pytest.mark.asyncio async def test_boxstream(): buffer = AsyncBuffer() diff --git a/secret_handshake/test_network.py b/secret_handshake/test_network.py new file mode 100644 index 0000000..f19081d --- /dev/null +++ b/secret_handshake/test_network.py @@ -0,0 +1,129 @@ +# Copyright (c) 2017 PySecretHandshake contributors (see AUTHORS for more details) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import os +from asyncio import Event, wait_for + +import pytest +from nacl.signing import SigningKey + +from secret_handshake.util import AsyncBuffer + + +class DummyCrypto(object): + """Dummy crypto module, pretends everything is fine.""" + def verify_server_challenge(self, data): + return True + + def verify_challenge(self, data): + return True + + def verify_server_accept(self, data): + return True + + def generate_challenge(self): + return b'CHALLENGE' + + def generate_client_auth(self): + return b'AUTH' + + def verify_client_auth(self, data): + return True + + def generate_accept(self): + return b'ACCEPT' + + def get_box_keys(self): + return { + 'encrypt_key': b'x' * 32, + 'encrypt_nonce': b'x' * 32, + 'decrypt_key': b'x' * 32, + 'decrypt_nonce': b'x' * 32 + } + + def clean(self): + return + + +def _dummy_boxstream(stream, **kwargs): + """Identity boxstream, no tansformation.""" + return stream + + +def _client_stream_mocker(): + reader = AsyncBuffer(b'xxx') + writer = AsyncBuffer(b'xxx') + + async def _create_mock_streams(host, port): + return reader, writer + + return reader, writer, _create_mock_streams + + +def _server_stream_mocker(): + reader = AsyncBuffer(b'xxx') + writer = AsyncBuffer(b'xxx') + + async def _create_mock_server(cb, host, port): + await cb(reader, writer) + + return reader, writer, _create_mock_server + + +@pytest.mark.asyncio +async def test_client(mocker): + reader, writer, _create_mock_streams = _client_stream_mocker() + mocker.patch('asyncio.open_connection', new=_create_mock_streams) + mocker.patch('secret_handshake.boxstream.BoxStream', new=_dummy_boxstream) + mocker.patch('secret_handshake.boxstream.UnboxStream', new=_dummy_boxstream) + + from secret_handshake import SHSClient + + client = SHSClient('shop.local', 1111, SigningKey.generate(), os.urandom(32)) + client.crypto = DummyCrypto() + + await client.open() + reader.append(b'TEST') + assert (await client.read()) == b'TEST' + client.disconnect() + + +@pytest.mark.asyncio +async def test_server(mocker): + from secret_handshake import SHSServer + + resolve = Event() + + async def _on_connect(conn): + server.disconnect() + resolve.set() + + reader, writer, _create_mock_server = _server_stream_mocker() + mocker.patch('asyncio.start_server', new=_create_mock_server) + mocker.patch('secret_handshake.boxstream.BoxStream', new=_dummy_boxstream) + mocker.patch('secret_handshake.boxstream.UnboxStream', new=_dummy_boxstream) + + server = SHSServer('shop.local', 1111, SigningKey.generate(), os.urandom(32)) + server.crypto = DummyCrypto() + + server.on_connect(_on_connect) + + await server.listen() + await wait_for(resolve.wait(), 5) diff --git a/secret_handshake/util.py b/secret_handshake/util.py index 8972632..d0df4bd 100644 --- a/secret_handshake/util.py +++ b/secret_handshake/util.py @@ -1,9 +1,46 @@ +# Copyright (c) 2017 PySecretHandshake contributors (see AUTHORS for more details) +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + import struct +from io import BytesIO + NONCE_SIZE = 24 MAX_NONCE = (8 * NONCE_SIZE) +class AsyncBuffer(BytesIO): + """Just a BytesIO with an async read method.""" + async def read(self, n=None): + v = super(AsyncBuffer, self).read(n) + return v + readexactly = read + + def append(self, data): + """Append data to the buffer without changing the current position.""" + pos = self.tell() + self.write(data) + self.seek(pos) + + async def async_comprehend(generator): """Emulate ``[elem async for elem in generator]``.""" results = []