test: Add tests for secret_handshake.network
This commit is contained in:
parent
1d07f9ba02
commit
1ad7cb4e5e
@ -91,7 +91,7 @@ class SHSEndpoint:
|
||||
|
||||
self._on_connect = cb
|
||||
|
||||
def disconnect(self) -> None:
|
||||
def disconnect(self) -> None: # pragma: no cover
|
||||
"""Disconnect the endpoint"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
@ -24,7 +24,7 @@
|
||||
|
||||
from asyncio import Event, wait_for
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable, Tuple
|
||||
from typing import Any, Awaitable, Callable, Literal, Tuple
|
||||
|
||||
from nacl.signing import SigningKey
|
||||
import pytest
|
||||
@ -32,6 +32,7 @@ from pytest_mock import MockerFixture
|
||||
|
||||
from secret_handshake import SHSClient, SHSServer
|
||||
from secret_handshake.boxstream import BoxStreamKeys
|
||||
from secret_handshake.network import SHSClientException, SHSDuplexStream
|
||||
|
||||
from .helpers import AsyncBuffer
|
||||
|
||||
@ -169,3 +170,146 @@ async def test_server(mocker: MockerFixture) -> None:
|
||||
|
||||
await server.listen()
|
||||
await wait_for(resolve.wait(), 5)
|
||||
|
||||
|
||||
def test_duplex_write(mocker: MockerFixture) -> None:
|
||||
"""Test the writing capabilities of the duplex stream"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
d_stream.write_stream = mocker.AsyncMock()
|
||||
d_stream.write(b"thing")
|
||||
|
||||
assert d_stream.write_stream
|
||||
|
||||
d_stream.write_stream.write.assert_called_once_with(b"thing") # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_duplex_close_no_write_stream() -> None:
|
||||
"""Test if SHSDuplexStream’s close method doesn’t fail if there is no write stream"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
assert d_stream.write_stream is None
|
||||
d_stream.close()
|
||||
|
||||
# We cannot really do assertions here. If there is not set (it is None), the above call would fail
|
||||
|
||||
|
||||
def test_duplex_stream_aiter() -> None:
|
||||
"""Test if the __aiter__ method of SHSDuplexStream returns the stream itself"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
assert d_stream.__aiter__() is d_stream # pylint: disable=unnecessary-dunder-call
|
||||
|
||||
|
||||
async def test_duplex_stream_anext(mocker: MockerFixture) -> None:
|
||||
"""Test if the __anext__ method of SHSDuplexStream reads from the stream"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=b"test"))
|
||||
|
||||
assert await d_stream.__anext__() == b"test" # pylint: disable=unnecessary-dunder-call
|
||||
|
||||
|
||||
async def test_duplex_stream_anext_eof(mocker: MockerFixture) -> None:
|
||||
"""Test if SHSDuplexStream.__anext__ breaks iteration if there’s no data to read"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=None))
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
assert await d_stream.__anext__() # pylint: disable=unnecessary-dunder-call
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_auth"))
|
||||
async def test_server_fail_handshake(
|
||||
fail_type: Literal["verify_challenge", "verify_auth"], mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test if a failing handshake results in an SHSClientException"""
|
||||
|
||||
server = SHSServer("127.0.0.1", 8754, SigningKey.generate())
|
||||
|
||||
if fail_type == "verify_challenge":
|
||||
expected_error = "Client challenge is not valid"
|
||||
elif fail_type == "verify_auth": # pragma: no branch
|
||||
expected_error = "Client auth is not valid"
|
||||
|
||||
mocker.patch.object(server.crypto, "verify_challenge", return_value=fail_type != "verify_challenge")
|
||||
mocker.patch.object(server.crypto, "verify_client_auth", return_value=fail_type != "verify_auth")
|
||||
|
||||
with pytest.raises(SHSClientException) as ctx:
|
||||
await server._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
|
||||
|
||||
assert str(ctx.value) == expected_error
|
||||
|
||||
|
||||
async def test_server_no_connect_callback(mocker: MockerFixture) -> None:
|
||||
"""Test if SHSServer.handle_connection works without an on_connect callback"""
|
||||
|
||||
server = SHSServer("127.0.0.1", 7429, SigningKey.generate())
|
||||
mocker.patch.object(server, "_handshake", return_value=None)
|
||||
mocker.patch.object(
|
||||
server.crypto,
|
||||
"get_box_keys",
|
||||
return_value={
|
||||
"decrypt_key": b"d" * 32,
|
||||
"decrypt_nonce": b"dnonce",
|
||||
"encrypt_key": b"e" * 32,
|
||||
"encrypt_nonce": b"enonce",
|
||||
},
|
||||
)
|
||||
|
||||
await server.handle_connection(AsyncBuffer(), AsyncBuffer())
|
||||
|
||||
# No assertion here. We should get here without a problem
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_accept"))
|
||||
async def test_client_fail_handshake(
|
||||
fail_type: Literal["verify_challenge", "verify_accept"], mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test if a failing handshake results in an SHSClientException"""
|
||||
|
||||
client = SHSClient("127.0.0.1", 8754, SigningKey.generate(), b"s" * 32)
|
||||
|
||||
if fail_type == "verify_challenge":
|
||||
expected_error = "Server challenge is not valid"
|
||||
elif fail_type == "verify_accept": # pragma: no branch
|
||||
expected_error = "Server accept is not valid"
|
||||
|
||||
mocker.patch.object(client.crypto, "verify_server_challenge", return_value=fail_type != "verify_challenge")
|
||||
mocker.patch.object(client.crypto, "verify_server_accept", return_value=fail_type != "verify_accept")
|
||||
mocker.patch.object(client.crypto, "generate_client_auth", return_value=b"ca" * 16)
|
||||
|
||||
with pytest.raises(SHSClientException) as ctx:
|
||||
await client._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
|
||||
|
||||
assert str(ctx.value) == expected_error
|
||||
|
||||
|
||||
@pytest.mark.parametrize("with_callback", (True, False))
|
||||
async def test_client_open(with_callback: bool, mocker: MockerFixture) -> None:
|
||||
"""Test if SHSServer.handle_connection works with and without an on_connect callback"""
|
||||
|
||||
client = SHSClient("127.0.0.1", 7429, SigningKey.generate(), SigningKey.generate().verify_key.encode())
|
||||
|
||||
mocker.patch("secret_handshake.network.open_connection", return_value=(AsyncBuffer(), AsyncBuffer()))
|
||||
mocker.patch.object(client, "_handshake", return_value=None)
|
||||
mocker.patch.object(
|
||||
client.crypto,
|
||||
"get_box_keys",
|
||||
return_value={
|
||||
"decrypt_key": b"d" * 32,
|
||||
"decrypt_nonce": b"dnonce",
|
||||
"encrypt_key": b"e" * 32,
|
||||
"encrypt_nonce": b"enonce",
|
||||
},
|
||||
)
|
||||
|
||||
if with_callback:
|
||||
callback = mocker.AsyncMock()
|
||||
client.on_connect(callback)
|
||||
|
||||
await client.open()
|
||||
|
||||
if with_callback:
|
||||
callback.assert_awaited_once_with(client)
|
||||
|
Loading…
Reference in New Issue
Block a user