ci: Lint source

This commit is contained in:
2023-10-29 09:55:39 +01:00
parent 53994b77a7
commit d28ca167f2
14 changed files with 267 additions and 65 deletions

View File

@@ -18,6 +18,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Secret Handshake"""
from .network import SHSClient, SHSServer

View File

@@ -1,5 +1,8 @@
import struct
"""Box stream utilities"""
from asyncio import IncompleteReadError
import struct
from typing import Tuple
from nacl.secret import SecretBox
@@ -10,11 +13,9 @@ MAX_SEGMENT_SIZE = 4 * 1024
TERMINATION_HEADER = b"\x00" * 18
def get_stream_pair(reader, writer, **kwargs):
"""Return a tuple with `(unbox_stream, box_stream)` (reader/writer).
def get_stream_pair(reader, writer, **kwargs) -> Tuple["UnboxStream", "BoxStream"]:
"""Create a new duplex box stream"""
:return: (:class:`secret_handshake.boxstream.UnboxStream`,
:class:`secret_handshake.boxstream.BoxStream`)"""
box_args = {
"key": kwargs["encrypt_key"],
"nonce": kwargs["encrypt_nonce"],
@@ -26,7 +27,9 @@ def get_stream_pair(reader, writer, **kwargs):
return UnboxStream(reader, **unbox_args), BoxStream(writer, **box_args)
class UnboxStream(object):
class UnboxStream:
"""Unboxing stream"""
def __init__(self, reader, key, nonce):
self.reader = reader
self.key = key
@@ -34,6 +37,8 @@ class UnboxStream(object):
self.closed = False
async def read(self):
"""Read data from the stream"""
try:
data = await self.reader.readexactly(HEADER_LENGTH)
except IncompleteReadError:
@@ -70,7 +75,9 @@ class UnboxStream(object):
return data
class BoxStream(object):
class BoxStream:
"""Box stream"""
def __init__(self, writer, key, nonce):
self.writer = writer
self.key = key
@@ -78,6 +85,8 @@ class BoxStream(object):
self.nonce = nonce
def write(self, data):
"""Write data to the box stream"""
for chunk in split_chunks(data, MAX_SEGMENT_SIZE):
body = self.box.encrypt(chunk, inc_nonce(self.nonce))[24:]
header = struct.pack(">H", len(body) - 16) + body[:16]
@@ -89,4 +98,6 @@ class BoxStream(object):
self.writer.write(body[16:])
def close(self):
"""Close the box stream"""
self.writer.write(self.box.encrypt(b"\x00" * 18, self.nonce)[24:])

View File

@@ -18,10 +18,12 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Cryptography functions"""
from base64 import b64decode
import hashlib
import hmac
from base64 import b64decode
from typing import Optional
from nacl.bindings import crypto_box_afternm, crypto_box_open_afternm, crypto_scalarmult
from nacl.exceptions import CryptoError
@@ -34,13 +36,19 @@ APPLICATION_KEY = b64decode("1KHLiKZvAvjbY1ziZEHMXawbCEIM6qwjCDm3VYRan/s=")
class SHSError(Exception):
"""A SHS exception."""
pass
class SHSCryptoBase: # pylint: disable=too-many-instance-attributes
"""Base functions for SHS cryptography"""
class SHSCryptoBase(object):
def __init__(self, local_key, ephemeral_key=None, application_key=None):
self.local_key = local_key
self.application_key = application_key or APPLICATION_KEY
self.shared_hash = None
self.remote_ephemeral_key = None
self.shared_secret = None
self.remote_app_hmac = None
self.remote_pub_key = None
self.box_secret = None
self._reset_keys(ephemeral_key or PrivateKey.generate())
def _reset_keys(self, ephemeral_key):
@@ -71,12 +79,16 @@ class SHSCryptoBase(object):
return ok
def clean(self, new_ephemeral_key=None):
"""Clean internal data"""
self._reset_keys(new_ephemeral_key or PrivateKey.generate())
self.shared_secret = None
self.shared_hash = None
self.remote_ephemeral_key = None
def get_box_keys(self):
"""Get the box streams keys"""
shared_secret = hashlib.sha256(self.box_secret).digest()
return {
"shared_secret": shared_secret,
@@ -88,7 +100,18 @@ class SHSCryptoBase(object):
class SHSServerCrypto(SHSCryptoBase):
"""SHS server crypto algorithm"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.b_alice = None
self.hello = None
self.box_secret = None
self.remote_pub_key = None
def verify_client_auth(self, data):
"""Verify client authentication data"""
assert len(data) == 112
a_bob = crypto_scalarmult(bytes(self.local_key.to_curve25519_private_key()), self.remote_ephemeral_key)
box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob).digest()
@@ -107,12 +130,13 @@ class SHSServerCrypto(SHSCryptoBase):
return True
def generate_accept(self):
"""Generate an accept message"""
okay = self.local_key.sign(self.application_key + self.hello + self.shared_hash).signature
d = crypto_box_afternm(okay, b"\x00" * 24, self.box_secret)
return d
return crypto_box_afternm(okay, b"\x00" * 24, self.box_secret)
def clean(self, new_ephemeral_key=None):
super(SHSServerCrypto, self).clean(new_ephemeral_key=new_ephemeral_key)
super().clean(new_ephemeral_key=new_ephemeral_key)
self.hello = None
self.b_alice = None
@@ -120,19 +144,29 @@ class SHSServerCrypto(SHSCryptoBase):
class SHSClientCrypto(SHSCryptoBase):
"""An object that encapsulates all the SHS client-side crypto.
:param local_key: the keypair used by the client (:class:`nacl.public.PrivateKey` object)
:param server_pub_key: the server's public key (``byte`` string)
:param ephemeral_key: a fresh local :class:`nacl.public.PrivateKey`
:param application_key: the unique application key (``byte`` string), defaults to SSB's
:param local_key: the keypair used by the client
:param server_pub_key: the server's public key
:param ephemeral_key: a fresh local private key
:param application_key: the unique application key, defaults to SSB's
"""
def __init__(self, local_key, server_pub_key, ephemeral_key, application_key=None):
super(SHSClientCrypto, self).__init__(local_key, ephemeral_key, application_key)
def __init__(
self,
local_key: PrivateKey,
server_pub_key: bytes,
ephemeral_key: PrivateKey,
application_key: Optional[bytes] = None,
):
super().__init__(local_key, ephemeral_key, application_key)
self.remote_pub_key = VerifyKey(server_pub_key)
self.b_alice = None
self.a_bob = None
self.hello = None
self.box_secret = None
def verify_server_challenge(self, data):
"""Verify the correctness of challenge sent from the server."""
assert super(SHSClientCrypto, self).verify_challenge(data)
assert super().verify_challenge(data)
curve_pkey = self.remote_pub_key.to_curve25519_public_key()
# a_bob is (a * B)
@@ -168,8 +202,8 @@ class SHSClientCrypto(SHSCryptoBase):
try:
# let's use the box secret to unbox our encrypted message
signature = crypto_box_open_afternm(data, nonce, self.box_secret)
except CryptoError:
raise SHSError("Error decrypting server acceptance message")
except CryptoError as exc:
raise SHSError("Error decrypting server acceptance message") from exc
# we should have received sign(B)[K | H | hash(a * b)]
# let's see if that signature can verify the reconstructed data on our side
@@ -177,6 +211,6 @@ class SHSClientCrypto(SHSCryptoBase):
return True
def clean(self, new_ephemeral_key=None):
super(SHSClientCrypto, self).clean(new_ephemeral_key=new_ephemeral_key)
super().clean(new_ephemeral_key=new_ephemeral_key)
self.a_bob = None
self.b_alice = None

View File

@@ -18,6 +18,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Networking functionality"""
import asyncio
@@ -26,22 +27,30 @@ from .crypto import SHSClientCrypto, SHSServerCrypto
class SHSClientException(Exception):
pass
"""Base exception class for client errors"""
class SHSDuplexStream(object):
class SHSDuplexStream:
"""SHS duplex stream"""
def __init__(self):
self.write_stream = None
self.read_stream = None
self.is_connected = False
def write(self, data):
"""Write data to the write stream"""
self.write_stream.write(data)
async def read(self):
"""Read data from the read stream"""
return await self.read_stream.read()
def close(self):
"""Close the duplex stream"""
self.write_stream.close()
self.read_stream.close()
self.is_connected = False
@@ -58,21 +67,27 @@ class SHSDuplexStream(object):
return msg
class SHSEndpoint(object):
class SHSEndpoint:
"""SHS endpoint"""
def __init__(self):
self._on_connect = None
self.crypto = None
def on_connect(self, cb):
"""Set the function to be called when a new connection arrives"""
self._on_connect = cb
def disconnect(self):
"""Disconnect the endpoint"""
raise NotImplementedError
class SHSServer(SHSEndpoint):
"""SHS server"""
def __init__(self, host, port, server_kp, application_key=None):
super(SHSServer, self).__init__()
super().__init__()
self.host = host
self.port = port
self.crypto = SHSServerCrypto(server_kp, application_key=application_key)
@@ -92,6 +107,8 @@ class SHSServer(SHSEndpoint):
writer.write(self.crypto.generate_accept())
async def handle_connection(self, reader, writer):
"""Handle incoming connections"""
self.crypto.clean()
await self._handshake(reader, writer)
keys = self.crypto.get_box_keys()
@@ -104,6 +121,8 @@ class SHSServer(SHSEndpoint):
asyncio.ensure_future(self._on_connect(conn))
async def listen(self):
"""Listen for connections"""
await asyncio.start_server(self.handle_connection, self.host, self.port)
def disconnect(self):
@@ -112,23 +131,33 @@ class SHSServer(SHSEndpoint):
class SHSServerConnection(SHSDuplexStream):
"""SHS server connection"""
def __init__(self, read_stream, write_stream):
super(SHSServerConnection, self).__init__()
super().__init__()
self.read_stream = read_stream
self.write_stream = write_stream
@classmethod
def from_byte_streams(cls, reader, writer, **keys):
"""Create a server connection from an existing byte stream"""
reader, writer = get_stream_pair(reader, writer, **keys)
return cls(reader, writer)
class SHSClient(SHSDuplexStream, SHSEndpoint):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None):
"""SHS client"""
def __init__( # pylint: disable=too-many-arguments
self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None
):
SHSDuplexStream.__init__(self)
SHSEndpoint.__init__(self)
self.host = host
self.port = port
self.writer = None
self.crypto = SHSClientCrypto(
client_kp, server_pub_key, ephemeral_key=ephemeral_key, application_key=application_key
)
@@ -147,6 +176,8 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
raise SHSClientException("Server accept is not valid")
async def open(self):
"""Open the TCP connection"""
reader, writer = await asyncio.open_connection(self.host, self.port)
await self._handshake(reader, writer)
@@ -156,6 +187,7 @@ class SHSClient(SHSDuplexStream, SHSEndpoint):
self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys)
self.writer = writer
self.is_connected = True
if self._on_connect:
await self._on_connect()

View File

@@ -18,6 +18,7 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Utility functions"""
import struct
@@ -26,11 +27,16 @@ MAX_NONCE = 8 * NONCE_SIZE
def inc_nonce(nonce):
"""Increment nonce"""
num = bytes_to_long(nonce) + 1
if num > 2**MAX_NONCE:
num = 0
bnum = long_to_bytes(num)
bnum = b"\x00" * (NONCE_SIZE - len(bnum)) + bnum
return bnum
@@ -44,6 +50,8 @@ def split_chunks(seq, n):
# Stolen from PyCypto (Public Domain)
def b(s):
"""Shorthand for s.encode("latin-1")"""
return s.encode("latin-1") # utf-8 would cause some side-effects we don't want
@@ -61,8 +69,8 @@ def long_to_bytes(n, blocksize=0):
s = pack(">I", n & 0xFFFFFFFF) + s
n = n >> 32
# strip off leading zeros
for i in range(len(s)):
if s[i] != b("\000")[0]:
for i, c in enumerate(s):
if c != b("\000")[0]:
break
else:
# only happens when n == 0