Source code for websocket.client

"""
You should not make an instance of the Client class yourself, rather you should listen for new connections with 
:meth:`~websocket.server.WebSocketServer.connection`

>>> @socket.connection
>>> async def on_connection(client: Client):
...     # Here you can use the client, register callbacks on it or send it messages
...     await client.writer.ping()
"""
import asyncio
import logging

import time

from .enums import DataType, State
from .reasons import Reasons, Reason
from .stream.reader import WebSocketReader
from .stream.writer import WebSocketWriter

logger = logging.getLogger(__name__)


class NoCallbackException(Exception):
    pass


class UnexpectedFrameException(Exception):
    def __init__(self, client, recv, expect):
        super().__init__(f"Received unexpected {recv.name.lower()} frame from client {client.addr, client.port}, "
                         f"expected {expect.name.lower()}.")

        self.recieved = recv
        self.expected = expect
        self.client = client


class ConnectionClosed(Exception):
    def __init__(self):
        super().__init__("Closing connection in middle of message.")


[docs]class Client: """ :ivar addr: IPv4 or IPv6 address of the client. :type addr: str :ivar port: The port the client opened it's socket on. :type port: int :ivar writer: The writer used for writing frames to the client. :type writer: WebSocketWriter """ def __init__(self, state, addr, port, writer, loop): self.last_message = time.time() self.state = state self.addr = addr self.port = port self.data_type = DataType.NONE self.writer = WebSocketWriter(writer, loop) self._reader = None self.read_task = None self.continuation = DataType.NONE self.server_has_initiated_close = False self._loop = loop @self.message async def on_message(reader): raise NoCallbackException("No message callback defined.") @self.ping async def on_ping(payload, length): await self.writer.pong(length, payload) @self.pong async def on_pong(payload, length): pass @self.closed async def on_closed(code, reason): pass
[docs] def message(self, fn): """Decorator for registering the on_message callback. :param fn: The callback to register. The callback should be async and take one parameter, a :class:`~websocket.stream.reader.WebSocketReader` This callback is called when the server receives an valid data frame, if an exception occurs after the first valid frame e.g. if an text frame contains invalid utf-8, or if it's an invalid fragmented message, then we send the exception to the reader with :meth:`~websocket.stream.buffer.Buffer.set_exception`. >>> @client.message >>> async def on_message(reader: WebSocketReader): ... print("Got message " + await reader.get()) """ self.on_message = fn
[docs] def ping(self, fn): """Decorator for registering the on_ping callback. :param fn: The callback to register. If you set this callback you will override the default behaviour of sending pongs back to the client when receiving pings. If you want to keep this behaviour call :meth:`~websocket.stream.writer.WebSocketWriter.pong`. The callback should be async and take two parameters, :class:`bytes` payload, and :class:`int` length. This callback is called when we receive a valid ping from the client. >>> @client.ping >>> async def on_ping(payload: bytes, length: int): ... print("Received ping from client") ... await self.writer.pong(length, payload) """ self.on_ping = fn
[docs] def pong(self, fn): """Decorator for registering the on_pong callback. :param fn: The callback to register. The callback should be async and take two parameters, :class:`bytes` payload, and :class:`int` length This callback is called when we receive a valid pong from the client. >>> @client.pong >>> async def on_pong(payload: bytes, length: int): ... print("Received pong from client") """ self.on_pong = fn
[docs] def closed(self, fn): """Decorator for registering the on_closed callback. :param fn: The callback to register. The callback should be async and take two parameters, :class:`bytes` code of length 2, and :class:`str` reason. This callback is called when the connection this this client is closing. >>> @client.closed >>> async def on_closed(code: bytes, reason: str): ... print("Connection with client is closing for " + reason) """ self.on_closed = fn
async def close_with_read(self, reader, code, reason): close = asyncio.ensure_future(self.close(code, reason), loop=self._loop) buffer = WebSocketReader(DataType.BINARY, self, self._loop) length = await buffer.feed(reader) buffer.done() logger.debug("1") data = await buffer.read(length) logger.debug("2") await close return data async def close(self, code: bytes, reason: str): if not self.server_has_initiated_close: asyncio.ensure_future(self.on_closed(code, reason), loop=self._loop) self.server_has_initiated_close = True await self.writer.close(code, reason) # TODO: Kill in 5 secs if client dont respond async def _read_message(self, reader, fin): await self._reader.feed(reader) if fin: self.continuation = DataType.NONE self._reader.done() else: self.continuation = self._reader.data_type @staticmethod def handle_data(kind): async def handler(self, reader, fin): if self.continuation != DataType.NONE: self._reader.set_exception(UnexpectedFrameException(self, kind, DataType.CONTINUATION)) self._reader.done() await self.close_with_read(reader, Reasons.PROTOCOL_ERROR.value, "expected continuation frame") return logger.debug(f"Received {kind.name.lower()} data frame from client {self.addr, self.port}.") self.type = kind self._reader = WebSocketReader(kind, self, self._loop) self._loop.create_task(self.on_message(self._reader)) return await self._read_message(reader, fin) return handler async def handle_continuation(self, reader, fin): if self.continuation == DataType.NONE: logger.debug("Received unexpected continuation data frame from client " f"{self.addr, self.port}, expected {self.continuation.name.lower()}.") await self.close_with_read(reader, Reasons.PROTOCOL_ERROR.value, f"expected {self.continuation.name.lower()} frame") return logger.debug(f"Received continuation frame from client {self.addr, self.port}.") await self._read_message(reader, fin) def ensure_clean_close(self): if self.continuation != DataType.NONE: self._reader.set_exception(ConnectionClosed()) self._reader.done() @staticmethod def handle_ping_or_pong(kind): async def handler(self, reader, fin): buffer = WebSocketReader(DataType.BINARY, self, self._loop) feed = asyncio.ensure_future(buffer.feed_once(reader), loop=self._loop) if not fin or self.server_has_initiated_close: if not fin: logger.warning(f"Received fragmented {kind.name.lower()} from client {self.addr, self.port}.") self.ensure_clean_close() await self.close(Reasons.PROTOCOL_ERROR.value, "fragmented control frame") else: logger.warning(f"Received {kind.name.lower()} from client {self.addr, self.port} after server " "initiated close.") self.ensure_clean_close() await self.close(Reasons.POLICY_VIOLATION.value, "control frame after close") await feed return length = await feed if length > 125: logger.warning(f"{kind.name.lower()} payload too long({length} bytes).") self.ensure_clean_close() await self.close(Reasons.PROTOCOL_ERROR.value, "control frame too long") return logger.debug(f"Received {kind.name.lower()} from client {self.addr, self.port}.") data = await buffer.read(length) if kind is DataType.PING: self._loop.create_task(self.on_ping(data, length)) elif kind is DataType.PONG: self._loop.create_task(self.on_pong(data, length)) buffer.done() return handler async def handle_close(self, reader, fin): logger.debug(f"Received close from client {self.addr, self.port}.") buffer = WebSocketReader(DataType.BINARY, self, self._loop) length = await buffer.feed_once(reader) reason = await buffer.read(length) if not self.server_has_initiated_close: if length > WebSocketWriter.MAX_LEN_7: code, reason = Reasons.PROTOCOL_ERROR.value, "control frame too long" else: code, reason = Reason.from_bytes(reason, length) if code == Reasons.NO_STATUS.value: code = Reasons.NORMAL.value self.ensure_clean_close() await self.close(code, reason) self.state = State.CLOSING if self.read_task is not None: self.read_task.cancel() async def handle_undefined(self, reader, fin): logger.debug(f"Received invalid opcode from client {self.addr, self.port}.") await self.close_with_read(reader, Reasons.PROTOCOL_ERROR.value, "invalid opcode") def tick(self): self.last_message = time.time()
HANDLERS = {opcode: Client.handle_undefined for opcode in range(0, 1 << 4)} HANDLERS.update({ DataType.CONTINUATION.value: Client.handle_continuation, DataType.TEXT.value: Client.handle_data(DataType.TEXT), DataType.BINARY.value: Client.handle_data(DataType.BINARY), DataType.CLOSE.value: Client.handle_close, DataType.PING.value: Client.handle_ping_or_pong(DataType.PING), DataType.PONG.value: Client.handle_ping_or_pong(DataType.PONG), })