ci: Add and configure mypy, and make it happy

This commit is contained in:
2023-11-01 07:22:29 +01:00
parent 5aa4f16a5a
commit f6e58b7682
8 changed files with 163 additions and 86 deletions

View File

@@ -4,8 +4,11 @@ from base64 import b64encode
from collections import OrderedDict, namedtuple
import datetime
from hashlib import sha256
from typing import Any, Dict, Optional
from nacl.signing import SigningKey, VerifyKey
from simplejson import dumps, loads
from typing_extensions import Self
from ssb.util import tag
@@ -16,7 +19,7 @@ class NoPrivateKeyException(Exception):
"""Exception to raise when a private key is not available"""
def to_ordered(data):
def to_ordered(data: Dict[str, Any]) -> OrderedDict[str, Any]:
"""Convert a dictionary to an ``OrderedDict``"""
smsg = OrderedMsg(**data)
@@ -24,7 +27,7 @@ def to_ordered(data):
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
def get_millis_1970():
def get_millis_1970() -> int:
"""Get the UNIX timestamp in milliseconds"""
return int(datetime.datetime.utcnow().timestamp() * 1000)
@@ -33,16 +36,16 @@ def get_millis_1970():
class Feed:
"""Base class for feeds"""
def __init__(self, public_key):
def __init__(self, public_key: VerifyKey):
self.public_key = public_key
@property
def id(self):
def id(self) -> str:
"""The identifier of the feed"""
return tag(self.public_key).decode("ascii")
def sign(self, msg):
def sign(self, msg: "Message") -> bytes:
"""Sign a message"""
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
@@ -51,16 +54,20 @@ class Feed:
class LocalFeed(Feed):
"""Class representing a local feed"""
def __init__(self, private_key): # pylint: disable=super-init-not-called
self.private_key = private_key
def __init__(self, private_key: SigningKey): # pylint: disable=super-init-not-called
self.private_key: SigningKey = private_key
@property
def public_key(self):
def public_key(self) -> VerifyKey:
"""The public key of the feed"""
return self.private_key.verify_key
def sign(self, msg):
@public_key.setter
def public_key(self, _: VerifyKey) -> None:
raise TypeError("Cannot set just the public key of a local feed")
def sign(self, msg: "Message") -> bytes:
"""Sign a message for this feed"""
return self.private_key.sign(msg).signature
@@ -70,7 +77,13 @@ class Message:
"""Base class for SSB messages"""
def __init__( # pylint: disable=too-many-arguments
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
self,
feed: Feed,
content: Dict[str, Any],
signature: Optional[str] = None,
sequence: int = 1,
timestamp: Optional[int] = None,
previous: Optional["Message"] = None,
):
self.feed = feed
self.content = content
@@ -81,15 +94,16 @@ class Message:
self.signature = signature
self.previous = previous
if self.previous:
self.sequence = self.previous.sequence + 1
self.sequence: int = self.previous.sequence + 1
else:
self.sequence = sequence
self.timestamp = get_millis_1970() if timestamp is None else timestamp
@classmethod
def parse(cls, data, feed):
def parse(cls, data: bytes, feed: Feed) -> Self:
"""Parse raw message data"""
obj = loads(data, object_pairs_hook=OrderedDict)
@@ -97,12 +111,12 @@ class Message:
return msg
def serialize(self, add_signature=True):
def serialize(self, add_signature: bool = True) -> bytes:
"""Serialize the message"""
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
def to_dict(self, add_signature=True):
def to_dict(self, add_signature: bool = True) -> OrderedDict[str, Any]:
"""Convert the message to a dictionary"""
obj = to_ordered(
@@ -121,20 +135,21 @@ class Message:
return obj
def verify(self, signature):
def verify(self, signature: str) -> bool:
"""Verify the signature of the message"""
return self.signature == signature
@property
def hash(self):
def hash(self) -> str:
"""The cryptographic hash of the message"""
hash_ = sha256(self.serialize()).digest()
return b64encode(hash_).decode("ascii") + ".sha256"
@property
def key(self):
def key(self) -> str:
"""The key of the message"""
return "%" + self.hash
@@ -144,7 +159,13 @@ class LocalMessage(Message):
"""Class representing a local message"""
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
self,
feed: Feed,
content: Dict[str, Any],
signature: Optional[str] = None,
sequence: int = 1,
timestamp: Optional[int] = None,
previous: Optional[Message] = None,
):
self.feed = feed
self.content = content
@@ -162,7 +183,8 @@ class LocalMessage(Message):
else:
self.signature = signature
def _sign(self):
def _sign(self) -> str:
# ensure ordering of keys and indentation of 2 characters, like ssb-keys
data = self.serialize(add_signature=False)
return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii")

View File

@@ -6,14 +6,15 @@ import logging
from math import ceil
import struct
from time import time
from typing import Any, AsyncIterator, Dict, Optional, Union
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
from secret_handshake.network import SHSDuplexStream
import simplejson
from typing_extensions import Self
logger = logging.getLogger("packet_stream")
PSHandler = Union["PSRequestHandler", "PSStreamHandler"]
PSMessageData = Union[bytes, bool, Dict[str, Any], str]
logger = logging.getLogger("packet_stream")
class PSMessageType(Enum):
@@ -31,12 +32,12 @@ class PSStreamHandler:
self.req = req
self.queue: Queue[Optional["PSMessage"]] = Queue()
async def process(self, msg):
async def process(self, msg: "PSMessage") -> None:
"""Process a pending message"""
await self.queue.put(msg)
async def stop(self):
async def stop(self) -> None:
"""Stop a pending request"""
await self.queue.put(None)
@@ -59,15 +60,15 @@ class PSRequestHandler:
def __init__(self, req: int):
self.req = req
self.event = Event()
self._msg = None
self._msg: Optional[PSMessage] = None
async def process(self, msg):
async def process(self, msg: "PSMessage") -> None:
"""Process a message request"""
self._msg = msg
self.event.set()
async def stop(self):
async def stop(self) -> None:
"""Stop a pending event request"""
if not self.event.is_set():
@@ -87,37 +88,44 @@ class PSMessage:
"""Packet Stream message"""
@classmethod
def from_header_body(cls, flags, req, body):
def from_header_body(cls, flags: int, req: int, body: bytes) -> Self:
"""Parse a raw message"""
type_ = PSMessageType(flags & 0x03)
if type_ == PSMessageType.TEXT:
body = body.decode("utf-8")
body_s = body.decode("utf-8")
elif type_ == PSMessageType.JSON:
body = simplejson.loads(body)
body_s = simplejson.loads(body)
return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req)
return cls(type_, body_s, bool(flags & 0x08), bool(flags & 0x04), req=req)
@property
def data(self) -> bytes:
"""The raw message data"""
if self.type == PSMessageType.TEXT:
assert isinstance(self.body, str)
return self.body.encode("utf-8")
if self.type == PSMessageType.JSON:
return simplejson.dumps(self.body).encode("utf-8")
assert isinstance(self.body, bytes)
return self.body
def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments
def __init__(
self, type_: PSMessageType, body: Any, stream: bool, end_err: bool, req: Optional[int] = None
): # pylint: disable=too-many-arguments
self.stream = stream
self.end_err = end_err
self.type = type_
self.body = body
self.req = req
def __repr__(self):
def __repr__(self) -> str:
if self.type == PSMessageType.BUFFER:
body = f"{len(self.body)} bytes"
else:
@@ -136,16 +144,16 @@ class PacketStream:
def __init__(self, connection: SHSDuplexStream):
self.connection = connection
self.req_counter = 1
self._event_map = {}
self._event_map: Dict[int, Tuple[float, PSHandler]] = {}
self._connected = False
def register_handler(self, handler):
def register_handler(self, handler: PSHandler) -> None:
"""Register an RPC handler"""
self._event_map[handler.req] = (time(), handler)
@property
def is_connected(self):
def is_connected(self) -> bool:
"""Check if the stream is connected"""
return self.connection.is_connected
@@ -159,22 +167,25 @@ class PacketStream:
if not msg:
raise StopAsyncIteration()
if msg.req >= 0:
if msg.req is not None and msg.req >= 0:
return msg
return None
async def __await__(self):
async def __await__(self) -> None:
async for data in self:
logger.info("RECV: %r", data)
if data is None:
return
async def _read(self):
async def _read(self) -> Optional[PSMessage]:
try:
header = await self.connection.read()
if not header or header == b"\x00" * 9:
return
return None
flags, length, req = struct.unpack(">BIi", header)
n_packets = ceil(length / 4096)
@@ -182,30 +193,39 @@ class PacketStream:
body = b""
for _ in range(n_packets):
body += await self.connection.read()
read_data = await self.connection.read()
if read_data is not None:
body += read_data
logger.debug("READ %s %s", header, len(body))
return PSMessage.from_header_body(flags, req, body)
except StopAsyncIteration:
logger.debug("DISCONNECT")
self.connection.disconnect()
return None
async def read(self):
async def read(self) -> Optional[PSMessage]:
"""Read data from the packet stream"""
msg = await self._read()
if not msg:
return None
# check whether it's a reply and handle accordingly
if msg.req < 0:
if msg.req is not None and msg.req < 0:
_, handler = self._event_map[-msg.req]
await handler.process(msg)
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
if msg.end_err:
await handler.stop()
del self._event_map[-msg.req]
logger.info("RESPONSE [%d]: EOS", -msg.req)
return msg
def _write(self, msg: PSMessage) -> None:
@@ -225,7 +245,7 @@ class PacketStream:
stream: bool = False,
end_err: bool = False,
req: Optional[int] = None,
):
) -> PSHandler:
"""Send data through the packet stream"""
update_counter = False
@@ -240,7 +260,7 @@ class PacketStream:
self._write(msg)
if stream:
handler = PSStreamHandler(self.req_counter)
handler: PSHandler = PSStreamHandler(self.req_counter)
else:
handler = PSRequestHandler(self.req_counter)
@@ -251,7 +271,7 @@ class PacketStream:
return handler
def disconnect(self):
def disconnect(self) -> None:
"""Disconnect the stream"""
self._connected = False

View File

@@ -4,12 +4,12 @@ from base64 import b64decode, b64encode
import os
from typing import TypedDict
from nacl.signing import SigningKey
from nacl.signing import SigningKey, VerifyKey
import yaml
class SSBSecret(TypedDict):
"""Dictionary to hold an SSB identity"""
"""Dictionary type to hold an SSB secret identity"""
keypair: SigningKey
id: str
@@ -19,7 +19,7 @@ class ConfigException(Exception):
"""Exception to raise if there is a problem with the configuration data"""
def tag(key):
def tag(key: VerifyKey) -> bytes:
"""Create tag from public key."""
return b"@" + b64encode(bytes(key)) + b".ed25519"
@@ -35,4 +35,5 @@ def load_ssb_secret() -> SSBSecret:
raise ConfigException("Algorithm not known: " + config["curve"])
server_prv_key = b64decode(config["private"][:-8])
return {"keypair": SigningKey(server_prv_key[:32]), "id": config["id"]}