ci: Add and configure mypy, and make it happy

This commit is contained in:
Gergely Polonkai 2023-11-01 07:22:29 +01:00
parent 5aa4f16a5a
commit f6e58b7682
No known key found for this signature in database
GPG Key ID: 2D2885533B869ED4
8 changed files with 163 additions and 86 deletions

24
poetry.lock generated
View File

@ -1258,6 +1258,28 @@ files = [
{file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"}, {file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"},
] ]
[[package]]
name = "types-pyyaml"
version = "6.0.12.12"
description = "Typing stubs for PyYAML"
optional = false
python-versions = "*"
files = [
{file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"},
{file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"},
]
[[package]]
name = "types-simplejson"
version = "3.19.0.2"
description = "Typing stubs for simplejson"
optional = false
python-versions = "*"
files = [
{file = "types-simplejson-3.19.0.2.tar.gz", hash = "sha256:ebc81f886f89d99d6b80c726518aa2228bc77c26438f18fd81455e4f79f8ee1b"},
{file = "types_simplejson-3.19.0.2-py3-none-any.whl", hash = "sha256:8ba093dc7884f59b3e62aed217144085e675a269debc32678fd80e0b43b2b86f"},
]
[[package]] [[package]]
name = "typing-extensions" name = "typing-extensions"
version = "4.8.0" version = "4.8.0"
@ -1315,4 +1337,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "b503e50c4ab977c6785c68bc1e5bf2efb7ab88a1fd33770e84a7f612a85d2641" content-hash = "d80cbfdf7923c50c95505a84d8ad75eae016ca81ae32a8b22d074569b0a0fcbd"

View File

@ -27,6 +27,8 @@ commitizen = "^3.12.0"
black = "^23.10.1" black = "^23.10.1"
pylint = "^3.0.2" pylint = "^3.0.2"
mypy = "^1.6.1" mypy = "^1.6.1"
types-pyyaml = "^6.0.12.12"
types-simplejson = "^3.19.0.2"
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
Sphinx = "^2.1.1" Sphinx = "^2.1.1"

View File

@ -4,8 +4,11 @@ from base64 import b64encode
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
import datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import Any, Dict, Optional
from nacl.signing import SigningKey, VerifyKey
from simplejson import dumps, loads from simplejson import dumps, loads
from typing_extensions import Self
from ssb.util import tag from ssb.util import tag
@ -16,7 +19,7 @@ class NoPrivateKeyException(Exception):
"""Exception to raise when a private key is not available""" """Exception to raise when a private key is not available"""
def to_ordered(data): def to_ordered(data: Dict[str, Any]) -> OrderedDict[str, Any]:
"""Convert a dictionary to an ``OrderedDict``""" """Convert a dictionary to an ``OrderedDict``"""
smsg = OrderedMsg(**data) smsg = OrderedMsg(**data)
@ -24,7 +27,7 @@ def to_ordered(data):
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields) return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
def get_millis_1970(): def get_millis_1970() -> int:
"""Get the UNIX timestamp in milliseconds""" """Get the UNIX timestamp in milliseconds"""
return int(datetime.datetime.utcnow().timestamp() * 1000) return int(datetime.datetime.utcnow().timestamp() * 1000)
@ -33,16 +36,16 @@ def get_millis_1970():
class Feed: class Feed:
"""Base class for feeds""" """Base class for feeds"""
def __init__(self, public_key): def __init__(self, public_key: VerifyKey):
self.public_key = public_key self.public_key = public_key
@property @property
def id(self): def id(self) -> str:
"""The identifier of the feed""" """The identifier of the feed"""
return tag(self.public_key).decode("ascii") return tag(self.public_key).decode("ascii")
def sign(self, msg): def sign(self, msg: "Message") -> bytes:
"""Sign a message""" """Sign a message"""
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)") raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
@ -51,16 +54,20 @@ class Feed:
class LocalFeed(Feed): class LocalFeed(Feed):
"""Class representing a local feed""" """Class representing a local feed"""
def __init__(self, private_key): # pylint: disable=super-init-not-called def __init__(self, private_key: SigningKey): # pylint: disable=super-init-not-called
self.private_key = private_key self.private_key: SigningKey = private_key
@property @property
def public_key(self): def public_key(self) -> VerifyKey:
"""The public key of the feed""" """The public key of the feed"""
return self.private_key.verify_key return self.private_key.verify_key
def sign(self, msg): @public_key.setter
def public_key(self, _: VerifyKey) -> None:
raise TypeError("Cannot set just the public key of a local feed")
def sign(self, msg: "Message") -> bytes:
"""Sign a message for this feed""" """Sign a message for this feed"""
return self.private_key.sign(msg).signature return self.private_key.sign(msg).signature
@ -70,7 +77,13 @@ class Message:
"""Base class for SSB messages""" """Base class for SSB messages"""
def __init__( # pylint: disable=too-many-arguments def __init__( # pylint: disable=too-many-arguments
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None self,
feed: Feed,
content: Dict[str, Any],
signature: Optional[str] = None,
sequence: int = 1,
timestamp: Optional[int] = None,
previous: Optional["Message"] = None,
): ):
self.feed = feed self.feed = feed
self.content = content self.content = content
@ -81,15 +94,16 @@ class Message:
self.signature = signature self.signature = signature
self.previous = previous self.previous = previous
if self.previous: if self.previous:
self.sequence = self.previous.sequence + 1 self.sequence: int = self.previous.sequence + 1
else: else:
self.sequence = sequence self.sequence = sequence
self.timestamp = get_millis_1970() if timestamp is None else timestamp self.timestamp = get_millis_1970() if timestamp is None else timestamp
@classmethod @classmethod
def parse(cls, data, feed): def parse(cls, data: bytes, feed: Feed) -> Self:
"""Parse raw message data""" """Parse raw message data"""
obj = loads(data, object_pairs_hook=OrderedDict) obj = loads(data, object_pairs_hook=OrderedDict)
@ -97,12 +111,12 @@ class Message:
return msg return msg
def serialize(self, add_signature=True): def serialize(self, add_signature: bool = True) -> bytes:
"""Serialize the message""" """Serialize the message"""
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8") return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
def to_dict(self, add_signature=True): def to_dict(self, add_signature: bool = True) -> OrderedDict[str, Any]:
"""Convert the message to a dictionary""" """Convert the message to a dictionary"""
obj = to_ordered( obj = to_ordered(
@ -121,20 +135,21 @@ class Message:
return obj return obj
def verify(self, signature): def verify(self, signature: str) -> bool:
"""Verify the signature of the message""" """Verify the signature of the message"""
return self.signature == signature return self.signature == signature
@property @property
def hash(self): def hash(self) -> str:
"""The cryptographic hash of the message""" """The cryptographic hash of the message"""
hash_ = sha256(self.serialize()).digest() hash_ = sha256(self.serialize()).digest()
return b64encode(hash_).decode("ascii") + ".sha256" return b64encode(hash_).decode("ascii") + ".sha256"
@property @property
def key(self): def key(self) -> str:
"""The key of the message""" """The key of the message"""
return "%" + self.hash return "%" + self.hash
@ -144,7 +159,13 @@ class LocalMessage(Message):
"""Class representing a local message""" """Class representing a local message"""
def __init__( # pylint: disable=too-many-arguments,super-init-not-called def __init__( # pylint: disable=too-many-arguments,super-init-not-called
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None self,
feed: Feed,
content: Dict[str, Any],
signature: Optional[str] = None,
sequence: int = 1,
timestamp: Optional[int] = None,
previous: Optional[Message] = None,
): ):
self.feed = feed self.feed = feed
self.content = content self.content = content
@ -162,7 +183,8 @@ class LocalMessage(Message):
else: else:
self.signature = signature self.signature = signature
def _sign(self): def _sign(self) -> str:
# ensure ordering of keys and indentation of 2 characters, like ssb-keys # ensure ordering of keys and indentation of 2 characters, like ssb-keys
data = self.serialize(add_signature=False) data = self.serialize(add_signature=False)
return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii") return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii")

View File

@ -6,14 +6,15 @@ import logging
from math import ceil from math import ceil
import struct import struct
from time import time from time import time
from typing import Any, AsyncIterator, Dict, Optional, Union from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
from secret_handshake.network import SHSDuplexStream from secret_handshake.network import SHSDuplexStream
import simplejson import simplejson
from typing_extensions import Self
logger = logging.getLogger("packet_stream") PSHandler = Union["PSRequestHandler", "PSStreamHandler"]
PSMessageData = Union[bytes, bool, Dict[str, Any], str] PSMessageData = Union[bytes, bool, Dict[str, Any], str]
logger = logging.getLogger("packet_stream")
class PSMessageType(Enum): class PSMessageType(Enum):
@ -31,12 +32,12 @@ class PSStreamHandler:
self.req = req self.req = req
self.queue: Queue[Optional["PSMessage"]] = Queue() self.queue: Queue[Optional["PSMessage"]] = Queue()
async def process(self, msg): async def process(self, msg: "PSMessage") -> None:
"""Process a pending message""" """Process a pending message"""
await self.queue.put(msg) await self.queue.put(msg)
async def stop(self): async def stop(self) -> None:
"""Stop a pending request""" """Stop a pending request"""
await self.queue.put(None) await self.queue.put(None)
@ -59,15 +60,15 @@ class PSRequestHandler:
def __init__(self, req: int): def __init__(self, req: int):
self.req = req self.req = req
self.event = Event() self.event = Event()
self._msg = None self._msg: Optional[PSMessage] = None
async def process(self, msg): async def process(self, msg: "PSMessage") -> None:
"""Process a message request""" """Process a message request"""
self._msg = msg self._msg = msg
self.event.set() self.event.set()
async def stop(self): async def stop(self) -> None:
"""Stop a pending event request""" """Stop a pending event request"""
if not self.event.is_set(): if not self.event.is_set():
@ -87,37 +88,44 @@ class PSMessage:
"""Packet Stream message""" """Packet Stream message"""
@classmethod @classmethod
def from_header_body(cls, flags, req, body): def from_header_body(cls, flags: int, req: int, body: bytes) -> Self:
"""Parse a raw message""" """Parse a raw message"""
type_ = PSMessageType(flags & 0x03) type_ = PSMessageType(flags & 0x03)
if type_ == PSMessageType.TEXT: if type_ == PSMessageType.TEXT:
body = body.decode("utf-8") body_s = body.decode("utf-8")
elif type_ == PSMessageType.JSON: elif type_ == PSMessageType.JSON:
body = simplejson.loads(body) body_s = simplejson.loads(body)
return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req) return cls(type_, body_s, bool(flags & 0x08), bool(flags & 0x04), req=req)
@property @property
def data(self) -> bytes: def data(self) -> bytes:
"""The raw message data""" """The raw message data"""
if self.type == PSMessageType.TEXT: if self.type == PSMessageType.TEXT:
assert isinstance(self.body, str)
return self.body.encode("utf-8") return self.body.encode("utf-8")
if self.type == PSMessageType.JSON: if self.type == PSMessageType.JSON:
return simplejson.dumps(self.body).encode("utf-8") return simplejson.dumps(self.body).encode("utf-8")
assert isinstance(self.body, bytes)
return self.body return self.body
def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments def __init__(
self, type_: PSMessageType, body: Any, stream: bool, end_err: bool, req: Optional[int] = None
): # pylint: disable=too-many-arguments
self.stream = stream self.stream = stream
self.end_err = end_err self.end_err = end_err
self.type = type_ self.type = type_
self.body = body self.body = body
self.req = req self.req = req
def __repr__(self): def __repr__(self) -> str:
if self.type == PSMessageType.BUFFER: if self.type == PSMessageType.BUFFER:
body = f"{len(self.body)} bytes" body = f"{len(self.body)} bytes"
else: else:
@ -136,16 +144,16 @@ class PacketStream:
def __init__(self, connection: SHSDuplexStream): def __init__(self, connection: SHSDuplexStream):
self.connection = connection self.connection = connection
self.req_counter = 1 self.req_counter = 1
self._event_map = {} self._event_map: Dict[int, Tuple[float, PSHandler]] = {}
self._connected = False self._connected = False
def register_handler(self, handler): def register_handler(self, handler: PSHandler) -> None:
"""Register an RPC handler""" """Register an RPC handler"""
self._event_map[handler.req] = (time(), handler) self._event_map[handler.req] = (time(), handler)
@property @property
def is_connected(self): def is_connected(self) -> bool:
"""Check if the stream is connected""" """Check if the stream is connected"""
return self.connection.is_connected return self.connection.is_connected
@ -159,22 +167,25 @@ class PacketStream:
if not msg: if not msg:
raise StopAsyncIteration() raise StopAsyncIteration()
if msg.req >= 0: if msg.req is not None and msg.req >= 0:
return msg return msg
return None return None
async def __await__(self): async def __await__(self) -> None:
async for data in self: async for data in self:
logger.info("RECV: %r", data) logger.info("RECV: %r", data)
if data is None: if data is None:
return return
async def _read(self): async def _read(self) -> Optional[PSMessage]:
try: try:
header = await self.connection.read() header = await self.connection.read()
if not header or header == b"\x00" * 9: if not header or header == b"\x00" * 9:
return return None
flags, length, req = struct.unpack(">BIi", header) flags, length, req = struct.unpack(">BIi", header)
n_packets = ceil(length / 4096) n_packets = ceil(length / 4096)
@ -182,30 +193,39 @@ class PacketStream:
body = b"" body = b""
for _ in range(n_packets): for _ in range(n_packets):
body += await self.connection.read() read_data = await self.connection.read()
if read_data is not None:
body += read_data
logger.debug("READ %s %s", header, len(body)) logger.debug("READ %s %s", header, len(body))
return PSMessage.from_header_body(flags, req, body) return PSMessage.from_header_body(flags, req, body)
except StopAsyncIteration: except StopAsyncIteration:
logger.debug("DISCONNECT") logger.debug("DISCONNECT")
self.connection.disconnect() self.connection.disconnect()
return None return None
async def read(self): async def read(self) -> Optional[PSMessage]:
"""Read data from the packet stream""" """Read data from the packet stream"""
msg = await self._read() msg = await self._read()
if not msg: if not msg:
return None return None
# check whether it's a reply and handle accordingly # check whether it's a reply and handle accordingly
if msg.req < 0: if msg.req is not None and msg.req < 0:
_, handler = self._event_map[-msg.req] _, handler = self._event_map[-msg.req]
await handler.process(msg) await handler.process(msg)
logger.info("RESPONSE [%d]: %r", -msg.req, msg) logger.info("RESPONSE [%d]: %r", -msg.req, msg)
if msg.end_err: if msg.end_err:
await handler.stop() await handler.stop()
del self._event_map[-msg.req] del self._event_map[-msg.req]
logger.info("RESPONSE [%d]: EOS", -msg.req) logger.info("RESPONSE [%d]: EOS", -msg.req)
return msg return msg
def _write(self, msg: PSMessage) -> None: def _write(self, msg: PSMessage) -> None:
@ -225,7 +245,7 @@ class PacketStream:
stream: bool = False, stream: bool = False,
end_err: bool = False, end_err: bool = False,
req: Optional[int] = None, req: Optional[int] = None,
): ) -> PSHandler:
"""Send data through the packet stream""" """Send data through the packet stream"""
update_counter = False update_counter = False
@ -240,7 +260,7 @@ class PacketStream:
self._write(msg) self._write(msg)
if stream: if stream:
handler = PSStreamHandler(self.req_counter) handler: PSHandler = PSStreamHandler(self.req_counter)
else: else:
handler = PSRequestHandler(self.req_counter) handler = PSRequestHandler(self.req_counter)
@ -251,7 +271,7 @@ class PacketStream:
return handler return handler
def disconnect(self): def disconnect(self) -> None:
"""Disconnect the stream""" """Disconnect the stream"""
self._connected = False self._connected = False

View File

@ -4,12 +4,12 @@ from base64 import b64decode, b64encode
import os import os
from typing import TypedDict from typing import TypedDict
from nacl.signing import SigningKey from nacl.signing import SigningKey, VerifyKey
import yaml import yaml
class SSBSecret(TypedDict): class SSBSecret(TypedDict):
"""Dictionary to hold an SSB identity""" """Dictionary type to hold an SSB secret identity"""
keypair: SigningKey keypair: SigningKey
id: str id: str
@ -19,7 +19,7 @@ class ConfigException(Exception):
"""Exception to raise if there is a problem with the configuration data""" """Exception to raise if there is a problem with the configuration data"""
def tag(key): def tag(key: VerifyKey) -> bytes:
"""Create tag from public key.""" """Create tag from public key."""
return b"@" + b64encode(bytes(key)) + b".ed25519" return b"@" + b64encode(bytes(key)) + b".ed25519"
@ -35,4 +35,5 @@ def load_ssb_secret() -> SSBSecret:
raise ConfigException("Algorithm not known: " + config["curve"]) raise ConfigException("Algorithm not known: " + config["curve"])
server_prv_key = b64decode(config["private"][:-8]) server_prv_key = b64decode(config["private"][:-8])
return {"keypair": SigningKey(server_prv_key[:32]), "id": config["id"]} return {"keypair": SigningKey(server_prv_key[:32]), "id": config["id"]}

View File

@ -25,7 +25,7 @@ SERIALIZED_M1 = b"""{
@pytest.fixture @pytest.fixture
def local_feed(): def local_feed() -> LocalFeed:
"""Fixture providing a local feed""" """Fixture providing a local feed"""
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
@ -33,14 +33,14 @@ def local_feed():
@pytest.fixture @pytest.fixture
def remote_feed(): def remote_feed() -> Feed:
"""Fixture providing a remote feed""" """Fixture providing a remote feed"""
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
return Feed(VerifyKey(public)) return Feed(VerifyKey(public))
def test_local_feed(): def test_local_feed() -> None:
"""Test a local feed""" """Test a local feed"""
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
@ -50,7 +50,7 @@ def test_local_feed():
assert feed.id == "@I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=.ed25519" assert feed.id == "@I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=.ed25519"
def test_remote_feed(): def test_remote_feed() -> None:
"""Test a remote feed""" """Test a remote feed"""
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
@ -69,7 +69,7 @@ def test_remote_feed():
feed.sign(m1) feed.sign(m1)
def test_local_message(local_feed): # pylint: disable=redefined-outer-name def test_local_message(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
"""Test a local message""" """Test a local message"""
m1 = LocalMessage( m1 = LocalMessage(
@ -102,7 +102,7 @@ def test_local_message(local_feed): # pylint: disable=redefined-outer-name
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name def test_remote_message(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name
"""Test a remote message""" """Test a remote message"""
signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519" signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519"
@ -136,7 +136,7 @@ def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name def test_remote_no_signature(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name
"""Test remote feed without a signature""" """Test remote feed without a signature"""
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -150,7 +150,7 @@ def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-na
) )
def test_serialize(local_feed): # pylint: disable=redefined-outer-name def test_serialize(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
"""Test feed serialization""" """Test feed serialization"""
m1 = LocalMessage( m1 = LocalMessage(
@ -162,7 +162,7 @@ def test_serialize(local_feed): # pylint: disable=redefined-outer-name
assert m1.serialize() == SERIALIZED_M1 assert m1.serialize() == SERIALIZED_M1
def test_parse(local_feed): # pylint: disable=redefined-outer-name def test_parse(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
"""Test feed parsing""" """Test feed parsing"""
m1 = LocalMessage.parse(SERIALIZED_M1, local_feed) m1 = LocalMessage.parse(SERIALIZED_M1, local_feed)

View File

@ -1,18 +1,23 @@
"""Tests for the packet stream""" """Tests for the packet stream"""
from asyncio import Event, ensure_future, gather from asyncio import Event, ensure_future, gather
from asyncio.events import AbstractEventLoop
import json import json
from typing import AsyncGenerator, Awaitable, Callable, Generator, List
import pytest import pytest
from pytest_mock import MockerFixture
from secret_handshake.network import SHSDuplexStream from secret_handshake.network import SHSDuplexStream
from ssb.packet_stream import PacketStream, PSMessageType from ssb.packet_stream import PacketStream, PSMessage, PSMessageType
async def _collect_messages(generator): async def _collect_messages(generator: AsyncGenerator[PSMessage, None]) -> List[PSMessage]:
results = [] results = []
async for msg in generator: async for msg in generator:
results.append(msg) results.append(msg)
return results return results
@ -39,45 +44,47 @@ MSG_BODY_2 = (
class MockSHSSocket(SHSDuplexStream): class MockSHSSocket(SHSDuplexStream):
"""A mocked SHS socket""" """A mocked SHS socket"""
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument def __init__(self): # pylint: disable=unused-argument
super().__init__() super().__init__()
self.input = [] self.input: List[bytes] = []
self.output = [] self.output: List[bytes] = []
self.is_connected = False self.is_connected = False
self._on_connect = [] self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = []
def on_connect(self, cb): def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
"""Set the on_connect callback""" """Set the on_connect callback"""
self._on_connect.append(cb) self._on_connect.append(cb)
async def read(self): async def read(self) -> bytes:
"""Read data from the socket""" """Read data from the socket"""
if not self.input: if not self.input:
raise StopAsyncIteration raise StopAsyncIteration
return self.input.pop(0) return self.input.pop(0)
def write(self, data): def write(self, data: bytes) -> None:
"""Write data to the socket""" """Write data to the socket"""
self.output.append(data) self.output.append(data)
def feed(self, input_): def feed(self, input_: List[bytes]) -> None:
"""Get the connections feed""" """Feed data into the connection"""
self.input += input_ self.input += input_
def get_output(self): def get_output(self) -> Generator[bytes, None, None]:
"""Get the output of a call""" """Get the output of a call"""
while True: while True:
if not self.output: if not self.output:
break break
yield self.output.pop(0) yield self.output.pop(0)
def disconnect(self): def disconnect(self) -> None:
"""Disconnect from the remote party""" """Disconnect from the remote party"""
self.is_connected = False self.is_connected = False
@ -86,48 +93,48 @@ class MockSHSSocket(SHSDuplexStream):
class MockSHSClient(MockSHSSocket): class MockSHSClient(MockSHSSocket):
"""A mocked SHS client""" """A mocked SHS client"""
async def connect(self): async def connect(self) -> None:
"""Connect to a SHS server""" """Connect to a SHS server"""
self.is_connected = True self.is_connected = True
for cb in self._on_connect: for cb in self._on_connect:
await cb() await cb(self)
class MockSHSServer(MockSHSSocket): class MockSHSServer(MockSHSSocket):
"""A mocked SHS server""" """A mocked SHS server"""
def listen(self): def listen(self) -> None:
"""Listen for new connections""" """Listen for new connections"""
self.is_connected = True self.is_connected = True
for cb in self._on_connect: for cb in self._on_connect:
ensure_future(cb()) ensure_future(cb(self))
@pytest.fixture @pytest.fixture
def ps_client(event_loop): # pylint: disable=unused-argument def ps_client(event_loop: AbstractEventLoop) -> MockSHSClient: # pylint: disable=unused-argument
"""Fixture to provide a mocked SHS client""" """Fixture to provide a mocked SHS client"""
return MockSHSClient() return MockSHSClient()
@pytest.fixture @pytest.fixture
def ps_server(event_loop): # pylint: disable=unused-argument def ps_server(event_loop: AbstractEventLoop) -> MockSHSServer: # pylint: disable=unused-argument
"""Fixture to provide a mocked SHS server""" """Fixture to provide a mocked SHS server"""
return MockSHSServer() return MockSHSServer()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable=redefined-outer-name
"""Test the on_connect callback functionality""" """Test the on_connect callback functionality"""
called = Event() called = Event()
async def _on_connect(): async def _on_connect(_: SHSDuplexStream) -> None:
called.set() called.set()
ps_server.on_connect(_on_connect) ps_server.on_connect(_on_connect)
@ -137,7 +144,7 @@ async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name async def test_message_decoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name
"""Test message decoding""" """Test message decoding"""
await ps_client.connect() await ps_client.connect()
@ -167,7 +174,7 @@ async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-n
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name async def test_message_encoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name
"""Test message encoding""" """Test message encoding"""
await ps_client.connect() await ps_client.connect()
@ -200,7 +207,7 @@ async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-n
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name async def test_message_stream(ps_client: MockSHSClient, mocker: MockerFixture): # pylint: disable=redefined-outer-name
"""Test requesting a history stream""" """Test requesting a history stream"""
await ps_client.connect() await ps_client.connect()
@ -272,7 +279,9 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name async def test_message_request(
ps_server: MockSHSServer, mocker: MockerFixture # pylint: disable=redefined-outer-name
) -> None:
"""Test message sending""" """Test message sending"""
ps_server.listen() ps_server.listen()

View File

@ -20,7 +20,7 @@ CONFIG_FILE = """
CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo") CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo")
def test_load_secret(): def test_load_secret() -> None:
"""Test loading the SSB secret from a file""" """Test loading the SSB secret from a file"""
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True): with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True):
@ -33,8 +33,9 @@ def test_load_secret():
assert bytes(secret["keypair"].verify_key) == b64decode("rsYpBIcXsxjQAf0JNes+MHqT2DL+EfopWKAp4rGeEPQ=") assert bytes(secret["keypair"].verify_key) == b64decode("rsYpBIcXsxjQAf0JNes+MHqT2DL+EfopWKAp4rGeEPQ=")
def test_load_exception(): def test_load_exception() -> None:
"""Test configuration loading if there is a problem with the file""" """Test configuration loading if there is a problem with the file"""
with pytest.raises(ConfigException): with pytest.raises(ConfigException):
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True): with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True):
load_ssb_secret() load_ssb_secret()