diff --git a/tests/test_packet_stream.py b/tests/test_packet_stream.py index 98658d9..256626a 100644 --- a/tests/test_packet_stream.py +++ b/tests/test_packet_stream.py @@ -25,7 +25,7 @@ from asyncio import Event, ensure_future, gather from asyncio.events import AbstractEventLoop import json -from typing import Any, AsyncIterator, Awaitable, Callable, Generator, List, Optional +from typing import AsyncIterator, Awaitable, Callable, Generator, List, Optional import pytest from pytest_mock import MockerFixture @@ -66,24 +66,24 @@ MSG_BODY_2 = ( class MockSHSSocket(SHSDuplexStream): """A mocked SHS socket""" - def __init__(self, *args: Any, **kwargs: Any): # pylint: disable=unused-argument + def __init__(self) -> None: super().__init__() self.input: List[bytes] = [] self.output: List[bytes] = [] self.is_connected: bool = False - self._on_connect: List[Callable[[], Awaitable[None]]] = [] + self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = [] - def on_connect(self, cb: Callable[[], Awaitable[None]]) -> None: + def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None: """Set the on_connect callback""" self._on_connect.append(cb) - async def read(self) -> Optional[bytes]: + async def read(self) -> bytes: """Read data from the socket""" if not self.input: - return None + raise StopAsyncIteration() return self.input.pop(0) @@ -125,7 +125,7 @@ class MockSHSServer(MockSHSSocket): self.is_connected = True for cb in self._on_connect: - ensure_future(cb()) + ensure_future(cb(self)) @pytest.fixture @@ -148,7 +148,7 @@ async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable= called = Event() - async def _on_connect() -> None: + async def _on_connect(_: SHSDuplexStream) -> None: called.set() ps_server.on_connect(_on_connect)