Start using readexaclty()

Since read(N) is not guaranteed to read exactly N bytes
This commit is contained in:
Pedro Ferreira 2017-06-05 22:49:27 +02:00
parent 0a9ebc09bf
commit 18489de6d4
3 changed files with 11 additions and 8 deletions

View File

@ -1,4 +1,6 @@
import struct import struct
from asyncio import IncompleteReadError
from nacl.secret import SecretBox from nacl.secret import SecretBox
from .util import split_chunks, inc_nonce from .util import split_chunks, inc_nonce
@ -32,9 +34,9 @@ class UnboxStream(object):
self.closed = False self.closed = False
async def read(self): async def read(self):
data = await self.reader.read(HEADER_LENGTH) try:
data = await self.reader.readexactly(HEADER_LENGTH)
if not data: except IncompleteReadError:
self.closed = True self.closed = True
return None return None
@ -49,7 +51,7 @@ class UnboxStream(object):
length = struct.unpack('>H', header[:2])[0] length = struct.unpack('>H', header[:2])[0]
mac = header[2:] mac = header[2:]
data = await self.reader.read(length) data = await self.reader.readexactly(length)
body = box.decrypt(mac + data, inc_nonce(self.nonce)) body = box.decrypt(mac + data, inc_nonce(self.nonce))

View File

@ -59,13 +59,13 @@ class SHSServer(SHSSocket):
self.crypto = SHSServerCrypto(server_kp, application_key=application_key) self.crypto = SHSServerCrypto(server_kp, application_key=application_key)
async def _handshake(self, reader, writer): async def _handshake(self, reader, writer):
data = await reader.read(64) data = await reader.readexactly(64)
if not self.crypto.verify_challenge(data): if not self.crypto.verify_challenge(data):
raise SHSClientException('Client challenge is not valid') raise SHSClientException('Client challenge is not valid')
writer.write(self.crypto.generate_challenge()) writer.write(self.crypto.generate_challenge())
data = await reader.read(112) data = await reader.readexactly(112)
if not self.crypto.verify_client_auth(data): if not self.crypto.verify_client_auth(data):
raise SHSClientException('Client auth is not valid') raise SHSClientException('Client auth is not valid')
@ -99,13 +99,13 @@ class SHSClient(SHSSocket):
async def _handshake(self, reader, writer): async def _handshake(self, reader, writer):
writer.write(self.crypto.generate_challenge()) writer.write(self.crypto.generate_challenge())
data = await reader.read(64) data = await reader.readexactly(64)
if not self.crypto.verify_server_challenge(data): if not self.crypto.verify_server_challenge(data):
raise SHSClientException('Server challenge is not valid') raise SHSClientException('Server challenge is not valid')
writer.write(self.crypto.generate_client_auth()) writer.write(self.crypto.generate_client_auth())
data = await reader.read(80) data = await reader.readexactly(80)
if not self.crypto.verify_server_accept(data): if not self.crypto.verify_server_accept(data):
raise SHSClientException('Server accept is not valid') raise SHSClientException('Server accept is not valid')

View File

@ -37,6 +37,7 @@ class AsyncBuffer(BytesIO):
"""Just a BytesIO with an async read method.""" """Just a BytesIO with an async read method."""
async def read(self, n=None): async def read(self, n=None):
return super(AsyncBuffer, self).read(n) return super(AsyncBuffer, self).read(n)
readexactly = read
@pytest.mark.asyncio @pytest.mark.asyncio