diff --git a/poetry.lock b/poetry.lock index 8f78f5b..962a840 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1258,6 +1258,28 @@ files = [ {file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"}, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + +[[package]] +name = "types-simplejson" +version = "3.19.0.2" +description = "Typing stubs for simplejson" +optional = false +python-versions = "*" +files = [ + {file = "types-simplejson-3.19.0.2.tar.gz", hash = "sha256:ebc81f886f89d99d6b80c726518aa2228bc77c26438f18fd81455e4f79f8ee1b"}, + {file = "types_simplejson-3.19.0.2-py3-none-any.whl", hash = "sha256:8ba093dc7884f59b3e62aed217144085e675a269debc32678fd80e0b43b2b86f"}, +] + [[package]] name = "typing-extensions" version = "4.8.0" @@ -1315,4 +1337,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "b503e50c4ab977c6785c68bc1e5bf2efb7ab88a1fd33770e84a7f612a85d2641" +content-hash = "d80cbfdf7923c50c95505a84d8ad75eae016ca81ae32a8b22d074569b0a0fcbd" diff --git a/pyproject.toml b/pyproject.toml index ba91826..4f4ba04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ commitizen = "^3.12.0" black = "^23.10.1" pylint = "^3.0.2" mypy = "^1.6.1" +types-pyyaml = "^6.0.12.12" +types-simplejson = "^3.19.0.2" [tool.poetry.group.docs.dependencies] Sphinx = "^2.1.1" diff --git a/ssb/feed/models.py b/ssb/feed/models.py index 958fd3c..93ba35e 100644 --- a/ssb/feed/models.py +++ b/ssb/feed/models.py @@ -4,8 +4,11 @@ from base64 import b64encode from collections import OrderedDict, namedtuple import datetime from hashlib import sha256 +from typing import Any, Dict, Optional +from nacl.signing import SigningKey, VerifyKey from simplejson import dumps, loads +from typing_extensions import Self from ssb.util import tag @@ -16,7 +19,7 @@ class NoPrivateKeyException(Exception): """Exception to raise when a private key is not available""" -def to_ordered(data): +def to_ordered(data: Dict[str, Any]) -> OrderedDict[str, Any]: """Convert a dictionary to an ``OrderedDict``""" smsg = OrderedMsg(**data) @@ -24,7 +27,7 @@ def to_ordered(data): return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields) -def get_millis_1970(): +def get_millis_1970() -> int: """Get the UNIX timestamp in milliseconds""" return int(datetime.datetime.utcnow().timestamp() * 1000) @@ -33,16 +36,16 @@ def get_millis_1970(): class Feed: """Base class for feeds""" - def __init__(self, public_key): + def __init__(self, public_key: VerifyKey): self.public_key = public_key @property - def id(self): + def id(self) -> str: """The identifier of the feed""" return tag(self.public_key).decode("ascii") - def sign(self, msg): + def sign(self, msg: "Message") -> bytes: """Sign a message""" raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)") @@ -51,16 +54,20 @@ class Feed: class LocalFeed(Feed): """Class representing a local feed""" - def __init__(self, private_key): # pylint: disable=super-init-not-called - self.private_key = private_key + def __init__(self, private_key: SigningKey): # pylint: disable=super-init-not-called + self.private_key: SigningKey = private_key @property - def public_key(self): + def public_key(self) -> VerifyKey: """The public key of the feed""" return self.private_key.verify_key - def sign(self, msg): + @public_key.setter + def public_key(self, _: VerifyKey) -> None: + raise TypeError("Cannot set just the public key of a local feed") + + def sign(self, msg: "Message") -> bytes: """Sign a message for this feed""" return self.private_key.sign(msg).signature @@ -70,7 +77,13 @@ class Message: """Base class for SSB messages""" def __init__( # pylint: disable=too-many-arguments - self, feed, content, signature=None, sequence=1, timestamp=None, previous=None + self, + feed: Feed, + content: Dict[str, Any], + signature: Optional[str] = None, + sequence: int = 1, + timestamp: Optional[int] = None, + previous: Optional["Message"] = None, ): self.feed = feed self.content = content @@ -81,15 +94,16 @@ class Message: self.signature = signature self.previous = previous + if self.previous: - self.sequence = self.previous.sequence + 1 + self.sequence: int = self.previous.sequence + 1 else: self.sequence = sequence self.timestamp = get_millis_1970() if timestamp is None else timestamp @classmethod - def parse(cls, data, feed): + def parse(cls, data: bytes, feed: Feed) -> Self: """Parse raw message data""" obj = loads(data, object_pairs_hook=OrderedDict) @@ -97,12 +111,12 @@ class Message: return msg - def serialize(self, add_signature=True): + def serialize(self, add_signature: bool = True) -> bytes: """Serialize the message""" return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8") - def to_dict(self, add_signature=True): + def to_dict(self, add_signature: bool = True) -> OrderedDict[str, Any]: """Convert the message to a dictionary""" obj = to_ordered( @@ -121,20 +135,21 @@ class Message: return obj - def verify(self, signature): + def verify(self, signature: str) -> bool: """Verify the signature of the message""" return self.signature == signature @property - def hash(self): + def hash(self) -> str: """The cryptographic hash of the message""" hash_ = sha256(self.serialize()).digest() + return b64encode(hash_).decode("ascii") + ".sha256" @property - def key(self): + def key(self) -> str: """The key of the message""" return "%" + self.hash @@ -144,7 +159,13 @@ class LocalMessage(Message): """Class representing a local message""" def __init__( # pylint: disable=too-many-arguments,super-init-not-called - self, feed, content, signature=None, sequence=1, timestamp=None, previous=None + self, + feed: Feed, + content: Dict[str, Any], + signature: Optional[str] = None, + sequence: int = 1, + timestamp: Optional[int] = None, + previous: Optional[Message] = None, ): self.feed = feed self.content = content @@ -162,7 +183,8 @@ class LocalMessage(Message): else: self.signature = signature - def _sign(self): + def _sign(self) -> str: # ensure ordering of keys and indentation of 2 characters, like ssb-keys data = self.serialize(add_signature=False) + return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii") diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index bdb21d1..819a093 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -6,14 +6,15 @@ import logging from math import ceil import struct from time import time -from typing import Any, AsyncIterator, Dict, Optional, Union +from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union from secret_handshake.network import SHSDuplexStream import simplejson +from typing_extensions import Self -logger = logging.getLogger("packet_stream") - +PSHandler = Union["PSRequestHandler", "PSStreamHandler"] PSMessageData = Union[bytes, bool, Dict[str, Any], str] +logger = logging.getLogger("packet_stream") class PSMessageType(Enum): @@ -31,12 +32,12 @@ class PSStreamHandler: self.req = req self.queue: Queue[Optional["PSMessage"]] = Queue() - async def process(self, msg): + async def process(self, msg: "PSMessage") -> None: """Process a pending message""" await self.queue.put(msg) - async def stop(self): + async def stop(self) -> None: """Stop a pending request""" await self.queue.put(None) @@ -59,15 +60,15 @@ class PSRequestHandler: def __init__(self, req: int): self.req = req self.event = Event() - self._msg = None + self._msg: Optional[PSMessage] = None - async def process(self, msg): + async def process(self, msg: "PSMessage") -> None: """Process a message request""" self._msg = msg self.event.set() - async def stop(self): + async def stop(self) -> None: """Stop a pending event request""" if not self.event.is_set(): @@ -87,37 +88,44 @@ class PSMessage: """Packet Stream message""" @classmethod - def from_header_body(cls, flags, req, body): + def from_header_body(cls, flags: int, req: int, body: bytes) -> Self: """Parse a raw message""" + type_ = PSMessageType(flags & 0x03) if type_ == PSMessageType.TEXT: - body = body.decode("utf-8") + body_s = body.decode("utf-8") elif type_ == PSMessageType.JSON: - body = simplejson.loads(body) + body_s = simplejson.loads(body) - return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req) + return cls(type_, body_s, bool(flags & 0x08), bool(flags & 0x04), req=req) @property def data(self) -> bytes: """The raw message data""" if self.type == PSMessageType.TEXT: + assert isinstance(self.body, str) + return self.body.encode("utf-8") if self.type == PSMessageType.JSON: return simplejson.dumps(self.body).encode("utf-8") + assert isinstance(self.body, bytes) + return self.body - def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments + def __init__( + self, type_: PSMessageType, body: Any, stream: bool, end_err: bool, req: Optional[int] = None + ): # pylint: disable=too-many-arguments self.stream = stream self.end_err = end_err self.type = type_ self.body = body self.req = req - def __repr__(self): + def __repr__(self) -> str: if self.type == PSMessageType.BUFFER: body = f"{len(self.body)} bytes" else: @@ -136,16 +144,16 @@ class PacketStream: def __init__(self, connection: SHSDuplexStream): self.connection = connection self.req_counter = 1 - self._event_map = {} + self._event_map: Dict[int, Tuple[float, PSHandler]] = {} self._connected = False - def register_handler(self, handler): + def register_handler(self, handler: PSHandler) -> None: """Register an RPC handler""" self._event_map[handler.req] = (time(), handler) @property - def is_connected(self): + def is_connected(self) -> bool: """Check if the stream is connected""" return self.connection.is_connected @@ -159,22 +167,25 @@ class PacketStream: if not msg: raise StopAsyncIteration() - if msg.req >= 0: + if msg.req is not None and msg.req >= 0: return msg return None - async def __await__(self): + async def __await__(self) -> None: async for data in self: logger.info("RECV: %r", data) + if data is None: return - async def _read(self): + async def _read(self) -> Optional[PSMessage]: try: header = await self.connection.read() + if not header or header == b"\x00" * 9: - return + return None + flags, length, req = struct.unpack(">BIi", header) n_packets = ceil(length / 4096) @@ -182,30 +193,39 @@ class PacketStream: body = b"" for _ in range(n_packets): - body += await self.connection.read() + read_data = await self.connection.read() + + if read_data is not None: + body += read_data logger.debug("READ %s %s", header, len(body)) + return PSMessage.from_header_body(flags, req, body) except StopAsyncIteration: logger.debug("DISCONNECT") self.connection.disconnect() + return None - async def read(self): + async def read(self) -> Optional[PSMessage]: """Read data from the packet stream""" msg = await self._read() + if not msg: return None + # check whether it's a reply and handle accordingly - if msg.req < 0: + if msg.req is not None and msg.req < 0: _, handler = self._event_map[-msg.req] await handler.process(msg) logger.info("RESPONSE [%d]: %r", -msg.req, msg) + if msg.end_err: await handler.stop() del self._event_map[-msg.req] logger.info("RESPONSE [%d]: EOS", -msg.req) + return msg def _write(self, msg: PSMessage) -> None: @@ -225,7 +245,7 @@ class PacketStream: stream: bool = False, end_err: bool = False, req: Optional[int] = None, - ): + ) -> PSHandler: """Send data through the packet stream""" update_counter = False @@ -240,7 +260,7 @@ class PacketStream: self._write(msg) if stream: - handler = PSStreamHandler(self.req_counter) + handler: PSHandler = PSStreamHandler(self.req_counter) else: handler = PSRequestHandler(self.req_counter) @@ -251,7 +271,7 @@ class PacketStream: return handler - def disconnect(self): + def disconnect(self) -> None: """Disconnect the stream""" self._connected = False diff --git a/ssb/util.py b/ssb/util.py index 181c3d1..aad4afa 100644 --- a/ssb/util.py +++ b/ssb/util.py @@ -4,12 +4,12 @@ from base64 import b64decode, b64encode import os from typing import TypedDict -from nacl.signing import SigningKey +from nacl.signing import SigningKey, VerifyKey import yaml class SSBSecret(TypedDict): - """Dictionary to hold an SSB identity""" + """Dictionary type to hold an SSB secret identity""" keypair: SigningKey id: str @@ -19,7 +19,7 @@ class ConfigException(Exception): """Exception to raise if there is a problem with the configuration data""" -def tag(key): +def tag(key: VerifyKey) -> bytes: """Create tag from public key.""" return b"@" + b64encode(bytes(key)) + b".ed25519" @@ -35,4 +35,5 @@ def load_ssb_secret() -> SSBSecret: raise ConfigException("Algorithm not known: " + config["curve"]) server_prv_key = b64decode(config["private"][:-8]) + return {"keypair": SigningKey(server_prv_key[:32]), "id": config["id"]} diff --git a/tests/test_feed.py b/tests/test_feed.py index 63c0b3a..5f3989a 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -25,7 +25,7 @@ SERIALIZED_M1 = b"""{ @pytest.fixture -def local_feed(): +def local_feed() -> LocalFeed: """Fixture providing a local feed""" secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") @@ -33,14 +33,14 @@ def local_feed(): @pytest.fixture -def remote_feed(): +def remote_feed() -> Feed: """Fixture providing a remote feed""" public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") return Feed(VerifyKey(public)) -def test_local_feed(): +def test_local_feed() -> None: """Test a local feed""" secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") @@ -50,7 +50,7 @@ def test_local_feed(): assert feed.id == "@I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=.ed25519" -def test_remote_feed(): +def test_remote_feed() -> None: """Test a remote feed""" public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") @@ -69,7 +69,7 @@ def test_remote_feed(): feed.sign(m1) -def test_local_message(local_feed): # pylint: disable=redefined-outer-name +def test_local_message(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name """Test a local message""" m1 = LocalMessage( @@ -102,7 +102,7 @@ def test_local_message(local_feed): # pylint: disable=redefined-outer-name assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" -def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name +def test_remote_message(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name """Test a remote message""" signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519" @@ -136,7 +136,7 @@ def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" -def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name +def test_remote_no_signature(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name """Test remote feed without a signature""" with pytest.raises(ValueError): @@ -150,7 +150,7 @@ def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-na ) -def test_serialize(local_feed): # pylint: disable=redefined-outer-name +def test_serialize(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name """Test feed serialization""" m1 = LocalMessage( @@ -162,7 +162,7 @@ def test_serialize(local_feed): # pylint: disable=redefined-outer-name assert m1.serialize() == SERIALIZED_M1 -def test_parse(local_feed): # pylint: disable=redefined-outer-name +def test_parse(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name """Test feed parsing""" m1 = LocalMessage.parse(SERIALIZED_M1, local_feed) diff --git a/tests/test_packet_stream.py b/tests/test_packet_stream.py index 5dff5cf..4e7a55f 100644 --- a/tests/test_packet_stream.py +++ b/tests/test_packet_stream.py @@ -1,18 +1,23 @@ """Tests for the packet stream""" from asyncio import Event, ensure_future, gather +from asyncio.events import AbstractEventLoop import json +from typing import AsyncGenerator, Awaitable, Callable, Generator, List import pytest +from pytest_mock import MockerFixture from secret_handshake.network import SHSDuplexStream -from ssb.packet_stream import PacketStream, PSMessageType +from ssb.packet_stream import PacketStream, PSMessage, PSMessageType -async def _collect_messages(generator): +async def _collect_messages(generator: AsyncGenerator[PSMessage, None]) -> List[PSMessage]: results = [] + async for msg in generator: results.append(msg) + return results @@ -39,45 +44,47 @@ MSG_BODY_2 = ( class MockSHSSocket(SHSDuplexStream): """A mocked SHS socket""" - def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + def __init__(self): # pylint: disable=unused-argument super().__init__() - self.input = [] - self.output = [] + self.input: List[bytes] = [] + self.output: List[bytes] = [] self.is_connected = False - self._on_connect = [] + self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = [] - def on_connect(self, cb): + def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None: """Set the on_connect callback""" self._on_connect.append(cb) - async def read(self): + async def read(self) -> bytes: """Read data from the socket""" if not self.input: raise StopAsyncIteration + return self.input.pop(0) - def write(self, data): + def write(self, data: bytes) -> None: """Write data to the socket""" self.output.append(data) - def feed(self, input_): - """Get the connection’s feed""" + def feed(self, input_: List[bytes]) -> None: + """Feed data into the connection""" self.input += input_ - def get_output(self): + def get_output(self) -> Generator[bytes, None, None]: """Get the output of a call""" while True: if not self.output: break + yield self.output.pop(0) - def disconnect(self): + def disconnect(self) -> None: """Disconnect from the remote party""" self.is_connected = False @@ -86,48 +93,48 @@ class MockSHSSocket(SHSDuplexStream): class MockSHSClient(MockSHSSocket): """A mocked SHS client""" - async def connect(self): + async def connect(self) -> None: """Connect to a SHS server""" self.is_connected = True for cb in self._on_connect: - await cb() + await cb(self) class MockSHSServer(MockSHSSocket): """A mocked SHS server""" - def listen(self): + def listen(self) -> None: """Listen for new connections""" self.is_connected = True for cb in self._on_connect: - ensure_future(cb()) + ensure_future(cb(self)) @pytest.fixture -def ps_client(event_loop): # pylint: disable=unused-argument +def ps_client(event_loop: AbstractEventLoop) -> MockSHSClient: # pylint: disable=unused-argument """Fixture to provide a mocked SHS client""" return MockSHSClient() @pytest.fixture -def ps_server(event_loop): # pylint: disable=unused-argument +def ps_server(event_loop: AbstractEventLoop) -> MockSHSServer: # pylint: disable=unused-argument """Fixture to provide a mocked SHS server""" return MockSHSServer() @pytest.mark.asyncio -async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name +async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable=redefined-outer-name """Test the on_connect callback functionality""" called = Event() - async def _on_connect(): + async def _on_connect(_: SHSDuplexStream) -> None: called.set() ps_server.on_connect(_on_connect) @@ -137,7 +144,7 @@ async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name @pytest.mark.asyncio -async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name +async def test_message_decoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name """Test message decoding""" await ps_client.connect() @@ -167,7 +174,7 @@ async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-n @pytest.mark.asyncio -async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name +async def test_message_encoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name """Test message encoding""" await ps_client.connect() @@ -200,7 +207,7 @@ async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-n @pytest.mark.asyncio -async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name +async def test_message_stream(ps_client: MockSHSClient, mocker: MockerFixture): # pylint: disable=redefined-outer-name """Test requesting a history stream""" await ps_client.connect() @@ -272,7 +279,9 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o @pytest.mark.asyncio -async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name +async def test_message_request( + ps_server: MockSHSServer, mocker: MockerFixture # pylint: disable=redefined-outer-name +) -> None: """Test message sending""" ps_server.listen() diff --git a/tests/test_util.py b/tests/test_util.py index d3594f8..b3d5a66 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -20,7 +20,7 @@ CONFIG_FILE = """ CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo") -def test_load_secret(): +def test_load_secret() -> None: """Test loading the SSB secret from a file""" with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True): @@ -33,8 +33,9 @@ def test_load_secret(): assert bytes(secret["keypair"].verify_key) == b64decode("rsYpBIcXsxjQAf0JNes+MHqT2DL+EfopWKAp4rGeEPQ=") -def test_load_exception(): +def test_load_exception() -> None: """Test configuration loading if there is a problem with the file""" + with pytest.raises(ConfigException): with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True): load_ssb_secret()