diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 54db536..ce259a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,3 +36,10 @@ repos: language: system require_serial: true types_or: [python, pyi] + - id: mypy + name: mypy + entry: poetry run mypy + args: ["--strict"] + language: system + types_or: [python, pyi] + require_serial: true diff --git a/examples/test_client.py b/examples/test_client.py index 745cccc..2474e5e 100644 --- a/examples/test_client.py +++ b/examples/test_client.py @@ -13,7 +13,7 @@ with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f: config = yaml.safe_load(f) -async def main(): +async def main() -> None: """Main function to run""" server_pub_key = b64decode(config["public"][:-8]) diff --git a/examples/test_server.py b/examples/test_server.py index d9255ae..b8f75dd 100644 --- a/examples/test_server.py +++ b/examples/test_server.py @@ -8,17 +8,18 @@ from nacl.signing import SigningKey import yaml from secret_handshake import SHSServer +from secret_handshake.network import SHSDuplexStream with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f: config = yaml.safe_load(f) -async def _on_connect(conn): +async def _on_connect(conn: SHSDuplexStream) -> None: async for msg in conn: print(msg) -async def main(): +async def main() -> None: """Main function to run""" server_keypair = SigningKey(b64decode(config["private"][:-8])[:32]) diff --git a/poetry.lock b/poetry.lock index ceb3525..6ad999c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -674,6 +674,51 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mypy" +version = "1.6.1" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e5012e5cc2ac628177eaac0e83d622b2dd499e28253d4107a08ecc59ede3fc2c"}, + {file = "mypy-1.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d8fbb68711905f8912e5af474ca8b78d077447d8f3918997fecbf26943ff3cbb"}, + {file = "mypy-1.6.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21a1ad938fee7d2d96ca666c77b7c494c3c5bd88dff792220e1afbebb2925b5e"}, + {file = "mypy-1.6.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b96ae2c1279d1065413965c607712006205a9ac541895004a1e0d4f281f2ff9f"}, + {file = "mypy-1.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:40b1844d2e8b232ed92e50a4bd11c48d2daa351f9deee6c194b83bf03e418b0c"}, + {file = "mypy-1.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81af8adaa5e3099469e7623436881eff6b3b06db5ef75e6f5b6d4871263547e5"}, + {file = "mypy-1.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8c223fa57cb154c7eab5156856c231c3f5eace1e0bed9b32a24696b7ba3c3245"}, + {file = "mypy-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8032e00ce71c3ceb93eeba63963b864bf635a18f6c0c12da6c13c450eedb183"}, + {file = "mypy-1.6.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4c46b51de523817a0045b150ed11b56f9fff55f12b9edd0f3ed35b15a2809de0"}, + {file = "mypy-1.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:19f905bcfd9e167159b3d63ecd8cb5e696151c3e59a1742e79bc3bcb540c42c7"}, + {file = "mypy-1.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:82e469518d3e9a321912955cc702d418773a2fd1e91c651280a1bda10622f02f"}, + {file = "mypy-1.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d4473c22cc296425bbbce7e9429588e76e05bc7342da359d6520b6427bf76660"}, + {file = "mypy-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a0d7d24dfb26729e0a068639a6ce3500e31d6655df8557156c51c1cb874ce7"}, + {file = "mypy-1.6.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cfd13d47b29ed3bbaafaff7d8b21e90d827631afda134836962011acb5904b71"}, + {file = "mypy-1.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:eb4f18589d196a4cbe5290b435d135dee96567e07c2b2d43b5c4621b6501531a"}, + {file = "mypy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:41697773aa0bf53ff917aa077e2cde7aa50254f28750f9b88884acea38a16169"}, + {file = "mypy-1.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7274b0c57737bd3476d2229c6389b2ec9eefeb090bbaf77777e9d6b1b5a9d143"}, + {file = "mypy-1.6.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbaf4662e498c8c2e352da5f5bca5ab29d378895fa2d980630656178bd607c46"}, + {file = "mypy-1.6.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bb8ccb4724f7d8601938571bf3f24da0da791fe2db7be3d9e79849cb64e0ae85"}, + {file = "mypy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:68351911e85145f582b5aa6cd9ad666c8958bcae897a1bfda8f4940472463c45"}, + {file = "mypy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:49ae115da099dcc0922a7a895c1eec82c1518109ea5c162ed50e3b3594c71208"}, + {file = "mypy-1.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8b27958f8c76bed8edaa63da0739d76e4e9ad4ed325c814f9b3851425582a3cd"}, + {file = "mypy-1.6.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:925cd6a3b7b55dfba252b7c4561892311c5358c6b5a601847015a1ad4eb7d332"}, + {file = "mypy-1.6.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8f57e6b6927a49550da3d122f0cb983d400f843a8a82e65b3b380d3d7259468f"}, + {file = "mypy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a43ef1c8ddfdb9575691720b6352761f3f53d85f1b57d7745701041053deff30"}, + {file = "mypy-1.6.1-py3-none-any.whl", hash = "sha256:4cbe68ef919c28ea561165206a2dcb68591c50f3bcf777932323bc208d949cf1"}, + {file = "mypy-1.6.1.tar.gz", hash = "sha256:4d01c00d09a0be62a4ca3f933e315455bde83f37f892ba4b08ce92f3cf44bcc1"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +typing-extensions = ">=4.1.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -1243,6 +1288,28 @@ files = [ {file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"}, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + +[[package]] +name = "typing-extensions" +version = "4.8.0" +description = "Backported and Experimental Type Hints for Python 3.8+" +optional = false +python-versions = ">=3.8" +files = [ + {file = "typing_extensions-4.8.0-py3-none-any.whl", hash = "sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0"}, + {file = "typing_extensions-4.8.0.tar.gz", hash = "sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef"}, +] + [[package]] name = "urllib3" version = "2.0.7" @@ -1309,4 +1376,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "82c05c7c22b990f45fae70d36d0bed6f1c414b30af9e40318dfc383ebab9d86a" +content-hash = "980ed766f11ffdaf6694f7800174de76cd21ac2d7a5ffbe27d8df1d1673f6119" diff --git a/pyproject.toml b/pyproject.toml index d5c80bd..4037afb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ authors = ["Pedro Ferreira "] license = "MIT" readme = "README.rst" packages = [{include = "secret_handshake"}] +include = ["secret_handshake/py.typed"] [tool.poetry.dependencies] python = "^3.12" @@ -24,6 +25,8 @@ pre-commit = "^3.5.0" commitizen = "^3.12.0" black = "^23.10.1" pylint = "^3.0.2" +mypy = "^1.6.1" +types-pyyaml = "^6.0.12.12" [tool.poetry.group.docs.dependencies] sphinx = "^7.2.6" diff --git a/secret_handshake/boxstream.py b/secret_handshake/boxstream.py index 67695a0..e294e43 100644 --- a/secret_handshake/boxstream.py +++ b/secret_handshake/boxstream.py @@ -1,8 +1,8 @@ """Box stream utilities""" -from asyncio import IncompleteReadError +from asyncio import IncompleteReadError, StreamReader, StreamWriter import struct -from typing import Tuple +from typing import Any, AsyncIterator, Optional, Tuple, TypedDict from nacl.secret import SecretBox @@ -13,36 +13,51 @@ MAX_SEGMENT_SIZE = 4 * 1024 TERMINATION_HEADER = b"\x00" * 18 -def get_stream_pair(reader, writer, **kwargs) -> Tuple["UnboxStream", "BoxStream"]: +class BoxStreamKeys(TypedDict): + """Dictionary to hold all box stream keys""" + + decrypt_key: bytes + decrypt_nonce: bytes + encrypt_key: bytes + encrypt_nonce: bytes + shared_secret: bytes + + +def get_stream_pair( # pylint: disable=too-many-arguments + reader: StreamReader, # pylint: disable=unused-argument + writer: StreamWriter, + *, + decrypt_key: bytes, + decrypt_nonce: bytes, + encrypt_key: bytes, + encrypt_nonce: bytes, + **kwargs: Any, +) -> Tuple["UnboxStream", "BoxStream"]: """Create a new duplex box stream""" - box_args = { - "key": kwargs["encrypt_key"], - "nonce": kwargs["encrypt_nonce"], - } - unbox_args = { - "key": kwargs["decrypt_key"], - "nonce": kwargs["decrypt_nonce"], - } - return UnboxStream(reader, **unbox_args), BoxStream(writer, **box_args) + read_stream = UnboxStream(reader, key=decrypt_key, nonce=decrypt_nonce) + write_stream = BoxStream(writer, key=encrypt_key, nonce=encrypt_nonce) + + return read_stream, write_stream class UnboxStream: """Unboxing stream""" - def __init__(self, reader, key, nonce): + def __init__(self, reader: StreamReader, key: bytes, nonce: bytes): self.reader = reader self.key = key self.nonce = nonce self.closed = False - async def read(self): + async def read(self) -> Optional[bytes]: """Read data from the stream""" try: data = await self.reader.readexactly(HEADER_LENGTH) except IncompleteReadError: self.closed = True + return None box = SecretBox(self.key) @@ -51,6 +66,7 @@ class UnboxStream: if header == TERMINATION_HEADER: self.closed = True + return None length = struct.unpack(">H", header[:2])[0] @@ -61,12 +77,13 @@ class UnboxStream: body = box.decrypt(mac + data, inc_nonce(self.nonce)) self.nonce = inc_nonce(inc_nonce(self.nonce)) + return body - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[bytes]: return self - async def __anext__(self): + async def __anext__(self) -> bytes: data = await self.read() if data is None: @@ -78,17 +95,17 @@ class UnboxStream: class BoxStream: """Box stream""" - def __init__(self, writer, key, nonce): + def __init__(self, writer: StreamWriter, key: bytes, nonce: bytes): self.writer = writer self.key = key self.box = SecretBox(self.key) self.nonce = nonce - def write(self, data): + def write(self, data: bytes) -> None: """Write data to the box stream""" for chunk in split_chunks(data, MAX_SEGMENT_SIZE): - body = self.box.encrypt(chunk, inc_nonce(self.nonce))[24:] + body = self.box.encrypt(bytes(chunk), inc_nonce(self.nonce))[24:] header = struct.pack(">H", len(body) - 16) + body[:16] hdrbox = self.box.encrypt(header, self.nonce)[24:] @@ -97,7 +114,7 @@ class BoxStream: self.nonce = inc_nonce(inc_nonce(self.nonce)) self.writer.write(body[16:]) - def close(self): + def close(self) -> None: """Close the box stream""" self.writer.write(self.box.encrypt(b"\x00" * 18, self.nonce)[24:]) diff --git a/secret_handshake/crypto.py b/secret_handshake/crypto.py index 879b272..851fbc4 100644 --- a/secret_handshake/crypto.py +++ b/secret_handshake/crypto.py @@ -28,7 +28,9 @@ from typing import Optional from nacl.bindings import crypto_box_afternm, crypto_box_open_afternm, crypto_scalarmult from nacl.exceptions import CryptoError from nacl.public import PrivateKey -from nacl.signing import VerifyKey +from nacl.signing import SigningKey, VerifyKey + +from .boxstream import BoxStreamKeys APPLICATION_KEY = b64decode("1KHLiKZvAvjbY1ziZEHMXawbCEIM6qwjCDm3VYRan/s=") @@ -40,29 +42,38 @@ class SHSError(Exception): class SHSCryptoBase: # pylint: disable=too-many-instance-attributes """Base functions for SHS cryptography""" - def __init__(self, local_key, ephemeral_key=None, application_key=None): - self.local_key = local_key + def __init__( + self, + local_key: SigningKey, + ephemeral_key: Optional[PrivateKey] = None, + application_key: Optional[bytes] = None, + ): self.application_key = application_key or APPLICATION_KEY - self.shared_hash = None - self.remote_ephemeral_key = None - self.shared_secret = None - self.remote_app_hmac = None - self.remote_pub_key = None - self.box_secret = None + self.b_alice: Optional[bytes] = None + self.box_secret: Optional[bytes] = None + self.hello: Optional[bytes] = None + self.local_key = local_key + self.remote_app_hmac: Optional[bytes] = None + self.remote_ephemeral_key: Optional[bytes] = None + self.remote_pub_key: Optional[VerifyKey] = None + self.shared_hash: Optional[bytes] = None + self.shared_secret: Optional[bytes] = None + self._reset_keys(ephemeral_key or PrivateKey.generate()) - def _reset_keys(self, ephemeral_key): + def _reset_keys(self, ephemeral_key: PrivateKey) -> None: self.local_ephemeral_key = ephemeral_key self.local_app_hmac = hmac.new( self.application_key, bytes(ephemeral_key.public_key), digestmod="sha512" ).digest()[:32] - def generate_challenge(self): + def generate_challenge(self) -> bytes: """Generate and return a challenge to be sent to the server.""" return self.local_app_hmac + bytes(self.local_ephemeral_key.public_key) - def verify_challenge(self, data): + def verify_challenge(self, data: bytes) -> bool: """Verify the correctness of challenge sent from the client.""" + assert len(data) == 64 sent_hmac, remote_ephemeral_key = data[:32], data[32:] @@ -76,9 +87,10 @@ class SHSCryptoBase: # pylint: disable=too-many-instance-attributes self.remote_ephemeral_key = remote_ephemeral_key # this is hash(a * b) self.shared_hash = hashlib.sha256(self.shared_secret).digest() + return ok - def clean(self, new_ephemeral_key=None): + def clean(self, new_ephemeral_key: Optional[PrivateKey] = None) -> None: """Clean internal data""" self._reset_keys(new_ephemeral_key or PrivateKey.generate()) @@ -86,10 +98,15 @@ class SHSCryptoBase: # pylint: disable=too-many-instance-attributes self.shared_hash = None self.remote_ephemeral_key = None - def get_box_keys(self): + def get_box_keys(self) -> BoxStreamKeys: """Get the box stream’s keys""" + assert self.box_secret + assert self.remote_app_hmac + assert self.remote_pub_key + shared_secret = hashlib.sha256(self.box_secret).digest() + return { "shared_secret": shared_secret, "encrypt_key": hashlib.sha256(shared_secret + bytes(self.remote_pub_key)).digest(), @@ -102,17 +119,15 @@ class SHSCryptoBase: # pylint: disable=too-many-instance-attributes class SHSServerCrypto(SHSCryptoBase): """SHS server crypto algorithm""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.b_alice = None - self.hello = None - self.box_secret = None - self.remote_pub_key = None - - def verify_client_auth(self, data): + def verify_client_auth(self, data: bytes) -> bool: """Verify client authentication data""" + assert self.remote_ephemeral_key + assert self.shared_hash + assert self.shared_secret + assert len(data) == 112 + a_bob = crypto_scalarmult(bytes(self.local_key.to_curve25519_private_key()), self.remote_ephemeral_key) box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob).digest() self.hello = crypto_box_open_afternm(data, b"\x00" * 24, box_secret) @@ -127,15 +142,21 @@ class SHSServerCrypto(SHSCryptoBase): bytes(self.local_ephemeral_key), bytes(self.remote_pub_key.to_curve25519_public_key()) ) self.box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob + b_alice).digest()[:32] + return True - def generate_accept(self): + def generate_accept(self) -> bytes: """Generate an accept message""" + assert self.box_secret + assert self.hello + assert self.shared_hash + okay = self.local_key.sign(self.application_key + self.hello + self.shared_hash).signature + return crypto_box_afternm(okay, b"\x00" * 24, self.box_secret) - def clean(self, new_ephemeral_key=None): + def clean(self, new_ephemeral_key: Optional[PrivateKey] = None) -> None: super().clean(new_ephemeral_key=new_ephemeral_key) self.hello = None self.b_alice = None @@ -152,21 +173,26 @@ class SHSClientCrypto(SHSCryptoBase): def __init__( self, - local_key: PrivateKey, + local_key: SigningKey, server_pub_key: bytes, ephemeral_key: PrivateKey, application_key: Optional[bytes] = None, ): super().__init__(local_key, ephemeral_key, application_key) - self.remote_pub_key = VerifyKey(server_pub_key) - self.b_alice = None - self.a_bob = None - self.hello = None - self.box_secret = None - def verify_server_challenge(self, data): + self.a_bob: Optional[bytes] = None + self.remote_pub_key = VerifyKey(server_pub_key) + + def verify_server_challenge(self, data: bytes) -> bool: """Verify the correctness of challenge sent from the server.""" + + assert self.remote_pub_key + assert super().verify_challenge(data) + + assert self.shared_hash + assert self.shared_secret + curve_pkey = self.remote_pub_key.to_curve25519_public_key() # a_bob is (a * B) @@ -179,17 +205,30 @@ class SHSClientCrypto(SHSCryptoBase): signed_message = self.local_key.sign(self.application_key + bytes(self.remote_pub_key) + self.shared_hash) message_to_box = signed_message.signature + bytes(self.local_key.verify_key) self.hello = message_to_box + return True - def generate_client_auth(self): + def generate_client_auth(self) -> bytes: """Generate box[K|a*b|a*B](H)""" + assert self.box_secret + assert self.hello + nonce = b"\x00" * 24 + # return box(K | a * b | a * B)[H] return crypto_box_afternm(self.hello, nonce, self.box_secret) - def verify_server_accept(self, data): + def verify_server_accept(self, data: bytes) -> bool: """Verify that the server's accept message is sane""" + + assert self.a_bob + assert self.hello + assert self.remote_ephemeral_key + assert self.remote_pub_key + assert self.shared_hash + assert self.shared_secret + curve_lkey = self.local_key.to_curve25519_private_key() # b_alice is (A * b) b_alice = crypto_scalarmult(bytes(curve_lkey), self.remote_ephemeral_key) @@ -208,9 +247,10 @@ class SHSClientCrypto(SHSCryptoBase): # we should have received sign(B)[K | H | hash(a * b)] # let's see if that signature can verify the reconstructed data on our side self.remote_pub_key.verify(self.application_key + self.hello + self.shared_hash, signature) + return True - def clean(self, new_ephemeral_key=None): + def clean(self, new_ephemeral_key: Optional[PrivateKey] = None) -> None: super().clean(new_ephemeral_key=new_ephemeral_key) self.a_bob = None self.b_alice = None diff --git a/secret_handshake/network.py b/secret_handshake/network.py index 1e665df..d92627c 100644 --- a/secret_handshake/network.py +++ b/secret_handshake/network.py @@ -20,10 +20,15 @@ """Networking functionality""" -import asyncio +from asyncio import StreamReader, StreamWriter, ensure_future, open_connection, start_server +from typing import AsyncIterator, Awaitable, Callable, List, Optional -from .boxstream import get_stream_pair -from .crypto import SHSClientCrypto, SHSServerCrypto +from nacl.public import PrivateKey +from nacl.signing import SigningKey +from typing_extensions import Self + +from .boxstream import BoxStream, UnboxStream, get_stream_pair +from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto class SHSClientException(Exception): @@ -33,32 +38,37 @@ class SHSClientException(Exception): class SHSDuplexStream: """SHS duplex stream""" - def __init__(self): - self.write_stream = None - self.read_stream = None + def __init__(self) -> None: + self.write_stream: Optional[BoxStream] = None + self.read_stream: Optional[UnboxStream] = None self.is_connected = False - def write(self, data): + def write(self, data: bytes) -> None: """Write data to the write stream""" + assert self.write_stream + self.write_stream.write(data) - async def read(self): + async def read(self) -> Optional[bytes]: """Read data from the read stream""" + assert self.read_stream + return await self.read_stream.read() - def close(self): + def close(self) -> None: """Close the duplex stream""" - self.write_stream.close() - self.read_stream.close() + if self.write_stream: + self.write_stream.close() + self.is_connected = False - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[bytes]: return self - async def __anext__(self): + async def __anext__(self) -> bytes: msg = await self.read() if msg is None: @@ -70,45 +80,53 @@ class SHSDuplexStream: class SHSEndpoint: """SHS endpoint""" - def __init__(self): - self._on_connect = None - self.crypto = None + def __init__(self) -> None: + self._on_connect: Optional[Callable[[SHSDuplexStream], Awaitable[None]]] = None + self.crypto: Optional[SHSCryptoBase] = None - def on_connect(self, cb): + def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None: """Set the function to be called when a new connection arrives""" + self._on_connect = cb - def disconnect(self): + def disconnect(self) -> None: """Disconnect the endpoint""" + raise NotImplementedError class SHSServer(SHSEndpoint): """SHS server""" - def __init__(self, host, port, server_kp, application_key=None): + def __init__(self, host: str, port: int, server_kp: SigningKey, application_key: Optional[bytes] = None): super().__init__() self.host = host self.port = port - self.crypto = SHSServerCrypto(server_kp, application_key=application_key) - self.connections = [] + self.crypto: SHSServerCrypto = SHSServerCrypto(server_kp, application_key=application_key) + self.connections: List[SHSServerConnection] = [] + + async def _handshake(self, reader: StreamReader, writer: StreamWriter) -> None: + assert self.crypto - async def _handshake(self, reader, writer): data = await reader.readexactly(64) + if not self.crypto.verify_challenge(data): raise SHSClientException("Client challenge is not valid") writer.write(self.crypto.generate_challenge()) data = await reader.readexactly(112) + if not self.crypto.verify_client_auth(data): raise SHSClientException("Client auth is not valid") writer.write(self.crypto.generate_accept()) - async def handle_connection(self, reader, writer): + async def handle_connection(self, reader: StreamReader, writer: StreamWriter) -> None: """Handle incoming connections""" + assert self.crypto + self.crypto.clean() await self._handshake(reader, writer) keys = self.crypto.get_box_keys() @@ -118,14 +136,14 @@ class SHSServer(SHSEndpoint): self.connections.append(conn) if self._on_connect: - asyncio.ensure_future(self._on_connect(conn)) + ensure_future(self._on_connect(conn)) - async def listen(self): + async def listen(self) -> None: """Listen for connections""" - await asyncio.start_server(self.handle_connection, self.host, self.port) + await start_server(self.handle_connection, self.host, self.port) - def disconnect(self): + def disconnect(self) -> None: for connection in self.connections: connection.close() @@ -133,52 +151,63 @@ class SHSServer(SHSEndpoint): class SHSServerConnection(SHSDuplexStream): """SHS server connection""" - def __init__(self, read_stream, write_stream): + def __init__(self, read_stream: UnboxStream, write_stream: BoxStream): super().__init__() + self.read_stream = read_stream self.write_stream = write_stream @classmethod - def from_byte_streams(cls, reader, writer, **keys): + def from_byte_streams(cls, reader: StreamReader, writer: StreamWriter, **keys: bytes) -> Self: """Create a server connection from an existing byte stream""" - reader, writer = get_stream_pair(reader, writer, **keys) + box_reader, box_writer = get_stream_pair(reader, writer, **keys) - return cls(reader, writer) + return cls(box_reader, box_writer) class SHSClient(SHSDuplexStream, SHSEndpoint): """SHS client""" def __init__( # pylint: disable=too-many-arguments - self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None + self, + host: str, + port: int, + client_kp: SigningKey, + server_pub_key: bytes, + ephemeral_key: Optional[PrivateKey] = None, + application_key: Optional[bytes] = None, ): SHSDuplexStream.__init__(self) SHSEndpoint.__init__(self) self.host = host self.port = port - self.writer = None - self.crypto = SHSClientCrypto( - client_kp, server_pub_key, ephemeral_key=ephemeral_key, application_key=application_key + self.writer: Optional[StreamWriter] = None + self.crypto: SHSClientCrypto = SHSClientCrypto( + client_kp, + server_pub_key, + ephemeral_key=ephemeral_key or PrivateKey.generate(), + application_key=application_key, ) - async def _handshake(self, reader, writer): + async def _handshake(self, reader: StreamReader, writer: StreamWriter) -> None: writer.write(self.crypto.generate_challenge()) data = await reader.readexactly(64) + if not self.crypto.verify_server_challenge(data): raise SHSClientException("Server challenge is not valid") writer.write(self.crypto.generate_client_auth()) - data = await reader.readexactly(80) + if not self.crypto.verify_server_accept(data): raise SHSClientException("Server accept is not valid") - async def open(self): + async def open(self) -> None: """Open the TCP connection""" - reader, writer = await asyncio.open_connection(self.host, self.port) + reader, writer = await open_connection(self.host, self.port) await self._handshake(reader, writer) keys = self.crypto.get_box_keys() @@ -189,7 +218,7 @@ class SHSClient(SHSDuplexStream, SHSEndpoint): self.is_connected = True if self._on_connect: - await self._on_connect() + await self._on_connect(self) - def disconnect(self): + def disconnect(self) -> None: self.close() diff --git a/secret_handshake/py.typed b/secret_handshake/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/secret_handshake/util.py b/secret_handshake/util.py index 69e17d2..7be80b5 100644 --- a/secret_handshake/util.py +++ b/secret_handshake/util.py @@ -21,12 +21,14 @@ """Utility functions""" import struct +from typing import Generator, Sequence, TypeVar NONCE_SIZE = 24 MAX_NONCE = 8 * NONCE_SIZE +T = TypeVar("T") -def inc_nonce(nonce): +def inc_nonce(nonce: bytes) -> bytes: """Increment nonce""" num = bytes_to_long(nonce) + 1 @@ -40,34 +42,38 @@ def inc_nonce(nonce): return bnum -def split_chunks(seq, n): +def split_chunks(seq: Sequence[T], n: int) -> Generator[Sequence[T], None, None]: """Split sequence in equal-sized chunks. + The last chunk is not padded.""" + while seq: yield seq[:n] seq = seq[n:] # Stolen from PyCypto (Public Domain) -def b(s): +def b(s: str) -> bytes: """Shorthand for s.encode("latin-1")""" return s.encode("latin-1") # utf-8 would cause some side-effects we don't want -def long_to_bytes(n, blocksize=0): - """long_to_bytes(n:long, blocksize:int) : string - Convert a long integer to a byte string. - If optional blocksize is given and greater than zero, pad the front of the - byte string with binary zeros so that the length is a multiple of - blocksize. +def long_to_bytes(n: int, blocksize: int = 0) -> bytes: + """Convert a long integer to a byte string. + + If optional ``blocksize`` is given and greater than zero, pad the front of the byte string with binary zeros so + that the length is a multiple of blocksize. """ + # after much testing, this algorithm was deemed to be the fastest s = b("") pack = struct.pack + while n > 0: s = pack(">I", n & 0xFFFFFFFF) + s n = n >> 32 + # strip off leading zeros for i, c in enumerate(s): if c != b("\000")[0]: @@ -76,26 +82,33 @@ def long_to_bytes(n, blocksize=0): # only happens when n == 0 s = b("\000") i = 0 + s = s[i:] + # add back some pad bytes. this could be done more efficiently w.r.t. the # de-padding being done above, but sigh... if blocksize > 0 and len(s) % blocksize: s = (blocksize - len(s) % blocksize) * b("\000") + s + return s -def bytes_to_long(s): - """bytes_to_long(string) : long - Convert a byte string to a long integer. - This is (essentially) the inverse of long_to_bytes(). +def bytes_to_long(s: bytes) -> int: + """Convert a byte string to a long integer. + + This is (essentially) the inverse of ``long_to_bytes()``. """ + acc = 0 unpack = struct.unpack length = len(s) + if length % 4: extra = 4 - length % 4 s = b("\000") * extra + s length = length + extra + for i in range(0, length, 4): acc = (acc << 32) + unpack(">I", s[i : i + 4])[0] + return acc diff --git a/tests/helpers.py b/tests/helpers.py index f244a26..7fadfa7 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -20,27 +20,34 @@ """Helper utilities for testing""" +from asyncio import StreamReader, StreamWriter from io import BytesIO +from typing import AsyncIterable, List, Optional, TypeVar + +T = TypeVar("T") -class AsyncBuffer(BytesIO): +class AsyncBuffer(BytesIO, StreamReader, StreamWriter): # type: ignore[misc] """Just a BytesIO with an async read method.""" - async def read(self, n=None): # pylint: disable=invalid-overridden-method + async def read( # type: ignore[override] # pylint: disable=invalid-overridden-method + self, n: Optional[int] = None + ) -> Optional[bytes]: v = super().read(n) return v - readexactly = read + readexactly = read # type: ignore[assignment] - def append(self, data): + def append(self, data: bytes) -> None: """Append data to the buffer without changing the current position.""" + pos = self.tell() self.write(data) self.seek(pos) -async def async_comprehend(generator): +async def async_comprehend(generator: AsyncIterable[T]) -> List[T]: """Emulate ``[elem async for elem in generator]``.""" results = [] diff --git a/tests/test_boxstream.py b/tests/test_boxstream.py index a92d02a..a341881 100644 --- a/tests/test_boxstream.py +++ b/tests/test_boxstream.py @@ -38,8 +38,9 @@ MESSAGE_CLOSED = b"\xb1\x14hU'\xb5M\xa6\"\x03\x9duy\xa1\xd4evW,\xdcE\x18\xe4+ C4 @pytest.mark.asyncio -async def test_boxstream(): +async def test_boxstream() -> None: """Test stream boxing""" + buffer = AsyncBuffer() box_stream = BoxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE) box_stream.write(b"foo") @@ -63,7 +64,7 @@ async def test_boxstream(): @pytest.mark.asyncio -async def test_unboxstream(): +async def test_unboxstream() -> None: """Test stream unboxing""" buffer = AsyncBuffer(MESSAGE_1 + MESSAGE_2 + MESSAGE_3 + MESSAGE_CLOSED) @@ -76,7 +77,7 @@ async def test_unboxstream(): @pytest.mark.asyncio -async def test_long_packets(): +async def test_long_packets() -> None: """Test for receiving long packets""" data_size = 6 * 1024 diff --git a/tests/test_crypto.py b/tests/test_crypto.py index 54d9cbf..cd8365e 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -36,7 +36,7 @@ CLIENT_EPH_KEY_SEED = b"u8\xd0\xe3\x85d_Pz\x0c\xf5\xfd\x15\xce2p#\xb0\xf0\x9f\xe @pytest.fixture -def server(): +def server() -> SHSServerCrypto: """A testing SHS server""" server_key = SigningKey(SERVER_KEY_SEED) @@ -46,7 +46,7 @@ def server(): @pytest.fixture -def client(): +def client() -> SHSClientCrypto: """A testing SHS client""" client_key = SigningKey(CLIENT_KEY_SEED) @@ -90,7 +90,7 @@ CLIENT_ENCRYPT_NONCE = b"S\\\x06\x8d\xe5\xeb&*\xb8\x0bp\xb3Z\x8e\\\x85\x14\xaa\x CLIENT_DECRYPT_NONCE = b"d\xe8\xccD\xec\xb9E\xbb\xaa\xa7\x7f\xe38\x15\x16\xef\xca\xd22u\x1d\xfe<\xe7" -def test_handshake(client, server): # pylint: disable=redefined-outer-name +def test_handshake(client: SHSClientCrypto, server: SHSServerCrypto) -> None: # pylint: disable=redefined-outer-name """Test the handshake procedure""" client_challenge = client.generate_challenge() diff --git a/tests/test_network.py b/tests/test_network.py index af209b9..38b86db 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -22,11 +22,14 @@ from asyncio import Event, wait_for import os +from typing import Any, Awaitable, Callable, Tuple from nacl.signing import SigningKey import pytest +from pytest_mock import MockerFixture from secret_handshake import SHSClient, SHSServer +from secret_handshake.boxstream import BoxStreamKeys from .helpers import AsyncBuffer @@ -34,41 +37,42 @@ from .helpers import AsyncBuffer class DummyCrypto: """Dummy crypto module, pretends everything is fine.""" - def verify_server_challenge(self, _): + def verify_server_challenge(self, _: bytes) -> bool: """Verify the server challenge""" return True - def verify_challenge(self, _): + def verify_challenge(self, _: bytes) -> bool: """Verify the challenge data""" return True - def verify_server_accept(self, _): + def verify_server_accept(self, _: bytes) -> bool: """Verify server’s accept message""" + return True - def generate_challenge(self): + def generate_challenge(self) -> bytes: """Generate authentication challenge""" return b"CHALLENGE" - def generate_client_auth(self): + def generate_client_auth(self) -> bytes: """Generate client authentication data""" return b"AUTH" - def verify_client_auth(self, _): + def verify_client_auth(self, _: bytes) -> bool: """Verify client authentication data""" return True - def generate_accept(self): + def generate_accept(self) -> bytes: """Generate an ACCEPT message""" return b"ACCEPT" - def get_box_keys(self): + def get_box_keys(self) -> BoxStreamKeys: """Get box keys""" return { @@ -76,48 +80,64 @@ class DummyCrypto: "encrypt_nonce": b"x" * 32, "decrypt_key": b"x" * 32, "decrypt_nonce": b"x" * 32, + "shared_secret": b"x" * 32, } - def clean(self): + def clean(self) -> None: """Clean up internal data""" -def _dummy_boxstream(stream, **_): - """Identity boxstream, no tansformation.""" +def _dummy_boxstream(stream: AsyncBuffer, **_: Any) -> AsyncBuffer: + """Identity boxstream, no transformation.""" + return stream -def _client_stream_mocker(): +def _client_stream_mocker() -> ( + Tuple[AsyncBuffer, AsyncBuffer, Callable[[str, int], Awaitable[Tuple[AsyncBuffer, AsyncBuffer]]]] +): reader = AsyncBuffer(b"xxx") writer = AsyncBuffer(b"xxx") - async def _create_mock_streams(host, port): # pylint: disable=unused-argument + async def _create_mock_streams( + host: str, port: int # pylint: disable=unused-argument + ) -> Tuple[AsyncBuffer, AsyncBuffer]: return reader, writer return reader, writer, _create_mock_streams -def _server_stream_mocker(): +def _server_stream_mocker() -> ( + Tuple[ + AsyncBuffer, + AsyncBuffer, + Callable[[Callable[[AsyncBuffer, AsyncBuffer], Awaitable[None]], str, int], Awaitable[None]], + ] +): reader = AsyncBuffer(b"xxx") writer = AsyncBuffer(b"xxx") - async def _create_mock_server(cb, host, port): # pylint: disable=unused-argument + async def _create_mock_server( + cb: Callable[[AsyncBuffer, AsyncBuffer], Awaitable[None]], + host: str, # pylint: disable=unused-argument + port: int, # pylint: disable=unused-argument + ) -> None: await cb(reader, writer) return reader, writer, _create_mock_server @pytest.mark.asyncio -async def test_client(mocker): +async def test_client(mocker: MockerFixture) -> None: """Test the client""" reader, _, _create_mock_streams = _client_stream_mocker() - mocker.patch("asyncio.open_connection", new=_create_mock_streams) + mocker.patch("secret_handshake.network.open_connection", new=_create_mock_streams) mocker.patch("secret_handshake.boxstream.BoxStream", new=_dummy_boxstream) mocker.patch("secret_handshake.boxstream.UnboxStream", new=_dummy_boxstream) client = SHSClient("shop.local", 1111, SigningKey.generate(), os.urandom(32)) - client.crypto = DummyCrypto() + client.crypto = DummyCrypto() # type: ignore[assignment] await client.open() reader.append(b"TEST") @@ -126,22 +146,22 @@ async def test_client(mocker): @pytest.mark.asyncio -async def test_server(mocker): +async def test_server(mocker: MockerFixture) -> None: """Test the server""" resolve = Event() - async def _on_connect(_): + async def _on_connect(_: Any) -> None: server.disconnect() resolve.set() _, _, _create_mock_server = _server_stream_mocker() - mocker.patch("asyncio.start_server", new=_create_mock_server) + mocker.patch("secret_handshake.network.start_server", new=_create_mock_server) mocker.patch("secret_handshake.boxstream.BoxStream", new=_dummy_boxstream) mocker.patch("secret_handshake.boxstream.UnboxStream", new=_dummy_boxstream) server = SHSServer("shop.local", 1111, SigningKey.generate(), os.urandom(32)) - server.crypto = DummyCrypto() + server.crypto = DummyCrypto() # type: ignore[assignment] server.on_connect(_on_connect)