Start using readexaclty()

Since read(N) is not guaranteed to read exactly N bytes
This commit is contained in:
Pedro Ferreira 2017-06-05 22:49:27 +02:00
parent 0a9ebc09bf
commit 18489de6d4
3 changed files with 11 additions and 8 deletions

View File

@ -1,4 +1,6 @@
import struct
from asyncio import IncompleteReadError
from nacl.secret import SecretBox
from .util import split_chunks, inc_nonce
@ -32,9 +34,9 @@ class UnboxStream(object):
self.closed = False
async def read(self):
data = await self.reader.read(HEADER_LENGTH)
if not data:
try:
data = await self.reader.readexactly(HEADER_LENGTH)
except IncompleteReadError:
self.closed = True
return None
@ -49,7 +51,7 @@ class UnboxStream(object):
length = struct.unpack('>H', header[:2])[0]
mac = header[2:]
data = await self.reader.read(length)
data = await self.reader.readexactly(length)
body = box.decrypt(mac + data, inc_nonce(self.nonce))

View File

@ -59,13 +59,13 @@ class SHSServer(SHSSocket):
self.crypto = SHSServerCrypto(server_kp, application_key=application_key)
async def _handshake(self, reader, writer):
data = await reader.read(64)
data = await reader.readexactly(64)
if not self.crypto.verify_challenge(data):
raise SHSClientException('Client challenge is not valid')
writer.write(self.crypto.generate_challenge())
data = await reader.read(112)
data = await reader.readexactly(112)
if not self.crypto.verify_client_auth(data):
raise SHSClientException('Client auth is not valid')
@ -99,13 +99,13 @@ class SHSClient(SHSSocket):
async def _handshake(self, reader, writer):
writer.write(self.crypto.generate_challenge())
data = await reader.read(64)
data = await reader.readexactly(64)
if not self.crypto.verify_server_challenge(data):
raise SHSClientException('Server challenge is not valid')
writer.write(self.crypto.generate_client_auth())
data = await reader.read(80)
data = await reader.readexactly(80)
if not self.crypto.verify_server_accept(data):
raise SHSClientException('Server accept is not valid')

View File

@ -37,6 +37,7 @@ class AsyncBuffer(BytesIO):
"""Just a BytesIO with an async read method."""
async def read(self, n=None):
return super(AsyncBuffer, self).read(n)
readexactly = read
@pytest.mark.asyncio