ci: Add and configure mypy, and make it happy
This commit is contained in:
parent
5aa4f16a5a
commit
f6e58b7682
24
poetry.lock
generated
24
poetry.lock
generated
@ -1258,6 +1258,28 @@ files = [
|
|||||||
{file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"},
|
{file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "types-pyyaml"
|
||||||
|
version = "6.0.12.12"
|
||||||
|
description = "Typing stubs for PyYAML"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"},
|
||||||
|
{file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "types-simplejson"
|
||||||
|
version = "3.19.0.2"
|
||||||
|
description = "Typing stubs for simplejson"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "types-simplejson-3.19.0.2.tar.gz", hash = "sha256:ebc81f886f89d99d6b80c726518aa2228bc77c26438f18fd81455e4f79f8ee1b"},
|
||||||
|
{file = "types_simplejson-3.19.0.2-py3-none-any.whl", hash = "sha256:8ba093dc7884f59b3e62aed217144085e675a269debc32678fd80e0b43b2b86f"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.8.0"
|
version = "4.8.0"
|
||||||
@ -1315,4 +1337,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "b503e50c4ab977c6785c68bc1e5bf2efb7ab88a1fd33770e84a7f612a85d2641"
|
content-hash = "d80cbfdf7923c50c95505a84d8ad75eae016ca81ae32a8b22d074569b0a0fcbd"
|
||||||
|
@ -27,6 +27,8 @@ commitizen = "^3.12.0"
|
|||||||
black = "^23.10.1"
|
black = "^23.10.1"
|
||||||
pylint = "^3.0.2"
|
pylint = "^3.0.2"
|
||||||
mypy = "^1.6.1"
|
mypy = "^1.6.1"
|
||||||
|
types-pyyaml = "^6.0.12.12"
|
||||||
|
types-simplejson = "^3.19.0.2"
|
||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
Sphinx = "^2.1.1"
|
Sphinx = "^2.1.1"
|
||||||
|
@ -4,8 +4,11 @@ from base64 import b64encode
|
|||||||
from collections import OrderedDict, namedtuple
|
from collections import OrderedDict, namedtuple
|
||||||
import datetime
|
import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from nacl.signing import SigningKey, VerifyKey
|
||||||
from simplejson import dumps, loads
|
from simplejson import dumps, loads
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ssb.util import tag
|
from ssb.util import tag
|
||||||
|
|
||||||
@ -16,7 +19,7 @@ class NoPrivateKeyException(Exception):
|
|||||||
"""Exception to raise when a private key is not available"""
|
"""Exception to raise when a private key is not available"""
|
||||||
|
|
||||||
|
|
||||||
def to_ordered(data):
|
def to_ordered(data: Dict[str, Any]) -> OrderedDict[str, Any]:
|
||||||
"""Convert a dictionary to an ``OrderedDict``"""
|
"""Convert a dictionary to an ``OrderedDict``"""
|
||||||
|
|
||||||
smsg = OrderedMsg(**data)
|
smsg = OrderedMsg(**data)
|
||||||
@ -24,7 +27,7 @@ def to_ordered(data):
|
|||||||
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
|
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
|
||||||
|
|
||||||
|
|
||||||
def get_millis_1970():
|
def get_millis_1970() -> int:
|
||||||
"""Get the UNIX timestamp in milliseconds"""
|
"""Get the UNIX timestamp in milliseconds"""
|
||||||
|
|
||||||
return int(datetime.datetime.utcnow().timestamp() * 1000)
|
return int(datetime.datetime.utcnow().timestamp() * 1000)
|
||||||
@ -33,16 +36,16 @@ def get_millis_1970():
|
|||||||
class Feed:
|
class Feed:
|
||||||
"""Base class for feeds"""
|
"""Base class for feeds"""
|
||||||
|
|
||||||
def __init__(self, public_key):
|
def __init__(self, public_key: VerifyKey):
|
||||||
self.public_key = public_key
|
self.public_key = public_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self):
|
def id(self) -> str:
|
||||||
"""The identifier of the feed"""
|
"""The identifier of the feed"""
|
||||||
|
|
||||||
return tag(self.public_key).decode("ascii")
|
return tag(self.public_key).decode("ascii")
|
||||||
|
|
||||||
def sign(self, msg):
|
def sign(self, msg: "Message") -> bytes:
|
||||||
"""Sign a message"""
|
"""Sign a message"""
|
||||||
|
|
||||||
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
|
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
|
||||||
@ -51,16 +54,20 @@ class Feed:
|
|||||||
class LocalFeed(Feed):
|
class LocalFeed(Feed):
|
||||||
"""Class representing a local feed"""
|
"""Class representing a local feed"""
|
||||||
|
|
||||||
def __init__(self, private_key): # pylint: disable=super-init-not-called
|
def __init__(self, private_key: SigningKey): # pylint: disable=super-init-not-called
|
||||||
self.private_key = private_key
|
self.private_key: SigningKey = private_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def public_key(self):
|
def public_key(self) -> VerifyKey:
|
||||||
"""The public key of the feed"""
|
"""The public key of the feed"""
|
||||||
|
|
||||||
return self.private_key.verify_key
|
return self.private_key.verify_key
|
||||||
|
|
||||||
def sign(self, msg):
|
@public_key.setter
|
||||||
|
def public_key(self, _: VerifyKey) -> None:
|
||||||
|
raise TypeError("Cannot set just the public key of a local feed")
|
||||||
|
|
||||||
|
def sign(self, msg: "Message") -> bytes:
|
||||||
"""Sign a message for this feed"""
|
"""Sign a message for this feed"""
|
||||||
|
|
||||||
return self.private_key.sign(msg).signature
|
return self.private_key.sign(msg).signature
|
||||||
@ -70,7 +77,13 @@ class Message:
|
|||||||
"""Base class for SSB messages"""
|
"""Base class for SSB messages"""
|
||||||
|
|
||||||
def __init__( # pylint: disable=too-many-arguments
|
def __init__( # pylint: disable=too-many-arguments
|
||||||
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
|
self,
|
||||||
|
feed: Feed,
|
||||||
|
content: Dict[str, Any],
|
||||||
|
signature: Optional[str] = None,
|
||||||
|
sequence: int = 1,
|
||||||
|
timestamp: Optional[int] = None,
|
||||||
|
previous: Optional["Message"] = None,
|
||||||
):
|
):
|
||||||
self.feed = feed
|
self.feed = feed
|
||||||
self.content = content
|
self.content = content
|
||||||
@ -81,15 +94,16 @@ class Message:
|
|||||||
self.signature = signature
|
self.signature = signature
|
||||||
|
|
||||||
self.previous = previous
|
self.previous = previous
|
||||||
|
|
||||||
if self.previous:
|
if self.previous:
|
||||||
self.sequence = self.previous.sequence + 1
|
self.sequence: int = self.previous.sequence + 1
|
||||||
else:
|
else:
|
||||||
self.sequence = sequence
|
self.sequence = sequence
|
||||||
|
|
||||||
self.timestamp = get_millis_1970() if timestamp is None else timestamp
|
self.timestamp = get_millis_1970() if timestamp is None else timestamp
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse(cls, data, feed):
|
def parse(cls, data: bytes, feed: Feed) -> Self:
|
||||||
"""Parse raw message data"""
|
"""Parse raw message data"""
|
||||||
|
|
||||||
obj = loads(data, object_pairs_hook=OrderedDict)
|
obj = loads(data, object_pairs_hook=OrderedDict)
|
||||||
@ -97,12 +111,12 @@ class Message:
|
|||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def serialize(self, add_signature=True):
|
def serialize(self, add_signature: bool = True) -> bytes:
|
||||||
"""Serialize the message"""
|
"""Serialize the message"""
|
||||||
|
|
||||||
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
|
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
|
||||||
|
|
||||||
def to_dict(self, add_signature=True):
|
def to_dict(self, add_signature: bool = True) -> OrderedDict[str, Any]:
|
||||||
"""Convert the message to a dictionary"""
|
"""Convert the message to a dictionary"""
|
||||||
|
|
||||||
obj = to_ordered(
|
obj = to_ordered(
|
||||||
@ -121,20 +135,21 @@ class Message:
|
|||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def verify(self, signature):
|
def verify(self, signature: str) -> bool:
|
||||||
"""Verify the signature of the message"""
|
"""Verify the signature of the message"""
|
||||||
|
|
||||||
return self.signature == signature
|
return self.signature == signature
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hash(self):
|
def hash(self) -> str:
|
||||||
"""The cryptographic hash of the message"""
|
"""The cryptographic hash of the message"""
|
||||||
|
|
||||||
hash_ = sha256(self.serialize()).digest()
|
hash_ = sha256(self.serialize()).digest()
|
||||||
|
|
||||||
return b64encode(hash_).decode("ascii") + ".sha256"
|
return b64encode(hash_).decode("ascii") + ".sha256"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self):
|
def key(self) -> str:
|
||||||
"""The key of the message"""
|
"""The key of the message"""
|
||||||
|
|
||||||
return "%" + self.hash
|
return "%" + self.hash
|
||||||
@ -144,7 +159,13 @@ class LocalMessage(Message):
|
|||||||
"""Class representing a local message"""
|
"""Class representing a local message"""
|
||||||
|
|
||||||
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
|
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
|
||||||
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
|
self,
|
||||||
|
feed: Feed,
|
||||||
|
content: Dict[str, Any],
|
||||||
|
signature: Optional[str] = None,
|
||||||
|
sequence: int = 1,
|
||||||
|
timestamp: Optional[int] = None,
|
||||||
|
previous: Optional[Message] = None,
|
||||||
):
|
):
|
||||||
self.feed = feed
|
self.feed = feed
|
||||||
self.content = content
|
self.content = content
|
||||||
@ -162,7 +183,8 @@ class LocalMessage(Message):
|
|||||||
else:
|
else:
|
||||||
self.signature = signature
|
self.signature = signature
|
||||||
|
|
||||||
def _sign(self):
|
def _sign(self) -> str:
|
||||||
# ensure ordering of keys and indentation of 2 characters, like ssb-keys
|
# ensure ordering of keys and indentation of 2 characters, like ssb-keys
|
||||||
data = self.serialize(add_signature=False)
|
data = self.serialize(add_signature=False)
|
||||||
|
|
||||||
return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii")
|
return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii")
|
||||||
|
@ -6,14 +6,15 @@ import logging
|
|||||||
from math import ceil
|
from math import ceil
|
||||||
import struct
|
import struct
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, AsyncIterator, Dict, Optional, Union
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
from secret_handshake.network import SHSDuplexStream
|
from secret_handshake.network import SHSDuplexStream
|
||||||
import simplejson
|
import simplejson
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
logger = logging.getLogger("packet_stream")
|
PSHandler = Union["PSRequestHandler", "PSStreamHandler"]
|
||||||
|
|
||||||
PSMessageData = Union[bytes, bool, Dict[str, Any], str]
|
PSMessageData = Union[bytes, bool, Dict[str, Any], str]
|
||||||
|
logger = logging.getLogger("packet_stream")
|
||||||
|
|
||||||
|
|
||||||
class PSMessageType(Enum):
|
class PSMessageType(Enum):
|
||||||
@ -31,12 +32,12 @@ class PSStreamHandler:
|
|||||||
self.req = req
|
self.req = req
|
||||||
self.queue: Queue[Optional["PSMessage"]] = Queue()
|
self.queue: Queue[Optional["PSMessage"]] = Queue()
|
||||||
|
|
||||||
async def process(self, msg):
|
async def process(self, msg: "PSMessage") -> None:
|
||||||
"""Process a pending message"""
|
"""Process a pending message"""
|
||||||
|
|
||||||
await self.queue.put(msg)
|
await self.queue.put(msg)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self) -> None:
|
||||||
"""Stop a pending request"""
|
"""Stop a pending request"""
|
||||||
|
|
||||||
await self.queue.put(None)
|
await self.queue.put(None)
|
||||||
@ -59,15 +60,15 @@ class PSRequestHandler:
|
|||||||
def __init__(self, req: int):
|
def __init__(self, req: int):
|
||||||
self.req = req
|
self.req = req
|
||||||
self.event = Event()
|
self.event = Event()
|
||||||
self._msg = None
|
self._msg: Optional[PSMessage] = None
|
||||||
|
|
||||||
async def process(self, msg):
|
async def process(self, msg: "PSMessage") -> None:
|
||||||
"""Process a message request"""
|
"""Process a message request"""
|
||||||
|
|
||||||
self._msg = msg
|
self._msg = msg
|
||||||
self.event.set()
|
self.event.set()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self) -> None:
|
||||||
"""Stop a pending event request"""
|
"""Stop a pending event request"""
|
||||||
|
|
||||||
if not self.event.is_set():
|
if not self.event.is_set():
|
||||||
@ -87,37 +88,44 @@ class PSMessage:
|
|||||||
"""Packet Stream message"""
|
"""Packet Stream message"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_header_body(cls, flags, req, body):
|
def from_header_body(cls, flags: int, req: int, body: bytes) -> Self:
|
||||||
"""Parse a raw message"""
|
"""Parse a raw message"""
|
||||||
|
|
||||||
type_ = PSMessageType(flags & 0x03)
|
type_ = PSMessageType(flags & 0x03)
|
||||||
|
|
||||||
if type_ == PSMessageType.TEXT:
|
if type_ == PSMessageType.TEXT:
|
||||||
body = body.decode("utf-8")
|
body_s = body.decode("utf-8")
|
||||||
elif type_ == PSMessageType.JSON:
|
elif type_ == PSMessageType.JSON:
|
||||||
body = simplejson.loads(body)
|
body_s = simplejson.loads(body)
|
||||||
|
|
||||||
return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req)
|
return cls(type_, body_s, bool(flags & 0x08), bool(flags & 0x04), req=req)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self) -> bytes:
|
def data(self) -> bytes:
|
||||||
"""The raw message data"""
|
"""The raw message data"""
|
||||||
|
|
||||||
if self.type == PSMessageType.TEXT:
|
if self.type == PSMessageType.TEXT:
|
||||||
|
assert isinstance(self.body, str)
|
||||||
|
|
||||||
return self.body.encode("utf-8")
|
return self.body.encode("utf-8")
|
||||||
|
|
||||||
if self.type == PSMessageType.JSON:
|
if self.type == PSMessageType.JSON:
|
||||||
return simplejson.dumps(self.body).encode("utf-8")
|
return simplejson.dumps(self.body).encode("utf-8")
|
||||||
|
|
||||||
|
assert isinstance(self.body, bytes)
|
||||||
|
|
||||||
return self.body
|
return self.body
|
||||||
|
|
||||||
def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments
|
def __init__(
|
||||||
|
self, type_: PSMessageType, body: Any, stream: bool, end_err: bool, req: Optional[int] = None
|
||||||
|
): # pylint: disable=too-many-arguments
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.end_err = end_err
|
self.end_err = end_err
|
||||||
self.type = type_
|
self.type = type_
|
||||||
self.body = body
|
self.body = body
|
||||||
self.req = req
|
self.req = req
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
if self.type == PSMessageType.BUFFER:
|
if self.type == PSMessageType.BUFFER:
|
||||||
body = f"{len(self.body)} bytes"
|
body = f"{len(self.body)} bytes"
|
||||||
else:
|
else:
|
||||||
@ -136,16 +144,16 @@ class PacketStream:
|
|||||||
def __init__(self, connection: SHSDuplexStream):
|
def __init__(self, connection: SHSDuplexStream):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.req_counter = 1
|
self.req_counter = 1
|
||||||
self._event_map = {}
|
self._event_map: Dict[int, Tuple[float, PSHandler]] = {}
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
def register_handler(self, handler):
|
def register_handler(self, handler: PSHandler) -> None:
|
||||||
"""Register an RPC handler"""
|
"""Register an RPC handler"""
|
||||||
|
|
||||||
self._event_map[handler.req] = (time(), handler)
|
self._event_map[handler.req] = (time(), handler)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self) -> bool:
|
||||||
"""Check if the stream is connected"""
|
"""Check if the stream is connected"""
|
||||||
|
|
||||||
return self.connection.is_connected
|
return self.connection.is_connected
|
||||||
@ -159,22 +167,25 @@ class PacketStream:
|
|||||||
if not msg:
|
if not msg:
|
||||||
raise StopAsyncIteration()
|
raise StopAsyncIteration()
|
||||||
|
|
||||||
if msg.req >= 0:
|
if msg.req is not None and msg.req >= 0:
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def __await__(self):
|
async def __await__(self) -> None:
|
||||||
async for data in self:
|
async for data in self:
|
||||||
logger.info("RECV: %r", data)
|
logger.info("RECV: %r", data)
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def _read(self):
|
async def _read(self) -> Optional[PSMessage]:
|
||||||
try:
|
try:
|
||||||
header = await self.connection.read()
|
header = await self.connection.read()
|
||||||
|
|
||||||
if not header or header == b"\x00" * 9:
|
if not header or header == b"\x00" * 9:
|
||||||
return
|
return None
|
||||||
|
|
||||||
flags, length, req = struct.unpack(">BIi", header)
|
flags, length, req = struct.unpack(">BIi", header)
|
||||||
|
|
||||||
n_packets = ceil(length / 4096)
|
n_packets = ceil(length / 4096)
|
||||||
@ -182,30 +193,39 @@ class PacketStream:
|
|||||||
body = b""
|
body = b""
|
||||||
|
|
||||||
for _ in range(n_packets):
|
for _ in range(n_packets):
|
||||||
body += await self.connection.read()
|
read_data = await self.connection.read()
|
||||||
|
|
||||||
|
if read_data is not None:
|
||||||
|
body += read_data
|
||||||
|
|
||||||
logger.debug("READ %s %s", header, len(body))
|
logger.debug("READ %s %s", header, len(body))
|
||||||
|
|
||||||
return PSMessage.from_header_body(flags, req, body)
|
return PSMessage.from_header_body(flags, req, body)
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
logger.debug("DISCONNECT")
|
logger.debug("DISCONNECT")
|
||||||
self.connection.disconnect()
|
self.connection.disconnect()
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def read(self):
|
async def read(self) -> Optional[PSMessage]:
|
||||||
"""Read data from the packet stream"""
|
"""Read data from the packet stream"""
|
||||||
|
|
||||||
msg = await self._read()
|
msg = await self._read()
|
||||||
|
|
||||||
if not msg:
|
if not msg:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# check whether it's a reply and handle accordingly
|
# check whether it's a reply and handle accordingly
|
||||||
if msg.req < 0:
|
if msg.req is not None and msg.req < 0:
|
||||||
_, handler = self._event_map[-msg.req]
|
_, handler = self._event_map[-msg.req]
|
||||||
await handler.process(msg)
|
await handler.process(msg)
|
||||||
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
|
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
|
||||||
|
|
||||||
if msg.end_err:
|
if msg.end_err:
|
||||||
await handler.stop()
|
await handler.stop()
|
||||||
del self._event_map[-msg.req]
|
del self._event_map[-msg.req]
|
||||||
logger.info("RESPONSE [%d]: EOS", -msg.req)
|
logger.info("RESPONSE [%d]: EOS", -msg.req)
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def _write(self, msg: PSMessage) -> None:
|
def _write(self, msg: PSMessage) -> None:
|
||||||
@ -225,7 +245,7 @@ class PacketStream:
|
|||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
end_err: bool = False,
|
end_err: bool = False,
|
||||||
req: Optional[int] = None,
|
req: Optional[int] = None,
|
||||||
):
|
) -> PSHandler:
|
||||||
"""Send data through the packet stream"""
|
"""Send data through the packet stream"""
|
||||||
|
|
||||||
update_counter = False
|
update_counter = False
|
||||||
@ -240,7 +260,7 @@ class PacketStream:
|
|||||||
self._write(msg)
|
self._write(msg)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
handler = PSStreamHandler(self.req_counter)
|
handler: PSHandler = PSStreamHandler(self.req_counter)
|
||||||
else:
|
else:
|
||||||
handler = PSRequestHandler(self.req_counter)
|
handler = PSRequestHandler(self.req_counter)
|
||||||
|
|
||||||
@ -251,7 +271,7 @@ class PacketStream:
|
|||||||
|
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self) -> None:
|
||||||
"""Disconnect the stream"""
|
"""Disconnect the stream"""
|
||||||
|
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
@ -4,12 +4,12 @@ from base64 import b64decode, b64encode
|
|||||||
import os
|
import os
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
from nacl.signing import SigningKey
|
from nacl.signing import SigningKey, VerifyKey
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
class SSBSecret(TypedDict):
|
class SSBSecret(TypedDict):
|
||||||
"""Dictionary to hold an SSB identity"""
|
"""Dictionary type to hold an SSB secret identity"""
|
||||||
|
|
||||||
keypair: SigningKey
|
keypair: SigningKey
|
||||||
id: str
|
id: str
|
||||||
@ -19,7 +19,7 @@ class ConfigException(Exception):
|
|||||||
"""Exception to raise if there is a problem with the configuration data"""
|
"""Exception to raise if there is a problem with the configuration data"""
|
||||||
|
|
||||||
|
|
||||||
def tag(key):
|
def tag(key: VerifyKey) -> bytes:
|
||||||
"""Create tag from public key."""
|
"""Create tag from public key."""
|
||||||
|
|
||||||
return b"@" + b64encode(bytes(key)) + b".ed25519"
|
return b"@" + b64encode(bytes(key)) + b".ed25519"
|
||||||
@ -35,4 +35,5 @@ def load_ssb_secret() -> SSBSecret:
|
|||||||
raise ConfigException("Algorithm not known: " + config["curve"])
|
raise ConfigException("Algorithm not known: " + config["curve"])
|
||||||
|
|
||||||
server_prv_key = b64decode(config["private"][:-8])
|
server_prv_key = b64decode(config["private"][:-8])
|
||||||
|
|
||||||
return {"keypair": SigningKey(server_prv_key[:32]), "id": config["id"]}
|
return {"keypair": SigningKey(server_prv_key[:32]), "id": config["id"]}
|
||||||
|
@ -25,7 +25,7 @@ SERIALIZED_M1 = b"""{
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def local_feed():
|
def local_feed() -> LocalFeed:
|
||||||
"""Fixture providing a local feed"""
|
"""Fixture providing a local feed"""
|
||||||
|
|
||||||
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
||||||
@ -33,14 +33,14 @@ def local_feed():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def remote_feed():
|
def remote_feed() -> Feed:
|
||||||
"""Fixture providing a remote feed"""
|
"""Fixture providing a remote feed"""
|
||||||
|
|
||||||
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
||||||
return Feed(VerifyKey(public))
|
return Feed(VerifyKey(public))
|
||||||
|
|
||||||
|
|
||||||
def test_local_feed():
|
def test_local_feed() -> None:
|
||||||
"""Test a local feed"""
|
"""Test a local feed"""
|
||||||
|
|
||||||
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
||||||
@ -50,7 +50,7 @@ def test_local_feed():
|
|||||||
assert feed.id == "@I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=.ed25519"
|
assert feed.id == "@I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=.ed25519"
|
||||||
|
|
||||||
|
|
||||||
def test_remote_feed():
|
def test_remote_feed() -> None:
|
||||||
"""Test a remote feed"""
|
"""Test a remote feed"""
|
||||||
|
|
||||||
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
||||||
@ -69,7 +69,7 @@ def test_remote_feed():
|
|||||||
feed.sign(m1)
|
feed.sign(m1)
|
||||||
|
|
||||||
|
|
||||||
def test_local_message(local_feed): # pylint: disable=redefined-outer-name
|
def test_local_message(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test a local message"""
|
"""Test a local message"""
|
||||||
|
|
||||||
m1 = LocalMessage(
|
m1 = LocalMessage(
|
||||||
@ -102,7 +102,7 @@ def test_local_message(local_feed): # pylint: disable=redefined-outer-name
|
|||||||
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
||||||
|
|
||||||
|
|
||||||
def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
|
def test_remote_message(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test a remote message"""
|
"""Test a remote message"""
|
||||||
|
|
||||||
signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519"
|
signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519"
|
||||||
@ -136,7 +136,7 @@ def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
|
|||||||
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
||||||
|
|
||||||
|
|
||||||
def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name
|
def test_remote_no_signature(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test remote feed without a signature"""
|
"""Test remote feed without a signature"""
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
@ -150,7 +150,7 @@ def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-na
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_serialize(local_feed): # pylint: disable=redefined-outer-name
|
def test_serialize(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test feed serialization"""
|
"""Test feed serialization"""
|
||||||
|
|
||||||
m1 = LocalMessage(
|
m1 = LocalMessage(
|
||||||
@ -162,7 +162,7 @@ def test_serialize(local_feed): # pylint: disable=redefined-outer-name
|
|||||||
assert m1.serialize() == SERIALIZED_M1
|
assert m1.serialize() == SERIALIZED_M1
|
||||||
|
|
||||||
|
|
||||||
def test_parse(local_feed): # pylint: disable=redefined-outer-name
|
def test_parse(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test feed parsing"""
|
"""Test feed parsing"""
|
||||||
|
|
||||||
m1 = LocalMessage.parse(SERIALIZED_M1, local_feed)
|
m1 = LocalMessage.parse(SERIALIZED_M1, local_feed)
|
||||||
|
@ -1,18 +1,23 @@
|
|||||||
"""Tests for the packet stream"""
|
"""Tests for the packet stream"""
|
||||||
|
|
||||||
from asyncio import Event, ensure_future, gather
|
from asyncio import Event, ensure_future, gather
|
||||||
|
from asyncio.events import AbstractEventLoop
|
||||||
import json
|
import json
|
||||||
|
from typing import AsyncGenerator, Awaitable, Callable, Generator, List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
from secret_handshake.network import SHSDuplexStream
|
from secret_handshake.network import SHSDuplexStream
|
||||||
|
|
||||||
from ssb.packet_stream import PacketStream, PSMessageType
|
from ssb.packet_stream import PacketStream, PSMessage, PSMessageType
|
||||||
|
|
||||||
|
|
||||||
async def _collect_messages(generator):
|
async def _collect_messages(generator: AsyncGenerator[PSMessage, None]) -> List[PSMessage]:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
async for msg in generator:
|
async for msg in generator:
|
||||||
results.append(msg)
|
results.append(msg)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -39,45 +44,47 @@ MSG_BODY_2 = (
|
|||||||
class MockSHSSocket(SHSDuplexStream):
|
class MockSHSSocket(SHSDuplexStream):
|
||||||
"""A mocked SHS socket"""
|
"""A mocked SHS socket"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
def __init__(self): # pylint: disable=unused-argument
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.input = []
|
self.input: List[bytes] = []
|
||||||
self.output = []
|
self.output: List[bytes] = []
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self._on_connect = []
|
self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = []
|
||||||
|
|
||||||
def on_connect(self, cb):
|
def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
|
||||||
"""Set the on_connect callback"""
|
"""Set the on_connect callback"""
|
||||||
|
|
||||||
self._on_connect.append(cb)
|
self._on_connect.append(cb)
|
||||||
|
|
||||||
async def read(self):
|
async def read(self) -> bytes:
|
||||||
"""Read data from the socket"""
|
"""Read data from the socket"""
|
||||||
|
|
||||||
if not self.input:
|
if not self.input:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
return self.input.pop(0)
|
return self.input.pop(0)
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data: bytes) -> None:
|
||||||
"""Write data to the socket"""
|
"""Write data to the socket"""
|
||||||
|
|
||||||
self.output.append(data)
|
self.output.append(data)
|
||||||
|
|
||||||
def feed(self, input_):
|
def feed(self, input_: List[bytes]) -> None:
|
||||||
"""Get the connection’s feed"""
|
"""Feed data into the connection"""
|
||||||
|
|
||||||
self.input += input_
|
self.input += input_
|
||||||
|
|
||||||
def get_output(self):
|
def get_output(self) -> Generator[bytes, None, None]:
|
||||||
"""Get the output of a call"""
|
"""Get the output of a call"""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if not self.output:
|
if not self.output:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield self.output.pop(0)
|
yield self.output.pop(0)
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self) -> None:
|
||||||
"""Disconnect from the remote party"""
|
"""Disconnect from the remote party"""
|
||||||
|
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
@ -86,48 +93,48 @@ class MockSHSSocket(SHSDuplexStream):
|
|||||||
class MockSHSClient(MockSHSSocket):
|
class MockSHSClient(MockSHSSocket):
|
||||||
"""A mocked SHS client"""
|
"""A mocked SHS client"""
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self) -> None:
|
||||||
"""Connect to a SHS server"""
|
"""Connect to a SHS server"""
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
for cb in self._on_connect:
|
for cb in self._on_connect:
|
||||||
await cb()
|
await cb(self)
|
||||||
|
|
||||||
|
|
||||||
class MockSHSServer(MockSHSSocket):
|
class MockSHSServer(MockSHSSocket):
|
||||||
"""A mocked SHS server"""
|
"""A mocked SHS server"""
|
||||||
|
|
||||||
def listen(self):
|
def listen(self) -> None:
|
||||||
"""Listen for new connections"""
|
"""Listen for new connections"""
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
for cb in self._on_connect:
|
for cb in self._on_connect:
|
||||||
ensure_future(cb())
|
ensure_future(cb(self))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ps_client(event_loop): # pylint: disable=unused-argument
|
def ps_client(event_loop: AbstractEventLoop) -> MockSHSClient: # pylint: disable=unused-argument
|
||||||
"""Fixture to provide a mocked SHS client"""
|
"""Fixture to provide a mocked SHS client"""
|
||||||
|
|
||||||
return MockSHSClient()
|
return MockSHSClient()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ps_server(event_loop): # pylint: disable=unused-argument
|
def ps_server(event_loop: AbstractEventLoop) -> MockSHSServer: # pylint: disable=unused-argument
|
||||||
"""Fixture to provide a mocked SHS server"""
|
"""Fixture to provide a mocked SHS server"""
|
||||||
|
|
||||||
return MockSHSServer()
|
return MockSHSServer()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name
|
async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test the on_connect callback functionality"""
|
"""Test the on_connect callback functionality"""
|
||||||
|
|
||||||
called = Event()
|
called = Event()
|
||||||
|
|
||||||
async def _on_connect():
|
async def _on_connect(_: SHSDuplexStream) -> None:
|
||||||
called.set()
|
called.set()
|
||||||
|
|
||||||
ps_server.on_connect(_on_connect)
|
ps_server.on_connect(_on_connect)
|
||||||
@ -137,7 +144,7 @@ async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name
|
async def test_message_decoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test message decoding"""
|
"""Test message decoding"""
|
||||||
|
|
||||||
await ps_client.connect()
|
await ps_client.connect()
|
||||||
@ -167,7 +174,7 @@ async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-n
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name
|
async def test_message_encoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test message encoding"""
|
"""Test message encoding"""
|
||||||
|
|
||||||
await ps_client.connect()
|
await ps_client.connect()
|
||||||
@ -200,7 +207,7 @@ async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-n
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name
|
async def test_message_stream(ps_client: MockSHSClient, mocker: MockerFixture): # pylint: disable=redefined-outer-name
|
||||||
"""Test requesting a history stream"""
|
"""Test requesting a history stream"""
|
||||||
|
|
||||||
await ps_client.connect()
|
await ps_client.connect()
|
||||||
@ -272,7 +279,9 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name
|
async def test_message_request(
|
||||||
|
ps_server: MockSHSServer, mocker: MockerFixture # pylint: disable=redefined-outer-name
|
||||||
|
) -> None:
|
||||||
"""Test message sending"""
|
"""Test message sending"""
|
||||||
|
|
||||||
ps_server.listen()
|
ps_server.listen()
|
||||||
|
@ -20,7 +20,7 @@ CONFIG_FILE = """
|
|||||||
CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo")
|
CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo")
|
||||||
|
|
||||||
|
|
||||||
def test_load_secret():
|
def test_load_secret() -> None:
|
||||||
"""Test loading the SSB secret from a file"""
|
"""Test loading the SSB secret from a file"""
|
||||||
|
|
||||||
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True):
|
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True):
|
||||||
@ -33,8 +33,9 @@ def test_load_secret():
|
|||||||
assert bytes(secret["keypair"].verify_key) == b64decode("rsYpBIcXsxjQAf0JNes+MHqT2DL+EfopWKAp4rGeEPQ=")
|
assert bytes(secret["keypair"].verify_key) == b64decode("rsYpBIcXsxjQAf0JNes+MHqT2DL+EfopWKAp4rGeEPQ=")
|
||||||
|
|
||||||
|
|
||||||
def test_load_exception():
|
def test_load_exception() -> None:
|
||||||
"""Test configuration loading if there is a problem with the file"""
|
"""Test configuration loading if there is a problem with the file"""
|
||||||
|
|
||||||
with pytest.raises(ConfigException):
|
with pytest.raises(ConfigException):
|
||||||
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True):
|
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True):
|
||||||
load_ssb_secret()
|
load_ssb_secret()
|
||||||
|
Loading…
Reference in New Issue
Block a user