ci: Add and configure mypy, and make it happy
This commit is contained in:
@@ -26,8 +26,11 @@ from base64 import b64encode
|
||||
from collections import OrderedDict, namedtuple
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from nacl.signing import SigningKey, VerifyKey
|
||||
from simplejson import dumps, loads
|
||||
from typing_extensions import Self
|
||||
|
||||
from ssb.util import tag
|
||||
|
||||
@@ -38,7 +41,7 @@ class NoPrivateKeyException(Exception):
|
||||
"""Exception to raise when a private key is not available"""
|
||||
|
||||
|
||||
def to_ordered(data):
|
||||
def to_ordered(data: Dict[str, Any]) -> OrderedDict[str, Any]:
|
||||
"""Convert a dictionary to an ``OrderedDict``"""
|
||||
|
||||
smsg = OrderedMsg(**data)
|
||||
@@ -46,7 +49,7 @@ def to_ordered(data):
|
||||
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
|
||||
|
||||
|
||||
def get_millis_1970():
|
||||
def get_millis_1970() -> int:
|
||||
"""Get the UNIX timestamp in milliseconds"""
|
||||
|
||||
return int(datetime.utcnow().timestamp() * 1000)
|
||||
@@ -55,16 +58,16 @@ def get_millis_1970():
|
||||
class Feed:
|
||||
"""Base class for feeds"""
|
||||
|
||||
def __init__(self, public_key):
|
||||
def __init__(self, public_key: VerifyKey) -> None:
|
||||
self.public_key = public_key
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
def id(self) -> str:
|
||||
"""The identifier of the feed"""
|
||||
|
||||
return tag(self.public_key).decode("ascii")
|
||||
|
||||
def sign(self, msg):
|
||||
def sign(self, msg: bytes) -> bytes:
|
||||
"""Sign a message"""
|
||||
|
||||
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
|
||||
@@ -73,16 +76,20 @@ class Feed:
|
||||
class LocalFeed(Feed):
|
||||
"""Class representing a local feed"""
|
||||
|
||||
def __init__(self, private_key): # pylint: disable=super-init-not-called
|
||||
def __init__(self, private_key: SigningKey) -> None: # pylint: disable=super-init-not-called
|
||||
self.private_key = private_key
|
||||
|
||||
@property
|
||||
def public_key(self):
|
||||
def public_key(self) -> VerifyKey:
|
||||
"""The public key of the feed"""
|
||||
|
||||
return self.private_key.verify_key
|
||||
|
||||
def sign(self, msg):
|
||||
@public_key.setter
|
||||
def public_key(self, key: VerifyKey) -> None:
|
||||
raise TypeError("Can not set only the public key for a local feed")
|
||||
|
||||
def sign(self, msg: bytes) -> bytes:
|
||||
"""Sign a message for this feed"""
|
||||
|
||||
return self.private_key.sign(msg).signature
|
||||
@@ -92,25 +99,34 @@ class Message:
|
||||
"""Base class for SSB messages"""
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
|
||||
self,
|
||||
feed: Feed,
|
||||
content: Dict[str, Any],
|
||||
signature: Optional[str] = None,
|
||||
sequence: int = 1,
|
||||
timestamp: Optional[int] = None,
|
||||
previous: Optional["Message"] = None,
|
||||
):
|
||||
self.feed = feed
|
||||
self.content = content
|
||||
|
||||
if signature is None:
|
||||
raise ValueError("signature can't be None")
|
||||
self.signature = signature
|
||||
|
||||
self.previous = previous
|
||||
self.timestamp = get_millis_1970() if timestamp is None else timestamp
|
||||
|
||||
if self.previous:
|
||||
self.sequence = self.previous.sequence + 1
|
||||
self.sequence: int = self.previous.sequence + 1
|
||||
else:
|
||||
self.sequence = sequence
|
||||
|
||||
self.timestamp = get_millis_1970() if timestamp is None else timestamp
|
||||
self._check_signature()
|
||||
|
||||
def _check_signature(self) -> None:
|
||||
if self.signature is None:
|
||||
raise ValueError("signature can't be None")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, data, feed):
|
||||
def parse(cls, data: bytes, feed: Feed) -> Self:
|
||||
"""Parse raw message data"""
|
||||
|
||||
obj = loads(data, object_pairs_hook=OrderedDict)
|
||||
@@ -118,12 +134,12 @@ class Message:
|
||||
|
||||
return msg
|
||||
|
||||
def serialize(self, add_signature=True):
|
||||
def serialize(self, add_signature: bool = True) -> bytes:
|
||||
"""Serialize the message"""
|
||||
|
||||
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
|
||||
|
||||
def to_dict(self, add_signature=True):
|
||||
def to_dict(self, add_signature: bool = True) -> OrderedDict[str, Any]:
|
||||
"""Convert the message to a dictionary"""
|
||||
|
||||
obj = to_ordered(
|
||||
@@ -142,20 +158,21 @@ class Message:
|
||||
|
||||
return obj
|
||||
|
||||
def verify(self, signature):
|
||||
def verify(self, signature: str) -> bool:
|
||||
"""Verify the signature of the message"""
|
||||
|
||||
return self.signature == signature
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
def hash(self) -> str:
|
||||
"""The cryptographic hash of the message"""
|
||||
|
||||
hash_ = sha256(self.serialize()).digest()
|
||||
|
||||
return b64encode(hash_).decode("ascii") + ".sha256"
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
def key(self) -> str:
|
||||
"""The key of the message"""
|
||||
|
||||
return "%" + self.hash
|
||||
@@ -165,25 +182,21 @@ class LocalMessage(Message):
|
||||
"""Class representing a local message"""
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
|
||||
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
|
||||
self,
|
||||
feed: LocalFeed,
|
||||
content: Dict[str, Any],
|
||||
signature: Optional[str] = None,
|
||||
sequence: int = 1,
|
||||
timestamp: Optional[int] = None,
|
||||
previous: Optional["LocalMessage"] = None,
|
||||
):
|
||||
self.feed = feed
|
||||
self.content = content
|
||||
super().__init__(feed, content, signature=signature, sequence=sequence, timestamp=timestamp, previous=previous)
|
||||
|
||||
self.previous = previous
|
||||
if self.previous:
|
||||
self.sequence = self.previous.sequence + 1
|
||||
else:
|
||||
self.sequence = sequence
|
||||
|
||||
self.timestamp = get_millis_1970() if timestamp is None else timestamp
|
||||
|
||||
if signature is None:
|
||||
def _check_signature(self) -> None:
|
||||
if self.signature is None:
|
||||
self.signature = self._sign()
|
||||
else:
|
||||
self.signature = signature
|
||||
|
||||
def _sign(self):
|
||||
def _sign(self) -> str:
|
||||
# ensure ordering of keys and indentation of 2 characters, like ssb-keys
|
||||
data = self.serialize(add_signature=False)
|
||||
return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii")
|
||||
|
108
ssb/muxrpc.py
108
ssb/muxrpc.py
@@ -22,7 +22,16 @@
|
||||
|
||||
"""MuxRPC"""
|
||||
|
||||
from ssb.packet_stream import PSMessageType
|
||||
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from .packet_stream import PacketStream, PSMessage, PSMessageType, PSRequestHandler, PSStreamHandler
|
||||
|
||||
MuxRPCJSON = Dict[str, Any]
|
||||
MuxRPCCallType = Literal["async", "duplex", "sink", "source", "sync"]
|
||||
MuxRPCRequestHandlerType = Callable[[PacketStream, "MuxRPCRequest"], None]
|
||||
MuxRPCRequestParam = Union[bytes, str, MuxRPCJSON] # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class MuxRPCAPIException(Exception):
|
||||
@@ -32,7 +41,7 @@ class MuxRPCAPIException(Exception):
|
||||
class MuxRPCHandler: # pylint: disable=too-few-public-methods
|
||||
"""Base MuxRPC handler class"""
|
||||
|
||||
def check_message(self, msg):
|
||||
def check_message(self, msg: PSMessage) -> None:
|
||||
"""Check message validity"""
|
||||
|
||||
body = msg.body
|
||||
@@ -40,34 +49,53 @@ class MuxRPCHandler: # pylint: disable=too-few-public-methods
|
||||
if isinstance(body, dict) and "name" in body and body["name"] == "Error":
|
||||
raise MuxRPCAPIException(body["message"])
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
|
||||
raise NotImplementedError()
|
||||
|
||||
class MuxRPCRequestHandler(MuxRPCHandler):
|
||||
async def __anext__(self) -> Optional[PSMessage]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None:
|
||||
"""Send a message through the stream"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
async def get_response(self) -> PSMessage:
|
||||
"""Get the response for an RPC request"""
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MuxRPCRequestHandler(MuxRPCHandler): # pylint: disable=abstract-method
|
||||
"""MuxRPC handler for incoming RPC requests"""
|
||||
|
||||
def __init__(self, ps_handler):
|
||||
def __init__(self, ps_handler: PSRequestHandler):
|
||||
self.ps_handler = ps_handler
|
||||
|
||||
async def get_response(self):
|
||||
async def get_response(self) -> PSMessage:
|
||||
"""Get the response data"""
|
||||
|
||||
msg = await self.ps_handler
|
||||
msg = await self.ps_handler.__anext__()
|
||||
|
||||
self.check_message(msg)
|
||||
|
||||
return msg
|
||||
|
||||
|
||||
class MuxRPCSourceHandler(MuxRPCHandler):
|
||||
class MuxRPCSourceHandler(MuxRPCHandler): # pylint: disable=abstract-method
|
||||
"""MuxRPC handler for source-type RPC requests"""
|
||||
|
||||
def __init__(self, ps_handler):
|
||||
def __init__(self, ps_handler: PSStreamHandler):
|
||||
self.ps_handler = ps_handler
|
||||
|
||||
def __aiter__(self):
|
||||
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
async def __anext__(self) -> Optional[PSMessage]:
|
||||
msg = await self.ps_handler.__anext__()
|
||||
|
||||
assert msg
|
||||
|
||||
self.check_message(msg)
|
||||
|
||||
return msg
|
||||
@@ -76,64 +104,74 @@ class MuxRPCSourceHandler(MuxRPCHandler):
|
||||
class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods
|
||||
"""Mixin for sink-type MuxRPC handlers"""
|
||||
|
||||
def send(self, msg, msg_type=PSMessageType.JSON, end=False):
|
||||
connection: PacketStream
|
||||
req: int
|
||||
|
||||
def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None:
|
||||
"""Send a message through the stream"""
|
||||
|
||||
self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end)
|
||||
|
||||
|
||||
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler):
|
||||
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler): # pylint: disable=abstract-method
|
||||
"""MuxRPC handler for duplex streams"""
|
||||
|
||||
def __init__(self, ps_handler, connection, req):
|
||||
def __init__(self, ps_handler: PSStreamHandler, connection: PacketStream, req: int):
|
||||
super().__init__(ps_handler)
|
||||
|
||||
self.connection = connection
|
||||
self.req = req
|
||||
|
||||
|
||||
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin):
|
||||
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin): # pylint: disable=abstract-method
|
||||
"""MuxRPC handler for sinks"""
|
||||
|
||||
def __init__(self, connection, req):
|
||||
def __init__(self, connection: PacketStream, req: int):
|
||||
self.connection = connection
|
||||
self.req = req
|
||||
|
||||
|
||||
def _get_appropriate_api_handler(type_, connection, ps_handler, req):
|
||||
def _get_appropriate_api_handler(
|
||||
type_: MuxRPCCallType, connection: PacketStream, ps_handler: Union[PSRequestHandler, PSStreamHandler], req: int
|
||||
) -> MuxRPCHandler:
|
||||
"""Find the appropriate MuxRPC handler"""
|
||||
|
||||
if type_ in {"sync", "async"}:
|
||||
assert isinstance(ps_handler, PSRequestHandler)
|
||||
return MuxRPCRequestHandler(ps_handler)
|
||||
|
||||
if type_ == "source":
|
||||
assert isinstance(ps_handler, PSStreamHandler)
|
||||
return MuxRPCSourceHandler(ps_handler)
|
||||
|
||||
if type_ == "sink":
|
||||
return MuxRPCSinkHandler(connection, req)
|
||||
|
||||
if type_ == "duplex":
|
||||
assert isinstance(ps_handler, PSStreamHandler)
|
||||
return MuxRPCDuplexHandler(ps_handler, connection, req)
|
||||
|
||||
return None
|
||||
raise TypeError(f"Unknown request type {type_}")
|
||||
|
||||
|
||||
class MuxRPCRequest:
|
||||
"""MuxRPC request"""
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message):
|
||||
def from_message(cls, message: PSMessage) -> Self:
|
||||
"""Initialise a request from a raw packet stream message"""
|
||||
|
||||
body = message.body
|
||||
|
||||
assert isinstance(body, dict)
|
||||
|
||||
return cls(".".join(body["name"]), body["args"])
|
||||
|
||||
def __init__(self, name, args):
|
||||
def __init__(self, name: str, args: List[MuxRPCRequestParam]):
|
||||
self.name = name
|
||||
self.args = args
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"<MuxRPCRequest {self.name} {self.args}>"
|
||||
|
||||
|
||||
@@ -141,28 +179,30 @@ class MuxRPCMessage:
|
||||
"""MuxRPC message"""
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message):
|
||||
def from_message(cls, message: PSMessage) -> Self:
|
||||
"""Initialise a MuxRPC message from a raw packet stream message"""
|
||||
|
||||
return cls(message.body)
|
||||
|
||||
def __init__(self, body):
|
||||
def __init__(self, body: Union[bytes, str, Dict[str, Any]]):
|
||||
self.body = body
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MuxRPCMessage {self.body}>"
|
||||
def __repr__(self) -> str:
|
||||
return f"<MuxRPCMessage {self.body!r}>"
|
||||
|
||||
|
||||
class MuxRPCAPI:
|
||||
"""Generic MuxRPC API"""
|
||||
|
||||
def __init__(self):
|
||||
self.handlers = {}
|
||||
self.connection = None
|
||||
def __init__(self) -> None:
|
||||
self.handlers: Dict[str, MuxRPCRequestHandlerType] = {}
|
||||
self.connection: Optional[PacketStream] = None
|
||||
|
||||
async def process_messages(self):
|
||||
async def process_messages(self) -> None:
|
||||
"""Continuously process incoming messages"""
|
||||
|
||||
assert self.connection
|
||||
|
||||
async for req_message in self.connection:
|
||||
if req_message is None:
|
||||
return
|
||||
@@ -172,22 +212,22 @@ class MuxRPCAPI:
|
||||
if isinstance(body, dict) and body.get("name"):
|
||||
self.process(self.connection, MuxRPCRequest.from_message(req_message))
|
||||
|
||||
def add_connection(self, connection):
|
||||
def add_connection(self, connection: PacketStream) -> None:
|
||||
"""Set the packet stream connection of this RPC API"""
|
||||
|
||||
self.connection = connection
|
||||
|
||||
def define(self, name):
|
||||
def define(self, name: str) -> Callable[[MuxRPCRequestHandlerType], MuxRPCRequestHandlerType]:
|
||||
"""Decorator to define an RPC method handler"""
|
||||
|
||||
def _handle(f):
|
||||
def _handle(f: MuxRPCRequestHandlerType) -> MuxRPCRequestHandlerType:
|
||||
self.handlers[name] = f
|
||||
|
||||
return f
|
||||
|
||||
return _handle
|
||||
|
||||
def process(self, connection, request):
|
||||
def process(self, connection: PacketStream, request: MuxRPCRequest) -> None:
|
||||
"""Process an incoming request"""
|
||||
|
||||
handler = self.handlers.get(request.name)
|
||||
@@ -197,9 +237,11 @@ class MuxRPCAPI:
|
||||
|
||||
handler(connection, request)
|
||||
|
||||
def call(self, name, args, type_="sync"):
|
||||
def call(self, name: str, args: List[MuxRPCRequestParam], type_: MuxRPCCallType = "sync") -> MuxRPCHandler:
|
||||
"""Call an RPC method"""
|
||||
|
||||
assert self.connection
|
||||
|
||||
if not self.connection.is_connected:
|
||||
raise Exception("not connected") # pylint: disable=broad-exception-raised
|
||||
|
||||
|
@@ -28,9 +28,13 @@ import logging
|
||||
from math import ceil
|
||||
import struct
|
||||
from time import time
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
|
||||
|
||||
from secret_handshake.network import SHSDuplexStream
|
||||
import simplejson
|
||||
from typing_extensions import Self
|
||||
|
||||
PSMessageData = Union[bytes, bool, Dict[str, Any], str]
|
||||
logger = logging.getLogger("packet_stream")
|
||||
|
||||
|
||||
@@ -45,25 +49,27 @@ class PSMessageType(Enum):
|
||||
class PSStreamHandler:
|
||||
"""Packet stream handler"""
|
||||
|
||||
def __init__(self, req):
|
||||
super(PSStreamHandler).__init__()
|
||||
def __init__(self, req: int):
|
||||
super().__init__()
|
||||
self.req = req
|
||||
self.queue = Queue()
|
||||
self.queue: Queue["PSMessage"] = Queue()
|
||||
|
||||
async def process(self, msg):
|
||||
async def process(self, msg: "PSMessage") -> None:
|
||||
"""Process a pending message"""
|
||||
|
||||
await self.queue.put(msg)
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
"""Stop a pending request"""
|
||||
|
||||
await self.queue.put(None)
|
||||
# We use the None value internally to signal __anext__ that the stream can be closed. It is not used otherwise,
|
||||
# hence the typing ignore
|
||||
await self.queue.put(None) # type: ignore[arg-type]
|
||||
|
||||
def __aiter__(self):
|
||||
def __aiter__(self) -> AsyncIterator[Optional["PSMessage"]]:
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
async def __anext__(self) -> Optional["PSMessage"]:
|
||||
elem = await self.queue.get()
|
||||
|
||||
if not elem:
|
||||
@@ -75,30 +81,32 @@ class PSStreamHandler:
|
||||
class PSRequestHandler:
|
||||
"""Packet stream request handler"""
|
||||
|
||||
def __init__(self, req):
|
||||
def __init__(self, req: int):
|
||||
self.req = req
|
||||
self.event = Event()
|
||||
self._msg = None
|
||||
self._msg: Optional["PSMessage"] = None
|
||||
|
||||
async def process(self, msg):
|
||||
async def process(self, msg: "PSMessage") -> None:
|
||||
"""Process a message request"""
|
||||
|
||||
self._msg = msg
|
||||
self.event.set()
|
||||
|
||||
async def stop(self):
|
||||
async def stop(self) -> None:
|
||||
"""Stop a pending event request"""
|
||||
|
||||
if not self.event.is_set():
|
||||
self.event.set()
|
||||
|
||||
def __aiter__(self):
|
||||
def __aiter__(self) -> AsyncIterator["PSMessage"]:
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
async def __anext__(self) -> "PSMessage":
|
||||
# wait until 'process' is called
|
||||
await self.event.wait()
|
||||
|
||||
assert self._msg
|
||||
|
||||
return self._msg
|
||||
|
||||
|
||||
@@ -106,42 +114,55 @@ class PSMessage:
|
||||
"""Packet Stream message"""
|
||||
|
||||
@classmethod
|
||||
def from_header_body(cls, flags, req, body):
|
||||
def from_header_body(cls, flags: int, req: int, body: bytes) -> Self:
|
||||
"""Parse a raw message"""
|
||||
|
||||
type_ = PSMessageType(flags & 0x03)
|
||||
|
||||
if type_ == PSMessageType.TEXT:
|
||||
body = body.decode("utf-8")
|
||||
decoded_body: Union[str, Dict[str, Any], bytes] = body.decode("utf-8")
|
||||
elif type_ == PSMessageType.JSON:
|
||||
body = simplejson.loads(body)
|
||||
decoded_body = simplejson.loads(body)
|
||||
else:
|
||||
decoded_body = body
|
||||
|
||||
return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req)
|
||||
return cls(type_, decoded_body, bool(flags & 0x08), bool(flags & 0x04), req=req)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
def data(self) -> bytes:
|
||||
"""The raw message data"""
|
||||
|
||||
if self.type == PSMessageType.TEXT:
|
||||
assert isinstance(self.body, str)
|
||||
return self.body.encode("utf-8")
|
||||
|
||||
if self.type == PSMessageType.JSON:
|
||||
assert isinstance(self.body, dict)
|
||||
return simplejson.dumps(self.body).encode("utf-8")
|
||||
|
||||
assert isinstance(self.body, bytes)
|
||||
|
||||
return self.body
|
||||
|
||||
def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments
|
||||
def __init__(
|
||||
self,
|
||||
type_: PSMessageType,
|
||||
body: Union[bytes, str, Dict[str, Any]],
|
||||
stream: bool,
|
||||
end_err: bool,
|
||||
req: Optional[int] = None,
|
||||
): # pylint: disable=too-many-arguments
|
||||
self.stream = stream
|
||||
self.end_err = end_err
|
||||
self.type = type_
|
||||
self.body = body
|
||||
self.req = req
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
if self.type == PSMessageType.BUFFER:
|
||||
body = f"{len(self.body)} bytes"
|
||||
else:
|
||||
body = self.body
|
||||
body = str(self.body)
|
||||
|
||||
req = "" if self.req is None else f" [{self.req}]"
|
||||
is_stream = "~" if self.stream else ""
|
||||
@@ -153,79 +174,90 @@ class PSMessage:
|
||||
class PacketStream:
|
||||
"""SSB Packet stream"""
|
||||
|
||||
def __init__(self, connection):
|
||||
def __init__(self, connection: SHSDuplexStream):
|
||||
self.connection = connection
|
||||
self.req_counter = 1
|
||||
self._event_map = {}
|
||||
self._event_map: Dict[int, Tuple[float, Union[PSRequestHandler, PSStreamHandler]]] = {}
|
||||
self._connected = False
|
||||
|
||||
def register_handler(self, handler):
|
||||
def register_handler(self, handler: Union[PSRequestHandler, PSStreamHandler]) -> None:
|
||||
"""Register an RPC handler"""
|
||||
|
||||
self._event_map[handler.req] = (time(), handler)
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the stream is connected"""
|
||||
|
||||
return self.connection.is_connected
|
||||
|
||||
def __aiter__(self):
|
||||
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
async def __anext__(self) -> PSMessage:
|
||||
while True:
|
||||
msg = await self.read()
|
||||
|
||||
if not msg:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
if msg.req >= 0:
|
||||
if msg.req is not None and msg.req >= 0:
|
||||
logger.info("RECV: %r", msg)
|
||||
|
||||
return msg
|
||||
|
||||
async def _read(self):
|
||||
async def _read(self) -> Optional[PSMessage]:
|
||||
try:
|
||||
header = await self.connection.read()
|
||||
|
||||
if not header or header == b"\x00" * 9:
|
||||
return
|
||||
return None
|
||||
|
||||
flags, length, req = struct.unpack(">BIi", header)
|
||||
|
||||
n_packets = ceil(length / 4096)
|
||||
|
||||
body = b""
|
||||
|
||||
for _ in range(n_packets):
|
||||
body += await self.connection.read()
|
||||
read_data = await self.connection.read()
|
||||
|
||||
if not read_data:
|
||||
logger.debug("DISCONNECT")
|
||||
self.connection.disconnect()
|
||||
|
||||
return None
|
||||
|
||||
body += read_data
|
||||
|
||||
logger.debug("READ %s %s", header, len(body))
|
||||
|
||||
return PSMessage.from_header_body(flags, req, body)
|
||||
except StopAsyncIteration:
|
||||
logger.debug("DISCONNECT")
|
||||
self.connection.disconnect()
|
||||
return None
|
||||
|
||||
async def read(self):
|
||||
async def read(self) -> Optional[PSMessage]:
|
||||
"""Read data from the packet stream"""
|
||||
|
||||
msg = await self._read()
|
||||
|
||||
if not msg:
|
||||
return None
|
||||
|
||||
# check whether it's a reply and handle accordingly
|
||||
if msg.req < 0:
|
||||
if msg.req is not None and msg.req < 0:
|
||||
_, handler = self._event_map[-msg.req]
|
||||
await handler.process(msg)
|
||||
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
|
||||
|
||||
if msg.end_err:
|
||||
await handler.stop()
|
||||
del self._event_map[-msg.req]
|
||||
logger.info("RESPONSE [%d]: EOS", -msg.req)
|
||||
|
||||
return msg
|
||||
|
||||
def _write(self, msg):
|
||||
def _write(self, msg: PSMessage) -> None:
|
||||
logger.info("SEND [%d]: %r", msg.req, msg)
|
||||
header = struct.pack(
|
||||
">BIi",
|
||||
@@ -239,11 +271,17 @@ class PacketStream:
|
||||
logger.debug("WRITE DATA: %s", msg.data)
|
||||
|
||||
def send( # pylint: disable=too-many-arguments
|
||||
self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None
|
||||
):
|
||||
self,
|
||||
data: Union[bytes, str, Dict[str, Any]],
|
||||
msg_type: PSMessageType = PSMessageType.JSON,
|
||||
stream: bool = False,
|
||||
end_err: bool = False,
|
||||
req: Optional[int] = None,
|
||||
) -> Union[PSRequestHandler, PSStreamHandler]:
|
||||
"""Send data through the packet stream"""
|
||||
|
||||
update_counter = False
|
||||
|
||||
if req is None:
|
||||
update_counter = True
|
||||
req = self.req_counter
|
||||
@@ -254,16 +292,18 @@ class PacketStream:
|
||||
self._write(msg)
|
||||
|
||||
if stream:
|
||||
handler = PSStreamHandler(self.req_counter)
|
||||
handler: Union[PSRequestHandler, PSStreamHandler] = PSStreamHandler(self.req_counter)
|
||||
else:
|
||||
handler = PSRequestHandler(self.req_counter)
|
||||
|
||||
self.register_handler(handler)
|
||||
|
||||
if update_counter:
|
||||
self.req_counter += 1
|
||||
|
||||
return handler
|
||||
|
||||
def disconnect(self):
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect the stream"""
|
||||
|
||||
self._connected = False
|
||||
|
0
ssb/py.typed
Normal file
0
ssb/py.typed
Normal file
15
ssb/util.py
15
ssb/util.py
@@ -24,23 +24,30 @@
|
||||
|
||||
from base64 import b64decode, b64encode
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
from nacl.signing import SigningKey
|
||||
from nacl.signing import SigningKey, VerifyKey
|
||||
import yaml
|
||||
|
||||
|
||||
class SSBSecret(TypedDict):
|
||||
"""Dictionary to hold an SSB identity"""
|
||||
|
||||
keypair: SigningKey
|
||||
id: str
|
||||
|
||||
|
||||
class ConfigException(Exception):
|
||||
"""Exception to raise if there is a problem with the configuration data"""
|
||||
|
||||
|
||||
def tag(key):
|
||||
def tag(key: VerifyKey) -> bytes:
|
||||
"""Create tag from public key"""
|
||||
|
||||
return b"@" + b64encode(bytes(key)) + b".ed25519"
|
||||
|
||||
|
||||
def load_ssb_secret(filename: Optional[str] = None):
|
||||
def load_ssb_secret(filename: Optional[str] = None) -> SSBSecret:
|
||||
"""Load SSB keys from ``filename`` or, if unset, from ``~/.ssb/secret``"""
|
||||
|
||||
filename = filename or os.path.expanduser("~/.ssb/secret")
|
||||
|
Reference in New Issue
Block a user