Partially working implementation (yay!)

This commit is contained in:
Pedro Ferreira 2017-07-29 11:54:03 +02:00
parent 21af5fba09
commit ecb67ebcf4
3 changed files with 204 additions and 39 deletions

50
ssb/api.py Normal file
View File

@ -0,0 +1,50 @@
from functools import wraps
class MuxRPCAPIException(Exception):
pass
class MuxRPCRequest(object):
@classmethod
def from_message(cls, message):
body = message.body
return cls('.'.join(body['name']), body['args'])
def __init__(self, name, args):
self.name = name
self.args = args
def __repr__(self):
return '<MuxRPCRequest {0.name} {0.args}>'.format(self)
class MuxRPCAPI(object):
def __init__(self):
self.handlers = {}
self.connection = None
async def __await__(self):
async for req_message in self.connection:
if req_message is None:
return
self.process(self.connection, MuxRPCRequest.from_message(req_message))
def add_connection(self, connection):
self.connection = connection
def define(self, name):
def _handle(f):
self.handlers[name] = f
@wraps(f)
def _f(*args, **kwargs):
return f(*args, **kwargs)
return f
return _handle
def process(self, connection, request):
handler = self.handlers.get(request.name)
if not handler:
raise MuxRPCAPIException('Method {} not found!'.format(request.name))
handler(connection, request)

View File

@ -1,70 +1,145 @@
import logging
import struct import struct
from asyncio import Queue
from enum import Enum from enum import Enum
from time import time
from secret_handshake import SHSClient, SHSServer from secret_handshake import SHSClient, SHSServer
import simplejson import simplejson
logger = logging.getLogger('packet_stream')
class PSMessageType(Enum): class PSMessageType(Enum):
BUFFER = 0 BUFFER = 0
TEXT = 1 TEXT = 1
JSON = 2 JSON = 2
class PSStreamHandler(object):
def __init__(self, req):
super(PSStreamHandler).__init__()
self.req = req
self.queue = Queue()
async def process(self, msg):
await self.queue.put(msg)
async def stop(self):
await self.queue.put(None)
async def __aiter__(self):
while True:
elem = await self.queue.get()
if not elem:
return
yield elem
class PSMessage(object): class PSMessage(object):
def __init__(self, stream, end_err, type_, body):
self.stream = stream @classmethod
self.end_err = end_err def from_header_body(cls, header, body):
self.type = PSMessageType(type_) flags, length, req = struct.unpack('>BIi', header)
self.body = body type_ = PSMessageType(flags & 0x03)
if type_ == PSMessageType.TEXT:
body = body.decode('utf-8')
elif type_ == PSMessageType.JSON:
body = simplejson.loads(body)
return cls(type_, body, bool(flags & 0x08), bool(flags & 0x04), req=req)
@property @property
def data(self): def data(self):
if self.type == PSMessageType.TEXT: if self.type == PSMessageType.TEXT:
return self.body.decode('utf-8') return self.body.encode('utf-8')
elif self.type == PSMessageType.JSON: elif self.type == PSMessageType.JSON:
return simplejson.loads(self.body) return simplejson.dumps(self.body)
return self.body
def __init__(self, type_, body, stream, end_err, req=None):
self.stream = stream
self.end_err = end_err
self.type = type_
self.body = body
self.req = req
def __repr__(self): def __repr__(self):
return '<PSMessage ({}): {}>'.format(self.type.name, self.data) return '<PSMessage ({}): {}{} {}{}>'.format(self.type.name, self.body,
'' if self.req is None else ' [{}]'.format(self.req),
'~' if self.stream else '', '!' if self.end_err else '')
class PSSocket(object): class PSConnection(object):
def __init__(self):
self._event_map = {}
self.req_counter = 1
async def read(self): async def read(self):
try: try:
header = await self.connection.read() header = await self.connection.read()
if not header: if not header:
return return
body = await self.connection.read() body = await self.connection.read()
flags, length, req = struct.unpack('>BIi', header) logger.debug('READ %s %s', header, body)
return PSMessage(bool(flags & 0x08), bool(flags & 0x04), flags & 0x03, body) return PSMessage.from_header_body(header, body)
except StopAsyncIteration: except StopAsyncIteration:
logger.debug('DISCONNECT')
await self.connection.disconnect() await self.connection.disconnect()
return None return None
async def __aiter__(self): async def __await__(self):
while True: async for data in self:
data = await self.read() logger.info('RECV: %r', data)
if data is None: if data is None:
return return
yield data
def write(self, type_, data, req=0): def register_handler(self, handler):
type_ = PSMessageType[type_] self._event_map[handler.req] = (time(), handler)
if type_ == PSMessageType.JSON:
data = simplejson.dumps(data)
# XXX: Not yet handling flags that nicely async def __aiter__(self):
while True:
msg = await self.read()
if not msg:
return
if msg.req < 0:
t, handler = self._event_map[-msg.req]
if msg.end_err:
await handler.stop()
del self._event_map[-msg.req]
logger.info('REQ: %d END', msg.req)
else:
logger.info('REQ: %d ELEM: %r', msg.req, msg)
await handler.process(msg)
else:
yield msg
header = struct.pack('>BIi', 0x08 | type_.value, len(data), req) def write(self, msg):
logger.info('SEND: %r (%d)', msg, msg.req)
header = struct.pack('>BIi', (int(msg.stream) << 3) | (int(msg.end_err) << 2) | msg.type.value, len(msg.data),
msg.req)
self.connection.write(header) self.connection.write(header)
self.connection.write(data.encode('utf-8')) self.connection.write(msg.data.encode('utf-8'))
logger.info('WRITE: %s', header)
def on_connect(self, cb):
async def _on_connect():
await cb()
self.connection.on_connect(_on_connect)
def stream(self, data):
msg = PSMessage(PSMessageType.JSON, data, stream=True, end_err=False, req=self.req_counter)
self.write(msg)
handler = PSStreamHandler(self.req_counter)
self.register_handler(handler)
return handler
class PSClient(PSSocket): class PSClient(PSConnection):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None):
super(PSClient, self).__init__()
self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key, self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key,
application_key=application_key, loop=loop) application_key=application_key, loop=loop)
self.loop = loop self.loop = loop
@ -73,15 +148,11 @@ class PSClient(PSSocket):
self.connection.connect() self.connection.connect()
class PSServer(PSSocket): class PSServer(PSConnection):
def __init__(self, host, port, client_kp, application_key=None, loop=None): def __init__(self, host, port, client_kp, application_key=None, loop=None):
super(PSClient, self).__init__()
self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop) self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop)
self.loop = loop self.loop = loop
def listen(self): def listen(self):
self.connection.listen() self.connection.listen()
def on_connect(self, cb):
async def _on_connect():
await cb(self)
self.connection.on_connect(_on_connect)

View File

@ -1,10 +1,13 @@
import logging
import os import os
from asyncio import get_event_loop from asyncio import get_event_loop, gather, ensure_future
from base64 import b64decode from base64 import b64decode
import yaml import yaml
from colorlog import ColoredFormatter
from nacl.signing import SigningKey from nacl.signing import SigningKey
from ssb.api import MuxRPCAPI
from ssb.packet_stream import PSClient from ssb.packet_stream import PSClient
@ -12,17 +15,58 @@ with open(os.path.expanduser('~/.ssb/secret')) as f:
config = yaml.load(f) config = yaml.load(f)
api = MuxRPCAPI()
@api.define('createHistoryStream')
def create_history_stream(connection, msg):
print('create_history_stream', msg)
# msg = PSMessage(PSMessageType.JSON, True, stream=True, end_err=True, req=-req)
# connection.write(msg)
@api.define('blobs.createWants')
def create_wants(connection, msg):
print('create_wants', msg)
async def main():
handler = packet_stream.stream({
'name': 'createHistoryStream',
'args': [{
'id': "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519",
'seq': 1,
'live': False,
'keys': False
}],
'type': 'source'
})
async for msg in handler:
print('> RESPONSE:', msg)
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
# create formatter
formatter = ColoredFormatter('%(log_color)s%(levelname)s%(reset)s:%(bold_white)s%(name)s%(reset)s - '
'%(cyan)s%(message)s%(reset)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
logger = logging.getLogger('packet_stream')
logger.setLevel(logging.DEBUG)
logger.addHandler(ch)
server_pub_key = b64decode(config['public'][:-8]) server_pub_key = b64decode(config['public'][:-8])
async def main(loop):
async for msg in packet_stream:
print(msg)
print('bye')
loop = get_event_loop() loop = get_event_loop()
packet_stream = PSClient('127.0.0.1', 8008, SigningKey.generate(), server_pub_key, loop=loop) packet_stream = PSClient('127.0.0.1', 8008, SigningKey.generate(), server_pub_key, loop=loop)
packet_stream.connect() packet_stream.connect()
loop.run_until_complete(main(loop)) api.add_connection(packet_stream)
loop.run_until_complete(gather(ensure_future(api), main()))
loop.close() loop.close()