Compare commits

...

12 Commits

Author SHA1 Message Date
3f13152684
fix(ci): Add the --check and --diff argument to isort’s pre-commit config 2023-11-16 06:44:34 +01:00
83add95c8a
feat: Rename the disconnect method of each SHSEndpoint to close
The two verbs have the same meaning in this context, and this unifies the two interfaces, which helps type checking in
code using this library.
2023-11-13 06:36:51 +01:00
1ad7cb4e5e
test: Add tests for secret_handshake.network 2023-11-06 05:55:17 +01:00
1d07f9ba02
test: Fully cover crypto.py with tests 2023-11-06 05:11:54 +01:00
5807e64462
test: Fully cover boxstream.py with tests 2023-11-06 04:42:01 +01:00
9ea816f832 fix: Make inc_nonce overflow at the correct value 2023-11-03 20:21:21 +01:00
c8b07ef913 fix: Make sure split_chunks doesn’t get 0 as chunk size 2023-11-03 20:19:41 +01:00
995f0dabed test: Add tests for util.inc_nonce 2023-11-03 20:13:29 +01:00
88516b230b test: Add tests for util.split_chunks 2023-11-03 20:01:34 +01:00
5820cd3e5c test: Add tests for long_to_bytes and bytes_to_long 2023-11-03 19:56:20 +01:00
86a9fa300c chore: Get rid of the util.b function
It doesn't really gave a use in the world of bytestrings.
2023-11-03 07:46:14 +01:00
5a3af65927
build: Lower Python version requirement to 3.9
The library could work even with 3.6, but some dependencies require 3.9.  Also, even though 3.8 is still supported until
late 2024, moving to higher Python versions might even result in better security.
2023-10-31 12:53:02 +01:00
10 changed files with 471 additions and 33 deletions

View File

@ -32,6 +32,7 @@ repos:
require_serial: true require_serial: true
- id: isort - id: isort
name: isort name: isort
args: ["--check", "--diff"]
entry: poetry run isort entry: poetry run isort
language: system language: system
require_serial: true require_serial: true

55
poetry.lock generated
View File

@ -36,6 +36,9 @@ files = [
{file = "astroid-3.0.1.tar.gz", hash = "sha256:86b0bb7d7da0be1a7c4aedb7974e391b32d4ed89e33de6ed6902b4b15c97577e"}, {file = "astroid-3.0.1.tar.gz", hash = "sha256:86b0bb7d7da0be1a7c4aedb7974e391b32d4ed89e33de6ed6902b4b15c97577e"},
] ]
[package.dependencies]
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
[[package]] [[package]]
name = "babel" name = "babel"
version = "2.13.1" version = "2.13.1"
@ -100,6 +103,8 @@ mypy-extensions = ">=0.4.3"
packaging = ">=22.0" packaging = ">=22.0"
pathspec = ">=0.9.0" pathspec = ">=0.9.0"
platformdirs = ">=2" platformdirs = ">=2"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
colorama = ["colorama (>=0.4.3)"] colorama = ["colorama (>=0.4.3)"]
@ -131,8 +136,10 @@ files = [
[package.dependencies] [package.dependencies]
colorama = {version = "*", markers = "os_name == \"nt\""} colorama = {version = "*", markers = "os_name == \"nt\""}
importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""}
packaging = ">=19.0" packaging = ">=19.0"
pyproject_hooks = "*" pyproject_hooks = "*"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"] docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"]
@ -350,6 +357,7 @@ files = [
[package.dependencies] [package.dependencies]
build = ">=0.1" build = ">=0.1"
setuptools = "*" setuptools = "*"
tomli = {version = "*", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
test = ["mock (>=3.0.0)", "pytest"] test = ["mock (>=3.0.0)", "pytest"]
@ -464,6 +472,9 @@ files = [
{file = "coverage-7.3.2.tar.gz", hash = "sha256:be32ad29341b0170e795ca590e1c07e81fc061cb5b10c74ce7203491484404ef"}, {file = "coverage-7.3.2.tar.gz", hash = "sha256:be32ad29341b0170e795ca590e1c07e81fc061cb5b10c74ce7203491484404ef"},
] ]
[package.dependencies]
tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""}
[package.extras] [package.extras]
toml = ["tomli"] toml = ["tomli"]
@ -514,6 +525,20 @@ files = [
{file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"}, {file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"},
] ]
[[package]]
name = "exceptiongroup"
version = "1.1.3"
description = "Backport of PEP 654 (exception groups)"
optional = false
python-versions = ">=3.7"
files = [
{file = "exceptiongroup-1.1.3-py3-none-any.whl", hash = "sha256:343280667a4585d195ca1cf9cef84a4e178c4b6cf2274caef9859782b567d5e3"},
{file = "exceptiongroup-1.1.3.tar.gz", hash = "sha256:097acd85d473d75af5bb98e41b61ff7fe35efe6675e4f9370ec6ec5126d160e9"},
]
[package.extras]
test = ["pytest (>=6)"]
[[package]] [[package]]
name = "filelock" name = "filelock"
version = "3.13.1" version = "3.13.1"
@ -766,6 +791,7 @@ files = [
[package.dependencies] [package.dependencies]
mypy-extensions = ">=1.0.0" mypy-extensions = ">=1.0.0"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = ">=4.1.0" typing-extensions = ">=4.1.0"
[package.extras] [package.extras]
@ -938,11 +964,17 @@ files = [
[package.dependencies] [package.dependencies]
astroid = ">=3.0.1,<=3.1.0-dev0" astroid = ">=3.0.1,<=3.1.0-dev0"
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
dill = {version = ">=0.3.7", markers = "python_version >= \"3.12\""} dill = [
{version = ">=0.2", markers = "python_version < \"3.11\""},
{version = ">=0.3.7", markers = "python_version >= \"3.12\""},
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
]
isort = ">=4.2.5,<6" isort = ">=4.2.5,<6"
mccabe = ">=0.6,<0.8" mccabe = ">=0.6,<0.8"
platformdirs = ">=2.2.0" platformdirs = ">=2.2.0"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
tomlkit = ">=0.10.1" tomlkit = ">=0.10.1"
typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""}
[package.extras] [package.extras]
spelling = ["pyenchant (>=3.2,<4.0)"] spelling = ["pyenchant (>=3.2,<4.0)"]
@ -985,6 +1017,9 @@ files = [
{file = "pyproject_hooks-1.0.0.tar.gz", hash = "sha256:f271b298b97f5955d53fb12b72c1fb1948c22c1a6b70b315c54cedaca0264ef5"}, {file = "pyproject_hooks-1.0.0.tar.gz", hash = "sha256:f271b298b97f5955d53fb12b72c1fb1948c22c1a6b70b315c54cedaca0264ef5"},
] ]
[package.dependencies]
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "7.4.3" version = "7.4.3"
@ -998,9 +1033,11 @@ files = [
[package.dependencies] [package.dependencies]
colorama = {version = "*", markers = "sys_platform == \"win32\""} colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*" iniconfig = "*"
packaging = "*" packaging = "*"
pluggy = ">=0.12,<2.0" pluggy = ">=0.12,<2.0"
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
@ -1228,6 +1265,7 @@ babel = ">=2.9"
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
docutils = ">=0.18.1,<0.21" docutils = ">=0.18.1,<0.21"
imagesize = ">=1.3" imagesize = ">=1.3"
importlib-metadata = {version = ">=4.8", markers = "python_version < \"3.10\""}
Jinja2 = ">=3.0" Jinja2 = ">=3.0"
packaging = ">=21.0" packaging = ">=21.0"
Pygments = ">=2.14" Pygments = ">=2.14"
@ -1363,6 +1401,17 @@ files = [
[package.extras] [package.extras]
tests = ["pytest", "pytest-cov"] tests = ["pytest", "pytest-cov"]
[[package]]
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.7"
files = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
[[package]] [[package]]
name = "tomlkit" name = "tomlkit"
version = "0.12.1" version = "0.12.1"
@ -1461,5 +1510,5 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.12" python-versions = "^3.9"
content-hash = "d369c4e8658ac40afe61eee66a6a00eb6a7ecc6ffbf5875c83f5eb7df4d1a961" content-hash = "69fae4847ecb272da1713aa16b4d0d782e8fde2b1aaf5baebeb8b8a7bae3d1f3"

View File

@ -12,7 +12,7 @@ packages = [{include = "secret_handshake"}]
include = ["secret_handshake/py.typed"] include = ["secret_handshake/py.typed"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.12" python = "^3.9"
PyNaCl = "^1.5.0" PyNaCl = "^1.5.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]

View File

@ -53,6 +53,7 @@ def get_stream_pair( # pylint: disable=too-many-arguments
decrypt_nonce: bytes, decrypt_nonce: bytes,
encrypt_key: bytes, encrypt_key: bytes,
encrypt_nonce: bytes, encrypt_nonce: bytes,
# We have kwargs here to devour any extra parameters we get, e.g. from the output of SHSCryptoBase.get_box_keys()
**kwargs: Any, **kwargs: Any,
) -> Tuple["UnboxStream", "BoxStream"]: ) -> Tuple["UnboxStream", "BoxStream"]:
"""Create a new duplex box stream""" """Create a new duplex box stream"""

View File

@ -27,7 +27,7 @@ from typing import AsyncIterator, Awaitable, Callable, List, Optional
from nacl.public import PrivateKey from nacl.public import PrivateKey
from nacl.signing import SigningKey from nacl.signing import SigningKey
from typing_extensions import Self from typing_extensions import Self, deprecated
from .boxstream import BoxStream, UnboxStream, get_stream_pair from .boxstream import BoxStream, UnboxStream, get_stream_pair
from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto
@ -91,10 +91,16 @@ class SHSEndpoint:
self._on_connect = cb self._on_connect = cb
def close(self) -> None: # pragma: no cover
"""Disconnect the endpoint"""
raise NotImplementedError()
@deprecated("Use close instead")
def disconnect(self) -> None: def disconnect(self) -> None:
"""Disconnect the endpoint""" """Disconnect the endpoint"""
raise NotImplementedError self.close()
class SHSServer(SHSEndpoint): class SHSServer(SHSEndpoint):
@ -145,10 +151,14 @@ class SHSServer(SHSEndpoint):
await start_server(self.handle_connection, self.host, self.port) await start_server(self.handle_connection, self.host, self.port)
def disconnect(self) -> None: def close(self) -> None:
for connection in self.connections: for connection in self.connections:
connection.close() connection.close()
@deprecated("Use close instead")
def disconnect(self) -> None:
self.close()
class SHSServerConnection(SHSDuplexStream): class SHSServerConnection(SHSDuplexStream):
"""SHS server connection""" """SHS server connection"""
@ -221,6 +231,3 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
if self._on_connect: if self._on_connect:
await self._on_connect(self) await self._on_connect(self)
def disconnect(self) -> None:
self.close()

View File

@ -35,7 +35,7 @@ def inc_nonce(nonce: bytes) -> bytes:
num = bytes_to_long(nonce) + 1 num = bytes_to_long(nonce) + 1
if num > 2**MAX_NONCE: if num > 2**MAX_NONCE - 1:
num = 0 num = 0
bnum = long_to_bytes(num) bnum = long_to_bytes(num)
@ -44,21 +44,17 @@ def inc_nonce(nonce: bytes) -> bytes:
return bnum return bnum
def split_chunks(seq: Sequence[T], n: int) -> Generator[Sequence[T], None, None]: def split_chunks(seq: Sequence[T], chunk_size: int) -> Generator[Sequence[T], None, None]:
"""Split sequence in equal-sized chunks. """Split sequence in equal-sized chunks.
The last chunk is not padded.""" The last chunk is not padded."""
if chunk_size <= 0:
raise ValueError("chunk_size must be greater than zero")
while seq: while seq:
yield seq[:n] yield seq[:chunk_size]
seq = seq[n:] seq = seq[chunk_size:]
# Stolen from PyCypto (Public Domain)
def b(s: str) -> bytes:
"""Shorthand for s.encode("latin-1")"""
return s.encode("latin-1") # utf-8 would cause some side-effects we don't want
def long_to_bytes(n: int, blocksize: int = 0) -> bytes: def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
@ -69,7 +65,7 @@ def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
""" """
# after much testing, this algorithm was deemed to be the fastest # after much testing, this algorithm was deemed to be the fastest
s = b("") s = b""
pack = struct.pack pack = struct.pack
while n > 0: while n > 0:
@ -78,11 +74,11 @@ def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
# strip off leading zeros # strip off leading zeros
for i, c in enumerate(s): for i, c in enumerate(s):
if c != b("\000")[0]: if c != 0:
break break
else: else:
# only happens when n == 0 # only happens when n == 0
s = b("\000") s = b"\x00"
i = 0 i = 0
s = s[i:] s = s[i:]
@ -90,7 +86,7 @@ def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
# add back some pad bytes. this could be done more efficiently w.r.t. the # add back some pad bytes. this could be done more efficiently w.r.t. the
# de-padding being done above, but sigh... # de-padding being done above, but sigh...
if blocksize > 0 and len(s) % blocksize: if blocksize > 0 and len(s) % blocksize:
s = (blocksize - len(s) % blocksize) * b("\000") + s s = (blocksize - len(s) % blocksize) * b"\x00" + s
return s return s
@ -107,7 +103,7 @@ def bytes_to_long(s: bytes) -> int:
if length % 4: if length % 4:
extra = 4 - length % 4 extra = 4 - length % 4
s = b("\000") * extra + s s = b"\x00" * extra + s
length = length + extra length = length + extra
for i in range(0, length, 4): for i in range(0, length, 4):

View File

@ -22,7 +22,11 @@
"""Tests for the box stream""" """Tests for the box stream"""
from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream from asyncio import IncompleteReadError
from pytest_mock import MockerFixture
from secret_handshake.boxstream import HEADER_LENGTH, BoxStream, UnboxStream, get_stream_pair
from .helpers import AsyncBuffer, async_comprehend from .helpers import AsyncBuffer, async_comprehend
from .test_crypto import CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE from .test_crypto import CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE
@ -74,6 +78,18 @@ async def test_unboxstream() -> None:
assert unbox_stream.closed assert unbox_stream.closed
async def test_unboxstream_header_read_error(mocker: MockerFixture) -> None:
"""Test that we can handle errors during header read"""
buffer = AsyncBuffer()
mocker.patch.object(buffer, "readexactly", side_effect=IncompleteReadError(b"", HEADER_LENGTH))
unbox_stream = UnboxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
assert await unbox_stream.read() is None
assert unbox_stream.closed is True
async def test_long_packets() -> None: async def test_long_packets() -> None:
"""Test for receiving long packets""" """Test for receiving long packets"""
@ -94,3 +110,28 @@ async def test_long_packets() -> None:
assert first_packet == data[:4096] assert first_packet == data[:4096]
second_packet = await unbox_stream.read() second_packet = await unbox_stream.read()
assert second_packet == data[4096:] assert second_packet == data[4096:]
def test_get_stream_pair() -> None:
"""Test the get_stream_pair() function"""
read_buffer = AsyncBuffer()
write_buffer = AsyncBuffer()
read_stream, write_stream = get_stream_pair(
read_buffer,
write_buffer,
decrypt_key=b"d" * 32,
decrypt_nonce=b"dnonce",
encrypt_key=b"e" * 32,
encrypt_nonce=b"enonce",
)
assert isinstance(read_stream, UnboxStream)
assert isinstance(write_stream, BoxStream)
assert read_stream.key == b"d" * 32
assert read_stream.nonce == b"dnonce"
assert write_stream.key == b"e" * 32
assert write_stream.nonce == b"enonce"

View File

@ -23,12 +23,15 @@
"""Tests for the crypto components""" """Tests for the crypto components"""
import hashlib import hashlib
from typing import Literal
from nacl.exceptions import CryptoError
from nacl.public import PrivateKey from nacl.public import PrivateKey
from nacl.signing import SigningKey from nacl.signing import SigningKey
import pytest import pytest
from pytest_mock import MockerFixture
from secret_handshake.crypto import SHSClientCrypto, SHSServerCrypto from secret_handshake.crypto import SHSClientCrypto, SHSError, SHSServerCrypto
APP_KEY = hashlib.sha256(b"app_key").digest() APP_KEY = hashlib.sha256(b"app_key").digest()
SERVER_KEY_SEED = b"\xcaw\x01\xc2cQ\xfd\x94\x9f\x14\x84\x0c0<l\xd8\xe4\xf5>\x12\\\x96\xcd\x9b\x0c\x02z&\x96!\xe0\xa2" SERVER_KEY_SEED = b"\xcaw\x01\xc2cQ\xfd\x94\x9f\x14\x84\x0c0<l\xd8\xe4\xf5>\x12\\\x96\xcd\x9b\x0c\x02z&\x96!\xe0\xa2"
@ -48,14 +51,22 @@ def server() -> SHSServerCrypto:
@pytest.fixture @pytest.fixture
def client() -> SHSClientCrypto: def client(request: pytest.FixtureRequest) -> SHSClientCrypto:
"""A testing SHS client""" """A testing SHS client"""
app_key = None
for marker in request.node.iter_markers(name="client_app_key"):
app_key = marker.args[0]
if app_key is None:
app_key = APP_KEY
client_key = SigningKey(CLIENT_KEY_SEED) client_key = SigningKey(CLIENT_KEY_SEED)
server_key = SigningKey(SERVER_KEY_SEED) server_key = SigningKey(SERVER_KEY_SEED)
client_eph_key = PrivateKey(CLIENT_EPH_KEY_SEED) client_eph_key = PrivateKey(CLIENT_EPH_KEY_SEED)
return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=APP_KEY) return SHSClientCrypto(client_key, bytes(server_key.verify_key), client_eph_key, application_key=app_key)
CLIENT_CHALLENGE = ( CLIENT_CHALLENGE = (
@ -130,3 +141,67 @@ def test_handshake(client: SHSClientCrypto, server: SHSServerCrypto) -> None: #
assert client_keys["shared_secret"] == server_keys["shared_secret"] assert client_keys["shared_secret"] == server_keys["shared_secret"]
assert client_keys["encrypt_key"] == server_keys["decrypt_key"] assert client_keys["encrypt_key"] == server_keys["decrypt_key"]
assert client_keys["encrypt_nonce"] == server_keys["decrypt_nonce"] assert client_keys["encrypt_nonce"] == server_keys["decrypt_nonce"]
@pytest.mark.client_app_key(b"a" * 32)
def test_verify_challenge_different_app_keys(
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
) -> None:
"""Test challenge verification when the application keys of the client and server dont match"""
challenge = client.generate_challenge()
assert not server.verify_challenge(challenge)
def test_verify_challenge(
client: SHSClientCrypto, server: SHSServerCrypto # pylint: disable=redefined-outer-name
) -> None:
"""Test challenge verification when the application keys of the client and server match"""
challenge = client.generate_challenge()
assert server.verify_challenge(challenge)
@pytest.mark.parametrize("type_", ("server", "client"))
@pytest.mark.parametrize("provide_key", (True, False))
def test_clean(
type_: Literal["client", "server"], provide_key: bool, request: pytest.FixtureRequest, mocker: MockerFixture
) -> None:
"""Test the clean() method"""
if type_ == "server":
actor = request.getfixturevalue("server")
elif type_ == "client": # pragma: no branch
actor = request.getfixturevalue("client")
mocked_private_key = mocker.patch("secret_handshake.crypto.PrivateKey")
mocked_private_key.generate = mocker.MagicMock(return_value=PrivateKey(b"g" * 32))
new_key = PrivateKey(b"p" * 32) if provide_key else None
actor.clean(new_ephemeral_key=new_key)
assert actor.shared_secret is None
assert actor.shared_hash is None
assert actor.remote_ephemeral_key is None
assert isinstance(actor.local_ephemeral_key, PrivateKey)
if provide_key:
assert actor.local_ephemeral_key == new_key
else:
assert actor.local_ephemeral_key.encode() == b"g" * 32
def test_failing_server_accept(
client: SHSClientCrypto, server: SHSServerCrypto, mocker: MockerFixture # pylint: disable=redefined-outer-name
) -> None:
"""Test if verify_server_accept raises the correct exception type"""
server.verify_challenge(client.generate_challenge())
client.verify_server_challenge(server.generate_challenge())
server.verify_client_auth(client.generate_client_auth())
server_accept = server.generate_accept()
mocker.patch("secret_handshake.crypto.crypto_box_open_afternm", side_effect=CryptoError())
with pytest.raises(SHSError):
client.verify_server_accept(server_accept)

View File

@ -24,7 +24,7 @@
from asyncio import Event, wait_for from asyncio import Event, wait_for
import os import os
from typing import Any, Awaitable, Callable, Tuple from typing import Any, Awaitable, Callable, Literal, Tuple
from nacl.signing import SigningKey from nacl.signing import SigningKey
import pytest import pytest
@ -32,6 +32,7 @@ from pytest_mock import MockerFixture
from secret_handshake import SHSClient, SHSServer from secret_handshake import SHSClient, SHSServer
from secret_handshake.boxstream import BoxStreamKeys from secret_handshake.boxstream import BoxStreamKeys
from secret_handshake.network import SHSClientException, SHSDuplexStream
from .helpers import AsyncBuffer from .helpers import AsyncBuffer
@ -144,7 +145,7 @@ async def test_client(mocker: MockerFixture) -> None:
await client.open() await client.open()
reader.append(b"TEST") reader.append(b"TEST")
assert (await client.read()) == b"TEST" assert (await client.read()) == b"TEST"
client.disconnect() client.close()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -154,7 +155,7 @@ async def test_server(mocker: MockerFixture) -> None:
resolve = Event() resolve = Event()
async def _on_connect(_: Any) -> None: async def _on_connect(_: Any) -> None:
server.disconnect() server.close()
resolve.set() resolve.set()
_, _, _create_mock_server = _server_stream_mocker() _, _, _create_mock_server = _server_stream_mocker()
@ -169,3 +170,146 @@ async def test_server(mocker: MockerFixture) -> None:
await server.listen() await server.listen()
await wait_for(resolve.wait(), 5) await wait_for(resolve.wait(), 5)
def test_duplex_write(mocker: MockerFixture) -> None:
"""Test the writing capabilities of the duplex stream"""
d_stream = SHSDuplexStream()
d_stream.write_stream = mocker.AsyncMock()
d_stream.write(b"thing")
assert d_stream.write_stream
d_stream.write_stream.write.assert_called_once_with(b"thing") # type: ignore[attr-defined]
def test_duplex_close_no_write_stream() -> None:
"""Test if SHSDuplexStreams close method doesnt fail if there is no write stream"""
d_stream = SHSDuplexStream()
assert d_stream.write_stream is None
d_stream.close()
# We cannot really do assertions here. If there is not set (it is None), the above call would fail
def test_duplex_stream_aiter() -> None:
"""Test if the __aiter__ method of SHSDuplexStream returns the stream itself"""
d_stream = SHSDuplexStream()
assert d_stream.__aiter__() is d_stream # pylint: disable=unnecessary-dunder-call
async def test_duplex_stream_anext(mocker: MockerFixture) -> None:
"""Test if the __anext__ method of SHSDuplexStream reads from the stream"""
d_stream = SHSDuplexStream()
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=b"test"))
assert await d_stream.__anext__() == b"test" # pylint: disable=unnecessary-dunder-call
async def test_duplex_stream_anext_eof(mocker: MockerFixture) -> None:
"""Test if SHSDuplexStream.__anext__ breaks iteration if theres no data to read"""
d_stream = SHSDuplexStream()
mocker.patch.object(d_stream, "read", mocker.AsyncMock(return_value=None))
with pytest.raises(StopAsyncIteration):
assert await d_stream.__anext__() # pylint: disable=unnecessary-dunder-call
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_auth"))
async def test_server_fail_handshake(
fail_type: Literal["verify_challenge", "verify_auth"], mocker: MockerFixture
) -> None:
"""Test if a failing handshake results in an SHSClientException"""
server = SHSServer("127.0.0.1", 8754, SigningKey.generate())
if fail_type == "verify_challenge":
expected_error = "Client challenge is not valid"
elif fail_type == "verify_auth": # pragma: no branch
expected_error = "Client auth is not valid"
mocker.patch.object(server.crypto, "verify_challenge", return_value=fail_type != "verify_challenge")
mocker.patch.object(server.crypto, "verify_client_auth", return_value=fail_type != "verify_auth")
with pytest.raises(SHSClientException) as ctx:
await server._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
assert str(ctx.value) == expected_error
async def test_server_no_connect_callback(mocker: MockerFixture) -> None:
"""Test if SHSServer.handle_connection works without an on_connect callback"""
server = SHSServer("127.0.0.1", 7429, SigningKey.generate())
mocker.patch.object(server, "_handshake", return_value=None)
mocker.patch.object(
server.crypto,
"get_box_keys",
return_value={
"decrypt_key": b"d" * 32,
"decrypt_nonce": b"dnonce",
"encrypt_key": b"e" * 32,
"encrypt_nonce": b"enonce",
},
)
await server.handle_connection(AsyncBuffer(), AsyncBuffer())
# No assertion here. We should get here without a problem
@pytest.mark.parametrize("fail_type", ("verify_challenge", "verify_accept"))
async def test_client_fail_handshake(
fail_type: Literal["verify_challenge", "verify_accept"], mocker: MockerFixture
) -> None:
"""Test if a failing handshake results in an SHSClientException"""
client = SHSClient("127.0.0.1", 8754, SigningKey.generate(), b"s" * 32)
if fail_type == "verify_challenge":
expected_error = "Server challenge is not valid"
elif fail_type == "verify_accept": # pragma: no branch
expected_error = "Server accept is not valid"
mocker.patch.object(client.crypto, "verify_server_challenge", return_value=fail_type != "verify_challenge")
mocker.patch.object(client.crypto, "verify_server_accept", return_value=fail_type != "verify_accept")
mocker.patch.object(client.crypto, "generate_client_auth", return_value=b"ca" * 16)
with pytest.raises(SHSClientException) as ctx:
await client._handshake(AsyncBuffer(b"d" * 64), AsyncBuffer()) # pylint: disable=protected-access
assert str(ctx.value) == expected_error
@pytest.mark.parametrize("with_callback", (True, False))
async def test_client_open(with_callback: bool, mocker: MockerFixture) -> None:
"""Test if SHSServer.handle_connection works with and without an on_connect callback"""
client = SHSClient("127.0.0.1", 7429, SigningKey.generate(), SigningKey.generate().verify_key.encode())
mocker.patch("secret_handshake.network.open_connection", return_value=(AsyncBuffer(), AsyncBuffer()))
mocker.patch.object(client, "_handshake", return_value=None)
mocker.patch.object(
client.crypto,
"get_box_keys",
return_value={
"decrypt_key": b"d" * 32,
"decrypt_nonce": b"dnonce",
"encrypt_key": b"e" * 32,
"encrypt_nonce": b"enonce",
},
)
if with_callback:
callback = mocker.AsyncMock()
client.on_connect(callback)
await client.open()
if with_callback:
callback.assert_awaited_once_with(client)

124
tests/test_util.py Normal file
View File

@ -0,0 +1,124 @@
# SPDX-License-Identifier: MIT
#
# Copyright (c) 2017 PySecretHandshake contributors (see AUTHORS for more details)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Tests for utility functions"""
import math
from types import GeneratorType
from typing import List, Sequence, TypeVar
import pytest
from secret_handshake.util import bytes_to_long, inc_nonce, long_to_bytes, split_chunks
T = TypeVar("T")
@pytest.mark.parametrize(
"in_,out",
(
(b"\x00\x00\x00\x00", b"\x00" * 23 + b"\x01"),
(b"\xff" * 24, b"\x00" * 24),
),
)
def test_inc_nonce(in_: bytes, out: bytes) -> None:
"""Test the inc_nonce function"""
result = inc_nonce(in_)
assert len(result) == 24
assert result == out
def test_split_chunks_is_generator() -> None:
"""Test if split_chunks returns a generator"""
assert isinstance(split_chunks([], 1), GeneratorType)
@pytest.mark.parametrize("size", (-123, -1, 0))
def test_nonpositive_chunk_size(size: int) -> None:
"""Test if split_chunks() with non-positive chunk sizes raise an error"""
with pytest.raises(ValueError) as ctx:
list(split_chunks(b"", size))
assert str(ctx.value) == "chunk_size must be greater than zero"
@pytest.mark.parametrize(
"in_,chunksize,out",
(
(b"asdfg", 2, [b"as", b"df", b"g"]),
(b"asdfgh", 3, [b"asd", b"fgh"]),
),
)
def test_split_chunks(in_: Sequence[T], chunksize: int, out: List[Sequence[T]]) -> None:
"""Test if split_chunks splits the input into equal chunks"""
assert list(split_chunks(in_, chunksize)) == out
@pytest.mark.parametrize(
"in_,out",
(
(0, b"\x00"),
(1, b"\x01"),
(4278255360, b"\xff\x00\xff\x00"),
(65536, b"\x01\x00\x00"),
(4546694913, b"\x01\x0f\x01\x0f\x01"),
),
)
@pytest.mark.parametrize("blocksize", (0, 4, 6))
def test_long_to_bytes(in_: int, out: bytes, blocksize: int) -> None:
"""Test long_to_bytes"""
result = long_to_bytes(in_, blocksize=blocksize)
if blocksize:
block_count = math.ceil(len(out) / blocksize)
else:
block_count = 1
padding = b"\x00" * blocksize * block_count
expected = (padding + out)[-(blocksize * block_count) :]
if blocksize:
assert not len(result) % blocksize
assert result == expected
@pytest.mark.parametrize(
"in_,out",
(
(b"\x00\x00\x00\x00", 0),
(b"\x00\x00\x00\x01", 1),
(b"\xff\x00\xff\x00", 4278255360),
(b"\x01\x00\x00", 65536),
(b"\x01\x0f\x01\x0f\x01", 4546694913),
),
)
def test_bytes_to_long(in_: bytes, out: int) -> None:
"""Test bytes_to_long"""
assert bytes_to_long(in_) == out