Compare commits
5 Commits
network-te
...
main
Author | SHA1 | Date | |
---|---|---|---|
3f13152684 | |||
83add95c8a | |||
1ad7cb4e5e | |||
1d07f9ba02 | |||
5807e64462 |
@ -32,6 +32,7 @@ repos:
|
||||
require_serial: true
|
||||
- id: isort
|
||||
name: isort
|
||||
args: ["--check", "--diff"]
|
||||
entry: poetry run isort
|
||||
language: system
|
||||
require_serial: true
|
||||
|
@ -53,6 +53,7 @@ def get_stream_pair( # pylint: disable=too-many-arguments
|
||||
decrypt_nonce: bytes,
|
||||
encrypt_key: bytes,
|
||||
encrypt_nonce: bytes,
|
||||
# We have kwargs here to devour any extra parameters we get, e.g. from the output of SHSCryptoBase.get_box_keys()
|
||||
**kwargs: Any,
|
||||
) -> Tuple["UnboxStream", "BoxStream"]:
|
||||
"""Create a new duplex box stream"""
|
||||
|
@ -27,7 +27,7 @@ from typing import AsyncIterator, Awaitable, Callable, List, Optional
|
||||
|
||||
from nacl.public import PrivateKey
|
||||
from nacl.signing import SigningKey
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
from .boxstream import BoxStream, UnboxStream, get_stream_pair
|
||||
from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto
|
||||
@ -91,10 +91,16 @@ class SHSEndpoint:
|
||||
|
||||
self._on_connect = cb
|
||||
|
||||
def close(self) -> None: # pragma: no cover
|
||||
"""Disconnect the endpoint"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@deprecated("Use close instead")
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect the endpoint"""
|
||||
|
||||
raise NotImplementedError
|
||||
self.close()
|
||||
|
||||
|
||||
class SHSServer(SHSEndpoint):
|
||||
@ -145,10 +151,14 @@ class SHSServer(SHSEndpoint):
|
||||
|
||||
await start_server(self.handle_connection, self.host, self.port)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
def close(self) -> None:
|
||||
for connection in self.connections:
|
||||
connection.close()
|
||||
|
||||
@deprecated("Use close instead")
|
||||
def disconnect(self) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class SHSServerConnection(SHSDuplexStream):
|
||||
"""SHS server connection"""
|
||||
@ -221,6 +231,3 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
|
||||
|
||||
if self._on_connect:
|
||||
await self._on_connect(self)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.close()
|
||||
|
@ -22,7 +22,11 @@
|
||||
|
||||
"""Tests for the box stream"""
|
||||
|
||||
from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream
|
||||
from asyncio import IncompleteReadError
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream, get_stream_pair
|
||||
|
||||
from .helpers import AsyncBuffer, async_comprehend
|
||||
from .test_crypto import CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE
|
||||
@ -74,6 +78,18 @@ async def test_unboxstream() -> None:
|
||||
assert unbox_stream.closed
|
||||
|
||||
|
||||
async def test_unboxstream_header_read_error(mocker: MockerFixture) -> None:
|
||||
"""Test that we can handle errors during header read"""
|
||||
|
||||
buffer = AsyncBuffer()
|
||||
mocker.patch.object(buffer, "readexactly", side_effect=IncompleteReadError(b"", HEADER_LENGTH))
|
||||
|
||||
unbox_stream = UnboxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
|
||||
|
||||
assert await unbox_stream.read() is None
|
||||
assert unbox_stream.closed is True
|
||||
|
||||
|
||||
async def test_long_packets() -> None:
|
||||
"""Test for receiving long packets"""
|
||||
|
||||
@ -94,3 +110,28 @@ async def test_long_packets() -> None:
|
||||
assert first_packet == data[:4096]
|
||||
second_packet = await unbox_stream.read()
|
||||
assert second_packet == data[4096:]
|
||||
|
||||
|
||||
def test_get_stream_pair() -> None:
|
||||
"""Test the get_stream_pair() function"""
|
||||
|
||||
read_buffer = AsyncBuffer()
|
||||
write_buffer = AsyncBuffer()
|
||||
|
||||
read_stream, write_stream = get_stream_pair(
|
||||
read_buffer,
|
||||
write_buffer,
|
||||
decrypt_key=b"d" * 32,
|
||||
decrypt_nonce=b"dnonce",
|
||||
encrypt_key=b"e" * 32,
|
||||
encrypt_nonce=b"enonce",
|
||||
)
|
||||
|
||||
assert isinstance(read_stream, UnboxStream)
|
||||
assert isinstance(write_stream, BoxStream)
|
||||
|
||||
assert read_stream.key == b"d" * 32
|
||||
assert read_stream.nonce == b"dnonce"
|
||||
|
||||
assert write_stream.key == b"e" * 32
|
||||
assert write_stream.nonce == b"enonce"
|
||||
|
@ -23,12 +23,15 @@
|
||||
"""Tests for the crypto components"""
|
||||
|
||||
import hashlib
|
||||
from typing import Literal
|
||||
|
||||
from nacl.exceptions import CryptoError
|
||||
from nacl.public import PrivateKey
|
||||
from nacl.signing import SigningKey
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from secret_handshake.crypto import SHSClientCrypto, SHSServerCrypto
|
||||
from secret_handshake.crypto import SHSClientCrypto, SHSError, SHSServerCrypto
|
||||
|
||||
APP_KEY = hashlib.sha256(b"app_key").digest()
|
||||
SERVER_KEY_SEED = b"\xcaw\x01\xc2cQ\xfd\x94\x9f\x14\x84\x0c0<l\xd8\xe4\xf5>\x12\\\x96\xcd\x9b\x0c\x02z&\x96!\xe0\xa2"
|
||||
@ -48,14 +51,22 @@ def server() -> SHSServerCrypto:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> SHSClientCrypto:
|
||||
def client(request: pytest.FixtureRequest) -> SHSClientCrypto:
|
||||
"""A testing SHS client"""
|
||||
|
||||
app_key = None
|
||||
|
||||
for marker in request.node.iter_markers(name="client_app_key"):
|
||||
app_key = marker.args[0]
|
||||
|
||||
if app_key is None:
|
||||
app_key = APP_KEY
|
||||
|
||||
client_key = SigningKey(CLIENT_KEY_SEED)
|
||||
server_key = SigningKey(SERVER_KEY_SEED)
|
||||
client_eph_key = PrivateKey(CLIENT_EPH_KEY_SEED)
|
||||
|
||||
return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=APP_KEY)
|
||||
return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=app_key)
|
||||
|
||||
|
||||
CLIENT_CHALLENGE = (
|
||||
@ -130,3 +141,67 @@ def test_handshake(client: SHSClientCrypto, server: SHSServerCrypto) -> None: #
|
||||
assert client_keys["shared_secret"] == server_keys["shared_secret"]
|
||||
assert client_keys["encrypt_key"] == server_keys["decrypt_key"]
|
||||
assert client_keys["encrypt_nonce"] == server_keys["decrypt_nonce"]
|
||||
|
||||
|
||||
@pytest.mark.client_app_key(b"a" * 32)
|
||||
def test_verify_challenge_different_app_keys(
|
||||
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
|
||||
) -> None:
|
||||
"""Test challenge verification when the application keys of the client and server don’t match"""
|
||||
|
||||
challenge = client.generate_challenge()
|
||||
assert not server.verify_challenge(challenge)
|
||||
|
||||
|
||||
def test_verify_challenge(
|
||||
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
|
||||
) -> None:
|
||||
"""Test challenge verification when the application keys of the client and server match"""
|
||||
|
||||
challenge = client.generate_challenge()
|
||||
assert server.verify_challenge(challenge)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type_", ("server", "client"))
|
||||
@pytest.mark.parametrize("provide_key", (True, False))
|
||||
def test_clean(
|
||||
type_: Literal["client", "server"], provide_key: bool, request: pytest.FixtureRequest, mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test the clean() method"""
|
||||
|
||||
if type_ == "server":
|
||||
actor = request.getfixturevalue("server")
|
||||
elif type_ == "client": # pragma: no branch
|
||||
actor = request.getfixturevalue("client")
|
||||
|
||||
mocked_private_key = mocker.patch("secret_handshake.crypto.PrivateKey")
|
||||
mocked_private_key.generate = mocker.MagicMock(return_value=PrivateKey(b"g" * 32))
|
||||
|
||||
new_key = PrivateKey(b"p" * 32) if provide_key else None
|
||||
actor.clean(new_ephemeral_key=new_key)
|
||||
|
||||
assert actor.shared_secret is None
|
||||
assert actor.shared_hash is None
|
||||
assert actor.remote_ephemeral_key is None
|
||||
assert isinstance(actor.local_ephemeral_key, PrivateKey)
|
||||
|
||||
if provide_key:
|
||||
assert actor.local_ephemeral_key == new_key
|
||||
else:
|
||||
assert actor.local_ephemeral_key.encode() == b"g" * 32
|
||||
|
||||
|
||||
def test_failing_server_accept(
|
||||
client: SHSClientCrypto, server: SHSServerCrypto, mocker: MockerFixture # pylint: disable=redefined-outer-name
|
||||
) -> None:
|
||||
"""Test if verify_server_accept raises the correct exception type"""
|
||||
|
||||
server.verify_challenge(client.generate_challenge())
|
||||
client.verify_server_challenge(server.generate_challenge())
|
||||
server.verify_client_auth(client.generate_client_auth())
|
||||
server_accept = server.generate_accept()
|
||||
|
||||
mocker.patch("secret_handshake.crypto.crypto_box_open_afternm", side_effect=CryptoError())
|
||||
|
||||
with pytest.raises(SHSError):
|
||||
client.verify_server_accept(server_accept)
|
||||
|
@ -24,7 +24,7 @@
|
||||
|
||||
from asyncio import Event, wait_for
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable, Tuple
|
||||
from typing import Any, Awaitable, Callable, Literal, Tuple
|
||||
|
||||
from nacl.signing import SigningKey
|
||||
import pytest
|
||||
@ -32,6 +32,7 @@ from pytest_mock import MockerFixture
|
||||
|
||||
from secret_handshake import SHSClient, SHSServer
|
||||
from secret_handshake.boxstream import BoxStreamKeys
|
||||
from secret_handshake.network import SHSClientException, SHSDuplexStream
|
||||
|
||||
from .helpers import AsyncBuffer
|
||||
|
||||
@ -144,7 +145,7 @@ async def test_client(mocker: MockerFixture) -> None:
|
||||
await client.open()
|
||||
reader.append(b"TEST")
|
||||
assert (await client.read()) == b"TEST"
|
||||
client.disconnect()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -154,7 +155,7 @@ async def test_server(mocker: MockerFixture) -> None:
|
||||
resolve = Event()
|
||||
|
||||
async def _on_connect(_: Any) -> None:
|
||||
server.disconnect()
|
||||
server.close()
|
||||
resolve.set()
|
||||
|
||||
_, _, _create_mock_server = _server_stream_mocker()
|
||||
@ -169,3 +170,146 @@ async def test_server(mocker: MockerFixture) -> None:
|
||||
|
||||
await server.listen()
|
||||
await wait_for(resolve.wait(), 5)
|
||||
|
||||
|
||||
def test_duplex_write(mocker: MockerFixture) -> None:
|
||||
"""Test the writing capabilities of the duplex stream"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
d_stream.write_stream = mocker.AsyncMock()
|
||||
d_stream.write(b"thing")
|
||||
|
||||
assert d_stream.write_stream
|
||||
|
||||
d_stream.write_stream.write.assert_called_once_with(b"thing") # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_duplex_close_no_write_stream() -> None:
|
||||
"""Test if SHSDuplexStream’s close method doesn’t fail if there is no write stream"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
assert d_stream.write_stream is None
|
||||
d_stream.close()
|
||||
|
||||
# We cannot really do assertions here. If there is not set (it is None), the above call would fail
|
||||
|
||||
|
||||
def test_duplex_stream_aiter() -> None:
|
||||
"""Test if the __aiter__ method of SHSDuplexStream returns the stream itself"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
assert d_stream.__aiter__() is d_stream # pylint: disable=unnecessary-dunder-call
|
||||
|
||||
|
||||
async def test_duplex_stream_anext(mocker: MockerFixture) -> None:
|
||||
"""Test if the __anext__ method of SHSDuplexStream reads from the stream"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=b"test"))
|
||||
|
||||
assert await d_stream.__anext__() == b"test" # pylint: disable=unnecessary-dunder-call
|
||||
|
||||
|
||||
async def test_duplex_stream_anext_eof(mocker: MockerFixture) -> None:
|
||||
"""Test if SHSDuplexStream.__anext__ breaks iteration if there’s no data to read"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=None))
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
assert await d_stream.__anext__() # pylint: disable=unnecessary-dunder-call
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_auth"))
|
||||
async def test_server_fail_handshake(
|
||||
fail_type: Literal["verify_challenge", "verify_auth"], mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test if a failing handshake results in an SHSClientException"""
|
||||
|
||||
server = SHSServer("127.0.0.1", 8754, SigningKey.generate())
|
||||
|
||||
if fail_type == "verify_challenge":
|
||||
expected_error = "Client challenge is not valid"
|
||||
elif fail_type == "verify_auth": # pragma: no branch
|
||||
expected_error = "Client auth is not valid"
|
||||
|
||||
mocker.patch.object(server.crypto, "verify_challenge", return_value=fail_type != "verify_challenge")
|
||||
mocker.patch.object(server.crypto, "verify_client_auth", return_value=fail_type != "verify_auth")
|
||||
|
||||
with pytest.raises(SHSClientException) as ctx:
|
||||
await server._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
|
||||
|
||||
assert str(ctx.value) == expected_error
|
||||
|
||||
|
||||
async def test_server_no_connect_callback(mocker: MockerFixture) -> None:
|
||||
"""Test if SHSServer.handle_connection works without an on_connect callback"""
|
||||
|
||||
server = SHSServer("127.0.0.1", 7429, SigningKey.generate())
|
||||
mocker.patch.object(server, "_handshake", return_value=None)
|
||||
mocker.patch.object(
|
||||
server.crypto,
|
||||
"get_box_keys",
|
||||
return_value={
|
||||
"decrypt_key": b"d" * 32,
|
||||
"decrypt_nonce": b"dnonce",
|
||||
"encrypt_key": b"e" * 32,
|
||||
"encrypt_nonce": b"enonce",
|
||||
},
|
||||
)
|
||||
|
||||
await server.handle_connection(AsyncBuffer(), AsyncBuffer())
|
||||
|
||||
# No assertion here. We should get here without a problem
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_accept"))
|
||||
async def test_client_fail_handshake(
|
||||
fail_type: Literal["verify_challenge", "verify_accept"], mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test if a failing handshake results in an SHSClientException"""
|
||||
|
||||
client = SHSClient("127.0.0.1", 8754, SigningKey.generate(), b"s" * 32)
|
||||
|
||||
if fail_type == "verify_challenge":
|
||||
expected_error = "Server challenge is not valid"
|
||||
elif fail_type == "verify_accept": # pragma: no branch
|
||||
expected_error = "Server accept is not valid"
|
||||
|
||||
mocker.patch.object(client.crypto, "verify_server_challenge", return_value=fail_type != "verify_challenge")
|
||||
mocker.patch.object(client.crypto, "verify_server_accept", return_value=fail_type != "verify_accept")
|
||||
mocker.patch.object(client.crypto, "generate_client_auth", return_value=b"ca" * 16)
|
||||
|
||||
with pytest.raises(SHSClientException) as ctx:
|
||||
await client._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
|
||||
|
||||
assert str(ctx.value) == expected_error
|
||||
|
||||
|
||||
@pytest.mark.parametrize("with_callback", (True, False))
|
||||
async def test_client_open(with_callback: bool, mocker: MockerFixture) -> None:
|
||||
"""Test if SHSServer.handle_connection works with and without an on_connect callback"""
|
||||
|
||||
client = SHSClient("127.0.0.1", 7429, SigningKey.generate(), SigningKey.generate().verify_key.encode())
|
||||
|
||||
mocker.patch("secret_handshake.network.open_connection", return_value=(AsyncBuffer(), AsyncBuffer()))
|
||||
mocker.patch.object(client, "_handshake", return_value=None)
|
||||
mocker.patch.object(
|
||||
client.crypto,
|
||||
"get_box_keys",
|
||||
return_value={
|
||||
"decrypt_key": b"d" * 32,
|
||||
"decrypt_nonce": b"dnonce",
|
||||
"encrypt_key": b"e" * 32,
|
||||
"encrypt_nonce": b"enonce",
|
||||
},
|
||||
)
|
||||
|
||||
if with_callback:
|
||||
callback = mocker.AsyncMock()
|
||||
client.on_connect(callback)
|
||||
|
||||
await client.open()
|
||||
|
||||
if with_callback:
|
||||
callback.assert_awaited_once_with(client)
|
||||
|
Loading…
Reference in New Issue
Block a user