From d51f27d88315228571881e6f40fe7e5e22bf39ce Mon Sep 17 00:00:00 2001 From: Gergely Polonkai Date: Wed, 1 Nov 2023 06:03:06 +0100 Subject: [PATCH] ci: Add and configure PyLint, and make it happy --- .pre-commit-config.yaml | 6 +++ examples/test_client.py | 20 ++++++--- examples/test_server.py | 10 ++++- poetry.lock | 71 ++++++++++++++++++++++++++++++- pyproject.toml | 4 ++ ssb/feed/__init__.py | 2 + ssb/feed/models.py | 57 +++++++++++++++++++++---- ssb/muxrpc.py | 85 ++++++++++++++++++++++++++++--------- ssb/packet_stream.py | 83 +++++++++++++++++++++++++++--------- ssb/util.py | 10 +++-- tests/test_feed.py | 30 ++++++++++--- tests/test_packet_stream.py | 76 +++++++++++++++++++++++++-------- tests/test_util.py | 6 +++ 13 files changed, 379 insertions(+), 81 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84e8298..258706a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,12 @@ repos: language: system require_serial: true types_or: [python, pyi] + - id: pylint + name: pylint + entry: poetry run pylint + language: system + types: [python] + require_serial: true - id: isort name: isort args: ["--check", "--diff"] diff --git a/examples/test_client.py b/examples/test_client.py index 5006852..624b53e 100644 --- a/examples/test_client.py +++ b/examples/test_client.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""Example SSB Client""" + from asyncio import ensure_future, gather, get_event_loop import base64 import hashlib @@ -38,18 +40,24 @@ api = MuxRPCAPI() @api.define("createHistoryStream") -def create_history_stream(connection, msg): +def create_history_stream(connection, msg): # pylint: disable=unused-argument + """Handle the createHistoryStream RPC call""" + print("create_history_stream", msg) # msg = PSMessage(PSMessageType.JSON, True, stream=True, end_err=True, req=-req) # connection.write(msg) @api.define("blobs.createWants") -def create_wants(connection, msg): +def create_wants(connection, msg): # pylint: disable=unused-argument + """Handle the createWants RPC call""" + print("create_wants", msg) async def test_client(): + """The actual client implementation""" + async for msg in api.call( "createHistoryStream", [ @@ -90,7 +98,9 @@ async def test_client(): f.write(img_data) -async def main(): +async def main(keypair): + """The main function to run""" + client = SHSClient("127.0.0.1", 8008, keypair, bytes(keypair.verify_key)) packet_stream = PacketStream(client) await client.open() @@ -116,8 +126,8 @@ if __name__ == "__main__": logger.setLevel(logging.INFO) logger.addHandler(ch) - keypair = load_ssb_secret()["keypair"] + ssb_keypair = load_ssb_secret()["keypair"] loop = get_event_loop() - loop.run_until_complete(main()) + loop.run_until_complete(main(ssb_keypair)) loop.close() diff --git a/examples/test_server.py b/examples/test_server.py index 46ac6b4..d1726e2 100644 --- a/examples/test_server.py +++ b/examples/test_server.py @@ -20,7 +20,9 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from asyncio import ensure_future, gather, get_event_loop +"""Test SSB server""" + +from asyncio import get_event_loop import logging from colorlog import ColoredFormatter @@ -34,6 +36,8 @@ api = MuxRPCAPI() async def on_connect(conn): + """Incoming connection handler""" + packet_stream = PacketStream(conn) api.add_connection(packet_stream) @@ -43,6 +47,8 @@ async def on_connect(conn): async def main(): + """The main function to run""" + server = SHSServer("127.0.0.1", 8008, load_ssb_secret()["keypair"]) server.on_connect(on_connect) await server.listen() @@ -55,7 +61,7 @@ if __name__ == "__main__": # create formatter formatter = ColoredFormatter( - "%(log_color)s%(levelname)s%(reset)s:%(bold_white)s%(name)s%(reset)s - " "%(cyan)s%(message)s%(reset)s" + "%(log_color)s%(levelname)s%(reset)s:%(bold_white)s%(name)s%(reset)s - %(cyan)s%(message)s%(reset)s" ) # add formatter to ch diff --git a/poetry.lock b/poetry.lock index 95092f8..7861176 100644 --- a/poetry.lock +++ b/poetry.lock @@ -25,6 +25,20 @@ files = [ [package.extras] test = ["coverage", "mypy", "pexpect", "ruff", "wheel"] +[[package]] +name = "astroid" +version = "3.0.1" +description = "An abstract syntax tree for Python with inference support." +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "astroid-3.0.1-py3-none-any.whl", hash = "sha256:7d5895c9825e18079c5aeac0572bc2e4c83205c95d416e0b4fee8bc361d2d9ca"}, + {file = "astroid-3.0.1.tar.gz", hash = "sha256:86b0bb7d7da0be1a7c4aedb7974e391b32d4ed89e33de6ed6902b4b15c97577e"}, +] + +[package.dependencies] +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + [[package]] name = "async-generator" version = "1.8" @@ -430,6 +444,20 @@ files = [ {file = "decli-0.6.1.tar.gz", hash = "sha256:ed88ccb947701e8e5509b7945fda56e150e2ac74a69f25d47ac85ef30ab0c0f0"}, ] +[[package]] +name = "dill" +version = "0.3.7" +description = "serialize all of Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"}, + {file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"}, +] + +[package.extras] +graph = ["objgraph (>=1.7.2)"] + [[package]] name = "docutils" version = "0.17.1" @@ -610,6 +638,17 @@ files = [ {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, ] +[[package]] +name = "mccabe" +version = "0.7.0" +description = "McCabe checker, plugin for flake8" +optional = false +python-versions = ">=3.6" +files = [ + {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, + {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, +] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -723,6 +762,36 @@ files = [ [package.extras] plugins = ["importlib-metadata"] +[[package]] +name = "pylint" +version = "3.0.2" +description = "python code static checker" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "pylint-3.0.2-py3-none-any.whl", hash = "sha256:60ed5f3a9ff8b61839ff0348b3624ceeb9e6c2a92c514d81c9cc273da3b6bcda"}, + {file = "pylint-3.0.2.tar.gz", hash = "sha256:0d4c286ef6d2f66c8bfb527a7f8a629009e42c99707dec821a03e1b51a4c1496"}, +] + +[package.dependencies] +astroid = ">=3.0.1,<=3.1.0-dev0" +colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} +dill = [ + {version = ">=0.2", markers = "python_version < \"3.11\""}, + {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, + {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, +] +isort = ">=4.2.5,<6" +mccabe = ">=0.6,<0.8" +platformdirs = ">=2.2.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +tomlkit = ">=0.10.1" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +spelling = ["pyenchant (>=3.2,<4.0)"] +testutils = ["gitpython (>3)"] + [[package]] name = "pynacl" version = "1.1.2" @@ -1241,4 +1310,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4eb60a723d8be38d3d197522f58d19f4bdae3887bd006d13ab87fb058c75b467" +content-hash = "bd8b3213143f1abe13f580d28e2d42ee3d663c2d010548e7acd27be04912308a" diff --git a/pyproject.toml b/pyproject.toml index b00849c..da65dd5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ commitizen = "^3.12.0" coverage = "^7.3.2" isort = "^5.12.0" pep257 = "^0.7.0" +pylint = "^3.0.2" pytest = "^7.4.3" pytest-asyncio = "^0.21.1" pytest-cov = "^4.1.0" @@ -50,6 +51,9 @@ force_sort_within_sections = true line_length = 120 profile = "black" +[tool.pylint.format] +max-line-length = 120 + [tool.pytest.ini_options] addopts = ["--cov=.", "--no-cov-on-fail"] python_files = ["tests/test_*.py"] diff --git a/ssb/feed/__init__.py b/ssb/feed/__init__.py index 75be4ca..d63a339 100644 --- a/ssb/feed/__init__.py +++ b/ssb/feed/__init__.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""Feed related functionality""" + from .models import Feed, LocalFeed, LocalMessage, Message, NoPrivateKeyException __all__ = ("Feed", "LocalFeed", "Message", "LocalMessage", "NoPrivateKeyException") diff --git a/ssb/feed/models.py b/ssb/feed/models.py index 9114c4b..a52c8f8 100644 --- a/ssb/feed/models.py +++ b/ssb/feed/models.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""Feed models""" + from base64 import b64encode from collections import OrderedDict, namedtuple import datetime @@ -33,44 +35,65 @@ OrderedMsg = namedtuple("OrderedMsg", ("previous", "author", "sequence", "timest class NoPrivateKeyException(Exception): - pass + """Exception to raise when a private key is not available""" def to_ordered(data): + """Convert a dictionary to an ``OrderedDict``""" + smsg = OrderedMsg(**data) + return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields) def get_millis_1970(): + """Get the UNIX timestamp in milliseconds""" + return int(datetime.datetime.utcnow().timestamp() * 1000) -class Feed(object): +class Feed: + """Base class for feeds""" + def __init__(self, public_key): self.public_key = public_key @property def id(self): + """The identifier of the feed""" + return tag(self.public_key).decode("ascii") def sign(self, msg): + """Sign a message""" + raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)") class LocalFeed(Feed): - def __init__(self, private_key): + """Class representing a local feed""" + + def __init__(self, private_key): # pylint: disable=super-init-not-called self.private_key = private_key @property def public_key(self): + """The public key of the feed""" + return self.private_key.verify_key def sign(self, msg): + """Sign a message for this feed""" + return self.private_key.sign(msg).signature -class Message(object): - def __init__(self, feed, content, signature, sequence=1, timestamp=None, previous=None): +class Message: + """Base class for SSB messages""" + + def __init__( # pylint: disable=too-many-arguments + self, feed, content, signature=None, sequence=1, timestamp=None, previous=None + ): self.feed = feed self.content = content @@ -88,14 +111,21 @@ class Message(object): @classmethod def parse(cls, data, feed): + """Parse raw message data""" + obj = loads(data, object_pairs_hook=OrderedDict) msg = cls(feed, obj["content"], timestamp=obj["timestamp"]) + return msg def serialize(self, add_signature=True): + """Serialize the message""" + return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8") def to_dict(self, add_signature=True): + """Convert the message to a dictionary""" + obj = to_ordered( { "previous": self.previous.key if self.previous else None, @@ -109,23 +139,34 @@ class Message(object): if add_signature: obj["signature"] = self.signature + return obj def verify(self, signature): + """Verify the signature of the message""" + return self.signature == signature @property def hash(self): - hash = sha256(self.serialize()).digest() - return b64encode(hash).decode("ascii") + ".sha256" + """The cryptographic hash of the message""" + + hash_ = sha256(self.serialize()).digest() + return b64encode(hash_).decode("ascii") + ".sha256" @property def key(self): + """The key of the message""" + return "%" + self.hash class LocalMessage(Message): - def __init__(self, feed, content, signature=None, sequence=1, timestamp=None, previous=None): + """Class representing a local message""" + + def __init__( # pylint: disable=too-many-arguments,super-init-not-called + self, feed, content, signature=None, sequence=1, timestamp=None, previous=None + ): self.feed = feed self.content = content diff --git a/ssb/muxrpc.py b/ssb/muxrpc.py index a1636d6..6244a73 100644 --- a/ssb/muxrpc.py +++ b/ssb/muxrpc.py @@ -20,23 +20,32 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""MuxRPC""" + from async_generator import async_generator, yield_ from ssb.packet_stream import PSMessageType class MuxRPCAPIException(Exception): - pass + """Exception to raise on MuxRPC API errors""" -class MuxRPCHandler(object): +class MuxRPCHandler: # pylint: disable=too-few-public-methods + """Base MuxRPC handler class""" + def check_message(self, msg): + """Check message validity""" + body = msg.body + if isinstance(body, dict) and "name" in body and body["name"] == "Error": raise MuxRPCAPIException(body["message"]) class MuxRPCRequestHandler(MuxRPCHandler): + """MuxRPC handler for incoming RPC requests""" + def __init__(self, ps_handler): self.ps_handler = ps_handler @@ -47,52 +56,72 @@ class MuxRPCRequestHandler(MuxRPCHandler): class MuxRPCSourceHandler(MuxRPCHandler): + """MuxRPC handler for source-type RPC requests""" + def __init__(self, ps_handler): self.ps_handler = ps_handler @async_generator async def __aiter__(self): async for msg in self.ps_handler: - try: - self.check_message(msg) - await yield_(msg) - except MuxRPCAPIException: - raise + self.check_message(msg) + await yield_(msg) -class MuxRPCSinkHandlerMixin(object): +class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods + """Mixin for sink-type MuxRPC handlers""" + def send(self, msg, msg_type=PSMessageType.JSON, end=False): + """Send a message through the stream""" + self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end) class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler): + """MuxRPC handler for duplex streams""" + def __init__(self, ps_handler, connection, req): - super(MuxRPCDuplexHandler, self).__init__(ps_handler) + super().__init__(ps_handler) + self.connection = connection self.req = req class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin): + """MuxRPC handler for sinks""" + def __init__(self, connection, req): self.connection = connection self.req = req def _get_appropriate_api_handler(type_, connection, ps_handler, req): + """Find the appropriate MuxRPC handler""" + if type_ in {"sync", "async"}: return MuxRPCRequestHandler(ps_handler) - elif type_ == "source": + + if type_ == "source": return MuxRPCSourceHandler(ps_handler) - elif type_ == "sink": + + if type_ == "sink": return MuxRPCSinkHandler(connection, req) - elif type_ == "duplex": + + if type_ == "duplex": return MuxRPCDuplexHandler(ps_handler, connection, req) + return None + + +class MuxRPCRequest: + """MuxRPC request""" -class MuxRPCRequest(object): @classmethod def from_message(cls, message): + """Initialise a request from a raw packet stream message""" + body = message.body + return cls(".".join(body["name"]), body["args"]) def __init__(self, name, args): @@ -100,22 +129,28 @@ class MuxRPCRequest(object): self.args = args def __repr__(self): - return "".format(self) + return f"" -class MuxRPCMessage(object): +class MuxRPCMessage: + """MuxRPC message""" + @classmethod def from_message(cls, message): + """Initialise a MuxRPC message from a raw packet stream message""" + return cls(message.body) def __init__(self, body): self.body = body def __repr__(self): - return "".format(self) + return f"" -class MuxRPCAPI(object): +class MuxRPCAPI: + """Generic MuxRPC API""" + def __init__(self): self.handlers = {} self.connection = None @@ -129,9 +164,13 @@ class MuxRPCAPI(object): self.process(self.connection, MuxRPCRequest.from_message(req_message)) def add_connection(self, connection): + """Set the packet stream connection of this RPC API""" + self.connection = connection def define(self, name): + """Decorator to define an RPC method handler""" + def _handle(f): self.handlers[name] = f @@ -140,17 +179,25 @@ class MuxRPCAPI(object): return _handle def process(self, connection, request): + """Process an incoming request""" + handler = self.handlers.get(request.name) + if not handler: - raise MuxRPCAPIException("Method {} not found!".format(request.name)) + raise MuxRPCAPIException(f"Method {request.name} not found!") + handler(connection, request) def call(self, name, args, type_="sync"): + """Call an RPC method""" + if not self.connection.is_connected: - raise Exception("not connected") + raise Exception("not connected") # pylint: disable=broad-exception-raised + old_counter = self.connection.req_counter ps_handler = self.connection.send( {"name": name.split("."), "args": args, "type": type_}, stream=type_ in {"sink", "source", "duplex"}, ) + return _get_appropriate_api_handler(type_, self.connection, ps_handler, old_counter) diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index 8e3c390..e0ba88a 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""Packet streams""" + from asyncio import Event, Queue from enum import Enum import logging @@ -28,28 +30,35 @@ import struct from time import time from async_generator import async_generator, yield_ -from secret_handshake import SHSClient, SHSServer import simplejson logger = logging.getLogger("packet_stream") class PSMessageType(Enum): + """Available message types""" + BUFFER = 0 TEXT = 1 JSON = 2 -class PSStreamHandler(object): +class PSStreamHandler: + """Packet stream handler""" + def __init__(self, req): super(PSStreamHandler).__init__() self.req = req self.queue = Queue() async def process(self, msg): + """Process a pending message""" + await self.queue.put(msg) async def stop(self): + """Stop a pending request""" + await self.queue.put(None) @async_generator @@ -61,30 +70,40 @@ class PSStreamHandler(object): await yield_(elem) -class PSRequestHandler(object): +class PSRequestHandler: + """Packet stream request handler""" + def __init__(self, req): - super(PSRequestHandler).__init__() self.req = req self.event = Event() self._msg = None async def process(self, msg): + """Process a message request""" + self._msg = msg self.event.set() async def stop(self): + """Stop a pending event request""" + if not self.event.is_set(): self.event.set() def __await__(self): # wait until 'process' is called - yield from self.event.wait().__await__() + yield from self.event.wait().__await__() # pylint: disable=no-member + return self._msg -class PSMessage(object): +class PSMessage: + """Packet Stream message""" + @classmethod def from_header_body(cls, flags, req, body): + """Parse a raw message""" + type_ = PSMessageType(flags & 0x03) if type_ == PSMessageType.TEXT: @@ -96,13 +115,17 @@ class PSMessage(object): @property def data(self): + """The raw message data""" + if self.type == PSMessageType.TEXT: return self.body.encode("utf-8") - elif self.type == PSMessageType.JSON: + + if self.type == PSMessageType.JSON: return simplejson.dumps(self.body).encode("utf-8") + return self.body - def __init__(self, type_, body, stream, end_err, req=None): + def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments self.stream = stream self.end_err = end_err self.type = type_ @@ -111,37 +134,45 @@ class PSMessage(object): def __repr__(self): if self.type == PSMessageType.BUFFER: - body = "{} bytes".format(len(self.body)) + body = f"{len(self.body)} bytes" else: body = self.body - return "".format( - self.type.name, - body, - "" if self.req is None else " [{}]".format(self.req), - "~" if self.stream else "", - "!" if self.end_err else "", - ) + + req = "" if self.req is None else f" [{self.req}]" + is_stream = "~" if self.stream else "" + err = "!" if self.end_err else "" + + return f"" -class PacketStream(object): +class PacketStream: + """SSB Packet stream""" + def __init__(self, connection): self.connection = connection self.req_counter = 1 self._event_map = {} + self._connected = False def register_handler(self, handler): + """Register an RPC handler""" + self._event_map[handler.req] = (time(), handler) @property def is_connected(self): + """Check if the stream is connected""" + return self.connection.is_connected @async_generator async def __aiter__(self): while True: msg = await self.read() + if not msg: return + # filter out replies if msg.req >= 0: await yield_(msg) @@ -149,20 +180,24 @@ class PacketStream(object): async def __await__(self): async for data in self: logger.info("RECV: %r", data) + if data is None: return async def _read(self): try: header = await self.connection.read() + if not header or header == b"\x00" * 9: return + flags, length, req = struct.unpack(">BIi", header) n_packets = ceil(length / 4096) body = b"" - for n in range(n_packets): + + for _ in range(n_packets): body += await self.connection.read() logger.debug("READ %s %s", header, len(body)) @@ -173,12 +208,14 @@ class PacketStream(object): return None async def read(self): + """Read data from the packet stream""" + msg = await self._read() if not msg: return None # check whether it's a reply and handle accordingly if msg.req < 0: - t, handler = self._event_map[-msg.req] + _, handler = self._event_map[-msg.req] await handler.process(msg) logger.info("RESPONSE [%d]: %r", -msg.req, msg) if msg.end_err: @@ -200,7 +237,11 @@ class PacketStream(object): logger.debug("WRITE HDR: %s", header) logger.debug("WRITE DATA: %s", msg.data) - def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None): + def send( # pylint: disable=too-many-arguments + self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None + ): + """Send data through the packet stream""" + update_counter = False if req is None: update_counter = True @@ -222,5 +263,7 @@ class PacketStream(object): return handler def disconnect(self): + """Disconnect the stream""" + self._connected = False self.connection.disconnect() diff --git a/ssb/util.py b/ssb/util.py index e6e0270..55a431c 100644 --- a/ssb/util.py +++ b/ssb/util.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""Utility functions""" + from base64 import b64decode, b64encode import os @@ -28,17 +30,19 @@ import yaml class ConfigException(Exception): - pass + """Exception to raise if there is a problem with the configuration data""" def tag(key): - """Create tag from publick key.""" + """Create tag from public key""" + return b"@" + b64encode(bytes(key)) + b".ed25519" def load_ssb_secret(): """Load SSB keys from ~/.ssb""" - with open(os.path.expanduser("~/.ssb/secret")) as f: + + with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f: config = yaml.load(f, Loader=yaml.SafeLoader) if config["curve"] != "ed25519": diff --git a/tests/test_feed.py b/tests/test_feed.py index c32521e..5ff4378 100644 --- a/tests/test_feed.py +++ b/tests/test_feed.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""Tests for the feed functionality""" + from base64 import b64decode from collections import OrderedDict @@ -46,17 +48,23 @@ SERIALIZED_M1 = b"""{ @pytest.fixture() def local_feed(): + """Fixture providing a local feed""" + secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") return LocalFeed(SigningKey(secret)) @pytest.fixture() def remote_feed(): + """Fixture providing a remote feed""" + public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") return Feed(VerifyKey(public)) def test_local_feed(): + """Test a local feed""" + secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") feed = LocalFeed(SigningKey(secret)) assert bytes(feed.private_key) == secret @@ -65,6 +73,8 @@ def test_local_feed(): def test_remote_feed(): + """Test a remote feed""" + public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") feed = Feed(VerifyKey(public)) assert bytes(feed.public_key) == public @@ -88,7 +98,9 @@ def test_remote_feed(): feed.sign(m1) -def test_local_message(local_feed): +def test_local_message(local_feed): # pylint: disable=redefined-outer-name + """Test a local message""" + m1 = LocalMessage( local_feed, OrderedDict( @@ -133,7 +145,9 @@ def test_local_message(local_feed): assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" -def test_remote_message(remote_feed): +def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name + """Test a remote message""" + signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519" m1 = Message( remote_feed, @@ -177,7 +191,9 @@ def test_remote_message(remote_feed): assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" -def test_remote_no_signature(remote_feed): +def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name + """Test remote feed without a signature""" + with pytest.raises(ValueError): Message( remote_feed, @@ -194,7 +210,9 @@ def test_remote_no_signature(remote_feed): ) -def test_serialize(local_feed): +def test_serialize(local_feed): # pylint: disable=redefined-outer-name + """Test feed serialization""" + m1 = LocalMessage( local_feed, OrderedDict( @@ -211,7 +229,9 @@ def test_serialize(local_feed): assert m1.serialize() == SERIALIZED_M1 -def test_parse(local_feed): +def test_parse(local_feed): # pylint: disable=redefined-outer-name + """Test feed parsing""" + m1 = LocalMessage.parse(SERIALIZED_M1, local_feed) assert m1.content == { "type": "about", diff --git a/tests/test_packet_stream.py b/tests/test_packet_stream.py index 79f95c3..2e54202 100644 --- a/tests/test_packet_stream.py +++ b/tests/test_packet_stream.py @@ -20,10 +20,11 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""Tests for the packet stream""" + from asyncio import Event, ensure_future, gather import json -from nacl.signing import SigningKey import pytest from secret_handshake.network import SHSDuplexStream @@ -58,63 +59,94 @@ MSG_BODY_2 = ( class MockSHSSocket(SHSDuplexStream): - def __init__(self, *args, **kwargs): - super(MockSHSSocket, self).__init__() + """A mocked SHS socket""" + + def __init__(self, *args, **kwargs): # pylint: disable=unused-argument + super().__init__() + self.input = [] self.output = [] self.is_connected = False self._on_connect = [] def on_connect(self, cb): + """Set the on_connect callback""" + self._on_connect.append(cb) async def read(self): + """Read data from the socket""" + if not self.input: raise StopAsyncIteration return self.input.pop(0) def write(self, data): + """Write data to the socket""" + self.output.append(data) - def feed(self, input): - self.input += input + def feed(self, input_): + """Get the connection’s feed""" + + self.input += input_ def get_output(self): + """Get the output of a call""" + while True: if not self.output: break yield self.output.pop(0) def disconnect(self): + """Disconnect from the remote party""" + self.is_connected = False class MockSHSClient(MockSHSSocket): + """A mocked SHS client""" + async def connect(self): + """Connect to a SHS server""" + self.is_connected = True + for cb in self._on_connect: await cb() class MockSHSServer(MockSHSSocket): + """A mocked SHS server""" + def listen(self): + """Listen for new connections""" + self.is_connected = True + for cb in self._on_connect: ensure_future(cb()) @pytest.fixture -def ps_client(event_loop): +def ps_client(event_loop): # pylint: disable=unused-argument + """Fixture to provide a mocked SHS client""" + return MockSHSClient() @pytest.fixture -def ps_server(event_loop): +def ps_server(event_loop): # pylint: disable=unused-argument + """Fixture to provide a mocked SHS server""" + return MockSHSServer() @pytest.mark.asyncio -async def test_on_connect(ps_server): +async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name + """Test the on_connect callback functionality""" + called = Event() async def _on_connect(): @@ -127,7 +159,9 @@ async def test_on_connect(ps_server): @pytest.mark.asyncio -async def test_message_decoding(ps_client): +async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name + """Test message decoding""" + await ps_client.connect() ps = PacketStream(ps_client) @@ -160,7 +194,9 @@ async def test_message_decoding(ps_client): @pytest.mark.asyncio -async def test_message_encoding(ps_client): +async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name + """Test message encoding""" + await ps_client.connect() ps = PacketStream(ps_client) @@ -201,7 +237,9 @@ async def test_message_encoding(ps_client): @pytest.mark.asyncio -async def test_message_stream(ps_client, mocker): +async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name + """Test requesting a history stream""" + await ps_client.connect() ps = PacketStream(ps_client) @@ -226,8 +264,8 @@ async def test_message_stream(ps_client, mocker): ) assert ps.req_counter == 2 - assert ps.register_handler.call_count == 1 - handler = list(ps._event_map.values())[0][1] + assert ps.register_handler.call_count == 1 # pylint: disable=no-member + handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access mock_process = mocker.AsyncMock() mocker.patch.object(handler, "process", mock_process) @@ -259,8 +297,8 @@ async def test_message_stream(ps_client, mocker): ) assert ps.req_counter == 3 - assert ps.register_handler.call_count == 2 - handler = list(ps._event_map.values())[1][1] + assert ps.register_handler.call_count == 2 # pylint: disable=no-member + handler = list(ps._event_map.values())[1][1] # pylint: disable=protected-access mock_process = mocker.patch.object(handler, "process", wraps=handler.process) ps_client.feed( @@ -286,7 +324,9 @@ async def test_message_stream(ps_client, mocker): @pytest.mark.asyncio -async def test_message_request(ps_server, mocker): +async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name + """Test message sending""" + ps_server.listen() ps = PacketStream(ps_server) @@ -300,8 +340,8 @@ async def test_message_request(ps_server, mocker): assert json.loads(body.decode("utf-8")) == {"name": ["whoami"], "args": []} assert ps.req_counter == 2 - assert ps.register_handler.call_count == 1 - handler = list(ps._event_map.values())[0][1] + assert ps.register_handler.call_count == 1 # pylint: disable=no-member + handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access mock_process = mocker.AsyncMock() mocker.patch.object(handler, "process", mock_process) diff --git a/tests/test_util.py b/tests/test_util.py index a5f33c8..8122dd6 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -20,6 +20,8 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +"""Tests for the utility functions""" + from base64 import b64decode from unittest.mock import mock_open, patch @@ -41,6 +43,8 @@ CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo") def test_load_secret(): + """Test loading the SSB secret from a file""" + with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True): secret = load_ssb_secret() @@ -52,6 +56,8 @@ def test_load_secret(): def test_load_exception(): + """Test configuration loading if there is a problem with the file""" + with pytest.raises(ConfigException): with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True): load_ssb_secret()