test: Bring the MockSHSSocket testing class in line with the secret-handshake library

This commit is contained in:
Gergely Polonkai 2023-11-19 07:41:48 +01:00
parent 377368509e
commit 11d09b76df
No known key found for this signature in database
GPG Key ID: 2D2885533B869ED4
1 changed files with 8 additions and 8 deletions

View File

@ -25,7 +25,7 @@
from asyncio import Event, ensure_future, gather
from asyncio.events import AbstractEventLoop
import json
from typing import Any, AsyncIterator, Awaitable, Callable, Generator, List, Optional
from typing import AsyncIterator, Awaitable, Callable, Generator, List, Optional
import pytest
from pytest_mock import MockerFixture
@ -66,24 +66,24 @@ MSG_BODY_2 = (
class MockSHSSocket(SHSDuplexStream):
"""A mocked SHS socket"""
def __init__(self, *args: Any, **kwargs: Any): # pylint: disable=unused-argument
def __init__(self) -> None:
super().__init__()
self.input: List[bytes] = []
self.output: List[bytes] = []
self.is_connected: bool = False
self._on_connect: List[Callable[[], Awaitable[None]]] = []
self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = []
def on_connect(self, cb: Callable[[], Awaitable[None]]) -> None:
def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
"""Set the on_connect callback"""
self._on_connect.append(cb)
async def read(self) -> Optional[bytes]:
async def read(self) -> bytes:
"""Read data from the socket"""
if not self.input:
return None
raise StopAsyncIteration()
return self.input.pop(0)
@ -125,7 +125,7 @@ class MockSHSServer(MockSHSSocket):
self.is_connected = True
for cb in self._on_connect:
ensure_future(cb())
ensure_future(cb(self))
@pytest.fixture
@ -148,7 +148,7 @@ async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable=
called = Event()
async def _on_connect() -> None:
async def _on_connect(_: SHSDuplexStream) -> None:
called.set()
ps_server.on_connect(_on_connect)