Compare commits

..

1 Commits

Author SHA1 Message Date
56ca50a0df test: Add tests for secret_handshake.network 2023-11-03 20:21:23 +01:00
6 changed files with 15 additions and 256 deletions

View File

@ -32,7 +32,6 @@ repos:
require_serial: true require_serial: true
- id: isort - id: isort
name: isort name: isort
args: ["--check", "--diff"]
entry: poetry run isort entry: poetry run isort
language: system language: system
require_serial: true require_serial: true

View File

@ -53,7 +53,6 @@ def get_stream_pair( # pylint: disable=too-many-arguments
decrypt_nonce: bytes, decrypt_nonce: bytes,
encrypt_key: bytes, encrypt_key: bytes,
encrypt_nonce: bytes, encrypt_nonce: bytes,
# We have kwargs here to devour any extra parameters we get, e.g. from the output of SHSCryptoBase.get_box_keys()
**kwargs: Any, **kwargs: Any,
) -> Tuple["UnboxStream", "BoxStream"]: ) -> Tuple["UnboxStream", "BoxStream"]:
"""Create a new duplex box stream""" """Create a new duplex box stream"""

View File

@ -27,7 +27,7 @@ from typing import AsyncIterator, Awaitable, Callable, List, Optional
from nacl.public import PrivateKey from nacl.public import PrivateKey
from nacl.signing import SigningKey from nacl.signing import SigningKey
from typing_extensions import Self, deprecated from typing_extensions import Self
from .boxstream import BoxStream, UnboxStream, get_stream_pair from .boxstream import BoxStream, UnboxStream, get_stream_pair
from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto
@ -91,16 +91,10 @@ class SHSEndpoint:
self._on_connect = cb self._on_connect = cb
def close(self) -> None: # pragma: no cover
"""Disconnect the endpoint"""
raise NotImplementedError()
@deprecated("Use close instead")
def disconnect(self) -> None: def disconnect(self) -> None:
"""Disconnect the endpoint""" """Disconnect the endpoint"""
self.close() raise NotImplementedError
class SHSServer(SHSEndpoint): class SHSServer(SHSEndpoint):
@ -151,14 +145,10 @@ class SHSServer(SHSEndpoint):
await start_server(self.handle_connection, self.host, self.port) await start_server(self.handle_connection, self.host, self.port)
def close(self) -> None: def disconnect(self) -> None:
for connection in self.connections: for connection in self.connections:
connection.close() connection.close()
@deprecated("Use close instead")
def disconnect(self) -> None:
self.close()
class SHSServerConnection(SHSDuplexStream): class SHSServerConnection(SHSDuplexStream):
"""SHS server connection""" """SHS server connection"""
@ -231,3 +221,6 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
if self._on_connect: if self._on_connect:
await self._on_connect(self) await self._on_connect(self)
def disconnect(self) -> None:
self.close()

View File

@ -22,11 +22,7 @@
"""Tests for the box stream""" """Tests for the box stream"""
from asyncio import IncompleteReadError from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream
from pytest_mock import MockerFixture
from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream, get_stream_pair
from .helpers import AsyncBuffer, async_comprehend from .helpers import AsyncBuffer, async_comprehend
from .test_crypto import CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE from .test_crypto import CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE
@ -78,18 +74,6 @@ async def test_unboxstream() -> None:
assert unbox_stream.closed assert unbox_stream.closed
async def test_unboxstream_header_read_error(mocker: MockerFixture) -> None:
"""Test that we can handle errors during header read"""
buffer = AsyncBuffer()
mocker.patch.object(buffer, "readexactly", side_effect=IncompleteReadError(b"", HEADER_LENGTH))
unbox_stream = UnboxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
assert await unbox_stream.read() is None
assert unbox_stream.closed is True
async def test_long_packets() -> None: async def test_long_packets() -> None:
"""Test for receiving long packets""" """Test for receiving long packets"""
@ -110,28 +94,3 @@ async def test_long_packets() -> None:
assert first_packet == data[:4096] assert first_packet == data[:4096]
second_packet = await unbox_stream.read() second_packet = await unbox_stream.read()
assert second_packet == data[4096:] assert second_packet == data[4096:]
def test_get_stream_pair() -> None:
"""Test the get_stream_pair() function"""
read_buffer = AsyncBuffer()
write_buffer = AsyncBuffer()
read_stream, write_stream = get_stream_pair(
read_buffer,
write_buffer,
decrypt_key=b"d" * 32,
decrypt_nonce=b"dnonce",
encrypt_key=b"e" * 32,
encrypt_nonce=b"enonce",
)
assert isinstance(read_stream, UnboxStream)
assert isinstance(write_stream, BoxStream)
assert read_stream.key == b"d" * 32
assert read_stream.nonce == b"dnonce"
assert write_stream.key == b"e" * 32
assert write_stream.nonce == b"enonce"

View File

@ -23,15 +23,12 @@
"""Tests for the crypto components""" """Tests for the crypto components"""
import hashlib import hashlib
from typing import Literal
from nacl.exceptions import CryptoError
from nacl.public import PrivateKey from nacl.public import PrivateKey
from nacl.signing import SigningKey from nacl.signing import SigningKey
import pytest import pytest
from pytest_mock import MockerFixture
from secret_handshake.crypto import SHSClientCrypto, SHSError, SHSServerCrypto from secret_handshake.crypto import SHSClientCrypto, SHSServerCrypto
APP_KEY = hashlib.sha256(b"app_key").digest() APP_KEY = hashlib.sha256(b"app_key").digest()
SERVER_KEY_SEED = b"\xcaw\x01\xc2cQ\xfd\x94\x9f\x14\x84\x0c0<l\xd8\xe4\xf5>\x12\\\x96\xcd\x9b\x0c\x02z&\x96!\xe0\xa2" SERVER_KEY_SEED = b"\xcaw\x01\xc2cQ\xfd\x94\x9f\x14\x84\x0c0<l\xd8\xe4\xf5>\x12\\\x96\xcd\x9b\x0c\x02z&\x96!\xe0\xa2"
@ -51,22 +48,14 @@ def server() -> SHSServerCrypto:
@pytest.fixture @pytest.fixture
def client(request: pytest.FixtureRequest) -> SHSClientCrypto: def client() -> SHSClientCrypto:
"""A testing SHS client""" """A testing SHS client"""
app_key = None
for marker in request.node.iter_markers(name="client_app_key"):
app_key = marker.args[0]
if app_key is None:
app_key = APP_KEY
client_key = SigningKey(CLIENT_KEY_SEED) client_key = SigningKey(CLIENT_KEY_SEED)
server_key = SigningKey(SERVER_KEY_SEED) server_key = SigningKey(SERVER_KEY_SEED)
client_eph_key = PrivateKey(CLIENT_EPH_KEY_SEED) client_eph_key = PrivateKey(CLIENT_EPH_KEY_SEED)
return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=app_key) return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=APP_KEY)
CLIENT_CHALLENGE = ( CLIENT_CHALLENGE = (
@ -141,67 +130,3 @@ def test_handshake(client: SHSClientCrypto, server: SHSServerCrypto) -> None: #
assert client_keys["shared_secret"] == server_keys["shared_secret"] assert client_keys["shared_secret"] == server_keys["shared_secret"]
assert client_keys["encrypt_key"] == server_keys["decrypt_key"] assert client_keys["encrypt_key"] == server_keys["decrypt_key"]
assert client_keys["encrypt_nonce"] == server_keys["decrypt_nonce"] assert client_keys["encrypt_nonce"] == server_keys["decrypt_nonce"]
@pytest.mark.client_app_key(b"a" * 32)
def test_verify_challenge_different_app_keys(
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
) -> None:
"""Test challenge verification when the application keys of the client and server dont match"""
challenge = client.generate_challenge()
assert not server.verify_challenge(challenge)
def test_verify_challenge(
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
) -> None:
"""Test challenge verification when the application keys of the client and server match"""
challenge = client.generate_challenge()
assert server.verify_challenge(challenge)
@pytest.mark.parametrize("type_", ("server", "client"))
@pytest.mark.parametrize("provide_key", (True, False))
def test_clean(
type_: Literal["client", "server"], provide_key: bool, request: pytest.FixtureRequest, mocker: MockerFixture
) -> None:
"""Test the clean() method"""
if type_ == "server":
actor = request.getfixturevalue("server")
elif type_ == "client": # pragma: no branch
actor = request.getfixturevalue("client")
mocked_private_key = mocker.patch("secret_handshake.crypto.PrivateKey")
mocked_private_key.generate = mocker.MagicMock(return_value=PrivateKey(b"g" * 32))
new_key = PrivateKey(b"p" * 32) if provide_key else None
actor.clean(new_ephemeral_key=new_key)
assert actor.shared_secret is None
assert actor.shared_hash is None
assert actor.remote_ephemeral_key is None
assert isinstance(actor.local_ephemeral_key, PrivateKey)
if provide_key:
assert actor.local_ephemeral_key == new_key
else:
assert actor.local_ephemeral_key.encode() == b"g" * 32
def test_failing_server_accept(
client: SHSClientCrypto, server: SHSServerCrypto, mocker: MockerFixture # pylint: disable=redefined-outer-name
) -> None:
"""Test if verify_server_accept raises the correct exception type"""
server.verify_challenge(client.generate_challenge())
client.verify_server_challenge(server.generate_challenge())
server.verify_client_auth(client.generate_client_auth())
server_accept = server.generate_accept()
mocker.patch("secret_handshake.crypto.crypto_box_open_afternm", side_effect=CryptoError())
with pytest.raises(SHSError):
client.verify_server_accept(server_accept)

View File

@ -24,7 +24,7 @@
from asyncio import Event, wait_for from asyncio import Event, wait_for
import os import os
from typing import Any, Awaitable, Callable, Literal, Tuple from typing import Any, Awaitable, Callable, Tuple
from nacl.signing import SigningKey from nacl.signing import SigningKey
import pytest import pytest
@ -32,7 +32,7 @@ from pytest_mock import MockerFixture
from secret_handshake import SHSClient, SHSServer from secret_handshake import SHSClient, SHSServer
from secret_handshake.boxstream import BoxStreamKeys from secret_handshake.boxstream import BoxStreamKeys
from secret_handshake.network import SHSClientException, SHSDuplexStream from secret_handshake.network import SHSDuplexStream
from .helpers import AsyncBuffer from .helpers import AsyncBuffer
@ -145,7 +145,7 @@ async def test_client(mocker: MockerFixture) -> None:
await client.open() await client.open()
reader.append(b"TEST") reader.append(b"TEST")
assert (await client.read()) == b"TEST" assert (await client.read()) == b"TEST"
client.close() client.disconnect()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -155,7 +155,7 @@ async def test_server(mocker: MockerFixture) -> None:
resolve = Event() resolve = Event()
async def _on_connect(_: Any) -> None: async def _on_connect(_: Any) -> None:
server.close() server.disconnect()
resolve.set() resolve.set()
_, _, _create_mock_server = _server_stream_mocker() _, _, _create_mock_server = _server_stream_mocker()
@ -179,9 +179,7 @@ def test_duplex_write(mocker: MockerFixture) -> None:
d_stream.write_stream = mocker.AsyncMock() d_stream.write_stream = mocker.AsyncMock()
d_stream.write(b"thing") d_stream.write(b"thing")
assert d_stream.write_stream d_stream.write_stream.write.assert_called_once_with(b"thing")
d_stream.write_stream.write.assert_called_once_with(b"thing") # type: ignore[attr-defined]
def test_duplex_close_no_write_stream() -> None: def test_duplex_close_no_write_stream() -> None:
@ -199,117 +197,3 @@ def test_duplex_stream_aiter() -> None:
d_stream = SHSDuplexStream() d_stream = SHSDuplexStream()
assert d_stream.__aiter__() is d_stream # pylint: disable=unnecessary-dunder-call assert d_stream.__aiter__() is d_stream # pylint: disable=unnecessary-dunder-call
async def test_duplex_stream_anext(mocker: MockerFixture) -> None:
"""Test if the __anext__ method of SHSDuplexStream reads from the stream"""
d_stream = SHSDuplexStream()
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=b"test"))
assert await d_stream.__anext__() == b"test" # pylint: disable=unnecessary-dunder-call
async def test_duplex_stream_anext_eof(mocker: MockerFixture) -> None:
"""Test if SHSDuplexStream.__anext__ breaks iteration if theres no data to read"""
d_stream = SHSDuplexStream()
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=None))
with pytest.raises(StopAsyncIteration):
assert await d_stream.__anext__() # pylint: disable=unnecessary-dunder-call
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_auth"))
async def test_server_fail_handshake(
fail_type: Literal["verify_challenge", "verify_auth"], mocker: MockerFixture
) -> None:
"""Test if a failing handshake results in an SHSClientException"""
server = SHSServer("127.0.0.1", 8754, SigningKey.generate())
if fail_type == "verify_challenge":
expected_error = "Client challenge is not valid"
elif fail_type == "verify_auth": # pragma: no branch
expected_error = "Client auth is not valid"
mocker.patch.object(server.crypto, "verify_challenge", return_value=fail_type != "verify_challenge")
mocker.patch.object(server.crypto, "verify_client_auth", return_value=fail_type != "verify_auth")
with pytest.raises(SHSClientException) as ctx:
await server._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
assert str(ctx.value) == expected_error
async def test_server_no_connect_callback(mocker: MockerFixture) -> None:
"""Test if SHSServer.handle_connection works without an on_connect callback"""
server = SHSServer("127.0.0.1", 7429, SigningKey.generate())
mocker.patch.object(server, "_handshake", return_value=None)
mocker.patch.object(
server.crypto,
"get_box_keys",
return_value={
"decrypt_key": b"d" * 32,
"decrypt_nonce": b"dnonce",
"encrypt_key": b"e" * 32,
"encrypt_nonce": b"enonce",
},
)
await server.handle_connection(AsyncBuffer(), AsyncBuffer())
# No assertion here. We should get here without a problem
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_accept"))
async def test_client_fail_handshake(
fail_type: Literal["verify_challenge", "verify_accept"], mocker: MockerFixture
) -> None:
"""Test if a failing handshake results in an SHSClientException"""
client = SHSClient("127.0.0.1", 8754, SigningKey.generate(), b"s" * 32)
if fail_type == "verify_challenge":
expected_error = "Server challenge is not valid"
elif fail_type == "verify_accept": # pragma: no branch
expected_error = "Server accept is not valid"
mocker.patch.object(client.crypto, "verify_server_challenge", return_value=fail_type != "verify_challenge")
mocker.patch.object(client.crypto, "verify_server_accept", return_value=fail_type != "verify_accept")
mocker.patch.object(client.crypto, "generate_client_auth", return_value=b"ca" * 16)
with pytest.raises(SHSClientException) as ctx:
await client._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
assert str(ctx.value) == expected_error
@pytest.mark.parametrize("with_callback", (True, False))
async def test_client_open(with_callback: bool, mocker: MockerFixture) -> None:
"""Test if SHSServer.handle_connection works with and without an on_connect callback"""
client = SHSClient("127.0.0.1", 7429, SigningKey.generate(), SigningKey.generate().verify_key.encode())
mocker.patch("secret_handshake.network.open_connection", return_value=(AsyncBuffer(), AsyncBuffer()))
mocker.patch.object(client, "_handshake", return_value=None)
mocker.patch.object(
client.crypto,
"get_box_keys",
return_value={
"decrypt_key": b"d" * 32,
"decrypt_nonce": b"dnonce",
"encrypt_key": b"e" * 32,
"encrypt_nonce": b"enonce",
},
)
if with_callback:
callback = mocker.AsyncMock()
client.on_connect(callback)
await client.open()
if with_callback:
callback.assert_awaited_once_with(client)