ci: Add and configure mypy, and make it happy
This commit is contained in:
parent
f2a54b5ce6
commit
1c1e57d868
@ -37,6 +37,13 @@ repos:
|
|||||||
language: system
|
language: system
|
||||||
require_serial: true
|
require_serial: true
|
||||||
types_or: [python, pyi]
|
types_or: [python, pyi]
|
||||||
|
- id: mypy
|
||||||
|
name: mypy
|
||||||
|
entry: poetry run mypy
|
||||||
|
args: ["--strict"]
|
||||||
|
language: system
|
||||||
|
types_or: [python, pyi]
|
||||||
|
require_serial: true
|
||||||
- id: reuse
|
- id: reuse
|
||||||
name: reuse
|
name: reuse
|
||||||
entry: poetry run reuse
|
entry: poetry run reuse
|
||||||
|
@ -30,9 +30,10 @@ import struct
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from colorlog import ColoredFormatter
|
from colorlog import ColoredFormatter
|
||||||
|
from nacl.signing import SigningKey
|
||||||
from secret_handshake.network import SHSClient
|
from secret_handshake.network import SHSClient
|
||||||
|
|
||||||
from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException
|
from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException, MuxRPCRequest
|
||||||
from ssb.packet_stream import PacketStream, PSMessageType
|
from ssb.packet_stream import PacketStream, PSMessageType
|
||||||
from ssb.util import load_ssb_secret
|
from ssb.util import load_ssb_secret
|
||||||
|
|
||||||
@ -40,7 +41,7 @@ api = MuxRPCAPI()
|
|||||||
|
|
||||||
|
|
||||||
@api.define("createHistoryStream")
|
@api.define("createHistoryStream")
|
||||||
def create_history_stream(connection, msg): # pylint: disable=unused-argument
|
def create_history_stream(connection: PacketStream, msg: MuxRPCRequest) -> None: # pylint: disable=unused-argument
|
||||||
"""Handle the createHistoryStream RPC call"""
|
"""Handle the createHistoryStream RPC call"""
|
||||||
|
|
||||||
print("create_history_stream", msg)
|
print("create_history_stream", msg)
|
||||||
@ -49,13 +50,13 @@ def create_history_stream(connection, msg): # pylint: disable=unused-argument
|
|||||||
|
|
||||||
|
|
||||||
@api.define("blobs.createWants")
|
@api.define("blobs.createWants")
|
||||||
def create_wants(connection, msg): # pylint: disable=unused-argument
|
def create_wants(connection: PacketStream, msg: MuxRPCRequest) -> None: # pylint: disable=unused-argument
|
||||||
"""Handle the createWants RPC call"""
|
"""Handle the createWants RPC call"""
|
||||||
|
|
||||||
print("create_wants", msg)
|
print("create_wants", msg)
|
||||||
|
|
||||||
|
|
||||||
async def test_client():
|
async def test_client() -> None:
|
||||||
"""The actual client implementation"""
|
"""The actual client implementation"""
|
||||||
|
|
||||||
async for msg in api.call(
|
async for msg in api.call(
|
||||||
@ -90,6 +91,8 @@ async def test_client():
|
|||||||
|
|
||||||
img_data = b""
|
img_data = b""
|
||||||
async for msg in api.call("blobs.get", ["&kqZ52sDcJSHOx7m4Ww80kK1KIZ65gpGnqwZlfaIVWWM=.sha256"], "source"):
|
async for msg in api.call("blobs.get", ["&kqZ52sDcJSHOx7m4Ww80kK1KIZ65gpGnqwZlfaIVWWM=.sha256"], "source"):
|
||||||
|
assert msg
|
||||||
|
|
||||||
if msg.type.name == "BUFFER":
|
if msg.type.name == "BUFFER":
|
||||||
img_data += msg.data
|
img_data += msg.data
|
||||||
if msg.type.name == "JSON" and msg.data == b"true":
|
if msg.type.name == "JSON" and msg.data == b"true":
|
||||||
@ -101,7 +104,7 @@ async def test_client():
|
|||||||
f.write(img_data)
|
f.write(img_data)
|
||||||
|
|
||||||
|
|
||||||
async def main(keypair):
|
async def main(keypair: SigningKey) -> None:
|
||||||
"""The main function to run"""
|
"""The main function to run"""
|
||||||
|
|
||||||
client = SHSClient("127.0.0.1", 8008, keypair, bytes(keypair.verify_key))
|
client = SHSClient("127.0.0.1", 8008, keypair, bytes(keypair.verify_key))
|
||||||
|
@ -27,6 +27,7 @@ import logging
|
|||||||
|
|
||||||
from colorlog import ColoredFormatter
|
from colorlog import ColoredFormatter
|
||||||
from secret_handshake import SHSServer
|
from secret_handshake import SHSServer
|
||||||
|
from secret_handshake.network import SHSDuplexStream
|
||||||
|
|
||||||
from ssb.muxrpc import MuxRPCAPI
|
from ssb.muxrpc import MuxRPCAPI
|
||||||
from ssb.packet_stream import PacketStream
|
from ssb.packet_stream import PacketStream
|
||||||
@ -35,7 +36,7 @@ from ssb.util import load_ssb_secret
|
|||||||
api = MuxRPCAPI()
|
api = MuxRPCAPI()
|
||||||
|
|
||||||
|
|
||||||
async def on_connect(conn):
|
async def on_connect(conn: SHSDuplexStream) -> None:
|
||||||
"""Incoming connection handler"""
|
"""Incoming connection handler"""
|
||||||
|
|
||||||
packet_stream = PacketStream(conn)
|
packet_stream = PacketStream(conn)
|
||||||
@ -46,7 +47,7 @@ async def on_connect(conn):
|
|||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main() -> None:
|
||||||
"""The main function to run"""
|
"""The main function to run"""
|
||||||
|
|
||||||
server = SHSServer("127.0.0.1", 8008, load_ssb_secret()["keypair"])
|
server = SHSServer("127.0.0.1", 8008, load_ssb_secret()["keypair"])
|
||||||
|
71
poetry.lock
generated
71
poetry.lock
generated
@ -649,6 +649,53 @@ files = [
|
|||||||
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mypy"
|
||||||
|
version = "1.7.0"
|
||||||
|
description = "Optional static typing for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "mypy-1.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5da84d7bf257fd8f66b4f759a904fd2c5a765f70d8b52dde62b521972a0a2357"},
|
||||||
|
{file = "mypy-1.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a3637c03f4025f6405737570d6cbfa4f1400eb3c649317634d273687a09ffc2f"},
|
||||||
|
{file = "mypy-1.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b633f188fc5ae1b6edca39dae566974d7ef4e9aaaae00bc36efe1f855e5173ac"},
|
||||||
|
{file = "mypy-1.7.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d6ed9a3997b90c6f891138e3f83fb8f475c74db4ccaa942a1c7bf99e83a989a1"},
|
||||||
|
{file = "mypy-1.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:1fe46e96ae319df21359c8db77e1aecac8e5949da4773c0274c0ef3d8d1268a9"},
|
||||||
|
{file = "mypy-1.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:df67fbeb666ee8828f675fee724cc2cbd2e4828cc3df56703e02fe6a421b7401"},
|
||||||
|
{file = "mypy-1.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a79cdc12a02eb526d808a32a934c6fe6df07b05f3573d210e41808020aed8b5d"},
|
||||||
|
{file = "mypy-1.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f65f385a6f43211effe8c682e8ec3f55d79391f70a201575def73d08db68ead1"},
|
||||||
|
{file = "mypy-1.7.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e81ffd120ee24959b449b647c4b2fbfcf8acf3465e082b8d58fd6c4c2b27e46"},
|
||||||
|
{file = "mypy-1.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:f29386804c3577c83d76520abf18cfcd7d68264c7e431c5907d250ab502658ee"},
|
||||||
|
{file = "mypy-1.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:87c076c174e2c7ef8ab416c4e252d94c08cd4980a10967754f91571070bf5fbe"},
|
||||||
|
{file = "mypy-1.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6cb8d5f6d0fcd9e708bb190b224089e45902cacef6f6915481806b0c77f7786d"},
|
||||||
|
{file = "mypy-1.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d93e76c2256aa50d9c82a88e2f569232e9862c9982095f6d54e13509f01222fc"},
|
||||||
|
{file = "mypy-1.7.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:cddee95dea7990e2215576fae95f6b78a8c12f4c089d7e4367564704e99118d3"},
|
||||||
|
{file = "mypy-1.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:d01921dbd691c4061a3e2ecdbfbfad029410c5c2b1ee88946bf45c62c6c91210"},
|
||||||
|
{file = "mypy-1.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:185cff9b9a7fec1f9f7d8352dff8a4c713b2e3eea9c6c4b5ff7f0edf46b91e41"},
|
||||||
|
{file = "mypy-1.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7a7b1e399c47b18feb6f8ad4a3eef3813e28c1e871ea7d4ea5d444b2ac03c418"},
|
||||||
|
{file = "mypy-1.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc9fe455ad58a20ec68599139ed1113b21f977b536a91b42bef3ffed5cce7391"},
|
||||||
|
{file = "mypy-1.7.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d0fa29919d2e720c8dbaf07d5578f93d7b313c3e9954c8ec05b6d83da592e5d9"},
|
||||||
|
{file = "mypy-1.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:2b53655a295c1ed1af9e96b462a736bf083adba7b314ae775563e3fb4e6795f5"},
|
||||||
|
{file = "mypy-1.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c1b06b4b109e342f7dccc9efda965fc3970a604db70f8560ddfdee7ef19afb05"},
|
||||||
|
{file = "mypy-1.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:bf7a2f0a6907f231d5e41adba1a82d7d88cf1f61a70335889412dec99feeb0f8"},
|
||||||
|
{file = "mypy-1.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:551d4a0cdcbd1d2cccdcc7cb516bb4ae888794929f5b040bb51aae1846062901"},
|
||||||
|
{file = "mypy-1.7.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:55d28d7963bef00c330cb6461db80b0b72afe2f3c4e2963c99517cf06454e665"},
|
||||||
|
{file = "mypy-1.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:870bd1ffc8a5862e593185a4c169804f2744112b4a7c55b93eb50f48e7a77010"},
|
||||||
|
{file = "mypy-1.7.0-py3-none-any.whl", hash = "sha256:96650d9a4c651bc2a4991cf46f100973f656d69edc7faf91844e87fe627f7e96"},
|
||||||
|
{file = "mypy-1.7.0.tar.gz", hash = "sha256:1e280b5697202efa698372d2f39e9a6713a0395a756b1c6bd48995f8d72690dc"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
mypy-extensions = ">=1.0.0"
|
||||||
|
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||||
|
typing-extensions = ">=4.1.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dmypy = ["psutil (>=4.0)"]
|
||||||
|
install-types = ["pip"]
|
||||||
|
mypyc = ["setuptools (>=50)"]
|
||||||
|
reports = ["lxml"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mypy-extensions"
|
name = "mypy-extensions"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
@ -1253,6 +1300,28 @@ files = [
|
|||||||
{file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"},
|
{file = "tomlkit-0.12.1.tar.gz", hash = "sha256:38e1ff8edb991273ec9f6181244a6a391ac30e9f5098e7535640ea6be97a7c86"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "types-pyyaml"
|
||||||
|
version = "6.0.12.12"
|
||||||
|
description = "Typing stubs for PyYAML"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"},
|
||||||
|
{file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "types-simplejson"
|
||||||
|
version = "3.19.0.2"
|
||||||
|
description = "Typing stubs for simplejson"
|
||||||
|
optional = false
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "types-simplejson-3.19.0.2.tar.gz", hash = "sha256:ebc81f886f89d99d6b80c726518aa2228bc77c26438f18fd81455e4f79f8ee1b"},
|
||||||
|
{file = "types_simplejson-3.19.0.2-py3-none-any.whl", hash = "sha256:8ba093dc7884f59b3e62aed217144085e675a269debc32678fd80e0b43b2b86f"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.8.0"
|
version = "4.8.0"
|
||||||
@ -1310,4 +1379,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "d57dc0c074d7daf70507fda1fc9641cf367b6dc8f02b34a5fceafe6b45c0f4f9"
|
content-hash = "98384046072d2dd4f649a93231ee6a84e5b21be34f15d5d2196cd3832f15ebca"
|
||||||
|
@ -8,6 +8,7 @@ description = "Secure Scuttlebutt library in Python"
|
|||||||
authors = ["PyScuttleButt Contributors <pedro@dete.st>"]
|
authors = ["PyScuttleButt Contributors <pedro@dete.st>"]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
readme = "README.rst"
|
readme = "README.rst"
|
||||||
|
include = ["ssb/py.typed"]
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
@ -23,12 +24,15 @@ check-manifest = "^0.39"
|
|||||||
commitizen = "^3.12.0"
|
commitizen = "^3.12.0"
|
||||||
coverage = "^7.3.2"
|
coverage = "^7.3.2"
|
||||||
isort = "^5.12.0"
|
isort = "^5.12.0"
|
||||||
|
mypy = "^1.6.1"
|
||||||
pep257 = "^0.7.0"
|
pep257 = "^0.7.0"
|
||||||
pylint = "^3.0.2"
|
pylint = "^3.0.2"
|
||||||
pytest = "^7.4.3"
|
pytest = "^7.4.3"
|
||||||
pytest-asyncio = "^0.21.1"
|
pytest-asyncio = "^0.21.1"
|
||||||
pytest-cov = "^4.1.0"
|
pytest-cov = "^4.1.0"
|
||||||
pytest-mock = "^3.12.0"
|
pytest-mock = "^3.12.0"
|
||||||
|
types-pyyaml = "^6.0.12.12"
|
||||||
|
types-simplejson = "^3.19.0.2"
|
||||||
|
|
||||||
[tool.poetry.group.docs.dependencies]
|
[tool.poetry.group.docs.dependencies]
|
||||||
Sphinx = "^2.1.1"
|
Sphinx = "^2.1.1"
|
||||||
|
@ -26,8 +26,11 @@ from base64 import b64encode
|
|||||||
from collections import OrderedDict, namedtuple
|
from collections import OrderedDict, namedtuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from nacl.signing import SigningKey, VerifyKey
|
||||||
from simplejson import dumps, loads
|
from simplejson import dumps, loads
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ssb.util import tag
|
from ssb.util import tag
|
||||||
|
|
||||||
@ -38,7 +41,7 @@ class NoPrivateKeyException(Exception):
|
|||||||
"""Exception to raise when a private key is not available"""
|
"""Exception to raise when a private key is not available"""
|
||||||
|
|
||||||
|
|
||||||
def to_ordered(data):
|
def to_ordered(data: Dict[str, Any]) -> OrderedDict[str, Any]:
|
||||||
"""Convert a dictionary to an ``OrderedDict``"""
|
"""Convert a dictionary to an ``OrderedDict``"""
|
||||||
|
|
||||||
smsg = OrderedMsg(**data)
|
smsg = OrderedMsg(**data)
|
||||||
@ -46,7 +49,7 @@ def to_ordered(data):
|
|||||||
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
|
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
|
||||||
|
|
||||||
|
|
||||||
def get_millis_1970():
|
def get_millis_1970() -> int:
|
||||||
"""Get the UNIX timestamp in milliseconds"""
|
"""Get the UNIX timestamp in milliseconds"""
|
||||||
|
|
||||||
return int(datetime.utcnow().timestamp() * 1000)
|
return int(datetime.utcnow().timestamp() * 1000)
|
||||||
@ -55,16 +58,16 @@ def get_millis_1970():
|
|||||||
class Feed:
|
class Feed:
|
||||||
"""Base class for feeds"""
|
"""Base class for feeds"""
|
||||||
|
|
||||||
def __init__(self, public_key):
|
def __init__(self, public_key: VerifyKey) -> None:
|
||||||
self.public_key = public_key
|
self.public_key = public_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self):
|
def id(self) -> str:
|
||||||
"""The identifier of the feed"""
|
"""The identifier of the feed"""
|
||||||
|
|
||||||
return tag(self.public_key).decode("ascii")
|
return tag(self.public_key).decode("ascii")
|
||||||
|
|
||||||
def sign(self, msg):
|
def sign(self, msg: bytes) -> bytes:
|
||||||
"""Sign a message"""
|
"""Sign a message"""
|
||||||
|
|
||||||
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
|
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
|
||||||
@ -73,16 +76,20 @@ class Feed:
|
|||||||
class LocalFeed(Feed):
|
class LocalFeed(Feed):
|
||||||
"""Class representing a local feed"""
|
"""Class representing a local feed"""
|
||||||
|
|
||||||
def __init__(self, private_key): # pylint: disable=super-init-not-called
|
def __init__(self, private_key: SigningKey) -> None: # pylint: disable=super-init-not-called
|
||||||
self.private_key = private_key
|
self.private_key = private_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def public_key(self):
|
def public_key(self) -> VerifyKey:
|
||||||
"""The public key of the feed"""
|
"""The public key of the feed"""
|
||||||
|
|
||||||
return self.private_key.verify_key
|
return self.private_key.verify_key
|
||||||
|
|
||||||
def sign(self, msg):
|
@public_key.setter
|
||||||
|
def public_key(self, key: VerifyKey) -> None:
|
||||||
|
raise TypeError("Can not set only the public key for a local feed")
|
||||||
|
|
||||||
|
def sign(self, msg: bytes) -> bytes:
|
||||||
"""Sign a message for this feed"""
|
"""Sign a message for this feed"""
|
||||||
|
|
||||||
return self.private_key.sign(msg).signature
|
return self.private_key.sign(msg).signature
|
||||||
@ -92,25 +99,34 @@ class Message:
|
|||||||
"""Base class for SSB messages"""
|
"""Base class for SSB messages"""
|
||||||
|
|
||||||
def __init__( # pylint: disable=too-many-arguments
|
def __init__( # pylint: disable=too-many-arguments
|
||||||
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
|
self,
|
||||||
|
feed: Feed,
|
||||||
|
content: Dict[str, Any],
|
||||||
|
signature: Optional[str] = None,
|
||||||
|
sequence: int = 1,
|
||||||
|
timestamp: Optional[int] = None,
|
||||||
|
previous: Optional["Message"] = None,
|
||||||
):
|
):
|
||||||
self.feed = feed
|
self.feed = feed
|
||||||
self.content = content
|
self.content = content
|
||||||
|
|
||||||
if signature is None:
|
|
||||||
raise ValueError("signature can't be None")
|
|
||||||
self.signature = signature
|
self.signature = signature
|
||||||
|
|
||||||
self.previous = previous
|
self.previous = previous
|
||||||
|
self.timestamp = get_millis_1970() if timestamp is None else timestamp
|
||||||
|
|
||||||
if self.previous:
|
if self.previous:
|
||||||
self.sequence = self.previous.sequence + 1
|
self.sequence: int = self.previous.sequence + 1
|
||||||
else:
|
else:
|
||||||
self.sequence = sequence
|
self.sequence = sequence
|
||||||
|
|
||||||
self.timestamp = get_millis_1970() if timestamp is None else timestamp
|
self._check_signature()
|
||||||
|
|
||||||
|
def _check_signature(self) -> None:
|
||||||
|
if self.signature is None:
|
||||||
|
raise ValueError("signature can't be None")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse(cls, data, feed):
|
def parse(cls, data: bytes, feed: Feed) -> Self:
|
||||||
"""Parse raw message data"""
|
"""Parse raw message data"""
|
||||||
|
|
||||||
obj = loads(data, object_pairs_hook=OrderedDict)
|
obj = loads(data, object_pairs_hook=OrderedDict)
|
||||||
@ -118,12 +134,12 @@ class Message:
|
|||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def serialize(self, add_signature=True):
|
def serialize(self, add_signature: bool = True) -> bytes:
|
||||||
"""Serialize the message"""
|
"""Serialize the message"""
|
||||||
|
|
||||||
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
|
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
|
||||||
|
|
||||||
def to_dict(self, add_signature=True):
|
def to_dict(self, add_signature: bool = True) -> OrderedDict[str, Any]:
|
||||||
"""Convert the message to a dictionary"""
|
"""Convert the message to a dictionary"""
|
||||||
|
|
||||||
obj = to_ordered(
|
obj = to_ordered(
|
||||||
@ -142,20 +158,21 @@ class Message:
|
|||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def verify(self, signature):
|
def verify(self, signature: str) -> bool:
|
||||||
"""Verify the signature of the message"""
|
"""Verify the signature of the message"""
|
||||||
|
|
||||||
return self.signature == signature
|
return self.signature == signature
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hash(self):
|
def hash(self) -> str:
|
||||||
"""The cryptographic hash of the message"""
|
"""The cryptographic hash of the message"""
|
||||||
|
|
||||||
hash_ = sha256(self.serialize()).digest()
|
hash_ = sha256(self.serialize()).digest()
|
||||||
|
|
||||||
return b64encode(hash_).decode("ascii") + ".sha256"
|
return b64encode(hash_).decode("ascii") + ".sha256"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self):
|
def key(self) -> str:
|
||||||
"""The key of the message"""
|
"""The key of the message"""
|
||||||
|
|
||||||
return "%" + self.hash
|
return "%" + self.hash
|
||||||
@ -165,25 +182,21 @@ class LocalMessage(Message):
|
|||||||
"""Class representing a local message"""
|
"""Class representing a local message"""
|
||||||
|
|
||||||
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
|
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
|
||||||
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
|
self,
|
||||||
|
feed: LocalFeed,
|
||||||
|
content: Dict[str, Any],
|
||||||
|
signature: Optional[str] = None,
|
||||||
|
sequence: int = 1,
|
||||||
|
timestamp: Optional[int] = None,
|
||||||
|
previous: Optional["LocalMessage"] = None,
|
||||||
):
|
):
|
||||||
self.feed = feed
|
super().__init__(feed, content, signature=signature, sequence=sequence, timestamp=timestamp, previous=previous)
|
||||||
self.content = content
|
|
||||||
|
|
||||||
self.previous = previous
|
def _check_signature(self) -> None:
|
||||||
if self.previous:
|
if self.signature is None:
|
||||||
self.sequence = self.previous.sequence + 1
|
|
||||||
else:
|
|
||||||
self.sequence = sequence
|
|
||||||
|
|
||||||
self.timestamp = get_millis_1970() if timestamp is None else timestamp
|
|
||||||
|
|
||||||
if signature is None:
|
|
||||||
self.signature = self._sign()
|
self.signature = self._sign()
|
||||||
else:
|
|
||||||
self.signature = signature
|
|
||||||
|
|
||||||
def _sign(self):
|
def _sign(self) -> str:
|
||||||
# ensure ordering of keys and indentation of 2 characters, like ssb-keys
|
# ensure ordering of keys and indentation of 2 characters, like ssb-keys
|
||||||
data = self.serialize(add_signature=False)
|
data = self.serialize(add_signature=False)
|
||||||
return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii")
|
return (b64encode(bytes(self.feed.sign(data))) + b".sig.ed25519").decode("ascii")
|
||||||
|
108
ssb/muxrpc.py
108
ssb/muxrpc.py
@ -22,7 +22,16 @@
|
|||||||
|
|
||||||
"""MuxRPC"""
|
"""MuxRPC"""
|
||||||
|
|
||||||
from ssb.packet_stream import PSMessageType
|
from typing import Any, AsyncIterator, Callable, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from .packet_stream import PacketStream, PSMessage, PSMessageType, PSRequestHandler, PSStreamHandler
|
||||||
|
|
||||||
|
MuxRPCJSON = Dict[str, Any]
|
||||||
|
MuxRPCCallType = Literal["async", "duplex", "sink", "source", "sync"]
|
||||||
|
MuxRPCRequestHandlerType = Callable[[PacketStream, "MuxRPCRequest"], None]
|
||||||
|
MuxRPCRequestParam = Union[bytes, str, MuxRPCJSON] # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCAPIException(Exception):
|
class MuxRPCAPIException(Exception):
|
||||||
@ -32,7 +41,7 @@ class MuxRPCAPIException(Exception):
|
|||||||
class MuxRPCHandler: # pylint: disable=too-few-public-methods
|
class MuxRPCHandler: # pylint: disable=too-few-public-methods
|
||||||
"""Base MuxRPC handler class"""
|
"""Base MuxRPC handler class"""
|
||||||
|
|
||||||
def check_message(self, msg):
|
def check_message(self, msg: PSMessage) -> None:
|
||||||
"""Check message validity"""
|
"""Check message validity"""
|
||||||
|
|
||||||
body = msg.body
|
body = msg.body
|
||||||
@ -40,34 +49,53 @@ class MuxRPCHandler: # pylint: disable=too-few-public-methods
|
|||||||
if isinstance(body, dict) and "name" in body and body["name"] == "Error":
|
if isinstance(body, dict) and "name" in body and body["name"] == "Error":
|
||||||
raise MuxRPCAPIException(body["message"])
|
raise MuxRPCAPIException(body["message"])
|
||||||
|
|
||||||
|
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
class MuxRPCRequestHandler(MuxRPCHandler):
|
async def __anext__(self) -> Optional[PSMessage]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None:
|
||||||
|
"""Send a message through the stream"""
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def get_response(self) -> PSMessage:
|
||||||
|
"""Get the response for an RPC request"""
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class MuxRPCRequestHandler(MuxRPCHandler): # pylint: disable=abstract-method
|
||||||
"""MuxRPC handler for incoming RPC requests"""
|
"""MuxRPC handler for incoming RPC requests"""
|
||||||
|
|
||||||
def __init__(self, ps_handler):
|
def __init__(self, ps_handler: PSRequestHandler):
|
||||||
self.ps_handler = ps_handler
|
self.ps_handler = ps_handler
|
||||||
|
|
||||||
async def get_response(self):
|
async def get_response(self) -> PSMessage:
|
||||||
"""Get the response data"""
|
"""Get the response data"""
|
||||||
|
|
||||||
msg = await self.ps_handler
|
msg = await self.ps_handler.__anext__()
|
||||||
|
|
||||||
self.check_message(msg)
|
self.check_message(msg)
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCSourceHandler(MuxRPCHandler):
|
class MuxRPCSourceHandler(MuxRPCHandler): # pylint: disable=abstract-method
|
||||||
"""MuxRPC handler for source-type RPC requests"""
|
"""MuxRPC handler for source-type RPC requests"""
|
||||||
|
|
||||||
def __init__(self, ps_handler):
|
def __init__(self, ps_handler: PSStreamHandler):
|
||||||
self.ps_handler = ps_handler
|
self.ps_handler = ps_handler
|
||||||
|
|
||||||
def __aiter__(self):
|
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self) -> Optional[PSMessage]:
|
||||||
msg = await self.ps_handler.__anext__()
|
msg = await self.ps_handler.__anext__()
|
||||||
|
|
||||||
|
assert msg
|
||||||
|
|
||||||
self.check_message(msg)
|
self.check_message(msg)
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
@ -76,64 +104,74 @@ class MuxRPCSourceHandler(MuxRPCHandler):
|
|||||||
class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods
|
class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods
|
||||||
"""Mixin for sink-type MuxRPC handlers"""
|
"""Mixin for sink-type MuxRPC handlers"""
|
||||||
|
|
||||||
def send(self, msg, msg_type=PSMessageType.JSON, end=False):
|
connection: PacketStream
|
||||||
|
req: int
|
||||||
|
|
||||||
|
def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None:
|
||||||
"""Send a message through the stream"""
|
"""Send a message through the stream"""
|
||||||
|
|
||||||
self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end)
|
self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end)
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler):
|
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler): # pylint: disable=abstract-method
|
||||||
"""MuxRPC handler for duplex streams"""
|
"""MuxRPC handler for duplex streams"""
|
||||||
|
|
||||||
def __init__(self, ps_handler, connection, req):
|
def __init__(self, ps_handler: PSStreamHandler, connection: PacketStream, req: int):
|
||||||
super().__init__(ps_handler)
|
super().__init__(ps_handler)
|
||||||
|
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.req = req
|
self.req = req
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin):
|
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin): # pylint: disable=abstract-method
|
||||||
"""MuxRPC handler for sinks"""
|
"""MuxRPC handler for sinks"""
|
||||||
|
|
||||||
def __init__(self, connection, req):
|
def __init__(self, connection: PacketStream, req: int):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.req = req
|
self.req = req
|
||||||
|
|
||||||
|
|
||||||
def _get_appropriate_api_handler(type_, connection, ps_handler, req):
|
def _get_appropriate_api_handler(
|
||||||
|
type_: MuxRPCCallType, connection: PacketStream, ps_handler: Union[PSRequestHandler, PSStreamHandler], req: int
|
||||||
|
) -> MuxRPCHandler:
|
||||||
"""Find the appropriate MuxRPC handler"""
|
"""Find the appropriate MuxRPC handler"""
|
||||||
|
|
||||||
if type_ in {"sync", "async"}:
|
if type_ in {"sync", "async"}:
|
||||||
|
assert isinstance(ps_handler, PSRequestHandler)
|
||||||
return MuxRPCRequestHandler(ps_handler)
|
return MuxRPCRequestHandler(ps_handler)
|
||||||
|
|
||||||
if type_ == "source":
|
if type_ == "source":
|
||||||
|
assert isinstance(ps_handler, PSStreamHandler)
|
||||||
return MuxRPCSourceHandler(ps_handler)
|
return MuxRPCSourceHandler(ps_handler)
|
||||||
|
|
||||||
if type_ == "sink":
|
if type_ == "sink":
|
||||||
return MuxRPCSinkHandler(connection, req)
|
return MuxRPCSinkHandler(connection, req)
|
||||||
|
|
||||||
if type_ == "duplex":
|
if type_ == "duplex":
|
||||||
|
assert isinstance(ps_handler, PSStreamHandler)
|
||||||
return MuxRPCDuplexHandler(ps_handler, connection, req)
|
return MuxRPCDuplexHandler(ps_handler, connection, req)
|
||||||
|
|
||||||
return None
|
raise TypeError(f"Unknown request type {type_}")
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCRequest:
|
class MuxRPCRequest:
|
||||||
"""MuxRPC request"""
|
"""MuxRPC request"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_message(cls, message):
|
def from_message(cls, message: PSMessage) -> Self:
|
||||||
"""Initialise a request from a raw packet stream message"""
|
"""Initialise a request from a raw packet stream message"""
|
||||||
|
|
||||||
body = message.body
|
body = message.body
|
||||||
|
|
||||||
|
assert isinstance(body, dict)
|
||||||
|
|
||||||
return cls(".".join(body["name"]), body["args"])
|
return cls(".".join(body["name"]), body["args"])
|
||||||
|
|
||||||
def __init__(self, name, args):
|
def __init__(self, name: str, args: List[MuxRPCRequestParam]):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return f"<MuxRPCRequest {self.name} {self.args}>"
|
return f"<MuxRPCRequest {self.name} {self.args}>"
|
||||||
|
|
||||||
|
|
||||||
@ -141,28 +179,30 @@ class MuxRPCMessage:
|
|||||||
"""MuxRPC message"""
|
"""MuxRPC message"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_message(cls, message):
|
def from_message(cls, message: PSMessage) -> Self:
|
||||||
"""Initialise a MuxRPC message from a raw packet stream message"""
|
"""Initialise a MuxRPC message from a raw packet stream message"""
|
||||||
|
|
||||||
return cls(message.body)
|
return cls(message.body)
|
||||||
|
|
||||||
def __init__(self, body):
|
def __init__(self, body: Union[bytes, str, Dict[str, Any]]):
|
||||||
self.body = body
|
self.body = body
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return f"<MuxRPCMessage {self.body}>"
|
return f"<MuxRPCMessage {self.body!r}>"
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCAPI:
|
class MuxRPCAPI:
|
||||||
"""Generic MuxRPC API"""
|
"""Generic MuxRPC API"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.handlers = {}
|
self.handlers: Dict[str, MuxRPCRequestHandlerType] = {}
|
||||||
self.connection = None
|
self.connection: Optional[PacketStream] = None
|
||||||
|
|
||||||
async def process_messages(self):
|
async def process_messages(self) -> None:
|
||||||
"""Continuously process incoming messages"""
|
"""Continuously process incoming messages"""
|
||||||
|
|
||||||
|
assert self.connection
|
||||||
|
|
||||||
async for req_message in self.connection:
|
async for req_message in self.connection:
|
||||||
if req_message is None:
|
if req_message is None:
|
||||||
return
|
return
|
||||||
@ -172,22 +212,22 @@ class MuxRPCAPI:
|
|||||||
if isinstance(body, dict) and body.get("name"):
|
if isinstance(body, dict) and body.get("name"):
|
||||||
self.process(self.connection, MuxRPCRequest.from_message(req_message))
|
self.process(self.connection, MuxRPCRequest.from_message(req_message))
|
||||||
|
|
||||||
def add_connection(self, connection):
|
def add_connection(self, connection: PacketStream) -> None:
|
||||||
"""Set the packet stream connection of this RPC API"""
|
"""Set the packet stream connection of this RPC API"""
|
||||||
|
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
|
||||||
def define(self, name):
|
def define(self, name: str) -> Callable[[MuxRPCRequestHandlerType], MuxRPCRequestHandlerType]:
|
||||||
"""Decorator to define an RPC method handler"""
|
"""Decorator to define an RPC method handler"""
|
||||||
|
|
||||||
def _handle(f):
|
def _handle(f: MuxRPCRequestHandlerType) -> MuxRPCRequestHandlerType:
|
||||||
self.handlers[name] = f
|
self.handlers[name] = f
|
||||||
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return _handle
|
return _handle
|
||||||
|
|
||||||
def process(self, connection, request):
|
def process(self, connection: PacketStream, request: MuxRPCRequest) -> None:
|
||||||
"""Process an incoming request"""
|
"""Process an incoming request"""
|
||||||
|
|
||||||
handler = self.handlers.get(request.name)
|
handler = self.handlers.get(request.name)
|
||||||
@ -197,9 +237,11 @@ class MuxRPCAPI:
|
|||||||
|
|
||||||
handler(connection, request)
|
handler(connection, request)
|
||||||
|
|
||||||
def call(self, name, args, type_="sync"):
|
def call(self, name: str, args: List[MuxRPCRequestParam], type_: MuxRPCCallType = "sync") -> MuxRPCHandler:
|
||||||
"""Call an RPC method"""
|
"""Call an RPC method"""
|
||||||
|
|
||||||
|
assert self.connection
|
||||||
|
|
||||||
if not self.connection.is_connected:
|
if not self.connection.is_connected:
|
||||||
raise Exception("not connected") # pylint: disable=broad-exception-raised
|
raise Exception("not connected") # pylint: disable=broad-exception-raised
|
||||||
|
|
||||||
|
@ -28,9 +28,13 @@ import logging
|
|||||||
from math import ceil
|
from math import ceil
|
||||||
import struct
|
import struct
|
||||||
from time import time
|
from time import time
|
||||||
|
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from secret_handshake.network import SHSDuplexStream
|
||||||
import simplejson
|
import simplejson
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
PSMessageData = Union[bytes, bool, Dict[str, Any], str]
|
||||||
logger = logging.getLogger("packet_stream")
|
logger = logging.getLogger("packet_stream")
|
||||||
|
|
||||||
|
|
||||||
@ -45,25 +49,27 @@ class PSMessageType(Enum):
|
|||||||
class PSStreamHandler:
|
class PSStreamHandler:
|
||||||
"""Packet stream handler"""
|
"""Packet stream handler"""
|
||||||
|
|
||||||
def __init__(self, req):
|
def __init__(self, req: int):
|
||||||
super(PSStreamHandler).__init__()
|
super().__init__()
|
||||||
self.req = req
|
self.req = req
|
||||||
self.queue = Queue()
|
self.queue: Queue["PSMessage"] = Queue()
|
||||||
|
|
||||||
async def process(self, msg):
|
async def process(self, msg: "PSMessage") -> None:
|
||||||
"""Process a pending message"""
|
"""Process a pending message"""
|
||||||
|
|
||||||
await self.queue.put(msg)
|
await self.queue.put(msg)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self) -> None:
|
||||||
"""Stop a pending request"""
|
"""Stop a pending request"""
|
||||||
|
|
||||||
await self.queue.put(None)
|
# We use the None value internally to signal __anext__ that the stream can be closed. It is not used otherwise,
|
||||||
|
# hence the typing ignore
|
||||||
|
await self.queue.put(None) # type: ignore[arg-type]
|
||||||
|
|
||||||
def __aiter__(self):
|
def __aiter__(self) -> AsyncIterator[Optional["PSMessage"]]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self) -> Optional["PSMessage"]:
|
||||||
elem = await self.queue.get()
|
elem = await self.queue.get()
|
||||||
|
|
||||||
if not elem:
|
if not elem:
|
||||||
@ -75,30 +81,32 @@ class PSStreamHandler:
|
|||||||
class PSRequestHandler:
|
class PSRequestHandler:
|
||||||
"""Packet stream request handler"""
|
"""Packet stream request handler"""
|
||||||
|
|
||||||
def __init__(self, req):
|
def __init__(self, req: int):
|
||||||
self.req = req
|
self.req = req
|
||||||
self.event = Event()
|
self.event = Event()
|
||||||
self._msg = None
|
self._msg: Optional["PSMessage"] = None
|
||||||
|
|
||||||
async def process(self, msg):
|
async def process(self, msg: "PSMessage") -> None:
|
||||||
"""Process a message request"""
|
"""Process a message request"""
|
||||||
|
|
||||||
self._msg = msg
|
self._msg = msg
|
||||||
self.event.set()
|
self.event.set()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self) -> None:
|
||||||
"""Stop a pending event request"""
|
"""Stop a pending event request"""
|
||||||
|
|
||||||
if not self.event.is_set():
|
if not self.event.is_set():
|
||||||
self.event.set()
|
self.event.set()
|
||||||
|
|
||||||
def __aiter__(self):
|
def __aiter__(self) -> AsyncIterator["PSMessage"]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self) -> "PSMessage":
|
||||||
# wait until 'process' is called
|
# wait until 'process' is called
|
||||||
await self.event.wait()
|
await self.event.wait()
|
||||||
|
|
||||||
|
assert self._msg
|
||||||
|
|
||||||
return self._msg
|
return self._msg
|
||||||
|
|
||||||
|
|
||||||
@ -106,42 +114,55 @@ class PSMessage:
|
|||||||
"""Packet Stream message"""
|
"""Packet Stream message"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_header_body(cls, flags, req, body):
|
def from_header_body(cls, flags: int, req: int, body: bytes) -> Self:
|
||||||
"""Parse a raw message"""
|
"""Parse a raw message"""
|
||||||
|
|
||||||
type_ = PSMessageType(flags & 0x03)
|
type_ = PSMessageType(flags & 0x03)
|
||||||
|
|
||||||
if type_ == PSMessageType.TEXT:
|
if type_ == PSMessageType.TEXT:
|
||||||
body = body.decode("utf-8")
|
decoded_body: Union[str, Dict[str, Any], bytes] = body.decode("utf-8")
|
||||||
elif type_ == PSMessageType.JSON:
|
elif type_ == PSMessageType.JSON:
|
||||||
body = simplejson.loads(body)
|
decoded_body = simplejson.loads(body)
|
||||||
|
else:
|
||||||
|
decoded_body = body
|
||||||
|
|
||||||
return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req)
|
return cls(type_, decoded_body, bool(flags & 0x08), bool(flags & 0x04), req=req)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self) -> bytes:
|
||||||
"""The raw message data"""
|
"""The raw message data"""
|
||||||
|
|
||||||
if self.type == PSMessageType.TEXT:
|
if self.type == PSMessageType.TEXT:
|
||||||
|
assert isinstance(self.body, str)
|
||||||
return self.body.encode("utf-8")
|
return self.body.encode("utf-8")
|
||||||
|
|
||||||
if self.type == PSMessageType.JSON:
|
if self.type == PSMessageType.JSON:
|
||||||
|
assert isinstance(self.body, dict)
|
||||||
return simplejson.dumps(self.body).encode("utf-8")
|
return simplejson.dumps(self.body).encode("utf-8")
|
||||||
|
|
||||||
|
assert isinstance(self.body, bytes)
|
||||||
|
|
||||||
return self.body
|
return self.body
|
||||||
|
|
||||||
def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments
|
def __init__(
|
||||||
|
self,
|
||||||
|
type_: PSMessageType,
|
||||||
|
body: Union[bytes, str, Dict[str, Any]],
|
||||||
|
stream: bool,
|
||||||
|
end_err: bool,
|
||||||
|
req: Optional[int] = None,
|
||||||
|
): # pylint: disable=too-many-arguments
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.end_err = end_err
|
self.end_err = end_err
|
||||||
self.type = type_
|
self.type = type_
|
||||||
self.body = body
|
self.body = body
|
||||||
self.req = req
|
self.req = req
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
if self.type == PSMessageType.BUFFER:
|
if self.type == PSMessageType.BUFFER:
|
||||||
body = f"{len(self.body)} bytes"
|
body = f"{len(self.body)} bytes"
|
||||||
else:
|
else:
|
||||||
body = self.body
|
body = str(self.body)
|
||||||
|
|
||||||
req = "" if self.req is None else f" [{self.req}]"
|
req = "" if self.req is None else f" [{self.req}]"
|
||||||
is_stream = "~" if self.stream else ""
|
is_stream = "~" if self.stream else ""
|
||||||
@ -153,79 +174,90 @@ class PSMessage:
|
|||||||
class PacketStream:
|
class PacketStream:
|
||||||
"""SSB Packet stream"""
|
"""SSB Packet stream"""
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection: SHSDuplexStream):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.req_counter = 1
|
self.req_counter = 1
|
||||||
self._event_map = {}
|
self._event_map: Dict[int, Tuple[float, Union[PSRequestHandler, PSStreamHandler]]] = {}
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
||||||
def register_handler(self, handler):
|
def register_handler(self, handler: Union[PSRequestHandler, PSStreamHandler]) -> None:
|
||||||
"""Register an RPC handler"""
|
"""Register an RPC handler"""
|
||||||
|
|
||||||
self._event_map[handler.req] = (time(), handler)
|
self._event_map[handler.req] = (time(), handler)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self) -> bool:
|
||||||
"""Check if the stream is connected"""
|
"""Check if the stream is connected"""
|
||||||
|
|
||||||
return self.connection.is_connected
|
return self.connection.is_connected
|
||||||
|
|
||||||
def __aiter__(self):
|
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __anext__(self):
|
async def __anext__(self) -> PSMessage:
|
||||||
while True:
|
while True:
|
||||||
msg = await self.read()
|
msg = await self.read()
|
||||||
|
|
||||||
if not msg:
|
if not msg:
|
||||||
raise StopAsyncIteration()
|
raise StopAsyncIteration()
|
||||||
|
|
||||||
if msg.req >= 0:
|
if msg.req is not None and msg.req >= 0:
|
||||||
logger.info("RECV: %r", msg)
|
logger.info("RECV: %r", msg)
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
async def _read(self):
|
async def _read(self) -> Optional[PSMessage]:
|
||||||
try:
|
try:
|
||||||
header = await self.connection.read()
|
header = await self.connection.read()
|
||||||
|
|
||||||
if not header or header == b"\x00" * 9:
|
if not header or header == b"\x00" * 9:
|
||||||
return
|
return None
|
||||||
|
|
||||||
flags, length, req = struct.unpack(">BIi", header)
|
flags, length, req = struct.unpack(">BIi", header)
|
||||||
|
|
||||||
n_packets = ceil(length / 4096)
|
n_packets = ceil(length / 4096)
|
||||||
|
|
||||||
body = b""
|
body = b""
|
||||||
|
|
||||||
for _ in range(n_packets):
|
for _ in range(n_packets):
|
||||||
body += await self.connection.read()
|
read_data = await self.connection.read()
|
||||||
|
|
||||||
|
if not read_data:
|
||||||
|
logger.debug("DISCONNECT")
|
||||||
|
self.connection.disconnect()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
body += read_data
|
||||||
|
|
||||||
logger.debug("READ %s %s", header, len(body))
|
logger.debug("READ %s %s", header, len(body))
|
||||||
|
|
||||||
return PSMessage.from_header_body(flags, req, body)
|
return PSMessage.from_header_body(flags, req, body)
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
logger.debug("DISCONNECT")
|
logger.debug("DISCONNECT")
|
||||||
self.connection.disconnect()
|
self.connection.disconnect()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def read(self):
|
async def read(self) -> Optional[PSMessage]:
|
||||||
"""Read data from the packet stream"""
|
"""Read data from the packet stream"""
|
||||||
|
|
||||||
msg = await self._read()
|
msg = await self._read()
|
||||||
|
|
||||||
if not msg:
|
if not msg:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# check whether it's a reply and handle accordingly
|
# check whether it's a reply and handle accordingly
|
||||||
if msg.req < 0:
|
if msg.req is not None and msg.req < 0:
|
||||||
_, handler = self._event_map[-msg.req]
|
_, handler = self._event_map[-msg.req]
|
||||||
await handler.process(msg)
|
await handler.process(msg)
|
||||||
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
|
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
|
||||||
|
|
||||||
if msg.end_err:
|
if msg.end_err:
|
||||||
await handler.stop()
|
await handler.stop()
|
||||||
del self._event_map[-msg.req]
|
del self._event_map[-msg.req]
|
||||||
logger.info("RESPONSE [%d]: EOS", -msg.req)
|
logger.info("RESPONSE [%d]: EOS", -msg.req)
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def _write(self, msg):
|
def _write(self, msg: PSMessage) -> None:
|
||||||
logger.info("SEND [%d]: %r", msg.req, msg)
|
logger.info("SEND [%d]: %r", msg.req, msg)
|
||||||
header = struct.pack(
|
header = struct.pack(
|
||||||
">BIi",
|
">BIi",
|
||||||
@ -239,11 +271,17 @@ class PacketStream:
|
|||||||
logger.debug("WRITE DATA: %s", msg.data)
|
logger.debug("WRITE DATA: %s", msg.data)
|
||||||
|
|
||||||
def send( # pylint: disable=too-many-arguments
|
def send( # pylint: disable=too-many-arguments
|
||||||
self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None
|
self,
|
||||||
):
|
data: Union[bytes, str, Dict[str, Any]],
|
||||||
|
msg_type: PSMessageType = PSMessageType.JSON,
|
||||||
|
stream: bool = False,
|
||||||
|
end_err: bool = False,
|
||||||
|
req: Optional[int] = None,
|
||||||
|
) -> Union[PSRequestHandler, PSStreamHandler]:
|
||||||
"""Send data through the packet stream"""
|
"""Send data through the packet stream"""
|
||||||
|
|
||||||
update_counter = False
|
update_counter = False
|
||||||
|
|
||||||
if req is None:
|
if req is None:
|
||||||
update_counter = True
|
update_counter = True
|
||||||
req = self.req_counter
|
req = self.req_counter
|
||||||
@ -254,16 +292,18 @@ class PacketStream:
|
|||||||
self._write(msg)
|
self._write(msg)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
handler = PSStreamHandler(self.req_counter)
|
handler: Union[PSRequestHandler, PSStreamHandler] = PSStreamHandler(self.req_counter)
|
||||||
else:
|
else:
|
||||||
handler = PSRequestHandler(self.req_counter)
|
handler = PSRequestHandler(self.req_counter)
|
||||||
|
|
||||||
self.register_handler(handler)
|
self.register_handler(handler)
|
||||||
|
|
||||||
if update_counter:
|
if update_counter:
|
||||||
self.req_counter += 1
|
self.req_counter += 1
|
||||||
|
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self) -> None:
|
||||||
"""Disconnect the stream"""
|
"""Disconnect the stream"""
|
||||||
|
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
0
ssb/py.typed
Normal file
0
ssb/py.typed
Normal file
15
ssb/util.py
15
ssb/util.py
@ -24,23 +24,30 @@
|
|||||||
|
|
||||||
from base64 import b64decode, b64encode
|
from base64 import b64decode, b64encode
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional, TypedDict
|
||||||
|
|
||||||
from nacl.signing import SigningKey
|
from nacl.signing import SigningKey, VerifyKey
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
class SSBSecret(TypedDict):
|
||||||
|
"""Dictionary to hold an SSB identity"""
|
||||||
|
|
||||||
|
keypair: SigningKey
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
class ConfigException(Exception):
|
class ConfigException(Exception):
|
||||||
"""Exception to raise if there is a problem with the configuration data"""
|
"""Exception to raise if there is a problem with the configuration data"""
|
||||||
|
|
||||||
|
|
||||||
def tag(key):
|
def tag(key: VerifyKey) -> bytes:
|
||||||
"""Create tag from public key"""
|
"""Create tag from public key"""
|
||||||
|
|
||||||
return b"@" + b64encode(bytes(key)) + b".ed25519"
|
return b"@" + b64encode(bytes(key)) + b".ed25519"
|
||||||
|
|
||||||
|
|
||||||
def load_ssb_secret(filename: Optional[str] = None):
|
def load_ssb_secret(filename: Optional[str] = None) -> SSBSecret:
|
||||||
"""Load SSB keys from ``filename`` or, if unset, from ``~/.ssb/secret``"""
|
"""Load SSB keys from ``filename`` or, if unset, from ``~/.ssb/secret``"""
|
||||||
|
|
||||||
filename = filename or os.path.expanduser("~/.ssb/secret")
|
filename = filename or os.path.expanduser("~/.ssb/secret")
|
||||||
|
@ -50,7 +50,7 @@ SERIALIZED_M1 = b"""{
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def local_feed():
|
def local_feed() -> LocalFeed:
|
||||||
"""Fixture providing a local feed"""
|
"""Fixture providing a local feed"""
|
||||||
|
|
||||||
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
||||||
@ -58,14 +58,14 @@ def local_feed():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def remote_feed():
|
def remote_feed() -> Feed:
|
||||||
"""Fixture providing a remote feed"""
|
"""Fixture providing a remote feed"""
|
||||||
|
|
||||||
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
||||||
return Feed(VerifyKey(public))
|
return Feed(VerifyKey(public))
|
||||||
|
|
||||||
|
|
||||||
def test_local_feed():
|
def test_local_feed() -> None:
|
||||||
"""Test a local feed"""
|
"""Test a local feed"""
|
||||||
|
|
||||||
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
||||||
@ -75,7 +75,7 @@ def test_local_feed():
|
|||||||
assert feed.id == "@I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=.ed25519"
|
assert feed.id == "@I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=.ed25519"
|
||||||
|
|
||||||
|
|
||||||
def test_remote_feed():
|
def test_remote_feed() -> None:
|
||||||
"""Test a remote feed"""
|
"""Test a remote feed"""
|
||||||
|
|
||||||
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
||||||
@ -85,35 +85,21 @@ def test_remote_feed():
|
|||||||
|
|
||||||
m1 = Message(
|
m1 = Message(
|
||||||
feed,
|
feed,
|
||||||
OrderedDict(
|
OrderedDict([("type", "about"), ("about", feed.id), ("name", "neo"), ("description", "The Chosen One")]),
|
||||||
[
|
|
||||||
("type", "about"),
|
|
||||||
("about", feed.id),
|
|
||||||
("name", "neo"),
|
|
||||||
("description", "The Chosen One"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
"foo",
|
"foo",
|
||||||
timestamp=1495706260190,
|
timestamp=1495706260190,
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(NoPrivateKeyException):
|
with pytest.raises(NoPrivateKeyException):
|
||||||
feed.sign(m1)
|
feed.sign(m1.serialize())
|
||||||
|
|
||||||
|
|
||||||
def test_local_message(local_feed): # pylint: disable=redefined-outer-name
|
def test_local_message(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test a local message"""
|
"""Test a local message"""
|
||||||
|
|
||||||
m1 = LocalMessage(
|
m1 = LocalMessage(
|
||||||
local_feed,
|
local_feed,
|
||||||
OrderedDict(
|
OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]),
|
||||||
[
|
|
||||||
("type", "about"),
|
|
||||||
("about", local_feed.id),
|
|
||||||
("name", "neo"),
|
|
||||||
("description", "The Chosen One"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
timestamp=1495706260190,
|
timestamp=1495706260190,
|
||||||
)
|
)
|
||||||
assert m1.timestamp == 1495706260190
|
assert m1.timestamp == 1495706260190
|
||||||
@ -148,20 +134,13 @@ def test_local_message(local_feed): # pylint: disable=redefined-outer-name
|
|||||||
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
||||||
|
|
||||||
|
|
||||||
def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
|
def test_remote_message(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test a remote message"""
|
"""Test a remote message"""
|
||||||
|
|
||||||
signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519"
|
signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519"
|
||||||
m1 = Message(
|
m1 = Message(
|
||||||
remote_feed,
|
remote_feed,
|
||||||
OrderedDict(
|
OrderedDict([("type", "about"), ("about", remote_feed.id), ("name", "neo"), ("description", "The Chosen One")]),
|
||||||
[
|
|
||||||
("type", "about"),
|
|
||||||
("about", remote_feed.id),
|
|
||||||
("name", "neo"),
|
|
||||||
("description", "The Chosen One"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
signature,
|
signature,
|
||||||
timestamp=1495706260190,
|
timestamp=1495706260190,
|
||||||
)
|
)
|
||||||
@ -175,12 +154,7 @@ def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
|
|||||||
m2 = Message(
|
m2 = Message(
|
||||||
remote_feed,
|
remote_feed,
|
||||||
OrderedDict(
|
OrderedDict(
|
||||||
[
|
[("type", "about"), ("about", remote_feed.id), ("name", "morpheus"), ("description", "Dude with big jaw")]
|
||||||
("type", "about"),
|
|
||||||
("about", remote_feed.id),
|
|
||||||
("name", "morpheus"),
|
|
||||||
("description", "Dude with big jaw"),
|
|
||||||
]
|
|
||||||
),
|
),
|
||||||
signature,
|
signature,
|
||||||
previous=m1,
|
previous=m1,
|
||||||
@ -194,54 +168,37 @@ def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
|
|||||||
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
||||||
|
|
||||||
|
|
||||||
def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name
|
def test_remote_no_signature(remote_feed: Feed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test remote feed without a signature"""
|
"""Test remote feed without a signature"""
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
Message(
|
Message(
|
||||||
remote_feed,
|
remote_feed,
|
||||||
OrderedDict(
|
OrderedDict(
|
||||||
[
|
[("type", "about"), ("about", remote_feed.id), ("name", "neo"), ("description", "The Chosen One")]
|
||||||
("type", "about"),
|
|
||||||
("about", remote_feed.id),
|
|
||||||
("name", "neo"),
|
|
||||||
("description", "The Chosen One"),
|
|
||||||
]
|
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
timestamp=1495706260190,
|
timestamp=1495706260190,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_serialize(local_feed): # pylint: disable=redefined-outer-name
|
def test_serialize(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test feed serialization"""
|
"""Test feed serialization"""
|
||||||
|
|
||||||
m1 = LocalMessage(
|
m1 = LocalMessage(
|
||||||
local_feed,
|
local_feed,
|
||||||
OrderedDict(
|
OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]),
|
||||||
[
|
|
||||||
("type", "about"),
|
|
||||||
("about", local_feed.id),
|
|
||||||
("name", "neo"),
|
|
||||||
("description", "The Chosen One"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
timestamp=1495706260190,
|
timestamp=1495706260190,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert m1.serialize() == SERIALIZED_M1
|
assert m1.serialize() == SERIALIZED_M1
|
||||||
|
|
||||||
|
|
||||||
def test_parse(local_feed): # pylint: disable=redefined-outer-name
|
def test_parse(local_feed: LocalFeed) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test feed parsing"""
|
"""Test feed parsing"""
|
||||||
|
|
||||||
m1 = LocalMessage.parse(SERIALIZED_M1, local_feed)
|
m1 = LocalMessage.parse(SERIALIZED_M1, local_feed)
|
||||||
assert m1.content == {
|
assert m1.content == {"type": "about", "about": local_feed.id, "name": "neo", "description": "The Chosen One"}
|
||||||
"type": "about",
|
|
||||||
"about": local_feed.id,
|
|
||||||
"name": "neo",
|
|
||||||
"description": "The Chosen One",
|
|
||||||
}
|
|
||||||
assert m1.timestamp == 1495706260190
|
assert m1.timestamp == 1495706260190
|
||||||
|
|
||||||
|
|
||||||
@ -252,15 +209,15 @@ def test_local_unsigned(local_feed: LocalFeed, mocker: MockerFixture) -> None:
|
|||||||
mocked_dt.utcnow = mocker.MagicMock(return_value=datetime(2023, 3, 7, 11, 45, 54, 0, tzinfo=timezone.utc))
|
mocked_dt.utcnow = mocker.MagicMock(return_value=datetime(2023, 3, 7, 11, 45, 54, 0, tzinfo=timezone.utc))
|
||||||
mocker.patch("ssb.feed.models.datetime", mocked_dt)
|
mocker.patch("ssb.feed.models.datetime", mocked_dt)
|
||||||
|
|
||||||
msg = LocalMessage(local_feed, b"test")
|
msg = LocalMessage(local_feed, OrderedDict({"test": True}))
|
||||||
|
|
||||||
assert msg.feed == local_feed
|
assert msg.feed == local_feed
|
||||||
assert msg.content == b"test"
|
assert msg.content == {"test": True}
|
||||||
assert msg.sequence == 1
|
assert msg.sequence == 1
|
||||||
assert msg.previous is None
|
assert msg.previous is None
|
||||||
assert msg.timestamp == 1678189554000
|
assert msg.timestamp == 1678189554000
|
||||||
assert msg.signature == (
|
assert msg.signature == (
|
||||||
"SxZsBINzsuQqmB6JLmXyr22+FRY33bp3wj1MwjAOU3MqifGqfc3W/2T5D4qel5mqrgJt9IT8c3QayB1suj82AQ==.sig.ed25519"
|
"WjkA5rjzsYDHqeavEPcbNAbRMp5NRFDBNATMWgcsccso8sfwhaWnIEvQW79fA5YgKKybzlIsCMWHherToEI2DA==.sig.ed25519"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -23,18 +23,23 @@
|
|||||||
"""Tests for the packet stream"""
|
"""Tests for the packet stream"""
|
||||||
|
|
||||||
from asyncio import Event, ensure_future, gather
|
from asyncio import Event, ensure_future, gather
|
||||||
|
from asyncio.events import AbstractEventLoop
|
||||||
import json
|
import json
|
||||||
|
from typing import Any, AsyncIterator, Awaitable, Callable, Generator, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
from secret_handshake.network import SHSDuplexStream
|
from secret_handshake.network import SHSDuplexStream
|
||||||
|
|
||||||
from ssb.packet_stream import PacketStream, PSMessageType
|
from ssb.packet_stream import PacketStream, PSMessage, PSMessageType
|
||||||
|
|
||||||
|
|
||||||
async def _collect_messages(generator):
|
async def _collect_messages(generator: AsyncIterator[Optional[PSMessage]]) -> List[Optional["PSMessage"]]:
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
async for msg in generator:
|
async for msg in generator:
|
||||||
results.append(msg)
|
results.append(msg)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -61,45 +66,47 @@ MSG_BODY_2 = (
|
|||||||
class MockSHSSocket(SHSDuplexStream):
|
class MockSHSSocket(SHSDuplexStream):
|
||||||
"""A mocked SHS socket"""
|
"""A mocked SHS socket"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
def __init__(self, *args: Any, **kwargs: Any): # pylint: disable=unused-argument
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.input = []
|
self.input: List[bytes] = []
|
||||||
self.output = []
|
self.output: List[bytes] = []
|
||||||
self.is_connected = False
|
self.is_connected: bool = False
|
||||||
self._on_connect = []
|
self._on_connect: List[Callable[[], Awaitable[None]]] = []
|
||||||
|
|
||||||
def on_connect(self, cb):
|
def on_connect(self, cb: Callable[[], Awaitable[None]]) -> None:
|
||||||
"""Set the on_connect callback"""
|
"""Set the on_connect callback"""
|
||||||
|
|
||||||
self._on_connect.append(cb)
|
self._on_connect.append(cb)
|
||||||
|
|
||||||
async def read(self):
|
async def read(self) -> Optional[bytes]:
|
||||||
"""Read data from the socket"""
|
"""Read data from the socket"""
|
||||||
|
|
||||||
if not self.input:
|
if not self.input:
|
||||||
raise StopAsyncIteration
|
return None
|
||||||
|
|
||||||
return self.input.pop(0)
|
return self.input.pop(0)
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data: bytes) -> None:
|
||||||
"""Write data to the socket"""
|
"""Write data to the socket"""
|
||||||
|
|
||||||
self.output.append(data)
|
self.output.append(data)
|
||||||
|
|
||||||
def feed(self, input_):
|
def feed(self, input_: List[bytes]) -> None:
|
||||||
"""Get the connection’s feed"""
|
"""Feed data into the connection"""
|
||||||
|
|
||||||
self.input += input_
|
self.input += input_
|
||||||
|
|
||||||
def get_output(self):
|
def get_output(self) -> Generator[bytes, None, None]:
|
||||||
"""Get the output of a call"""
|
"""Get the output of a call"""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if not self.output:
|
if not self.output:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield self.output.pop(0)
|
yield self.output.pop(0)
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self) -> None:
|
||||||
"""Disconnect from the remote party"""
|
"""Disconnect from the remote party"""
|
||||||
|
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
@ -108,7 +115,7 @@ class MockSHSSocket(SHSDuplexStream):
|
|||||||
class MockSHSClient(MockSHSSocket):
|
class MockSHSClient(MockSHSSocket):
|
||||||
"""A mocked SHS client"""
|
"""A mocked SHS client"""
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self) -> None:
|
||||||
"""Connect to a SHS server"""
|
"""Connect to a SHS server"""
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
@ -120,7 +127,7 @@ class MockSHSClient(MockSHSSocket):
|
|||||||
class MockSHSServer(MockSHSSocket):
|
class MockSHSServer(MockSHSSocket):
|
||||||
"""A mocked SHS server"""
|
"""A mocked SHS server"""
|
||||||
|
|
||||||
def listen(self):
|
def listen(self) -> None:
|
||||||
"""Listen for new connections"""
|
"""Listen for new connections"""
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
@ -130,26 +137,26 @@ class MockSHSServer(MockSHSSocket):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ps_client(event_loop): # pylint: disable=unused-argument
|
def ps_client(event_loop: AbstractEventLoop) -> MockSHSClient: # pylint: disable=unused-argument
|
||||||
"""Fixture to provide a mocked SHS client"""
|
"""Fixture to provide a mocked SHS client"""
|
||||||
|
|
||||||
return MockSHSClient()
|
return MockSHSClient()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ps_server(event_loop): # pylint: disable=unused-argument
|
def ps_server(event_loop: AbstractEventLoop) -> MockSHSServer: # pylint: disable=unused-argument
|
||||||
"""Fixture to provide a mocked SHS server"""
|
"""Fixture to provide a mocked SHS server"""
|
||||||
|
|
||||||
return MockSHSServer()
|
return MockSHSServer()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name
|
async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test the on_connect callback functionality"""
|
"""Test the on_connect callback functionality"""
|
||||||
|
|
||||||
called = Event()
|
called = Event()
|
||||||
|
|
||||||
async def _on_connect():
|
async def _on_connect() -> None:
|
||||||
called.set()
|
called.set()
|
||||||
|
|
||||||
ps_server.on_connect(_on_connect)
|
ps_server.on_connect(_on_connect)
|
||||||
@ -159,7 +166,7 @@ async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name
|
async def test_message_decoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test message decoding"""
|
"""Test message decoding"""
|
||||||
|
|
||||||
await ps_client.connect()
|
await ps_client.connect()
|
||||||
@ -178,6 +185,7 @@ async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-n
|
|||||||
|
|
||||||
messages = await _collect_messages(ps)
|
messages = await _collect_messages(ps)
|
||||||
assert len(messages) == 1
|
assert len(messages) == 1
|
||||||
|
assert messages[0]
|
||||||
assert messages[0].type == PSMessageType.JSON
|
assert messages[0].type == PSMessageType.JSON
|
||||||
assert messages[0].body == {
|
assert messages[0].body == {
|
||||||
"name": ["createHistoryStream"],
|
"name": ["createHistoryStream"],
|
||||||
@ -194,7 +202,7 @@ async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-n
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name
|
async def test_message_encoding(ps_client: MockSHSClient) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test message encoding"""
|
"""Test message encoding"""
|
||||||
|
|
||||||
await ps_client.connect()
|
await ps_client.connect()
|
||||||
@ -237,7 +245,9 @@ async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-n
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name
|
async def test_message_stream(
|
||||||
|
ps_client: MockSHSClient, mocker: MockerFixture # pylint: disable=redefined-outer-name
|
||||||
|
) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test requesting a history stream"""
|
"""Test requesting a history stream"""
|
||||||
|
|
||||||
await ps_client.connect()
|
await ps_client.connect()
|
||||||
@ -264,7 +274,7 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert ps.req_counter == 2
|
assert ps.req_counter == 2
|
||||||
assert ps.register_handler.call_count == 1 # pylint: disable=no-member
|
assert ps.register_handler.call_count == 1 # type: ignore[attr-defined] # pylint: disable=no-member
|
||||||
handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access
|
handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access
|
||||||
mock_process = mocker.patch.object(handler, "process")
|
mock_process = mocker.patch.object(handler, "process")
|
||||||
|
|
||||||
@ -273,6 +283,8 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o
|
|||||||
assert mock_process.await_count == 1
|
assert mock_process.await_count == 1
|
||||||
|
|
||||||
# responses have negative req
|
# responses have negative req
|
||||||
|
assert msg
|
||||||
|
assert isinstance(msg.body, dict)
|
||||||
assert msg.req == -1
|
assert msg.req == -1
|
||||||
assert msg.body["previous"] == "%KTGP6W8vF80McRAZHYDWuKOD0KlNyKSq6Gb42iuV7Iw=.sha256"
|
assert msg.body["previous"] == "%KTGP6W8vF80McRAZHYDWuKOD0KlNyKSq6Gb42iuV7Iw=.sha256"
|
||||||
|
|
||||||
@ -295,7 +307,7 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert ps.req_counter == 3
|
assert ps.req_counter == 3
|
||||||
assert ps.register_handler.call_count == 2 # pylint: disable=no-member
|
assert ps.register_handler.call_count == 2 # type: ignore[attr-defined] # pylint: disable=no-member
|
||||||
handler = list(ps._event_map.values())[1][1] # pylint: disable=protected-access
|
handler = list(ps._event_map.values())[1][1] # pylint: disable=protected-access
|
||||||
|
|
||||||
mock_process = mocker.patch.object(handler, "process", wraps=handler.process)
|
mock_process = mocker.patch.object(handler, "process", wraps=handler.process)
|
||||||
@ -318,11 +330,14 @@ async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-o
|
|||||||
|
|
||||||
for msg in handled:
|
for msg in handled:
|
||||||
# responses have negative req
|
# responses have negative req
|
||||||
|
assert msg
|
||||||
assert msg.req == -2
|
assert msg.req == -2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name
|
async def test_message_request(
|
||||||
|
ps_server: MockSHSServer, mocker: MockerFixture # pylint: disable=redefined-outer-name
|
||||||
|
) -> None: # pylint: disable=redefined-outer-name
|
||||||
"""Test message sending"""
|
"""Test message sending"""
|
||||||
|
|
||||||
ps_server.listen()
|
ps_server.listen()
|
||||||
@ -338,7 +353,7 @@ async def test_message_request(ps_server, mocker): # pylint: disable=redefined-
|
|||||||
assert json.loads(body.decode("utf-8")) == {"name": ["whoami"], "args": []}
|
assert json.loads(body.decode("utf-8")) == {"name": ["whoami"], "args": []}
|
||||||
|
|
||||||
assert ps.req_counter == 2
|
assert ps.req_counter == 2
|
||||||
assert ps.register_handler.call_count == 1 # pylint: disable=no-member
|
assert ps.register_handler.call_count == 1 # type: ignore[attr-defined] # pylint: disable=no-member
|
||||||
handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access
|
handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access
|
||||||
mock_process = mocker.patch.object(handler, "process")
|
mock_process = mocker.patch.object(handler, "process")
|
||||||
ps_server.feed(
|
ps_server.feed(
|
||||||
@ -351,6 +366,8 @@ async def test_message_request(ps_server, mocker): # pylint: disable=redefined-
|
|||||||
assert mock_process.await_count == 1
|
assert mock_process.await_count == 1
|
||||||
|
|
||||||
# responses have negative req
|
# responses have negative req
|
||||||
|
assert msg
|
||||||
|
assert isinstance(msg.body, dict)
|
||||||
assert msg.req == -1
|
assert msg.req == -1
|
||||||
assert msg.body["id"] == "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519"
|
assert msg.body["id"] == "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519"
|
||||||
assert ps.req_counter == 2
|
assert ps.req_counter == 2
|
||||||
|
@ -42,7 +42,7 @@ CONFIG_FILE = """
|
|||||||
CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo")
|
CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo")
|
||||||
|
|
||||||
|
|
||||||
def test_load_secret():
|
def test_load_secret() -> None:
|
||||||
"""Test loading the SSB secret from a file"""
|
"""Test loading the SSB secret from a file"""
|
||||||
|
|
||||||
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True):
|
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True):
|
||||||
@ -55,7 +55,7 @@ def test_load_secret():
|
|||||||
assert bytes(secret["keypair"].verify_key) == b64decode("rsYpBIcXsxjQAf0JNes+MHqT2DL+EfopWKAp4rGeEPQ=")
|
assert bytes(secret["keypair"].verify_key) == b64decode("rsYpBIcXsxjQAf0JNes+MHqT2DL+EfopWKAp4rGeEPQ=")
|
||||||
|
|
||||||
|
|
||||||
def test_load_exception():
|
def test_load_exception() -> None:
|
||||||
"""Test configuration loading if there is a problem with the file"""
|
"""Test configuration loading if there is a problem with the file"""
|
||||||
|
|
||||||
with pytest.raises(ConfigException):
|
with pytest.raises(ConfigException):
|
||||||
|
Loading…
Reference in New Issue
Block a user