Properly handle MuxRPC errors

This commit is contained in:
Pedro Ferreira 2017-07-30 14:08:18 +02:00
parent dda0b488c6
commit 3d652d11cd
3 changed files with 58 additions and 29 deletions

View File

@ -3,21 +3,38 @@ from functools import wraps
from ssb.packet_stream import PSMessageType from ssb.packet_stream import PSMessageType
class MuxRPCRequestHandler(object): class MuxRPCAPIException(Exception):
pass
class MuxRPCHandler(object):
def check_message(self, msg):
body = msg.body
if isinstance(body, dict) and 'name' in body and body['name'] == 'Error':
raise MuxRPCAPIException(body['message'])
class MuxRPCRequestHandler(MuxRPCHandler):
def __init__(self, ps_handler): def __init__(self, ps_handler):
self.ps_handler = ps_handler self.ps_handler = ps_handler
def __await__(self): def __await__(self):
return self.ps_handler.__await__() msg = (yield from self.ps_handler.__await__())
self.check_message(msg)
return msg
class MuxRPCSourceHandler(object): class MuxRPCSourceHandler(MuxRPCHandler):
def __init__(self, ps_handler): def __init__(self, ps_handler):
self.ps_handler = ps_handler self.ps_handler = ps_handler
async def __aiter__(self): async def __aiter__(self):
async for msg in self.ps_handler: async for msg in self.ps_handler:
yield msg try:
self.check_message(msg)
yield msg
except MuxRPCAPIException:
raise
class MuxRPCSinkHandlerMixin(object): class MuxRPCSinkHandlerMixin(object):
@ -33,7 +50,7 @@ class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler):
self.req = req self.req = req
class MuxRPCSinkHandler(MuxRPCSinkHandlerMixin): class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin):
def __init__(self, connection, req): def __init__(self, connection, req):
self.connection = connection self.connection = connection
self.req = req self.req = req
@ -50,10 +67,6 @@ def _get_appropriate_api_handler(type_, connection, ps_handler, req):
return MuxRPCDuplexHandler(ps_handler, connection, req) return MuxRPCDuplexHandler(ps_handler, connection, req)
class MuxRPCAPIException(Exception):
pass
class MuxRPCRequest(object): class MuxRPCRequest(object):
@classmethod @classmethod
def from_message(cls, message): def from_message(cls, message):
@ -113,6 +126,8 @@ class MuxRPCAPI(object):
handler(connection, request) handler(connection, request)
def call(self, name, args, type_='sync'): def call(self, name, args, type_='sync'):
if not self.connection.is_connected:
raise Exception('not connected')
old_counter = self.connection.req_counter old_counter = self.connection.req_counter
ps_handler = self.connection.send({ ps_handler = self.connection.send({
'name': name.split('.'), 'name': name.split('.'),

View File

@ -1,6 +1,6 @@
import logging import logging
import struct import struct
from asyncio import Lock, Queue from asyncio import Event, Queue
from enum import Enum from enum import Enum
from time import time from time import time
@ -42,23 +42,20 @@ class PSRequestHandler(object):
def __init__(self, req): def __init__(self, req):
super(PSRequestHandler).__init__() super(PSRequestHandler).__init__()
self.req = req self.req = req
self.lock = Lock() self.event = Event()
self._msg = None self._msg = None
async def process(self, msg): async def process(self, msg):
self._msg = msg self._msg = msg
self.lock.release() self.event.set()
async def stop(self): async def stop(self):
self._msg = None if not self.event.is_set():
if self.lock.locked(): self.event.set()
self.lock.release()
def __await__(self): def __await__(self):
yield from self.lock.acquire() # wait until 'process' is called
# try second acquire, which will only be granted yield from self.event.wait()
# when 'process' is called
yield from self.lock.acquire()
return self._msg return self._msg
@ -104,6 +101,14 @@ class PSConnection(object):
def __init__(self): def __init__(self):
self._event_map = {} self._event_map = {}
self.req_counter = 1 self.req_counter = 1
self._connected = False
async def _on_connect(self):
self._connected = True
@property
def is_connected(self):
return self._connected
async def read(self): async def read(self):
try: try:
@ -159,11 +164,6 @@ class PSConnection(object):
logger.debug('WRITE HDR: %s', header) logger.debug('WRITE HDR: %s', header)
logger.debug('WRITE DATA: %s', msg.data) logger.debug('WRITE DATA: %s', msg.data)
def on_connect(self, cb):
async def _on_connect():
await cb()
self.connection.on_connect(_on_connect)
def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None): def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None):
update_counter = False update_counter = False
if req is None: if req is None:
@ -185,12 +185,17 @@ class PSConnection(object):
self.req_counter += 1 self.req_counter += 1
return handler return handler
def disconnect(self):
self._connected = False
self.connection.disconnect()
class PSClient(PSConnection): class PSClient(PSConnection):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None):
super(PSClient, self).__init__() super(PSClient, self).__init__()
self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key, self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key,
application_key=application_key, loop=loop) application_key=application_key, loop=loop)
self.connection.on_connect(self._on_connect)
self.loop = loop self.loop = loop
def connect(self): def connect(self):
@ -201,7 +206,13 @@ class PSServer(PSConnection):
def __init__(self, host, port, client_kp, application_key=None, loop=None): def __init__(self, host, port, client_kp, application_key=None, loop=None):
super(PSClient, self).__init__() super(PSClient, self).__init__()
self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop) self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop)
self.connection.on_connect(self._on_connect)
self.loop = loop self.loop = loop
def on_connect(self, cb):
async def _on_connect():
await cb()
self.connection.on_connect(_on_connect)
def listen(self): def listen(self):
self.connection.listen() self.connection.listen()

View File

@ -9,7 +9,7 @@ import yaml
from colorlog import ColoredFormatter from colorlog import ColoredFormatter
from nacl.signing import SigningKey from nacl.signing import SigningKey
from ssb.muxrpc import MuxRPCAPI from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException
from ssb.packet_stream import PSClient, PSMessageType from ssb.packet_stream import PSClient, PSMessageType
@ -41,7 +41,10 @@ async def main():
}], 'source'): }], 'source'):
print('> RESPONSE:', msg) print('> RESPONSE:', msg)
print('> RESPONSE:', await api.call('whoami', [], 'sync')) try:
print('> RESPONSE:', await api.call('whoami', [], 'sync'))
except MuxRPCAPIException as e:
print(e)
handler = api.call('gossip.ping', [], 'duplex') handler = api.call('gossip.ping', [], 'duplex')
handler.send(struct.pack('l', int(time.time() * 1000)), msg_type=PSMessageType.BUFFER) handler.send(struct.pack('l', int(time.time() * 1000)), msg_type=PSMessageType.BUFFER)
@ -72,11 +75,11 @@ logger.setLevel(logging.DEBUG)
logger.addHandler(ch) logger.addHandler(ch)
server_pub_key = b64decode(config['public'][:-8]) server_pub_key = b64decode(config['public'][:-8])
server_prv_key = b64decode(config['private'][:-8]) # server_prv_key = b64decode(config['private'][:-8])
sign = SigningKey(server_prv_key[:32]) # sign = SigningKey(server_prv_key[:32])
loop = get_event_loop() loop = get_event_loop()
packet_stream = PSClient('127.0.0.1', 8008, sign, server_pub_key, loop=loop) packet_stream = PSClient('127.0.0.1', 8008, SigningKey.generate(), server_pub_key, loop=loop)
packet_stream.connect() packet_stream.connect()
api.add_connection(packet_stream) api.add_connection(packet_stream)