From ecb67ebcf429510788e0435dad9416d6d5a5b40b Mon Sep 17 00:00:00 2001 From: Pedro Ferreira Date: Sat, 29 Jul 2017 11:54:03 +0200 Subject: [PATCH] Partially working implementation (yay!) --- ssb/api.py | 50 +++++++++++++++++ ssb/packet_stream.py | 131 +++++++++++++++++++++++++++++++++---------- test_client.py | 62 +++++++++++++++++--- 3 files changed, 204 insertions(+), 39 deletions(-) create mode 100644 ssb/api.py diff --git a/ssb/api.py b/ssb/api.py new file mode 100644 index 0000000..df7376e --- /dev/null +++ b/ssb/api.py @@ -0,0 +1,50 @@ +from functools import wraps + + +class MuxRPCAPIException(Exception): + pass + + +class MuxRPCRequest(object): + @classmethod + def from_message(cls, message): + body = message.body + return cls('.'.join(body['name']), body['args']) + + def __init__(self, name, args): + self.name = name + self.args = args + + def __repr__(self): + return ''.format(self) + + +class MuxRPCAPI(object): + def __init__(self): + self.handlers = {} + self.connection = None + + async def __await__(self): + async for req_message in self.connection: + if req_message is None: + return + self.process(self.connection, MuxRPCRequest.from_message(req_message)) + + def add_connection(self, connection): + self.connection = connection + + def define(self, name): + def _handle(f): + self.handlers[name] = f + + @wraps(f) + def _f(*args, **kwargs): + return f(*args, **kwargs) + return f + return _handle + + def process(self, connection, request): + handler = self.handlers.get(request.name) + if not handler: + raise MuxRPCAPIException('Method {} not found!'.format(request.name)) + handler(connection, request) diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index b8f33fd..978362a 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -1,70 +1,145 @@ +import logging import struct +from asyncio import Queue from enum import Enum +from time import time from secret_handshake import SHSClient, SHSServer import simplejson +logger = logging.getLogger('packet_stream') + + class PSMessageType(Enum): BUFFER = 0 TEXT = 1 JSON = 2 +class PSStreamHandler(object): + def __init__(self, req): + super(PSStreamHandler).__init__() + self.req = req + self.queue = Queue() + + async def process(self, msg): + await self.queue.put(msg) + + async def stop(self): + await self.queue.put(None) + + async def __aiter__(self): + while True: + elem = await self.queue.get() + if not elem: + return + yield elem + + class PSMessage(object): - def __init__(self, stream, end_err, type_, body): - self.stream = stream - self.end_err = end_err - self.type = PSMessageType(type_) - self.body = body + + @classmethod + def from_header_body(cls, header, body): + flags, length, req = struct.unpack('>BIi', header) + type_ = PSMessageType(flags & 0x03) + + if type_ == PSMessageType.TEXT: + body = body.decode('utf-8') + elif type_ == PSMessageType.JSON: + body = simplejson.loads(body) + + return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req) @property def data(self): if self.type == PSMessageType.TEXT: - return self.body.decode('utf-8') + return self.body.encode('utf-8') elif self.type == PSMessageType.JSON: - return simplejson.loads(self.body) - return self.body + return simplejson.dumps(self.body) + + def __init__(self, type_, body, stream, end_err, req=None): + self.stream = stream + self.end_err = end_err + self.type = type_ + self.body = body + self.req = req def __repr__(self): - return ''.format(self.type.name, self.data) + return ''.format(self.type.name, self.body, + '' if self.req is None else ' [{}]'.format(self.req), + '~' if self.stream else '', '!' if self.end_err else '') -class PSSocket(object): +class PSConnection(object): + def __init__(self): + self._event_map = {} + self.req_counter = 1 + async def read(self): try: header = await self.connection.read() if not header: return body = await self.connection.read() - flags, length, req = struct.unpack('>BIi', header) - return PSMessage(bool(flags & 0x08), bool(flags & 0x04), flags & 0x03, body) + logger.debug('READ %s %s', header, body) + return PSMessage.from_header_body(header, body) except StopAsyncIteration: + logger.debug('DISCONNECT') await self.connection.disconnect() return None - async def __aiter__(self): - while True: - data = await self.read() + async def __await__(self): + async for data in self: + logger.info('RECV: %r', data) if data is None: return - yield data - def write(self, type_, data, req=0): - type_ = PSMessageType[type_] - if type_ == PSMessageType.JSON: - data = simplejson.dumps(data) + def register_handler(self, handler): + self._event_map[handler.req] = (time(), handler) - # XXX: Not yet handling flags that nicely + async def __aiter__(self): + while True: + msg = await self.read() + if not msg: + return + if msg.req < 0: + t, handler = self._event_map[-msg.req] + if msg.end_err: + await handler.stop() + del self._event_map[-msg.req] + logger.info('REQ: %d END', msg.req) + else: + logger.info('REQ: %d ELEM: %r', msg.req, msg) + await handler.process(msg) + else: + yield msg - header = struct.pack('>BIi', 0x08 | type_.value, len(data), req) + def write(self, msg): + logger.info('SEND: %r (%d)', msg, msg.req) + header = struct.pack('>BIi', (int(msg.stream) << 3) | (int(msg.end_err) << 2) | msg.type.value, len(msg.data), + msg.req) self.connection.write(header) - self.connection.write(data.encode('utf-8')) + self.connection.write(msg.data.encode('utf-8')) + logger.info('WRITE: %s', header) + + def on_connect(self, cb): + async def _on_connect(): + await cb() + self.connection.on_connect(_on_connect) + + def stream(self, data): + msg = PSMessage(PSMessageType.JSON, data, stream=True, end_err=False, req=self.req_counter) + self.write(msg) + handler = PSStreamHandler(self.req_counter) + self.register_handler(handler) + return handler -class PSClient(PSSocket): +class PSClient(PSConnection): def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): + super(PSClient, self).__init__() self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key, application_key=application_key, loop=loop) self.loop = loop @@ -73,15 +148,11 @@ class PSClient(PSSocket): self.connection.connect() -class PSServer(PSSocket): +class PSServer(PSConnection): def __init__(self, host, port, client_kp, application_key=None, loop=None): + super(PSClient, self).__init__() self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop) self.loop = loop def listen(self): self.connection.listen() - - def on_connect(self, cb): - async def _on_connect(): - await cb(self) - self.connection.on_connect(_on_connect) diff --git a/test_client.py b/test_client.py index 54644f0..a352d8c 100644 --- a/test_client.py +++ b/test_client.py @@ -1,10 +1,13 @@ +import logging import os -from asyncio import get_event_loop +from asyncio import get_event_loop, gather, ensure_future from base64 import b64decode import yaml +from colorlog import ColoredFormatter from nacl.signing import SigningKey +from ssb.api import MuxRPCAPI from ssb.packet_stream import PSClient @@ -12,17 +15,58 @@ with open(os.path.expanduser('~/.ssb/secret')) as f: config = yaml.load(f) +api = MuxRPCAPI() + + +@api.define('createHistoryStream') +def create_history_stream(connection, msg): + print('create_history_stream', msg) + # msg = PSMessage(PSMessageType.JSON, True, stream=True, end_err=True, req=-req) + # connection.write(msg) + + +@api.define('blobs.createWants') +def create_wants(connection, msg): + print('create_wants', msg) + + +async def main(): + handler = packet_stream.stream({ + 'name': 'createHistoryStream', + 'args': [{ + 'id': "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519", + 'seq': 1, + 'live': False, + 'keys': False + }], + 'type': 'source' + }) + async for msg in handler: + print('> RESPONSE:', msg) + + +# create console handler and set level to debug +ch = logging.StreamHandler() +ch.setLevel(logging.INFO) + +# create formatter +formatter = ColoredFormatter('%(log_color)s%(levelname)s%(reset)s:%(bold_white)s%(name)s%(reset)s - ' + '%(cyan)s%(message)s%(reset)s') + +# add formatter to ch +ch.setFormatter(formatter) + +# add ch to logger +logger = logging.getLogger('packet_stream') +logger.setLevel(logging.DEBUG) +logger.addHandler(ch) + server_pub_key = b64decode(config['public'][:-8]) - -async def main(loop): - async for msg in packet_stream: - print(msg) - print('bye') - loop = get_event_loop() - packet_stream = PSClient('127.0.0.1', 8008, SigningKey.generate(), server_pub_key, loop=loop) packet_stream.connect() -loop.run_until_complete(main(loop)) +api.add_connection(packet_stream) + +loop.run_until_complete(gather(ensure_future(api), main())) loop.close()