diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index db88ec5..8d45122 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,3 +24,9 @@ repos: language: system require_serial: true types_or: [python, pyi] + - id: pylint + name: pylint + entry: poetry run pylint + language: system + types: [python] + require_serial: true diff --git a/examples/test_client.py b/examples/test_client.py index bb7404d..8c0939c 100644 --- a/examples/test_client.py +++ b/examples/test_client.py @@ -1,3 +1,7 @@ +"""Example SSB Client""" + +import base64 +import hashlib import logging import struct import time @@ -10,26 +14,28 @@ from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException from ssb.packet_stream import PacketStream, PSMessageType from ssb.util import load_ssb_secret -import hashlib -import base64 - - api = MuxRPCAPI() @api.define("createHistoryStream") -def create_history_stream(connection, msg): +def create_history_stream(connection, msg): # pylint: disable=unused-argument + """Handle the createHistoryStream RPC call""" + print("create_history_stream", msg) # msg = PSMessage(PSMessageType.JSON, True, stream=True, end_err=True, req=-req) # connection.write(msg) @api.define("blobs.createWants") -def create_wants(connection, msg): +def create_wants(connection, msg): # pylint: disable=unused-argument + """Handle the createWants RPC call""" + print("create_wants", msg) async def test_client(): + """The actual client implementation""" + async for msg in api.call( "createHistoryStream", [{"id": "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519", "seq": 1, "live": False, "keys": False}], @@ -63,7 +69,9 @@ async def test_client(): f.write(img_data) -async def main(): +async def main(keypair): + """The main function to run""" + client = SHSClient("127.0.0.1", 8008, keypair, bytes(keypair.verify_key)) packet_stream = PacketStream(client) await client.open() @@ -89,8 +97,8 @@ if __name__ == "__main__": logger.setLevel(logging.INFO) logger.addHandler(ch) - keypair = load_ssb_secret()["keypair"] + ssb_keypair = load_ssb_secret()["keypair"] loop = get_event_loop() - loop.run_until_complete(main()) + loop.run_until_complete(main(ssb_keypair)) loop.close() diff --git a/examples/test_server.py b/examples/test_server.py index 4527791..b91b3f6 100644 --- a/examples/test_server.py +++ b/examples/test_server.py @@ -1,5 +1,7 @@ +"""Test SSB server""" + import logging -from asyncio import gather, get_event_loop, ensure_future +from asyncio import get_event_loop from colorlog import ColoredFormatter @@ -12,6 +14,8 @@ api = MuxRPCAPI() async def on_connect(conn): + """Incoming connection handler""" + packet_stream = PacketStream(conn) api.add_connection(packet_stream) @@ -22,6 +26,7 @@ async def on_connect(conn): async def main(): + """The main function to run""" server = SHSServer("127.0.0.1", 8008, load_ssb_secret()["keypair"]) server.on_connect(on_connect) await server.listen() diff --git a/poetry.lock b/poetry.lock index 72529fd..a219809 100644 --- a/poetry.lock +++ b/poetry.lock @@ -25,6 +25,20 @@ files = [ [package.extras] test = ["coverage", "mypy", "pexpect", "ruff", "wheel"] +[[package]] +name = "astroid" +version = "3.0.1" +description = "An abstract syntax tree for Python with inference support." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "astroid-3.0.1-py3-none-any.whl", hash = "sha256:7d5895c9825e18079c5aeac0572bc2e4c83205c95d416e0b4fee8bc361d2d9ca"}, + {file = "astroid-3.0.1.tar.gz", hash = "sha256:86b0bb7d7da0be1a7c4aedb7974e391b32d4ed89e33de6ed6902b4b15c97577e"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + [[package]] name = "babel" version = "2.13.1" @@ -419,6 +433,20 @@ files = [ {file = "decli-0.6.1.tar.gz", hash = "sha256:ed88ccb947701e8e5509b7945fda56e150e2ac74a69f25d47ac85ef30ab0c0f0"}, ] +[[package]] +name = "dill" +version = "0.3.7" +description = "serialize all of Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, + {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] + [[package]] name = "docutils" version = "0.17.1" @@ -599,6 +627,17 @@ files = [ {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, ] +[[package]] +name = "mccabe" +version = "0.7.0" +description = "McCabe checker, plugin for flake8" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, +] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -712,6 +751,36 @@ files = [ [package.extras] plugins = ["importlib-metadata"] +[[package]] +name = "pylint" +version = "3.0.2" +description = "python code static checker" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "pylint-3.0.2-py3-none-any.whl", hash = "sha256:60ed5f3a9ff8b61839ff0348b3624ceeb9e6c2a92c514d81c9cc273da3b6bcda"}, + {file = "pylint-3.0.2.tar.gz", hash = "sha256:0d4c286ef6d2f66c8bfb527a7f8a629009e42c99707dec821a03e1b51a4c1496"}, +] + +[package.dependencies] +astroid = ">=3.0.1,<=3.1.0-dev0" +colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} +dill = [ + {version = ">=0.2", markers = "python_version < \"3.11\""}, + {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, +] +isort = ">=4.2.5,<6" +mccabe = ">=0.6,<0.8" +platformdirs = ">=2.2.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +tomlkit = ">=0.10.1" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +spelling = ["pyenchant (>=3.2,<4.0)"] +testutils = ["gitpython (>3)"] + [[package]] name = "pynacl" version = "1.5.0" @@ -1200,4 +1269,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "a69623d229f05becfdd7a18072ae96970994ceebdc193d7840aa704ba0d86169" +content-hash = "63b3d6f54c99a6722a3d0d5cf9eac68bdb5ef0ea7c58957dd76494529870186c" diff --git a/pyproject.toml b/pyproject.toml index a96d327..6a3078c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ pytest-cov = "^4.1.0" pytest-mock = "^3.12.0" commitizen = "^3.12.0" black = "^23.10.1" +pylint = "^3.0.2" [tool.poetry.group.docs.dependencies] Sphinx = "^2.1.1" @@ -41,6 +42,9 @@ skip_covered = true fail_under = 70 omit = ["examples/*"] +[tool.pylint.format] +max-line-length = 120 + [tool.pytest.ini_options] addopts = ["--cov=.", "--no-cov-on-fail"] python_files = ["tests/test_*.py"] diff --git a/ssb/feed/__init__.py b/ssb/feed/__init__.py index b8c7cb6..ccbf48f 100644 --- a/ssb/feed/__init__.py +++ b/ssb/feed/__init__.py @@ -1,3 +1,5 @@ +"""Feed related functionality""" + from .models import Feed, LocalFeed, Message, LocalMessage, NoPrivateKeyException __all__ = ("Feed", "LocalFeed", "Message", "LocalMessage", "NoPrivateKeyException") diff --git a/ssb/feed/models.py b/ssb/feed/models.py index 816e86f..9344f5b 100644 --- a/ssb/feed/models.py +++ b/ssb/feed/models.py @@ -1,3 +1,5 @@ +"""Feed models""" + import datetime from base64 import b64encode from collections import namedtuple, OrderedDict @@ -12,44 +14,65 @@ OrderedMsg = namedtuple("OrderedMsg", ("previous", "author", "sequence", "timest class NoPrivateKeyException(Exception): - pass + """Exception to raise when a private key is not available""" def to_ordered(data): + """Convert a dictionary to an ``OrderedDict``""" + smsg = OrderedMsg(**data) + return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields) def get_millis_1970(): + """Get the UNIX timestamp in milliseconds""" + return int(datetime.datetime.utcnow().timestamp() * 1000) -class Feed(object): +class Feed: + """Base class for feeds""" + def __init__(self, public_key): self.public_key = public_key @property def id(self): + """The identifier of the feed""" + return tag(self.public_key).decode("ascii") def sign(self, msg): + """Sign a message""" + raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)") class LocalFeed(Feed): - def __init__(self, private_key): + """Class representing a local feed""" + + def __init__(self, private_key): # pylint: disable=super-init-not-called self.private_key = private_key @property def public_key(self): + """The public key of the feed""" + return self.private_key.verify_key def sign(self, msg): + """Sign a message for this feed""" + return self.private_key.sign(msg).signature -class Message(object): - def __init__(self, feed, content, signature, sequence=1, timestamp=None, previous=None): +class Message: + """Base class for SSB messages""" + + def __init__( # pylint: disable=too-many-arguments + self, feed, content, signature=None, sequence=1, timestamp=None, previous=None + ): self.feed = feed self.content = content @@ -68,14 +91,21 @@ class Message(object): @classmethod def parse(cls, data, feed): + """Parse raw message data""" + obj = loads(data, object_pairs_hook=OrderedDict) msg = cls(feed, obj["content"], timestamp=obj["timestamp"]) + return msg def serialize(self, add_signature=True): + """Serialize the message""" + return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8") def to_dict(self, add_signature=True): + """Convert the message to a dictionary""" + obj = to_ordered( { "previous": self.previous.key if self.previous else None, @@ -89,23 +119,34 @@ class Message(object): if add_signature: obj["signature"] = self.signature + return obj def verify(self, signature): + """Verify the signature of the message""" + return self.signature == signature @property def hash(self): - hash = sha256(self.serialize()).digest() - return b64encode(hash).decode("ascii") + ".sha256" + """The cryptographic hash of the message""" + + hash_ = sha256(self.serialize()).digest() + return b64encode(hash_).decode("ascii") + ".sha256" @property def key(self): + """The key of the message""" + return "%" + self.hash class LocalMessage(Message): - def __init__(self, feed, content, signature=None, sequence=1, timestamp=None, previous=None): + """Class representing a local message""" + + def __init__( # pylint: disable=too-many-arguments,super-init-not-called + self, feed, content, signature=None, sequence=1, timestamp=None, previous=None + ): self.feed = feed self.content = content diff --git a/ssb/muxrpc.py b/ssb/muxrpc.py index a86894a..57f3d04 100644 --- a/ssb/muxrpc.py +++ b/ssb/muxrpc.py @@ -1,14 +1,20 @@ +"""MuxRPC""" + from functools import wraps from ssb.packet_stream import PSMessageType class MuxRPCAPIException(Exception): - pass + """Exception to raise on MuxRPC API errors""" -class MuxRPCHandler: +class MuxRPCHandler: # pylint: disable=too-few-public-methods + """Base MuxRPC handler class""" + def check_message(self, msg): + """Check message validity""" + body = msg.body if isinstance(body, dict) and "name" in body and body["name"] == "Error": @@ -16,6 +22,8 @@ class MuxRPCHandler: class MuxRPCRequestHandler(MuxRPCHandler): + """Base class for MuxRPC request handlers""" + def __init__(self, ps_handler): self.ps_handler = ps_handler @@ -26,6 +34,8 @@ class MuxRPCRequestHandler(MuxRPCHandler): class MuxRPCSourceHandler(MuxRPCHandler): + """MuxRPC handler for sources""" + def __init__(self, ps_handler): self.ps_handler = ps_handler @@ -39,39 +49,60 @@ class MuxRPCSourceHandler(MuxRPCHandler): return msg -class MuxRPCSinkHandlerMixin(object): +class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods + """Mixin for sink-type MuxRPC handlers""" + def send(self, msg, msg_type=PSMessageType.JSON, end=False): + """Send a message through the stream""" + self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end) class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler): + """MuxRPC handler for duplex streams""" + def __init__(self, ps_handler, connection, req): - super(MuxRPCDuplexHandler, self).__init__(ps_handler) + super().__init__(ps_handler) + self.connection = connection self.req = req class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin): + """MuxRPC handler for sinks""" + def __init__(self, connection, req): self.connection = connection self.req = req def _get_appropriate_api_handler(type_, connection, ps_handler, req): + """Find the appropriate MuxRPC handler""" + if type_ in {"sync", "async"}: return MuxRPCRequestHandler(ps_handler) - elif type_ == "source": + + if type_ == "source": return MuxRPCSourceHandler(ps_handler) - elif type_ == "sink": + + if type_ == "sink": return MuxRPCSinkHandler(connection, req) - elif type_ == "duplex": + + if type_ == "duplex": return MuxRPCDuplexHandler(ps_handler, connection, req) + return None + + +class MuxRPCRequest: + """MuxRPC request""" -class MuxRPCRequest(object): @classmethod def from_message(cls, message): + """Initialise a request from a raw packet stream message""" + body = message.body + return cls(".".join(body["name"]), body["args"]) def __init__(self, name, args): @@ -79,22 +110,28 @@ class MuxRPCRequest(object): self.args = args def __repr__(self): - return "".format(self) + return f"" -class MuxRPCMessage(object): +class MuxRPCMessage: + """MuxRPC message""" + @classmethod def from_message(cls, message): + """Initialise a MuxRPC message from a raw packet stream message""" + return cls(message.body) def __init__(self, body): self.body = body def __repr__(self): - return "".format(self) + return f"" -class MuxRPCAPI(object): +class MuxRPCAPI: + """Generit MuxRPC API""" + def __init__(self): self.handlers = {} self.connection = None @@ -109,9 +146,13 @@ class MuxRPCAPI(object): self.process(self.connection, MuxRPCRequest.from_message(req_message)) def add_connection(self, connection): + """Set the packet stream connection of this RPC API""" + self.connection = connection def define(self, name): + """Decorator to define an RPC method handler""" + def _handle(f): self.handlers[name] = f @@ -124,14 +165,20 @@ class MuxRPCAPI(object): return _handle def process(self, connection, request): + """Process an incoming request""" + handler = self.handlers.get(request.name) + if not handler: - raise MuxRPCAPIException("Method {} not found!".format(request.name)) + raise MuxRPCAPIException(f"Method {request.name} not found!") + handler(connection, request) def call(self, name, args, type_="sync"): + """Call an RPC method""" + if not self.connection.is_connected: - raise Exception("not connected") + raise Exception("not connected") # pylint: disable=broad-exception-raised old_counter = self.connection.req_counter ps_handler = self.connection.send( diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index 59b4a7b..969f7c8 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -1,3 +1,5 @@ +"""Packet streams""" + import logging import struct from asyncio import Event, Queue @@ -7,28 +9,34 @@ from math import ceil import simplejson -from secret_handshake import SHSClient, SHSServer - logger = logging.getLogger("packet_stream") class PSMessageType(Enum): + """Available message types""" + BUFFER = 0 TEXT = 1 JSON = 2 -class PSStreamHandler(object): +class PSStreamHandler: + """Packet stream handler""" + def __init__(self, req): super(PSStreamHandler).__init__() self.req = req self.queue = Queue() async def process(self, msg): + """Process a pending message""" + await self.queue.put(msg) async def stop(self): + """Stop a pending request""" + await self.queue.put(None) def __aiter__(self): @@ -43,30 +51,39 @@ class PSStreamHandler(object): return elem -class PSRequestHandler(object): +class PSRequestHandler: + """Packet stream request handler""" + def __init__(self, req): - super(PSRequestHandler).__init__() self.req = req self.event = Event() self._msg = None async def process(self, msg): + """Process a message request""" + self._msg = msg self.event.set() async def stop(self): + """Stop a pending event request""" + if not self.event.is_set(): self.event.set() - def __await__(self): + async def __await__(self): # wait until 'process' is called - yield from self.event.wait().__await__() + await self.event.wait() + return self._msg -class PSMessage(object): +class PSMessage: + """Packet Stream message""" + @classmethod def from_header_body(cls, flags, req, body): + """Parse a raw message""" type_ = PSMessageType(flags & 0x03) if type_ == PSMessageType.TEXT: @@ -78,13 +95,17 @@ class PSMessage(object): @property def data(self): + """The raw message data""" + if self.type == PSMessageType.TEXT: return self.body.encode("utf-8") - elif self.type == PSMessageType.JSON: + + if self.type == PSMessageType.JSON: return simplejson.dumps(self.body).encode("utf-8") + return self.body - def __init__(self, type_, body, stream, end_err, req=None): + def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments self.stream = stream self.end_err = end_err self.type = type_ @@ -93,29 +114,35 @@ class PSMessage(object): def __repr__(self): if self.type == PSMessageType.BUFFER: - body = "{} bytes".format(len(self.body)) + body = f"{len(self.body)} bytes" else: body = self.body - return "".format( - self.type.name, - body, - "" if self.req is None else " [{}]".format(self.req), - "~" if self.stream else "", - "!" if self.end_err else "", - ) + + req = "" if self.req is None else f" [{self.req}]" + is_stream = "~" if self.stream else "" + err = "!" if self.end_err else "" + + return f"" -class PacketStream(object): +class PacketStream: + """SSB Packet stream""" + def __init__(self, connection): self.connection = connection self.req_counter = 1 self._event_map = {} + self._connected = False def register_handler(self, handler): + """Register an RPC handler""" + self._event_map[handler.req] = (time(), handler) @property def is_connected(self): + """Check if the stream is connected""" + return self.connection.is_connected def __aiter__(self): @@ -148,7 +175,8 @@ class PacketStream(object): n_packets = ceil(length / 4096) body = b"" - for n in range(n_packets): + + for _ in range(n_packets): body += await self.connection.read() logger.debug("READ %s %s", header, len(body)) @@ -159,12 +187,14 @@ class PacketStream(object): return None async def read(self): + """Read data from the packet stream""" + msg = await self._read() if not msg: return None # check whether it's a reply and handle accordingly if msg.req < 0: - t, handler = self._event_map[-msg.req] + _, handler = self._event_map[-msg.req] await handler.process(msg) logger.info("RESPONSE [%d]: %r", -msg.req, msg) if msg.end_err: @@ -183,7 +213,11 @@ class PacketStream(object): logger.debug("WRITE HDR: %s", header) logger.debug("WRITE DATA: %s", msg.data) - def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None): + def send( # pylint: disable=too-many-arguments + self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None + ): + """Send data through the packet stream""" + update_counter = False if req is None: update_counter = True @@ -205,5 +239,7 @@ class PacketStream(object): return handler def disconnect(self): + """Disconnect the stream""" + self._connected = False self.connection.disconnect() diff --git a/ssb/util.py b/ssb/util.py index 9efab92..d97cb93 100644 --- a/ssb/util.py +++ b/ssb/util.py @@ -1,16 +1,18 @@ -import os -import yaml -from base64 import b64decode, b64encode +"""Utility functions""" +from base64 import b64decode, b64encode +import os + +import yaml from nacl.signing import SigningKey class ConfigException(Exception): - pass + """Exception to raise if there is a problem with the configuration data""" def tag(key): - """Create tag from publick key.""" + """Create tag from public key.""" return b"@" + b64encode(bytes(key)) + b".ed25519" @@ -18,7 +20,7 @@ def tag(key): def load_ssb_secret(): """Load SSB keys from ~/.ssb""" - with open(os.path.expanduser("~/.ssb/secret")) as f: + with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f: config = yaml.load(f, Loader=yaml.SafeLoader) if config["curve"] != "ed25519": diff --git a/tests/test_feed.py b/tests/test_feed.py index dab0cd9..d111cf1 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -1,3 +1,5 @@ +"""Tests for the feed functionality""" + from base64 import b64decode from collections import OrderedDict @@ -25,17 +27,23 @@ SERIALIZED_M1 = b"""{ @pytest.fixture def local_feed(): + """Fixture providing a local feed""" + secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") return LocalFeed(SigningKey(secret)) @pytest.fixture def remote_feed(): + """Fixture providing a remote feed""" + public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") return Feed(VerifyKey(public)) def test_local_feed(): + """Test a local feed""" + secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") feed = LocalFeed(SigningKey(secret)) assert bytes(feed.private_key) == secret @@ -44,6 +52,8 @@ def test_local_feed(): def test_remote_feed(): + """Test a remote feed""" + public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") feed = Feed(VerifyKey(public)) assert bytes(feed.public_key) == public @@ -60,7 +70,9 @@ def test_remote_feed(): feed.sign(m1) -def test_local_message(local_feed): +def test_local_message(local_feed): # pylint: disable=redefined-outer-name + """Test a local message""" + m1 = LocalMessage( local_feed, OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]), @@ -91,7 +103,9 @@ def test_local_message(local_feed): assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" -def test_remote_message(remote_feed): +def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name + """Test a remote message""" + signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519" m1 = Message( remote_feed, @@ -123,7 +137,9 @@ def test_remote_message(remote_feed): assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" -def test_remote_no_signature(remote_feed): +def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name + """Test remote feed without a signature""" + with pytest.raises(ValueError): Message( remote_feed, @@ -135,7 +151,9 @@ def test_remote_no_signature(remote_feed): ) -def test_serialize(local_feed): +def test_serialize(local_feed): # pylint: disable=redefined-outer-name + """Test feed serialization""" + m1 = LocalMessage( local_feed, OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]), @@ -145,7 +163,9 @@ def test_serialize(local_feed): assert m1.serialize() == SERIALIZED_M1 -def test_parse(local_feed): +def test_parse(local_feed): # pylint: disable=redefined-outer-name + """Test feed parsing""" + m1 = LocalMessage.parse(SERIALIZED_M1, local_feed) assert m1.content == {"type": "about", "about": local_feed.id, "name": "neo", "description": "The Chosen One"} assert m1.timestamp == 1495706260190 diff --git a/tests/test_packet_stream.py b/tests/test_packet_stream.py index 6db5395..5b5e016 100644 --- a/tests/test_packet_stream.py +++ b/tests/test_packet_stream.py @@ -1,8 +1,9 @@ +"""Tests for the packet stream""" + import json from asyncio import ensure_future, gather, Event import pytest -from nacl.signing import SigningKey from secret_handshake.network import SHSDuplexStream from ssb.packet_stream import PacketStream, PSMessageType @@ -36,63 +37,94 @@ MSG_BODY_2 = ( class MockSHSSocket(SHSDuplexStream): - def __init__(self, *args, **kwargs): - super(MockSHSSocket, self).__init__() + """A mocked SHS socket""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + super().__init__() + self.input = [] self.output = [] self.is_connected = False self._on_connect = [] def on_connect(self, cb): + """Set the on_connect callback""" + self._on_connect.append(cb) async def read(self): + """Read data from the socket""" + if not self.input: raise StopAsyncIteration return self.input.pop(0) def write(self, data): + """Write data to the socket""" + self.output.append(data) - def feed(self, input): - self.input += input + def feed(self, input_): + """Get the connection’s feed""" + + self.input += input_ def get_output(self): + """Get the output of a call""" + while True: if not self.output: break yield self.output.pop(0) def disconnect(self): + """Disconnect from the remote party""" + self.is_connected = False class MockSHSClient(MockSHSSocket): + """A mocked SHS client""" + async def connect(self): + """Connect to a SHS server""" + self.is_connected = True + for cb in self._on_connect: await cb() class MockSHSServer(MockSHSSocket): + """A mocked SHS server""" + def listen(self): + """Listen for new connections""" + self.is_connected = True + for cb in self._on_connect: ensure_future(cb()) @pytest.fixture -def ps_client(event_loop): +def ps_client(event_loop): # pylint: disable=unused-argument + """Fixture to provide a mocked SHS client""" + return MockSHSClient() @pytest.fixture -def ps_server(event_loop): +def ps_server(event_loop): # pylint: disable=unused-argument + """Fixture to provide a mocked SHS server""" + return MockSHSServer() @pytest.mark.asyncio -async def test_on_connect(ps_server): +async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name + """Test the on_connect callback functionality""" + called = Event() async def _on_connect(): @@ -105,7 +137,9 @@ async def test_on_connect(ps_server): @pytest.mark.asyncio -async def test_message_decoding(ps_client): +async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name + """Test message decoding""" + await ps_client.connect() ps = PacketStream(ps_client) @@ -133,7 +167,9 @@ async def test_message_decoding(ps_client): @pytest.mark.asyncio -async def test_message_encoding(ps_client): +async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name + """Test message encoding""" + await ps_client.connect() ps = PacketStream(ps_client) @@ -164,7 +200,9 @@ async def test_message_encoding(ps_client): @pytest.mark.asyncio -async def test_message_stream(ps_client, mocker): +async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name + """Test requesting a history stream""" + await ps_client.connect() ps = PacketStream(ps_client) @@ -184,8 +222,8 @@ async def test_message_stream(ps_client, mocker): ) assert ps.req_counter == 2 - assert ps.register_handler.call_count == 1 - handler = list(ps._event_map.values())[0][1] + assert ps.register_handler.call_count == 1 # pylint: disable=no-member + handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access mock_process = mocker.patch.object(handler, "process") @@ -211,8 +249,8 @@ async def test_message_stream(ps_client, mocker): ) assert ps.req_counter == 3 - assert ps.register_handler.call_count == 2 - handler = list(ps._event_map.values())[1][1] + assert ps.register_handler.call_count == 2 # pylint: disable=no-member + handler = list(ps._event_map.values())[1][1] # pylint: disable=protected-access mock_process = mocker.patch.object(handler, "process", wraps=handler.process) @@ -234,7 +272,9 @@ async def test_message_stream(ps_client, mocker): @pytest.mark.asyncio -async def test_message_request(ps_server, mocker): +async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name + """Test message sending""" + ps_server.listen() ps = PacketStream(ps_server) @@ -248,8 +288,8 @@ async def test_message_request(ps_server, mocker): assert json.loads(body.decode("utf-8")) == {"name": ["whoami"], "args": []} assert ps.req_counter == 2 - assert ps.register_handler.call_count == 1 - handler = list(ps._event_map.values())[0][1] + assert ps.register_handler.call_count == 1 # pylint: disable=no-member + handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access mock_process = mocker.patch.object(handler, "process") diff --git a/tests/test_util.py b/tests/test_util.py index 5b40ca5..3766f28 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,3 +1,5 @@ +"""Test for utility functions""" + from base64 import b64decode from unittest.mock import mock_open, patch @@ -20,6 +22,8 @@ CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo") def test_load_secret(): + """Test loading the SSB secret from a file""" + with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True): secret = load_ssb_secret() @@ -31,6 +35,7 @@ def test_load_secret(): def test_load_exception(): + """Test configuration loading if there is a problem with the file""" with pytest.raises(ConfigException): with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True): load_ssb_secret()