Start using readexaclty()
Since read(N) is not guaranteed to read exactly N bytes
This commit is contained in:
parent
0a9ebc09bf
commit
18489de6d4
@ -1,4 +1,6 @@
|
|||||||
import struct
|
import struct
|
||||||
|
from asyncio import IncompleteReadError
|
||||||
|
|
||||||
from nacl.secret import SecretBox
|
from nacl.secret import SecretBox
|
||||||
|
|
||||||
from .util import split_chunks, inc_nonce
|
from .util import split_chunks, inc_nonce
|
||||||
@ -32,9 +34,9 @@ class UnboxStream(object):
|
|||||||
self.closed = False
|
self.closed = False
|
||||||
|
|
||||||
async def read(self):
|
async def read(self):
|
||||||
data = await self.reader.read(HEADER_LENGTH)
|
try:
|
||||||
|
data = await self.reader.readexactly(HEADER_LENGTH)
|
||||||
if not data:
|
except IncompleteReadError:
|
||||||
self.closed = True
|
self.closed = True
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -49,7 +51,7 @@ class UnboxStream(object):
|
|||||||
length = struct.unpack('>H', header[:2])[0]
|
length = struct.unpack('>H', header[:2])[0]
|
||||||
mac = header[2:]
|
mac = header[2:]
|
||||||
|
|
||||||
data = await self.reader.read(length)
|
data = await self.reader.readexactly(length)
|
||||||
|
|
||||||
body = box.decrypt(mac + data, inc_nonce(self.nonce))
|
body = box.decrypt(mac + data, inc_nonce(self.nonce))
|
||||||
|
|
||||||
|
@ -59,13 +59,13 @@ class SHSServer(SHSSocket):
|
|||||||
self.crypto = SHSServerCrypto(server_kp, application_key=application_key)
|
self.crypto = SHSServerCrypto(server_kp, application_key=application_key)
|
||||||
|
|
||||||
async def _handshake(self, reader, writer):
|
async def _handshake(self, reader, writer):
|
||||||
data = await reader.read(64)
|
data = await reader.readexactly(64)
|
||||||
if not self.crypto.verify_challenge(data):
|
if not self.crypto.verify_challenge(data):
|
||||||
raise SHSClientException('Client challenge is not valid')
|
raise SHSClientException('Client challenge is not valid')
|
||||||
|
|
||||||
writer.write(self.crypto.generate_challenge())
|
writer.write(self.crypto.generate_challenge())
|
||||||
|
|
||||||
data = await reader.read(112)
|
data = await reader.readexactly(112)
|
||||||
if not self.crypto.verify_client_auth(data):
|
if not self.crypto.verify_client_auth(data):
|
||||||
raise SHSClientException('Client auth is not valid')
|
raise SHSClientException('Client auth is not valid')
|
||||||
|
|
||||||
@ -99,13 +99,13 @@ class SHSClient(SHSSocket):
|
|||||||
async def _handshake(self, reader, writer):
|
async def _handshake(self, reader, writer):
|
||||||
writer.write(self.crypto.generate_challenge())
|
writer.write(self.crypto.generate_challenge())
|
||||||
|
|
||||||
data = await reader.read(64)
|
data = await reader.readexactly(64)
|
||||||
if not self.crypto.verify_server_challenge(data):
|
if not self.crypto.verify_server_challenge(data):
|
||||||
raise SHSClientException('Server challenge is not valid')
|
raise SHSClientException('Server challenge is not valid')
|
||||||
|
|
||||||
writer.write(self.crypto.generate_client_auth())
|
writer.write(self.crypto.generate_client_auth())
|
||||||
|
|
||||||
data = await reader.read(80)
|
data = await reader.readexactly(80)
|
||||||
if not self.crypto.verify_server_accept(data):
|
if not self.crypto.verify_server_accept(data):
|
||||||
raise SHSClientException('Server accept is not valid')
|
raise SHSClientException('Server accept is not valid')
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ class AsyncBuffer(BytesIO):
|
|||||||
"""Just a BytesIO with an async read method."""
|
"""Just a BytesIO with an async read method."""
|
||||||
async def read(self, n=None):
|
async def read(self, n=None):
|
||||||
return super(AsyncBuffer, self).read(n)
|
return super(AsyncBuffer, self).read(n)
|
||||||
|
readexactly = read
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
Reference in New Issue
Block a user