Source code for websocket.stream.reader

"""
You should not make an instance of the WebSocketReader class yourself, rather you should only make use of it through a 
callback registerd with :meth:`~websocket.client.Client.message`

>>> @client.message
>>> async def on_message(reader: WebSocketReader):
...     # Read from the reader here...
...     print(await reader.get())
"""

import asyncio
import codecs
import logging
import struct

from websocket.reasons import Reasons
from websocket.stream import buffer
from websocket.stream.writer import WebSocketWriter
from ..enums import DataType

logger = logging.getLogger(__name__)


[docs]class WebSocketReader(buffer.Buffer): """ :ivar data_type: The type of data frame the client sent us, this is the default kind for :meth:`get`. :type data_type: :class:`~websocket.enums.DataType` """ BUFFER_SIZE = 1024 QUE_MAXSIZE = 12 MASK_BIT = 1 << 7 FIN_BIT = 1 << 7 RSV_BITS = 0b111 << 4 OP_CODE_BITS = 0b1111 decoder_factory = codecs.getincrementaldecoder('utf8') def __init__(self, kind, client, loop): super().__init__(WebSocketReader.BUFFER_SIZE * WebSocketReader.QUE_MAXSIZE, loop) self.data_type = kind self.client = client self.decoder = WebSocketReader.decoder_factory() self.que = asyncio.Queue(WebSocketReader.QUE_MAXSIZE) self.reading = True if self.data_type is DataType.TEXT: self.processor = asyncio.ensure_future(self.process_text(), loop=self._loop) else: self.processor = asyncio.ensure_future(self.process_binary(), loop=self._loop)
[docs] async def get(self, kind=None): """Reads all of the bytes from the stream. :param kind: Specifies the type of data returned, default is :attr:`~websocket.stream.reader.WebSocketReader.data_type` :type kind: :class:`~websocket.enums.DataType` :return: :class:`bytes` if kind is DataType.BINARY, :class:`str` if kind is DataType.TEXT """ if kind is None: kind = self.data_type data = await self.read() if kind == DataType.TEXT: return data.decode() elif kind == DataType.BINARY: return data
def done(self): asyncio.ensure_future(self.adone(), loop=self._loop) async def adone(self): self.reading = False try: if self.processor.done(): exc = self.processor.exception() if exc: raise exc if self.que.empty(): self.processor.cancel() await self.processor else: await self.processor if self.data_type is DataType.TEXT: self.decoder.decode(b'', True) self.feed_eof() except UnicodeDecodeError as e: self.set_exception(e) await self.client.close(Reasons.INCONSISTENT_DATA.value, f"{e.object[e.start:e.end]} at {e.start}-{e.end}: {e.reason}"[:WebSocketWriter.MAX_LEN_7]) async def process_text(self): try: while not self.que.empty() or self.reading: data, length, mask = await self.que.get() data = bytearray(data) for i in range(length): data[i] ^= mask[i % 4] self.decoder.decode(data) await self.write(data) except UnicodeDecodeError as e: self.set_exception(e) raise except asyncio.CancelledError: pass async def process_binary(self): try: while not self.que.empty() or self.reading: data, length, mask = await self.que.get() data = bytearray(data) for i in range(length): data[i] ^= mask[i % 4] await self.write(data) except asyncio.CancelledError: pass async def feed_once(self, reader): length = await self.feed(reader) self.done() return length async def feed(self, reader): data = await reader.readexactly(1) mask_flag = data[0] & WebSocketReader.MASK_BIT length = data[0] & ~WebSocketReader.MASK_BIT # "The form '!' is available for those poor souls who claim they can’t remember whether network byte order is # big-endian or little-endian." # - <https://docs.python.org/3/library/struct.html> if length == 126: data = await reader.readexactly(2) length, = struct.unpack('!H', data) elif length == 127: data = await reader.readexactly(8) length, = struct.unpack('!Q', data) left_to_read = length if not mask_flag: # TODO: Reject frame per <https://tools.ietf.org/html/rfc6455#section-5.1> logger.warning("Received message from client without mask.") await reader.readexactly(left_to_read) raise Exception("Received message from client without mask.") # TODO: HANDLE mask = await reader.readexactly(4) first_read_size = min(WebSocketReader.BUFFER_SIZE, left_to_read) await self.que.put((await reader.readexactly(first_read_size), first_read_size, mask)) left_to_read -= first_read_size while left_to_read > WebSocketReader.BUFFER_SIZE: await self.que.put( (await reader.readexactly(WebSocketReader.BUFFER_SIZE), WebSocketReader.BUFFER_SIZE, mask)) left_to_read -= WebSocketReader.BUFFER_SIZE if left_to_read > 0: await self.que.put((await reader.readexactly(left_to_read), left_to_read, mask)) return length