ci: Add and configure PyLint, and make it happy

This commit is contained in:
Gergely Polonkai 2023-11-01 05:57:08 +01:00
parent b30aa39d6b
commit 3aa2794b92
No known key found for this signature in database
GPG Key ID: 2D2885533B869ED4
13 changed files with 369 additions and 84 deletions

View File

@ -24,3 +24,9 @@ repos:
language: system language: system
require_serial: true require_serial: true
types_or: [python, pyi] types_or: [python, pyi]
- id: pylint
name: pylint
entry: poetry run pylint
language: system
types: [python]
require_serial: true

View File

@ -1,3 +1,7 @@
"""Example SSB Client"""
import base64
import hashlib
import logging import logging
import struct import struct
import time import time
@ -10,26 +14,28 @@ from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException
from ssb.packet_stream import PacketStream, PSMessageType from ssb.packet_stream import PacketStream, PSMessageType
from ssb.util import load_ssb_secret from ssb.util import load_ssb_secret
import hashlib
import base64
api = MuxRPCAPI() api = MuxRPCAPI()
@api.define("createHistoryStream") @api.define("createHistoryStream")
def create_history_stream(connection, msg): def create_history_stream(connection, msg): # pylint: disable=unused-argument
"""Handle the createHistoryStream RPC call"""
print("create_history_stream", msg) print("create_history_stream", msg)
# msg = PSMessage(PSMessageType.JSON, True, stream=True, end_err=True, req=-req) # msg = PSMessage(PSMessageType.JSON, True, stream=True, end_err=True, req=-req)
# connection.write(msg) # connection.write(msg)
@api.define("blobs.createWants") @api.define("blobs.createWants")
def create_wants(connection, msg): def create_wants(connection, msg): # pylint: disable=unused-argument
"""Handle the createWants RPC call"""
print("create_wants", msg) print("create_wants", msg)
async def test_client(): async def test_client():
"""The actual client implementation"""
async for msg in api.call( async for msg in api.call(
"createHistoryStream", "createHistoryStream",
[{"id": "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519", "seq": 1, "live": False, "keys": False}], [{"id": "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519", "seq": 1, "live": False, "keys": False}],
@ -63,7 +69,9 @@ async def test_client():
f.write(img_data) f.write(img_data)
async def main(): async def main(keypair):
"""The main function to run"""
client = SHSClient("127.0.0.1", 8008, keypair, bytes(keypair.verify_key)) client = SHSClient("127.0.0.1", 8008, keypair, bytes(keypair.verify_key))
packet_stream = PacketStream(client) packet_stream = PacketStream(client)
await client.open() await client.open()
@ -89,8 +97,8 @@ if __name__ == "__main__":
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger.addHandler(ch) logger.addHandler(ch)
keypair = load_ssb_secret()["keypair"] ssb_keypair = load_ssb_secret()["keypair"]
loop = get_event_loop() loop = get_event_loop()
loop.run_until_complete(main()) loop.run_until_complete(main(ssb_keypair))
loop.close() loop.close()

View File

@ -1,5 +1,7 @@
"""Test SSB server"""
import logging import logging
from asyncio import gather, get_event_loop, ensure_future from asyncio import get_event_loop
from colorlog import ColoredFormatter from colorlog import ColoredFormatter
@ -12,6 +14,8 @@ api = MuxRPCAPI()
async def on_connect(conn): async def on_connect(conn):
"""Incoming connection handler"""
packet_stream = PacketStream(conn) packet_stream = PacketStream(conn)
api.add_connection(packet_stream) api.add_connection(packet_stream)
@ -22,6 +26,7 @@ async def on_connect(conn):
async def main(): async def main():
"""The main function to run"""
server = SHSServer("127.0.0.1", 8008, load_ssb_secret()["keypair"]) server = SHSServer("127.0.0.1", 8008, load_ssb_secret()["keypair"])
server.on_connect(on_connect) server.on_connect(on_connect)
await server.listen() await server.listen()

71
poetry.lock generated
View File

@ -25,6 +25,20 @@ files = [
[package.extras] [package.extras]
test = ["coverage", "mypy", "pexpect", "ruff", "wheel"] test = ["coverage", "mypy", "pexpect", "ruff", "wheel"]
[[package]]
name = "astroid"
version = "3.0.1"
description = "An abstract syntax tree for Python with inference support."
optional = false
python-versions = ">=3.8.0"
files = [
{file = "astroid-3.0.1-py3-none-any.whl", hash = "sha256:7d5895c9825e18079c5aeac0572bc2e4c83205c95d416e0b4fee8bc361d2d9ca"},
{file = "astroid-3.0.1.tar.gz", hash = "sha256:86b0bb7d7da0be1a7c4aedb7974e391b32d4ed89e33de6ed6902b4b15c97577e"},
]
[package.dependencies]
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""}
[[package]] [[package]]
name = "babel" name = "babel"
version = "2.13.1" version = "2.13.1"
@ -419,6 +433,20 @@ files = [
{file = "decli-0.6.1.tar.gz", hash = "sha256:ed88ccb947701e8e5509b7945fda56e150e2ac74a69f25d47ac85ef30ab0c0f0"}, {file = "decli-0.6.1.tar.gz", hash = "sha256:ed88ccb947701e8e5509b7945fda56e150e2ac74a69f25d47ac85ef30ab0c0f0"},
] ]
[[package]]
name = "dill"
version = "0.3.7"
description = "serialize all of Python"
optional = false
python-versions = ">=3.7"
files = [
{file = "dill-0.3.7-py3-none-any.whl", hash = "sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e"},
{file = "dill-0.3.7.tar.gz", hash = "sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03"},
]
[package.extras]
graph = ["objgraph (>=1.7.2)"]
[[package]] [[package]]
name = "docutils" name = "docutils"
version = "0.17.1" version = "0.17.1"
@ -599,6 +627,17 @@ files = [
{file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"}, {file = "MarkupSafe-2.1.3.tar.gz", hash = "sha256:af598ed32d6ae86f1b747b82783958b1a4ab8f617b06fe68795c7f026abbdcad"},
] ]
[[package]]
name = "mccabe"
version = "0.7.0"
description = "McCabe checker, plugin for flake8"
optional = false
python-versions = ">=3.6"
files = [
{file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"},
{file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"},
]
[[package]] [[package]]
name = "mypy-extensions" name = "mypy-extensions"
version = "1.0.0" version = "1.0.0"
@ -712,6 +751,36 @@ files = [
[package.extras] [package.extras]
plugins = ["importlib-metadata"] plugins = ["importlib-metadata"]
[[package]]
name = "pylint"
version = "3.0.2"
description = "python code static checker"
optional = false
python-versions = ">=3.8.0"
files = [
{file = "pylint-3.0.2-py3-none-any.whl", hash = "sha256:60ed5f3a9ff8b61839ff0348b3624ceeb9e6c2a92c514d81c9cc273da3b6bcda"},
{file = "pylint-3.0.2.tar.gz", hash = "sha256:0d4c286ef6d2f66c8bfb527a7f8a629009e42c99707dec821a03e1b51a4c1496"},
]
[package.dependencies]
astroid = ">=3.0.1,<=3.1.0-dev0"
colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""}
dill = [
{version = ">=0.2", markers = "python_version < \"3.11\""},
{version = ">=0.3.7", markers = "python_version >= \"3.12\""},
{version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""},
]
isort = ">=4.2.5,<6"
mccabe = ">=0.6,<0.8"
platformdirs = ">=2.2.0"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
tomlkit = ">=0.10.1"
typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""}
[package.extras]
spelling = ["pyenchant (>=3.2,<4.0)"]
testutils = ["gitpython (>3)"]
[[package]] [[package]]
name = "pynacl" name = "pynacl"
version = "1.5.0" version = "1.5.0"
@ -1200,4 +1269,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "a69623d229f05becfdd7a18072ae96970994ceebdc193d7840aa704ba0d86169" content-hash = "63b3d6f54c99a6722a3d0d5cf9eac68bdb5ef0ea7c58957dd76494529870186c"

View File

@ -25,6 +25,7 @@ pytest-cov = "^4.1.0"
pytest-mock = "^3.12.0" pytest-mock = "^3.12.0"
commitizen = "^3.12.0" commitizen = "^3.12.0"
black = "^23.10.1" black = "^23.10.1"
pylint = "^3.0.2"
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
Sphinx = "^2.1.1" Sphinx = "^2.1.1"
@ -41,6 +42,9 @@ skip_covered = true
fail_under = 70 fail_under = 70
omit = ["examples/*"] omit = ["examples/*"]
[tool.pylint.format]
max-line-length = 120
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = ["--cov=.", "--no-cov-on-fail"] addopts = ["--cov=.", "--no-cov-on-fail"]
python_files = ["tests/test_*.py"] python_files = ["tests/test_*.py"]

View File

@ -1,3 +1,5 @@
"""Feed related functionality"""
from .models import Feed, LocalFeed, Message, LocalMessage, NoPrivateKeyException from .models import Feed, LocalFeed, Message, LocalMessage, NoPrivateKeyException
__all__ = ("Feed", "LocalFeed", "Message", "LocalMessage", "NoPrivateKeyException") __all__ = ("Feed", "LocalFeed", "Message", "LocalMessage", "NoPrivateKeyException")

View File

@ -1,3 +1,5 @@
"""Feed models"""
import datetime import datetime
from base64 import b64encode from base64 import b64encode
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
@ -12,44 +14,65 @@ OrderedMsg = namedtuple("OrderedMsg", ("previous", "author", "sequence", "timest
class NoPrivateKeyException(Exception): class NoPrivateKeyException(Exception):
pass """Exception to raise when a private key is not available"""
def to_ordered(data): def to_ordered(data):
"""Convert a dictionary to an ``OrderedDict``"""
smsg = OrderedMsg(**data) smsg = OrderedMsg(**data)
return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields) return OrderedDict((k, getattr(smsg, k)) for k in smsg._fields)
def get_millis_1970(): def get_millis_1970():
"""Get the UNIX timestamp in milliseconds"""
return int(datetime.datetime.utcnow().timestamp() * 1000) return int(datetime.datetime.utcnow().timestamp() * 1000)
class Feed(object): class Feed:
"""Base class for feeds"""
def __init__(self, public_key): def __init__(self, public_key):
self.public_key = public_key self.public_key = public_key
@property @property
def id(self): def id(self):
"""The identifier of the feed"""
return tag(self.public_key).decode("ascii") return tag(self.public_key).decode("ascii")
def sign(self, msg): def sign(self, msg):
"""Sign a message"""
raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)") raise NoPrivateKeyException("Cannot use remote identity to sign (no private key!)")
class LocalFeed(Feed): class LocalFeed(Feed):
def __init__(self, private_key): """Class representing a local feed"""
def __init__(self, private_key): # pylint: disable=super-init-not-called
self.private_key = private_key self.private_key = private_key
@property @property
def public_key(self): def public_key(self):
"""The public key of the feed"""
return self.private_key.verify_key return self.private_key.verify_key
def sign(self, msg): def sign(self, msg):
"""Sign a message for this feed"""
return self.private_key.sign(msg).signature return self.private_key.sign(msg).signature
class Message(object): class Message:
def __init__(self, feed, content, signature, sequence=1, timestamp=None, previous=None): """Base class for SSB messages"""
def __init__( # pylint: disable=too-many-arguments
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
):
self.feed = feed self.feed = feed
self.content = content self.content = content
@ -68,14 +91,21 @@ class Message(object):
@classmethod @classmethod
def parse(cls, data, feed): def parse(cls, data, feed):
"""Parse raw message data"""
obj = loads(data, object_pairs_hook=OrderedDict) obj = loads(data, object_pairs_hook=OrderedDict)
msg = cls(feed, obj["content"], timestamp=obj["timestamp"]) msg = cls(feed, obj["content"], timestamp=obj["timestamp"])
return msg return msg
def serialize(self, add_signature=True): def serialize(self, add_signature=True):
"""Serialize the message"""
return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8") return dumps(self.to_dict(add_signature=add_signature), indent=2).encode("utf-8")
def to_dict(self, add_signature=True): def to_dict(self, add_signature=True):
"""Convert the message to a dictionary"""
obj = to_ordered( obj = to_ordered(
{ {
"previous": self.previous.key if self.previous else None, "previous": self.previous.key if self.previous else None,
@ -89,23 +119,34 @@ class Message(object):
if add_signature: if add_signature:
obj["signature"] = self.signature obj["signature"] = self.signature
return obj return obj
def verify(self, signature): def verify(self, signature):
"""Verify the signature of the message"""
return self.signature == signature return self.signature == signature
@property @property
def hash(self): def hash(self):
hash = sha256(self.serialize()).digest() """The cryptographic hash of the message"""
return b64encode(hash).decode("ascii") + ".sha256"
hash_ = sha256(self.serialize()).digest()
return b64encode(hash_).decode("ascii") + ".sha256"
@property @property
def key(self): def key(self):
"""The key of the message"""
return "%" + self.hash return "%" + self.hash
class LocalMessage(Message): class LocalMessage(Message):
def __init__(self, feed, content, signature=None, sequence=1, timestamp=None, previous=None): """Class representing a local message"""
def __init__( # pylint: disable=too-many-arguments,super-init-not-called
self, feed, content, signature=None, sequence=1, timestamp=None, previous=None
):
self.feed = feed self.feed = feed
self.content = content self.content = content

View File

@ -1,14 +1,20 @@
"""MuxRPC"""
from functools import wraps from functools import wraps
from ssb.packet_stream import PSMessageType from ssb.packet_stream import PSMessageType
class MuxRPCAPIException(Exception): class MuxRPCAPIException(Exception):
pass """Exception to raise on MuxRPC API errors"""
class MuxRPCHandler: class MuxRPCHandler: # pylint: disable=too-few-public-methods
"""Base MuxRPC handler class"""
def check_message(self, msg): def check_message(self, msg):
"""Check message validity"""
body = msg.body body = msg.body
if isinstance(body, dict) and "name" in body and body["name"] == "Error": if isinstance(body, dict) and "name" in body and body["name"] == "Error":
@ -16,6 +22,8 @@ class MuxRPCHandler:
class MuxRPCRequestHandler(MuxRPCHandler): class MuxRPCRequestHandler(MuxRPCHandler):
"""Base class for MuxRPC request handlers"""
def __init__(self, ps_handler): def __init__(self, ps_handler):
self.ps_handler = ps_handler self.ps_handler = ps_handler
@ -26,6 +34,8 @@ class MuxRPCRequestHandler(MuxRPCHandler):
class MuxRPCSourceHandler(MuxRPCHandler): class MuxRPCSourceHandler(MuxRPCHandler):
"""MuxRPC handler for sources"""
def __init__(self, ps_handler): def __init__(self, ps_handler):
self.ps_handler = ps_handler self.ps_handler = ps_handler
@ -39,39 +49,60 @@ class MuxRPCSourceHandler(MuxRPCHandler):
return msg return msg
class MuxRPCSinkHandlerMixin(object): class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods
"""Mixin for sink-type MuxRPC handlers"""
def send(self, msg, msg_type=PSMessageType.JSON, end=False): def send(self, msg, msg_type=PSMessageType.JSON, end=False):
"""Send a message through the stream"""
self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end) self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end)
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler): class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler):
"""MuxRPC handler for duplex streams"""
def __init__(self, ps_handler, connection, req): def __init__(self, ps_handler, connection, req):
super(MuxRPCDuplexHandler, self).__init__(ps_handler) super().__init__(ps_handler)
self.connection = connection self.connection = connection
self.req = req self.req = req
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin): class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin):
"""MuxRPC handler for sinks"""
def __init__(self, connection, req): def __init__(self, connection, req):
self.connection = connection self.connection = connection
self.req = req self.req = req
def _get_appropriate_api_handler(type_, connection, ps_handler, req): def _get_appropriate_api_handler(type_, connection, ps_handler, req):
"""Find the appropriate MuxRPC handler"""
if type_ in {"sync", "async"}: if type_ in {"sync", "async"}:
return MuxRPCRequestHandler(ps_handler) return MuxRPCRequestHandler(ps_handler)
elif type_ == "source":
if type_ == "source":
return MuxRPCSourceHandler(ps_handler) return MuxRPCSourceHandler(ps_handler)
elif type_ == "sink":
if type_ == "sink":
return MuxRPCSinkHandler(connection, req) return MuxRPCSinkHandler(connection, req)
elif type_ == "duplex":
if type_ == "duplex":
return MuxRPCDuplexHandler(ps_handler, connection, req) return MuxRPCDuplexHandler(ps_handler, connection, req)
return None
class MuxRPCRequest:
"""MuxRPC request"""
class MuxRPCRequest(object):
@classmethod @classmethod
def from_message(cls, message): def from_message(cls, message):
"""Initialise a request from a raw packet stream message"""
body = message.body body = message.body
return cls(".".join(body["name"]), body["args"]) return cls(".".join(body["name"]), body["args"])
def __init__(self, name, args): def __init__(self, name, args):
@ -79,22 +110,28 @@ class MuxRPCRequest(object):
self.args = args self.args = args
def __repr__(self): def __repr__(self):
return "<MuxRPCRequest {0.name} {0.args}>".format(self) return f"<MuxRPCRequest {self.name} {self.args}>"
class MuxRPCMessage(object): class MuxRPCMessage:
"""MuxRPC message"""
@classmethod @classmethod
def from_message(cls, message): def from_message(cls, message):
"""Initialise a MuxRPC message from a raw packet stream message"""
return cls(message.body) return cls(message.body)
def __init__(self, body): def __init__(self, body):
self.body = body self.body = body
def __repr__(self): def __repr__(self):
return "<MuxRPCMessage {0.body}}>".format(self) return f"<MuxRPCMessage {self.body}>"
class MuxRPCAPI(object): class MuxRPCAPI:
"""Generit MuxRPC API"""
def __init__(self): def __init__(self):
self.handlers = {} self.handlers = {}
self.connection = None self.connection = None
@ -109,9 +146,13 @@ class MuxRPCAPI(object):
self.process(self.connection, MuxRPCRequest.from_message(req_message)) self.process(self.connection, MuxRPCRequest.from_message(req_message))
def add_connection(self, connection): def add_connection(self, connection):
"""Set the packet stream connection of this RPC API"""
self.connection = connection self.connection = connection
def define(self, name): def define(self, name):
"""Decorator to define an RPC method handler"""
def _handle(f): def _handle(f):
self.handlers[name] = f self.handlers[name] = f
@ -124,14 +165,20 @@ class MuxRPCAPI(object):
return _handle return _handle
def process(self, connection, request): def process(self, connection, request):
"""Process an incoming request"""
handler = self.handlers.get(request.name) handler = self.handlers.get(request.name)
if not handler: if not handler:
raise MuxRPCAPIException("Method {} not found!".format(request.name)) raise MuxRPCAPIException(f"Method {request.name} not found!")
handler(connection, request) handler(connection, request)
def call(self, name, args, type_="sync"): def call(self, name, args, type_="sync"):
"""Call an RPC method"""
if not self.connection.is_connected: if not self.connection.is_connected:
raise Exception("not connected") raise Exception("not connected") # pylint: disable=broad-exception-raised
old_counter = self.connection.req_counter old_counter = self.connection.req_counter
ps_handler = self.connection.send( ps_handler = self.connection.send(

View File

@ -1,3 +1,5 @@
"""Packet streams"""
import logging import logging
import struct import struct
from asyncio import Event, Queue from asyncio import Event, Queue
@ -7,28 +9,34 @@ from math import ceil
import simplejson import simplejson
from secret_handshake import SHSClient, SHSServer
logger = logging.getLogger("packet_stream") logger = logging.getLogger("packet_stream")
class PSMessageType(Enum): class PSMessageType(Enum):
"""Available message types"""
BUFFER = 0 BUFFER = 0
TEXT = 1 TEXT = 1
JSON = 2 JSON = 2
class PSStreamHandler(object): class PSStreamHandler:
"""Packet stream handler"""
def __init__(self, req): def __init__(self, req):
super(PSStreamHandler).__init__() super(PSStreamHandler).__init__()
self.req = req self.req = req
self.queue = Queue() self.queue = Queue()
async def process(self, msg): async def process(self, msg):
"""Process a pending message"""
await self.queue.put(msg) await self.queue.put(msg)
async def stop(self): async def stop(self):
"""Stop a pending request"""
await self.queue.put(None) await self.queue.put(None)
def __aiter__(self): def __aiter__(self):
@ -43,30 +51,39 @@ class PSStreamHandler(object):
return elem return elem
class PSRequestHandler(object): class PSRequestHandler:
"""Packet stream request handler"""
def __init__(self, req): def __init__(self, req):
super(PSRequestHandler).__init__()
self.req = req self.req = req
self.event = Event() self.event = Event()
self._msg = None self._msg = None
async def process(self, msg): async def process(self, msg):
"""Process a message request"""
self._msg = msg self._msg = msg
self.event.set() self.event.set()
async def stop(self): async def stop(self):
"""Stop a pending event request"""
if not self.event.is_set(): if not self.event.is_set():
self.event.set() self.event.set()
def __await__(self): async def __await__(self):
# wait until 'process' is called # wait until 'process' is called
yield from self.event.wait().__await__() await self.event.wait()
return self._msg return self._msg
class PSMessage(object): class PSMessage:
"""Packet Stream message"""
@classmethod @classmethod
def from_header_body(cls, flags, req, body): def from_header_body(cls, flags, req, body):
"""Parse a raw message"""
type_ = PSMessageType(flags & 0x03) type_ = PSMessageType(flags & 0x03)
if type_ == PSMessageType.TEXT: if type_ == PSMessageType.TEXT:
@ -78,13 +95,17 @@ class PSMessage(object):
@property @property
def data(self): def data(self):
"""The raw message data"""
if self.type == PSMessageType.TEXT: if self.type == PSMessageType.TEXT:
return self.body.encode("utf-8") return self.body.encode("utf-8")
elif self.type == PSMessageType.JSON:
if self.type == PSMessageType.JSON:
return simplejson.dumps(self.body).encode("utf-8") return simplejson.dumps(self.body).encode("utf-8")
return self.body return self.body
def __init__(self, type_, body, stream, end_err, req=None): def __init__(self, type_, body, stream, end_err, req=None): # pylint: disable=too-many-arguments
self.stream = stream self.stream = stream
self.end_err = end_err self.end_err = end_err
self.type = type_ self.type = type_
@ -93,29 +114,35 @@ class PSMessage(object):
def __repr__(self): def __repr__(self):
if self.type == PSMessageType.BUFFER: if self.type == PSMessageType.BUFFER:
body = "{} bytes".format(len(self.body)) body = f"{len(self.body)} bytes"
else: else:
body = self.body body = self.body
return "<PSMessage ({}): {}{} {}{}>".format(
self.type.name, req = "" if self.req is None else f" [{self.req}]"
body, is_stream = "~" if self.stream else ""
"" if self.req is None else " [{}]".format(self.req), err = "!" if self.end_err else ""
"~" if self.stream else "",
"!" if self.end_err else "", return f"<PSMessage ({self.type.name}): {body}{req} {is_stream}{err}>"
)
class PacketStream(object): class PacketStream:
"""SSB Packet stream"""
def __init__(self, connection): def __init__(self, connection):
self.connection = connection self.connection = connection
self.req_counter = 1 self.req_counter = 1
self._event_map = {} self._event_map = {}
self._connected = False
def register_handler(self, handler): def register_handler(self, handler):
"""Register an RPC handler"""
self._event_map[handler.req] = (time(), handler) self._event_map[handler.req] = (time(), handler)
@property @property
def is_connected(self): def is_connected(self):
"""Check if the stream is connected"""
return self.connection.is_connected return self.connection.is_connected
def __aiter__(self): def __aiter__(self):
@ -148,7 +175,8 @@ class PacketStream(object):
n_packets = ceil(length / 4096) n_packets = ceil(length / 4096)
body = b"" body = b""
for n in range(n_packets):
for _ in range(n_packets):
body += await self.connection.read() body += await self.connection.read()
logger.debug("READ %s %s", header, len(body)) logger.debug("READ %s %s", header, len(body))
@ -159,12 +187,14 @@ class PacketStream(object):
return None return None
async def read(self): async def read(self):
"""Read data from the packet stream"""
msg = await self._read() msg = await self._read()
if not msg: if not msg:
return None return None
# check whether it's a reply and handle accordingly # check whether it's a reply and handle accordingly
if msg.req < 0: if msg.req < 0:
t, handler = self._event_map[-msg.req] _, handler = self._event_map[-msg.req]
await handler.process(msg) await handler.process(msg)
logger.info("RESPONSE [%d]: %r", -msg.req, msg) logger.info("RESPONSE [%d]: %r", -msg.req, msg)
if msg.end_err: if msg.end_err:
@ -183,7 +213,11 @@ class PacketStream(object):
logger.debug("WRITE HDR: %s", header) logger.debug("WRITE HDR: %s", header)
logger.debug("WRITE DATA: %s", msg.data) logger.debug("WRITE DATA: %s", msg.data)
def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None): def send( # pylint: disable=too-many-arguments
self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None
):
"""Send data through the packet stream"""
update_counter = False update_counter = False
if req is None: if req is None:
update_counter = True update_counter = True
@ -205,5 +239,7 @@ class PacketStream(object):
return handler return handler
def disconnect(self): def disconnect(self):
"""Disconnect the stream"""
self._connected = False self._connected = False
self.connection.disconnect() self.connection.disconnect()

View File

@ -1,16 +1,18 @@
import os """Utility functions"""
import yaml
from base64 import b64decode, b64encode
from base64 import b64decode, b64encode
import os
import yaml
from nacl.signing import SigningKey from nacl.signing import SigningKey
class ConfigException(Exception): class ConfigException(Exception):
pass """Exception to raise if there is a problem with the configuration data"""
def tag(key): def tag(key):
"""Create tag from publick key.""" """Create tag from public key."""
return b"@" + b64encode(bytes(key)) + b".ed25519" return b"@" + b64encode(bytes(key)) + b".ed25519"
@ -18,7 +20,7 @@ def tag(key):
def load_ssb_secret(): def load_ssb_secret():
"""Load SSB keys from ~/.ssb""" """Load SSB keys from ~/.ssb"""
with open(os.path.expanduser("~/.ssb/secret")) as f: with open(os.path.expanduser("~/.ssb/secret"), encoding="utf-8") as f:
config = yaml.load(f, Loader=yaml.SafeLoader) config = yaml.load(f, Loader=yaml.SafeLoader)
if config["curve"] != "ed25519": if config["curve"] != "ed25519":

View File

@ -1,3 +1,5 @@
"""Tests for the feed functionality"""
from base64 import b64decode from base64 import b64decode
from collections import OrderedDict from collections import OrderedDict
@ -25,17 +27,23 @@ SERIALIZED_M1 = b"""{
@pytest.fixture @pytest.fixture
def local_feed(): def local_feed():
"""Fixture providing a local feed"""
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
return LocalFeed(SigningKey(secret)) return LocalFeed(SigningKey(secret))
@pytest.fixture @pytest.fixture
def remote_feed(): def remote_feed():
"""Fixture providing a remote feed"""
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
return Feed(VerifyKey(public)) return Feed(VerifyKey(public))
def test_local_feed(): def test_local_feed():
"""Test a local feed"""
secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=") secret = b64decode("Mz2qkNOP2K6upnqibWrR+z8pVUI1ReA1MLc7QMtF2qQ=")
feed = LocalFeed(SigningKey(secret)) feed = LocalFeed(SigningKey(secret))
assert bytes(feed.private_key) == secret assert bytes(feed.private_key) == secret
@ -44,6 +52,8 @@ def test_local_feed():
def test_remote_feed(): def test_remote_feed():
"""Test a remote feed"""
public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=") public = b64decode("I/4cyN/jPBbDsikbHzAEvmaYlaJK33lW3UhWjNXjyrU=")
feed = Feed(VerifyKey(public)) feed = Feed(VerifyKey(public))
assert bytes(feed.public_key) == public assert bytes(feed.public_key) == public
@ -60,7 +70,9 @@ def test_remote_feed():
feed.sign(m1) feed.sign(m1)
def test_local_message(local_feed): def test_local_message(local_feed): # pylint: disable=redefined-outer-name
"""Test a local message"""
m1 = LocalMessage( m1 = LocalMessage(
local_feed, local_feed,
OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]), OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]),
@ -91,7 +103,9 @@ def test_local_message(local_feed):
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
def test_remote_message(remote_feed): def test_remote_message(remote_feed): # pylint: disable=redefined-outer-name
"""Test a remote message"""
signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519" signature = "lPsQ9P10OgeyH6u0unFgiI2wV/RQ7Q2x2ebxnXYCzsJ055TBMXphRADTKhOMS2EkUxXQ9k3amj5fnWPudGxwBQ==.sig.ed25519"
m1 = Message( m1 = Message(
remote_feed, remote_feed,
@ -123,7 +137,9 @@ def test_remote_message(remote_feed):
assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256" assert m2.key == "%nx13uks5GUwuKJC49PfYGMS/1pgGTtwwdWT7kbVaroM=.sha256"
def test_remote_no_signature(remote_feed): def test_remote_no_signature(remote_feed): # pylint: disable=redefined-outer-name
"""Test remote feed without a signature"""
with pytest.raises(ValueError): with pytest.raises(ValueError):
Message( Message(
remote_feed, remote_feed,
@ -135,7 +151,9 @@ def test_remote_no_signature(remote_feed):
) )
def test_serialize(local_feed): def test_serialize(local_feed): # pylint: disable=redefined-outer-name
"""Test feed serialization"""
m1 = LocalMessage( m1 = LocalMessage(
local_feed, local_feed,
OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]), OrderedDict([("type", "about"), ("about", local_feed.id), ("name", "neo"), ("description", "The Chosen One")]),
@ -145,7 +163,9 @@ def test_serialize(local_feed):
assert m1.serialize() == SERIALIZED_M1 assert m1.serialize() == SERIALIZED_M1
def test_parse(local_feed): def test_parse(local_feed): # pylint: disable=redefined-outer-name
"""Test feed parsing"""
m1 = LocalMessage.parse(SERIALIZED_M1, local_feed) m1 = LocalMessage.parse(SERIALIZED_M1, local_feed)
assert m1.content == {"type": "about", "about": local_feed.id, "name": "neo", "description": "The Chosen One"} assert m1.content == {"type": "about", "about": local_feed.id, "name": "neo", "description": "The Chosen One"}
assert m1.timestamp == 1495706260190 assert m1.timestamp == 1495706260190

View File

@ -1,8 +1,9 @@
"""Tests for the packet stream"""
import json import json
from asyncio import ensure_future, gather, Event from asyncio import ensure_future, gather, Event
import pytest import pytest
from nacl.signing import SigningKey
from secret_handshake.network import SHSDuplexStream from secret_handshake.network import SHSDuplexStream
from ssb.packet_stream import PacketStream, PSMessageType from ssb.packet_stream import PacketStream, PSMessageType
@ -36,63 +37,94 @@ MSG_BODY_2 = (
class MockSHSSocket(SHSDuplexStream): class MockSHSSocket(SHSDuplexStream):
def __init__(self, *args, **kwargs): """A mocked SHS socket"""
super(MockSHSSocket, self).__init__()
def __init__(self, *args, **kwargs): # pylint: disable=unused-argument
super().__init__()
self.input = [] self.input = []
self.output = [] self.output = []
self.is_connected = False self.is_connected = False
self._on_connect = [] self._on_connect = []
def on_connect(self, cb): def on_connect(self, cb):
"""Set the on_connect callback"""
self._on_connect.append(cb) self._on_connect.append(cb)
async def read(self): async def read(self):
"""Read data from the socket"""
if not self.input: if not self.input:
raise StopAsyncIteration raise StopAsyncIteration
return self.input.pop(0) return self.input.pop(0)
def write(self, data): def write(self, data):
"""Write data to the socket"""
self.output.append(data) self.output.append(data)
def feed(self, input): def feed(self, input_):
self.input += input """Get the connections feed"""
self.input += input_
def get_output(self): def get_output(self):
"""Get the output of a call"""
while True: while True:
if not self.output: if not self.output:
break break
yield self.output.pop(0) yield self.output.pop(0)
def disconnect(self): def disconnect(self):
"""Disconnect from the remote party"""
self.is_connected = False self.is_connected = False
class MockSHSClient(MockSHSSocket): class MockSHSClient(MockSHSSocket):
"""A mocked SHS client"""
async def connect(self): async def connect(self):
"""Connect to a SHS server"""
self.is_connected = True self.is_connected = True
for cb in self._on_connect: for cb in self._on_connect:
await cb() await cb()
class MockSHSServer(MockSHSSocket): class MockSHSServer(MockSHSSocket):
"""A mocked SHS server"""
def listen(self): def listen(self):
"""Listen for new connections"""
self.is_connected = True self.is_connected = True
for cb in self._on_connect: for cb in self._on_connect:
ensure_future(cb()) ensure_future(cb())
@pytest.fixture @pytest.fixture
def ps_client(event_loop): def ps_client(event_loop): # pylint: disable=unused-argument
"""Fixture to provide a mocked SHS client"""
return MockSHSClient() return MockSHSClient()
@pytest.fixture @pytest.fixture
def ps_server(event_loop): def ps_server(event_loop): # pylint: disable=unused-argument
"""Fixture to provide a mocked SHS server"""
return MockSHSServer() return MockSHSServer()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_on_connect(ps_server): async def test_on_connect(ps_server): # pylint: disable=redefined-outer-name
"""Test the on_connect callback functionality"""
called = Event() called = Event()
async def _on_connect(): async def _on_connect():
@ -105,7 +137,9 @@ async def test_on_connect(ps_server):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_decoding(ps_client): async def test_message_decoding(ps_client): # pylint: disable=redefined-outer-name
"""Test message decoding"""
await ps_client.connect() await ps_client.connect()
ps = PacketStream(ps_client) ps = PacketStream(ps_client)
@ -133,7 +167,9 @@ async def test_message_decoding(ps_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_encoding(ps_client): async def test_message_encoding(ps_client): # pylint: disable=redefined-outer-name
"""Test message encoding"""
await ps_client.connect() await ps_client.connect()
ps = PacketStream(ps_client) ps = PacketStream(ps_client)
@ -164,7 +200,9 @@ async def test_message_encoding(ps_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_stream(ps_client, mocker): async def test_message_stream(ps_client, mocker): # pylint: disable=redefined-outer-name
"""Test requesting a history stream"""
await ps_client.connect() await ps_client.connect()
ps = PacketStream(ps_client) ps = PacketStream(ps_client)
@ -184,8 +222,8 @@ async def test_message_stream(ps_client, mocker):
) )
assert ps.req_counter == 2 assert ps.req_counter == 2
assert ps.register_handler.call_count == 1 assert ps.register_handler.call_count == 1 # pylint: disable=no-member
handler = list(ps._event_map.values())[0][1] handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access
mock_process = mocker.patch.object(handler, "process") mock_process = mocker.patch.object(handler, "process")
@ -211,8 +249,8 @@ async def test_message_stream(ps_client, mocker):
) )
assert ps.req_counter == 3 assert ps.req_counter == 3
assert ps.register_handler.call_count == 2 assert ps.register_handler.call_count == 2 # pylint: disable=no-member
handler = list(ps._event_map.values())[1][1] handler = list(ps._event_map.values())[1][1] # pylint: disable=protected-access
mock_process = mocker.patch.object(handler, "process", wraps=handler.process) mock_process = mocker.patch.object(handler, "process", wraps=handler.process)
@ -234,7 +272,9 @@ async def test_message_stream(ps_client, mocker):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_message_request(ps_server, mocker): async def test_message_request(ps_server, mocker): # pylint: disable=redefined-outer-name
"""Test message sending"""
ps_server.listen() ps_server.listen()
ps = PacketStream(ps_server) ps = PacketStream(ps_server)
@ -248,8 +288,8 @@ async def test_message_request(ps_server, mocker):
assert json.loads(body.decode("utf-8")) == {"name": ["whoami"], "args": []} assert json.loads(body.decode("utf-8")) == {"name": ["whoami"], "args": []}
assert ps.req_counter == 2 assert ps.req_counter == 2
assert ps.register_handler.call_count == 1 assert ps.register_handler.call_count == 1 # pylint: disable=no-member
handler = list(ps._event_map.values())[0][1] handler = list(ps._event_map.values())[0][1] # pylint: disable=protected-access
mock_process = mocker.patch.object(handler, "process") mock_process = mocker.patch.object(handler, "process")

View File

@ -1,3 +1,5 @@
"""Test for utility functions"""
from base64 import b64decode from base64 import b64decode
from unittest.mock import mock_open, patch from unittest.mock import mock_open, patch
@ -20,6 +22,8 @@ CONFIG_FILE_INVALID = CONFIG_FILE.replace("ed25519", "foo")
def test_load_secret(): def test_load_secret():
"""Test loading the SSB secret from a file"""
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True): with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE), create=True):
secret = load_ssb_secret() secret = load_ssb_secret()
@ -31,6 +35,7 @@ def test_load_secret():
def test_load_exception(): def test_load_exception():
"""Test configuration loading if there is a problem with the file"""
with pytest.raises(ConfigException): with pytest.raises(ConfigException):
with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True): with patch("ssb.util.open", mock_open(read_data=CONFIG_FILE_INVALID), create=True):
load_ssb_secret() load_ssb_secret()