pyssb/ssb/packet_stream.py

279 lines
7.2 KiB
Python

"""Packet streams"""
from asyncio import Event, Queue
from enum import Enum
import logging
from math import ceil
import struct
from time import time
from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union
from secret_handshake.network import SHSDuplexStream
import simplejson
from typing_extensions import Self
PSHandler = Union["PSRequestHandler", "PSStreamHandler"]
PSMessageData = Union[bytes, bool, Dict[str, Any], str]
logger = logging.getLogger("packet_stream")
class PSMessageType(Enum):
"""Available message types"""
BUFFER = 0
TEXT = 1
JSON = 2
class PSStreamHandler:
"""Packet stream handler"""
def __init__(self, req: int):
self.req = req
self.queue: Queue[Optional["PSMessage"]] = Queue()
async def process(self, msg: "PSMessage") -> None:
"""Process a pending message"""
await self.queue.put(msg)
async def stop(self) -> None:
"""Stop a pending request"""
await self.queue.put(None)
def __aiter__(self) -> AsyncIterator[Optional["PSMessage"]]:
return self
async def __anext__(self) -> Optional["PSMessage"]:
elem = await self.queue.get()
if not elem:
raise StopAsyncIteration()
return elem
class PSRequestHandler:
"""Packet stream request handler"""
def __init__(self, req: int):
self.req = req
self.event = Event()
self._msg: Optional[PSMessage] = None
async def process(self, msg: "PSMessage") -> None:
"""Process a message request"""
self._msg = msg
self.event.set()
async def stop(self) -> None:
"""Stop a pending event request"""
if not self.event.is_set():
self.event.set()
def __aiter__(self):
return self
async def __anext__(self) -> Optional["PSMessage"]:
# wait until 'process' is called
await self.event.wait()
return self._msg
class PSMessage:
"""Packet Stream message"""
@classmethod
def from_header_body(cls, flags: int, req: int, body: bytes) -> Self:
"""Parse a raw message"""
type_ = PSMessageType(flags & 0x03)
if type_ == PSMessageType.TEXT:
body_s = body.decode("utf-8")
elif type_ == PSMessageType.JSON:
body_s = simplejson.loads(body)
return cls(type_, body_s, bool(flags & 0x08), bool(flags & 0x04), req=req)
@property
def data(self) -> bytes:
"""The raw message data"""
if self.type == PSMessageType.TEXT:
assert isinstance(self.body, str)
return self.body.encode("utf-8")
if self.type == PSMessageType.JSON:
return simplejson.dumps(self.body).encode("utf-8")
assert isinstance(self.body, bytes)
return self.body
def __init__(
self, type_: PSMessageType, body: Any, stream: bool, end_err: bool, req: Optional[int] = None
): # pylint: disable=too-many-arguments
self.stream = stream
self.end_err = end_err
self.type = type_
self.body = body
self.req = req
def __repr__(self) -> str:
if self.type == PSMessageType.BUFFER:
body = f"{len(self.body)} bytes"
else:
body = self.body
req = "" if self.req is None else f" [{self.req}]"
is_stream = "~" if self.stream else ""
err = "!" if self.end_err else ""
return f"<PSMessage ({self.type.name}): {body}{req} {is_stream}{err}>"
class PacketStream:
"""SSB Packet stream"""
def __init__(self, connection: SHSDuplexStream):
self.connection = connection
self.req_counter = 1
self._event_map: Dict[int, Tuple[float, PSHandler]] = {}
self._connected = False
def register_handler(self, handler: PSHandler) -> None:
"""Register an RPC handler"""
self._event_map[handler.req] = (time(), handler)
@property
def is_connected(self) -> bool:
"""Check if the stream is connected"""
return self.connection.is_connected
def __aiter__(self) -> AsyncIterator[Optional[PSMessage]]:
return self
async def __anext__(self) -> Optional[PSMessage]:
msg = await self.read()
if not msg:
raise StopAsyncIteration()
if msg.req is not None and msg.req >= 0:
return msg
return None
async def __await__(self) -> None:
async for data in self:
logger.info("RECV: %r", data)
if data is None:
return
async def _read(self) -> Optional[PSMessage]:
try:
header = await self.connection.read()
if not header or header == b"\x00" * 9:
return None
flags, length, req = struct.unpack(">BIi", header)
n_packets = ceil(length / 4096)
body = b""
for _ in range(n_packets):
read_data = await self.connection.read()
if read_data is not None:
body += read_data
logger.debug("READ %s %s", header, len(body))
return PSMessage.from_header_body(flags, req, body)
except StopAsyncIteration:
logger.debug("DISCONNECT")
self.connection.disconnect()
return None
async def read(self) -> Optional[PSMessage]:
"""Read data from the packet stream"""
msg = await self._read()
if not msg:
return None
# check whether it's a reply and handle accordingly
if msg.req is not None and msg.req < 0:
_, handler = self._event_map[-msg.req]
await handler.process(msg)
logger.info("RESPONSE [%d]: %r", -msg.req, msg)
if msg.end_err:
await handler.stop()
del self._event_map[-msg.req]
logger.info("RESPONSE [%d]: EOS", -msg.req)
return msg
def _write(self, msg: PSMessage) -> None:
logger.info("SEND [%d]: %r", msg.req, msg)
header = struct.pack(
">BIi", (int(msg.stream) << 3) | (int(msg.end_err) << 2) | msg.type.value, len(msg.data), msg.req
)
self.connection.write(header)
self.connection.write(msg.data)
logger.debug("WRITE HDR: %s", header)
logger.debug("WRITE DATA: %s", msg.data)
def send( # pylint: disable=too-many-arguments
self,
data: Any,
msg_type: PSMessageType = PSMessageType.JSON,
stream: bool = False,
end_err: bool = False,
req: Optional[int] = None,
) -> PSHandler:
"""Send data through the packet stream"""
update_counter = False
if req is None:
update_counter = True
req = self.req_counter
msg = PSMessage(msg_type, data, stream=stream, end_err=end_err, req=req)
# send request
self._write(msg)
if stream:
handler: PSHandler = PSStreamHandler(self.req_counter)
else:
handler = PSRequestHandler(self.req_counter)
self.register_handler(handler)
if update_counter:
self.req_counter += 1
return handler
def disconnect(self) -> None:
"""Disconnect the stream"""
self._connected = False
self.connection.disconnect()