"""Packet streams""" from asyncio import Event, Queue from enum import Enum import logging from math import ceil import struct from time import time from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union from secret_handshake.network import SHSDuplexStream import simplejson from typing_extensions import Self PSHandler = Union["PSRequestHandler", "PSStreamHandler"] PSMessageData = Union[bytes, bool, Dict[str, Any], str] logger = logging.getLogger("packet_stream") class PSMessageType(Enum): """Available message types""" BUFFER = 0 TEXT = 1 JSON = 2 class PSStreamHandler: """Packet stream handler""" def __init__(self, req: int): self.req = req self.queue: Queue[Optional["PSMessage"]] = Queue() async def process(self, msg: "PSMessage") -> None: """Process a pending message""" await self.queue.put(msg) async def stop(self) -> None: """Stop a pending request""" await self.queue.put(None) def __aiter__(self) -> AsyncIterator[Optional["PSMessage"]]: return self async def __anext__(self) -> Optional["PSMessage"]: elem = await self.queue.get() if not elem: raise StopAsyncIteration() return elem class PSRequestHandler: """Packet stream request handler""" def __init__(self, req: int): self.req = req self.event = Event() self._msg: Optional[PSMessage] = None async def process(self, msg: "PSMessage") -> None: """Process a message request""" self._msg = msg self.event.set() async def stop(self) -> None: """Stop a pending event request""" if not self.event.is_set(): self.event.set() def __aiter__(self): return self async def __anext__(self) -> Optional["PSMessage"]: # wait until 'process' is called await self.event.wait() return self._msg class PSMessage: """Packet Stream message""" @classmethod def from_header_body(cls, flags: int, req: int, body: bytes) -> Self: """Parse a raw message""" type_ = PSMessageType(flags & 0x03) if type_ == PSMessageType.TEXT: body_s = body.decode("utf-8") elif type_ == PSMessageType.JSON: body_s = simplejson.loads(body) return cls(type_, body_s, bool(flags & 0x08), bool(flags & 0x04), req=req) @property def data(self) -> bytes: """The raw message data""" if self.type == PSMessageType.TEXT: assert isinstance(self.body, str) return self.body.encode("utf-8") if self.type == PSMessageType.JSON: return simplejson.dumps(self.body).encode("utf-8") assert isinstance(self.body, bytes) return self.body def __init__( self, type_: PSMessageType, body: Any, stream: bool, end_err: bool, req: Optional[int] = None ): # pylint: disable=too-many-arguments self.stream = stream self.end_err = end_err self.type = type_ self.body = body self.req = req def __repr__(self) -> str: if self.type == PSMessageType.BUFFER: body = f"{len(self.body)} bytes" else: body = self.body req = "" if self.req is None else f" [{self.req}]" is_stream = "~" if self.stream else "" err = "!" if self.end_err else "" return f"" class PacketStream: """SSB Packet stream""" def __init__(self, connection: SHSDuplexStream): self.connection = connection self.req_counter = 1 self._event_map: Dict[int, Tuple[float, PSHandler]] = {} self._connected = False def register_handler(self, handler: PSHandler) -> None: """Register an RPC handler""" self._event_map[handler.req] = (time(), handler) @property def is_connected(self) -> bool: """Check if the stream is connected""" return self.connection.is_connected def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]: return self async def __anext__(self) -> Optional[PSMessage]: msg = await self.read() if not msg: raise StopAsyncIteration() if msg.req is not None and msg.req >= 0: return msg return None async def __await__(self) -> None: async for data in self: logger.info("RECV: %r", data) if data is None: return async def _read(self) -> Optional[PSMessage]: try: header = await self.connection.read() if not header or header == b"\x00" * 9: return None flags, length, req = struct.unpack(">BIi", header) n_packets = ceil(length / 4096) body = b"" for _ in range(n_packets): read_data = await self.connection.read() if read_data is not None: body += read_data logger.debug("READ %s %s", header, len(body)) return PSMessage.from_header_body(flags, req, body) except StopAsyncIteration: logger.debug("DISCONNECT") self.connection.disconnect() return None async def read(self) -> Optional[PSMessage]: """Read data from the packet stream""" msg = await self._read() if not msg: return None # check whether it's a reply and handle accordingly if msg.req is not None and msg.req < 0: _, handler = self._event_map[-msg.req] await handler.process(msg) logger.info("RESPONSE [%d]: %r", -msg.req, msg) if msg.end_err: await handler.stop() del self._event_map[-msg.req] logger.info("RESPONSE [%d]: EOS", -msg.req) return msg def _write(self, msg: PSMessage) -> None: logger.info("SEND [%d]: %r", msg.req, msg) header = struct.pack( ">BIi", (int(msg.stream) << 3) | (int(msg.end_err) << 2) | msg.type.value, len(msg.data), msg.req ) self.connection.write(header) self.connection.write(msg.data) logger.debug("WRITE HDR: %s", header) logger.debug("WRITE DATA: %s", msg.data) def send( # pylint: disable=too-many-arguments self, data: Any, msg_type: PSMessageType = PSMessageType.JSON, stream: bool = False, end_err: bool = False, req: Optional[int] = None, ) -> PSHandler: """Send data through the packet stream""" update_counter = False if req is None: update_counter = True req = self.req_counter msg = PSMessage(msg_type, data, stream=stream, end_err=end_err, req=req) # send request self._write(msg) if stream: handler: PSHandler = PSStreamHandler(self.req_counter) else: handler = PSRequestHandler(self.req_counter) self.register_handler(handler) if update_counter: self.req_counter += 1 return handler def disconnect(self) -> None: """Disconnect the stream""" self._connected = False self.connection.disconnect()