test: Bring the MockSHSSocket testing class in line with the secret-handshake library
This commit is contained in:
		@@ -25,7 +25,7 @@
 | 
				
			|||||||
from asyncio import Event, ensure_future, gather
 | 
					from asyncio import Event, ensure_future, gather
 | 
				
			||||||
from asyncio.events import AbstractEventLoop
 | 
					from asyncio.events import AbstractEventLoop
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
from typing import Any, AsyncIterator, Awaitable, Callable, Generator, List, Optional
 | 
					from typing import AsyncIterator, Awaitable, Callable, Generator, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from pytest_mock import MockerFixture
 | 
					from pytest_mock import MockerFixture
 | 
				
			||||||
@@ -66,24 +66,24 @@ MSG_BODY_2 = (
 | 
				
			|||||||
class MockSHSSocket(SHSDuplexStream):
 | 
					class MockSHSSocket(SHSDuplexStream):
 | 
				
			||||||
    """A mocked SHS socket"""
 | 
					    """A mocked SHS socket"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, *args: Any, **kwargs: Any):  # pylint: disable=unused-argument
 | 
					    def __init__(self) -> None:
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.input: List[bytes] = []
 | 
					        self.input: List[bytes] = []
 | 
				
			||||||
        self.output: List[bytes] = []
 | 
					        self.output: List[bytes] = []
 | 
				
			||||||
        self.is_connected: bool = False
 | 
					        self.is_connected: bool = False
 | 
				
			||||||
        self._on_connect: List[Callable[[], Awaitable[None]]] = []
 | 
					        self._on_connect: List[Callable[[SHSDuplexStream], Awaitable[None]]] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def on_connect(self, cb: Callable[[], Awaitable[None]]) -> None:
 | 
					    def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
 | 
				
			||||||
        """Set the on_connect callback"""
 | 
					        """Set the on_connect callback"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._on_connect.append(cb)
 | 
					        self._on_connect.append(cb)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def read(self) -> Optional[bytes]:
 | 
					    async def read(self) -> bytes:
 | 
				
			||||||
        """Read data from the socket"""
 | 
					        """Read data from the socket"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not self.input:
 | 
					        if not self.input:
 | 
				
			||||||
            return None
 | 
					            raise StopAsyncIteration()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return self.input.pop(0)
 | 
					        return self.input.pop(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -125,7 +125,7 @@ class MockSHSServer(MockSHSSocket):
 | 
				
			|||||||
        self.is_connected = True
 | 
					        self.is_connected = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for cb in self._on_connect:
 | 
					        for cb in self._on_connect:
 | 
				
			||||||
            ensure_future(cb())
 | 
					            ensure_future(cb(self))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
@@ -148,7 +148,7 @@ async def test_on_connect(ps_server: MockSHSServer) -> None:  # pylint: disable=
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    called = Event()
 | 
					    called = Event()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def _on_connect() -> None:
 | 
					    async def _on_connect(_: SHSDuplexStream) -> None:
 | 
				
			||||||
        called.set()
 | 
					        called.set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    ps_server.on_connect(_on_connect)
 | 
					    ps_server.on_connect(_on_connect)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user