ci: Add and configure mypy, and make it happy

This commit is contained in:
Gergely Polonkai 2023-11-01 07:22:29 +01:00
parent 5aa4f16a5a
commit f6e58b7682
No known key found for this signature in database
GPG Key ID: 2D2885533B869ED4
8 changed files with 163 additions and 86 deletions

24
poetry.lock generated
View File

@ -1258,6 +1258,28 @@ files = [
{file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"},
]
[[package]]
name = "types-pyyaml"
version = "6.0.12.12"
description = "Typing stubs for PyYAML"
optional = false
python-versions = "*"
files = [
{file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"},
{file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"},
]
[[package]]
name = "types-simplejson"
version = "3.19.0.2"
description = "Typing stubs for simplejson"
optional = false
python-versions = "*"
files = [
{file = "types-simplejson-3.19.0.2.tar.gz", hash = "sha256:ebc81f886f89d99d6b80c726518aa2228bc77c26438f18fd81455e4f79f8ee1b"},
{file = "types_simplejson-3.19.0.2-py3-none-any.whl", hash = "sha256:8ba093dc7884f59b3e62aed217144085e675a269debc32678fd80e0b43b2b86f"},
]
[[package]]
name = "typing-extensions"
version = "4.8.0"
@ -1315,4 +1337,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.9"
content-hash = "b503e50c4ab977c6785c68bc1e5bf2efb7ab88a1fd33770e84a7f612a85d2641"
content-hash = "d80cbfdf7923c50c95505a84d8ad75eae016ca81ae32a8b22d074569b0a0fcbd"

View File

@ -27,6 +27,8 @@ commitizen = "^3.12.0"
black = "^23.10.1"
pylint = "^3.0.2"
mypy = "^1.6.1"
types-pyyaml = "^6.0.12.12"
types-simplejson = "^3.19.0.2"
[tool.poetry.group.docs.dependencies]
Sphinx = "^2.1.1"

View File

@ -4,8 +4,11 @@ from base64 import b64encode
from collections import OrderedDict, namedtuple
import datetime
from hashlib import sha256
from typing import Any, Dict, Optional
from nacl.signing import SigningKey, VerifyKey
from simplejson import dumps, loads
from typing_extensions import Self
from ssb.util import tag
@ -16,7 +19,7 @@ class NoPrivateKeyException(Exception):
"""Exception to raise when a private key is not available"""
def to_ordered(data):
def to_ordered(data: Dict[str, Any]) -> OrderedDict[str, Any]:
"""Convert a dictionary to an ``OrderedDict``"""
smsg = OrderedMsg(**data)
@ -24,7 +27,7 @@ def to_ordered(data):
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
def get_millis_1970():
def get_millis_1970() -> int:
"""Get the UNIX timestamp in milliseconds"""
return int(datetime.datetime.utcnow().timestamp() * 1000)
@ -33,16 +36,16 @@ def get_millis_1970():
class Feed:
"""Base class for feeds"""
def __init__(self, public_key):
def __init__(self, public_key: VerifyKey):
self.public_key = public_key
@property
def id(self):
def id(self) -> str:
"""The identifier of the feed"""
return tag(self.public_key).decode("ascii")
def sign(self, msg):
def sign(self, msg: "Message") -> bytes:
"""Sign a message"""
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
@ -51,16 +54,20 @@ class Feed:
class LocalFeed(Feed):
"""Class representing a local feed"""
def __init__(self, private_key): # pylint: disable=super-init-not-called
self.private_key = private_key
def __init__(self, private_key: SigningKey): # pylint: disable=super-init-not-called
self.private_key: SigningKey = private_key
@property
def public_key(self):
def public_key(self) -> VerifyKey:
"""The public key of the feed"""
return self.private_key.verify_key
def sign(self, msg):
@public_key.setter
def public_key(self, _: VerifyKey) -> None:
raise TypeError("Cannot set just the public key of a local feed")
def sign(self, msg: "Message") -> bytes:
"""Sign a message for this feed"""
return self.private_key.sign(msg).signature
@ -70,7 +77,13 @@ class Message:
"""Base class for SSB messages"""
def __init__( # pylint: disable=too-many-arguments
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
self,
feed: Feed,
content: Dict[str, Any],
signature: Optional[str] = None,
sequence: int = 1,
timestamp: Optional[int] = None,
previous: Optional["Message"] = None,
):
self.feed = feed
self.content = content
@ -81,15 +94,16 @@ class Message:
self.signature = signature
self.previous = previous
if self.previous:
self.sequence = self.previous.sequence + 1
self.sequence: int = self.previous.sequence + 1
else:
self.sequence = sequence
self.timestamp = get_millis_1970() if timestamp is None else timestamp
@classmethod
def parse(cls, data, feed):
def parse(cls, data: bytes, feed: Feed) -> Self:
"""Parse raw message data"""
obj = loads(data, object_pairs_hook=OrderedDict)
@ -97,12 +111,12 @@ class Message:
return msg
def serialize(self, add_signature=True):
def serialize(self, add_signature: bool = True) -> bytes:
"""Serialize the message"""
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
def to_dict(self, add_signature=True):
def to_dict(self, add_signature: bool = True) -> OrderedDict[str, Any]:
"""Convert the message to a dictionary"""
obj = to_ordered(
@ -121,20 +135,21 @@ class Message:
return obj
def verify(self, signature):
def verify(self, signature: str) -> bool:
"""Verify the signature of the message"""
return self.signature == signature
@property
def hash(self):
def hash(self) -> str:
"""The cryptographic hash of the message"""
hash_ = sha256(self.serialize()).digest()
return b64encode(hash_).decode("ascii") + ".sha256"
@property
def key(self):
def key(self) -> str:
"""The key of the message"""
return "%" + self.hash
@ -144,7 +159,13 @@ class LocalMessage(Message):
"""Class representing a local message"""
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
self,
feed: Feed,
content: Dict[str, Any],
signature: Optional[str] = None,
sequence: int = 1,
timestamp: Optional[int] = None,
previous: Optional[Message] = None,
):
self.feed = feed
self.content = content
@ -162,7 +183,8 @@ class LocalMessage(Message):
else:
self.signature = signature
def _sign(self):
def _sign(self) -> str:
# ensure ordering of keys and indentation of 2 characters, like ssb-keys
data = self.serialize(add_signature=False)
return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii")

View File

@ -6,14 +6,15 @@ import logging
from math import ceil
import struct
from time import time
from typing import Any, AsyncIterator, Dict, Optional, Union
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
from secret_handshake.network import SHSDuplexStream
import simplejson
from typing_extensions import Self
logger = logging.getLogger("packet_stream")
PSHandler = Union["PSRequestHandler", "PSStreamHandler"]
PSMessageData = Union[bytes, bool, Dict[str, Any], str]
logger = logging.getLogger("packet_stream")
class PSMessageType(Enum):
@ -31,12 +32,12 @@ class PSStreamHandler:
self.req = req
self.queue: Queue[Optional["PSMessage"]] = Queue()
async def process(self, msg):
async def process(self, msg: "PSMessage") -> None:
"""Process a pending message"""
await self.queue.put(msg)
async def stop(self):
async def stop(self) -> None:
"""Stop a pending request"""
await self.queue.put(None)
@ -59,15 +60,15 @@ class PSRequestHandler:
def __init__(self, req: int):
self.req = req
self.event = Event()
self._msg = None
self._msg: Optional[PSMessage] = None
async def process(self, msg):
async def process(self, msg: "PSMessage") -> None:
"""Process a message request"""
self._msg = msg
self.event.set()
async def stop(self):
async def stop(self) -> None:
"""Stop a pending event request"""
if not self.event.is_set():
@ -87,37 +88,44 @@ class PSMessage:
"""Packet Stream message"""
@classmethod
def from_header_body(cls, flags, req, body):
def from_header_body(cls, flags: int, req: int, body: bytes) -> Self:
"""Parse a raw message"""
type_ = PSMessageType(flags & 0x03)
if type_ == PSMessageType.TEXT:
body = body.decode("utf-8")
body_s = body.decode("utf-8")
elif type_ == PSMessageType.JSON:
body = simplejson.loads(body)
body_s = simplejson.loads(body)
return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req)
return cls(type_, body_s, bool(flags & 0x08), bool(flags & 0x04), req=req)
@property
def data(self) -> bytes:
"""The raw message data"""
if self.type == PSMessageType.TEXT:
assert isinstance(self.body, str)
return self.body.encode("utf-8")
if self.type == PSMessageType.JSON:
return simplejson.dumps(self.body).encode("utf-8")
assert isinstance(self.body, bytes)
return self.body
def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments
def __init__(
self, type_: PSMessageType, body: Any, stream: bool, end_err: bool, req: Optional[int] = None
): # pylint: disable=too-many-arguments
self.stream = stream
self.end_err = end_err
self.type = type_
self.body = body
self.req = req
def __repr__(self):
def __repr__(self) -> str:
if self.type == PSMessageType.BUFFER:
body = f"{len(self.body)} bytes"
else:
@ -136,16 +144,16 @@ class PacketStream:
def __init__(self, connection: SHSDuplexStream):
self.connection = connection
self.req_counter = 1
self._event_map = {}
self._event_map: Dict[int, Tuple[float, PSHandler]] = {}
self._connected = False
def register_handler(self, handler):
def register_handler(self, handler: PSHandler) -> None:
"""Register an RPC handler"""
self._event_map[handler.req] = (time(), handler)
@property
def is_connected(self):
def is_connected(self) -> bool:
"""Check if the stream is connected"""
return self.connection.is_connected
@ -159,22 +167,25 @@ class PacketStream:
if not msg:
raise StopAsyncIteration()
if msg.req >= 0:
if msg.req is not None and msg.req >= 0:
return msg
return None
async def __await__(self):
async def __await__(self) -> None:
async for data in self:
logger.info("RECV: %r", data)
if data is None:
return
async def _read(self):
async def _read(self) -> Optional[PSMessage]:
try:
header = await self.connection.read()
if not header or header == b"\x00" * 9:
return
return None
flags, length, req = struct.unpack(">BIi", header)
n_packets = ceil(length / 4096)
@ -182,30 +193,39 @@ class PacketStream:
body = b""
for _ in range(n_packets):
body += await self.connection.read()
read_data = await self.connection.read()
if read_data is not None:
body += read_data
logger.debug("READ %s %s", header, len(body))
return PSMessage.from_header_body(flags, req, body)
except StopAsyncIteration:
logger.debug("DISCONNECT")
self.connection.disconnect()
return None
async def read(self):
async def read(self) -> Optional[PSMessage]:
"""Read data from the packet stream"""
msg = await self._read()
if not msg:
return None
# check whether it's a reply and handle accordingly
if msg.req < 0:
if msg.req is not None and msg.req < 0:
_, handler = self._event_map[-msg.req]
await handler.process(msg)
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
if msg.end_err:
await handler.stop()
del self._event_map[-msg.req]
logger.info("RESPONSE [%d]: EOS", -msg.req)
return msg
def _write(self, msg: PSMessage) -> None:
@ -225,7 +245,7 @@ class PacketStream:
stream: bool = False,
end_err: bool = False,
req: Optional[int] = None,
):
) -> PSHandler:
"""Send data through the packet stream"""
update_counter = False
@ -240,7 +260,7 @@ class PacketStream:
self._write(msg)
if stream:
handler = PSStreamHandler(self.req_counter)
handler: PSHandler = PSStreamHandler(self.req_counter)
else:
handler = PSRequestHandler(self.req_counter)
@ -251,7 +271,7 @@ class PacketStream:
return handler
def disconnect(self):
def disconnect(self) -> None:
"""Disconnect the stream"""
self._connected = False

View File

@ -4,12 +4,12 @@ from base64 import b64decode, b64encode
import os
from typing import TypedDict
from nacl.signing import SigningKey
from nacl.signing import SigningKey, VerifyKey
import yaml
class SSBSecret(TypedDict):
"""Dictionary to hold an SSB identity"""
"""Dictionary type to hold an SSB secret identity"""
keypair: SigningKey
id: str
@ -19,7 +19,7 @@ class ConfigException(Exception):
"""Exception to raise if there is a problem with the configuration data"""
def tag(key):
def tag(key: VerifyKey) -> bytes:
"""Create tag from public key."""
return b"@" + b64encode(bytes(key)) + b".ed25519"
@ -35,4 +35,5 @@ def load_ssb_secret() -> SSBSecret:
raise ConfigException("Algorithm not known: " + config["curve"])
server_prv_key = b64decode(config["private"][:-8])
return {"keypair": SigningKey(server_prv_key[:32]), "id": config["id"]}

View File

@ -25,7 +25,7 @@ SERIALIZED_M1 = b"""{
@pytest.fixture
def local_feed():
def local_feed() -> LocalFeed:
"""Fixture providing a local feed"""
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
@ -33,14 +33,14 @@ def local_feed():
@pytest.fixture
def remote_feed():
def remote_feed() -> Feed:
"""Fixture providing a remote feed"""
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
return Feed(VerifyKey(public))
def test_local_feed():
def test_local_feed() -> None:
"""Test a local feed"""
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
@ -50,7 +50,7 @@ def test_local_feed():
assert feed.id == "@I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=.ed25519"
def test_remote_feed():
def test_remote_feed() -> None:
"""Test a remote feed"""
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
@ -69,7 +69,7 @@ def test_remote_feed():
feed.sign(m1)
def test_local_message(local_feed): # pylint: disable=redefined-outer-name
def test_local_message(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
"""Test a local message"""
m1 = LocalMessage(
@ -102,7 +102,7 @@ def test_local_message(local_feed): # pylint: disable=redefined-outer-name
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
def test_remote_message(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name
"""Test a remote message"""
signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519"
@ -136,7 +136,7 @@ def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name
def test_remote_no_signature(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name
"""Test remote feed without a signature"""
with pytest.raises(ValueError):
@ -150,7 +150,7 @@ def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-na
)
def test_serialize(local_feed): # pylint: disable=redefined-outer-name
def test_serialize(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
"""Test feed serialization"""
m1 = LocalMessage(
@ -162,7 +162,7 @@ def test_serialize(local_feed): # pylint: disable=redefined-outer-name
assert m1.serialize() == SERIALIZED_M1
def test_parse(local_feed): # pylint: disable=redefined-outer-name
def test_parse(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
"""Test feed parsing"""
m1 = LocalMessage.parse(SERIALIZED_M1, local_feed)

View File

@ -1,18 +1,23 @@
"""Tests for the packet stream"""
from asyncio import Event, ensure_future, gather
from asyncio.events import AbstractEventLoop
import json
from typing import AsyncGenerator, Awaitable, Callable, Generator, List
import pytest
from pytest_mock import MockerFixture
from secret_handshake.network import SHSDuplexStream
from ssb.packet_stream import PacketStream, PSMessageType
from ssb.packet_stream import PacketStream, PSMessage, PSMessageType
async def _collect_messages(generator):
async def _collect_messages(generator: AsyncGenerator[PSMessage, None]) -> List[PSMessage]:
results = []
async for msg in generator:
results.append(msg)
return results
@ -39,45 +44,47 @@ MSG_BODY_2 = (
class MockSHSSocket(SHSDuplexStream):
"""A mocked SHS socket"""
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
def __init__(self): # pylint: disable=unused-argument
super().__init__()
self.input = []
self.output = []
self.input: List[bytes] = []
self.output: List[bytes] = []
self.is_connected = False
self._on_connect = []
self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = []
def on_connect(self, cb):
def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
"""Set the on_connect callback"""
self._on_connect.append(cb)
async def read(self):
async def read(self) -> bytes:
"""Read data from the socket"""
if not self.input:
raise StopAsyncIteration
return self.input.pop(0)
def write(self, data):
def write(self, data: bytes) -> None:
"""Write data to the socket"""
self.output.append(data)
def feed(self, input_):
"""Get the connections feed"""
def feed(self, input_: List[bytes]) -> None:
"""Feed data into the connection"""
self.input += input_
def get_output(self):
def get_output(self) -> Generator[bytes, None, None]:
"""Get the output of a call"""
while True:
if not self.output:
break
yield self.output.pop(0)
def disconnect(self):
def disconnect(self) -> None:
"""Disconnect from the remote party"""
self.is_connected = False
@ -86,48 +93,48 @@ class MockSHSSocket(SHSDuplexStream):
class MockSHSClient(MockSHSSocket):
"""A mocked SHS client"""
async def connect(self):
async def connect(self) -> None:
"""Connect to a SHS server"""
self.is_connected = True
for cb in self._on_connect:
await cb()
await cb(self)
class MockSHSServer(MockSHSSocket):
"""A mocked SHS server"""
def listen(self):
def listen(self) -> None:
"""Listen for new connections"""
self.is_connected = True
for cb in self._on_connect:
ensure_future(cb())
ensure_future(cb(self))
@pytest.fixture
def ps_client(event_loop): # pylint: disable=unused-argument
def ps_client(event_loop: AbstractEventLoop) -> MockSHSClient: # pylint: disable=unused-argument
"""Fixture to provide a mocked SHS client"""
return MockSHSClient()
@pytest.fixture
def ps_server(event_loop): # pylint: disable=unused-argument
def ps_server(event_loop: AbstractEventLoop) -> MockSHSServer: # pylint: disable=unused-argument
"""Fixture to provide a mocked SHS server"""
return MockSHSServer()
@pytest.mark.asyncio
async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name
async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable=redefined-outer-name
"""Test the on_connect callback functionality"""
called = Event()
async def _on_connect():
async def _on_connect(_: SHSDuplexStream) -> None:
called.set()
ps_server.on_connect(_on_connect)
@ -137,7 +144,7 @@ async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name
@pytest.mark.asyncio
async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name
async def test_message_decoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name
"""Test message decoding"""
await ps_client.connect()
@ -167,7 +174,7 @@ async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-n
@pytest.mark.asyncio
async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name
async def test_message_encoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name
"""Test message encoding"""
await ps_client.connect()
@ -200,7 +207,7 @@ async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-n
@pytest.mark.asyncio
async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name
async def test_message_stream(ps_client: MockSHSClient, mocker: MockerFixture): # pylint: disable=redefined-outer-name
"""Test requesting a history stream"""
await ps_client.connect()
@ -272,7 +279,9 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o
@pytest.mark.asyncio
async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name
async def test_message_request(
ps_server: MockSHSServer, mocker: MockerFixture # pylint: disable=redefined-outer-name
) -> None:
"""Test message sending"""
ps_server.listen()

View File

@ -20,7 +20,7 @@ CONFIG_FILE = """
CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo")
def test_load_secret():
def test_load_secret() -> None:
"""Test loading the SSB secret from a file"""
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True):
@ -33,8 +33,9 @@ def test_load_secret():
assert bytes(secret["keypair"].verify_key) == b64decode("rsYpBIcXsxjQAf0JNes+MHqT2DL+EfopWKAp4rGeEPQ=")
def test_load_exception():
def test_load_exception() -> None:
"""Test configuration loading if there is a problem with the file"""
with pytest.raises(ConfigException):
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True):
load_ssb_secret()