test: Bring the MockSHSSocket testing class in line with the secret-handshake library

This commit is contained in:
Gergely Polonkai 2023-11-19 07:41:48 +01:00
parent 377368509e
commit 11d09b76df
No known key found for this signature in database
GPG Key ID: 2D2885533B869ED4

View File

@ -25,7 +25,7 @@
from asyncio import Event, ensure_future, gather from asyncio import Event, ensure_future, gather
from asyncio.events import AbstractEventLoop from asyncio.events import AbstractEventLoop
import json import json
from typing import Any, AsyncIterator, Awaitable, Callable, Generator, List, Optional from typing import AsyncIterator, Awaitable, Callable, Generator, List, Optional
import pytest import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
@ -66,24 +66,24 @@ MSG_BODY_2 = (
class MockSHSSocket(SHSDuplexStream): class MockSHSSocket(SHSDuplexStream):
"""A mocked SHS socket""" """A mocked SHS socket"""
def __init__(self, *args: Any, **kwargs: Any): # pylint: disable=unused-argument def __init__(self) -> None:
super().__init__() super().__init__()
self.input: List[bytes] = [] self.input: List[bytes] = []
self.output: List[bytes] = [] self.output: List[bytes] = []
self.is_connected: bool = False self.is_connected: bool = False
self._on_connect: List[Callable[[], Awaitable[None]]] = [] self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = []
def on_connect(self, cb: Callable[[], Awaitable[None]]) -> None: def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
"""Set the on_connect callback""" """Set the on_connect callback"""
self._on_connect.append(cb) self._on_connect.append(cb)
async def read(self) -> Optional[bytes]: async def read(self) -> bytes:
"""Read data from the socket""" """Read data from the socket"""
if not self.input: if not self.input:
return None raise StopAsyncIteration()
return self.input.pop(0) return self.input.pop(0)
@ -125,7 +125,7 @@ class MockSHSServer(MockSHSSocket):
self.is_connected = True self.is_connected = True
for cb in self._on_connect: for cb in self._on_connect:
ensure_future(cb()) ensure_future(cb(self))
@pytest.fixture @pytest.fixture
@ -148,7 +148,7 @@ async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable=
called = Event() called = Event()
async def _on_connect() -> None: async def _on_connect(_: SHSDuplexStream) -> None:
called.set() called.set()
ps_server.on_connect(_on_connect) ps_server.on_connect(_on_connect)