Compare commits

..

1 Commits

Author SHA1 Message Date
56ca50a0df test: Add tests for secret_handshake.network 2023-11-03 20:21:23 +01:00
6 changed files with 15 additions and 256 deletions

View File

@ -32,7 +32,6 @@ repos:
require_serial: true
- id: isort
name: isort
args: ["--check", "--diff"]
entry: poetry run isort
language: system
require_serial: true

View File

@ -53,7 +53,6 @@ def get_stream_pair( # pylint: disable=too-many-arguments
decrypt_nonce: bytes,
encrypt_key: bytes,
encrypt_nonce: bytes,
# We have kwargs here to devour any extra parameters we get, e.g. from the output of SHSCryptoBase.get_box_keys()
**kwargs: Any,
) -> Tuple["UnboxStream", "BoxStream"]:
"""Create a new duplex box stream"""

View File

@ -27,7 +27,7 @@ from typing import AsyncIterator, Awaitable, Callable, List, Optional
from nacl.public import PrivateKey
from nacl.signing import SigningKey
from typing_extensions import Self, deprecated
from typing_extensions import Self
from .boxstream import BoxStream, UnboxStream, get_stream_pair
from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto
@ -91,16 +91,10 @@ class SHSEndpoint:
self._on_connect = cb
def close(self) -> None: # pragma: no cover
"""Disconnect the endpoint"""
raise NotImplementedError()
@deprecated("Use close instead")
def disconnect(self) -> None:
"""Disconnect the endpoint"""
self.close()
raise NotImplementedError
class SHSServer(SHSEndpoint):
@ -151,14 +145,10 @@ class SHSServer(SHSEndpoint):
await start_server(self.handle_connection, self.host, self.port)
def close(self) -> None:
def disconnect(self) -> None:
for connection in self.connections:
connection.close()
@deprecated("Use close instead")
def disconnect(self) -> None:
self.close()
class SHSServerConnection(SHSDuplexStream):
"""SHS server connection"""
@ -231,3 +221,6 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
if self._on_connect:
await self._on_connect(self)
def disconnect(self) -> None:
self.close()

View File

@ -22,11 +22,7 @@
"""Tests for the box stream"""
from asyncio import IncompleteReadError
from pytest_mock import MockerFixture
from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream, get_stream_pair
from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream
from .helpers import AsyncBuffer, async_comprehend
from .test_crypto import CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE
@ -78,18 +74,6 @@ async def test_unboxstream() -> None:
assert unbox_stream.closed
async def test_unboxstream_header_read_error(mocker: MockerFixture) -> None:
"""Test that we can handle errors during header read"""
buffer = AsyncBuffer()
mocker.patch.object(buffer, "readexactly", side_effect=IncompleteReadError(b"", HEADER_LENGTH))
unbox_stream = UnboxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
assert await unbox_stream.read() is None
assert unbox_stream.closed is True
async def test_long_packets() -> None:
"""Test for receiving long packets"""
@ -110,28 +94,3 @@ async def test_long_packets() -> None:
assert first_packet == data[:4096]
second_packet = await unbox_stream.read()
assert second_packet == data[4096:]
def test_get_stream_pair() -> None:
"""Test the get_stream_pair() function"""
read_buffer = AsyncBuffer()
write_buffer = AsyncBuffer()
read_stream, write_stream = get_stream_pair(
read_buffer,
write_buffer,
decrypt_key=b"d" * 32,
decrypt_nonce=b"dnonce",
encrypt_key=b"e" * 32,
encrypt_nonce=b"enonce",
)
assert isinstance(read_stream, UnboxStream)
assert isinstance(write_stream, BoxStream)
assert read_stream.key == b"d" * 32
assert read_stream.nonce == b"dnonce"
assert write_stream.key == b"e" * 32
assert write_stream.nonce == b"enonce"

View File

@ -23,15 +23,12 @@
"""Tests for the crypto components"""
import hashlib
from typing import Literal
from nacl.exceptions import CryptoError
from nacl.public import PrivateKey
from nacl.signing import SigningKey
import pytest
from pytest_mock import MockerFixture
from secret_handshake.crypto import SHSClientCrypto, SHSError, SHSServerCrypto
from secret_handshake.crypto import SHSClientCrypto, SHSServerCrypto
APP_KEY = hashlib.sha256(b"app_key").digest()
SERVER_KEY_SEED = b"\xcaw\x01\xc2cQ\xfd\x94\x9f\x14\x84\x0c0<l\xd8\xe4\xf5>\x12\\\x96\xcd\x9b\x0c\x02z&\x96!\xe0\xa2"
@ -51,22 +48,14 @@ def server() -> SHSServerCrypto:
@pytest.fixture
def client(request: pytest.FixtureRequest) -> SHSClientCrypto:
def client() -> SHSClientCrypto:
"""A testing SHS client"""
app_key = None
for marker in request.node.iter_markers(name="client_app_key"):
app_key = marker.args[0]
if app_key is None:
app_key = APP_KEY
client_key = SigningKey(CLIENT_KEY_SEED)
server_key = SigningKey(SERVER_KEY_SEED)
client_eph_key = PrivateKey(CLIENT_EPH_KEY_SEED)
return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=app_key)
return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=APP_KEY)
CLIENT_CHALLENGE = (
@ -141,67 +130,3 @@ def test_handshake(client: SHSClientCrypto, server: SHSServerCrypto) -> None: #
assert client_keys["shared_secret"] == server_keys["shared_secret"]
assert client_keys["encrypt_key"] == server_keys["decrypt_key"]
assert client_keys["encrypt_nonce"] == server_keys["decrypt_nonce"]
@pytest.mark.client_app_key(b"a" * 32)
def test_verify_challenge_different_app_keys(
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
) -> None:
"""Test challenge verification when the application keys of the client and server dont match"""
challenge = client.generate_challenge()
assert not server.verify_challenge(challenge)
def test_verify_challenge(
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
) -> None:
"""Test challenge verification when the application keys of the client and server match"""
challenge = client.generate_challenge()
assert server.verify_challenge(challenge)
@pytest.mark.parametrize("type_", ("server", "client"))
@pytest.mark.parametrize("provide_key", (True, False))
def test_clean(
type_: Literal["client", "server"], provide_key: bool, request: pytest.FixtureRequest, mocker: MockerFixture
) -> None:
"""Test the clean() method"""
if type_ == "server":
actor = request.getfixturevalue("server")
elif type_ == "client": # pragma: no branch
actor = request.getfixturevalue("client")
mocked_private_key = mocker.patch("secret_handshake.crypto.PrivateKey")
mocked_private_key.generate = mocker.MagicMock(return_value=PrivateKey(b"g" * 32))
new_key = PrivateKey(b"p" * 32) if provide_key else None
actor.clean(new_ephemeral_key=new_key)
assert actor.shared_secret is None
assert actor.shared_hash is None
assert actor.remote_ephemeral_key is None
assert isinstance(actor.local_ephemeral_key, PrivateKey)
if provide_key:
assert actor.local_ephemeral_key == new_key
else:
assert actor.local_ephemeral_key.encode() == b"g" * 32
def test_failing_server_accept(
client: SHSClientCrypto, server: SHSServerCrypto, mocker: MockerFixture # pylint: disable=redefined-outer-name
) -> None:
"""Test if verify_server_accept raises the correct exception type"""
server.verify_challenge(client.generate_challenge())
client.verify_server_challenge(server.generate_challenge())
server.verify_client_auth(client.generate_client_auth())
server_accept = server.generate_accept()
mocker.patch("secret_handshake.crypto.crypto_box_open_afternm", side_effect=CryptoError())
with pytest.raises(SHSError):
client.verify_server_accept(server_accept)

View File

@ -24,7 +24,7 @@
from asyncio import Event, wait_for
import os
from typing import Any, Awaitable, Callable, Literal, Tuple
from typing import Any, Awaitable, Callable, Tuple
from nacl.signing import SigningKey
import pytest
@ -32,7 +32,7 @@ from pytest_mock import MockerFixture
from secret_handshake import SHSClient, SHSServer
from secret_handshake.boxstream import BoxStreamKeys
from secret_handshake.network import SHSClientException, SHSDuplexStream
from secret_handshake.network import SHSDuplexStream
from .helpers import AsyncBuffer
@ -145,7 +145,7 @@ async def test_client(mocker: MockerFixture) -> None:
await client.open()
reader.append(b"TEST")
assert (await client.read()) == b"TEST"
client.close()
client.disconnect()
@pytest.mark.asyncio
@ -155,7 +155,7 @@ async def test_server(mocker: MockerFixture) -> None:
resolve = Event()
async def _on_connect(_: Any) -> None:
server.close()
server.disconnect()
resolve.set()
_, _, _create_mock_server = _server_stream_mocker()
@ -179,9 +179,7 @@ def test_duplex_write(mocker: MockerFixture) -> None:
d_stream.write_stream = mocker.AsyncMock()
d_stream.write(b"thing")
assert d_stream.write_stream
d_stream.write_stream.write.assert_called_once_with(b"thing") # type: ignore[attr-defined]
d_stream.write_stream.write.assert_called_once_with(b"thing")
def test_duplex_close_no_write_stream() -> None:
@ -199,117 +197,3 @@ def test_duplex_stream_aiter() -> None:
d_stream = SHSDuplexStream()
assert d_stream.__aiter__() is d_stream # pylint: disable=unnecessary-dunder-call
async def test_duplex_stream_anext(mocker: MockerFixture) -> None:
"""Test if the __anext__ method of SHSDuplexStream reads from the stream"""
d_stream = SHSDuplexStream()
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=b"test"))
assert await d_stream.__anext__() == b"test" # pylint: disable=unnecessary-dunder-call
async def test_duplex_stream_anext_eof(mocker: MockerFixture) -> None:
"""Test if SHSDuplexStream.__anext__ breaks iteration if theres no data to read"""
d_stream = SHSDuplexStream()
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=None))
with pytest.raises(StopAsyncIteration):
assert await d_stream.__anext__() # pylint: disable=unnecessary-dunder-call
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_auth"))
async def test_server_fail_handshake(
fail_type: Literal["verify_challenge", "verify_auth"], mocker: MockerFixture
) -> None:
"""Test if a failing handshake results in an SHSClientException"""
server = SHSServer("127.0.0.1", 8754, SigningKey.generate())
if fail_type == "verify_challenge":
expected_error = "Client challenge is not valid"
elif fail_type == "verify_auth": # pragma: no branch
expected_error = "Client auth is not valid"
mocker.patch.object(server.crypto, "verify_challenge", return_value=fail_type != "verify_challenge")
mocker.patch.object(server.crypto, "verify_client_auth", return_value=fail_type != "verify_auth")
with pytest.raises(SHSClientException) as ctx:
await server._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
assert str(ctx.value) == expected_error
async def test_server_no_connect_callback(mocker: MockerFixture) -> None:
"""Test if SHSServer.handle_connection works without an on_connect callback"""
server = SHSServer("127.0.0.1", 7429, SigningKey.generate())
mocker.patch.object(server, "_handshake", return_value=None)
mocker.patch.object(
server.crypto,
"get_box_keys",
return_value={
"decrypt_key": b"d" * 32,
"decrypt_nonce": b"dnonce",
"encrypt_key": b"e" * 32,
"encrypt_nonce": b"enonce",
},
)
await server.handle_connection(AsyncBuffer(), AsyncBuffer())
# No assertion here. We should get here without a problem
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_accept"))
async def test_client_fail_handshake(
fail_type: Literal["verify_challenge", "verify_accept"], mocker: MockerFixture
) -> None:
"""Test if a failing handshake results in an SHSClientException"""
client = SHSClient("127.0.0.1", 8754, SigningKey.generate(), b"s" * 32)
if fail_type == "verify_challenge":
expected_error = "Server challenge is not valid"
elif fail_type == "verify_accept": # pragma: no branch
expected_error = "Server accept is not valid"
mocker.patch.object(client.crypto, "verify_server_challenge", return_value=fail_type != "verify_challenge")
mocker.patch.object(client.crypto, "verify_server_accept", return_value=fail_type != "verify_accept")
mocker.patch.object(client.crypto, "generate_client_auth", return_value=b"ca" * 16)
with pytest.raises(SHSClientException) as ctx:
await client._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
assert str(ctx.value) == expected_error
@pytest.mark.parametrize("with_callback", (True, False))
async def test_client_open(with_callback: bool, mocker: MockerFixture) -> None:
"""Test if SHSServer.handle_connection works with and without an on_connect callback"""
client = SHSClient("127.0.0.1", 7429, SigningKey.generate(), SigningKey.generate().verify_key.encode())
mocker.patch("secret_handshake.network.open_connection", return_value=(AsyncBuffer(), AsyncBuffer()))
mocker.patch.object(client, "_handshake", return_value=None)
mocker.patch.object(
client.crypto,
"get_box_keys",
return_value={
"decrypt_key": b"d" * 32,
"decrypt_nonce": b"dnonce",
"encrypt_key": b"e" * 32,
"encrypt_nonce": b"enonce",
},
)
if with_callback:
callback = mocker.AsyncMock()
client.on_connect(callback)
await client.open()
if with_callback:
callback.assert_awaited_once_with(client)