test: Bring the MockSHSSocket testing class in line with the secret-handshake library
This commit is contained in:
		@@ -25,7 +25,7 @@
 | 
			
		||||
from asyncio import Event, ensure_future, gather
 | 
			
		||||
from asyncio.events import AbstractEventLoop
 | 
			
		||||
import json
 | 
			
		||||
from typing import Any, AsyncIterator, Awaitable, Callable, Generator, List, Optional
 | 
			
		||||
from typing import AsyncIterator, Awaitable, Callable, Generator, List, Optional
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from pytest_mock import MockerFixture
 | 
			
		||||
@@ -66,24 +66,24 @@ MSG_BODY_2 = (
 | 
			
		||||
class MockSHSSocket(SHSDuplexStream):
 | 
			
		||||
    """A mocked SHS socket"""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *args: Any, **kwargs: Any):  # pylint: disable=unused-argument
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self.input: List[bytes] = []
 | 
			
		||||
        self.output: List[bytes] = []
 | 
			
		||||
        self.is_connected: bool = False
 | 
			
		||||
        self._on_connect: List[Callable[[], Awaitable[None]]] = []
 | 
			
		||||
        self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = []
 | 
			
		||||
 | 
			
		||||
    def on_connect(self, cb: Callable[[], Awaitable[None]]) -> None:
 | 
			
		||||
    def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
 | 
			
		||||
        """Set the on_connect callback"""
 | 
			
		||||
 | 
			
		||||
        self._on_connect.append(cb)
 | 
			
		||||
 | 
			
		||||
    async def read(self) -> Optional[bytes]:
 | 
			
		||||
    async def read(self) -> bytes:
 | 
			
		||||
        """Read data from the socket"""
 | 
			
		||||
 | 
			
		||||
        if not self.input:
 | 
			
		||||
            return None
 | 
			
		||||
            raise StopAsyncIteration()
 | 
			
		||||
 | 
			
		||||
        return self.input.pop(0)
 | 
			
		||||
 | 
			
		||||
@@ -125,7 +125,7 @@ class MockSHSServer(MockSHSSocket):
 | 
			
		||||
        self.is_connected = True
 | 
			
		||||
 | 
			
		||||
        for cb in self._on_connect:
 | 
			
		||||
            ensure_future(cb())
 | 
			
		||||
            ensure_future(cb(self))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
@@ -148,7 +148,7 @@ async def test_on_connect(ps_server: MockSHSServer) -> None:  # pylint: disable=
 | 
			
		||||
 | 
			
		||||
    called = Event()
 | 
			
		||||
 | 
			
		||||
    async def _on_connect() -> None:
 | 
			
		||||
    async def _on_connect(_: SHSDuplexStream) -> None:
 | 
			
		||||
        called.set()
 | 
			
		||||
 | 
			
		||||
    ps_server.on_connect(_on_connect)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user