ci: Lint source

This commit is contained in:
2023-10-29 09:55:39 +01:00
parent 53994b77a7
commit d28ca167f2
14 changed files with 267 additions and 65 deletions

View File

@@ -18,6 +18,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Networking functionality"""
import asyncio
@@ -26,22 +27,30 @@ from .crypto import SHSClientCrypto, SHSServerCrypto
class SHSClientException(Exception):
pass
"""Base exception class for client errors"""
class SHSDuplexStream(object):
class SHSDuplexStream:
"""SHS duplex stream"""
def __init__(self):
self.write_stream = None
self.read_stream = None
self.is_connected = False
def write(self, data):
"""Write data to the write stream"""
self.write_stream.write(data)
async def read(self):
"""Read data from the read stream"""
return await self.read_stream.read()
def close(self):
"""Close the duplex stream"""
self.write_stream.close()
self.read_stream.close()
self.is_connected = False
@@ -58,21 +67,27 @@ class SHSDuplexStream(object):
return msg
class SHSEndpoint(object):
class SHSEndpoint:
"""SHS endpoint"""
def __init__(self):
self._on_connect = None
self.crypto = None
def on_connect(self, cb):
"""Set the function to be called when a new connection arrives"""
self._on_connect = cb
def disconnect(self):
"""Disconnect the endpoint"""
raise NotImplementedError
class SHSServer(SHSEndpoint):
"""SHS server"""
def __init__(self, host, port, server_kp, application_key=None):
super(SHSServer, self).__init__()
super().__init__()
self.host = host
self.port = port
self.crypto = SHSServerCrypto(server_kp, application_key=application_key)
@@ -92,6 +107,8 @@ class SHSServer(SHSEndpoint):
writer.write(self.crypto.generate_accept())
async def handle_connection(self, reader, writer):
"""Handle incoming connections"""
self.crypto.clean()
await self._handshake(reader, writer)
keys = self.crypto.get_box_keys()
@@ -104,6 +121,8 @@ class SHSServer(SHSEndpoint):
asyncio.ensure_future(self._on_connect(conn))
async def listen(self):
"""Listen for connections"""
await asyncio.start_server(self.handle_connection, self.host, self.port)
def disconnect(self):
@@ -112,23 +131,33 @@ class SHSServer(SHSEndpoint):
class SHSServerConnection(SHSDuplexStream):
"""SHS server connection"""
def __init__(self, read_stream, write_stream):
super(SHSServerConnection, self).__init__()
super().__init__()
self.read_stream = read_stream
self.write_stream = write_stream
@classmethod
def from_byte_streams(cls, reader, writer, **keys):
"""Create a server connection from an existing byte stream"""
reader, writer = get_stream_pair(reader, writer, **keys)
return cls(reader, writer)
class SHSClient(SHSDuplexStream, SHSEndpoint):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None):
"""SHS client"""
def __init__( # pylint: disable=too-many-arguments
self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None
):
SHSDuplexStream.__init__(self)
SHSEndpoint.__init__(self)
self.host = host
self.port = port
self.writer = None
self.crypto = SHSClientCrypto(
client_kp, server_pub_key, ephemeral_key=ephemeral_key, application_key=application_key
)
@@ -147,6 +176,8 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
raise SHSClientException("Server accept is not valid")
async def open(self):
"""Open the TCP connection"""
reader, writer = await asyncio.open_connection(self.host, self.port)
await self._handshake(reader, writer)
@@ -156,6 +187,7 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys)
self.writer = writer
self.is_connected = True
if self._on_connect:
await self._on_connect()