ci: Add and configure PyLint, and make it happy
This commit is contained in:
parent
e0cd456e77
commit
d51f27d883
@ -24,6 +24,12 @@ repos:
|
|||||||
language: system
|
language: system
|
||||||
require_serial: true
|
require_serial: true
|
||||||
types_or: [python, pyi]
|
types_or: [python, pyi]
|
||||||
|
- id: pylint
|
||||||
|
name: pylint
|
||||||
|
entry: poetry run pylint
|
||||||
|
language: system
|
||||||
|
types: [python]
|
||||||
|
require_serial: true
|
||||||
- id: isort
|
- id: isort
|
||||||
name: isort
|
name: isort
|
||||||
args: ["--check", "--diff"]
|
args: ["--check", "--diff"]
|
||||||
|
@ -20,6 +20,8 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
"""Example SSB Client"""
|
||||||
|
|
||||||
from asyncio import ensure_future, gather, get_event_loop
|
from asyncio import ensure_future, gather, get_event_loop
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -38,18 +40,24 @@ api = MuxRPCAPI()
|
|||||||
|
|
||||||
|
|
||||||
@api.define("createHistoryStream")
|
@api.define("createHistoryStream")
|
||||||
def create_history_stream(connection, msg):
|
def create_history_stream(connection, msg): # pylint: disable=unused-argument
|
||||||
|
"""Handle the createHistoryStream RPC call"""
|
||||||
|
|
||||||
print("create_history_stream", msg)
|
print("create_history_stream", msg)
|
||||||
# msg = PSMessage(PSMessageType.JSON, True, stream=True, end_err=True, req=-req)
|
# msg = PSMessage(PSMessageType.JSON, True, stream=True, end_err=True, req=-req)
|
||||||
# connection.write(msg)
|
# connection.write(msg)
|
||||||
|
|
||||||
|
|
||||||
@api.define("blobs.createWants")
|
@api.define("blobs.createWants")
|
||||||
def create_wants(connection, msg):
|
def create_wants(connection, msg): # pylint: disable=unused-argument
|
||||||
|
"""Handle the createWants RPC call"""
|
||||||
|
|
||||||
print("create_wants", msg)
|
print("create_wants", msg)
|
||||||
|
|
||||||
|
|
||||||
async def test_client():
|
async def test_client():
|
||||||
|
"""The actual client implementation"""
|
||||||
|
|
||||||
async for msg in api.call(
|
async for msg in api.call(
|
||||||
"createHistoryStream",
|
"createHistoryStream",
|
||||||
[
|
[
|
||||||
@ -90,7 +98,9 @@ async def test_client():
|
|||||||
f.write(img_data)
|
f.write(img_data)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main(keypair):
|
||||||
|
"""The main function to run"""
|
||||||
|
|
||||||
client = SHSClient("127.0.0.1", 8008, keypair, bytes(keypair.verify_key))
|
client = SHSClient("127.0.0.1", 8008, keypair, bytes(keypair.verify_key))
|
||||||
packet_stream = PacketStream(client)
|
packet_stream = PacketStream(client)
|
||||||
await client.open()
|
await client.open()
|
||||||
@ -116,8 +126,8 @@ if __name__ == "__main__":
|
|||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
logger.addHandler(ch)
|
logger.addHandler(ch)
|
||||||
|
|
||||||
keypair = load_ssb_secret()["keypair"]
|
ssb_keypair = load_ssb_secret()["keypair"]
|
||||||
|
|
||||||
loop = get_event_loop()
|
loop = get_event_loop()
|
||||||
loop.run_until_complete(main())
|
loop.run_until_complete(main(ssb_keypair))
|
||||||
loop.close()
|
loop.close()
|
||||||
|
@ -20,7 +20,9 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
from asyncio import ensure_future, gather, get_event_loop
|
"""Test SSB server"""
|
||||||
|
|
||||||
|
from asyncio import get_event_loop
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from colorlog import ColoredFormatter
|
from colorlog import ColoredFormatter
|
||||||
@ -34,6 +36,8 @@ api = MuxRPCAPI()
|
|||||||
|
|
||||||
|
|
||||||
async def on_connect(conn):
|
async def on_connect(conn):
|
||||||
|
"""Incoming connection handler"""
|
||||||
|
|
||||||
packet_stream = PacketStream(conn)
|
packet_stream = PacketStream(conn)
|
||||||
api.add_connection(packet_stream)
|
api.add_connection(packet_stream)
|
||||||
|
|
||||||
@ -43,6 +47,8 @@ async def on_connect(conn):
|
|||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
"""The main function to run"""
|
||||||
|
|
||||||
server = SHSServer("127.0.0.1", 8008, load_ssb_secret()["keypair"])
|
server = SHSServer("127.0.0.1", 8008, load_ssb_secret()["keypair"])
|
||||||
server.on_connect(on_connect)
|
server.on_connect(on_connect)
|
||||||
await server.listen()
|
await server.listen()
|
||||||
@ -55,7 +61,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# create formatter
|
# create formatter
|
||||||
formatter = ColoredFormatter(
|
formatter = ColoredFormatter(
|
||||||
"%(log_color)s%(levelname)s%(reset)s:%(bold_white)s%(name)s%(reset)s - " "%(cyan)s%(message)s%(reset)s"
|
"%(log_color)s%(levelname)s%(reset)s:%(bold_white)s%(name)s%(reset)s - %(cyan)s%(message)s%(reset)s"
|
||||||
)
|
)
|
||||||
|
|
||||||
# add formatter to ch
|
# add formatter to ch
|
||||||
|
71
poetry.lock
generated
71
poetry.lock
generated
@ -25,6 +25,20 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
test = ["coverage", "mypy", "pexpect", "ruff", "wheel"]
|
test = ["coverage", "mypy", "pexpect", "ruff", "wheel"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "astroid"
|
||||||
|
version = "3.0.1"
|
||||||
|
description = "An abstract syntax tree for Python with inference support."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8.0"
|
||||||
|
files = [
|
||||||
|
{file = "astroid-3.0.1-py3-none-any.whl", hash = "sha256:7d5895c9825e18079c5aeac0572bc2e4c83205c95d416e0b4fee8bc361d2d9ca"},
|
||||||
|
{file = "astroid-3.0.1.tar.gz", hash = "sha256:86b0bb7d7da0be1a7c4aedb7974e391b32d4ed89e33de6ed6902b4b15c97577e"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-generator"
|
name = "async-generator"
|
||||||
version = "1.8"
|
version = "1.8"
|
||||||
@ -430,6 +444,20 @@ files = [
|
|||||||
{file = "decli-0.6.1.tar.gz", hash = "sha256:ed88ccb947701e8e5509b7945fda56e150e2ac74a69f25d47ac85ef30ab0c0f0"},
|
{file = "decli-0.6.1.tar.gz", hash = "sha256:ed88ccb947701e8e5509b7945fda56e150e2ac74a69f25d47ac85ef30ab0c0f0"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dill"
|
||||||
|
version = "0.3.7"
|
||||||
|
description = "serialize all of Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"},
|
||||||
|
{file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
graph = ["objgraph (>=1.7.2)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "docutils"
|
name = "docutils"
|
||||||
version = "0.17.1"
|
version = "0.17.1"
|
||||||
@ -610,6 +638,17 @@ files = [
|
|||||||
{file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"},
|
{file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mccabe"
|
||||||
|
version = "0.7.0"
|
||||||
|
description = "McCabe checker, plugin for flake8"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"},
|
||||||
|
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mypy-extensions"
|
name = "mypy-extensions"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
@ -723,6 +762,36 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
plugins = ["importlib-metadata"]
|
plugins = ["importlib-metadata"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pylint"
|
||||||
|
version = "3.0.2"
|
||||||
|
description = "python code static checker"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8.0"
|
||||||
|
files = [
|
||||||
|
{file = "pylint-3.0.2-py3-none-any.whl", hash = "sha256:60ed5f3a9ff8b61839ff0348b3624ceeb9e6c2a92c514d81c9cc273da3b6bcda"},
|
||||||
|
{file = "pylint-3.0.2.tar.gz", hash = "sha256:0d4c286ef6d2f66c8bfb527a7f8a629009e42c99707dec821a03e1b51a4c1496"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
astroid = ">=3.0.1,<=3.1.0-dev0"
|
||||||
|
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
|
||||||
|
dill = [
|
||||||
|
{version = ">=0.2", markers = "python_version < \"3.11\""},
|
||||||
|
{version = ">=0.3.7", markers = "python_version >= \"3.12\""},
|
||||||
|
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
|
||||||
|
]
|
||||||
|
isort = ">=4.2.5,<6"
|
||||||
|
mccabe = ">=0.6,<0.8"
|
||||||
|
platformdirs = ">=2.2.0"
|
||||||
|
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||||
|
tomlkit = ">=0.10.1"
|
||||||
|
typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
spelling = ["pyenchant (>=3.2,<4.0)"]
|
||||||
|
testutils = ["gitpython (>3)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pynacl"
|
name = "pynacl"
|
||||||
version = "1.1.2"
|
version = "1.1.2"
|
||||||
@ -1241,4 +1310,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.9"
|
python-versions = "^3.9"
|
||||||
content-hash = "4eb60a723d8be38d3d197522f58d19f4bdae3887bd006d13ab87fb058c75b467"
|
content-hash = "bd8b3213143f1abe13f580d28e2d42ee3d663c2d010548e7acd27be04912308a"
|
||||||
|
@ -25,6 +25,7 @@ commitizen = "^3.12.0"
|
|||||||
coverage = "^7.3.2"
|
coverage = "^7.3.2"
|
||||||
isort = "^5.12.0"
|
isort = "^5.12.0"
|
||||||
pep257 = "^0.7.0"
|
pep257 = "^0.7.0"
|
||||||
|
pylint = "^3.0.2"
|
||||||
pytest = "^7.4.3"
|
pytest = "^7.4.3"
|
||||||
pytest-asyncio = "^0.21.1"
|
pytest-asyncio = "^0.21.1"
|
||||||
pytest-cov = "^4.1.0"
|
pytest-cov = "^4.1.0"
|
||||||
@ -50,6 +51,9 @@ force_sort_within_sections = true
|
|||||||
line_length = 120
|
line_length = 120
|
||||||
profile = "black"
|
profile = "black"
|
||||||
|
|
||||||
|
[tool.pylint.format]
|
||||||
|
max-line-length = 120
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = ["--cov=.", "--no-cov-on-fail"]
|
addopts = ["--cov=.", "--no-cov-on-fail"]
|
||||||
python_files = ["tests/test_*.py"]
|
python_files = ["tests/test_*.py"]
|
||||||
|
@ -20,6 +20,8 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
"""Feed related functionality"""
|
||||||
|
|
||||||
from .models import Feed, LocalFeed, LocalMessage, Message, NoPrivateKeyException
|
from .models import Feed, LocalFeed, LocalMessage, Message, NoPrivateKeyException
|
||||||
|
|
||||||
__all__ = ("Feed", "LocalFeed", "Message", "LocalMessage", "NoPrivateKeyException")
|
__all__ = ("Feed", "LocalFeed", "Message", "LocalMessage", "NoPrivateKeyException")
|
||||||
|
@ -20,6 +20,8 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
"""Feed models"""
|
||||||
|
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from collections import OrderedDict, namedtuple
|
from collections import OrderedDict, namedtuple
|
||||||
import datetime
|
import datetime
|
||||||
@ -33,44 +35,65 @@ OrderedMsg = namedtuple("OrderedMsg", ("previous", "author", "sequence", "timest
|
|||||||
|
|
||||||
|
|
||||||
class NoPrivateKeyException(Exception):
|
class NoPrivateKeyException(Exception):
|
||||||
pass
|
"""Exception to raise when a private key is not available"""
|
||||||
|
|
||||||
|
|
||||||
def to_ordered(data):
|
def to_ordered(data):
|
||||||
|
"""Convert a dictionary to an ``OrderedDict``"""
|
||||||
|
|
||||||
smsg = OrderedMsg(**data)
|
smsg = OrderedMsg(**data)
|
||||||
|
|
||||||
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
|
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
|
||||||
|
|
||||||
|
|
||||||
def get_millis_1970():
|
def get_millis_1970():
|
||||||
|
"""Get the UNIX timestamp in milliseconds"""
|
||||||
|
|
||||||
return int(datetime.datetime.utcnow().timestamp() * 1000)
|
return int(datetime.datetime.utcnow().timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
class Feed(object):
|
class Feed:
|
||||||
|
"""Base class for feeds"""
|
||||||
|
|
||||||
def __init__(self, public_key):
|
def __init__(self, public_key):
|
||||||
self.public_key = public_key
|
self.public_key = public_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self):
|
def id(self):
|
||||||
|
"""The identifier of the feed"""
|
||||||
|
|
||||||
return tag(self.public_key).decode("ascii")
|
return tag(self.public_key).decode("ascii")
|
||||||
|
|
||||||
def sign(self, msg):
|
def sign(self, msg):
|
||||||
|
"""Sign a message"""
|
||||||
|
|
||||||
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
|
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
|
||||||
|
|
||||||
|
|
||||||
class LocalFeed(Feed):
|
class LocalFeed(Feed):
|
||||||
def __init__(self, private_key):
|
"""Class representing a local feed"""
|
||||||
|
|
||||||
|
def __init__(self, private_key): # pylint: disable=super-init-not-called
|
||||||
self.private_key = private_key
|
self.private_key = private_key
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def public_key(self):
|
def public_key(self):
|
||||||
|
"""The public key of the feed"""
|
||||||
|
|
||||||
return self.private_key.verify_key
|
return self.private_key.verify_key
|
||||||
|
|
||||||
def sign(self, msg):
|
def sign(self, msg):
|
||||||
|
"""Sign a message for this feed"""
|
||||||
|
|
||||||
return self.private_key.sign(msg).signature
|
return self.private_key.sign(msg).signature
|
||||||
|
|
||||||
|
|
||||||
class Message(object):
|
class Message:
|
||||||
def __init__(self, feed, content, signature, sequence=1, timestamp=None, previous=None):
|
"""Base class for SSB messages"""
|
||||||
|
|
||||||
|
def __init__( # pylint: disable=too-many-arguments
|
||||||
|
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
|
||||||
|
):
|
||||||
self.feed = feed
|
self.feed = feed
|
||||||
self.content = content
|
self.content = content
|
||||||
|
|
||||||
@ -88,14 +111,21 @@ class Message(object):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse(cls, data, feed):
|
def parse(cls, data, feed):
|
||||||
|
"""Parse raw message data"""
|
||||||
|
|
||||||
obj = loads(data, object_pairs_hook=OrderedDict)
|
obj = loads(data, object_pairs_hook=OrderedDict)
|
||||||
msg = cls(feed, obj["content"], timestamp=obj["timestamp"])
|
msg = cls(feed, obj["content"], timestamp=obj["timestamp"])
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
def serialize(self, add_signature=True):
|
def serialize(self, add_signature=True):
|
||||||
|
"""Serialize the message"""
|
||||||
|
|
||||||
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
|
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
|
||||||
|
|
||||||
def to_dict(self, add_signature=True):
|
def to_dict(self, add_signature=True):
|
||||||
|
"""Convert the message to a dictionary"""
|
||||||
|
|
||||||
obj = to_ordered(
|
obj = to_ordered(
|
||||||
{
|
{
|
||||||
"previous": self.previous.key if self.previous else None,
|
"previous": self.previous.key if self.previous else None,
|
||||||
@ -109,23 +139,34 @@ class Message(object):
|
|||||||
|
|
||||||
if add_signature:
|
if add_signature:
|
||||||
obj["signature"] = self.signature
|
obj["signature"] = self.signature
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def verify(self, signature):
|
def verify(self, signature):
|
||||||
|
"""Verify the signature of the message"""
|
||||||
|
|
||||||
return self.signature == signature
|
return self.signature == signature
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hash(self):
|
def hash(self):
|
||||||
hash = sha256(self.serialize()).digest()
|
"""The cryptographic hash of the message"""
|
||||||
return b64encode(hash).decode("ascii") + ".sha256"
|
|
||||||
|
hash_ = sha256(self.serialize()).digest()
|
||||||
|
return b64encode(hash_).decode("ascii") + ".sha256"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self):
|
def key(self):
|
||||||
|
"""The key of the message"""
|
||||||
|
|
||||||
return "%" + self.hash
|
return "%" + self.hash
|
||||||
|
|
||||||
|
|
||||||
class LocalMessage(Message):
|
class LocalMessage(Message):
|
||||||
def __init__(self, feed, content, signature=None, sequence=1, timestamp=None, previous=None):
|
"""Class representing a local message"""
|
||||||
|
|
||||||
|
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
|
||||||
|
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
|
||||||
|
):
|
||||||
self.feed = feed
|
self.feed = feed
|
||||||
self.content = content
|
self.content = content
|
||||||
|
|
||||||
|
@ -20,23 +20,32 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
"""MuxRPC"""
|
||||||
|
|
||||||
from async_generator import async_generator, yield_
|
from async_generator import async_generator, yield_
|
||||||
|
|
||||||
from ssb.packet_stream import PSMessageType
|
from ssb.packet_stream import PSMessageType
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCAPIException(Exception):
|
class MuxRPCAPIException(Exception):
|
||||||
pass
|
"""Exception to raise on MuxRPC API errors"""
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCHandler(object):
|
class MuxRPCHandler: # pylint: disable=too-few-public-methods
|
||||||
|
"""Base MuxRPC handler class"""
|
||||||
|
|
||||||
def check_message(self, msg):
|
def check_message(self, msg):
|
||||||
|
"""Check message validity"""
|
||||||
|
|
||||||
body = msg.body
|
body = msg.body
|
||||||
|
|
||||||
if isinstance(body, dict) and "name" in body and body["name"] == "Error":
|
if isinstance(body, dict) and "name" in body and body["name"] == "Error":
|
||||||
raise MuxRPCAPIException(body["message"])
|
raise MuxRPCAPIException(body["message"])
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCRequestHandler(MuxRPCHandler):
|
class MuxRPCRequestHandler(MuxRPCHandler):
|
||||||
|
"""MuxRPC handler for incoming RPC requests"""
|
||||||
|
|
||||||
def __init__(self, ps_handler):
|
def __init__(self, ps_handler):
|
||||||
self.ps_handler = ps_handler
|
self.ps_handler = ps_handler
|
||||||
|
|
||||||
@ -47,52 +56,72 @@ class MuxRPCRequestHandler(MuxRPCHandler):
|
|||||||
|
|
||||||
|
|
||||||
class MuxRPCSourceHandler(MuxRPCHandler):
|
class MuxRPCSourceHandler(MuxRPCHandler):
|
||||||
|
"""MuxRPC handler for source-type RPC requests"""
|
||||||
|
|
||||||
def __init__(self, ps_handler):
|
def __init__(self, ps_handler):
|
||||||
self.ps_handler = ps_handler
|
self.ps_handler = ps_handler
|
||||||
|
|
||||||
@async_generator
|
@async_generator
|
||||||
async def __aiter__(self):
|
async def __aiter__(self):
|
||||||
async for msg in self.ps_handler:
|
async for msg in self.ps_handler:
|
||||||
try:
|
self.check_message(msg)
|
||||||
self.check_message(msg)
|
await yield_(msg)
|
||||||
await yield_(msg)
|
|
||||||
except MuxRPCAPIException:
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCSinkHandlerMixin(object):
|
class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods
|
||||||
|
"""Mixin for sink-type MuxRPC handlers"""
|
||||||
|
|
||||||
def send(self, msg, msg_type=PSMessageType.JSON, end=False):
|
def send(self, msg, msg_type=PSMessageType.JSON, end=False):
|
||||||
|
"""Send a message through the stream"""
|
||||||
|
|
||||||
self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end)
|
self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end)
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler):
|
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler):
|
||||||
|
"""MuxRPC handler for duplex streams"""
|
||||||
|
|
||||||
def __init__(self, ps_handler, connection, req):
|
def __init__(self, ps_handler, connection, req):
|
||||||
super(MuxRPCDuplexHandler, self).__init__(ps_handler)
|
super().__init__(ps_handler)
|
||||||
|
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.req = req
|
self.req = req
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin):
|
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin):
|
||||||
|
"""MuxRPC handler for sinks"""
|
||||||
|
|
||||||
def __init__(self, connection, req):
|
def __init__(self, connection, req):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.req = req
|
self.req = req
|
||||||
|
|
||||||
|
|
||||||
def _get_appropriate_api_handler(type_, connection, ps_handler, req):
|
def _get_appropriate_api_handler(type_, connection, ps_handler, req):
|
||||||
|
"""Find the appropriate MuxRPC handler"""
|
||||||
|
|
||||||
if type_ in {"sync", "async"}:
|
if type_ in {"sync", "async"}:
|
||||||
return MuxRPCRequestHandler(ps_handler)
|
return MuxRPCRequestHandler(ps_handler)
|
||||||
elif type_ == "source":
|
|
||||||
|
if type_ == "source":
|
||||||
return MuxRPCSourceHandler(ps_handler)
|
return MuxRPCSourceHandler(ps_handler)
|
||||||
elif type_ == "sink":
|
|
||||||
|
if type_ == "sink":
|
||||||
return MuxRPCSinkHandler(connection, req)
|
return MuxRPCSinkHandler(connection, req)
|
||||||
elif type_ == "duplex":
|
|
||||||
|
if type_ == "duplex":
|
||||||
return MuxRPCDuplexHandler(ps_handler, connection, req)
|
return MuxRPCDuplexHandler(ps_handler, connection, req)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MuxRPCRequest:
|
||||||
|
"""MuxRPC request"""
|
||||||
|
|
||||||
class MuxRPCRequest(object):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_message(cls, message):
|
def from_message(cls, message):
|
||||||
|
"""Initialise a request from a raw packet stream message"""
|
||||||
|
|
||||||
body = message.body
|
body = message.body
|
||||||
|
|
||||||
return cls(".".join(body["name"]), body["args"])
|
return cls(".".join(body["name"]), body["args"])
|
||||||
|
|
||||||
def __init__(self, name, args):
|
def __init__(self, name, args):
|
||||||
@ -100,22 +129,28 @@ class MuxRPCRequest(object):
|
|||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<MuxRPCRequest {0.name} {0.args}>".format(self)
|
return f"<MuxRPCRequest {self.name} {self.args}>"
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCMessage(object):
|
class MuxRPCMessage:
|
||||||
|
"""MuxRPC message"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_message(cls, message):
|
def from_message(cls, message):
|
||||||
|
"""Initialise a MuxRPC message from a raw packet stream message"""
|
||||||
|
|
||||||
return cls(message.body)
|
return cls(message.body)
|
||||||
|
|
||||||
def __init__(self, body):
|
def __init__(self, body):
|
||||||
self.body = body
|
self.body = body
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<MuxRPCMessage {0.body}}>".format(self)
|
return f"<MuxRPCMessage {self.body}>"
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCAPI(object):
|
class MuxRPCAPI:
|
||||||
|
"""Generic MuxRPC API"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.handlers = {}
|
self.handlers = {}
|
||||||
self.connection = None
|
self.connection = None
|
||||||
@ -129,9 +164,13 @@ class MuxRPCAPI(object):
|
|||||||
self.process(self.connection, MuxRPCRequest.from_message(req_message))
|
self.process(self.connection, MuxRPCRequest.from_message(req_message))
|
||||||
|
|
||||||
def add_connection(self, connection):
|
def add_connection(self, connection):
|
||||||
|
"""Set the packet stream connection of this RPC API"""
|
||||||
|
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
|
||||||
def define(self, name):
|
def define(self, name):
|
||||||
|
"""Decorator to define an RPC method handler"""
|
||||||
|
|
||||||
def _handle(f):
|
def _handle(f):
|
||||||
self.handlers[name] = f
|
self.handlers[name] = f
|
||||||
|
|
||||||
@ -140,17 +179,25 @@ class MuxRPCAPI(object):
|
|||||||
return _handle
|
return _handle
|
||||||
|
|
||||||
def process(self, connection, request):
|
def process(self, connection, request):
|
||||||
|
"""Process an incoming request"""
|
||||||
|
|
||||||
handler = self.handlers.get(request.name)
|
handler = self.handlers.get(request.name)
|
||||||
|
|
||||||
if not handler:
|
if not handler:
|
||||||
raise MuxRPCAPIException("Method {} not found!".format(request.name))
|
raise MuxRPCAPIException(f"Method {request.name} not found!")
|
||||||
|
|
||||||
handler(connection, request)
|
handler(connection, request)
|
||||||
|
|
||||||
def call(self, name, args, type_="sync"):
|
def call(self, name, args, type_="sync"):
|
||||||
|
"""Call an RPC method"""
|
||||||
|
|
||||||
if not self.connection.is_connected:
|
if not self.connection.is_connected:
|
||||||
raise Exception("not connected")
|
raise Exception("not connected") # pylint: disable=broad-exception-raised
|
||||||
|
|
||||||
old_counter = self.connection.req_counter
|
old_counter = self.connection.req_counter
|
||||||
ps_handler = self.connection.send(
|
ps_handler = self.connection.send(
|
||||||
{"name": name.split("."), "args": args, "type": type_},
|
{"name": name.split("."), "args": args, "type": type_},
|
||||||
stream=type_ in {"sink", "source", "duplex"},
|
stream=type_ in {"sink", "source", "duplex"},
|
||||||
)
|
)
|
||||||
|
|
||||||
return _get_appropriate_api_handler(type_, self.connection, ps_handler, old_counter)
|
return _get_appropriate_api_handler(type_, self.connection, ps_handler, old_counter)
|
||||||
|
@ -20,6 +20,8 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
"""Packet streams"""
|
||||||
|
|
||||||
from asyncio import Event, Queue
|
from asyncio import Event, Queue
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
@ -28,28 +30,35 @@ import struct
|
|||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
from async_generator import async_generator, yield_
|
from async_generator import async_generator, yield_
|
||||||
from secret_handshake import SHSClient, SHSServer
|
|
||||||
import simplejson
|
import simplejson
|
||||||
|
|
||||||
logger = logging.getLogger("packet_stream")
|
logger = logging.getLogger("packet_stream")
|
||||||
|
|
||||||
|
|
||||||
class PSMessageType(Enum):
|
class PSMessageType(Enum):
|
||||||
|
"""Available message types"""
|
||||||
|
|
||||||
BUFFER = 0
|
BUFFER = 0
|
||||||
TEXT = 1
|
TEXT = 1
|
||||||
JSON = 2
|
JSON = 2
|
||||||
|
|
||||||
|
|
||||||
class PSStreamHandler(object):
|
class PSStreamHandler:
|
||||||
|
"""Packet stream handler"""
|
||||||
|
|
||||||
def __init__(self, req):
|
def __init__(self, req):
|
||||||
super(PSStreamHandler).__init__()
|
super(PSStreamHandler).__init__()
|
||||||
self.req = req
|
self.req = req
|
||||||
self.queue = Queue()
|
self.queue = Queue()
|
||||||
|
|
||||||
async def process(self, msg):
|
async def process(self, msg):
|
||||||
|
"""Process a pending message"""
|
||||||
|
|
||||||
await self.queue.put(msg)
|
await self.queue.put(msg)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
|
"""Stop a pending request"""
|
||||||
|
|
||||||
await self.queue.put(None)
|
await self.queue.put(None)
|
||||||
|
|
||||||
@async_generator
|
@async_generator
|
||||||
@ -61,30 +70,40 @@ class PSStreamHandler(object):
|
|||||||
await yield_(elem)
|
await yield_(elem)
|
||||||
|
|
||||||
|
|
||||||
class PSRequestHandler(object):
|
class PSRequestHandler:
|
||||||
|
"""Packet stream request handler"""
|
||||||
|
|
||||||
def __init__(self, req):
|
def __init__(self, req):
|
||||||
super(PSRequestHandler).__init__()
|
|
||||||
self.req = req
|
self.req = req
|
||||||
self.event = Event()
|
self.event = Event()
|
||||||
self._msg = None
|
self._msg = None
|
||||||
|
|
||||||
async def process(self, msg):
|
async def process(self, msg):
|
||||||
|
"""Process a message request"""
|
||||||
|
|
||||||
self._msg = msg
|
self._msg = msg
|
||||||
self.event.set()
|
self.event.set()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
|
"""Stop a pending event request"""
|
||||||
|
|
||||||
if not self.event.is_set():
|
if not self.event.is_set():
|
||||||
self.event.set()
|
self.event.set()
|
||||||
|
|
||||||
def __await__(self):
|
def __await__(self):
|
||||||
# wait until 'process' is called
|
# wait until 'process' is called
|
||||||
yield from self.event.wait().__await__()
|
yield from self.event.wait().__await__() # pylint: disable=no-member
|
||||||
|
|
||||||
return self._msg
|
return self._msg
|
||||||
|
|
||||||
|
|
||||||
class PSMessage(object):
|
class PSMessage:
|
||||||
|
"""Packet Stream message"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_header_body(cls, flags, req, body):
|
def from_header_body(cls, flags, req, body):
|
||||||
|
"""Parse a raw message"""
|
||||||
|
|
||||||
type_ = PSMessageType(flags & 0x03)
|
type_ = PSMessageType(flags & 0x03)
|
||||||
|
|
||||||
if type_ == PSMessageType.TEXT:
|
if type_ == PSMessageType.TEXT:
|
||||||
@ -96,13 +115,17 @@ class PSMessage(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def data(self):
|
def data(self):
|
||||||
|
"""The raw message data"""
|
||||||
|
|
||||||
if self.type == PSMessageType.TEXT:
|
if self.type == PSMessageType.TEXT:
|
||||||
return self.body.encode("utf-8")
|
return self.body.encode("utf-8")
|
||||||
elif self.type == PSMessageType.JSON:
|
|
||||||
|
if self.type == PSMessageType.JSON:
|
||||||
return simplejson.dumps(self.body).encode("utf-8")
|
return simplejson.dumps(self.body).encode("utf-8")
|
||||||
|
|
||||||
return self.body
|
return self.body
|
||||||
|
|
||||||
def __init__(self, type_, body, stream, end_err, req=None):
|
def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.end_err = end_err
|
self.end_err = end_err
|
||||||
self.type = type_
|
self.type = type_
|
||||||
@ -111,37 +134,45 @@ class PSMessage(object):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if self.type == PSMessageType.BUFFER:
|
if self.type == PSMessageType.BUFFER:
|
||||||
body = "{} bytes".format(len(self.body))
|
body = f"{len(self.body)} bytes"
|
||||||
else:
|
else:
|
||||||
body = self.body
|
body = self.body
|
||||||
return "<PSMessage ({}): {}{} {}{}>".format(
|
|
||||||
self.type.name,
|
req = "" if self.req is None else f" [{self.req}]"
|
||||||
body,
|
is_stream = "~" if self.stream else ""
|
||||||
"" if self.req is None else " [{}]".format(self.req),
|
err = "!" if self.end_err else ""
|
||||||
"~" if self.stream else "",
|
|
||||||
"!" if self.end_err else "",
|
return f"<PSMessage ({self.type.name}): {body}{req} {is_stream}{err}>"
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PacketStream(object):
|
class PacketStream:
|
||||||
|
"""SSB Packet stream"""
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.req_counter = 1
|
self.req_counter = 1
|
||||||
self._event_map = {}
|
self._event_map = {}
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
def register_handler(self, handler):
|
def register_handler(self, handler):
|
||||||
|
"""Register an RPC handler"""
|
||||||
|
|
||||||
self._event_map[handler.req] = (time(), handler)
|
self._event_map[handler.req] = (time(), handler)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self):
|
def is_connected(self):
|
||||||
|
"""Check if the stream is connected"""
|
||||||
|
|
||||||
return self.connection.is_connected
|
return self.connection.is_connected
|
||||||
|
|
||||||
@async_generator
|
@async_generator
|
||||||
async def __aiter__(self):
|
async def __aiter__(self):
|
||||||
while True:
|
while True:
|
||||||
msg = await self.read()
|
msg = await self.read()
|
||||||
|
|
||||||
if not msg:
|
if not msg:
|
||||||
return
|
return
|
||||||
|
|
||||||
# filter out replies
|
# filter out replies
|
||||||
if msg.req >= 0:
|
if msg.req >= 0:
|
||||||
await yield_(msg)
|
await yield_(msg)
|
||||||
@ -149,20 +180,24 @@ class PacketStream(object):
|
|||||||
async def __await__(self):
|
async def __await__(self):
|
||||||
async for data in self:
|
async for data in self:
|
||||||
logger.info("RECV: %r", data)
|
logger.info("RECV: %r", data)
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def _read(self):
|
async def _read(self):
|
||||||
try:
|
try:
|
||||||
header = await self.connection.read()
|
header = await self.connection.read()
|
||||||
|
|
||||||
if not header or header == b"\x00" * 9:
|
if not header or header == b"\x00" * 9:
|
||||||
return
|
return
|
||||||
|
|
||||||
flags, length, req = struct.unpack(">BIi", header)
|
flags, length, req = struct.unpack(">BIi", header)
|
||||||
|
|
||||||
n_packets = ceil(length / 4096)
|
n_packets = ceil(length / 4096)
|
||||||
|
|
||||||
body = b""
|
body = b""
|
||||||
for n in range(n_packets):
|
|
||||||
|
for _ in range(n_packets):
|
||||||
body += await self.connection.read()
|
body += await self.connection.read()
|
||||||
|
|
||||||
logger.debug("READ %s %s", header, len(body))
|
logger.debug("READ %s %s", header, len(body))
|
||||||
@ -173,12 +208,14 @@ class PacketStream(object):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def read(self):
|
async def read(self):
|
||||||
|
"""Read data from the packet stream"""
|
||||||
|
|
||||||
msg = await self._read()
|
msg = await self._read()
|
||||||
if not msg:
|
if not msg:
|
||||||
return None
|
return None
|
||||||
# check whether it's a reply and handle accordingly
|
# check whether it's a reply and handle accordingly
|
||||||
if msg.req < 0:
|
if msg.req < 0:
|
||||||
t, handler = self._event_map[-msg.req]
|
_, handler = self._event_map[-msg.req]
|
||||||
await handler.process(msg)
|
await handler.process(msg)
|
||||||
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
|
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
|
||||||
if msg.end_err:
|
if msg.end_err:
|
||||||
@ -200,7 +237,11 @@ class PacketStream(object):
|
|||||||
logger.debug("WRITE HDR: %s", header)
|
logger.debug("WRITE HDR: %s", header)
|
||||||
logger.debug("WRITE DATA: %s", msg.data)
|
logger.debug("WRITE DATA: %s", msg.data)
|
||||||
|
|
||||||
def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None):
|
def send( # pylint: disable=too-many-arguments
|
||||||
|
self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None
|
||||||
|
):
|
||||||
|
"""Send data through the packet stream"""
|
||||||
|
|
||||||
update_counter = False
|
update_counter = False
|
||||||
if req is None:
|
if req is None:
|
||||||
update_counter = True
|
update_counter = True
|
||||||
@ -222,5 +263,7 @@ class PacketStream(object):
|
|||||||
return handler
|
return handler
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
|
"""Disconnect the stream"""
|
||||||
|
|
||||||
self._connected = False
|
self._connected = False
|
||||||
self.connection.disconnect()
|
self.connection.disconnect()
|
||||||
|
10
ssb/util.py
10
ssb/util.py
@ -20,6 +20,8 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
"""Utility functions"""
|
||||||
|
|
||||||
from base64 import b64decode, b64encode
|
from base64 import b64decode, b64encode
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@ -28,17 +30,19 @@ import yaml
|
|||||||
|
|
||||||
|
|
||||||
class ConfigException(Exception):
|
class ConfigException(Exception):
|
||||||
pass
|
"""Exception to raise if there is a problem with the configuration data"""
|
||||||
|
|
||||||
|
|
||||||
def tag(key):
|
def tag(key):
|
||||||
"""Create tag from publick key."""
|
"""Create tag from public key"""
|
||||||
|
|
||||||
return b"@" + b64encode(bytes(key)) + b".ed25519"
|
return b"@" + b64encode(bytes(key)) + b".ed25519"
|
||||||
|
|
||||||
|
|
||||||
def load_ssb_secret():
|
def load_ssb_secret():
|
||||||
"""Load SSB keys from ~/.ssb"""
|
"""Load SSB keys from ~/.ssb"""
|
||||||
with open(os.path.expanduser("~/.ssb/secret")) as f:
|
|
||||||
|
with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f:
|
||||||
config = yaml.load(f, Loader=yaml.SafeLoader)
|
config = yaml.load(f, Loader=yaml.SafeLoader)
|
||||||
|
|
||||||
if config["curve"] != "ed25519":
|
if config["curve"] != "ed25519":
|
||||||
|
@ -20,6 +20,8 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
"""Tests for the feed functionality"""
|
||||||
|
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
@ -46,17 +48,23 @@ SERIALIZED_M1 = b"""{
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def local_feed():
|
def local_feed():
|
||||||
|
"""Fixture providing a local feed"""
|
||||||
|
|
||||||
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
||||||
return LocalFeed(SigningKey(secret))
|
return LocalFeed(SigningKey(secret))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def remote_feed():
|
def remote_feed():
|
||||||
|
"""Fixture providing a remote feed"""
|
||||||
|
|
||||||
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
||||||
return Feed(VerifyKey(public))
|
return Feed(VerifyKey(public))
|
||||||
|
|
||||||
|
|
||||||
def test_local_feed():
|
def test_local_feed():
|
||||||
|
"""Test a local feed"""
|
||||||
|
|
||||||
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
|
||||||
feed = LocalFeed(SigningKey(secret))
|
feed = LocalFeed(SigningKey(secret))
|
||||||
assert bytes(feed.private_key) == secret
|
assert bytes(feed.private_key) == secret
|
||||||
@ -65,6 +73,8 @@ def test_local_feed():
|
|||||||
|
|
||||||
|
|
||||||
def test_remote_feed():
|
def test_remote_feed():
|
||||||
|
"""Test a remote feed"""
|
||||||
|
|
||||||
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
|
||||||
feed = Feed(VerifyKey(public))
|
feed = Feed(VerifyKey(public))
|
||||||
assert bytes(feed.public_key) == public
|
assert bytes(feed.public_key) == public
|
||||||
@ -88,7 +98,9 @@ def test_remote_feed():
|
|||||||
feed.sign(m1)
|
feed.sign(m1)
|
||||||
|
|
||||||
|
|
||||||
def test_local_message(local_feed):
|
def test_local_message(local_feed): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test a local message"""
|
||||||
|
|
||||||
m1 = LocalMessage(
|
m1 = LocalMessage(
|
||||||
local_feed,
|
local_feed,
|
||||||
OrderedDict(
|
OrderedDict(
|
||||||
@ -133,7 +145,9 @@ def test_local_message(local_feed):
|
|||||||
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
||||||
|
|
||||||
|
|
||||||
def test_remote_message(remote_feed):
|
def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test a remote message"""
|
||||||
|
|
||||||
signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519"
|
signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519"
|
||||||
m1 = Message(
|
m1 = Message(
|
||||||
remote_feed,
|
remote_feed,
|
||||||
@ -177,7 +191,9 @@ def test_remote_message(remote_feed):
|
|||||||
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
|
||||||
|
|
||||||
|
|
||||||
def test_remote_no_signature(remote_feed):
|
def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test remote feed without a signature"""
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
Message(
|
Message(
|
||||||
remote_feed,
|
remote_feed,
|
||||||
@ -194,7 +210,9 @@ def test_remote_no_signature(remote_feed):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_serialize(local_feed):
|
def test_serialize(local_feed): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test feed serialization"""
|
||||||
|
|
||||||
m1 = LocalMessage(
|
m1 = LocalMessage(
|
||||||
local_feed,
|
local_feed,
|
||||||
OrderedDict(
|
OrderedDict(
|
||||||
@ -211,7 +229,9 @@ def test_serialize(local_feed):
|
|||||||
assert m1.serialize() == SERIALIZED_M1
|
assert m1.serialize() == SERIALIZED_M1
|
||||||
|
|
||||||
|
|
||||||
def test_parse(local_feed):
|
def test_parse(local_feed): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test feed parsing"""
|
||||||
|
|
||||||
m1 = LocalMessage.parse(SERIALIZED_M1, local_feed)
|
m1 = LocalMessage.parse(SERIALIZED_M1, local_feed)
|
||||||
assert m1.content == {
|
assert m1.content == {
|
||||||
"type": "about",
|
"type": "about",
|
||||||
|
@ -20,10 +20,11 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
"""Tests for the packet stream"""
|
||||||
|
|
||||||
from asyncio import Event, ensure_future, gather
|
from asyncio import Event, ensure_future, gather
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from nacl.signing import SigningKey
|
|
||||||
import pytest
|
import pytest
|
||||||
from secret_handshake.network import SHSDuplexStream
|
from secret_handshake.network import SHSDuplexStream
|
||||||
|
|
||||||
@ -58,63 +59,94 @@ MSG_BODY_2 = (
|
|||||||
|
|
||||||
|
|
||||||
class MockSHSSocket(SHSDuplexStream):
|
class MockSHSSocket(SHSDuplexStream):
|
||||||
def __init__(self, *args, **kwargs):
|
"""A mocked SHS socket"""
|
||||||
super(MockSHSSocket, self).__init__()
|
|
||||||
|
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
self.input = []
|
self.input = []
|
||||||
self.output = []
|
self.output = []
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
self._on_connect = []
|
self._on_connect = []
|
||||||
|
|
||||||
def on_connect(self, cb):
|
def on_connect(self, cb):
|
||||||
|
"""Set the on_connect callback"""
|
||||||
|
|
||||||
self._on_connect.append(cb)
|
self._on_connect.append(cb)
|
||||||
|
|
||||||
async def read(self):
|
async def read(self):
|
||||||
|
"""Read data from the socket"""
|
||||||
|
|
||||||
if not self.input:
|
if not self.input:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
return self.input.pop(0)
|
return self.input.pop(0)
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data):
|
||||||
|
"""Write data to the socket"""
|
||||||
|
|
||||||
self.output.append(data)
|
self.output.append(data)
|
||||||
|
|
||||||
def feed(self, input):
|
def feed(self, input_):
|
||||||
self.input += input
|
"""Get the connection’s feed"""
|
||||||
|
|
||||||
|
self.input += input_
|
||||||
|
|
||||||
def get_output(self):
|
def get_output(self):
|
||||||
|
"""Get the output of a call"""
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if not self.output:
|
if not self.output:
|
||||||
break
|
break
|
||||||
yield self.output.pop(0)
|
yield self.output.pop(0)
|
||||||
|
|
||||||
def disconnect(self):
|
def disconnect(self):
|
||||||
|
"""Disconnect from the remote party"""
|
||||||
|
|
||||||
self.is_connected = False
|
self.is_connected = False
|
||||||
|
|
||||||
|
|
||||||
class MockSHSClient(MockSHSSocket):
|
class MockSHSClient(MockSHSSocket):
|
||||||
|
"""A mocked SHS client"""
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
|
"""Connect to a SHS server"""
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
for cb in self._on_connect:
|
for cb in self._on_connect:
|
||||||
await cb()
|
await cb()
|
||||||
|
|
||||||
|
|
||||||
class MockSHSServer(MockSHSSocket):
|
class MockSHSServer(MockSHSSocket):
|
||||||
|
"""A mocked SHS server"""
|
||||||
|
|
||||||
def listen(self):
|
def listen(self):
|
||||||
|
"""Listen for new connections"""
|
||||||
|
|
||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
for cb in self._on_connect:
|
for cb in self._on_connect:
|
||||||
ensure_future(cb())
|
ensure_future(cb())
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ps_client(event_loop):
|
def ps_client(event_loop): # pylint: disable=unused-argument
|
||||||
|
"""Fixture to provide a mocked SHS client"""
|
||||||
|
|
||||||
return MockSHSClient()
|
return MockSHSClient()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ps_server(event_loop):
|
def ps_server(event_loop): # pylint: disable=unused-argument
|
||||||
|
"""Fixture to provide a mocked SHS server"""
|
||||||
|
|
||||||
return MockSHSServer()
|
return MockSHSServer()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_connect(ps_server):
|
async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test the on_connect callback functionality"""
|
||||||
|
|
||||||
called = Event()
|
called = Event()
|
||||||
|
|
||||||
async def _on_connect():
|
async def _on_connect():
|
||||||
@ -127,7 +159,9 @@ async def test_on_connect(ps_server):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_decoding(ps_client):
|
async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test message decoding"""
|
||||||
|
|
||||||
await ps_client.connect()
|
await ps_client.connect()
|
||||||
|
|
||||||
ps = PacketStream(ps_client)
|
ps = PacketStream(ps_client)
|
||||||
@ -160,7 +194,9 @@ async def test_message_decoding(ps_client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_encoding(ps_client):
|
async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test message encoding"""
|
||||||
|
|
||||||
await ps_client.connect()
|
await ps_client.connect()
|
||||||
|
|
||||||
ps = PacketStream(ps_client)
|
ps = PacketStream(ps_client)
|
||||||
@ -201,7 +237,9 @@ async def test_message_encoding(ps_client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_stream(ps_client, mocker):
|
async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test requesting a history stream"""
|
||||||
|
|
||||||
await ps_client.connect()
|
await ps_client.connect()
|
||||||
|
|
||||||
ps = PacketStream(ps_client)
|
ps = PacketStream(ps_client)
|
||||||
@ -226,8 +264,8 @@ async def test_message_stream(ps_client, mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert ps.req_counter == 2
|
assert ps.req_counter == 2
|
||||||
assert ps.register_handler.call_count == 1
|
assert ps.register_handler.call_count == 1 # pylint: disable=no-member
|
||||||
handler = list(ps._event_map.values())[0][1]
|
handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access
|
||||||
mock_process = mocker.AsyncMock()
|
mock_process = mocker.AsyncMock()
|
||||||
|
|
||||||
mocker.patch.object(handler, "process", mock_process)
|
mocker.patch.object(handler, "process", mock_process)
|
||||||
@ -259,8 +297,8 @@ async def test_message_stream(ps_client, mocker):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert ps.req_counter == 3
|
assert ps.req_counter == 3
|
||||||
assert ps.register_handler.call_count == 2
|
assert ps.register_handler.call_count == 2 # pylint: disable=no-member
|
||||||
handler = list(ps._event_map.values())[1][1]
|
handler = list(ps._event_map.values())[1][1] # pylint: disable=protected-access
|
||||||
|
|
||||||
mock_process = mocker.patch.object(handler, "process", wraps=handler.process)
|
mock_process = mocker.patch.object(handler, "process", wraps=handler.process)
|
||||||
ps_client.feed(
|
ps_client.feed(
|
||||||
@ -286,7 +324,9 @@ async def test_message_stream(ps_client, mocker):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_request(ps_server, mocker):
|
async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name
|
||||||
|
"""Test message sending"""
|
||||||
|
|
||||||
ps_server.listen()
|
ps_server.listen()
|
||||||
|
|
||||||
ps = PacketStream(ps_server)
|
ps = PacketStream(ps_server)
|
||||||
@ -300,8 +340,8 @@ async def test_message_request(ps_server, mocker):
|
|||||||
assert json.loads(body.decode("utf-8")) == {"name": ["whoami"], "args": []}
|
assert json.loads(body.decode("utf-8")) == {"name": ["whoami"], "args": []}
|
||||||
|
|
||||||
assert ps.req_counter == 2
|
assert ps.req_counter == 2
|
||||||
assert ps.register_handler.call_count == 1
|
assert ps.register_handler.call_count == 1 # pylint: disable=no-member
|
||||||
handler = list(ps._event_map.values())[0][1]
|
handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access
|
||||||
mock_process = mocker.AsyncMock()
|
mock_process = mocker.AsyncMock()
|
||||||
|
|
||||||
mocker.patch.object(handler, "process", mock_process)
|
mocker.patch.object(handler, "process", mock_process)
|
||||||
|
@ -20,6 +20,8 @@
|
|||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
# SOFTWARE.
|
# SOFTWARE.
|
||||||
|
|
||||||
|
"""Tests for the utility functions"""
|
||||||
|
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from unittest.mock import mock_open, patch
|
from unittest.mock import mock_open, patch
|
||||||
|
|
||||||
@ -41,6 +43,8 @@ CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo")
|
|||||||
|
|
||||||
|
|
||||||
def test_load_secret():
|
def test_load_secret():
|
||||||
|
"""Test loading the SSB secret from a file"""
|
||||||
|
|
||||||
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True):
|
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True):
|
||||||
secret = load_ssb_secret()
|
secret = load_ssb_secret()
|
||||||
|
|
||||||
@ -52,6 +56,8 @@ def test_load_secret():
|
|||||||
|
|
||||||
|
|
||||||
def test_load_exception():
|
def test_load_exception():
|
||||||
|
"""Test configuration loading if there is a problem with the file"""
|
||||||
|
|
||||||
with pytest.raises(ConfigException):
|
with pytest.raises(ConfigException):
|
||||||
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True):
|
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True):
|
||||||
load_ssb_secret()
|
load_ssb_secret()
|
||||||
|
Loading…
Reference in New Issue
Block a user