diff --git a/examples/test_client.py b/examples/test_client.py index da5cad2..e050c3a 100644 --- a/examples/test_client.py +++ b/examples/test_client.py @@ -59,8 +59,11 @@ async def test_client() -> None: img_data = b"" async for msg in api.call("blobs.get", ["&kqZ52sDcJSHOx7m4Ww80kK1KIZ65gpGnqwZlfaIVWWM=.sha256"], "source"): + assert msg + if msg.type.name == "BUFFER": img_data += msg.data + if msg.type.name == "JSON" and msg.data == b"true": assert ( base64.b64encode(hashlib.sha256(img_data).digest()) == b"kqZ52sDcJSHOx7m4Ww80kK1KIZ65gpGnqwZlfaIVWWM=" diff --git a/ssb/muxrpc.py b/ssb/muxrpc.py index 1f5f83b..7bb9b20 100644 --- a/ssb/muxrpc.py +++ b/ssb/muxrpc.py @@ -27,13 +27,13 @@ class MuxRPCHandler: # pylint: disable=too-few-public-methods if isinstance(body, dict) and "name" in body and body["name"] == "Error": raise MuxRPCAPIException(body["message"]) - def __await__(self): + def __await__(self) -> Generator[Optional[PSMessage], None, None]: raise NotImplementedError() - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]: raise NotImplementedError() - async def __anext__(self): + async def __anext__(self) -> Optional[PSMessage]: raise NotImplementedError() def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None: @@ -48,8 +48,14 @@ class MuxRPCRequestHandler(MuxRPCHandler): # pylint: disable=abstract-method def __init__(self, ps_handler: PSRequestHandler): self.ps_handler = ps_handler - def __await__(self): - msg = yield from self.ps_handler.__await__() + def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]: + return self + + async def __anext__(self) -> Optional[PSMessage]: + msg = await self.ps_handler.__anext__() + + assert msg + self.check_message(msg) return msg @@ -77,12 +83,14 @@ class MuxRPCSourceHandler(MuxRPCHandler): # pylint: disable=abstract-method class MuxRPCSinkHandlerMixin: # pylint: disable=too-few-public-methods """Mixin for sink-type MuxRPC handlers""" - connection: PacketStream - req: int + connection: Optional[PacketStream] + req: Optional[int] def send(self, msg: Any, msg_type: PSMessageType = PSMessageType.JSON, end: bool = False) -> None: """Send a message through the stream""" + assert self.connection + self.connection.send(msg, stream=True, msg_type=msg_type, req=self.req, end_err=end) @@ -155,7 +163,7 @@ class MuxRPCMessage: return cls(message.body) - def __init__(self, body): + def __init__(self, body: PSMessage): self.body = body def __repr__(self) -> str: diff --git a/ssb/packet_stream.py b/ssb/packet_stream.py index b1bcaff..bdb21d1 100644 --- a/ssb/packet_stream.py +++ b/ssb/packet_stream.py @@ -27,10 +27,9 @@ class PSMessageType(Enum): class PSStreamHandler: """Packet stream handler""" - def __init__(self, req): - super(PSStreamHandler).__init__() + def __init__(self, req: int): self.req = req - self.queue = Queue() + self.queue: Queue[Optional["PSMessage"]] = Queue() async def process(self, msg): """Process a pending message""" @@ -42,7 +41,7 @@ class PSStreamHandler: await self.queue.put(None) - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[Optional["PSMessage"]]: return self async def __anext__(self) -> Optional["PSMessage"]: @@ -57,7 +56,7 @@ class PSStreamHandler: class PSRequestHandler: """Packet stream request handler""" - def __init__(self, req): + def __init__(self, req: int): self.req = req self.event = Event() self._msg = None @@ -74,7 +73,10 @@ class PSRequestHandler: if not self.event.is_set(): self.event.set() - async def __await__(self): + def __aiter__(self): + return self + + async def __anext__(self) -> Optional["PSMessage"]: # wait until 'process' is called await self.event.wait()