Handle split packets properly
This commit is contained in:
		| @@ -65,8 +65,7 @@ class PSRequestHandler(object): | ||||
| class PSMessage(object): | ||||
|  | ||||
|     @classmethod | ||||
|     def from_header_body(cls, header, body): | ||||
|         flags, length, req = struct.unpack('>BIi', header) | ||||
|     def from_header_body(cls, flags, req, body): | ||||
|         type_ = PSMessageType(flags & 0x03) | ||||
|  | ||||
|         if type_ == PSMessageType.TEXT: | ||||
| @@ -92,7 +91,11 @@ class PSMessage(object): | ||||
|         self.req = req | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return '<PSMessage ({}): {}{} {}{}>'.format(self.type.name, self.body, | ||||
|         if self.type == PSMessageType.BUFFER: | ||||
|             body = '{} bytes'.format(len(self.body)) | ||||
|         else: | ||||
|             body = self.body | ||||
|         return '<PSMessage ({}): {}{} {}{}>'.format(self.type.name, body, | ||||
|                                                     '' if self.req is None else ' [{}]'.format(self.req), | ||||
|                                                     '~' if self.stream else '', '!' if self.end_err else '') | ||||
|  | ||||
| @@ -107,9 +110,16 @@ class PSConnection(object): | ||||
|             header = await self.connection.read() | ||||
|             if not header: | ||||
|                 return | ||||
|             body = await self.connection.read() | ||||
|             logger.debug('READ %s %s', header, body) | ||||
|             return PSMessage.from_header_body(header, body) | ||||
|             flags, length, req = struct.unpack('>BIi', header) | ||||
|  | ||||
|             n_packets = length // 4096 + 1 | ||||
|  | ||||
|             body = b'' | ||||
|             for n in range(n_packets): | ||||
|                 body += await self.connection.read() | ||||
|  | ||||
|             logger.debug('READ %s %s', header, len(body)) | ||||
|             return PSMessage.from_header_body(flags, req, body) | ||||
|         except StopAsyncIteration: | ||||
|             logger.debug('DISCONNECT') | ||||
|             await self.connection.disconnect() | ||||
|   | ||||
| @@ -50,9 +50,10 @@ async def main(): | ||||
|         handler.send(True, end=True) | ||||
|         break | ||||
|  | ||||
|     handler = api.call('blobs.add', [], 'sink') | ||||
|     handler.send(b'dead0beef', msg_type=PSMessageType.BUFFER) | ||||
|     handler.send(True, end=True) | ||||
|     async for data in api.call('blobs.get', ['&/6q7JOKythgnnzoBI5xxvotCr5HeFkAIZSAuqHiZfLw=.sha256'], 'source'): | ||||
|         if data.type.name == 'BUFFER': | ||||
|             with open('./funny_img.png', 'wb') as f: | ||||
|                 f.write(data.data) | ||||
|  | ||||
| # create console handler and set level to debug | ||||
| ch = logging.StreamHandler() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user