ci: Add mypy as a dev dependency and configure it

This commit is contained in:
Gergely Polonkai 2023-10-30 05:47:28 +01:00
parent 95039914ba
commit 4996931b54
No known key found for this signature in database
GPG Key ID: 2D2885533B869ED4
14 changed files with 350 additions and 145 deletions

View File

@ -36,3 +36,10 @@ repos:
language: system language: system
require_serial: true require_serial: true
types_or: [python, pyi] types_or: [python, pyi]
- id: mypy
name: mypy
entry: poetry run mypy
args: ["--strict"]
language: system
types_or: [python, pyi]
require_serial: true

View File

@ -13,7 +13,7 @@ with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
async def main(): async def main() -> None:
"""Main function to run""" """Main function to run"""
server_pub_key = b64decode(config["public"][:-8]) server_pub_key = b64decode(config["public"][:-8])

View File

@ -8,17 +8,18 @@ from nacl.signing import SigningKey
import yaml import yaml
from secret_handshake import SHSServer from secret_handshake import SHSServer
from secret_handshake.network import SHSDuplexStream
with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f: with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f:
config = yaml.safe_load(f) config = yaml.safe_load(f)
async def _on_connect(conn): async def _on_connect(conn: SHSDuplexStream) -> None:
async for msg in conn: async for msg in conn:
print(msg) print(msg)
async def main(): async def main() -> None:
"""Main function to run""" """Main function to run"""
server_keypair = SigningKey(b64decode(config["private"][:-8])[:32]) server_keypair = SigningKey(b64decode(config["private"][:-8])[:32])

69
poetry.lock generated
View File

@ -674,6 +674,51 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
] ]
[[package]]
name = "mypy"
version = "1.6.1"
description = "Optional static typing for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "mypy-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e5012e5cc2ac628177eaac0e83d622b2dd499e28253d4107a08ecc59ede3fc2c"},
{file = "mypy-1.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d8fbb68711905f8912e5af474ca8b78d077447d8f3918997fecbf26943ff3cbb"},
{file = "mypy-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a1ad938fee7d2d96ca666c77b7c494c3c5bd88dff792220e1afbebb2925b5e"},
{file = "mypy-1.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b96ae2c1279d1065413965c607712006205a9ac541895004a1e0d4f281f2ff9f"},
{file = "mypy-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:40b1844d2e8b232ed92e50a4bd11c48d2daa351f9deee6c194b83bf03e418b0c"},
{file = "mypy-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81af8adaa5e3099469e7623436881eff6b3b06db5ef75e6f5b6d4871263547e5"},
{file = "mypy-1.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8c223fa57cb154c7eab5156856c231c3f5eace1e0bed9b32a24696b7ba3c3245"},
{file = "mypy-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8032e00ce71c3ceb93eeba63963b864bf635a18f6c0c12da6c13c450eedb183"},
{file = "mypy-1.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4c46b51de523817a0045b150ed11b56f9fff55f12b9edd0f3ed35b15a2809de0"},
{file = "mypy-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:19f905bcfd9e167159b3d63ecd8cb5e696151c3e59a1742e79bc3bcb540c42c7"},
{file = "mypy-1.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:82e469518d3e9a321912955cc702d418773a2fd1e91c651280a1bda10622f02f"},
{file = "mypy-1.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d4473c22cc296425bbbce7e9429588e76e05bc7342da359d6520b6427bf76660"},
{file = "mypy-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a0d7d24dfb26729e0a068639a6ce3500e31d6655df8557156c51c1cb874ce7"},
{file = "mypy-1.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cfd13d47b29ed3bbaafaff7d8b21e90d827631afda134836962011acb5904b71"},
{file = "mypy-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:eb4f18589d196a4cbe5290b435d135dee96567e07c2b2d43b5c4621b6501531a"},
{file = "mypy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:41697773aa0bf53ff917aa077e2cde7aa50254f28750f9b88884acea38a16169"},
{file = "mypy-1.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7274b0c57737bd3476d2229c6389b2ec9eefeb090bbaf77777e9d6b1b5a9d143"},
{file = "mypy-1.6.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbaf4662e498c8c2e352da5f5bca5ab29d378895fa2d980630656178bd607c46"},
{file = "mypy-1.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bb8ccb4724f7d8601938571bf3f24da0da791fe2db7be3d9e79849cb64e0ae85"},
{file = "mypy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:68351911e85145f582b5aa6cd9ad666c8958bcae897a1bfda8f4940472463c45"},
{file = "mypy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:49ae115da099dcc0922a7a895c1eec82c1518109ea5c162ed50e3b3594c71208"},
{file = "mypy-1.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b27958f8c76bed8edaa63da0739d76e4e9ad4ed325c814f9b3851425582a3cd"},
{file = "mypy-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:925cd6a3b7b55dfba252b7c4561892311c5358c6b5a601847015a1ad4eb7d332"},
{file = "mypy-1.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8f57e6b6927a49550da3d122f0cb983d400f843a8a82e65b3b380d3d7259468f"},
{file = "mypy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a43ef1c8ddfdb9575691720b6352761f3f53d85f1b57d7745701041053deff30"},
{file = "mypy-1.6.1-py3-none-any.whl", hash = "sha256:4cbe68ef919c28ea561165206a2dcb68591c50f3bcf777932323bc208d949cf1"},
{file = "mypy-1.6.1.tar.gz", hash = "sha256:4d01c00d09a0be62a4ca3f933e315455bde83f37f892ba4b08ce92f3cf44bcc1"},
]
[package.dependencies]
mypy-extensions = ">=1.0.0"
typing-extensions = ">=4.1.0"
[package.extras]
dmypy = ["psutil (>=4.0)"]
install-types = ["pip"]
reports = ["lxml"]
[[package]] [[package]]
name = "mypy-extensions" name = "mypy-extensions"
version = "1.0.0" version = "1.0.0"
@ -1243,6 +1288,28 @@ files = [
{file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"}, {file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"},
] ]
[[package]]
name = "types-pyyaml"
version = "6.0.12.12"
description = "Typing stubs for PyYAML"
optional = false
python-versions = "*"
files = [
{file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"},
{file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"},
]
[[package]]
name = "typing-extensions"
version = "4.8.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
{file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"},
{file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"},
]
[[package]] [[package]]
name = "urllib3" name = "urllib3"
version = "2.0.7" version = "2.0.7"
@ -1309,4 +1376,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.12" python-versions = "^3.12"
content-hash = "82c05c7c22b990f45fae70d36d0bed6f1c414b30af9e40318dfc383ebab9d86a" content-hash = "980ed766f11ffdaf6694f7800174de76cd21ac2d7a5ffbe27d8df1d1673f6119"

View File

@ -6,6 +6,7 @@ authors = ["Pedro Ferreira <pedro@dete.st>"]
license = "MIT" license = "MIT"
readme = "README.rst" readme = "README.rst"
packages = [{include = "secret_handshake"}] packages = [{include = "secret_handshake"}]
include = ["secret_handshake/py.typed"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.12" python = "^3.12"
@ -24,6 +25,8 @@ pre-commit = "^3.5.0"
commitizen = "^3.12.0" commitizen = "^3.12.0"
black = "^23.10.1" black = "^23.10.1"
pylint = "^3.0.2" pylint = "^3.0.2"
mypy = "^1.6.1"
types-pyyaml = "^6.0.12.12"
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
sphinx = "^7.2.6" sphinx = "^7.2.6"

View File

@ -1,8 +1,8 @@
"""Box stream utilities""" """Box stream utilities"""
from asyncio import IncompleteReadError from asyncio import IncompleteReadError, StreamReader, StreamWriter
import struct import struct
from typing import Tuple from typing import Any, AsyncIterator, Optional, Tuple, TypedDict
from nacl.secret import SecretBox from nacl.secret import SecretBox
@ -13,36 +13,51 @@ MAX_SEGMENT_SIZE = 4 * 1024
TERMINATION_HEADER = b"\x00" * 18 TERMINATION_HEADER = b"\x00" * 18
def get_stream_pair(reader, writer, **kwargs) -> Tuple["UnboxStream", "BoxStream"]: class BoxStreamKeys(TypedDict):
"""Dictionary to hold all box stream keys"""
decrypt_key: bytes
decrypt_nonce: bytes
encrypt_key: bytes
encrypt_nonce: bytes
shared_secret: bytes
def get_stream_pair( # pylint: disable=too-many-arguments
reader: StreamReader, # pylint: disable=unused-argument
writer: StreamWriter,
*,
decrypt_key: bytes,
decrypt_nonce: bytes,
encrypt_key: bytes,
encrypt_nonce: bytes,
**kwargs: Any,
) -> Tuple["UnboxStream", "BoxStream"]:
"""Create a new duplex box stream""" """Create a new duplex box stream"""
box_args = { read_stream = UnboxStream(reader, key=decrypt_key, nonce=decrypt_nonce)
"key": kwargs["encrypt_key"], write_stream = BoxStream(writer, key=encrypt_key, nonce=encrypt_nonce)
"nonce": kwargs["encrypt_nonce"],
} return read_stream, write_stream
unbox_args = {
"key": kwargs["decrypt_key"],
"nonce": kwargs["decrypt_nonce"],
}
return UnboxStream(reader, **unbox_args), BoxStream(writer, **box_args)
class UnboxStream: class UnboxStream:
"""Unboxing stream""" """Unboxing stream"""
def __init__(self, reader, key, nonce): def __init__(self, reader: StreamReader, key: bytes, nonce: bytes):
self.reader = reader self.reader = reader
self.key = key self.key = key
self.nonce = nonce self.nonce = nonce
self.closed = False self.closed = False
async def read(self): async def read(self) -> Optional[bytes]:
"""Read data from the stream""" """Read data from the stream"""
try: try:
data = await self.reader.readexactly(HEADER_LENGTH) data = await self.reader.readexactly(HEADER_LENGTH)
except IncompleteReadError: except IncompleteReadError:
self.closed = True self.closed = True
return None return None
box = SecretBox(self.key) box = SecretBox(self.key)
@ -51,6 +66,7 @@ class UnboxStream:
if header == TERMINATION_HEADER: if header == TERMINATION_HEADER:
self.closed = True self.closed = True
return None return None
length = struct.unpack(">H", header[:2])[0] length = struct.unpack(">H", header[:2])[0]
@ -61,12 +77,13 @@ class UnboxStream:
body = box.decrypt(mac + data, inc_nonce(self.nonce)) body = box.decrypt(mac + data, inc_nonce(self.nonce))
self.nonce = inc_nonce(inc_nonce(self.nonce)) self.nonce = inc_nonce(inc_nonce(self.nonce))
return body return body
def __aiter__(self): def __aiter__(self) -> AsyncIterator[bytes]:
return self return self
async def __anext__(self): async def __anext__(self) -> bytes:
data = await self.read() data = await self.read()
if data is None: if data is None:
@ -78,17 +95,17 @@ class UnboxStream:
class BoxStream: class BoxStream:
"""Box stream""" """Box stream"""
def __init__(self, writer, key, nonce): def __init__(self, writer: StreamWriter, key: bytes, nonce: bytes):
self.writer = writer self.writer = writer
self.key = key self.key = key
self.box = SecretBox(self.key) self.box = SecretBox(self.key)
self.nonce = nonce self.nonce = nonce
def write(self, data): def write(self, data: bytes) -> None:
"""Write data to the box stream""" """Write data to the box stream"""
for chunk in split_chunks(data, MAX_SEGMENT_SIZE): for chunk in split_chunks(data, MAX_SEGMENT_SIZE):
body = self.box.encrypt(chunk, inc_nonce(self.nonce))[24:] body = self.box.encrypt(bytes(chunk), inc_nonce(self.nonce))[24:]
header = struct.pack(">H", len(body) - 16) + body[:16] header = struct.pack(">H", len(body) - 16) + body[:16]
hdrbox = self.box.encrypt(header, self.nonce)[24:] hdrbox = self.box.encrypt(header, self.nonce)[24:]
@ -97,7 +114,7 @@ class BoxStream:
self.nonce = inc_nonce(inc_nonce(self.nonce)) self.nonce = inc_nonce(inc_nonce(self.nonce))
self.writer.write(body[16:]) self.writer.write(body[16:])
def close(self): def close(self) -> None:
"""Close the box stream""" """Close the box stream"""
self.writer.write(self.box.encrypt(b"\x00" * 18, self.nonce)[24:]) self.writer.write(self.box.encrypt(b"\x00" * 18, self.nonce)[24:])

View File

@ -28,7 +28,9 @@ from typing import Optional
from nacl.bindings import crypto_box_afternm, crypto_box_open_afternm, crypto_scalarmult from nacl.bindings import crypto_box_afternm, crypto_box_open_afternm, crypto_scalarmult
from nacl.exceptions import CryptoError from nacl.exceptions import CryptoError
from nacl.public import PrivateKey from nacl.public import PrivateKey
from nacl.signing import VerifyKey from nacl.signing import SigningKey, VerifyKey
from .boxstream import BoxStreamKeys
APPLICATION_KEY = b64decode("1KHLiKZvAvjbY1ziZEHMXawbCEIM6qwjCDm3VYRan/s=") APPLICATION_KEY = b64decode("1KHLiKZvAvjbY1ziZEHMXawbCEIM6qwjCDm3VYRan/s=")
@ -40,29 +42,38 @@ class SHSError(Exception):
class SHSCryptoBase: # pylint: disable=too-many-instance-attributes class SHSCryptoBase: # pylint: disable=too-many-instance-attributes
"""Base functions for SHS cryptography""" """Base functions for SHS cryptography"""
def __init__(self, local_key, ephemeral_key=None, application_key=None): def __init__(
self.local_key = local_key self,
local_key: SigningKey,
ephemeral_key: Optional[PrivateKey] = None,
application_key: Optional[bytes] = None,
):
self.application_key = application_key or APPLICATION_KEY self.application_key = application_key or APPLICATION_KEY
self.shared_hash = None self.b_alice: Optional[bytes] = None
self.remote_ephemeral_key = None self.box_secret: Optional[bytes] = None
self.shared_secret = None self.hello: Optional[bytes] = None
self.remote_app_hmac = None self.local_key = local_key
self.remote_pub_key = None self.remote_app_hmac: Optional[bytes] = None
self.box_secret = None self.remote_ephemeral_key: Optional[bytes] = None
self.remote_pub_key: Optional[VerifyKey] = None
self.shared_hash: Optional[bytes] = None
self.shared_secret: Optional[bytes] = None
self._reset_keys(ephemeral_key or PrivateKey.generate()) self._reset_keys(ephemeral_key or PrivateKey.generate())
def _reset_keys(self, ephemeral_key): def _reset_keys(self, ephemeral_key: PrivateKey) -> None:
self.local_ephemeral_key = ephemeral_key self.local_ephemeral_key = ephemeral_key
self.local_app_hmac = hmac.new( self.local_app_hmac = hmac.new(
self.application_key, bytes(ephemeral_key.public_key), digestmod="sha512" self.application_key, bytes(ephemeral_key.public_key), digestmod="sha512"
).digest()[:32] ).digest()[:32]
def generate_challenge(self): def generate_challenge(self) -> bytes:
"""Generate and return a challenge to be sent to the server.""" """Generate and return a challenge to be sent to the server."""
return self.local_app_hmac + bytes(self.local_ephemeral_key.public_key) return self.local_app_hmac + bytes(self.local_ephemeral_key.public_key)
def verify_challenge(self, data): def verify_challenge(self, data: bytes) -> bool:
"""Verify the correctness of challenge sent from the client.""" """Verify the correctness of challenge sent from the client."""
assert len(data) == 64 assert len(data) == 64
sent_hmac, remote_ephemeral_key = data[:32], data[32:] sent_hmac, remote_ephemeral_key = data[:32], data[32:]
@ -76,9 +87,10 @@ class SHSCryptoBase: # pylint: disable=too-many-instance-attributes
self.remote_ephemeral_key = remote_ephemeral_key self.remote_ephemeral_key = remote_ephemeral_key
# this is hash(a * b) # this is hash(a * b)
self.shared_hash = hashlib.sha256(self.shared_secret).digest() self.shared_hash = hashlib.sha256(self.shared_secret).digest()
return ok return ok
def clean(self, new_ephemeral_key=None): def clean(self, new_ephemeral_key: Optional[PrivateKey] = None) -> None:
"""Clean internal data""" """Clean internal data"""
self._reset_keys(new_ephemeral_key or PrivateKey.generate()) self._reset_keys(new_ephemeral_key or PrivateKey.generate())
@ -86,10 +98,15 @@ class SHSCryptoBase: # pylint: disable=too-many-instance-attributes
self.shared_hash = None self.shared_hash = None
self.remote_ephemeral_key = None self.remote_ephemeral_key = None
def get_box_keys(self): def get_box_keys(self) -> BoxStreamKeys:
"""Get the box streams keys""" """Get the box streams keys"""
assert self.box_secret
assert self.remote_app_hmac
assert self.remote_pub_key
shared_secret = hashlib.sha256(self.box_secret).digest() shared_secret = hashlib.sha256(self.box_secret).digest()
return { return {
"shared_secret": shared_secret, "shared_secret": shared_secret,
"encrypt_key": hashlib.sha256(shared_secret + bytes(self.remote_pub_key)).digest(), "encrypt_key": hashlib.sha256(shared_secret + bytes(self.remote_pub_key)).digest(),
@ -102,17 +119,15 @@ class SHSCryptoBase: # pylint: disable=too-many-instance-attributes
class SHSServerCrypto(SHSCryptoBase): class SHSServerCrypto(SHSCryptoBase):
"""SHS server crypto algorithm""" """SHS server crypto algorithm"""
def __init__(self, *args, **kwargs): def verify_client_auth(self, data: bytes) -> bool:
super().__init__(*args, **kwargs)
self.b_alice = None
self.hello = None
self.box_secret = None
self.remote_pub_key = None
def verify_client_auth(self, data):
"""Verify client authentication data""" """Verify client authentication data"""
assert self.remote_ephemeral_key
assert self.shared_hash
assert self.shared_secret
assert len(data) == 112 assert len(data) == 112
a_bob = crypto_scalarmult(bytes(self.local_key.to_curve25519_private_key()), self.remote_ephemeral_key) a_bob = crypto_scalarmult(bytes(self.local_key.to_curve25519_private_key()), self.remote_ephemeral_key)
box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob).digest() box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob).digest()
self.hello = crypto_box_open_afternm(data, b"\x00" * 24, box_secret) self.hello = crypto_box_open_afternm(data, b"\x00" * 24, box_secret)
@ -127,15 +142,21 @@ class SHSServerCrypto(SHSCryptoBase):
bytes(self.local_ephemeral_key), bytes(self.remote_pub_key.to_curve25519_public_key()) bytes(self.local_ephemeral_key), bytes(self.remote_pub_key.to_curve25519_public_key())
) )
self.box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob + b_alice).digest()[:32] self.box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob + b_alice).digest()[:32]
return True return True
def generate_accept(self): def generate_accept(self) -> bytes:
"""Generate an accept message""" """Generate an accept message"""
assert self.box_secret
assert self.hello
assert self.shared_hash
okay = self.local_key.sign(self.application_key + self.hello + self.shared_hash).signature okay = self.local_key.sign(self.application_key + self.hello + self.shared_hash).signature
return crypto_box_afternm(okay, b"\x00" * 24, self.box_secret) return crypto_box_afternm(okay, b"\x00" * 24, self.box_secret)
def clean(self, new_ephemeral_key=None): def clean(self, new_ephemeral_key: Optional[PrivateKey] = None) -> None:
super().clean(new_ephemeral_key=new_ephemeral_key) super().clean(new_ephemeral_key=new_ephemeral_key)
self.hello = None self.hello = None
self.b_alice = None self.b_alice = None
@ -152,21 +173,26 @@ class SHSClientCrypto(SHSCryptoBase):
def __init__( def __init__(
self, self,
local_key: PrivateKey, local_key: SigningKey,
server_pub_key: bytes, server_pub_key: bytes,
ephemeral_key: PrivateKey, ephemeral_key: PrivateKey,
application_key: Optional[bytes] = None, application_key: Optional[bytes] = None,
): ):
super().__init__(local_key, ephemeral_key, application_key) super().__init__(local_key, ephemeral_key, application_key)
self.remote_pub_key = VerifyKey(server_pub_key)
self.b_alice = None
self.a_bob = None
self.hello = None
self.box_secret = None
def verify_server_challenge(self, data): self.a_bob: Optional[bytes] = None
self.remote_pub_key = VerifyKey(server_pub_key)
def verify_server_challenge(self, data: bytes) -> bool:
"""Verify the correctness of challenge sent from the server.""" """Verify the correctness of challenge sent from the server."""
assert self.remote_pub_key
assert super().verify_challenge(data) assert super().verify_challenge(data)
assert self.shared_hash
assert self.shared_secret
curve_pkey = self.remote_pub_key.to_curve25519_public_key() curve_pkey = self.remote_pub_key.to_curve25519_public_key()
# a_bob is (a * B) # a_bob is (a * B)
@ -179,17 +205,30 @@ class SHSClientCrypto(SHSCryptoBase):
signed_message = self.local_key.sign(self.application_key + bytes(self.remote_pub_key) + self.shared_hash) signed_message = self.local_key.sign(self.application_key + bytes(self.remote_pub_key) + self.shared_hash)
message_to_box = signed_message.signature + bytes(self.local_key.verify_key) message_to_box = signed_message.signature + bytes(self.local_key.verify_key)
self.hello = message_to_box self.hello = message_to_box
return True return True
def generate_client_auth(self): def generate_client_auth(self) -> bytes:
"""Generate box[K|a*b|a*B](H)""" """Generate box[K|a*b|a*B](H)"""
assert self.box_secret
assert self.hello
nonce = b"\x00" * 24 nonce = b"\x00" * 24
# return box(K | a * b | a * B)[H] # return box(K | a * b | a * B)[H]
return crypto_box_afternm(self.hello, nonce, self.box_secret) return crypto_box_afternm(self.hello, nonce, self.box_secret)
def verify_server_accept(self, data): def verify_server_accept(self, data: bytes) -> bool:
"""Verify that the server's accept message is sane""" """Verify that the server's accept message is sane"""
assert self.a_bob
assert self.hello
assert self.remote_ephemeral_key
assert self.remote_pub_key
assert self.shared_hash
assert self.shared_secret
curve_lkey = self.local_key.to_curve25519_private_key() curve_lkey = self.local_key.to_curve25519_private_key()
# b_alice is (A * b) # b_alice is (A * b)
b_alice = crypto_scalarmult(bytes(curve_lkey), self.remote_ephemeral_key) b_alice = crypto_scalarmult(bytes(curve_lkey), self.remote_ephemeral_key)
@ -208,9 +247,10 @@ class SHSClientCrypto(SHSCryptoBase):
# we should have received sign(B)[K | H | hash(a * b)] # we should have received sign(B)[K | H | hash(a * b)]
# let's see if that signature can verify the reconstructed data on our side # let's see if that signature can verify the reconstructed data on our side
self.remote_pub_key.verify(self.application_key + self.hello + self.shared_hash, signature) self.remote_pub_key.verify(self.application_key + self.hello + self.shared_hash, signature)
return True return True
def clean(self, new_ephemeral_key=None): def clean(self, new_ephemeral_key: Optional[PrivateKey] = None) -> None:
super().clean(new_ephemeral_key=new_ephemeral_key) super().clean(new_ephemeral_key=new_ephemeral_key)
self.a_bob = None self.a_bob = None
self.b_alice = None self.b_alice = None

View File

@ -20,10 +20,15 @@
"""Networking functionality""" """Networking functionality"""
import asyncio from asyncio import StreamReader, StreamWriter, ensure_future, open_connection, start_server
from typing import AsyncIterator, Awaitable, Callable, List, Optional
from .boxstream import get_stream_pair from nacl.public import PrivateKey
from .crypto import SHSClientCrypto, SHSServerCrypto from nacl.signing import SigningKey
from typing_extensions import Self
from .boxstream import BoxStream, UnboxStream, get_stream_pair
from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto
class SHSClientException(Exception): class SHSClientException(Exception):
@ -33,32 +38,37 @@ class SHSClientException(Exception):
class SHSDuplexStream: class SHSDuplexStream:
"""SHS duplex stream""" """SHS duplex stream"""
def __init__(self): def __init__(self) -> None:
self.write_stream = None self.write_stream: Optional[BoxStream] = None
self.read_stream = None self.read_stream: Optional[UnboxStream] = None
self.is_connected = False self.is_connected = False
def write(self, data): def write(self, data: bytes) -> None:
"""Write data to the write stream""" """Write data to the write stream"""
assert self.write_stream
self.write_stream.write(data) self.write_stream.write(data)
async def read(self): async def read(self) -> Optional[bytes]:
"""Read data from the read stream""" """Read data from the read stream"""
assert self.read_stream
return await self.read_stream.read() return await self.read_stream.read()
def close(self): def close(self) -> None:
"""Close the duplex stream""" """Close the duplex stream"""
self.write_stream.close() if self.write_stream:
self.read_stream.close() self.write_stream.close()
self.is_connected = False self.is_connected = False
def __aiter__(self): def __aiter__(self) -> AsyncIterator[bytes]:
return self return self
async def __anext__(self): async def __anext__(self) -> bytes:
msg = await self.read() msg = await self.read()
if msg is None: if msg is None:
@ -70,45 +80,53 @@ class SHSDuplexStream:
class SHSEndpoint: class SHSEndpoint:
"""SHS endpoint""" """SHS endpoint"""
def __init__(self): def __init__(self) -> None:
self._on_connect = None self._on_connect: Optional[Callable[[SHSDuplexStream], Awaitable[None]]] = None
self.crypto = None self.crypto: Optional[SHSCryptoBase] = None
def on_connect(self, cb): def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
"""Set the function to be called when a new connection arrives""" """Set the function to be called when a new connection arrives"""
self._on_connect = cb self._on_connect = cb
def disconnect(self): def disconnect(self) -> None:
"""Disconnect the endpoint""" """Disconnect the endpoint"""
raise NotImplementedError raise NotImplementedError
class SHSServer(SHSEndpoint): class SHSServer(SHSEndpoint):
"""SHS server""" """SHS server"""
def __init__(self, host, port, server_kp, application_key=None): def __init__(self, host: str, port: int, server_kp: SigningKey, application_key: Optional[bytes] = None):
super().__init__() super().__init__()
self.host = host self.host = host
self.port = port self.port = port
self.crypto = SHSServerCrypto(server_kp, application_key=application_key) self.crypto: SHSServerCrypto = SHSServerCrypto(server_kp, application_key=application_key)
self.connections = [] self.connections: List[SHSServerConnection] = []
async def _handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
assert self.crypto
async def _handshake(self, reader, writer):
data = await reader.readexactly(64) data = await reader.readexactly(64)
if not self.crypto.verify_challenge(data): if not self.crypto.verify_challenge(data):
raise SHSClientException("Client challenge is not valid") raise SHSClientException("Client challenge is not valid")
writer.write(self.crypto.generate_challenge()) writer.write(self.crypto.generate_challenge())
data = await reader.readexactly(112) data = await reader.readexactly(112)
if not self.crypto.verify_client_auth(data): if not self.crypto.verify_client_auth(data):
raise SHSClientException("Client auth is not valid") raise SHSClientException("Client auth is not valid")
writer.write(self.crypto.generate_accept()) writer.write(self.crypto.generate_accept())
async def handle_connection(self, reader, writer): async def handle_connection(self, reader: StreamReader, writer: StreamWriter) -> None:
"""Handle incoming connections""" """Handle incoming connections"""
assert self.crypto
self.crypto.clean() self.crypto.clean()
await self._handshake(reader, writer) await self._handshake(reader, writer)
keys = self.crypto.get_box_keys() keys = self.crypto.get_box_keys()
@ -118,14 +136,14 @@ class SHSServer(SHSEndpoint):
self.connections.append(conn) self.connections.append(conn)
if self._on_connect: if self._on_connect:
asyncio.ensure_future(self._on_connect(conn)) ensure_future(self._on_connect(conn))
async def listen(self): async def listen(self) -> None:
"""Listen for connections""" """Listen for connections"""
await asyncio.start_server(self.handle_connection, self.host, self.port) await start_server(self.handle_connection, self.host, self.port)
def disconnect(self): def disconnect(self) -> None:
for connection in self.connections: for connection in self.connections:
connection.close() connection.close()
@ -133,52 +151,63 @@ class SHSServer(SHSEndpoint):
class SHSServerConnection(SHSDuplexStream): class SHSServerConnection(SHSDuplexStream):
"""SHS server connection""" """SHS server connection"""
def __init__(self, read_stream, write_stream): def __init__(self, read_stream: UnboxStream, write_stream: BoxStream):
super().__init__() super().__init__()
self.read_stream = read_stream self.read_stream = read_stream
self.write_stream = write_stream self.write_stream = write_stream
@classmethod @classmethod
def from_byte_streams(cls, reader, writer, **keys): def from_byte_streams(cls, reader: StreamReader, writer: StreamWriter, **keys: bytes) -> Self:
"""Create a server connection from an existing byte stream""" """Create a server connection from an existing byte stream"""
reader, writer = get_stream_pair(reader, writer, **keys) box_reader, box_writer = get_stream_pair(reader, writer, **keys)
return cls(reader, writer) return cls(box_reader, box_writer)
class SHSClient(SHSDuplexStream, SHSEndpoint): class SHSClient(SHSDuplexStream, SHSEndpoint):
"""SHS client""" """SHS client"""
def __init__( # pylint: disable=too-many-arguments def __init__( # pylint: disable=too-many-arguments
self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None self,
host: str,
port: int,
client_kp: SigningKey,
server_pub_key: bytes,
ephemeral_key: Optional[PrivateKey] = None,
application_key: Optional[bytes] = None,
): ):
SHSDuplexStream.__init__(self) SHSDuplexStream.__init__(self)
SHSEndpoint.__init__(self) SHSEndpoint.__init__(self)
self.host = host self.host = host
self.port = port self.port = port
self.writer = None self.writer: Optional[StreamWriter] = None
self.crypto = SHSClientCrypto( self.crypto: SHSClientCrypto = SHSClientCrypto(
client_kp, server_pub_key, ephemeral_key=ephemeral_key, application_key=application_key client_kp,
server_pub_key,
ephemeral_key=ephemeral_key or PrivateKey.generate(),
application_key=application_key,
) )
async def _handshake(self, reader, writer): async def _handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
writer.write(self.crypto.generate_challenge()) writer.write(self.crypto.generate_challenge())
data = await reader.readexactly(64) data = await reader.readexactly(64)
if not self.crypto.verify_server_challenge(data): if not self.crypto.verify_server_challenge(data):
raise SHSClientException("Server challenge is not valid") raise SHSClientException("Server challenge is not valid")
writer.write(self.crypto.generate_client_auth()) writer.write(self.crypto.generate_client_auth())
data = await reader.readexactly(80) data = await reader.readexactly(80)
if not self.crypto.verify_server_accept(data): if not self.crypto.verify_server_accept(data):
raise SHSClientException("Server accept is not valid") raise SHSClientException("Server accept is not valid")
async def open(self): async def open(self) -> None:
"""Open the TCP connection""" """Open the TCP connection"""
reader, writer = await asyncio.open_connection(self.host, self.port) reader, writer = await open_connection(self.host, self.port)
await self._handshake(reader, writer) await self._handshake(reader, writer)
keys = self.crypto.get_box_keys() keys = self.crypto.get_box_keys()
@ -189,7 +218,7 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
self.is_connected = True self.is_connected = True
if self._on_connect: if self._on_connect:
await self._on_connect() await self._on_connect(self)
def disconnect(self): def disconnect(self) -> None:
self.close() self.close()

View File

View File

@ -21,12 +21,14 @@
"""Utility functions""" """Utility functions"""
import struct import struct
from typing import Generator, Sequence, TypeVar
NONCE_SIZE = 24 NONCE_SIZE = 24
MAX_NONCE = 8 * NONCE_SIZE MAX_NONCE = 8 * NONCE_SIZE
T = TypeVar("T")
def inc_nonce(nonce): def inc_nonce(nonce: bytes) -> bytes:
"""Increment nonce""" """Increment nonce"""
num = bytes_to_long(nonce) + 1 num = bytes_to_long(nonce) + 1
@ -40,34 +42,38 @@ def inc_nonce(nonce):
return bnum return bnum
def split_chunks(seq, n): def split_chunks(seq: Sequence[T], n: int) -> Generator[Sequence[T], None, None]:
"""Split sequence in equal-sized chunks. """Split sequence in equal-sized chunks.
The last chunk is not padded.""" The last chunk is not padded."""
while seq: while seq:
yield seq[:n] yield seq[:n]
seq = seq[n:] seq = seq[n:]
# Stolen from PyCypto (Public Domain) # Stolen from PyCypto (Public Domain)
def b(s): def b(s: str) -> bytes:
"""Shorthand for s.encode("latin-1")""" """Shorthand for s.encode("latin-1")"""
return s.encode("latin-1") # utf-8 would cause some side-effects we don't want return s.encode("latin-1") # utf-8 would cause some side-effects we don't want
def long_to_bytes(n, blocksize=0): def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
"""long_to_bytes(n:long, blocksize:int) : string """Convert a long integer to a byte string.
Convert a long integer to a byte string.
If optional blocksize is given and greater than zero, pad the front of the If optional ``blocksize`` is given and greater than zero, pad the front of the byte string with binary zeros so
byte string with binary zeros so that the length is a multiple of that the length is a multiple of blocksize.
blocksize.
""" """
# after much testing, this algorithm was deemed to be the fastest # after much testing, this algorithm was deemed to be the fastest
s = b("") s = b("")
pack = struct.pack pack = struct.pack
while n > 0: while n > 0:
s = pack(">I", n & 0xFFFFFFFF) + s s = pack(">I", n & 0xFFFFFFFF) + s
n = n >> 32 n = n >> 32
# strip off leading zeros # strip off leading zeros
for i, c in enumerate(s): for i, c in enumerate(s):
if c != b("\000")[0]: if c != b("\000")[0]:
@ -76,26 +82,33 @@ def long_to_bytes(n, blocksize=0):
# only happens when n == 0 # only happens when n == 0
s = b("\000") s = b("\000")
i = 0 i = 0
s = s[i:] s = s[i:]
# add back some pad bytes. this could be done more efficiently w.r.t. the # add back some pad bytes. this could be done more efficiently w.r.t. the
# de-padding being done above, but sigh... # de-padding being done above, but sigh...
if blocksize > 0 and len(s) % blocksize: if blocksize > 0 and len(s) % blocksize:
s = (blocksize - len(s) % blocksize) * b("\000") + s s = (blocksize - len(s) % blocksize) * b("\000") + s
return s return s
def bytes_to_long(s): def bytes_to_long(s: bytes) -> int:
"""bytes_to_long(string) : long """Convert a byte string to a long integer.
Convert a byte string to a long integer.
This is (essentially) the inverse of long_to_bytes(). This is (essentially) the inverse of ``long_to_bytes()``.
""" """
acc = 0 acc = 0
unpack = struct.unpack unpack = struct.unpack
length = len(s) length = len(s)
if length % 4: if length % 4:
extra = 4 - length % 4 extra = 4 - length % 4
s = b("\000") * extra + s s = b("\000") * extra + s
length = length + extra length = length + extra
for i in range(0, length, 4): for i in range(0, length, 4):
acc = (acc << 32) + unpack(">I", s[i : i + 4])[0] acc = (acc << 32) + unpack(">I", s[i : i + 4])[0]
return acc return acc

View File

@ -20,27 +20,34 @@
"""Helper utilities for testing""" """Helper utilities for testing"""
from asyncio import StreamReader, StreamWriter
from io import BytesIO from io import BytesIO
from typing import AsyncIterable, List, Optional, TypeVar
T = TypeVar("T")
class AsyncBuffer(BytesIO): class AsyncBuffer(BytesIO, StreamReader, StreamWriter): # type: ignore[misc]
"""Just a BytesIO with an async read method.""" """Just a BytesIO with an async read method."""
async def read(self, n=None): # pylint: disable=invalid-overridden-method async def read( # type: ignore[override] # pylint: disable=invalid-overridden-method
self, n: Optional[int] = None
) -> Optional[bytes]:
v = super().read(n) v = super().read(n)
return v return v
readexactly = read readexactly = read # type: ignore[assignment]
def append(self, data): def append(self, data: bytes) -> None:
"""Append data to the buffer without changing the current position.""" """Append data to the buffer without changing the current position."""
pos = self.tell() pos = self.tell()
self.write(data) self.write(data)
self.seek(pos) self.seek(pos)
async def async_comprehend(generator): async def async_comprehend(generator: AsyncIterable[T]) -> List[T]:
"""Emulate ``[elem async for elem in generator]``.""" """Emulate ``[elem async for elem in generator]``."""
results = [] results = []

View File

@ -38,8 +38,9 @@ MESSAGE_CLOSED = b"\xb1\x14hU'\xb5M\xa6\"\x03\x9duy\xa1\xd4evW,\xdcE\x18\xe4+ C4
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_boxstream(): async def test_boxstream() -> None:
"""Test stream boxing""" """Test stream boxing"""
buffer = AsyncBuffer() buffer = AsyncBuffer()
box_stream = BoxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE) box_stream = BoxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
box_stream.write(b"foo") box_stream.write(b"foo")
@ -63,7 +64,7 @@ async def test_boxstream():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_unboxstream(): async def test_unboxstream() -> None:
"""Test stream unboxing""" """Test stream unboxing"""
buffer = AsyncBuffer(MESSAGE_1 + MESSAGE_2 + MESSAGE_3 + MESSAGE_CLOSED) buffer = AsyncBuffer(MESSAGE_1 + MESSAGE_2 + MESSAGE_3 + MESSAGE_CLOSED)
@ -76,7 +77,7 @@ async def test_unboxstream():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_long_packets(): async def test_long_packets() -> None:
"""Test for receiving long packets""" """Test for receiving long packets"""
data_size = 6 * 1024 data_size = 6 * 1024

View File

@ -36,7 +36,7 @@ CLIENT_EPH_KEY_SEED = b"u8\xd0\xe3\x85d_Pz\x0c\xf5\xfd\x15\xce2p#\xb0\xf0\x9f\xe
@pytest.fixture @pytest.fixture
def server(): def server() -> SHSServerCrypto:
"""A testing SHS server""" """A testing SHS server"""
server_key = SigningKey(SERVER_KEY_SEED) server_key = SigningKey(SERVER_KEY_SEED)
@ -46,7 +46,7 @@ def server():
@pytest.fixture @pytest.fixture
def client(): def client() -> SHSClientCrypto:
"""A testing SHS client""" """A testing SHS client"""
client_key = SigningKey(CLIENT_KEY_SEED) client_key = SigningKey(CLIENT_KEY_SEED)
@ -90,7 +90,7 @@ CLIENT_ENCRYPT_NONCE = b"S\\\x06\x8d\xe5\xeb&*\xb8\x0bp\xb3Z\x8e\\\x85\x14\xaa\x
CLIENT_DECRYPT_NONCE = b"d\xe8\xccD\xec\xb9E\xbb\xaa\xa7\x7f\xe38\x15\x16\xef\xca\xd22u\x1d\xfe<\xe7" CLIENT_DECRYPT_NONCE = b"d\xe8\xccD\xec\xb9E\xbb\xaa\xa7\x7f\xe38\x15\x16\xef\xca\xd22u\x1d\xfe<\xe7"
def test_handshake(client, server): # pylint: disable=redefined-outer-name def test_handshake(client: SHSClientCrypto, server: SHSServerCrypto) -> None: # pylint: disable=redefined-outer-name
"""Test the handshake procedure""" """Test the handshake procedure"""
client_challenge = client.generate_challenge() client_challenge = client.generate_challenge()

View File

@ -22,11 +22,14 @@
from asyncio import Event, wait_for from asyncio import Event, wait_for
import os import os
from typing import Any, Awaitable, Callable, Tuple
from nacl.signing import SigningKey from nacl.signing import SigningKey
import pytest import pytest
from pytest_mock import MockerFixture
from secret_handshake import SHSClient, SHSServer from secret_handshake import SHSClient, SHSServer
from secret_handshake.boxstream import BoxStreamKeys
from .helpers import AsyncBuffer from .helpers import AsyncBuffer
@ -34,41 +37,42 @@ from .helpers import AsyncBuffer
class DummyCrypto: class DummyCrypto:
"""Dummy crypto module, pretends everything is fine.""" """Dummy crypto module, pretends everything is fine."""
def verify_server_challenge(self, _): def verify_server_challenge(self, _: bytes) -> bool:
"""Verify the server challenge""" """Verify the server challenge"""
return True return True
def verify_challenge(self, _): def verify_challenge(self, _: bytes) -> bool:
"""Verify the challenge data""" """Verify the challenge data"""
return True return True
def verify_server_accept(self, _): def verify_server_accept(self, _: bytes) -> bool:
"""Verify servers accept message""" """Verify servers accept message"""
return True return True
def generate_challenge(self): def generate_challenge(self) -> bytes:
"""Generate authentication challenge""" """Generate authentication challenge"""
return b"CHALLENGE" return b"CHALLENGE"
def generate_client_auth(self): def generate_client_auth(self) -> bytes:
"""Generate client authentication data""" """Generate client authentication data"""
return b"AUTH" return b"AUTH"
def verify_client_auth(self, _): def verify_client_auth(self, _: bytes) -> bool:
"""Verify client authentication data""" """Verify client authentication data"""
return True return True
def generate_accept(self): def generate_accept(self) -> bytes:
"""Generate an ACCEPT message""" """Generate an ACCEPT message"""
return b"ACCEPT" return b"ACCEPT"
def get_box_keys(self): def get_box_keys(self) -> BoxStreamKeys:
"""Get box keys""" """Get box keys"""
return { return {
@ -76,48 +80,64 @@ class DummyCrypto:
"encrypt_nonce": b"x" * 32, "encrypt_nonce": b"x" * 32,
"decrypt_key": b"x" * 32, "decrypt_key": b"x" * 32,
"decrypt_nonce": b"x" * 32, "decrypt_nonce": b"x" * 32,
"shared_secret": b"x" * 32,
} }
def clean(self): def clean(self) -> None:
"""Clean up internal data""" """Clean up internal data"""
def _dummy_boxstream(stream, **_): def _dummy_boxstream(stream: AsyncBuffer, **_: Any) -> AsyncBuffer:
"""Identity boxstream, no tansformation.""" """Identity boxstream, no transformation."""
return stream return stream
def _client_stream_mocker(): def _client_stream_mocker() -> (
Tuple[AsyncBuffer, AsyncBuffer, Callable[[str, int], Awaitable[Tuple[AsyncBuffer, AsyncBuffer]]]]
):
reader = AsyncBuffer(b"xxx") reader = AsyncBuffer(b"xxx")
writer = AsyncBuffer(b"xxx") writer = AsyncBuffer(b"xxx")
async def _create_mock_streams(host, port): # pylint: disable=unused-argument async def _create_mock_streams(
host: str, port: int # pylint: disable=unused-argument
) -> Tuple[AsyncBuffer, AsyncBuffer]:
return reader, writer return reader, writer
return reader, writer, _create_mock_streams return reader, writer, _create_mock_streams
def _server_stream_mocker(): def _server_stream_mocker() -> (
Tuple[
AsyncBuffer,
AsyncBuffer,
Callable[[Callable[[AsyncBuffer, AsyncBuffer], Awaitable[None]], str, int], Awaitable[None]],
]
):
reader = AsyncBuffer(b"xxx") reader = AsyncBuffer(b"xxx")
writer = AsyncBuffer(b"xxx") writer = AsyncBuffer(b"xxx")
async def _create_mock_server(cb, host, port): # pylint: disable=unused-argument async def _create_mock_server(
cb: Callable[[AsyncBuffer, AsyncBuffer], Awaitable[None]],
host: str, # pylint: disable=unused-argument
port: int, # pylint: disable=unused-argument
) -> None:
await cb(reader, writer) await cb(reader, writer)
return reader, writer, _create_mock_server return reader, writer, _create_mock_server
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_client(mocker): async def test_client(mocker: MockerFixture) -> None:
"""Test the client""" """Test the client"""
reader, _, _create_mock_streams = _client_stream_mocker() reader, _, _create_mock_streams = _client_stream_mocker()
mocker.patch("asyncio.open_connection", new=_create_mock_streams) mocker.patch("secret_handshake.network.open_connection", new=_create_mock_streams)
mocker.patch("secret_handshake.boxstream.BoxStream", new=_dummy_boxstream) mocker.patch("secret_handshake.boxstream.BoxStream", new=_dummy_boxstream)
mocker.patch("secret_handshake.boxstream.UnboxStream", new=_dummy_boxstream) mocker.patch("secret_handshake.boxstream.UnboxStream", new=_dummy_boxstream)
client = SHSClient("shop.local", 1111, SigningKey.generate(), os.urandom(32)) client = SHSClient("shop.local", 1111, SigningKey.generate(), os.urandom(32))
client.crypto = DummyCrypto() client.crypto = DummyCrypto() # type: ignore[assignment]
await client.open() await client.open()
reader.append(b"TEST") reader.append(b"TEST")
@ -126,22 +146,22 @@ async def test_client(mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_server(mocker): async def test_server(mocker: MockerFixture) -> None:
"""Test the server""" """Test the server"""
resolve = Event() resolve = Event()
async def _on_connect(_): async def _on_connect(_: Any) -> None:
server.disconnect() server.disconnect()
resolve.set() resolve.set()
_, _, _create_mock_server = _server_stream_mocker() _, _, _create_mock_server = _server_stream_mocker()
mocker.patch("asyncio.start_server", new=_create_mock_server) mocker.patch("secret_handshake.network.start_server", new=_create_mock_server)
mocker.patch("secret_handshake.boxstream.BoxStream", new=_dummy_boxstream) mocker.patch("secret_handshake.boxstream.BoxStream", new=_dummy_boxstream)
mocker.patch("secret_handshake.boxstream.UnboxStream", new=_dummy_boxstream) mocker.patch("secret_handshake.boxstream.UnboxStream", new=_dummy_boxstream)
server = SHSServer("shop.local", 1111, SigningKey.generate(), os.urandom(32)) server = SHSServer("shop.local", 1111, SigningKey.generate(), os.urandom(32))
server.crypto = DummyCrypto() server.crypto = DummyCrypto() # type: ignore[assignment]
server.on_connect(_on_connect) server.on_connect(_on_connect)