Properly handle MuxRPC errors
This commit is contained in:
parent
dda0b488c6
commit
3d652d11cd
@ -3,21 +3,38 @@ from functools import wraps
|
|||||||
from ssb.packet_stream import PSMessageType
|
from ssb.packet_stream import PSMessageType
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCRequestHandler(object):
|
class MuxRPCAPIException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MuxRPCHandler(object):
|
||||||
|
def check_message(self, msg):
|
||||||
|
body = msg.body
|
||||||
|
if isinstance(body, dict) and 'name' in body and body['name'] == 'Error':
|
||||||
|
raise MuxRPCAPIException(body['message'])
|
||||||
|
|
||||||
|
|
||||||
|
class MuxRPCRequestHandler(MuxRPCHandler):
|
||||||
def __init__(self, ps_handler):
|
def __init__(self, ps_handler):
|
||||||
self.ps_handler = ps_handler
|
self.ps_handler = ps_handler
|
||||||
|
|
||||||
def __await__(self):
|
def __await__(self):
|
||||||
return self.ps_handler.__await__()
|
msg = (yield from self.ps_handler.__await__())
|
||||||
|
self.check_message(msg)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCSourceHandler(object):
|
class MuxRPCSourceHandler(MuxRPCHandler):
|
||||||
def __init__(self, ps_handler):
|
def __init__(self, ps_handler):
|
||||||
self.ps_handler = ps_handler
|
self.ps_handler = ps_handler
|
||||||
|
|
||||||
async def __aiter__(self):
|
async def __aiter__(self):
|
||||||
async for msg in self.ps_handler:
|
async for msg in self.ps_handler:
|
||||||
yield msg
|
try:
|
||||||
|
self.check_message(msg)
|
||||||
|
yield msg
|
||||||
|
except MuxRPCAPIException:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCSinkHandlerMixin(object):
|
class MuxRPCSinkHandlerMixin(object):
|
||||||
@ -33,7 +50,7 @@ class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler):
|
|||||||
self.req = req
|
self.req = req
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCSinkHandler(MuxRPCSinkHandlerMixin):
|
class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin):
|
||||||
def __init__(self, connection, req):
|
def __init__(self, connection, req):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.req = req
|
self.req = req
|
||||||
@ -50,10 +67,6 @@ def _get_appropriate_api_handler(type_, connection, ps_handler, req):
|
|||||||
return MuxRPCDuplexHandler(ps_handler, connection, req)
|
return MuxRPCDuplexHandler(ps_handler, connection, req)
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCAPIException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MuxRPCRequest(object):
|
class MuxRPCRequest(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_message(cls, message):
|
def from_message(cls, message):
|
||||||
@ -113,6 +126,8 @@ class MuxRPCAPI(object):
|
|||||||
handler(connection, request)
|
handler(connection, request)
|
||||||
|
|
||||||
def call(self, name, args, type_='sync'):
|
def call(self, name, args, type_='sync'):
|
||||||
|
if not self.connection.is_connected:
|
||||||
|
raise Exception('not connected')
|
||||||
old_counter = self.connection.req_counter
|
old_counter = self.connection.req_counter
|
||||||
ps_handler = self.connection.send({
|
ps_handler = self.connection.send({
|
||||||
'name': name.split('.'),
|
'name': name.split('.'),
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from asyncio import Lock, Queue
|
from asyncio import Event, Queue
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
@ -42,23 +42,20 @@ class PSRequestHandler(object):
|
|||||||
def __init__(self, req):
|
def __init__(self, req):
|
||||||
super(PSRequestHandler).__init__()
|
super(PSRequestHandler).__init__()
|
||||||
self.req = req
|
self.req = req
|
||||||
self.lock = Lock()
|
self.event = Event()
|
||||||
self._msg = None
|
self._msg = None
|
||||||
|
|
||||||
async def process(self, msg):
|
async def process(self, msg):
|
||||||
self._msg = msg
|
self._msg = msg
|
||||||
self.lock.release()
|
self.event.set()
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self._msg = None
|
if not self.event.is_set():
|
||||||
if self.lock.locked():
|
self.event.set()
|
||||||
self.lock.release()
|
|
||||||
|
|
||||||
def __await__(self):
|
def __await__(self):
|
||||||
yield from self.lock.acquire()
|
# wait until 'process' is called
|
||||||
# try second acquire, which will only be granted
|
yield from self.event.wait()
|
||||||
# when 'process' is called
|
|
||||||
yield from self.lock.acquire()
|
|
||||||
return self._msg
|
return self._msg
|
||||||
|
|
||||||
|
|
||||||
@ -104,6 +101,14 @@ class PSConnection(object):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._event_map = {}
|
self._event_map = {}
|
||||||
self.req_counter = 1
|
self.req_counter = 1
|
||||||
|
self._connected = False
|
||||||
|
|
||||||
|
async def _on_connect(self):
|
||||||
|
self._connected = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self):
|
||||||
|
return self._connected
|
||||||
|
|
||||||
async def read(self):
|
async def read(self):
|
||||||
try:
|
try:
|
||||||
@ -159,11 +164,6 @@ class PSConnection(object):
|
|||||||
logger.debug('WRITE HDR: %s', header)
|
logger.debug('WRITE HDR: %s', header)
|
||||||
logger.debug('WRITE DATA: %s', msg.data)
|
logger.debug('WRITE DATA: %s', msg.data)
|
||||||
|
|
||||||
def on_connect(self, cb):
|
|
||||||
async def _on_connect():
|
|
||||||
await cb()
|
|
||||||
self.connection.on_connect(_on_connect)
|
|
||||||
|
|
||||||
def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None):
|
def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None):
|
||||||
update_counter = False
|
update_counter = False
|
||||||
if req is None:
|
if req is None:
|
||||||
@ -185,12 +185,17 @@ class PSConnection(object):
|
|||||||
self.req_counter += 1
|
self.req_counter += 1
|
||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
self._connected = False
|
||||||
|
self.connection.disconnect()
|
||||||
|
|
||||||
|
|
||||||
class PSClient(PSConnection):
|
class PSClient(PSConnection):
|
||||||
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None):
|
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None):
|
||||||
super(PSClient, self).__init__()
|
super(PSClient, self).__init__()
|
||||||
self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key,
|
self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key,
|
||||||
application_key=application_key, loop=loop)
|
application_key=application_key, loop=loop)
|
||||||
|
self.connection.on_connect(self._on_connect)
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
@ -201,7 +206,13 @@ class PSServer(PSConnection):
|
|||||||
def __init__(self, host, port, client_kp, application_key=None, loop=None):
|
def __init__(self, host, port, client_kp, application_key=None, loop=None):
|
||||||
super(PSClient, self).__init__()
|
super(PSClient, self).__init__()
|
||||||
self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop)
|
self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop)
|
||||||
|
self.connection.on_connect(self._on_connect)
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
|
|
||||||
|
def on_connect(self, cb):
|
||||||
|
async def _on_connect():
|
||||||
|
await cb()
|
||||||
|
self.connection.on_connect(_on_connect)
|
||||||
|
|
||||||
def listen(self):
|
def listen(self):
|
||||||
self.connection.listen()
|
self.connection.listen()
|
||||||
|
@ -9,7 +9,7 @@ import yaml
|
|||||||
from colorlog import ColoredFormatter
|
from colorlog import ColoredFormatter
|
||||||
from nacl.signing import SigningKey
|
from nacl.signing import SigningKey
|
||||||
|
|
||||||
from ssb.muxrpc import MuxRPCAPI
|
from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException
|
||||||
from ssb.packet_stream import PSClient, PSMessageType
|
from ssb.packet_stream import PSClient, PSMessageType
|
||||||
|
|
||||||
|
|
||||||
@ -41,7 +41,10 @@ async def main():
|
|||||||
}], 'source'):
|
}], 'source'):
|
||||||
print('> RESPONSE:', msg)
|
print('> RESPONSE:', msg)
|
||||||
|
|
||||||
print('> RESPONSE:', await api.call('whoami', [], 'sync'))
|
try:
|
||||||
|
print('> RESPONSE:', await api.call('whoami', [], 'sync'))
|
||||||
|
except MuxRPCAPIException as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
handler = api.call('gossip.ping', [], 'duplex')
|
handler = api.call('gossip.ping', [], 'duplex')
|
||||||
handler.send(struct.pack('l', int(time.time() * 1000)), msg_type=PSMessageType.BUFFER)
|
handler.send(struct.pack('l', int(time.time() * 1000)), msg_type=PSMessageType.BUFFER)
|
||||||
@ -72,11 +75,11 @@ logger.setLevel(logging.DEBUG)
|
|||||||
logger.addHandler(ch)
|
logger.addHandler(ch)
|
||||||
|
|
||||||
server_pub_key = b64decode(config['public'][:-8])
|
server_pub_key = b64decode(config['public'][:-8])
|
||||||
server_prv_key = b64decode(config['private'][:-8])
|
# server_prv_key = b64decode(config['private'][:-8])
|
||||||
sign = SigningKey(server_prv_key[:32])
|
# sign = SigningKey(server_prv_key[:32])
|
||||||
|
|
||||||
loop = get_event_loop()
|
loop = get_event_loop()
|
||||||
packet_stream = PSClient('127.0.0.1', 8008, sign, server_pub_key, loop=loop)
|
packet_stream = PSClient('127.0.0.1', 8008, SigningKey.generate(), server_pub_key, loop=loop)
|
||||||
packet_stream.connect()
|
packet_stream.connect()
|
||||||
api.add_connection(packet_stream)
|
api.add_connection(packet_stream)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user