From 1c1e57d86835da997f7f241c3997ad5de6cfde58 Mon Sep 17 00:00:00 2001 From: Gergely Polonkai Date: Wed, 1 Nov 2023 07:22:29 +0100 Subject: [PATCH] ci: Add and configure mypy, and make it happy --- .pre-commit-config.yaml | 7 +++ examples/test_client.py | 13 ++-- examples/test_server.py | 5 +- poetry.lock | 71 ++++++++++++++++++++- pyproject.toml | 4 ++ ssb/feed/models.py | 83 +++++++++++++----------- ssb/muxrpc.py | 108 +++++++++++++++++++++---------- ssb/packet_stream.py | 122 ++++++++++++++++++++++++------------ ssb/py.typed | 0 ssb/util.py | 15 +++-- tests/test_feed.py | 83 ++++++------------------ tests/test_packet_stream.py | 73 ++++++++++++--------- tests/test_util.py | 4 +- 13 files changed, 374 insertions(+), 214 deletions(-) create mode 100644 ssb/py.typed diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 258706a..a01b6cc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,6 +37,13 @@ repos: language: system require_serial: true types_or: [python, pyi] + - id: mypy + name: mypy + entry: poetry run mypy + args: ["--strict"] + language: system + types_or: [python, pyi] + require_serial: true - id: reuse name: reuse entry: poetry run reuse diff --git a/examples/test_client.py b/examples/test_client.py index 536a301..cc9b1ac 100644 --- a/examples/test_client.py +++ b/examples/test_client.py @@ -30,9 +30,10 @@ import struct import time from colorlog import ColoredFormatter +from nacl.signing import SigningKey from secret_handshake.network import SHSClient -from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException +from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException, MuxRPCRequest from ssb.packet_stream import PacketStream, PSMessageType from ssb.util import load_ssb_secret @@ -40,7 +41,7 @@ api = MuxRPCAPI() @api.define("createHistoryStream") -def create_history_stream(connection, msg): # pylint: disable=unused-argument +def create_history_stream(connection: PacketStream, msg: MuxRPCRequest) -> None: # pylint: disable=unused-argument """Handle the createHistoryStream RPC call""" print("create_history_stream", msg) @@ -49,13 +50,13 @@ def create_history_stream(connection, msg): # pylint: disable=unused-argument @api.define("blobs.createWants") -def create_wants(connection, msg): # pylint: disable=unused-argument +def create_wants(connection: PacketStream, msg: MuxRPCRequest) -> None: # pylint: disable=unused-argument """Handle the createWants RPC call""" print("create_wants", msg) -async def test_client(): +async def test_client() -> None: """The actual client implementation""" async for msg in api.call( @@ -90,6 +91,8 @@ async def test_client(): img_data = b"" async for msg in api.call("blobs.get", ["&kqZ52sDcJSHOx7m4Ww80kK1KIZ65gpGnqwZlfaIVWWM=.sha256"], "source"): + assert msg + if msg.type.name == "BUFFER": img_data += msg.data if msg.type.name == "JSON" and msg.data == b"true": @@ -101,7 +104,7 @@ async def test_client(): f.write(img_data) -async def main(keypair): +async def main(keypair: SigningKey) -> None: """The main function to run""" client = SHSClient("127.0.0.1", 8008, keypair, bytes(keypair.verify_key)) diff --git a/examples/test_server.py b/examples/test_server.py index d1726e2..7854f51 100644 --- a/examples/test_server.py +++ b/examples/test_server.py @@ -27,6 +27,7 @@ import logging from colorlog import ColoredFormatter from secret_handshake import SHSServer +from secret_handshake.network import SHSDuplexStream from ssb.muxrpc import MuxRPCAPI from ssb.packet_stream import PacketStream @@ -35,7 +36,7 @@ from ssb.util import load_ssb_secret api = MuxRPCAPI() -async def on_connect(conn): +async def on_connect(conn: SHSDuplexStream) -> None: """Incoming connection handler""" packet_stream = PacketStream(conn) @@ -46,7 +47,7 @@ async def on_connect(conn): print(msg) -async def main(): +async def main() -> None: """The main function to run""" server = SHSServer("127.0.0.1", 8008, load_ssb_secret()["keypair"]) diff --git a/poetry.lock b/poetry.lock index 8c3db10..691b6bf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -649,6 +649,53 @@ files = [ {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, ] +[[package]] +name = "mypy" +version = "1.7.0" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5da84d7bf257fd8f66b4f759a904fd2c5a765f70d8b52dde62b521972a0a2357"}, + {file = "mypy-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3637c03f4025f6405737570d6cbfa4f1400eb3c649317634d273687a09ffc2f"}, + {file = "mypy-1.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b633f188fc5ae1b6edca39dae566974d7ef4e9aaaae00bc36efe1f855e5173ac"}, + {file = "mypy-1.7.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d6ed9a3997b90c6f891138e3f83fb8f475c74db4ccaa942a1c7bf99e83a989a1"}, + {file = "mypy-1.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:1fe46e96ae319df21359c8db77e1aecac8e5949da4773c0274c0ef3d8d1268a9"}, + {file = "mypy-1.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:df67fbeb666ee8828f675fee724cc2cbd2e4828cc3df56703e02fe6a421b7401"}, + {file = "mypy-1.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a79cdc12a02eb526d808a32a934c6fe6df07b05f3573d210e41808020aed8b5d"}, + {file = "mypy-1.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f65f385a6f43211effe8c682e8ec3f55d79391f70a201575def73d08db68ead1"}, + {file = "mypy-1.7.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e81ffd120ee24959b449b647c4b2fbfcf8acf3465e082b8d58fd6c4c2b27e46"}, + {file = "mypy-1.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:f29386804c3577c83d76520abf18cfcd7d68264c7e431c5907d250ab502658ee"}, + {file = "mypy-1.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:87c076c174e2c7ef8ab416c4e252d94c08cd4980a10967754f91571070bf5fbe"}, + {file = "mypy-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6cb8d5f6d0fcd9e708bb190b224089e45902cacef6f6915481806b0c77f7786d"}, + {file = "mypy-1.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93e76c2256aa50d9c82a88e2f569232e9862c9982095f6d54e13509f01222fc"}, + {file = "mypy-1.7.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cddee95dea7990e2215576fae95f6b78a8c12f4c089d7e4367564704e99118d3"}, + {file = "mypy-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:d01921dbd691c4061a3e2ecdbfbfad029410c5c2b1ee88946bf45c62c6c91210"}, + {file = "mypy-1.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:185cff9b9a7fec1f9f7d8352dff8a4c713b2e3eea9c6c4b5ff7f0edf46b91e41"}, + {file = "mypy-1.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a7b1e399c47b18feb6f8ad4a3eef3813e28c1e871ea7d4ea5d444b2ac03c418"}, + {file = "mypy-1.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc9fe455ad58a20ec68599139ed1113b21f977b536a91b42bef3ffed5cce7391"}, + {file = "mypy-1.7.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d0fa29919d2e720c8dbaf07d5578f93d7b313c3e9954c8ec05b6d83da592e5d9"}, + {file = "mypy-1.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b53655a295c1ed1af9e96b462a736bf083adba7b314ae775563e3fb4e6795f5"}, + {file = "mypy-1.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c1b06b4b109e342f7dccc9efda965fc3970a604db70f8560ddfdee7ef19afb05"}, + {file = "mypy-1.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bf7a2f0a6907f231d5e41adba1a82d7d88cf1f61a70335889412dec99feeb0f8"}, + {file = "mypy-1.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551d4a0cdcbd1d2cccdcc7cb516bb4ae888794929f5b040bb51aae1846062901"}, + {file = "mypy-1.7.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:55d28d7963bef00c330cb6461db80b0b72afe2f3c4e2963c99517cf06454e665"}, + {file = "mypy-1.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:870bd1ffc8a5862e593185a4c169804f2744112b4a7c55b93eb50f48e7a77010"}, + {file = "mypy-1.7.0-py3-none-any.whl", hash = "sha256:96650d9a4c651bc2a4991cf46f100973f656d69edc7faf91844e87fe627f7e96"}, + {file = "mypy-1.7.0.tar.gz", hash = "sha256:1e280b5697202efa698372d2f39e9a6713a0395a756b1c6bd48995f8d72690dc"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -1253,6 +1300,28 @@ files = [ {file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"}, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + +[[package]] +name = "types-simplejson" +version = "3.19.0.2" +description = "Typing stubs for simplejson" +optional = false +python-versions = "*" +files = [ + {file = "types-simplejson-3.19.0.2.tar.gz", hash = "sha256:ebc81f886f89d99d6b80c726518aa2228bc77c26438f18fd81455e4f79f8ee1b"}, + {file = "types_simplejson-3.19.0.2-py3-none-any.whl", hash = "sha256:8ba093dc7884f59b3e62aed217144085e675a269debc32678fd80e0b43b2b86f"}, +] + [[package]] name = "typing-extensions" version = "4.8.0" @@ -1310,4 +1379,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "d57dc0c074d7daf70507fda1fc9641cf367b6dc8f02b34a5fceafe6b45c0f4f9" +content-hash = "98384046072d2dd4f649a93231ee6a84e5b21be34f15d5d2196cd3832f15ebca" diff --git a/pyproject.toml b/pyproject.toml index 553054b..2da6b4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ description = "Secure Scuttlebutt library in Python" authors = ["PyScuttleButt Contributors "] license = "MIT" readme = "README.rst" +include = ["ssb/py.typed"] [tool.poetry.dependencies] python = "^3.9" @@ -23,12 +24,15 @@ check-manifest = "^0.39" commitizen = "^3.12.0" coverage = "^7.3.2" isort = "^5.12.0" +mypy = "^1.6.1" pep257 = "^0.7.0" pylint = "^3.0.2" pytest = "^7.4.3" pytest-asyncio = "^0.21.1" pytest-cov = "^4.1.0" pytest-mock = "^3.12.0" +types-pyyaml = "^6.0.12.12" +types-simplejson = "^3.19.0.2" [tool.poetry.group.docs.dependencies] Sphinx = "^2.1.1" diff --git a/ssb/feed/models.py b/ssb/feed/models.py index 8a2fba0..5b9a9de 100644 --- a/ssb/feed/models.py +++ b/ssb/feed/models.py @@ -26,8 +26,11 @@ from base64 import b64encode from collections import OrderedDict, namedtuple from datetime import datetime from hashlib import sha256 +from typing import Any, Dict, Optional +from nacl.signing import SigningKey, VerifyKey from simplejson import dumps, loads +from typing_extensions import Self from ssb.util import tag @@ -38,7 +41,7 @@ class NoPrivateKeyException(Exception): """Exception to raise when a private key is not available""" -def to_ordered(data): +def to_ordered(data: Dict[str, Any]) -> OrderedDict[str, Any]: """Convert a dictionary to an ``OrderedDict``""" smsg = OrderedMsg(**data) @@ -46,7 +49,7 @@ def to_ordered(data): return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields) -def get_millis_1970(): +def get_millis_1970() -> int: """Get the UNIX timestamp in milliseconds""" return int(datetime.utcnow().timestamp() * 1000) @@ -55,16 +58,16 @@ def get_millis_1970(): class Feed: """Base class for feeds""" - def __init__(self, public_key): + def __init__(self, public_key: VerifyKey) -> None: self.public_key = public_key @property - def id(self): + def id(self) -> str: """The identifier of the feed""" return tag(self.public_key).decode("ascii") - def sign(self, msg): + def sign(self, msg: bytes) -> bytes: """Sign a message""" raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)") @@ -73,16 +76,20 @@ class Feed: class LocalFeed(Feed): """Class representing a local feed""" - def __init__(self, private_key): # pylint: disable=super-init-not-called + def __init__(self, private_key: SigningKey) -> None: # pylint: disable=super-init-not-called self.private_key = private_key @property - def public_key(self): + def public_key(self) -> VerifyKey: """The public key of the feed""" return self.private_key.verify_key - def sign(self, msg): + @public_key.setter + def public_key(self, key: VerifyKey) -> None: + raise TypeError("Can not set only the public key for a local feed") + + def sign(self, msg: bytes) -> bytes: """Sign a message for this feed""" return self.private_key.sign(msg).signature @@ -92,25 +99,34 @@ class Message: """Base class for SSB messages""" def __init__( # pylint: disable=too-many-arguments - self, feed, content, signature=None, sequence=1, timestamp=None, previous=None + self, + feed: Feed, + content: Dict[str, Any], + signature: Optional[str] = None, + sequence: int = 1, + timestamp: Optional[int] = None, + previous: Optional["Message"] = None, ): self.feed = feed self.content = content - if signature is None: - raise ValueError("signature can't be None") self.signature = signature - self.previous = previous + self.timestamp = get_millis_1970() if timestamp is None else timestamp + if self.previous: - self.sequence = self.previous.sequence + 1 + self.sequence: int = self.previous.sequence + 1 else: self.sequence = sequence - self.timestamp = get_millis_1970() if timestamp is None else timestamp + self._check_signature() + + def _check_signature(self) -> None: + if self.signature is None: + raise ValueError("signature can't be None") @classmethod - def parse(cls, data, feed): + def parse(cls, data: bytes, feed: Feed) -> Self: """Parse raw message data""" obj = loads(data, object_pairs_hook=OrderedDict) @@ -118,12 +134,12 @@ class Message: return msg - def serialize(self, add_signature=True): + def serialize(self, add_signature: bool = True) -> bytes: """Serialize the message""" return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8") - def to_dict(self, add_signature=True): + def to_dict(self, add_signature: bool = True) -> OrderedDict[str, Any]: """Convert the message to a dictionary""" obj = to_ordered( @@ -142,20 +158,21 @@ class Message: return obj - def verify(self, signature): + def verify(self, signature: str) -> bool: """Verify the signature of the message""" return self.signature == signature @property - def hash(self): + def hash(self) -> str: """The cryptographic hash of the message""" hash_ = sha256(self.serialize()).digest() + return b64encode(hash_).decode("ascii") + ".sha256" @property - def key(self): + def key(self) -> str: """The key of the message""" return "%" + self.hash @@ -165,25 +182,21 @@ class LocalMessage(Message): """Class representing a local message""" def __init__( # pylint: disable=too-many-arguments,super-init-not-called - self, feed, content, signature=None, sequence=1, timestamp=None, previous=None + self, + feed: LocalFeed, + content: Dict[str, Any], + signature: Optional[str] = None, + sequence: int = 1, + timestamp: Optional[int] = None, + previous: Optional["LocalMessage"] = None, ): - self.feed = feed - self.content = content + super().__init__(feed, content, signature=signature, sequence=sequence, timestamp=timestamp, previous=previous) - self.previous = previous - if self.previous: - self.sequence = self.previous.sequence + 1 - else: - self.sequence = sequence - - self.timestamp = get_millis_1970() if timestamp is None else timestamp - - if signature is None: + def _check_signature(self) -> None: + if self.signature is None: self.signature = self._sign() - else: - self.signature = signature - def _sign(self): + def _sign(self) -> str: # ensure ordering of keys and indentation of 2 characters, like ssb-keys data = self.serialize(add_signature=False) return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii") diff --git a/ssb/muxrpc.py b/ssb/muxrpc.py index a059196..46d23cd 100644 --- a/ssb/muxrpc.py +++ b/ssb/muxrpc.py @@ -22,7 +22,16 @@ """MuxRPC""" -from ssb.packet_stream import PSMessageType +from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Union + +from typing_extensions import Self + +from .packet_stream import PacketStream, PSMessage, PSMessageType, PSRequestHandler, PSStreamHandler + +MuxRPCJSON = Dict[str, Any] +MuxRPCCallType = Literal["async", "duplex", "sink", "source", "sync"] +MuxRPCRequestHandlerType = Callable[[PacketStream, "MuxRPCRequest"], None] +MuxRPCRequestParam = Union[bytes, str, MuxRPCJSON] # pylint: disable=invalid-name class MuxRPCAPIException(Exception): @@ -32,7 +41,7 @@ class MuxRPCAPIException(Exception): class MuxRPCHandler: # pylint: disable=too-few-public-methods """Base MuxRPC handler class""" - def check_message(self, msg): + def check_message(self, msg: PSMessage) -> None: """Check message validity""" body = msg.body @@ -40,34 +49,53 @@ class MuxRPCHandler: # pylint: disable=too-few-public-methods if isinstance(body, dict) and "name" in body and body["name"] == "Error": raise MuxRPCAPIException(body["message"]) + def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]: + raise NotImplementedError() -class MuxRPCRequestHandler(MuxRPCHandler): + async def __anext__(self) -> Optional[PSMessage]: + raise NotImplementedError() + + def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None: + """Send a message through the stream""" + + raise NotImplementedError() + + async def get_response(self) -> PSMessage: + """Get the response for an RPC request""" + + raise NotImplementedError() + + +class MuxRPCRequestHandler(MuxRPCHandler): # pylint: disable=abstract-method """MuxRPC handler for incoming RPC requests""" - def __init__(self, ps_handler): + def __init__(self, ps_handler: PSRequestHandler): self.ps_handler = ps_handler - async def get_response(self): + async def get_response(self) -> PSMessage: """Get the response data""" - msg = await self.ps_handler + msg = await self.ps_handler.__anext__() self.check_message(msg) return msg -class MuxRPCSourceHandler(MuxRPCHandler): +class MuxRPCSourceHandler(MuxRPCHandler): # pylint: disable=abstract-method """MuxRPC handler for source-type RPC requests""" - def __init__(self, ps_handler): + def __init__(self, ps_handler: PSStreamHandler): self.ps_handler = ps_handler - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]: return self - async def __anext__(self): + async def __anext__(self) -> Optional[PSMessage]: msg = await self.ps_handler.__anext__() + + assert msg + self.check_message(msg) return msg @@ -76,64 +104,74 @@ class MuxRPCSourceHandler(MuxRPCHandler): class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods """Mixin for sink-type MuxRPC handlers""" - def send(self, msg, msg_type=PSMessageType.JSON, end=False): + connection: PacketStream + req: int + + def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None: """Send a message through the stream""" self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end) -class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler): +class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler): # pylint: disable=abstract-method """MuxRPC handler for duplex streams""" - def __init__(self, ps_handler, connection, req): + def __init__(self, ps_handler: PSStreamHandler, connection: PacketStream, req: int): super().__init__(ps_handler) self.connection = connection self.req = req -class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin): +class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin): # pylint: disable=abstract-method """MuxRPC handler for sinks""" - def __init__(self, connection, req): + def __init__(self, connection: PacketStream, req: int): self.connection = connection self.req = req -def _get_appropriate_api_handler(type_, connection, ps_handler, req): +def _get_appropriate_api_handler( + type_: MuxRPCCallType, connection: PacketStream, ps_handler: Union[PSRequestHandler, PSStreamHandler], req: int +) -> MuxRPCHandler: """Find the appropriate MuxRPC handler""" if type_ in {"sync", "async"}: + assert isinstance(ps_handler, PSRequestHandler) return MuxRPCRequestHandler(ps_handler) if type_ == "source": + assert isinstance(ps_handler, PSStreamHandler) return MuxRPCSourceHandler(ps_handler) if type_ == "sink": return MuxRPCSinkHandler(connection, req) if type_ == "duplex": + assert isinstance(ps_handler, PSStreamHandler) return MuxRPCDuplexHandler(ps_handler, connection, req) - return None + raise TypeError(f"Unknown request type {type_}") class MuxRPCRequest: """MuxRPC request""" @classmethod - def from_message(cls, message): + def from_message(cls, message: PSMessage) -> Self: """Initialise a request from a raw packet stream message""" body = message.body + assert isinstance(body, dict) + return cls(".".join(body["name"]), body["args"]) - def __init__(self, name, args): + def __init__(self, name: str, args: List[MuxRPCRequestParam]): self.name = name self.args = args - def __repr__(self): + def __repr__(self) -> str: return f"" @@ -141,28 +179,30 @@ class MuxRPCMessage: """MuxRPC message""" @classmethod - def from_message(cls, message): + def from_message(cls, message: PSMessage) -> Self: """Initialise a MuxRPC message from a raw packet stream message""" return cls(message.body) - def __init__(self, body): + def __init__(self, body: Union[bytes, str, Dict[str, Any]]): self.body = body - def __repr__(self): - return f"" + def __repr__(self) -> str: + return f"" class MuxRPCAPI: """Generic MuxRPC API""" - def __init__(self): - self.handlers = {} - self.connection = None + def __init__(self) -> None: + self.handlers: Dict[str, MuxRPCRequestHandlerType] = {} + self.connection: Optional[PacketStream] = None - async def process_messages(self): + async def process_messages(self) -> None: """Continuously process incoming messages""" + assert self.connection + async for req_message in self.connection: if req_message is None: return @@ -172,22 +212,22 @@ class MuxRPCAPI: if isinstance(body, dict) and body.get("name"): self.process(self.connection, MuxRPCRequest.from_message(req_message)) - def add_connection(self, connection): + def add_connection(self, connection: PacketStream) -> None: """Set the packet stream connection of this RPC API""" self.connection = connection - def define(self, name): + def define(self, name: str) -> Callable[[MuxRPCRequestHandlerType], MuxRPCRequestHandlerType]: """Decorator to define an RPC method handler""" - def _handle(f): + def _handle(f: MuxRPCRequestHandlerType) -> MuxRPCRequestHandlerType: self.handlers[name] = f return f return _handle - def process(self, connection, request): + def process(self, connection: PacketStream, request: MuxRPCRequest) -> None: """Process an incoming request""" handler = self.handlers.get(request.name) @@ -197,9 +237,11 @@ class MuxRPCAPI: handler(connection, request) - def call(self, name, args, type_="sync"): + def call(self, name: str, args: List[MuxRPCRequestParam], type_: MuxRPCCallType = "sync") -> MuxRPCHandler: """Call an RPC method""" + assert self.connection + if not self.connection.is_connected: raise Exception("not connected") # pylint: disable=broad-exception-raised diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index f61941e..7055ab0 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -28,9 +28,13 @@ import logging from math import ceil import struct from time import time +from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union +from secret_handshake.network import SHSDuplexStream import simplejson +from typing_extensions import Self +PSMessageData = Union[bytes, bool, Dict[str, Any], str] logger = logging.getLogger("packet_stream") @@ -45,25 +49,27 @@ class PSMessageType(Enum): class PSStreamHandler: """Packet stream handler""" - def __init__(self, req): - super(PSStreamHandler).__init__() + def __init__(self, req: int): + super().__init__() self.req = req - self.queue = Queue() + self.queue: Queue["PSMessage"] = Queue() - async def process(self, msg): + async def process(self, msg: "PSMessage") -> None: """Process a pending message""" await self.queue.put(msg) - async def stop(self): + async def stop(self) -> None: """Stop a pending request""" - await self.queue.put(None) + # We use the None value internally to signal __anext__ that the stream can be closed. It is not used otherwise, + # hence the typing ignore + await self.queue.put(None) # type: ignore[arg-type] - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[Optional["PSMessage"]]: return self - async def __anext__(self): + async def __anext__(self) -> Optional["PSMessage"]: elem = await self.queue.get() if not elem: @@ -75,30 +81,32 @@ class PSStreamHandler: class PSRequestHandler: """Packet stream request handler""" - def __init__(self, req): + def __init__(self, req: int): self.req = req self.event = Event() - self._msg = None + self._msg: Optional["PSMessage"] = None - async def process(self, msg): + async def process(self, msg: "PSMessage") -> None: """Process a message request""" self._msg = msg self.event.set() - async def stop(self): + async def stop(self) -> None: """Stop a pending event request""" if not self.event.is_set(): self.event.set() - def __aiter__(self): + def __aiter__(self) -> AsyncIterator["PSMessage"]: return self - async def __anext__(self): + async def __anext__(self) -> "PSMessage": # wait until 'process' is called await self.event.wait() + assert self._msg + return self._msg @@ -106,42 +114,55 @@ class PSMessage: """Packet Stream message""" @classmethod - def from_header_body(cls, flags, req, body): + def from_header_body(cls, flags: int, req: int, body: bytes) -> Self: """Parse a raw message""" type_ = PSMessageType(flags & 0x03) if type_ == PSMessageType.TEXT: - body = body.decode("utf-8") + decoded_body: Union[str, Dict[str, Any], bytes] = body.decode("utf-8") elif type_ == PSMessageType.JSON: - body = simplejson.loads(body) + decoded_body = simplejson.loads(body) + else: + decoded_body = body - return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req) + return cls(type_, decoded_body, bool(flags & 0x08), bool(flags & 0x04), req=req) @property - def data(self): + def data(self) -> bytes: """The raw message data""" if self.type == PSMessageType.TEXT: + assert isinstance(self.body, str) return self.body.encode("utf-8") if self.type == PSMessageType.JSON: + assert isinstance(self.body, dict) return simplejson.dumps(self.body).encode("utf-8") + assert isinstance(self.body, bytes) + return self.body - def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments + def __init__( + self, + type_: PSMessageType, + body: Union[bytes, str, Dict[str, Any]], + stream: bool, + end_err: bool, + req: Optional[int] = None, + ): # pylint: disable=too-many-arguments self.stream = stream self.end_err = end_err self.type = type_ self.body = body self.req = req - def __repr__(self): + def __repr__(self) -> str: if self.type == PSMessageType.BUFFER: body = f"{len(self.body)} bytes" else: - body = self.body + body = str(self.body) req = "" if self.req is None else f" [{self.req}]" is_stream = "~" if self.stream else "" @@ -153,79 +174,90 @@ class PSMessage: class PacketStream: """SSB Packet stream""" - def __init__(self, connection): + def __init__(self, connection: SHSDuplexStream): self.connection = connection self.req_counter = 1 - self._event_map = {} + self._event_map: Dict[int, Tuple[float, Union[PSRequestHandler, PSStreamHandler]]] = {} self._connected = False - def register_handler(self, handler): + def register_handler(self, handler: Union[PSRequestHandler, PSStreamHandler]) -> None: """Register an RPC handler""" self._event_map[handler.req] = (time(), handler) @property - def is_connected(self): + def is_connected(self) -> bool: """Check if the stream is connected""" return self.connection.is_connected - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]: return self - async def __anext__(self): + async def __anext__(self) -> PSMessage: while True: msg = await self.read() if not msg: raise StopAsyncIteration() - if msg.req >= 0: + if msg.req is not None and msg.req >= 0: logger.info("RECV: %r", msg) return msg - async def _read(self): + async def _read(self) -> Optional[PSMessage]: try: header = await self.connection.read() if not header or header == b"\x00" * 9: - return + return None flags, length, req = struct.unpack(">BIi", header) - n_packets = ceil(length / 4096) - body = b"" for _ in range(n_packets): - body += await self.connection.read() + read_data = await self.connection.read() + + if not read_data: + logger.debug("DISCONNECT") + self.connection.disconnect() + + return None + + body += read_data logger.debug("READ %s %s", header, len(body)) + return PSMessage.from_header_body(flags, req, body) except StopAsyncIteration: logger.debug("DISCONNECT") self.connection.disconnect() return None - async def read(self): + async def read(self) -> Optional[PSMessage]: """Read data from the packet stream""" msg = await self._read() + if not msg: return None + # check whether it's a reply and handle accordingly - if msg.req < 0: + if msg.req is not None and msg.req < 0: _, handler = self._event_map[-msg.req] await handler.process(msg) logger.info("RESPONSE [%d]: %r", -msg.req, msg) + if msg.end_err: await handler.stop() del self._event_map[-msg.req] logger.info("RESPONSE [%d]: EOS", -msg.req) + return msg - def _write(self, msg): + def _write(self, msg: PSMessage) -> None: logger.info("SEND [%d]: %r", msg.req, msg) header = struct.pack( ">BIi", @@ -239,11 +271,17 @@ class PacketStream: logger.debug("WRITE DATA: %s", msg.data) def send( # pylint: disable=too-many-arguments - self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None - ): + self, + data: Union[bytes, str, Dict[str, Any]], + msg_type: PSMessageType = PSMessageType.JSON, + stream: bool = False, + end_err: bool = False, + req: Optional[int] = None, + ) -> Union[PSRequestHandler, PSStreamHandler]: """Send data through the packet stream""" update_counter = False + if req is None: update_counter = True req = self.req_counter @@ -254,16 +292,18 @@ class PacketStream: self._write(msg) if stream: - handler = PSStreamHandler(self.req_counter) + handler: Union[PSRequestHandler, PSStreamHandler] = PSStreamHandler(self.req_counter) else: handler = PSRequestHandler(self.req_counter) + self.register_handler(handler) if update_counter: self.req_counter += 1 + return handler - def disconnect(self): + def disconnect(self) -> None: """Disconnect the stream""" self._connected = False diff --git a/ssb/py.typed b/ssb/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/ssb/util.py b/ssb/util.py index 09ad4e5..c0460f3 100644 --- a/ssb/util.py +++ b/ssb/util.py @@ -24,23 +24,30 @@ from base64 import b64decode, b64encode import os -from typing import Optional +from typing import Optional, TypedDict -from nacl.signing import SigningKey +from nacl.signing import SigningKey, VerifyKey import yaml +class SSBSecret(TypedDict): + """Dictionary to hold an SSB identity""" + + keypair: SigningKey + id: str + + class ConfigException(Exception): """Exception to raise if there is a problem with the configuration data""" -def tag(key): +def tag(key: VerifyKey) -> bytes: """Create tag from public key""" return b"@" + b64encode(bytes(key)) + b".ed25519" -def load_ssb_secret(filename: Optional[str] = None): +def load_ssb_secret(filename: Optional[str] = None) -> SSBSecret: """Load SSB keys from ``filename`` or, if unset, from ``~/.ssb/secret``""" filename = filename or os.path.expanduser("~/.ssb/secret") diff --git a/tests/test_feed.py b/tests/test_feed.py index e76b2c1..9685b7d 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -50,7 +50,7 @@ SERIALIZED_M1 = b"""{ @pytest.fixture() -def local_feed(): +def local_feed() -> LocalFeed: """Fixture providing a local feed""" secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") @@ -58,14 +58,14 @@ def local_feed(): @pytest.fixture() -def remote_feed(): +def remote_feed() -> Feed: """Fixture providing a remote feed""" public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") return Feed(VerifyKey(public)) -def test_local_feed(): +def test_local_feed() -> None: """Test a local feed""" secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") @@ -75,7 +75,7 @@ def test_local_feed(): assert feed.id == "@I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=.ed25519" -def test_remote_feed(): +def test_remote_feed() -> None: """Test a remote feed""" public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") @@ -85,35 +85,21 @@ def test_remote_feed(): m1 = Message( feed, - OrderedDict( - [ - ("type", "about"), - ("about", feed.id), - ("name", "neo"), - ("description", "The Chosen One"), - ] - ), + OrderedDict([("type", "about"), ("about", feed.id), ("name", "neo"), ("description", "The Chosen One")]), "foo", timestamp=1495706260190, ) with pytest.raises(NoPrivateKeyException): - feed.sign(m1) + feed.sign(m1.serialize()) -def test_local_message(local_feed): # pylint: disable=redefined-outer-name +def test_local_message(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name """Test a local message""" m1 = LocalMessage( local_feed, - OrderedDict( - [ - ("type", "about"), - ("about", local_feed.id), - ("name", "neo"), - ("description", "The Chosen One"), - ] - ), + OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]), timestamp=1495706260190, ) assert m1.timestamp == 1495706260190 @@ -148,20 +134,13 @@ def test_local_message(local_feed): # pylint: disable=redefined-outer-name assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" -def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name +def test_remote_message(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name """Test a remote message""" signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519" m1 = Message( remote_feed, - OrderedDict( - [ - ("type", "about"), - ("about", remote_feed.id), - ("name", "neo"), - ("description", "The Chosen One"), - ] - ), + OrderedDict([("type", "about"), ("about", remote_feed.id), ("name", "neo"), ("description", "The Chosen One")]), signature, timestamp=1495706260190, ) @@ -175,12 +154,7 @@ def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name m2 = Message( remote_feed, OrderedDict( - [ - ("type", "about"), - ("about", remote_feed.id), - ("name", "morpheus"), - ("description", "Dude with big jaw"), - ] + [("type", "about"), ("about", remote_feed.id), ("name", "morpheus"), ("description", "Dude with big jaw")] ), signature, previous=m1, @@ -194,54 +168,37 @@ def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" -def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name +def test_remote_no_signature(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name """Test remote feed without a signature""" with pytest.raises(ValueError): Message( remote_feed, OrderedDict( - [ - ("type", "about"), - ("about", remote_feed.id), - ("name", "neo"), - ("description", "The Chosen One"), - ] + [("type", "about"), ("about", remote_feed.id), ("name", "neo"), ("description", "The Chosen One")] ), None, timestamp=1495706260190, ) -def test_serialize(local_feed): # pylint: disable=redefined-outer-name +def test_serialize(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name """Test feed serialization""" m1 = LocalMessage( local_feed, - OrderedDict( - [ - ("type", "about"), - ("about", local_feed.id), - ("name", "neo"), - ("description", "The Chosen One"), - ] - ), + OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]), timestamp=1495706260190, ) assert m1.serialize() == SERIALIZED_M1 -def test_parse(local_feed): # pylint: disable=redefined-outer-name +def test_parse(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name """Test feed parsing""" m1 = LocalMessage.parse(SERIALIZED_M1, local_feed) - assert m1.content == { - "type": "about", - "about": local_feed.id, - "name": "neo", - "description": "The Chosen One", - } + assert m1.content == {"type": "about", "about": local_feed.id, "name": "neo", "description": "The Chosen One"} assert m1.timestamp == 1495706260190 @@ -252,15 +209,15 @@ def test_local_unsigned(local_feed: LocalFeed, mocker: MockerFixture) -> None: mocked_dt.utcnow = mocker.MagicMock(return_value=datetime(2023, 3, 7, 11, 45, 54, 0, tzinfo=timezone.utc)) mocker.patch("ssb.feed.models.datetime", mocked_dt) - msg = LocalMessage(local_feed, b"test") + msg = LocalMessage(local_feed, OrderedDict({"test": True})) assert msg.feed == local_feed - assert msg.content == b"test" + assert msg.content == {"test": True} assert msg.sequence == 1 assert msg.previous is None assert msg.timestamp == 1678189554000 assert msg.signature == ( - "SxZsBINzsuQqmB6JLmXyr22+FRY33bp3wj1MwjAOU3MqifGqfc3W/2T5D4qel5mqrgJt9IT8c3QayB1suj82AQ==.sig.ed25519" + "WjkA5rjzsYDHqeavEPcbNAbRMp5NRFDBNATMWgcsccso8sfwhaWnIEvQW79fA5YgKKybzlIsCMWHherToEI2DA==.sig.ed25519" ) diff --git a/tests/test_packet_stream.py b/tests/test_packet_stream.py index 215c82c..a972718 100644 --- a/tests/test_packet_stream.py +++ b/tests/test_packet_stream.py @@ -23,18 +23,23 @@ """Tests for the packet stream""" from asyncio import Event, ensure_future, gather +from asyncio.events import AbstractEventLoop import json +from typing import Any, AsyncIterator, Awaitable, Callable, Generator, List, Optional import pytest +from pytest_mock import MockerFixture from secret_handshake.network import SHSDuplexStream -from ssb.packet_stream import PacketStream, PSMessageType +from ssb.packet_stream import PacketStream, PSMessage, PSMessageType -async def _collect_messages(generator): +async def _collect_messages(generator: AsyncIterator[Optional[PSMessage]]) -> List[Optional["PSMessage"]]: results = [] + async for msg in generator: results.append(msg) + return results @@ -61,45 +66,47 @@ MSG_BODY_2 = ( class MockSHSSocket(SHSDuplexStream): """A mocked SHS socket""" - def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + def __init__(self, *args: Any, **kwargs: Any): # pylint: disable=unused-argument super().__init__() - self.input = [] - self.output = [] - self.is_connected = False - self._on_connect = [] + self.input: List[bytes] = [] + self.output: List[bytes] = [] + self.is_connected: bool = False + self._on_connect: List[Callable[[], Awaitable[None]]] = [] - def on_connect(self, cb): + def on_connect(self, cb: Callable[[], Awaitable[None]]) -> None: """Set the on_connect callback""" self._on_connect.append(cb) - async def read(self): + async def read(self) -> Optional[bytes]: """Read data from the socket""" if not self.input: - raise StopAsyncIteration + return None + return self.input.pop(0) - def write(self, data): + def write(self, data: bytes) -> None: """Write data to the socket""" self.output.append(data) - def feed(self, input_): - """Get the connection’s feed""" + def feed(self, input_: List[bytes]) -> None: + """Feed data into the connection""" self.input += input_ - def get_output(self): + def get_output(self) -> Generator[bytes, None, None]: """Get the output of a call""" while True: if not self.output: break + yield self.output.pop(0) - def disconnect(self): + def disconnect(self) -> None: """Disconnect from the remote party""" self.is_connected = False @@ -108,7 +115,7 @@ class MockSHSSocket(SHSDuplexStream): class MockSHSClient(MockSHSSocket): """A mocked SHS client""" - async def connect(self): + async def connect(self) -> None: """Connect to a SHS server""" self.is_connected = True @@ -120,7 +127,7 @@ class MockSHSClient(MockSHSSocket): class MockSHSServer(MockSHSSocket): """A mocked SHS server""" - def listen(self): + def listen(self) -> None: """Listen for new connections""" self.is_connected = True @@ -130,26 +137,26 @@ class MockSHSServer(MockSHSSocket): @pytest.fixture -def ps_client(event_loop): # pylint: disable=unused-argument +def ps_client(event_loop: AbstractEventLoop) -> MockSHSClient: # pylint: disable=unused-argument """Fixture to provide a mocked SHS client""" return MockSHSClient() @pytest.fixture -def ps_server(event_loop): # pylint: disable=unused-argument +def ps_server(event_loop: AbstractEventLoop) -> MockSHSServer: # pylint: disable=unused-argument """Fixture to provide a mocked SHS server""" return MockSHSServer() @pytest.mark.asyncio -async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name +async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable=redefined-outer-name """Test the on_connect callback functionality""" called = Event() - async def _on_connect(): + async def _on_connect() -> None: called.set() ps_server.on_connect(_on_connect) @@ -159,7 +166,7 @@ async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name @pytest.mark.asyncio -async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name +async def test_message_decoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name """Test message decoding""" await ps_client.connect() @@ -178,6 +185,7 @@ async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-n messages = await _collect_messages(ps) assert len(messages) == 1 + assert messages[0] assert messages[0].type == PSMessageType.JSON assert messages[0].body == { "name": ["createHistoryStream"], @@ -194,7 +202,7 @@ async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-n @pytest.mark.asyncio -async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name +async def test_message_encoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name """Test message encoding""" await ps_client.connect() @@ -237,7 +245,9 @@ async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-n @pytest.mark.asyncio -async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name +async def test_message_stream( + ps_client: MockSHSClient, mocker: MockerFixture # pylint: disable=redefined-outer-name +) -> None: # pylint: disable=redefined-outer-name """Test requesting a history stream""" await ps_client.connect() @@ -264,7 +274,7 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o ) assert ps.req_counter == 2 - assert ps.register_handler.call_count == 1 # pylint: disable=no-member + assert ps.register_handler.call_count == 1 # type: ignore[attr-defined] # pylint: disable=no-member handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access mock_process = mocker.patch.object(handler, "process") @@ -273,6 +283,8 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o assert mock_process.await_count == 1 # responses have negative req + assert msg + assert isinstance(msg.body, dict) assert msg.req == -1 assert msg.body["previous"] == "%KTGP6W8vF80McRAZHYDWuKOD0KlNyKSq6Gb42iuV7Iw=.sha256" @@ -295,7 +307,7 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o ) assert ps.req_counter == 3 - assert ps.register_handler.call_count == 2 # pylint: disable=no-member + assert ps.register_handler.call_count == 2 # type: ignore[attr-defined] # pylint: disable=no-member handler = list(ps._event_map.values())[1][1] # pylint: disable=protected-access mock_process = mocker.patch.object(handler, "process", wraps=handler.process) @@ -318,11 +330,14 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o for msg in handled: # responses have negative req + assert msg assert msg.req == -2 @pytest.mark.asyncio -async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name +async def test_message_request( + ps_server: MockSHSServer, mocker: MockerFixture # pylint: disable=redefined-outer-name +) -> None: # pylint: disable=redefined-outer-name """Test message sending""" ps_server.listen() @@ -338,7 +353,7 @@ async def test_message_request(ps_server, mocker): # pylint: disable=redefined- assert json.loads(body.decode("utf-8")) == {"name": ["whoami"], "args": []} assert ps.req_counter == 2 - assert ps.register_handler.call_count == 1 # pylint: disable=no-member + assert ps.register_handler.call_count == 1 # type: ignore[attr-defined] # pylint: disable=no-member handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access mock_process = mocker.patch.object(handler, "process") ps_server.feed( @@ -351,6 +366,8 @@ async def test_message_request(ps_server, mocker): # pylint: disable=redefined- assert mock_process.await_count == 1 # responses have negative req + assert msg + assert isinstance(msg.body, dict) assert msg.req == -1 assert msg.body["id"] == "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519" assert ps.req_counter == 2 diff --git a/tests/test_util.py b/tests/test_util.py index 8122dd6..c2188f5 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -42,7 +42,7 @@ CONFIG_FILE = """ CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo") -def test_load_secret(): +def test_load_secret() -> None: """Test loading the SSB secret from a file""" with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True): @@ -55,7 +55,7 @@ def test_load_secret(): assert bytes(secret["keypair"].verify_key) == b64decode("rsYpBIcXsxjQAf0JNes+MHqT2DL+EfopWKAp4rGeEPQ=") -def test_load_exception(): +def test_load_exception() -> None: """Test configuration loading if there is a problem with the file""" with pytest.raises(ConfigException):