Working examples for duplex, requests, source

This commit is contained in:
Pedro Ferreira 2017-07-30 10:14:26 +02:00
parent ecb67ebcf4
commit e2e893018b
3 changed files with 144 additions and 27 deletions

View File

@ -1,5 +1,54 @@
from functools import wraps from functools import wraps
from ssb.packet_stream import PSMessageType
class MuxRPCRequestHandler(object):
def __init__(self, ps_handler):
self.ps_handler = ps_handler
def __await__(self):
return self.ps_handler.__await__()
class MuxRPCSourceHandler(object):
def __init__(self, ps_handler):
self.ps_handler = ps_handler
async def __aiter__(self):
async for msg in self.ps_handler:
yield msg
class MuxRPCSinkHandlerMixin(object):
def send(self, msg, msg_type=PSMessageType.JSON):
self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req)
class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler):
def __init__(self, ps_handler, connection, req):
super(MuxRPCDuplexHandler, self).__init__(ps_handler)
self.connection = connection
self.req = req
class MuxRPCSinkHandler(MuxRPCSinkHandlerMixin):
def __init__(self, connection, req):
self.connection = connection
self.req = req
def _get_appropriate_api_handler(type_, connection, ps_handler, req):
if type_ in {'sync', 'async'}:
return MuxRPCRequestHandler(ps_handler)
elif type_ == 'source':
return MuxRPCSourceHandler(ps_handler)
elif type_ == 'sink':
return MuxRPCSinkHandler(connection, req)
elif type_ == 'duplex':
return MuxRPCDuplexHandler(ps_handler, connection, req)
class MuxRPCAPIException(Exception): class MuxRPCAPIException(Exception):
pass pass
@ -19,6 +68,18 @@ class MuxRPCRequest(object):
return '<MuxRPCRequest {0.name} {0.args}>'.format(self) return '<MuxRPCRequest {0.name} {0.args}>'.format(self)
class MuxRPCMessage(object):
@classmethod
def from_message(cls, message):
return cls(message.body)
def __init__(self, body):
self.body = body
def __repr__(self):
return '<MuxRPCMessage {0.body}}>'.format(self)
class MuxRPCAPI(object): class MuxRPCAPI(object):
def __init__(self): def __init__(self):
self.handlers = {} self.handlers = {}
@ -26,9 +87,11 @@ class MuxRPCAPI(object):
async def __await__(self): async def __await__(self):
async for req_message in self.connection: async for req_message in self.connection:
body = req_message.body
if req_message is None: if req_message is None:
return return
self.process(self.connection, MuxRPCRequest.from_message(req_message)) if isinstance(body, dict) and body.get('name'):
self.process(self.connection, MuxRPCRequest.from_message(req_message))
def add_connection(self, connection): def add_connection(self, connection):
self.connection = connection self.connection = connection
@ -48,3 +111,12 @@ class MuxRPCAPI(object):
if not handler: if not handler:
raise MuxRPCAPIException('Method {} not found!'.format(request.name)) raise MuxRPCAPIException('Method {} not found!'.format(request.name))
handler(connection, request) handler(connection, request)
def call(self, name, args, type_='sync'):
old_counter = self.connection.req_counter
ps_handler = self.connection.send({
'name': name.split('.'),
'args': args,
'type': type_
}, stream=type_ in {'sink', 'source', 'duplex'})
return _get_appropriate_api_handler(type_, self.connection, ps_handler, old_counter)

View File

@ -1,6 +1,6 @@
import logging import logging
import struct import struct
from asyncio import Queue from asyncio import Lock, Queue
from enum import Enum from enum import Enum
from time import time from time import time
@ -38,6 +38,30 @@ class PSStreamHandler(object):
yield elem yield elem
class PSRequestHandler(object):
def __init__(self, req):
super(PSRequestHandler).__init__()
self.req = req
self.lock = Lock()
self._msg = None
async def process(self, msg):
self._msg = msg
self.lock.release()
async def stop(self):
self._msg = None
if self.lock.locked():
self.lock.release()
def __await__(self):
yield from self.lock.acquire()
# try second acquire, which will only be granted
# when 'process' is called
yield from self.lock.acquire()
return self._msg
class PSMessage(object): class PSMessage(object):
@classmethod @classmethod
@ -57,7 +81,8 @@ class PSMessage(object):
if self.type == PSMessageType.TEXT: if self.type == PSMessageType.TEXT:
return self.body.encode('utf-8') return self.body.encode('utf-8')
elif self.type == PSMessageType.JSON: elif self.type == PSMessageType.JSON:
return simplejson.dumps(self.body) return simplejson.dumps(self.body).encode('utf-8')
return self.body
def __init__(self, type_, body, stream, end_err, req=None): def __init__(self, type_, body, stream, end_err, req=None):
self.stream = stream self.stream = stream
@ -106,34 +131,48 @@ class PSConnection(object):
return return
if msg.req < 0: if msg.req < 0:
t, handler = self._event_map[-msg.req] t, handler = self._event_map[-msg.req]
await handler.process(msg)
logger.info('RESPONSE [%d]: %r', -msg.req, msg)
if msg.end_err: if msg.end_err:
await handler.stop() await handler.stop()
del self._event_map[-msg.req] del self._event_map[-msg.req]
logger.info('REQ: %d END', msg.req) logger.info('RESPONSE [%d]: EOS', -msg.req)
else:
logger.info('REQ: %d ELEM: %r', msg.req, msg)
await handler.process(msg)
else: else:
yield msg yield msg
def write(self, msg): def _write(self, msg):
logger.info('SEND: %r (%d)', msg, msg.req) logger.info('SEND [%d]: %r', msg.req, msg)
header = struct.pack('>BIi', (int(msg.stream) << 3) | (int(msg.end_err) << 2) | msg.type.value, len(msg.data), header = struct.pack('>BIi', (int(msg.stream) << 3) | (int(msg.end_err) << 2) | msg.type.value, len(msg.data),
msg.req) msg.req)
self.connection.write(header) self.connection.write(header)
self.connection.write(msg.data.encode('utf-8')) self.connection.write(msg.data)
logger.info('WRITE: %s', header) logger.debug('WRITE HDR: %s', header)
logger.debug('WRITE DATA: %s', msg.data)
def on_connect(self, cb): def on_connect(self, cb):
async def _on_connect(): async def _on_connect():
await cb() await cb()
self.connection.on_connect(_on_connect) self.connection.on_connect(_on_connect)
def stream(self, data): def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None):
msg = PSMessage(PSMessageType.JSON, data, stream=True, end_err=False, req=self.req_counter) update_counter = False
self.write(msg) if req is None:
handler = PSStreamHandler(self.req_counter) update_counter = True
req = self.req_counter
msg = PSMessage(msg_type, data, stream=stream, end_err=end_err, req=req)
# send request
self._write(msg)
if stream:
handler = PSStreamHandler(self.req_counter)
else:
handler = PSRequestHandler(self.req_counter)
self.register_handler(handler) self.register_handler(handler)
if update_counter:
self.req_counter += 1
return handler return handler

View File

@ -1,5 +1,7 @@
import logging import logging
import os import os
import struct
import time
from asyncio import get_event_loop, gather, ensure_future from asyncio import get_event_loop, gather, ensure_future
from base64 import b64decode from base64 import b64decode
@ -8,7 +10,7 @@ from colorlog import ColoredFormatter
from nacl.signing import SigningKey from nacl.signing import SigningKey
from ssb.api import MuxRPCAPI from ssb.api import MuxRPCAPI
from ssb.packet_stream import PSClient from ssb.packet_stream import PSClient, PSMessageType
with open(os.path.expanduser('~/.ssb/secret')) as f: with open(os.path.expanduser('~/.ssb/secret')) as f:
@ -31,16 +33,18 @@ def create_wants(connection, msg):
async def main(): async def main():
handler = packet_stream.stream({ async for msg in api.call('createHistoryStream', [{
'name': 'createHistoryStream', 'id': "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519",
'args': [{ 'seq': 1,
'id': "@1+Iwm79DKvVBqYKFkhT6fWRbAVvNNVH4F2BSxwhYmx8=.ed25519", 'live': False,
'seq': 1, 'keys': False
'live': False, }], 'source'):
'keys': False print('> RESPONSE:', msg)
}],
'type': 'source' print('> RESPONSE:', await api.call('whoami', [], 'sync'))
})
handler = api.call('gossip.ping', [], 'duplex')
handler.send(struct.pack('l', int(time.time() * 1000)), msg_type=PSMessageType.BUFFER)
async for msg in handler: async for msg in handler:
print('> RESPONSE:', msg) print('> RESPONSE:', msg)
@ -62,9 +66,11 @@ logger.setLevel(logging.DEBUG)
logger.addHandler(ch) logger.addHandler(ch)
server_pub_key = b64decode(config['public'][:-8]) server_pub_key = b64decode(config['public'][:-8])
server_prv_key = b64decode(config['private'][:-8])
sign = SigningKey(server_prv_key[:32])
loop = get_event_loop() loop = get_event_loop()
packet_stream = PSClient('127.0.0.1', 8008, SigningKey.generate(), server_pub_key, loop=loop) packet_stream = PSClient('127.0.0.1', 8008, sign, server_pub_key, loop=loop)
packet_stream.connect() packet_stream.connect()
api.add_connection(packet_stream) api.add_connection(packet_stream)