diff --git a/secret_handshake/boxstream.py b/secret_handshake/boxstream.py index 83ce882..133a6e8 100644 --- a/secret_handshake/boxstream.py +++ b/secret_handshake/boxstream.py @@ -53,6 +53,7 @@ def get_stream_pair( # pylint: disable=too-many-arguments decrypt_nonce: bytes, encrypt_key: bytes, encrypt_nonce: bytes, + # We have kwargs here to devour any extra parameters we get, e.g. from the output of SHSCryptoBase.get_box_keys() **kwargs: Any, ) -> Tuple["UnboxStream", "BoxStream"]: """Create a new duplex box stream""" diff --git a/tests/test_boxstream.py b/tests/test_boxstream.py index 43f1253..55c8a78 100644 --- a/tests/test_boxstream.py +++ b/tests/test_boxstream.py @@ -22,7 +22,11 @@ """Tests for the box stream""" -from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream +from asyncio import IncompleteReadError + +from pytest_mock import MockerFixture + +from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream, get_stream_pair from .helpers import AsyncBuffer, async_comprehend from .test_crypto import CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE @@ -74,6 +78,18 @@ async def test_unboxstream() -> None: assert unbox_stream.closed +async def test_unboxstream_header_read_error(mocker: MockerFixture) -> None: + """Test that we can handle errors during header read""" + + buffer = AsyncBuffer() + mocker.patch.object(buffer, "readexactly", side_effect=IncompleteReadError(b"", HEADER_LENGTH)) + + unbox_stream = UnboxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE) + + assert await unbox_stream.read() is None + assert unbox_stream.closed is True + + async def test_long_packets() -> None: """Test for receiving long packets""" @@ -94,3 +110,28 @@ async def test_long_packets() -> None: assert first_packet == data[:4096] second_packet = await unbox_stream.read() assert second_packet == data[4096:] + + +def test_get_stream_pair() -> None: + """Test the get_stream_pair() function""" + + read_buffer = AsyncBuffer() + write_buffer = AsyncBuffer() + + read_stream, write_stream = get_stream_pair( + read_buffer, + write_buffer, + decrypt_key=b"d" * 32, + decrypt_nonce=b"dnonce", + encrypt_key=b"e" * 32, + encrypt_nonce=b"enonce", + ) + + assert isinstance(read_stream, UnboxStream) + assert isinstance(write_stream, BoxStream) + + assert read_stream.key == b"d" * 32 + assert read_stream.nonce == b"dnonce" + + assert write_stream.key == b"e" * 32 + assert write_stream.nonce == b"enonce"