Source code for websocket.server

"""Usage

>>> loop = asyncio.get_event_loop()
>>> socket = WebSocketServer("localhost", 3001, loop=loop)
...
>>> @socket.connection
>>> async def on_connection(client: Client):
...     logger.info(f'Connection from {client.addr, client.port}')
...     logger.info(f'All clients: {socket.clients}')
...
...     @client.message
...     async def on_message(reader: WebSocketReader):
...         await client.writer.send(await reader.get())
...
>>> with socket as server:
...     print(f'Serving on {server.sockets[0].getsockname()}')
...     loop.run_forever()
...
>>> loop.close()
"""

import asyncio
import logging
import ssl

import time

from .client import Client, HANDLERS
from .enums import State
from .http import handshake
from .reasons import Reasons
from .stream.reader import WebSocketReader

logger = logging.getLogger(__name__)


[docs]class WebSocketServer: """ :ivar addr: The server IPv4 or IPv6 address. :type addr: str :ivar port: The server port. :type port: int :ivar timeout: The timeout in seconds for clients, if they don't respond to pings. Set to 0 to disable heartbeat. :type timeout: float :ivar certs: SSL certificates :type certs: (certfile, keyfile) :ivar clients: All of the connected clients. :type clients: {(str, int): Client} :ivar loop: The event loop to run in. :type loop: AbstractEventLoop """ NEWLINE = b'\r\n' def __init__(self, addr, port, certs=None, loop=None, timeout=120): if loop is None: self.loop = asyncio.get_event_loop() else: self.loop = loop self._on_connection = None self.addr = addr self.port = port self.server = None self.certs = certs self.client_timeout = timeout self.clients = {} self.keepalive_task = None async def keepalive(self, timeout): try: half = int(timeout / 2) while True: await asyncio.sleep(half, loop=self.loop) logger.info("Sending heartbeats") cur_time = time.time() for client in self.clients.values(): diff = cur_time - client.last_message if diff > timeout: logger.warning(f"Cleaning up non-responsive client {client.addr, client.port}") self.disconnect_client(client, code=Reasons.POLICY_VIOLATION.value.code, reason='Client did not respond to heartbeat.') elif diff > half: client.writer.ping('heartbeat') except asyncio.CancelledError: pass def __enter__(self): """Start the server when entering the context manager.""" context = None if self.certs: crt, key = self.certs context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) context.load_cert_chain(certfile=crt, keyfile=key) coro = asyncio.start_server(self.socket_connect, self.addr, self.port, loop=self.loop, ssl=context) self.server = self.loop.run_until_complete(coro) if self.client_timeout > 0: self.keepalive_task = self.loop.create_task(self.keepalive(self.client_timeout)) return self.server def __exit__(self, exc_type, exc_val, exc_tb): """Stop server when exiting context manager""" if self.keepalive_task is not None: self.keepalive_task.cancel() self.loop.run_until_complete(self.disconnect_all()) self.server.close() self.loop.run_until_complete(self.server.wait_closed())
[docs] def connection(self, fn): """Decorator for registering the on_connection callback. :param fn: The callback to register. The callback should be async and take one parameter, :class:`~websocket.client.Client`. This callback is called when a new client connects with the websocket. >>> @socket.connection >>> async def on_connection(client: Client): ... for other_client in socket.clients.vals(): ... other_client.writer.send("New client connected.") ... ... @client.message ... async def on_message(reader: WebSocketReader): ... await client.writer.send(await reader.get()) """ self._on_connection = fn
async def connect_client(self, client): await self._on_connection(client) self.clients[client.addr, client.port] = client async def disconnect_all(self, timeout=1): done, pending = await asyncio.wait(map(self.disconnect_client, self.clients.values()), loop=self.loop, timeout=timeout) number_pending = len(pending) if number_pending > 0: logger.warning(f"{number_pending} futures failed to disconnect in {timeout} second(s), cancelling them.") for future in pending: future.cancel()
[docs] async def disconnect_client(self, client, code=Reasons.NORMAL.value.code, reason=''): """This method is the only clean way to close a connection with a client. >>> @socket.connection >>> async def on_connection(client: Client): ... print("Client connected, disconnecting it...") ... socket.disconnect_client(client) :param client: The client to disconnect from the server. :param code: The code to close the connection with, make sure it is valid. Default is :attr:`websocket.reasons.Reasons.NORMAL.value.code` :type code: bytes :param reason: The reason for closing the connection, may be ''. Should not be longer than 123 characters. :type reason: str """ await client.close(code, reason) del self.clients[client.addr, client.port]
def delete_client(self, addr, port): try: del self.clients[addr, port] except KeyError: pass async def socket_connect(self, reader, writer): addr, port, *_ = writer.get_extra_info('peername') try: logger.debug(f"Client {addr, port} connected, attempting handshake.") state = await self.handle_handshake(reader, writer) if state != State.OPEN: logger.warning(f"Handshake with client {addr, port} failed.") else: logger.debug(f"Handshake with client {addr, port} successful.") client = Client(state, addr, port, writer, self.loop) self.loop.create_task(self.connect_client(client)) while client.state != State.CLOSING: client.read_task = self.loop.create_task(reader.readexactly(1)) try: data = (await client.read_task)[0] client.tick() if data & WebSocketReader.RSV_BITS > 0: logger.warning("No extension defining RSV meaning has been negotiated") client.ensure_clean_close() await client.close_with_read(reader, Reasons.PROTOCOL_ERROR.value, "RSV bit(s) set") continue # Find the correct handler based on the opcode # Then pass it its arguments, including the fin flag await HANDLERS[data & WebSocketReader.OP_CODE_BITS](client, reader, (data & WebSocketReader.FIN_BIT) != 0) except asyncio.CancelledError: continue # Someone has cancelled the read task, check for new state except ConnectionResetError: logger.warning(f"Client {addr, port} has forcibly closed the connection.") except: raise finally: logger.debug(f"Closing connection with client {addr, port}.") writer.close() self.delete_client(addr, port)
[docs] def wait(self, fut, timeout): """Helper method for creating a future that times out after a timeout. :param fut: The future to time. :param timeout: The timeout in seconds. :return: future """ return asyncio.wait_for(fut, timeout=timeout, loop=self.loop)
async def handle_handshake(self, reader, writer): try: request_line = await self.wait(reader.readuntil(WebSocketServer.NEWLINE), 1) request = handshake.Request(request_line.decode()) header = await self.wait(reader.readuntil(WebSocketServer.NEWLINE), 1) while header != WebSocketServer.NEWLINE: request.header(header.decode()) header = await self.wait(reader.readuntil(WebSocketServer.NEWLINE), 1) response = request.validate_websocket_request() writer.write(response) return State.OPEN except handshake.BadRequestException: logger.exception("Failed to handshake.") writer.write(handshake.BadRequestException.RESPONSE) return State.CLOSING except (asyncio.streams.IncompleteReadError, asyncio.TimeoutError): logger.exception("Could not read stream, not http. Possibly https?.") writer.write(handshake.BadRequestException.RESPONSE) return State.CLOSING finally: await writer.drain()