Handle data buffers > 4K
This commit is contained in:
parent
6be0cb2b5a
commit
b423894b96
@ -1,24 +1,11 @@
|
|||||||
import struct
|
import struct
|
||||||
from nacl.secret import SecretBox
|
from nacl.secret import SecretBox
|
||||||
|
|
||||||
from .util import bytes_to_long, long_to_bytes
|
from .util import split_chunks, inc_nonce
|
||||||
|
|
||||||
NONCE_SIZE = 24
|
|
||||||
HEADER_LENGTH = 2 + 16 + 16
|
HEADER_LENGTH = 2 + 16 + 16
|
||||||
MAX_SEGMENT_SIZE = 4 * 1024
|
MAX_SEGMENT_SIZE = 4 * 1024
|
||||||
TERMINATION_HEADER = (b'\x00' * 18)
|
TERMINATION_HEADER = (b'\x00' * 18)
|
||||||
MAX_NONCE = (8 * NONCE_SIZE)
|
|
||||||
|
|
||||||
# TODO: Implement handling of messages > 4k
|
|
||||||
|
|
||||||
|
|
||||||
def inc_nonce(nonce):
|
|
||||||
num = bytes_to_long(nonce) + 1
|
|
||||||
if num > 2 ** MAX_NONCE:
|
|
||||||
num = 0
|
|
||||||
bnum = long_to_bytes(num)
|
|
||||||
bnum = b'\x00' * (NONCE_SIZE - len(bnum)) + bnum
|
|
||||||
return bnum
|
|
||||||
|
|
||||||
|
|
||||||
def get_stream_pair(reader, writer, **kwargs):
|
def get_stream_pair(reader, writer, **kwargs):
|
||||||
@ -88,14 +75,15 @@ class BoxStream(object):
|
|||||||
|
|
||||||
# XXX: This nonce logic is almost for sure wrong
|
# XXX: This nonce logic is almost for sure wrong
|
||||||
|
|
||||||
body = self.box.encrypt(data, inc_nonce(self.nonce))[24:]
|
for chunk in split_chunks(data, MAX_SEGMENT_SIZE):
|
||||||
header = struct.pack('>H', len(body) - 16) + body[:16]
|
body = self.box.encrypt(chunk, inc_nonce(self.nonce))[24:]
|
||||||
|
header = struct.pack('>H', len(body) - 16) + body[:16]
|
||||||
|
|
||||||
hdrbox = self.box.encrypt(header, self.nonce)[24:]
|
hdrbox = self.box.encrypt(header, self.nonce)[24:]
|
||||||
self.writer.write(hdrbox)
|
self.writer.write(hdrbox)
|
||||||
|
|
||||||
self.nonce = inc_nonce(inc_nonce(self.nonce))
|
self.nonce = inc_nonce(inc_nonce(self.nonce))
|
||||||
self.writer.write(body[16:])
|
self.writer.write(body[16:])
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.writer.write(self.box.encrypt(b'\x00' * 18, self.nonce)[24:])
|
self.writer.write(self.box.encrypt(b'\x00' * 18, self.nonce)[24:])
|
||||||
|
@ -23,7 +23,7 @@ import pytest
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from .test_crypto import (CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
|
from .test_crypto import (CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
|
||||||
from secret_handshake.boxstream import BoxStream, UnboxStream
|
from secret_handshake.boxstream import BoxStream, UnboxStream, HEADER_LENGTH
|
||||||
|
|
||||||
MESSAGE_1 = (b'\xcev\xedE\x06l\x02\x13\xc8\x17V\xfa\x8bZ?\x88B%O\xb0L\x9f\x8e\x8c0y\x1dv\xc0\xc9\xf6\x9d\xc2\xdf\xdb'
|
MESSAGE_1 = (b'\xcev\xedE\x06l\x02\x13\xc8\x17V\xfa\x8bZ?\x88B%O\xb0L\x9f\x8e\x8c0y\x1dv\xc0\xc9\xf6\x9d\xc2\xdf\xdb'
|
||||||
b'\xee\x9d')
|
b'\xee\x9d')
|
||||||
@ -72,3 +72,24 @@ async def test_unboxstream():
|
|||||||
assert not unbox_stream.closed
|
assert not unbox_stream.closed
|
||||||
assert [msg async for msg in unbox_stream] == [b'foo', b'foo', b'bar']
|
assert [msg async for msg in unbox_stream] == [b'foo', b'foo', b'bar']
|
||||||
assert unbox_stream.closed
|
assert unbox_stream.closed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_packets():
|
||||||
|
data_size = 6 * 1024
|
||||||
|
data = bytes(n % 256 for n in range(data_size))
|
||||||
|
|
||||||
|
# box 6K buffer
|
||||||
|
buffer = AsyncBuffer()
|
||||||
|
box_stream = BoxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
|
||||||
|
box_stream.write(data)
|
||||||
|
# the size overhead corresponds to the two packet headers
|
||||||
|
assert buffer.tell() == data_size + (HEADER_LENGTH * 2)
|
||||||
|
buffer.seek(0)
|
||||||
|
|
||||||
|
# now let's unbox it and check whether it's OK
|
||||||
|
unbox_stream = UnboxStream(buffer, CLIENT_ENCRYPT_KEY, CLIENT_ENCRYPT_NONCE)
|
||||||
|
first_packet = await unbox_stream.read()
|
||||||
|
assert first_packet == data[:4096]
|
||||||
|
second_packet = await unbox_stream.read()
|
||||||
|
assert second_packet == data[4096:]
|
||||||
|
@ -1,8 +1,27 @@
|
|||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
NONCE_SIZE = 24
|
||||||
|
MAX_NONCE = (8 * NONCE_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
def inc_nonce(nonce):
|
||||||
|
num = bytes_to_long(nonce) + 1
|
||||||
|
if num > 2 ** MAX_NONCE:
|
||||||
|
num = 0
|
||||||
|
bnum = long_to_bytes(num)
|
||||||
|
bnum = b'\x00' * (NONCE_SIZE - len(bnum)) + bnum
|
||||||
|
return bnum
|
||||||
|
|
||||||
|
|
||||||
|
def split_chunks(seq, n):
|
||||||
|
"""Split sequence in equal-sized chunks.
|
||||||
|
The last chunk is not padded."""
|
||||||
|
while seq:
|
||||||
|
yield seq[:n]
|
||||||
|
seq = seq[n:]
|
||||||
|
|
||||||
|
|
||||||
# Stolen from PyCypto (Public Domain)
|
# Stolen from PyCypto (Public Domain)
|
||||||
|
|
||||||
|
|
||||||
def b(s):
|
def b(s):
|
||||||
return s.encode("latin-1") # utf-8 would cause some side-effects we don't want
|
return s.encode("latin-1") # utf-8 would cause some side-effects we don't want
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user