diff --git a/ssb/muxrpc.py b/ssb/muxrpc.py index 605609d..e70584f 100644 --- a/ssb/muxrpc.py +++ b/ssb/muxrpc.py @@ -3,21 +3,38 @@ from functools import wraps from ssb.packet_stream import PSMessageType -class MuxRPCRequestHandler(object): +class MuxRPCAPIException(Exception): + pass + + +class MuxRPCHandler(object): + def check_message(self, msg): + body = msg.body + if isinstance(body, dict) and 'name' in body and body['name'] == 'Error': + raise MuxRPCAPIException(body['message']) + + +class MuxRPCRequestHandler(MuxRPCHandler): def __init__(self, ps_handler): self.ps_handler = ps_handler def __await__(self): - return self.ps_handler.__await__() + msg = (yield from self.ps_handler.__await__()) + self.check_message(msg) + return msg -class MuxRPCSourceHandler(object): +class MuxRPCSourceHandler(MuxRPCHandler): def __init__(self, ps_handler): self.ps_handler = ps_handler async def __aiter__(self): async for msg in self.ps_handler: - yield msg + try: + self.check_message(msg) + yield msg + except MuxRPCAPIException: + raise class MuxRPCSinkHandlerMixin(object): @@ -33,7 +50,7 @@ class MuxRPCDuplexHandler(MuxRPCSinkHandlerMixin, MuxRPCSourceHandler): self.req = req -class MuxRPCSinkHandler(MuxRPCSinkHandlerMixin): +class MuxRPCSinkHandler(MuxRPCHandler, MuxRPCSinkHandlerMixin): def __init__(self, connection, req): self.connection = connection self.req = req @@ -50,10 +67,6 @@ def _get_appropriate_api_handler(type_, connection, ps_handler, req): return MuxRPCDuplexHandler(ps_handler, connection, req) -class MuxRPCAPIException(Exception): - pass - - class MuxRPCRequest(object): @classmethod def from_message(cls, message): @@ -113,6 +126,8 @@ class MuxRPCAPI(object): handler(connection, request) def call(self, name, args, type_='sync'): + if not self.connection.is_connected: + raise Exception('not connected') old_counter = self.connection.req_counter ps_handler = self.connection.send({ 'name': name.split('.'), diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index e094bdb..3377230 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -1,6 +1,6 @@ import logging import struct -from asyncio import Lock, Queue +from asyncio import Event, Queue from enum import Enum from time import time @@ -42,23 +42,20 @@ class PSRequestHandler(object): def __init__(self, req): super(PSRequestHandler).__init__() self.req = req - self.lock = Lock() + self.event = Event() self._msg = None async def process(self, msg): self._msg = msg - self.lock.release() + self.event.set() async def stop(self): - self._msg = None - if self.lock.locked(): - self.lock.release() + if not self.event.is_set(): + self.event.set() def __await__(self): - yield from self.lock.acquire() - # try second acquire, which will only be granted - # when 'process' is called - yield from self.lock.acquire() + # wait until 'process' is called + yield from self.event.wait() return self._msg @@ -104,6 +101,14 @@ class PSConnection(object): def __init__(self): self._event_map = {} self.req_counter = 1 + self._connected = False + + async def _on_connect(self): + self._connected = True + + @property + def is_connected(self): + return self._connected async def read(self): try: @@ -159,11 +164,6 @@ class PSConnection(object): logger.debug('WRITE HDR: %s', header) logger.debug('WRITE DATA: %s', msg.data) - def on_connect(self, cb): - async def _on_connect(): - await cb() - self.connection.on_connect(_on_connect) - def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None): update_counter = False if req is None: @@ -185,12 +185,17 @@ class PSConnection(object): self.req_counter += 1 return handler + def disconnect(self): + self._connected = False + self.connection.disconnect() + class PSClient(PSConnection): def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None): super(PSClient, self).__init__() self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key, application_key=application_key, loop=loop) + self.connection.on_connect(self._on_connect) self.loop = loop def connect(self): @@ -201,7 +206,13 @@ class PSServer(PSConnection): def __init__(self, host, port, client_kp, application_key=None, loop=None): super(PSClient, self).__init__() self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop) + self.connection.on_connect(self._on_connect) self.loop = loop + def on_connect(self, cb): + async def _on_connect(): + await cb() + self.connection.on_connect(_on_connect) + def listen(self): self.connection.listen() diff --git a/test_client.py b/test_client.py index 1e298f2..69fb9ed 100644 --- a/test_client.py +++ b/test_client.py @@ -9,7 +9,7 @@ import yaml from colorlog import ColoredFormatter from nacl.signing import SigningKey -from ssb.muxrpc import MuxRPCAPI +from ssb.muxrpc import MuxRPCAPI, MuxRPCAPIException from ssb.packet_stream import PSClient, PSMessageType @@ -41,7 +41,10 @@ async def main(): }], 'source'): print('> RESPONSE:', msg) - print('> RESPONSE:', await api.call('whoami', [], 'sync')) + try: + print('> RESPONSE:', await api.call('whoami', [], 'sync')) + except MuxRPCAPIException as e: + print(e) handler = api.call('gossip.ping', [], 'duplex') handler.send(struct.pack('l', int(time.time() * 1000)), msg_type=PSMessageType.BUFFER) @@ -72,11 +75,11 @@ logger.setLevel(logging.DEBUG) logger.addHandler(ch) server_pub_key = b64decode(config['public'][:-8]) -server_prv_key = b64decode(config['private'][:-8]) -sign = SigningKey(server_prv_key[:32]) +# server_prv_key = b64decode(config['private'][:-8]) +# sign = SigningKey(server_prv_key[:32]) loop = get_event_loop() -packet_stream = PSClient('127.0.0.1', 8008, sign, server_pub_key, loop=loop) +packet_stream = PSClient('127.0.0.1', 8008, SigningKey.generate(), server_pub_key, loop=loop) packet_stream.connect() api.add_connection(packet_stream)