PySecretHandshake/secret_handshake/network.py
Gergely Polonkai 83add95c8a
feat: Rename the disconnect method of each SHSEndpoint to close
The two verbs have the same meaning in this context, and this unifies the two interfaces, which helps type checking in
code using this library.
2023-11-13 06:36:51 +01:00

234 lines
7.3 KiB
Python

# SPDX-License-Identifier: MIT
#
# Copyright (c) 2017 PySecretHandshake contributors (see AUTHORS for more details)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Networking functionality"""
from asyncio import StreamReader, StreamWriter, ensure_future, open_connection, start_server
from typing import AsyncIterator, Awaitable, Callable, List, Optional
from nacl.public import PrivateKey
from nacl.signing import SigningKey
from typing_extensions import Self, deprecated
from .boxstream import BoxStream, UnboxStream, get_stream_pair
from .crypto import SHSClientCrypto, SHSCryptoBase, SHSServerCrypto
class SHSClientException(Exception):
"""Base exception class for client errors"""
class SHSDuplexStream:
"""SHS duplex stream"""
def __init__(self) -> None:
self.write_stream: Optional[BoxStream] = None
self.read_stream: Optional[UnboxStream] = None
self.is_connected = False
def write(self, data: bytes) -> None:
"""Write data to the write stream"""
assert self.write_stream
self.write_stream.write(data)
async def read(self) -> Optional[bytes]:
"""Read data from the read stream"""
assert self.read_stream
return await self.read_stream.read()
def close(self) -> None:
"""Close the duplex stream"""
if self.write_stream:
self.write_stream.close()
self.is_connected = False
def __aiter__(self) -> AsyncIterator[bytes]:
return self
async def __anext__(self) -> bytes:
msg = await self.read()
if msg is None:
raise StopAsyncIteration()
return msg
class SHSEndpoint:
"""SHS endpoint"""
def __init__(self) -> None:
self._on_connect: Optional[Callable[[SHSDuplexStream], Awaitable[None]]] = None
self.crypto: Optional[SHSCryptoBase] = None
def on_connect(self, cb: Callable[[SHSDuplexStream], Awaitable[None]]) -> None:
"""Set the function to be called when a new connection arrives"""
self._on_connect = cb
def close(self) -> None: # pragma: no cover
"""Disconnect the endpoint"""
raise NotImplementedError()
@deprecated("Use close instead")
def disconnect(self) -> None:
"""Disconnect the endpoint"""
self.close()
class SHSServer(SHSEndpoint):
"""SHS server"""
def __init__(self, host: str, port: int, server_kp: SigningKey, application_key: Optional[bytes] = None):
super().__init__()
self.host = host
self.port = port
self.crypto: SHSServerCrypto = SHSServerCrypto(server_kp, application_key=application_key)
self.connections: List[SHSServerConnection] = []
async def _handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
assert self.crypto
data = await reader.readexactly(64)
if not self.crypto.verify_challenge(data):
raise SHSClientException("Client challenge is not valid")
writer.write(self.crypto.generate_challenge())
data = await reader.readexactly(112)
if not self.crypto.verify_client_auth(data):
raise SHSClientException("Client auth is not valid")
writer.write(self.crypto.generate_accept())
async def handle_connection(self, reader: StreamReader, writer: StreamWriter) -> None:
"""Handle incoming connections"""
assert self.crypto
self.crypto.clean()
await self._handshake(reader, writer)
keys = self.crypto.get_box_keys()
self.crypto.clean()
conn = SHSServerConnection.from_byte_streams(reader, writer, **keys)
self.connections.append(conn)
if self._on_connect:
ensure_future(self._on_connect(conn))
async def listen(self) -> None:
"""Listen for connections"""
await start_server(self.handle_connection, self.host, self.port)
def close(self) -> None:
for connection in self.connections:
connection.close()
@deprecated("Use close instead")
def disconnect(self) -> None:
self.close()
class SHSServerConnection(SHSDuplexStream):
"""SHS server connection"""
def __init__(self, read_stream: UnboxStream, write_stream: BoxStream):
super().__init__()
self.read_stream = read_stream
self.write_stream = write_stream
@classmethod
def from_byte_streams(cls, reader: StreamReader, writer: StreamWriter, **keys: bytes) -> Self:
"""Create a server connection from an existing byte stream"""
box_reader, box_writer = get_stream_pair(reader, writer, **keys)
return cls(box_reader, box_writer)
class SHSClient(SHSDuplexStream, SHSEndpoint):
"""SHS client"""
def __init__( # pylint: disable=too-many-arguments
self,
host: str,
port: int,
client_kp: SigningKey,
server_pub_key: bytes,
ephemeral_key: Optional[PrivateKey] = None,
application_key: Optional[bytes] = None,
):
SHSDuplexStream.__init__(self)
SHSEndpoint.__init__(self)
self.host = host
self.port = port
self.writer: Optional[StreamWriter] = None
self.crypto: SHSClientCrypto = SHSClientCrypto(
client_kp,
server_pub_key,
ephemeral_key=ephemeral_key or PrivateKey.generate(),
application_key=application_key,
)
async def _handshake(self, reader: StreamReader, writer: StreamWriter) -> None:
writer.write(self.crypto.generate_challenge())
data = await reader.readexactly(64)
if not self.crypto.verify_server_challenge(data):
raise SHSClientException("Server challenge is not valid")
writer.write(self.crypto.generate_client_auth())
data = await reader.readexactly(80)
if not self.crypto.verify_server_accept(data):
raise SHSClientException("Server accept is not valid")
async def open(self) -> None:
"""Open the TCP connection"""
reader, writer = await open_connection(self.host, self.port)
await self._handshake(reader, writer)
keys = self.crypto.get_box_keys()
self.crypto.clean()
self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys)
self.writer = writer
self.is_connected = True
if self._on_connect:
await self._on_connect(self)