Compare commits
12 Commits
sphinx-doc
...
main
Author | SHA1 | Date | |
---|---|---|---|
3f13152684 | |||
83add95c8a | |||
1ad7cb4e5e | |||
1d07f9ba02 | |||
5807e64462 | |||
9ea816f832 | |||
c8b07ef913 | |||
995f0dabed | |||
88516b230b | |||
5820cd3e5c | |||
86a9fa300c | |||
5a3af65927 |
@ -32,6 +32,7 @@ repos:
|
||||
require_serial: true
|
||||
- id: isort
|
||||
name: isort
|
||||
args: ["--check", "--diff"]
|
||||
entry: poetry run isort
|
||||
language: system
|
||||
require_serial: true
|
||||
|
55
poetry.lock
generated
55
poetry.lock
generated
@ -36,6 +36,9 @@ files = [
|
||||
{file = "astroid-3.0.1.tar.gz", hash = "sha256:86b0bb7d7da0be1a7c4aedb7974e391b32d4ed89e33de6ed6902b4b15c97577e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "babel"
|
||||
version = "2.13.1"
|
||||
@ -100,6 +103,8 @@ mypy-extensions = ">=0.4.3"
|
||||
packaging = ">=22.0"
|
||||
pathspec = ">=0.9.0"
|
||||
platformdirs = ">=2"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
colorama = ["colorama (>=0.4.3)"]
|
||||
@ -131,8 +136,10 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "os_name == \"nt\""}
|
||||
importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""}
|
||||
packaging = ">=19.0"
|
||||
pyproject_hooks = "*"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"]
|
||||
@ -350,6 +357,7 @@ files = [
|
||||
[package.dependencies]
|
||||
build = ">=0.1"
|
||||
setuptools = "*"
|
||||
tomli = {version = "*", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
test = ["mock (>=3.0.0)", "pytest"]
|
||||
@ -464,6 +472,9 @@ files = [
|
||||
{file = "coverage-7.3.2.tar.gz", hash = "sha256:be32ad29341b0170e795ca590e1c07e81fc061cb5b10c74ce7203491484404ef"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli"]
|
||||
|
||||
@ -514,6 +525,20 @@ files = [
|
||||
{file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.1.3"
|
||||
description = "Backport of PEP 654 (exception groups)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"},
|
||||
{file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["pytest (>=6)"]
|
||||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.13.1"
|
||||
@ -766,6 +791,7 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
mypy-extensions = ">=1.0.0"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
typing-extensions = ">=4.1.0"
|
||||
|
||||
[package.extras]
|
||||
@ -938,11 +964,17 @@ files = [
|
||||
[package.dependencies]
|
||||
astroid = ">=3.0.1,<=3.1.0-dev0"
|
||||
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
|
||||
dill = {version = ">=0.3.7", markers = "python_version >= \"3.12\""}
|
||||
dill = [
|
||||
{version = ">=0.2", markers = "python_version < \"3.11\""},
|
||||
{version = ">=0.3.7", markers = "python_version >= \"3.12\""},
|
||||
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||
]
|
||||
isort = ">=4.2.5,<6"
|
||||
mccabe = ">=0.6,<0.8"
|
||||
platformdirs = ">=2.2.0"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
tomlkit = ">=0.10.1"
|
||||
typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""}
|
||||
|
||||
[package.extras]
|
||||
spelling = ["pyenchant (>=3.2,<4.0)"]
|
||||
@ -985,6 +1017,9 @@ files = [
|
||||
{file = "pyproject_hooks-1.0.0.tar.gz", hash = "sha256:f271b298b97f5955d53fb12b72c1fb1948c22c1a6b70b315c54cedaca0264ef5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "7.4.3"
|
||||
@ -998,9 +1033,11 @@ files = [
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "sys_platform == \"win32\""}
|
||||
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
|
||||
iniconfig = "*"
|
||||
packaging = "*"
|
||||
pluggy = ">=0.12,<2.0"
|
||||
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
|
||||
@ -1228,6 +1265,7 @@ babel = ">=2.9"
|
||||
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
|
||||
docutils = ">=0.18.1,<0.21"
|
||||
imagesize = ">=1.3"
|
||||
importlib-metadata = {version = ">=4.8", markers = "python_version < \"3.10\""}
|
||||
Jinja2 = ">=3.0"
|
||||
packaging = ">=21.0"
|
||||
Pygments = ">=2.14"
|
||||
@ -1363,6 +1401,17 @@ files = [
|
||||
[package.extras]
|
||||
tests = ["pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.1"
|
||||
description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
|
||||
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomlkit"
|
||||
version = "0.12.1"
|
||||
@ -1461,5 +1510,5 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.12"
|
||||
content-hash = "d369c4e8658ac40afe61eee66a6a00eb6a7ecc6ffbf5875c83f5eb7df4d1a961"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "69fae4847ecb272da1713aa16b4d0d782e8fde2b1aaf5baebeb8b8a7bae3d1f3"
|
||||
|
@ -12,7 +12,7 @@ packages = [{include = "secret_handshake"}]
|
||||
include = ["secret_handshake/py.typed"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.12"
|
||||
python = "^3.9"
|
||||
PyNaCl = "^1.5.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
@ -53,6 +53,7 @@ def get_stream_pair( # pylint: disable=too-many-arguments
|
||||
decrypt_nonce: bytes,
|
||||
encrypt_key: bytes,
|
||||
encrypt_nonce: bytes,
|
||||
# We have kwargs here to devour any extra parameters we get, e.g. from the output of SHSCryptoBase.get_box_keys()
|
||||
**kwargs: Any,
|
||||
) -> Tuple["UnboxStream", "BoxStream"]:
|
||||
"""Create a new duplex box stream"""
|
||||
|
@ -27,7 +27,7 @@ from typing import AsyncIterator, Awaitable, Callable, List, Optional
|
||||
|
||||
from nacl.public import PrivateKey
|
||||
from nacl.signing import SigningKey
|
||||
from typing_extensions import Self
|
||||
from typing_extensions import Self, deprecated
|
||||
|
||||
from .boxstream import BoxStream, UnboxStream, get_stream_pair
|
||||
from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto
|
||||
@ -91,10 +91,16 @@ class SHSEndpoint:
|
||||
|
||||
self._on_connect = cb
|
||||
|
||||
def close(self) -> None: # pragma: no cover
|
||||
"""Disconnect the endpoint"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@deprecated("Use close instead")
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect the endpoint"""
|
||||
|
||||
raise NotImplementedError
|
||||
self.close()
|
||||
|
||||
|
||||
class SHSServer(SHSEndpoint):
|
||||
@ -145,10 +151,14 @@ class SHSServer(SHSEndpoint):
|
||||
|
||||
await start_server(self.handle_connection, self.host, self.port)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
def close(self) -> None:
|
||||
for connection in self.connections:
|
||||
connection.close()
|
||||
|
||||
@deprecated("Use close instead")
|
||||
def disconnect(self) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class SHSServerConnection(SHSDuplexStream):
|
||||
"""SHS server connection"""
|
||||
@ -221,6 +231,3 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
|
||||
|
||||
if self._on_connect:
|
||||
await self._on_connect(self)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.close()
|
||||
|
@ -35,7 +35,7 @@ def inc_nonce(nonce: bytes) -> bytes:
|
||||
|
||||
num = bytes_to_long(nonce) + 1
|
||||
|
||||
if num > 2**MAX_NONCE:
|
||||
if num > 2**MAX_NONCE - 1:
|
||||
num = 0
|
||||
|
||||
bnum = long_to_bytes(num)
|
||||
@ -44,21 +44,17 @@ def inc_nonce(nonce: bytes) -> bytes:
|
||||
return bnum
|
||||
|
||||
|
||||
def split_chunks(seq: Sequence[T], n: int) -> Generator[Sequence[T], None, None]:
|
||||
def split_chunks(seq: Sequence[T], chunk_size: int) -> Generator[Sequence[T], None, None]:
|
||||
"""Split sequence in equal-sized chunks.
|
||||
|
||||
The last chunk is not padded."""
|
||||
|
||||
if chunk_size <= 0:
|
||||
raise ValueError("chunk_size must be greater than zero")
|
||||
|
||||
while seq:
|
||||
yield seq[:n]
|
||||
seq = seq[n:]
|
||||
|
||||
|
||||
# Stolen from PyCypto (Public Domain)
|
||||
def b(s: str) -> bytes:
|
||||
"""Shorthand for s.encode("latin-1")"""
|
||||
|
||||
return s.encode("latin-1") # utf-8 would cause some side-effects we don't want
|
||||
yield seq[:chunk_size]
|
||||
seq = seq[chunk_size:]
|
||||
|
||||
|
||||
def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
|
||||
@ -69,7 +65,7 @@ def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
|
||||
"""
|
||||
|
||||
# after much testing, this algorithm was deemed to be the fastest
|
||||
s = b("")
|
||||
s = b""
|
||||
pack = struct.pack
|
||||
|
||||
while n > 0:
|
||||
@ -78,11 +74,11 @@ def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
|
||||
|
||||
# strip off leading zeros
|
||||
for i, c in enumerate(s):
|
||||
if c != b("\000")[0]:
|
||||
if c != 0:
|
||||
break
|
||||
else:
|
||||
# only happens when n == 0
|
||||
s = b("\000")
|
||||
s = b"\x00"
|
||||
i = 0
|
||||
|
||||
s = s[i:]
|
||||
@ -90,7 +86,7 @@ def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
|
||||
# add back some pad bytes. this could be done more efficiently w.r.t. the
|
||||
# de-padding being done above, but sigh...
|
||||
if blocksize > 0 and len(s) % blocksize:
|
||||
s = (blocksize - len(s) % blocksize) * b("\000") + s
|
||||
s = (blocksize - len(s) % blocksize) * b"\x00" + s
|
||||
|
||||
return s
|
||||
|
||||
@ -107,7 +103,7 @@ def bytes_to_long(s: bytes) -> int:
|
||||
|
||||
if length % 4:
|
||||
extra = 4 - length % 4
|
||||
s = b("\000") * extra + s
|
||||
s = b"\x00" * extra + s
|
||||
length = length + extra
|
||||
|
||||
for i in range(0, length, 4):
|
||||
|
@ -22,7 +22,11 @@
|
||||
|
||||
"""Tests for the box stream"""
|
||||
|
||||
from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream
|
||||
from asyncio import IncompleteReadError
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream, get_stream_pair
|
||||
|
||||
from .helpers import AsyncBuffer, async_comprehend
|
||||
from .test_crypto import CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE
|
||||
@ -74,6 +78,18 @@ async def test_unboxstream() -> None:
|
||||
assert unbox_stream.closed
|
||||
|
||||
|
||||
async def test_unboxstream_header_read_error(mocker: MockerFixture) -> None:
|
||||
"""Test that we can handle errors during header read"""
|
||||
|
||||
buffer = AsyncBuffer()
|
||||
mocker.patch.object(buffer, "readexactly", side_effect=IncompleteReadError(b"", HEADER_LENGTH))
|
||||
|
||||
unbox_stream = UnboxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
|
||||
|
||||
assert await unbox_stream.read() is None
|
||||
assert unbox_stream.closed is True
|
||||
|
||||
|
||||
async def test_long_packets() -> None:
|
||||
"""Test for receiving long packets"""
|
||||
|
||||
@ -94,3 +110,28 @@ async def test_long_packets() -> None:
|
||||
assert first_packet == data[:4096]
|
||||
second_packet = await unbox_stream.read()
|
||||
assert second_packet == data[4096:]
|
||||
|
||||
|
||||
def test_get_stream_pair() -> None:
|
||||
"""Test the get_stream_pair() function"""
|
||||
|
||||
read_buffer = AsyncBuffer()
|
||||
write_buffer = AsyncBuffer()
|
||||
|
||||
read_stream, write_stream = get_stream_pair(
|
||||
read_buffer,
|
||||
write_buffer,
|
||||
decrypt_key=b"d" * 32,
|
||||
decrypt_nonce=b"dnonce",
|
||||
encrypt_key=b"e" * 32,
|
||||
encrypt_nonce=b"enonce",
|
||||
)
|
||||
|
||||
assert isinstance(read_stream, UnboxStream)
|
||||
assert isinstance(write_stream, BoxStream)
|
||||
|
||||
assert read_stream.key == b"d" * 32
|
||||
assert read_stream.nonce == b"dnonce"
|
||||
|
||||
assert write_stream.key == b"e" * 32
|
||||
assert write_stream.nonce == b"enonce"
|
||||
|
@ -23,12 +23,15 @@
|
||||
"""Tests for the crypto components"""
|
||||
|
||||
import hashlib
|
||||
from typing import Literal
|
||||
|
||||
from nacl.exceptions import CryptoError
|
||||
from nacl.public import PrivateKey
|
||||
from nacl.signing import SigningKey
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from secret_handshake.crypto import SHSClientCrypto, SHSServerCrypto
|
||||
from secret_handshake.crypto import SHSClientCrypto, SHSError, SHSServerCrypto
|
||||
|
||||
APP_KEY = hashlib.sha256(b"app_key").digest()
|
||||
SERVER_KEY_SEED = b"\xcaw\x01\xc2cQ\xfd\x94\x9f\x14\x84\x0c0<l\xd8\xe4\xf5>\x12\\\x96\xcd\x9b\x0c\x02z&\x96!\xe0\xa2"
|
||||
@ -48,14 +51,22 @@ def server() -> SHSServerCrypto:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> SHSClientCrypto:
|
||||
def client(request: pytest.FixtureRequest) -> SHSClientCrypto:
|
||||
"""A testing SHS client"""
|
||||
|
||||
app_key = None
|
||||
|
||||
for marker in request.node.iter_markers(name="client_app_key"):
|
||||
app_key = marker.args[0]
|
||||
|
||||
if app_key is None:
|
||||
app_key = APP_KEY
|
||||
|
||||
client_key = SigningKey(CLIENT_KEY_SEED)
|
||||
server_key = SigningKey(SERVER_KEY_SEED)
|
||||
client_eph_key = PrivateKey(CLIENT_EPH_KEY_SEED)
|
||||
|
||||
return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=APP_KEY)
|
||||
return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=app_key)
|
||||
|
||||
|
||||
CLIENT_CHALLENGE = (
|
||||
@ -130,3 +141,67 @@ def test_handshake(client: SHSClientCrypto, server: SHSServerCrypto) -> None: #
|
||||
assert client_keys["shared_secret"] == server_keys["shared_secret"]
|
||||
assert client_keys["encrypt_key"] == server_keys["decrypt_key"]
|
||||
assert client_keys["encrypt_nonce"] == server_keys["decrypt_nonce"]
|
||||
|
||||
|
||||
@pytest.mark.client_app_key(b"a" * 32)
|
||||
def test_verify_challenge_different_app_keys(
|
||||
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
|
||||
) -> None:
|
||||
"""Test challenge verification when the application keys of the client and server don’t match"""
|
||||
|
||||
challenge = client.generate_challenge()
|
||||
assert not server.verify_challenge(challenge)
|
||||
|
||||
|
||||
def test_verify_challenge(
|
||||
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
|
||||
) -> None:
|
||||
"""Test challenge verification when the application keys of the client and server match"""
|
||||
|
||||
challenge = client.generate_challenge()
|
||||
assert server.verify_challenge(challenge)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("type_", ("server", "client"))
|
||||
@pytest.mark.parametrize("provide_key", (True, False))
|
||||
def test_clean(
|
||||
type_: Literal["client", "server"], provide_key: bool, request: pytest.FixtureRequest, mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test the clean() method"""
|
||||
|
||||
if type_ == "server":
|
||||
actor = request.getfixturevalue("server")
|
||||
elif type_ == "client": # pragma: no branch
|
||||
actor = request.getfixturevalue("client")
|
||||
|
||||
mocked_private_key = mocker.patch("secret_handshake.crypto.PrivateKey")
|
||||
mocked_private_key.generate = mocker.MagicMock(return_value=PrivateKey(b"g" * 32))
|
||||
|
||||
new_key = PrivateKey(b"p" * 32) if provide_key else None
|
||||
actor.clean(new_ephemeral_key=new_key)
|
||||
|
||||
assert actor.shared_secret is None
|
||||
assert actor.shared_hash is None
|
||||
assert actor.remote_ephemeral_key is None
|
||||
assert isinstance(actor.local_ephemeral_key, PrivateKey)
|
||||
|
||||
if provide_key:
|
||||
assert actor.local_ephemeral_key == new_key
|
||||
else:
|
||||
assert actor.local_ephemeral_key.encode() == b"g" * 32
|
||||
|
||||
|
||||
def test_failing_server_accept(
|
||||
client: SHSClientCrypto, server: SHSServerCrypto, mocker: MockerFixture # pylint: disable=redefined-outer-name
|
||||
) -> None:
|
||||
"""Test if verify_server_accept raises the correct exception type"""
|
||||
|
||||
server.verify_challenge(client.generate_challenge())
|
||||
client.verify_server_challenge(server.generate_challenge())
|
||||
server.verify_client_auth(client.generate_client_auth())
|
||||
server_accept = server.generate_accept()
|
||||
|
||||
mocker.patch("secret_handshake.crypto.crypto_box_open_afternm", side_effect=CryptoError())
|
||||
|
||||
with pytest.raises(SHSError):
|
||||
client.verify_server_accept(server_accept)
|
||||
|
@ -24,7 +24,7 @@
|
||||
|
||||
from asyncio import Event, wait_for
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable, Tuple
|
||||
from typing import Any, Awaitable, Callable, Literal, Tuple
|
||||
|
||||
from nacl.signing import SigningKey
|
||||
import pytest
|
||||
@ -32,6 +32,7 @@ from pytest_mock import MockerFixture
|
||||
|
||||
from secret_handshake import SHSClient, SHSServer
|
||||
from secret_handshake.boxstream import BoxStreamKeys
|
||||
from secret_handshake.network import SHSClientException, SHSDuplexStream
|
||||
|
||||
from .helpers import AsyncBuffer
|
||||
|
||||
@ -144,7 +145,7 @@ async def test_client(mocker: MockerFixture) -> None:
|
||||
await client.open()
|
||||
reader.append(b"TEST")
|
||||
assert (await client.read()) == b"TEST"
|
||||
client.disconnect()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -154,7 +155,7 @@ async def test_server(mocker: MockerFixture) -> None:
|
||||
resolve = Event()
|
||||
|
||||
async def _on_connect(_: Any) -> None:
|
||||
server.disconnect()
|
||||
server.close()
|
||||
resolve.set()
|
||||
|
||||
_, _, _create_mock_server = _server_stream_mocker()
|
||||
@ -169,3 +170,146 @@ async def test_server(mocker: MockerFixture) -> None:
|
||||
|
||||
await server.listen()
|
||||
await wait_for(resolve.wait(), 5)
|
||||
|
||||
|
||||
def test_duplex_write(mocker: MockerFixture) -> None:
|
||||
"""Test the writing capabilities of the duplex stream"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
d_stream.write_stream = mocker.AsyncMock()
|
||||
d_stream.write(b"thing")
|
||||
|
||||
assert d_stream.write_stream
|
||||
|
||||
d_stream.write_stream.write.assert_called_once_with(b"thing") # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_duplex_close_no_write_stream() -> None:
|
||||
"""Test if SHSDuplexStream’s close method doesn’t fail if there is no write stream"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
assert d_stream.write_stream is None
|
||||
d_stream.close()
|
||||
|
||||
# We cannot really do assertions here. If there is not set (it is None), the above call would fail
|
||||
|
||||
|
||||
def test_duplex_stream_aiter() -> None:
|
||||
"""Test if the __aiter__ method of SHSDuplexStream returns the stream itself"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
assert d_stream.__aiter__() is d_stream # pylint: disable=unnecessary-dunder-call
|
||||
|
||||
|
||||
async def test_duplex_stream_anext(mocker: MockerFixture) -> None:
|
||||
"""Test if the __anext__ method of SHSDuplexStream reads from the stream"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=b"test"))
|
||||
|
||||
assert await d_stream.__anext__() == b"test" # pylint: disable=unnecessary-dunder-call
|
||||
|
||||
|
||||
async def test_duplex_stream_anext_eof(mocker: MockerFixture) -> None:
|
||||
"""Test if SHSDuplexStream.__anext__ breaks iteration if there’s no data to read"""
|
||||
|
||||
d_stream = SHSDuplexStream()
|
||||
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=None))
|
||||
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
assert await d_stream.__anext__() # pylint: disable=unnecessary-dunder-call
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_auth"))
|
||||
async def test_server_fail_handshake(
|
||||
fail_type: Literal["verify_challenge", "verify_auth"], mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test if a failing handshake results in an SHSClientException"""
|
||||
|
||||
server = SHSServer("127.0.0.1", 8754, SigningKey.generate())
|
||||
|
||||
if fail_type == "verify_challenge":
|
||||
expected_error = "Client challenge is not valid"
|
||||
elif fail_type == "verify_auth": # pragma: no branch
|
||||
expected_error = "Client auth is not valid"
|
||||
|
||||
mocker.patch.object(server.crypto, "verify_challenge", return_value=fail_type != "verify_challenge")
|
||||
mocker.patch.object(server.crypto, "verify_client_auth", return_value=fail_type != "verify_auth")
|
||||
|
||||
with pytest.raises(SHSClientException) as ctx:
|
||||
await server._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
|
||||
|
||||
assert str(ctx.value) == expected_error
|
||||
|
||||
|
||||
async def test_server_no_connect_callback(mocker: MockerFixture) -> None:
|
||||
"""Test if SHSServer.handle_connection works without an on_connect callback"""
|
||||
|
||||
server = SHSServer("127.0.0.1", 7429, SigningKey.generate())
|
||||
mocker.patch.object(server, "_handshake", return_value=None)
|
||||
mocker.patch.object(
|
||||
server.crypto,
|
||||
"get_box_keys",
|
||||
return_value={
|
||||
"decrypt_key": b"d" * 32,
|
||||
"decrypt_nonce": b"dnonce",
|
||||
"encrypt_key": b"e" * 32,
|
||||
"encrypt_nonce": b"enonce",
|
||||
},
|
||||
)
|
||||
|
||||
await server.handle_connection(AsyncBuffer(), AsyncBuffer())
|
||||
|
||||
# No assertion here. We should get here without a problem
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_accept"))
|
||||
async def test_client_fail_handshake(
|
||||
fail_type: Literal["verify_challenge", "verify_accept"], mocker: MockerFixture
|
||||
) -> None:
|
||||
"""Test if a failing handshake results in an SHSClientException"""
|
||||
|
||||
client = SHSClient("127.0.0.1", 8754, SigningKey.generate(), b"s" * 32)
|
||||
|
||||
if fail_type == "verify_challenge":
|
||||
expected_error = "Server challenge is not valid"
|
||||
elif fail_type == "verify_accept": # pragma: no branch
|
||||
expected_error = "Server accept is not valid"
|
||||
|
||||
mocker.patch.object(client.crypto, "verify_server_challenge", return_value=fail_type != "verify_challenge")
|
||||
mocker.patch.object(client.crypto, "verify_server_accept", return_value=fail_type != "verify_accept")
|
||||
mocker.patch.object(client.crypto, "generate_client_auth", return_value=b"ca" * 16)
|
||||
|
||||
with pytest.raises(SHSClientException) as ctx:
|
||||
await client._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
|
||||
|
||||
assert str(ctx.value) == expected_error
|
||||
|
||||
|
||||
@pytest.mark.parametrize("with_callback", (True, False))
|
||||
async def test_client_open(with_callback: bool, mocker: MockerFixture) -> None:
|
||||
"""Test if SHSServer.handle_connection works with and without an on_connect callback"""
|
||||
|
||||
client = SHSClient("127.0.0.1", 7429, SigningKey.generate(), SigningKey.generate().verify_key.encode())
|
||||
|
||||
mocker.patch("secret_handshake.network.open_connection", return_value=(AsyncBuffer(), AsyncBuffer()))
|
||||
mocker.patch.object(client, "_handshake", return_value=None)
|
||||
mocker.patch.object(
|
||||
client.crypto,
|
||||
"get_box_keys",
|
||||
return_value={
|
||||
"decrypt_key": b"d" * 32,
|
||||
"decrypt_nonce": b"dnonce",
|
||||
"encrypt_key": b"e" * 32,
|
||||
"encrypt_nonce": b"enonce",
|
||||
},
|
||||
)
|
||||
|
||||
if with_callback:
|
||||
callback = mocker.AsyncMock()
|
||||
client.on_connect(callback)
|
||||
|
||||
await client.open()
|
||||
|
||||
if with_callback:
|
||||
callback.assert_awaited_once_with(client)
|
||||
|
124
tests/test_util.py
Normal file
124
tests/test_util.py
Normal file
@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
#
|
||||
# Copyright (c) 2017 PySecretHandshake contributors (see AUTHORS for more details)
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
"""Tests for utility functions"""
|
||||
|
||||
import math
|
||||
from types import GeneratorType
|
||||
from typing import List, Sequence, TypeVar
|
||||
|
||||
import pytest
|
||||
|
||||
from secret_handshake.util import bytes_to_long, inc_nonce, long_to_bytes, split_chunks
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"in_,out",
|
||||
(
|
||||
(b"\x00\x00\x00\x00", b"\x00" * 23 + b"\x01"),
|
||||
(b"\xff" * 24, b"\x00" * 24),
|
||||
),
|
||||
)
|
||||
def test_inc_nonce(in_: bytes, out: bytes) -> None:
|
||||
"""Test the inc_nonce function"""
|
||||
|
||||
result = inc_nonce(in_)
|
||||
|
||||
assert len(result) == 24
|
||||
assert result == out
|
||||
|
||||
|
||||
def test_split_chunks_is_generator() -> None:
|
||||
"""Test if split_chunks returns a generator"""
|
||||
|
||||
assert isinstance(split_chunks([], 1), GeneratorType)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("size", (-123, -1, 0))
|
||||
def test_nonpositive_chunk_size(size: int) -> None:
|
||||
"""Test if split_chunks() with non-positive chunk sizes raise an error"""
|
||||
|
||||
with pytest.raises(ValueError) as ctx:
|
||||
list(split_chunks(b"", size))
|
||||
|
||||
assert str(ctx.value) == "chunk_size must be greater than zero"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"in_,chunksize,out",
|
||||
(
|
||||
(b"asdfg", 2, [b"as", b"df", b"g"]),
|
||||
(b"asdfgh", 3, [b"asd", b"fgh"]),
|
||||
),
|
||||
)
|
||||
def test_split_chunks(in_: Sequence[T], chunksize: int, out: List[Sequence[T]]) -> None:
|
||||
"""Test if split_chunks splits the input into equal chunks"""
|
||||
|
||||
assert list(split_chunks(in_, chunksize)) == out
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"in_,out",
|
||||
(
|
||||
(0, b"\x00"),
|
||||
(1, b"\x01"),
|
||||
(4278255360, b"\xff\x00\xff\x00"),
|
||||
(65536, b"\x01\x00\x00"),
|
||||
(4546694913, b"\x01\x0f\x01\x0f\x01"),
|
||||
),
|
||||
)
|
||||
@pytest.mark.parametrize("blocksize", (0, 4, 6))
|
||||
def test_long_to_bytes(in_: int, out: bytes, blocksize: int) -> None:
|
||||
"""Test long_to_bytes"""
|
||||
|
||||
result = long_to_bytes(in_, blocksize=blocksize)
|
||||
|
||||
if blocksize:
|
||||
block_count = math.ceil(len(out) / blocksize)
|
||||
else:
|
||||
block_count = 1
|
||||
|
||||
padding = b"\x00" * blocksize * block_count
|
||||
expected = (padding + out)[-(blocksize * block_count) :]
|
||||
|
||||
if blocksize:
|
||||
assert not len(result) % blocksize
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"in_,out",
|
||||
(
|
||||
(b"\x00\x00\x00\x00", 0),
|
||||
(b"\x00\x00\x00\x01", 1),
|
||||
(b"\xff\x00\xff\x00", 4278255360),
|
||||
(b"\x01\x00\x00", 65536),
|
||||
(b"\x01\x0f\x01\x0f\x01", 4546694913),
|
||||
),
|
||||
)
|
||||
def test_bytes_to_long(in_: bytes, out: int) -> None:
|
||||
"""Test bytes_to_long"""
|
||||
|
||||
assert bytes_to_long(in_) == out
|
Loading…
Reference in New Issue
Block a user