ci: Add mypy as a dev dependency and configure it

This commit is contained in:
Gergely Polonkai 2023-10-30 05:47:28 +01:00
parent 95039914ba
commit 4996931b54
No known key found for this signature in database
GPG Key ID: 2D2885533B869ED4
14 changed files with 350 additions and 145 deletions

View File

@ -36,3 +36,10 @@ repos:
language: system
require_serial: true
types_or: [python, pyi]
- id: mypy
name: mypy
entry: poetry run mypy
args: ["--strict"]
language: system
types_or: [python, pyi]
require_serial: true

View File

@ -13,7 +13,7 @@ with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f:
config = yaml.safe_load(f)
async def main():
async def main() -> None:
"""Main function to run"""
server_pub_key = b64decode(config["public"][:-8])

View File

@ -8,17 +8,18 @@ from nacl.signing import SigningKey
import yaml
from secret_handshake import SHSServer
from secret_handshake.network import SHSDuplexStream
with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f:
config = yaml.safe_load(f)
async def _on_connect(conn):
async def _on_connect(conn: SHSDuplexStream) -> None:
async for msg in conn:
print(msg)
async def main():
async def main() -> None:
"""Main function to run"""
server_keypair = SigningKey(b64decode(config["private"][:-8])[:32])

69
poetry.lock generated
View File

@ -674,6 +674,51 @@ files = [
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]]
name = "mypy"
version = "1.6.1"
description = "Optional static typing for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "mypy-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e5012e5cc2ac628177eaac0e83d622b2dd499e28253d4107a08ecc59ede3fc2c"},
{file = "mypy-1.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d8fbb68711905f8912e5af474ca8b78d077447d8f3918997fecbf26943ff3cbb"},
{file = "mypy-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a1ad938fee7d2d96ca666c77b7c494c3c5bd88dff792220e1afbebb2925b5e"},
{file = "mypy-1.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b96ae2c1279d1065413965c607712006205a9ac541895004a1e0d4f281f2ff9f"},
{file = "mypy-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:40b1844d2e8b232ed92e50a4bd11c48d2daa351f9deee6c194b83bf03e418b0c"},
{file = "mypy-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81af8adaa5e3099469e7623436881eff6b3b06db5ef75e6f5b6d4871263547e5"},
{file = "mypy-1.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8c223fa57cb154c7eab5156856c231c3f5eace1e0bed9b32a24696b7ba3c3245"},
{file = "mypy-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8032e00ce71c3ceb93eeba63963b864bf635a18f6c0c12da6c13c450eedb183"},
{file = "mypy-1.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4c46b51de523817a0045b150ed11b56f9fff55f12b9edd0f3ed35b15a2809de0"},
{file = "mypy-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:19f905bcfd9e167159b3d63ecd8cb5e696151c3e59a1742e79bc3bcb540c42c7"},
{file = "mypy-1.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:82e469518d3e9a321912955cc702d418773a2fd1e91c651280a1bda10622f02f"},
{file = "mypy-1.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d4473c22cc296425bbbce7e9429588e76e05bc7342da359d6520b6427bf76660"},
{file = "mypy-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a0d7d24dfb26729e0a068639a6ce3500e31d6655df8557156c51c1cb874ce7"},
{file = "mypy-1.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cfd13d47b29ed3bbaafaff7d8b21e90d827631afda134836962011acb5904b71"},
{file = "mypy-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:eb4f18589d196a4cbe5290b435d135dee96567e07c2b2d43b5c4621b6501531a"},
{file = "mypy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:41697773aa0bf53ff917aa077e2cde7aa50254f28750f9b88884acea38a16169"},
{file = "mypy-1.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7274b0c57737bd3476d2229c6389b2ec9eefeb090bbaf77777e9d6b1b5a9d143"},
{file = "mypy-1.6.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbaf4662e498c8c2e352da5f5bca5ab29d378895fa2d980630656178bd607c46"},
{file = "mypy-1.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bb8ccb4724f7d8601938571bf3f24da0da791fe2db7be3d9e79849cb64e0ae85"},
{file = "mypy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:68351911e85145f582b5aa6cd9ad666c8958bcae897a1bfda8f4940472463c45"},
{file = "mypy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:49ae115da099dcc0922a7a895c1eec82c1518109ea5c162ed50e3b3594c71208"},
{file = "mypy-1.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b27958f8c76bed8edaa63da0739d76e4e9ad4ed325c814f9b3851425582a3cd"},
{file = "mypy-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:925cd6a3b7b55dfba252b7c4561892311c5358c6b5a601847015a1ad4eb7d332"},
{file = "mypy-1.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8f57e6b6927a49550da3d122f0cb983d400f843a8a82e65b3b380d3d7259468f"},
{file = "mypy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a43ef1c8ddfdb9575691720b6352761f3f53d85f1b57d7745701041053deff30"},
{file = "mypy-1.6.1-py3-none-any.whl", hash = "sha256:4cbe68ef919c28ea561165206a2dcb68591c50f3bcf777932323bc208d949cf1"},
{file = "mypy-1.6.1.tar.gz", hash = "sha256:4d01c00d09a0be62a4ca3f933e315455bde83f37f892ba4b08ce92f3cf44bcc1"},
]
[package.dependencies]
mypy-extensions = ">=1.0.0"
typing-extensions = ">=4.1.0"
[package.extras]
dmypy = ["psutil (>=4.0)"]
install-types = ["pip"]
reports = ["lxml"]
[[package]]
name = "mypy-extensions"
version = "1.0.0"
@ -1243,6 +1288,28 @@ files = [
{file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"},
]
[[package]]
name = "types-pyyaml"
version = "6.0.12.12"
description = "Typing stubs for PyYAML"
optional = false
python-versions = "*"
files = [
{file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"},
{file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"},
]
[[package]]
name = "typing-extensions"
version = "4.8.0"
description = "Backported and Experimental Type Hints for Python 3.8+"
optional = false
python-versions = ">=3.8"
files = [
{file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"},
{file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"},
]
[[package]]
name = "urllib3"
version = "2.0.7"
@ -1309,4 +1376,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
content-hash = "82c05c7c22b990f45fae70d36d0bed6f1c414b30af9e40318dfc383ebab9d86a"
content-hash = "980ed766f11ffdaf6694f7800174de76cd21ac2d7a5ffbe27d8df1d1673f6119"

View File

@ -6,6 +6,7 @@ authors = ["Pedro Ferreira <pedro@dete.st>"]
license = "MIT"
readme = "README.rst"
packages = [{include = "secret_handshake"}]
include = ["secret_handshake/py.typed"]
[tool.poetry.dependencies]
python = "^3.12"
@ -24,6 +25,8 @@ pre-commit = "^3.5.0"
commitizen = "^3.12.0"
black = "^23.10.1"
pylint = "^3.0.2"
mypy = "^1.6.1"
types-pyyaml = "^6.0.12.12"
[tool.poetry.group.docs.dependencies]
sphinx = "^7.2.6"

View File

@ -1,8 +1,8 @@
"""Box stream utilities"""
from asyncio import IncompleteReadError
from asyncio import IncompleteReadError, StreamReader, StreamWriter
import struct
from typing import Tuple
from typing import Any, AsyncIterator, Optional, Tuple, TypedDict
from nacl.secret import SecretBox
@ -13,36 +13,51 @@ MAX_SEGMENT_SIZE = 4 * 1024
TERMINATION_HEADER = b"\x00" * 18
def get_stream_pair(reader, writer, **kwargs) -> Tuple["UnboxStream", "BoxStream"]:
class BoxStreamKeys(TypedDict):
"""Dictionary to hold all box stream keys"""
decrypt_key: bytes
decrypt_nonce: bytes
encrypt_key: bytes
encrypt_nonce: bytes
shared_secret: bytes
def get_stream_pair( # pylint: disable=too-many-arguments
reader: StreamReader, # pylint: disable=unused-argument
writer: StreamWriter,
*,
decrypt_key: bytes,
decrypt_nonce: bytes,
encrypt_key: bytes,
encrypt_nonce: bytes,
**kwargs: Any,
) -> Tuple["UnboxStream", "BoxStream"]:
"""Create a new duplex box stream"""
box_args = {
"key": kwargs["encrypt_key"],
"nonce": kwargs["encrypt_nonce"],
}
unbox_args = {
"key": kwargs["decrypt_key"],
"nonce": kwargs["decrypt_nonce"],
}
return UnboxStream(reader, **unbox_args), BoxStream(writer, **box_args)
read_stream = UnboxStream(reader, key=decrypt_key, nonce=decrypt_nonce)
write_stream = BoxStream(writer, key=encrypt_key, nonce=encrypt_nonce)
return read_stream, write_stream
class UnboxStream:
"""Unboxing stream"""
def __init__(self, reader, key, nonce):
def __init__(self, reader: StreamReader, key: bytes, nonce: bytes):
self.reader = reader
self.key = key
self.nonce = nonce
self.closed = False
async def read(self):
async def read(self) -> Optional[bytes]:
"""Read data from the stream"""
try:
data = await self.reader.readexactly(HEADER_LENGTH)
except IncompleteReadError:
self.closed = True
return None
box = SecretBox(self.key)
@ -51,6 +66,7 @@ class UnboxStream:
if header == TERMINATION_HEADER:
self.closed = True
return None
length = struct.unpack(">H", header[:2])[0]
@ -61,12 +77,13 @@ class UnboxStream:
body = box.decrypt(mac + data, inc_nonce(self.nonce))
self.nonce = inc_nonce(inc_nonce(self.nonce))
return body
def __aiter__(self):
def __aiter__(self) -> AsyncIterator[bytes]:
return self
async def __anext__(self):
async def __anext__(self) -> bytes:
data = await self.read()
if data is None:
@ -78,17 +95,17 @@ class UnboxStream:
class BoxStream:
"""Box stream"""
def __init__(self, writer, key, nonce):
def __init__(self, writer: StreamWriter, key: bytes, nonce: bytes):
self.writer = writer
self.key = key
self.box = SecretBox(self.key)
self.nonce = nonce
def write(self, data):
def write(self, data: bytes) -> None:
"""Write data to the box stream"""
for chunk in split_chunks(data, MAX_SEGMENT_SIZE):
body = self.box.encrypt(chunk, inc_nonce(self.nonce))[24:]
body = self.box.encrypt(bytes(chunk), inc_nonce(self.nonce))[24:]
header = struct.pack(">H", len(body) - 16) + body[:16]
hdrbox = self.box.encrypt(header, self.nonce)[24:]
@ -97,7 +114,7 @@ class BoxStream:
self.nonce = inc_nonce(inc_nonce(self.nonce))
self.writer.write(body[16:])
def close(self):
def close(self) -> None:
"""Close the box stream"""
self.writer.write(self.box.encrypt(b"\x00" * 18, self.nonce)[24:])

View File

@ -28,7 +28,9 @@ from typing import Optional
from nacl.bindings import crypto_box_afternm, crypto_box_open_afternm, crypto_scalarmult
from nacl.exceptions import CryptoError
from nacl.public import PrivateKey
from nacl.signing import VerifyKey
from nacl.signing import SigningKey, VerifyKey
from .boxstream import BoxStreamKeys
APPLICATION_KEY = b64decode("1KHLiKZvAvjbY1ziZEHMXawbCEIM6qwjCDm3VYRan/s=")
@ -40,29 +42,38 @@ class SHSError(Exception):
class SHSCryptoBase: # pylint: disable=too-many-instance-attributes
"""Base functions for SHS cryptography"""
def __init__(self, local_key, ephemeral_key=None, application_key=None):
self.local_key = local_key
def __init__(
self,
local_key: SigningKey,
ephemeral_key: Optional[PrivateKey] = None,
application_key: Optional[bytes] = None,
):
self.application_key = application_key or APPLICATION_KEY
self.shared_hash = None
self.remote_ephemeral_key = None
self.shared_secret = None
self.remote_app_hmac = None
self.remote_pub_key = None
self.box_secret = None
self.b_alice: Optional[bytes] = None
self.box_secret: Optional[bytes] = None
self.hello: Optional[bytes] = None
self.local_key = local_key
self.remote_app_hmac: Optional[bytes] = None
self.remote_ephemeral_key: Optional[bytes] = None
self.remote_pub_key: Optional[VerifyKey] = None
self.shared_hash: Optional[bytes] = None
self.shared_secret: Optional[bytes] = None
self._reset_keys(ephemeral_key or PrivateKey.generate())
def _reset_keys(self, ephemeral_key):
def _reset_keys(self, ephemeral_key: PrivateKey) -> None:
self.local_ephemeral_key = ephemeral_key
self.local_app_hmac = hmac.new(
self.application_key, bytes(ephemeral_key.public_key), digestmod="sha512"
).digest()[:32]
def generate_challenge(self):
def generate_challenge(self) -> bytes:
"""Generate and return a challenge to be sent to the server."""
return self.local_app_hmac + bytes(self.local_ephemeral_key.public_key)
def verify_challenge(self, data):
def verify_challenge(self, data: bytes) -> bool:
"""Verify the correctness of challenge sent from the client."""
assert len(data) == 64
sent_hmac, remote_ephemeral_key = data[:32], data[32:]
@ -76,9 +87,10 @@ class SHSCryptoBase: # pylint: disable=too-many-instance-attributes
self.remote_ephemeral_key = remote_ephemeral_key
# this is hash(a * b)
self.shared_hash = hashlib.sha256(self.shared_secret).digest()
return ok
def clean(self, new_ephemeral_key=None):
def clean(self, new_ephemeral_key: Optional[PrivateKey] = None) -> None:
"""Clean internal data"""
self._reset_keys(new_ephemeral_key or PrivateKey.generate())
@ -86,10 +98,15 @@ class SHSCryptoBase: # pylint: disable=too-many-instance-attributes
self.shared_hash = None
self.remote_ephemeral_key = None
def get_box_keys(self):
def get_box_keys(self) -> BoxStreamKeys:
"""Get the box streams keys"""
assert self.box_secret
assert self.remote_app_hmac
assert self.remote_pub_key
shared_secret = hashlib.sha256(self.box_secret).digest()
return {
"shared_secret": shared_secret,
"encrypt_key": hashlib.sha256(shared_secret + bytes(self.remote_pub_key)).digest(),
@ -102,17 +119,15 @@ class SHSCryptoBase: # pylint: disable=too-many-instance-attributes
class SHSServerCrypto(SHSCryptoBase):
"""SHS server crypto algorithm"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.b_alice = None
self.hello = None
self.box_secret = None
self.remote_pub_key = None
def verify_client_auth(self, data):
def verify_client_auth(self, data: bytes) -> bool:
"""Verify client authentication data"""
assert self.remote_ephemeral_key
assert self.shared_hash
assert self.shared_secret
assert len(data) == 112
a_bob = crypto_scalarmult(bytes(self.local_key.to_curve25519_private_key()), self.remote_ephemeral_key)
box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob).digest()
self.hello = crypto_box_open_afternm(data, b"\x00" * 24, box_secret)
@ -127,15 +142,21 @@ class SHSServerCrypto(SHSCryptoBase):
bytes(self.local_ephemeral_key), bytes(self.remote_pub_key.to_curve25519_public_key())
)
self.box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob + b_alice).digest()[:32]
return True
def generate_accept(self):
def generate_accept(self) -> bytes:
"""Generate an accept message"""
assert self.box_secret
assert self.hello
assert self.shared_hash
okay = self.local_key.sign(self.application_key + self.hello + self.shared_hash).signature
return crypto_box_afternm(okay, b"\x00" * 24, self.box_secret)
def clean(self, new_ephemeral_key=None):
def clean(self, new_ephemeral_key: Optional[PrivateKey] = None) -> None:
super().clean(new_ephemeral_key=new_ephemeral_key)
self.hello = None
self.b_alice = None
@ -152,21 +173,26 @@ class SHSClientCrypto(SHSCryptoBase):
def __init__(
self,
local_key: PrivateKey,
local_key: SigningKey,
server_pub_key: bytes,
ephemeral_key: PrivateKey,
application_key: Optional[bytes] = None,
):
super().__init__(local_key, ephemeral_key, application_key)
self.remote_pub_key = VerifyKey(server_pub_key)
self.b_alice = None
self.a_bob = None
self.hello = None
self.box_secret = None
def verify_server_challenge(self, data):
self.a_bob: Optional[bytes] = None
self.remote_pub_key = VerifyKey(server_pub_key)
def verify_server_challenge(self, data: bytes) -> bool:
"""Verify the correctness of challenge sent from the server."""
assert self.remote_pub_key
assert super().verify_challenge(data)
assert self.shared_hash
assert self.shared_secret
curve_pkey = self.remote_pub_key.to_curve25519_public_key()
# a_bob is (a * B)
@ -179,17 +205,30 @@ class SHSClientCrypto(SHSCryptoBase):
signed_message = self.local_key.sign(self.application_key + bytes(self.remote_pub_key) + self.shared_hash)
message_to_box = signed_message.signature + bytes(self.local_key.verify_key)
self.hello = message_to_box
return True
def generate_client_auth(self):
def generate_client_auth(self) -> bytes:
"""Generate box[K|a*b|a*B](H)"""
assert self.box_secret
assert self.hello
nonce = b"\x00" * 24
# return box(K | a * b | a * B)[H]
return crypto_box_afternm(self.hello, nonce, self.box_secret)
def verify_server_accept(self, data):
def verify_server_accept(self, data: bytes) -> bool:
"""Verify that the server's accept message is sane"""
assert self.a_bob
assert self.hello
assert self.remote_ephemeral_key
assert self.remote_pub_key
assert self.shared_hash
assert self.shared_secret
curve_lkey = self.local_key.to_curve25519_private_key()
# b_alice is (A * b)
b_alice = crypto_scalarmult(bytes(curve_lkey), self.remote_ephemeral_key)
@ -208,9 +247,10 @@ class SHSClientCrypto(SHSCryptoBase):
# we should have received sign(B)[K | H | hash(a * b)]
# let's see if that signature can verify the reconstructed data on our side
self.remote_pub_key.verify(self.application_key + self.hello + self.shared_hash, signature)
return True
def clean(self, new_ephemeral_key=None):
def clean(self, new_ephemeral_key: Optional[PrivateKey] = None) -> None:
super().clean(new_ephemeral_key=new_ephemeral_key)
self.a_bob = None
self.b_alice = None

View File

@ -20,10 +20,15 @@
"""Networking functionality"""
import asyncio
from asyncio import StreamReader, StreamWriter, ensure_future, open_connection, start_server
from typing import AsyncIterator, Awaitable, Callable, List, Optional
from .boxstream import get_stream_pair
from .crypto import SHSClientCrypto, SHSServerCrypto
from nacl.public import PrivateKey
from nacl.signing import SigningKey
from typing_extensions import Self
from .boxstream import BoxStream, UnboxStream, get_stream_pair
from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto
class SHSClientException(Exception):
@ -33,32 +38,37 @@ class SHSClientException(Exception):
class SHSDuplexStream:
"""SHS duplex stream"""
def __init__(self):
self.write_stream = None
self.read_stream = None
def __init__(self) -> None:
self.write_stream: Optional[BoxStream] = None
self.read_stream: Optional[UnboxStream] = None
self.is_connected = False
def write(self, data):
def write(self, data: bytes) -> None:
"""Write data to the write stream"""
assert self.write_stream
self.write_stream.write(data)
async def read(self):
async def read(self) -> Optional[bytes]:
"""Read data from the read stream"""
assert self.read_stream
return await self.read_stream.read()
def close(self):
def close(self) -> None:
"""Close the duplex stream"""
self.write_stream.close()
self.read_stream.close()
if self.write_stream:
self.write_stream.close()
self.is_connected = False
def __aiter__(self):
def __aiter__(self) -> AsyncIterator[bytes]:
return self
async def __anext__(self):
async def __anext__(self) -> bytes:
msg = await self.read()
if msg is None:
@ -70,45 +80,53 @@ class SHSDuplexStream:
class SHSEndpoint:
"""SHS endpoint"""
def __init__(self):
self._on_connect = None
self.crypto = None
def __init__(self) -> None:
self._on_connect: Optional[Callable[[SHSDuplexStream], Awaitable[None]]] = None
self.crypto: Optional[SHSCryptoBase] = None
def on_connect(self, cb):
def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
"""Set the function to be called when a new connection arrives"""
self._on_connect = cb
def disconnect(self):
def disconnect(self) -> None:
"""Disconnect the endpoint"""
raise NotImplementedError
class SHSServer(SHSEndpoint):
"""SHS server"""
def __init__(self, host, port, server_kp, application_key=None):
def __init__(self, host: str, port: int, server_kp: SigningKey, application_key: Optional[bytes] = None):
super().__init__()
self.host = host
self.port = port
self.crypto = SHSServerCrypto(server_kp, application_key=application_key)
self.connections = []
self.crypto: SHSServerCrypto = SHSServerCrypto(server_kp, application_key=application_key)
self.connections: List[SHSServerConnection] = []
async def _handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
assert self.crypto
async def _handshake(self, reader, writer):
data = await reader.readexactly(64)
if not self.crypto.verify_challenge(data):
raise SHSClientException("Client challenge is not valid")
writer.write(self.crypto.generate_challenge())
data = await reader.readexactly(112)
if not self.crypto.verify_client_auth(data):
raise SHSClientException("Client auth is not valid")
writer.write(self.crypto.generate_accept())
async def handle_connection(self, reader, writer):
async def handle_connection(self, reader: StreamReader, writer: StreamWriter) -> None:
"""Handle incoming connections"""
assert self.crypto
self.crypto.clean()
await self._handshake(reader, writer)
keys = self.crypto.get_box_keys()
@ -118,14 +136,14 @@ class SHSServer(SHSEndpoint):
self.connections.append(conn)
if self._on_connect:
asyncio.ensure_future(self._on_connect(conn))
ensure_future(self._on_connect(conn))
async def listen(self):
async def listen(self) -> None:
"""Listen for connections"""
await asyncio.start_server(self.handle_connection, self.host, self.port)
await start_server(self.handle_connection, self.host, self.port)
def disconnect(self):
def disconnect(self) -> None:
for connection in self.connections:
connection.close()
@ -133,52 +151,63 @@ class SHSServer(SHSEndpoint):
class SHSServerConnection(SHSDuplexStream):
"""SHS server connection"""
def __init__(self, read_stream, write_stream):
def __init__(self, read_stream: UnboxStream, write_stream: BoxStream):
super().__init__()
self.read_stream = read_stream
self.write_stream = write_stream
@classmethod
def from_byte_streams(cls, reader, writer, **keys):
def from_byte_streams(cls, reader: StreamReader, writer: StreamWriter, **keys: bytes) -> Self:
"""Create a server connection from an existing byte stream"""
reader, writer = get_stream_pair(reader, writer, **keys)
box_reader, box_writer = get_stream_pair(reader, writer, **keys)
return cls(reader, writer)
return cls(box_reader, box_writer)
class SHSClient(SHSDuplexStream, SHSEndpoint):
"""SHS client"""
def __init__( # pylint: disable=too-many-arguments
self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None
self,
host: str,
port: int,
client_kp: SigningKey,
server_pub_key: bytes,
ephemeral_key: Optional[PrivateKey] = None,
application_key: Optional[bytes] = None,
):
SHSDuplexStream.__init__(self)
SHSEndpoint.__init__(self)
self.host = host
self.port = port
self.writer = None
self.crypto = SHSClientCrypto(
client_kp, server_pub_key, ephemeral_key=ephemeral_key, application_key=application_key
self.writer: Optional[StreamWriter] = None
self.crypto: SHSClientCrypto = SHSClientCrypto(
client_kp,
server_pub_key,
ephemeral_key=ephemeral_key or PrivateKey.generate(),
application_key=application_key,
)
async def _handshake(self, reader, writer):
async def _handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
writer.write(self.crypto.generate_challenge())
data = await reader.readexactly(64)
if not self.crypto.verify_server_challenge(data):
raise SHSClientException("Server challenge is not valid")
writer.write(self.crypto.generate_client_auth())
data = await reader.readexactly(80)
if not self.crypto.verify_server_accept(data):
raise SHSClientException("Server accept is not valid")
async def open(self):
async def open(self) -> None:
"""Open the TCP connection"""
reader, writer = await asyncio.open_connection(self.host, self.port)
reader, writer = await open_connection(self.host, self.port)
await self._handshake(reader, writer)
keys = self.crypto.get_box_keys()
@ -189,7 +218,7 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
self.is_connected = True
if self._on_connect:
await self._on_connect()
await self._on_connect(self)
def disconnect(self):
def disconnect(self) -> None:
self.close()

View File

View File

@ -21,12 +21,14 @@
"""Utility functions"""
import struct
from typing import Generator, Sequence, TypeVar
NONCE_SIZE = 24
MAX_NONCE = 8 * NONCE_SIZE
T = TypeVar("T")
def inc_nonce(nonce):
def inc_nonce(nonce: bytes) -> bytes:
"""Increment nonce"""
num = bytes_to_long(nonce) + 1
@ -40,34 +42,38 @@ def inc_nonce(nonce):
return bnum
def split_chunks(seq, n):
def split_chunks(seq: Sequence[T], n: int) -> Generator[Sequence[T], None, None]:
"""Split sequence in equal-sized chunks.
The last chunk is not padded."""
while seq:
yield seq[:n]
seq = seq[n:]
# Stolen from PyCypto (Public Domain)
def b(s):
def b(s: str) -> bytes:
"""Shorthand for s.encode("latin-1")"""
return s.encode("latin-1") # utf-8 would cause some side-effects we don't want
def long_to_bytes(n, blocksize=0):
"""long_to_bytes(n:long, blocksize:int) : string
Convert a long integer to a byte string.
If optional blocksize is given and greater than zero, pad the front of the
byte string with binary zeros so that the length is a multiple of
blocksize.
def long_to_bytes(n: int, blocksize: int = 0) -> bytes:
"""Convert a long integer to a byte string.
If optional ``blocksize`` is given and greater than zero, pad the front of the byte string with binary zeros so
that the length is a multiple of blocksize.
"""
# after much testing, this algorithm was deemed to be the fastest
s = b("")
pack = struct.pack
while n > 0:
s = pack(">I", n & 0xFFFFFFFF) + s
n = n >> 32
# strip off leading zeros
for i, c in enumerate(s):
if c != b("\000")[0]:
@ -76,26 +82,33 @@ def long_to_bytes(n, blocksize=0):
# only happens when n == 0
s = b("\000")
i = 0
s = s[i:]
# add back some pad bytes. this could be done more efficiently w.r.t. the
# de-padding being done above, but sigh...
if blocksize > 0 and len(s) % blocksize:
s = (blocksize - len(s) % blocksize) * b("\000") + s
return s
def bytes_to_long(s):
"""bytes_to_long(string) : long
Convert a byte string to a long integer.
This is (essentially) the inverse of long_to_bytes().
def bytes_to_long(s: bytes) -> int:
"""Convert a byte string to a long integer.
This is (essentially) the inverse of ``long_to_bytes()``.
"""
acc = 0
unpack = struct.unpack
length = len(s)
if length % 4:
extra = 4 - length % 4
s = b("\000") * extra + s
length = length + extra
for i in range(0, length, 4):
acc = (acc << 32) + unpack(">I", s[i : i + 4])[0]
return acc

View File

@ -20,27 +20,34 @@
"""Helper utilities for testing"""
from asyncio import StreamReader, StreamWriter
from io import BytesIO
from typing import AsyncIterable, List, Optional, TypeVar
T = TypeVar("T")
class AsyncBuffer(BytesIO):
class AsyncBuffer(BytesIO, StreamReader, StreamWriter): # type: ignore[misc]
"""Just a BytesIO with an async read method."""
async def read(self, n=None): # pylint: disable=invalid-overridden-method
async def read( # type: ignore[override] # pylint: disable=invalid-overridden-method
self, n: Optional[int] = None
) -> Optional[bytes]:
v = super().read(n)
return v
readexactly = read
readexactly = read # type: ignore[assignment]
def append(self, data):
def append(self, data: bytes) -> None:
"""Append data to the buffer without changing the current position."""
pos = self.tell()
self.write(data)
self.seek(pos)
async def async_comprehend(generator):
async def async_comprehend(generator: AsyncIterable[T]) -> List[T]:
"""Emulate ``[elem async for elem in generator]``."""
results = []

View File

@ -38,8 +38,9 @@ MESSAGE_CLOSED = b"\xb1\x14hU'\xb5M\xa6\"\x03\x9duy\xa1\xd4evW,\xdcE\x18\xe4+ C4
@pytest.mark.asyncio
async def test_boxstream():
async def test_boxstream() -> None:
"""Test stream boxing"""
buffer = AsyncBuffer()
box_stream = BoxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
box_stream.write(b"foo")
@ -63,7 +64,7 @@ async def test_boxstream():
@pytest.mark.asyncio
async def test_unboxstream():
async def test_unboxstream() -> None:
"""Test stream unboxing"""
buffer = AsyncBuffer(MESSAGE_1 + MESSAGE_2 + MESSAGE_3 + MESSAGE_CLOSED)
@ -76,7 +77,7 @@ async def test_unboxstream():
@pytest.mark.asyncio
async def test_long_packets():
async def test_long_packets() -> None:
"""Test for receiving long packets"""
data_size = 6 * 1024

View File

@ -36,7 +36,7 @@ CLIENT_EPH_KEY_SEED = b"u8\xd0\xe3\x85d_Pz\x0c\xf5\xfd\x15\xce2p#\xb0\xf0\x9f\xe
@pytest.fixture
def server():
def server() -> SHSServerCrypto:
"""A testing SHS server"""
server_key = SigningKey(SERVER_KEY_SEED)
@ -46,7 +46,7 @@ def server():
@pytest.fixture
def client():
def client() -> SHSClientCrypto:
"""A testing SHS client"""
client_key = SigningKey(CLIENT_KEY_SEED)
@ -90,7 +90,7 @@ CLIENT_ENCRYPT_NONCE = b"S\\\x06\x8d\xe5\xeb&*\xb8\x0bp\xb3Z\x8e\\\x85\x14\xaa\x
CLIENT_DECRYPT_NONCE = b"d\xe8\xccD\xec\xb9E\xbb\xaa\xa7\x7f\xe38\x15\x16\xef\xca\xd22u\x1d\xfe<\xe7"
def test_handshake(client, server): # pylint: disable=redefined-outer-name
def test_handshake(client: SHSClientCrypto, server: SHSServerCrypto) -> None: # pylint: disable=redefined-outer-name
"""Test the handshake procedure"""
client_challenge = client.generate_challenge()

View File

@ -22,11 +22,14 @@
from asyncio import Event, wait_for
import os
from typing import Any, Awaitable, Callable, Tuple
from nacl.signing import SigningKey
import pytest
from pytest_mock import MockerFixture
from secret_handshake import SHSClient, SHSServer
from secret_handshake.boxstream import BoxStreamKeys
from .helpers import AsyncBuffer
@ -34,41 +37,42 @@ from .helpers import AsyncBuffer
class DummyCrypto:
"""Dummy crypto module, pretends everything is fine."""
def verify_server_challenge(self, _):
def verify_server_challenge(self, _: bytes) -> bool:
"""Verify the server challenge"""
return True
def verify_challenge(self, _):
def verify_challenge(self, _: bytes) -> bool:
"""Verify the challenge data"""
return True
def verify_server_accept(self, _):
def verify_server_accept(self, _: bytes) -> bool:
"""Verify servers accept message"""
return True
def generate_challenge(self):
def generate_challenge(self) -> bytes:
"""Generate authentication challenge"""
return b"CHALLENGE"
def generate_client_auth(self):
def generate_client_auth(self) -> bytes:
"""Generate client authentication data"""
return b"AUTH"
def verify_client_auth(self, _):
def verify_client_auth(self, _: bytes) -> bool:
"""Verify client authentication data"""
return True
def generate_accept(self):
def generate_accept(self) -> bytes:
"""Generate an ACCEPT message"""
return b"ACCEPT"
def get_box_keys(self):
def get_box_keys(self) -> BoxStreamKeys:
"""Get box keys"""
return {
@ -76,48 +80,64 @@ class DummyCrypto:
"encrypt_nonce": b"x" * 32,
"decrypt_key": b"x" * 32,
"decrypt_nonce": b"x" * 32,
"shared_secret": b"x" * 32,
}
def clean(self):
def clean(self) -> None:
"""Clean up internal data"""
def _dummy_boxstream(stream, **_):
"""Identity boxstream, no tansformation."""
def _dummy_boxstream(stream: AsyncBuffer, **_: Any) -> AsyncBuffer:
"""Identity boxstream, no transformation."""
return stream
def _client_stream_mocker():
def _client_stream_mocker() -> (
Tuple[AsyncBuffer, AsyncBuffer, Callable[[str, int], Awaitable[Tuple[AsyncBuffer, AsyncBuffer]]]]
):
reader = AsyncBuffer(b"xxx")
writer = AsyncBuffer(b"xxx")
async def _create_mock_streams(host, port): # pylint: disable=unused-argument
async def _create_mock_streams(
host: str, port: int # pylint: disable=unused-argument
) -> Tuple[AsyncBuffer, AsyncBuffer]:
return reader, writer
return reader, writer, _create_mock_streams
def _server_stream_mocker():
def _server_stream_mocker() -> (
Tuple[
AsyncBuffer,
AsyncBuffer,
Callable[[Callable[[AsyncBuffer, AsyncBuffer], Awaitable[None]], str, int], Awaitable[None]],
]
):
reader = AsyncBuffer(b"xxx")
writer = AsyncBuffer(b"xxx")
async def _create_mock_server(cb, host, port): # pylint: disable=unused-argument
async def _create_mock_server(
cb: Callable[[AsyncBuffer, AsyncBuffer], Awaitable[None]],
host: str, # pylint: disable=unused-argument
port: int, # pylint: disable=unused-argument
) -> None:
await cb(reader, writer)
return reader, writer, _create_mock_server
@pytest.mark.asyncio
async def test_client(mocker):
async def test_client(mocker: MockerFixture) -> None:
"""Test the client"""
reader, _, _create_mock_streams = _client_stream_mocker()
mocker.patch("asyncio.open_connection", new=_create_mock_streams)
mocker.patch("secret_handshake.network.open_connection", new=_create_mock_streams)
mocker.patch("secret_handshake.boxstream.BoxStream", new=_dummy_boxstream)
mocker.patch("secret_handshake.boxstream.UnboxStream", new=_dummy_boxstream)
client = SHSClient("shop.local", 1111, SigningKey.generate(), os.urandom(32))
client.crypto = DummyCrypto()
client.crypto = DummyCrypto() # type: ignore[assignment]
await client.open()
reader.append(b"TEST")
@ -126,22 +146,22 @@ async def test_client(mocker):
@pytest.mark.asyncio
async def test_server(mocker):
async def test_server(mocker: MockerFixture) -> None:
"""Test the server"""
resolve = Event()
async def _on_connect(_):
async def _on_connect(_: Any) -> None:
server.disconnect()
resolve.set()
_, _, _create_mock_server = _server_stream_mocker()
mocker.patch("asyncio.start_server", new=_create_mock_server)
mocker.patch("secret_handshake.network.start_server", new=_create_mock_server)
mocker.patch("secret_handshake.boxstream.BoxStream", new=_dummy_boxstream)
mocker.patch("secret_handshake.boxstream.UnboxStream", new=_dummy_boxstream)
server = SHSServer("shop.local", 1111, SigningKey.generate(), os.urandom(32))
server.crypto = DummyCrypto()
server.crypto = DummyCrypto() # type: ignore[assignment]
server.on_connect(_on_connect)