Extract secret-handshake into separate lib

This commit is contained in:
Pedro Ferreira 2017-06-05 00:27:44 +02:00
parent 761b503c95
commit 21af5fba09
9 changed files with 72 additions and 393 deletions

View File

@ -1,3 +1,4 @@
pynacl pynacl
simplejson simplejson
pytest pytest
secret-handshake

View File

@ -1,77 +0,0 @@
import struct
from nacl.secret import SecretBox
from .util import bytes_to_long, long_to_bytes
NONCE_SIZE = 24
HEADER_LENGTH = 2 + 16 + 16
MAX_SEGMENT_SIZE = 4 * 1024
TERMINATION_HEADER = (b'\x00' * HEADER_LENGTH)
MAX_NONCE = (8 * NONCE_SIZE)
def inc_nonce(nonce):
num = bytes_to_long(nonce) + 1
if num > 2**MAX_NONCE:
num = 0
bnum = long_to_bytes(num)
bnum = b'\x00' * (NONCE_SIZE - len(bnum)) + bnum
return bnum
def get_stream_pair(reader, writer, **kwargs):
shared_secret = kwargs.pop('shared_secret')
return UnboxStream(reader, shared_secret, **kwargs), BoxStream(writer, shared_secret, **kwargs)
class UnboxStream(object):
def __init__(self, reader, shared_secret, **key_data):
self.reader = reader
self.decrypt_key = key_data.get('decrypt_key')
self.decrypt_nonce = key_data.get('decrypt_nonce')
async def process(self):
while True:
data = await self.reader.read(HEADER_LENGTH)
if not data:
break
box = SecretBox(self.decrypt_key)
header = box.decrypt(data, self.decrypt_nonce)
if header == TERMINATION_HEADER:
return
length = struct.unpack('>H', header[:2])[0]
mac = header[2:]
data = await self.reader.read(length)
self.decrypt_nonce = inc_nonce(self.decrypt_nonce)
body = box.decrypt(mac + data, self.decrypt_nonce)
self.decrypt_nonce = inc_nonce(self.decrypt_nonce)
yield body
print('Disconnect')
class BoxStream(object):
def __init__(self, writer, shared_secret, **key_data):
self.writer = writer
self.encrypt_key = key_data.get('decrypt_key')
self.encrypt_nonce = key_data.get('decrypt_nonce')
async def write(self, data):
box = SecretBox(self.encrypt_key)
# XXX: This nonce logic is almost for sure wrong
self.encrypt_nonce = inc_nonce(self.encrypt_nonce)
body = box.encrypt(data, self.encrypt_nonce)
header = struct.pack('>H', len(body)) + body[:16]
self.writer.write(box.encrypt(header, self.encrypt_nonce))
self.encrypt_nonce = inc_nonce(self.encrypt_nonce)
self.writer.write(body)

View File

@ -1,7 +1,7 @@
import struct import struct
from enum import Enum from enum import Enum
from .shs.socket import SHSClient, SHSServer from secret_handshake import SHSClient, SHSServer
import simplejson import simplejson
@ -33,17 +33,25 @@ class PSMessage(object):
class PSSocket(object): class PSSocket(object):
async def read(self): async def read(self):
while True: try:
try: header = await self.connection.read()
header = await self.connection.read().__anext__() if not header:
body = await self.connection.read().__anext__() return
flags, length, req = struct.unpack('>BIi', header) body = await self.connection.read()
yield PSMessage(bool(flags & 0x08), bool(flags & 0x04), flags & 0x03, body) flags, length, req = struct.unpack('>BIi', header)
except StopAsyncIteration: return PSMessage(bool(flags & 0x08), bool(flags & 0x04), flags & 0x03, body)
await self.connection.disconnect() except StopAsyncIteration:
break await self.connection.disconnect()
return None
async def write(self, type_, data, req=0): async def __aiter__(self):
while True:
data = await self.read()
if data is None:
return
yield data
def write(self, type_, data, req=0):
type_ = PSMessageType[type_] type_ = PSMessageType[type_]
if type_ == PSMessageType.JSON: if type_ == PSMessageType.JSON:
data = simplejson.dumps(data) data = simplejson.dumps(data)
@ -51,27 +59,29 @@ class PSSocket(object):
# XXX: Not yet handling flags that nicely # XXX: Not yet handling flags that nicely
header = struct.pack('>BIi', 0x08 | type_.value, len(data), req) header = struct.pack('>BIi', 0x08 | type_.value, len(data), req)
await self.connection.write(header) self.connection.write(header)
await self.connection.write(data.encode('utf-8')) self.connection.write(data.encode('utf-8'))
class PSClient(PSSocket): class PSClient(PSSocket):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None): def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None, loop=None):
self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key, self.connection = SHSClient(host, port, client_kp, server_pub_key, ephemeral_key=ephemeral_key,
application_key=application_key) application_key=application_key, loop=loop)
self.loop = loop
async def connect(self, loop=None): def connect(self):
await self.connection.connect(loop=loop) self.connection.connect()
class PSServer(PSSocket): class PSServer(PSSocket):
def __init__(self, host, port, client_kp, application_key=None): def __init__(self, host, port, client_kp, application_key=None, loop=None):
self.connection = SHSServer(host, port, client_kp, application_key=application_key) self.connection = SHSServer(host, port, client_kp, application_key=application_key, loop=loop)
self.loop = loop
async def listen(self, loop=None): def listen(self):
await self.connection.listen(loop=loop) self.connection.listen()
def on_connect(self, handler): def on_connect(self, cb):
async def _on_connect(): async def _on_connect():
await handler(self) await cb(self)
self.connection._on_connect = _on_connect self.connection.on_connect(_on_connect)

View File

@ -1,4 +0,0 @@
from .crypto import SHSClientCrypto, SHSServerCrypto
from .socket import SHSClient, SHSServer
__all__ = ('SHSClient', 'SHSClientCrypto', 'SHSServer', 'SHSServerCrypto')

View File

@ -1,160 +0,0 @@
import hashlib
import hmac
from base64 import b64decode
from nacl.bindings import crypto_scalarmult, crypto_box_afternm, crypto_box_open_afternm
from nacl.exceptions import CryptoError
from nacl.public import PrivateKey
from nacl.signing import VerifyKey
APPLICATION_KEY = b64decode('1KHLiKZvAvjbY1ziZEHMXawbCEIM6qwjCDm3VYRan/s=')
class SHSError(Exception):
"""A SHS exception."""
pass
class SHSCryptoBase(object):
def __init__(self, local_key, ephemeral_key=None, application_key=None):
self.local_key = local_key
self.application_key = application_key or APPLICATION_KEY
self._reset_keys(ephemeral_key or PrivateKey.generate())
def _reset_keys(self, ephemeral_key):
self.local_ephemeral_key = ephemeral_key
self.local_app_hmac = (hmac.new(self.application_key, bytes(ephemeral_key.public_key), digestmod='sha512')
.digest()[:32])
def generate_challenge(self):
"""Generate and return a challenge to be sent to the server."""
return self.local_app_hmac + bytes(self.local_ephemeral_key.public_key)
def verify_challenge(self, data):
"""Verify the correctness of challenge sent from the client."""
assert len(data) == 64
sent_hmac, remote_ephemeral_key = data[:32], data[32:]
h = hmac.new(self.application_key, remote_ephemeral_key, digestmod='sha512')
self.remote_app_hmac = h.digest()[:32]
ok = self.remote_app_hmac == sent_hmac
if ok:
# this is (a * b)
self.shared_secret = crypto_scalarmult(bytes(self.local_ephemeral_key), remote_ephemeral_key)
self.remote_ephemeral_key = remote_ephemeral_key
# this is hash(a * b)
self.shared_hash = hashlib.sha256(self.shared_secret).digest()
return ok
def clean(self, new_ephemeral_key=None):
self._reset_keys(new_ephemeral_key or PrivateKey.generate())
self.shared_secret = None
self.shared_hash = None
self.remote_ephemeral_key = None
def get_box_keys(self):
shared_secret = hashlib.sha256(self.box_secret).digest()
return {
'shared_secret': shared_secret,
'encrypt_key': hashlib.sha256(shared_secret + bytes(self.remote_pub_key)).digest(),
'decrypt_key': hashlib.sha256(shared_secret + bytes(self.local_key.verify_key)).digest(),
'encrypt_nonce': self.remote_app_hmac[:24],
'decrypt_nonce': self.local_app_hmac[:24]
}
class SHSServerCrypto(SHSCryptoBase):
def verify_client_auth(self, data):
assert len(data) == 112
a_bob = crypto_scalarmult(bytes(self.local_key.to_curve25519_private_key()), self.remote_ephemeral_key)
box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob).digest()
self.hello = crypto_box_open_afternm(data, b'\x00' * 24, box_secret)
signature, public_key = self.hello[:64], self.hello[64:]
signed = self.application_key + bytes(self.local_key.verify_key) + self.shared_hash
pkey = VerifyKey(public_key)
# will raise an exception if verification fails
pkey.verify(signed, signature)
self.remote_pub_key = pkey
b_alice = crypto_scalarmult(bytes(self.local_ephemeral_key),
bytes(self.remote_pub_key.to_curve25519_public_key()))
self.box_secret = hashlib.sha256(self.application_key + self.shared_secret + a_bob + b_alice).digest()[:32]
return True
def generate_accept(self):
okay = self.local_key.sign(self.application_key + self.hello + self.shared_hash).signature
d = crypto_box_afternm(okay, b'\x00' * 24, self.box_secret)
return d
def clean(self, new_ephemeral_key=None):
super(SHSServerCrypto, self).clean(new_ephemeral_key=new_ephemeral_key)
self.hello = None
self.local_lterm_shared = None
class SHSClientCrypto(SHSCryptoBase):
"""An object that encapsulates all the SHS client-side crypto.
:param local_key: the :class:`ssb.keys.KeyPair` used by the client
:param local_ephemeral_key: a fresh local :class:`nacl.public.PrivateKey`
:param server_pub_key: the server's public key (``byte`` string)
:param application_key: the unique application key (``byte`` string), defaults to SSB's
"""
def __init__(self, local_key, server_pub_key, ephemeral_key, application_key=None):
super(SHSClientCrypto, self).__init__(local_key, ephemeral_key, application_key)
self.remote_pub_key = VerifyKey(server_pub_key)
def verify_server_challenge(self, data):
"""Verify the correctness of challenge sent from the server."""
# TODO: use super.verify_challenge and add extra logic
return super(SHSClientCrypto, self).verify_challenge(data)
def generate_client_auth(self):
"""Generate box[K|a*b|a*B](H)"""
curve_pkey = self.remote_pub_key.to_curve25519_public_key()
# remote_lterm_shared is (a * B)
remote_lterm_shared = crypto_scalarmult(bytes(self.local_ephemeral_key), bytes(curve_pkey))
self.remote_lterm_shared = remote_lterm_shared
# this shall be hash(K | a * b | a * B)
box_secret = hashlib.sha256(self.application_key + self.shared_secret + remote_lterm_shared).digest()
# and message_to_box will correspond to H = sign(A)[K | Bp | hash(a * b)] | Ap
signed_message = self.local_key.sign(self.application_key + bytes(self.remote_pub_key) + self.shared_hash)
message_to_box = signed_message.signature + bytes(self.local_key.verify_key)
self.client_auth = message_to_box
nonce = b"\x00" * 24
# return box(K | a * b | a * B)[H]
return crypto_box_afternm(message_to_box, nonce, box_secret)
def verify_server_accept(self, data):
"""Verify that the server's accept message is sane"""
curve_lkey = self.local_key.to_curve25519_private_key()
# local_lterm_shared is (A * b)
local_lterm_shared = crypto_scalarmult(bytes(curve_lkey), self.remote_ephemeral_key)
self.local_lterm_shared = local_lterm_shared
# this is hash(K | a * b | a * B | A * b)
self.box_secret = hashlib.sha256(self.application_key + self.shared_secret + self.remote_lterm_shared +
local_lterm_shared).digest()
nonce = b"\x00" * 24
try:
# let's use the box secret to unbox our encrypted message
signature = crypto_box_open_afternm(data, nonce, self.box_secret)
except CryptoError:
raise SHSError('Error decrypting server acceptance message')
# we should have received sign(B)[K | H | hash(a * b)]
# let's see if that signature can verify the reconstructed data on our side
self.remote_pub_key.verify(self.application_key + self.client_auth + self.shared_hash, signature)
return True
def clean(self, new_ephemeral_key=None):
super(SHSClientCrypto, self).clean(new_ephemeral_key=new_ephemeral_key)
self.remote_lterm_shared = None
self.local_lterm_shared = None

View File

@ -1,88 +0,0 @@
from asyncio import open_connection, start_server
from ..boxstream import get_stream_pair
from .crypto import SHSClientCrypto, SHSServerCrypto
class SHSClientException(Exception):
pass
class SHSSocket(object):
async def read(self):
async for msg in self.read_stream.process():
yield msg
async def write(self, data):
await self.write_stream.write(data)
async def disconnect(self):
self.writer.close()
class SHSServer(SHSSocket):
def __init__(self, host, port, server_kp, application_key=None):
self.host = host
self.port = port
self.crypto = SHSServerCrypto(server_kp.private_key, application_key=application_key)
self._on_connect = None
async def _handshake(self, reader, writer):
data = await reader.read(64)
if not self.crypto.verify_challenge(data):
raise SHSClientException('Client challenge is not valid')
writer.write(self.crypto.generate_challenge())
data = await reader.read(112)
if not self.crypto.verify_client_auth(data):
raise SHSClientException('Client auth is not valid')
writer.write(self.crypto.generate_accept())
async def handle_connection(self, reader, writer):
self.crypto.clean()
await self._handshake(reader, writer)
keys = self.crypto.get_box_keys()
self.crypto.clean()
self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys)
self.writer = writer
if self._on_connect:
await self._on_connect()
async def listen(self, loop=None):
await start_server(self.handle_connection, self.host, self.port, loop=loop)
class SHSClient(SHSSocket):
def __init__(self, host, port, client_kp, server_pub_key, ephemeral_key=None, application_key=None):
self.host = host
self.port = port
self.crypto = SHSClientCrypto(client_kp.private_key, server_pub_key, ephemeral_key=ephemeral_key,
application_key=application_key)
async def _handshake(self, reader, writer):
writer.write(self.crypto.generate_challenge())
data = await reader.read(64)
if not self.crypto.verify_server_challenge(data):
raise SHSClientException('Server challenge is not valid')
writer.write(self.crypto.generate_client_auth())
data = await reader.read(80)
if not self.crypto.verify_server_accept(data):
raise SHSClientException('Server accept is not valid')
async def connect(self, loop=None):
reader, writer = await open_connection(self.host, self.port, loop=loop)
await self._handshake(reader, writer)
keys = self.crypto.get_box_keys()
self.crypto.clean()
self.read_stream, self.write_stream = get_stream_pair(reader, writer, **keys)
self.writer = writer

View File

@ -1,17 +0,0 @@
import hashlib
import pytest
from nacl.public import PrivateKey
from ssb.shs import SecretHandShake
@pytest.fixture()
def appkey():
return hashlib.sha256(b'app_key').digest()
def test_client_challenge(appkey):
pk = PrivateKey.generate()
shs = SecretHandShake(pk, application_key=appkey)
assert shs.client_challenge

View File

@ -1,19 +1,28 @@
import os
from asyncio import get_event_loop from asyncio import get_event_loop
from base64 import b64decode from base64 import b64decode
from ssb.keys import KeyPair import yaml
from nacl.signing import SigningKey
from ssb.packet_stream import PSClient from ssb.packet_stream import PSClient
server_pub_key = b64decode('--- your public key ---')
with open(os.path.expanduser('~/.ssb/secret')) as f:
config = yaml.load(f)
server_pub_key = b64decode(config['public'][:-8])
async def main(loop): async def main(loop):
await packet_stream.connect(loop) async for msg in packet_stream:
async for msg in packet_stream.read():
print(msg) print(msg)
print('bye') print('bye')
packet_stream = PSClient('127.0.0.1', 8008, KeyPair(), server_pub_key)
loop = get_event_loop() loop = get_event_loop()
packet_stream = PSClient('127.0.0.1', 8008, SigningKey.generate(), server_pub_key, loop=loop)
packet_stream.connect()
loop.run_until_complete(main(loop)) loop.run_until_complete(main(loop))
loop.close() loop.close()

View File

@ -1,31 +1,36 @@
import os
from asyncio import get_event_loop from asyncio import get_event_loop
from base64 import b64decode from base64 import b64decode
from ssb.keys import KeyPair import yaml
from nacl.signing import SigningKey
from ssb.packet_stream import PSServer from ssb.packet_stream import PSServer
priv_key = b64decode('--- your private key ---') with open(os.path.expanduser('~/.ssb/secret')) as f:
config = yaml.load(f)
async def main(loop):
await packet_stream.listen(loop)
async def on_connect(server): async def on_connect(server):
await server.write('JSON', {"name": ["createHistoryStream"], server.write('JSON', {
"args": [{ "name": ["createHistoryStream"],
"id": "@/Odg52x38pt7OivNnxK1Lm+H/+x6yV4DhMeXHBQRYc0=.ed25519", "args": [{
"seq": 9, "id": "@/Odg52x38pt7OivNnxK1Lm+H/+x6yV4DhMeXHBQRYc0=.ed25519",
"live": True, "seq": 9,
"keys": False "live": True,
}], "keys": False
"type": "source"}, req=1) }],
print(await server.read().__anext__()) "type": "source"}, req=1)
print(await server.read())
server.write('JSON', {})
packet_stream = PSServer('127.0.0.1', 8008, KeyPair(priv_key[:32]))
packet_stream.on_connect(on_connect)
loop = get_event_loop() loop = get_event_loop()
loop.run_until_complete(main(loop))
server_keypair = SigningKey(b64decode(config['private'][:-8])[:32])
packet_stream = PSServer('127.0.0.1', 8008, server_keypair, loop=loop)
packet_stream.on_connect(on_connect)
packet_stream.listen()
loop.run_forever() loop.run_forever()
loop.close() loop.close()