ci: Add and configure mypy, and make it happy

This commit is contained in:
2023-11-01 07:22:29 +01:00
parent f2a54b5ce6
commit 1c1e57d868
13 changed files with 374 additions and 214 deletions

View File

@@ -26,8 +26,11 @@ from base64 import b64encode
from collections import OrderedDict, namedtuple
from datetime import datetime
from hashlib import sha256
from typing import Any, Dict, Optional
from nacl.signing import SigningKey, VerifyKey
from simplejson import dumps, loads
from typing_extensions import Self
from ssb.util import tag
@@ -38,7 +41,7 @@ class NoPrivateKeyException(Exception):
"""Exception to raise when a private key is not available"""
def to_ordered(data):
def to_ordered(data: Dict[str, Any]) -> OrderedDict[str, Any]:
"""Convert a dictionary to an ``OrderedDict``"""
smsg = OrderedMsg(**data)
@@ -46,7 +49,7 @@ def to_ordered(data):
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
def get_millis_1970():
def get_millis_1970() -> int:
"""Get the UNIX timestamp in milliseconds"""
return int(datetime.utcnow().timestamp() * 1000)
@@ -55,16 +58,16 @@ def get_millis_1970():
class Feed:
"""Base class for feeds"""
def __init__(self, public_key):
def __init__(self, public_key: VerifyKey) -> None:
self.public_key = public_key
@property
def id(self):
def id(self) -> str:
"""The identifier of the feed"""
return tag(self.public_key).decode("ascii")
def sign(self, msg):
def sign(self, msg: bytes) -> bytes:
"""Sign a message"""
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
@@ -73,16 +76,20 @@ class Feed:
class LocalFeed(Feed):
"""Class representing a local feed"""
def __init__(self, private_key): # pylint: disable=super-init-not-called
def __init__(self, private_key: SigningKey) -> None: # pylint: disable=super-init-not-called
self.private_key = private_key
@property
def public_key(self):
def public_key(self) -> VerifyKey:
"""The public key of the feed"""
return self.private_key.verify_key
def sign(self, msg):
@public_key.setter
def public_key(self, key: VerifyKey) -> None:
raise TypeError("Can not set only the public key for a local feed")
def sign(self, msg: bytes) -> bytes:
"""Sign a message for this feed"""
return self.private_key.sign(msg).signature
@@ -92,25 +99,34 @@ class Message:
"""Base class for SSB messages"""
def __init__( # pylint: disable=too-many-arguments
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
self,
feed: Feed,
content: Dict[str, Any],
signature: Optional[str] = None,
sequence: int = 1,
timestamp: Optional[int] = None,
previous: Optional["Message"] = None,
):
self.feed = feed
self.content = content
if signature is None:
raise ValueError("signature can't be None")
self.signature = signature
self.previous = previous
self.timestamp = get_millis_1970() if timestamp is None else timestamp
if self.previous:
self.sequence = self.previous.sequence + 1
self.sequence: int = self.previous.sequence + 1
else:
self.sequence = sequence
self.timestamp = get_millis_1970() if timestamp is None else timestamp
self._check_signature()
def _check_signature(self) -> None:
if self.signature is None:
raise ValueError("signature can't be None")
@classmethod
def parse(cls, data, feed):
def parse(cls, data: bytes, feed: Feed) -> Self:
"""Parse raw message data"""
obj = loads(data, object_pairs_hook=OrderedDict)
@@ -118,12 +134,12 @@ class Message:
return msg
def serialize(self, add_signature=True):
def serialize(self, add_signature: bool = True) -> bytes:
"""Serialize the message"""
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
def to_dict(self, add_signature=True):
def to_dict(self, add_signature: bool = True) -> OrderedDict[str, Any]:
"""Convert the message to a dictionary"""
obj = to_ordered(
@@ -142,20 +158,21 @@ class Message:
return obj
def verify(self, signature):
def verify(self, signature: str) -> bool:
"""Verify the signature of the message"""
return self.signature == signature
@property
def hash(self):
def hash(self) -> str:
"""The cryptographic hash of the message"""
hash_ = sha256(self.serialize()).digest()
return b64encode(hash_).decode("ascii") + ".sha256"
@property
def key(self):
def key(self) -> str:
"""The key of the message"""
return "%" + self.hash
@@ -165,25 +182,21 @@ class LocalMessage(Message):
"""Class representing a local message"""
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
self,
feed: LocalFeed,
content: Dict[str, Any],
signature: Optional[str] = None,
sequence: int = 1,
timestamp: Optional[int] = None,
previous: Optional["LocalMessage"] = None,
):
self.feed = feed
self.content = content
super().__init__(feed, content, signature=signature, sequence=sequence, timestamp=timestamp, previous=previous)
self.previous = previous
if self.previous:
self.sequence = self.previous.sequence + 1
else:
self.sequence = sequence
self.timestamp = get_millis_1970() if timestamp is None else timestamp
if signature is None:
def _check_signature(self) -> None:
if self.signature is None:
self.signature = self._sign()
else:
self.signature = signature
def _sign(self):
def _sign(self) -> str:
# ensure ordering of keys and indentation of 2 characters, like ssb-keys
data = self.serialize(add_signature=False)
return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii")

View File

@@ -22,7 +22,16 @@
"""MuxRPC"""
from ssb.packet_stream import PSMessageType
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Union
from typing_extensions import Self
from .packet_stream import PacketStream, PSMessage, PSMessageType, PSRequestHandler, PSStreamHandler
MuxRPCJSON = Dict[str, Any]
MuxRPCCallType = Literal["async", "duplex", "sink", "source", "sync"]
MuxRPCRequestHandlerType = Callable[[PacketStream, "MuxRPCRequest"], None]
MuxRPCRequestParam = Union[bytes, str, MuxRPCJSON] # pylint: disable=invalid-name
class MuxRPCAPIException(Exception):
@@ -32,7 +41,7 @@ class MuxRPCAPIException(Exception):
class MuxRPCHandler: # pylint: disable=too-few-public-methods
"""Base MuxRPC handler class"""
def check_message(self, msg):
def check_message(self, msg: PSMessage) -> None:
"""Check message validity"""
body = msg.body
@@ -40,34 +49,53 @@ class MuxRPCHandler: # pylint: disable=too-few-public-methods
if isinstance(body, dict) and "name" in body and body["name"] == "Error":
raise MuxRPCAPIException(body["message"])
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
raise NotImplementedError()
class MuxRPCRequestHandler(MuxRPCHandler):
async def __anext__(self) -> Optional[PSMessage]:
raise NotImplementedError()
def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None:
"""Send a message through the stream"""
raise NotImplementedError()
async def get_response(self) -> PSMessage:
"""Get the response for an RPC request"""
raise NotImplementedError()
class MuxRPCRequestHandler(MuxRPCHandler): # pylint: disable=abstract-method
"""MuxRPC handler for incoming RPC requests"""
def __init__(self, ps_handler):
def __init__(self, ps_handler: PSRequestHandler):
self.ps_handler = ps_handler
async def get_response(self):
async def get_response(self) -> PSMessage:
"""Get the response data"""
msg = await self.ps_handler
msg = await self.ps_handler.__anext__()
self.check_message(msg)
return msg
class MuxRPCSourceHandler(MuxRPCHandler):
class MuxRPCSourceHandler(MuxRPCHandler): # pylint: disable=abstract-method
"""MuxRPC handler for source-type RPC requests"""
def __init__(self, ps_handler):
def __init__(self, ps_handler: PSStreamHandler):
self.ps_handler = ps_handler
def __aiter__(self):
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
return self
async def __anext__(self):
async def __anext__(self) -> Optional[PSMessage]:
msg = await self.ps_handler.__anext__()
assert msg
self.check_message(msg)
return msg
@@ -76,64 +104,74 @@ class MuxRPCSourceHandler(MuxRPCHandler):
class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods
"""Mixin for sink-type MuxRPC handlers"""
def send(self, msg, msg_type=PSMessageType.JSON, end=False):
connection: PacketStream
req: int
def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None:
"""Send a message through the stream"""
self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end)
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler):
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler): # pylint: disable=abstract-method
"""MuxRPC handler for duplex streams"""
def __init__(self, ps_handler, connection, req):
def __init__(self, ps_handler: PSStreamHandler, connection: PacketStream, req: int):
super().__init__(ps_handler)
self.connection = connection
self.req = req
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin):
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin): # pylint: disable=abstract-method
"""MuxRPC handler for sinks"""
def __init__(self, connection, req):
def __init__(self, connection: PacketStream, req: int):
self.connection = connection
self.req = req
def _get_appropriate_api_handler(type_, connection, ps_handler, req):
def _get_appropriate_api_handler(
type_: MuxRPCCallType, connection: PacketStream, ps_handler: Union[PSRequestHandler, PSStreamHandler], req: int
) -> MuxRPCHandler:
"""Find the appropriate MuxRPC handler"""
if type_ in {"sync", "async"}:
assert isinstance(ps_handler, PSRequestHandler)
return MuxRPCRequestHandler(ps_handler)
if type_ == "source":
assert isinstance(ps_handler, PSStreamHandler)
return MuxRPCSourceHandler(ps_handler)
if type_ == "sink":
return MuxRPCSinkHandler(connection, req)
if type_ == "duplex":
assert isinstance(ps_handler, PSStreamHandler)
return MuxRPCDuplexHandler(ps_handler, connection, req)
return None
raise TypeError(f"Unknown request type {type_}")
class MuxRPCRequest:
"""MuxRPC request"""
@classmethod
def from_message(cls, message):
def from_message(cls, message: PSMessage) -> Self:
"""Initialise a request from a raw packet stream message"""
body = message.body
assert isinstance(body, dict)
return cls(".".join(body["name"]), body["args"])
def __init__(self, name, args):
def __init__(self, name: str, args: List[MuxRPCRequestParam]):
self.name = name
self.args = args
def __repr__(self):
def __repr__(self) -> str:
return f"<MuxRPCRequest {self.name} {self.args}>"
@@ -141,28 +179,30 @@ class MuxRPCMessage:
"""MuxRPC message"""
@classmethod
def from_message(cls, message):
def from_message(cls, message: PSMessage) -> Self:
"""Initialise a MuxRPC message from a raw packet stream message"""
return cls(message.body)
def __init__(self, body):
def __init__(self, body: Union[bytes, str, Dict[str, Any]]):
self.body = body
def __repr__(self):
return f"<MuxRPCMessage {self.body}>"
def __repr__(self) -> str:
return f"<MuxRPCMessage {self.body!r}>"
class MuxRPCAPI:
"""Generic MuxRPC API"""
def __init__(self):
self.handlers = {}
self.connection = None
def __init__(self) -> None:
self.handlers: Dict[str, MuxRPCRequestHandlerType] = {}
self.connection: Optional[PacketStream] = None
async def process_messages(self):
async def process_messages(self) -> None:
"""Continuously process incoming messages"""
assert self.connection
async for req_message in self.connection:
if req_message is None:
return
@@ -172,22 +212,22 @@ class MuxRPCAPI:
if isinstance(body, dict) and body.get("name"):
self.process(self.connection, MuxRPCRequest.from_message(req_message))
def add_connection(self, connection):
def add_connection(self, connection: PacketStream) -> None:
"""Set the packet stream connection of this RPC API"""
self.connection = connection
def define(self, name):
def define(self, name: str) -> Callable[[MuxRPCRequestHandlerType], MuxRPCRequestHandlerType]:
"""Decorator to define an RPC method handler"""
def _handle(f):
def _handle(f: MuxRPCRequestHandlerType) -> MuxRPCRequestHandlerType:
self.handlers[name] = f
return f
return _handle
def process(self, connection, request):
def process(self, connection: PacketStream, request: MuxRPCRequest) -> None:
"""Process an incoming request"""
handler = self.handlers.get(request.name)
@@ -197,9 +237,11 @@ class MuxRPCAPI:
handler(connection, request)
def call(self, name, args, type_="sync"):
def call(self, name: str, args: List[MuxRPCRequestParam], type_: MuxRPCCallType = "sync") -> MuxRPCHandler:
"""Call an RPC method"""
assert self.connection
if not self.connection.is_connected:
raise Exception("not connected") # pylint: disable=broad-exception-raised

View File

@@ -28,9 +28,13 @@ import logging
from math import ceil
import struct
from time import time
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
from secret_handshake.network import SHSDuplexStream
import simplejson
from typing_extensions import Self
PSMessageData = Union[bytes, bool, Dict[str, Any], str]
logger = logging.getLogger("packet_stream")
@@ -45,25 +49,27 @@ class PSMessageType(Enum):
class PSStreamHandler:
"""Packet stream handler"""
def __init__(self, req):
super(PSStreamHandler).__init__()
def __init__(self, req: int):
super().__init__()
self.req = req
self.queue = Queue()
self.queue: Queue["PSMessage"] = Queue()
async def process(self, msg):
async def process(self, msg: "PSMessage") -> None:
"""Process a pending message"""
await self.queue.put(msg)
async def stop(self):
async def stop(self) -> None:
"""Stop a pending request"""
await self.queue.put(None)
# We use the None value internally to signal __anext__ that the stream can be closed. It is not used otherwise,
# hence the typing ignore
await self.queue.put(None) # type: ignore[arg-type]
def __aiter__(self):
def __aiter__(self) -> AsyncIterator[Optional["PSMessage"]]:
return self
async def __anext__(self):
async def __anext__(self) -> Optional["PSMessage"]:
elem = await self.queue.get()
if not elem:
@@ -75,30 +81,32 @@ class PSStreamHandler:
class PSRequestHandler:
"""Packet stream request handler"""
def __init__(self, req):
def __init__(self, req: int):
self.req = req
self.event = Event()
self._msg = None
self._msg: Optional["PSMessage"] = None
async def process(self, msg):
async def process(self, msg: "PSMessage") -> None:
"""Process a message request"""
self._msg = msg
self.event.set()
async def stop(self):
async def stop(self) -> None:
"""Stop a pending event request"""
if not self.event.is_set():
self.event.set()
def __aiter__(self):
def __aiter__(self) -> AsyncIterator["PSMessage"]:
return self
async def __anext__(self):
async def __anext__(self) -> "PSMessage":
# wait until 'process' is called
await self.event.wait()
assert self._msg
return self._msg
@@ -106,42 +114,55 @@ class PSMessage:
"""Packet Stream message"""
@classmethod
def from_header_body(cls, flags, req, body):
def from_header_body(cls, flags: int, req: int, body: bytes) -> Self:
"""Parse a raw message"""
type_ = PSMessageType(flags & 0x03)
if type_ == PSMessageType.TEXT:
body = body.decode("utf-8")
decoded_body: Union[str, Dict[str, Any], bytes] = body.decode("utf-8")
elif type_ == PSMessageType.JSON:
body = simplejson.loads(body)
decoded_body = simplejson.loads(body)
else:
decoded_body = body
return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req)
return cls(type_, decoded_body, bool(flags & 0x08), bool(flags & 0x04), req=req)
@property
def data(self):
def data(self) -> bytes:
"""The raw message data"""
if self.type == PSMessageType.TEXT:
assert isinstance(self.body, str)
return self.body.encode("utf-8")
if self.type == PSMessageType.JSON:
assert isinstance(self.body, dict)
return simplejson.dumps(self.body).encode("utf-8")
assert isinstance(self.body, bytes)
return self.body
def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments
def __init__(
self,
type_: PSMessageType,
body: Union[bytes, str, Dict[str, Any]],
stream: bool,
end_err: bool,
req: Optional[int] = None,
): # pylint: disable=too-many-arguments
self.stream = stream
self.end_err = end_err
self.type = type_
self.body = body
self.req = req
def __repr__(self):
def __repr__(self) -> str:
if self.type == PSMessageType.BUFFER:
body = f"{len(self.body)} bytes"
else:
body = self.body
body = str(self.body)
req = "" if self.req is None else f" [{self.req}]"
is_stream = "~" if self.stream else ""
@@ -153,79 +174,90 @@ class PSMessage:
class PacketStream:
"""SSB Packet stream"""
def __init__(self, connection):
def __init__(self, connection: SHSDuplexStream):
self.connection = connection
self.req_counter = 1
self._event_map = {}
self._event_map: Dict[int, Tuple[float, Union[PSRequestHandler, PSStreamHandler]]] = {}
self._connected = False
def register_handler(self, handler):
def register_handler(self, handler: Union[PSRequestHandler, PSStreamHandler]) -> None:
"""Register an RPC handler"""
self._event_map[handler.req] = (time(), handler)
@property
def is_connected(self):
def is_connected(self) -> bool:
"""Check if the stream is connected"""
return self.connection.is_connected
def __aiter__(self):
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
return self
async def __anext__(self):
async def __anext__(self) -> PSMessage:
while True:
msg = await self.read()
if not msg:
raise StopAsyncIteration()
if msg.req >= 0:
if msg.req is not None and msg.req >= 0:
logger.info("RECV: %r", msg)
return msg
async def _read(self):
async def _read(self) -> Optional[PSMessage]:
try:
header = await self.connection.read()
if not header or header == b"\x00" * 9:
return
return None
flags, length, req = struct.unpack(">BIi", header)
n_packets = ceil(length / 4096)
body = b""
for _ in range(n_packets):
body += await self.connection.read()
read_data = await self.connection.read()
if not read_data:
logger.debug("DISCONNECT")
self.connection.disconnect()
return None
body += read_data
logger.debug("READ %s %s", header, len(body))
return PSMessage.from_header_body(flags, req, body)
except StopAsyncIteration:
logger.debug("DISCONNECT")
self.connection.disconnect()
return None
async def read(self):
async def read(self) -> Optional[PSMessage]:
"""Read data from the packet stream"""
msg = await self._read()
if not msg:
return None
# check whether it's a reply and handle accordingly
if msg.req < 0:
if msg.req is not None and msg.req < 0:
_, handler = self._event_map[-msg.req]
await handler.process(msg)
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
if msg.end_err:
await handler.stop()
del self._event_map[-msg.req]
logger.info("RESPONSE [%d]: EOS", -msg.req)
return msg
def _write(self, msg):
def _write(self, msg: PSMessage) -> None:
logger.info("SEND [%d]: %r", msg.req, msg)
header = struct.pack(
">BIi",
@@ -239,11 +271,17 @@ class PacketStream:
logger.debug("WRITE DATA: %s", msg.data)
def send( # pylint: disable=too-many-arguments
self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None
):
self,
data: Union[bytes, str, Dict[str, Any]],
msg_type: PSMessageType = PSMessageType.JSON,
stream: bool = False,
end_err: bool = False,
req: Optional[int] = None,
) -> Union[PSRequestHandler, PSStreamHandler]:
"""Send data through the packet stream"""
update_counter = False
if req is None:
update_counter = True
req = self.req_counter
@@ -254,16 +292,18 @@ class PacketStream:
self._write(msg)
if stream:
handler = PSStreamHandler(self.req_counter)
handler: Union[PSRequestHandler, PSStreamHandler] = PSStreamHandler(self.req_counter)
else:
handler = PSRequestHandler(self.req_counter)
self.register_handler(handler)
if update_counter:
self.req_counter += 1
return handler
def disconnect(self):
def disconnect(self) -> None:
"""Disconnect the stream"""
self._connected = False

0
ssb/py.typed Normal file
View File

View File

@@ -24,23 +24,30 @@
from base64 import b64decode, b64encode
import os
from typing import Optional
from typing import Optional, TypedDict
from nacl.signing import SigningKey
from nacl.signing import SigningKey, VerifyKey
import yaml
class SSBSecret(TypedDict):
"""Dictionary to hold an SSB identity"""
keypair: SigningKey
id: str
class ConfigException(Exception):
"""Exception to raise if there is a problem with the configuration data"""
def tag(key):
def tag(key: VerifyKey) -> bytes:
"""Create tag from public key"""
return b"@" + b64encode(bytes(key)) + b".ed25519"
def load_ssb_secret(filename: Optional[str] = None):
def load_ssb_secret(filename: Optional[str] = None) -> SSBSecret:
"""Load SSB keys from ``filename`` or, if unset, from ``~/.ssb/secret``"""
filename = filename or os.path.expanduser("~/.ssb/secret")