test: Add tests for secret_handshake.network

This commit is contained in:
Gergely Polonkai 2023-11-03 19:51:50 +01:00
parent 1d07f9ba02
commit 1ad7cb4e5e
No known key found for this signature in database
GPG Key ID: 2D2885533B869ED4
2 changed files with 146 additions and 2 deletions

View File

@ -91,7 +91,7 @@ class SHSEndpoint:
self._on_connect = cb
def disconnect(self) -> None:
def disconnect(self) -> None: # pragma: no cover
"""Disconnect the endpoint"""
raise NotImplementedError

View File

@ -24,7 +24,7 @@
from asyncio import Event, wait_for
import os
from typing import Any, Awaitable, Callable, Tuple
from typing import Any, Awaitable, Callable, Literal, Tuple
from nacl.signing import SigningKey
import pytest
@ -32,6 +32,7 @@ from pytest_mock import MockerFixture
from secret_handshake import SHSClient, SHSServer
from secret_handshake.boxstream import BoxStreamKeys
from secret_handshake.network import SHSClientException, SHSDuplexStream
from .helpers import AsyncBuffer
@ -169,3 +170,146 @@ async def test_server(mocker: MockerFixture) -> None:
await server.listen()
await wait_for(resolve.wait(), 5)
def test_duplex_write(mocker: MockerFixture) -> None:
"""Test the writing capabilities of the duplex stream"""
d_stream = SHSDuplexStream()
d_stream.write_stream = mocker.AsyncMock()
d_stream.write(b"thing")
assert d_stream.write_stream
d_stream.write_stream.write.assert_called_once_with(b"thing") # type: ignore[attr-defined]
def test_duplex_close_no_write_stream() -> None:
"""Test if SHSDuplexStreams close method doesnt fail if there is no write stream"""
d_stream = SHSDuplexStream()
assert d_stream.write_stream is None
d_stream.close()
# We cannot really do assertions here. If there is not set (it is None), the above call would fail
def test_duplex_stream_aiter() -> None:
"""Test if the __aiter__ method of SHSDuplexStream returns the stream itself"""
d_stream = SHSDuplexStream()
assert d_stream.__aiter__() is d_stream # pylint: disable=unnecessary-dunder-call
async def test_duplex_stream_anext(mocker: MockerFixture) -> None:
"""Test if the __anext__ method of SHSDuplexStream reads from the stream"""
d_stream = SHSDuplexStream()
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=b"test"))
assert await d_stream.__anext__() == b"test" # pylint: disable=unnecessary-dunder-call
async def test_duplex_stream_anext_eof(mocker: MockerFixture) -> None:
"""Test if SHSDuplexStream.__anext__ breaks iteration if theres no data to read"""
d_stream = SHSDuplexStream()
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=None))
with pytest.raises(StopAsyncIteration):
assert await d_stream.__anext__() # pylint: disable=unnecessary-dunder-call
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_auth"))
async def test_server_fail_handshake(
fail_type: Literal["verify_challenge", "verify_auth"], mocker: MockerFixture
) -> None:
"""Test if a failing handshake results in an SHSClientException"""
server = SHSServer("127.0.0.1", 8754, SigningKey.generate())
if fail_type == "verify_challenge":
expected_error = "Client challenge is not valid"
elif fail_type == "verify_auth": # pragma: no branch
expected_error = "Client auth is not valid"
mocker.patch.object(server.crypto, "verify_challenge", return_value=fail_type != "verify_challenge")
mocker.patch.object(server.crypto, "verify_client_auth", return_value=fail_type != "verify_auth")
with pytest.raises(SHSClientException) as ctx:
await server._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
assert str(ctx.value) == expected_error
async def test_server_no_connect_callback(mocker: MockerFixture) -> None:
"""Test if SHSServer.handle_connection works without an on_connect callback"""
server = SHSServer("127.0.0.1", 7429, SigningKey.generate())
mocker.patch.object(server, "_handshake", return_value=None)
mocker.patch.object(
server.crypto,
"get_box_keys",
return_value={
"decrypt_key": b"d" * 32,
"decrypt_nonce": b"dnonce",
"encrypt_key": b"e" * 32,
"encrypt_nonce": b"enonce",
},
)
await server.handle_connection(AsyncBuffer(), AsyncBuffer())
# No assertion here. We should get here without a problem
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_accept"))
async def test_client_fail_handshake(
fail_type: Literal["verify_challenge", "verify_accept"], mocker: MockerFixture
) -> None:
"""Test if a failing handshake results in an SHSClientException"""
client = SHSClient("127.0.0.1", 8754, SigningKey.generate(), b"s" * 32)
if fail_type == "verify_challenge":
expected_error = "Server challenge is not valid"
elif fail_type == "verify_accept": # pragma: no branch
expected_error = "Server accept is not valid"
mocker.patch.object(client.crypto, "verify_server_challenge", return_value=fail_type != "verify_challenge")
mocker.patch.object(client.crypto, "verify_server_accept", return_value=fail_type != "verify_accept")
mocker.patch.object(client.crypto, "generate_client_auth", return_value=b"ca" * 16)
with pytest.raises(SHSClientException) as ctx:
await client._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
assert str(ctx.value) == expected_error
@pytest.mark.parametrize("with_callback", (True, False))
async def test_client_open(with_callback: bool, mocker: MockerFixture) -> None:
"""Test if SHSServer.handle_connection works with and without an on_connect callback"""
client = SHSClient("127.0.0.1", 7429, SigningKey.generate(), SigningKey.generate().verify_key.encode())
mocker.patch("secret_handshake.network.open_connection", return_value=(AsyncBuffer(), AsyncBuffer()))
mocker.patch.object(client, "_handshake", return_value=None)
mocker.patch.object(
client.crypto,
"get_box_keys",
return_value={
"decrypt_key": b"d" * 32,
"decrypt_nonce": b"dnonce",
"encrypt_key": b"e" * 32,
"encrypt_nonce": b"enonce",
},
)
if with_callback:
callback = mocker.AsyncMock()
client.on_connect(callback)
await client.open()
if with_callback:
callback.assert_awaited_once_with(client)