diff --git a/.travis.yml b/.travis.yml index 6d7672d..c504f44 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,6 @@ language: python python: + - "3.5" - "3.6.1" - "3.7-dev" install: diff --git a/setup.py b/setup.py index 51f916c..35fa9d7 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,8 @@ tests_require = [ 'pytest>=3.1.1', 'pytest-asyncio==0.6.0', 'asynctest==0.10.0', - 'pytest-mock==1.6.2' + 'pytest-mock==1.6.2', + 'async-generator==1.8' ] extras_require = { diff --git a/ssb/muxrpc.py b/ssb/muxrpc.py index e70584f..1dc0055 100644 --- a/ssb/muxrpc.py +++ b/ssb/muxrpc.py @@ -1,5 +1,7 @@ from functools import wraps +from async_generator import async_generator, yield_ + from ssb.packet_stream import PSMessageType @@ -28,11 +30,12 @@ class MuxRPCSourceHandler(MuxRPCHandler): def __init__(self, ps_handler): self.ps_handler = ps_handler + @async_generator async def __aiter__(self): async for msg in self.ps_handler: try: self.check_message(msg) - yield msg + await yield_(msg) except MuxRPCAPIException: raise diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index fd415dd..1829616 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -4,9 +4,10 @@ from asyncio import Event, Queue from enum import Enum from time import time -from secret_handshake import SHSClient, SHSServer - import simplejson +from async_generator import async_generator, yield_ + +from secret_handshake import SHSClient, SHSServer logger = logging.getLogger('packet_stream') @@ -30,12 +31,13 @@ class PSStreamHandler(object): async def stop(self): await self.queue.put(None) + @async_generator async def __aiter__(self): while True: elem = await self.queue.get() if not elem: return - yield elem + await yield_(elem) class PSRequestHandler(object): @@ -154,6 +156,7 @@ class PSConnection(object): def register_handler(self, handler): self._event_map[handler.req] = (time(), handler) + @async_generator async def __aiter__(self): while True: msg = await self.read() @@ -161,7 +164,7 @@ class PSConnection(object): return # filter out replies if msg.req >= 0: - yield msg + await yield_(msg) def _write(self, msg): logger.info('SEND [%d]: %r', msg.req, msg) diff --git a/ssb/tests/test_packet_stream.py b/ssb/tests/test_packet_stream.py index c6c6fb0..fe5ae54 100644 --- a/ssb/tests/test_packet_stream.py +++ b/ssb/tests/test_packet_stream.py @@ -1,5 +1,7 @@ -import pytest +import json from asyncio import ensure_future, gather, Event + +import pytest from asynctest import patch from nacl.signing import SigningKey