test: Bring the MockSHSSocket testing class in line with the secret-handshake library
This commit is contained in:
parent
377368509e
commit
11d09b76df
@ -25,7 +25,7 @@
|
|||||||
from asyncio import Event, ensure_future, gather
|
from asyncio import Event, ensure_future, gather
|
||||||
from asyncio.events import AbstractEventLoop
|
from asyncio.events import AbstractEventLoop
|
||||||
import json
|
import json
|
||||||
from typing import Any, AsyncIterator, Awaitable, Callable, Generator, List, Optional
|
from typing import AsyncIterator, Awaitable, Callable, Generator, List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
@ -66,24 +66,24 @@ MSG_BODY_2 = (
|
|||||||
class MockSHSSocket(SHSDuplexStream):
|
class MockSHSSocket(SHSDuplexStream):
|
||||||
"""A mocked SHS socket"""
|
"""A mocked SHS socket"""
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any): # pylint: disable=unused-argument
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.input: List[bytes] = []
|
self.input: List[bytes] = []
|
||||||
self.output: List[bytes] = []
|
self.output: List[bytes] = []
|
||||||
self.is_connected: bool = False
|
self.is_connected: bool = False
|
||||||
self._on_connect: List[Callable[[], Awaitable[None]]] = []
|
self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = []
|
||||||
|
|
||||||
def on_connect(self, cb: Callable[[], Awaitable[None]]) -> None:
|
def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
|
||||||
"""Set the on_connect callback"""
|
"""Set the on_connect callback"""
|
||||||
|
|
||||||
self._on_connect.append(cb)
|
self._on_connect.append(cb)
|
||||||
|
|
||||||
async def read(self) -> Optional[bytes]:
|
async def read(self) -> bytes:
|
||||||
"""Read data from the socket"""
|
"""Read data from the socket"""
|
||||||
|
|
||||||
if not self.input:
|
if not self.input:
|
||||||
return None
|
raise StopAsyncIteration()
|
||||||
|
|
||||||
return self.input.pop(0)
|
return self.input.pop(0)
|
||||||
|
|
||||||
@ -125,7 +125,7 @@ class MockSHSServer(MockSHSSocket):
|
|||||||
self.is_connected = True
|
self.is_connected = True
|
||||||
|
|
||||||
for cb in self._on_connect:
|
for cb in self._on_connect:
|
||||||
ensure_future(cb())
|
ensure_future(cb(self))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -148,7 +148,7 @@ async def test_on_connect(ps_server: MockSHSServer) -> None: # pylint: disable=
|
|||||||
|
|
||||||
called = Event()
|
called = Event()
|
||||||
|
|
||||||
async def _on_connect() -> None:
|
async def _on_connect(_: SHSDuplexStream) -> None:
|
||||||
called.set()
|
called.set()
|
||||||
|
|
||||||
ps_server.on_connect(_on_connect)
|
ps_server.on_connect(_on_connect)
|
||||||
|
Loading…
Reference in New Issue
Block a user