Properly handle MuxRPC errors

This commit is contained in:
Pedro Ferreira
2017-07-30 14:08:18 +02:00
parent dda0b488c6
commit 3d652d11cd
3 changed files with 58 additions and 29 deletions

View File

@@ -1,6 +1,6 @@
import logging
import struct
from asyncio import Lock, Queue
from asyncio import Event, Queue
from enum import Enum
from time import time
@@ -42,23 +42,20 @@ class PSRequestHandler(object):
def __init__(self, req):
super(PSRequestHandler).__init__()
self.req = req
self.lock = Lock()
self.event = Event()
self._msg = None
async def process(self, msg):
self._msg = msg
self.lock.release()
self.event.set()
async def stop(self):
self._msg = None
if self.lock.locked():
self.lock.release()
if not self.event.is_set():
self.event.set()
def __await__(self):
yield from self.lock.acquire()
# try second acquire, which will only be granted
# when 'process' is called
yield from self.lock.acquire()
# wait until 'process' is called
yield from self.event.wait()
return self._msg
@@ -104,6 +101,14 @@ class PSConnection(object):
def __init__(self):
self._event_map = {}
self.req_counter = 1
self._connected = False
async def _on_connect(self):
self._connected = True
@property
def is_connected(self):
return self._connected
async def read(self):
try:
@@ -159,11 +164,6 @@ class PSConnection(object):
logger.debug('WRITE HDR: %s', header)
logger.debug('WRITE DATA: %s', msg.data)
def on_connect(self, cb):
async def _on_connect():
await cb()
self.connection.on_connect(_on_connect)
def send(self, data, msg_type=PSMessageType.JSON, stream=False, end_err=False, req=None):
update_counter = False
if req is None:
@@ -185,12 +185,17 @@ class PSConnection(object):
self.req_counter += 1
return handler
def disconnect(self):
self._connected = False
self.connection.disconnect()
class PSClient(PSConnection):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None):
super(PSClient, self).__init__()
self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key,
application_key=application_key, loop=loop)
self.connection.on_connect(self._on_connect)
self.loop = loop
def connect(self):
@@ -201,7 +206,13 @@ class PSServer(PSConnection):
def __init__(self, host, port, client_kp, application_key=None, loop=None):
super(PSClient, self).__init__()
self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop)
self.connection.on_connect(self._on_connect)
self.loop = loop
def on_connect(self, cb):
async def _on_connect():
await cb()
self.connection.on_connect(_on_connect)
def listen(self):
self.connection.listen()