"""WebSocket server and broadcast helpers for hbd. Provides an asyncio-based WebSocket server and a thread-safe broadcast function that other threads or synchronous code can call. """ import asyncio import json import logging from typing import Callable, Iterable, Optional from . import data import websockets logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) _connections = set() _loop: Optional[asyncio.AbstractEventLoop] = None _get_hosts: Optional[Callable[[], Iterable]] = None #_get_msgs: Optional[Callable[[], Iterable]] = None _verbose = False async def _handler(websocket, path=None): _connections.add(websocket) remote_address = websocket.remote_address if path is None: path = getattr(websocket, "path", None) logger.info("WebSocket connection from %s: %s", remote_address, path) try: # send initial hosts if _get_hosts: try: hosts = list(_get_hosts()) logger.debug("Sending %d hosts to new WebSocket client", len(hosts)) for h in hosts: jmsg = json.dumps({"type": "host", "data": h}) await websocket.send(jmsg) except Exception as e: logger.error("Error sending initial hosts: %s", e, exc_info=True) # send recent messages if data.msgs: try: # msgs = list(_get_msgs())[-100:] logger.debug("Sending %d recent messages to new WebSocket client", len(data.msgs)) for m in data.msgs: jmsg = json.dumps({"type": "message", "data": m}) await websocket.send(jmsg) except Exception as e: logger.error("Error sending initial messages: %s", e, exc_info=True) # keep connection open until client disconnects async for _ in websocket: # we don't expect meaningful incoming messages besides the initial # client 'hello' that some clients send; ignore for now if _verbose: logger.debug("received ws data: %s", _) except ( websockets.exceptions.ConnectionClosedOK, websockets.exceptions.ConnectionClosedError, ) as e: logger.info("WebSocket closed from %s: %r", remote_address, e) except Exception as e: logger.exception("WebSocket handler exception from %s: %s", remote_address, e) finally: logger.debug("Removing WebSocket connection from %s", remote_address) _connections.discard(websocket) async def start( host: str, ws_port: int, wss_port: Optional[int] = None, ssl_context=None, get_hosts: Optional[Callable] = None, # get_msgs: Optional[Callable] = None, config: dict = {}, ): """Start WebSocket servers and block until cancelled. This is intended to be awaited inside the main asyncio event loop. If `wss_port` and `ssl_context` are provided, a WSS server will also be started. """ global _loop, _get_hosts, _verbose _loop = asyncio.get_running_loop() _get_hosts = get_hosts _verbose = config.get("verbose", False), _debug = config.get("debug", 0), # Start servers and keep the server objects for clean shutdown running_servers = [] ws_server = await websockets.serve(_handler, host, ws_port) running_servers.append(ws_server) if wss_port and ssl_context: wss_server = await websockets.serve(_handler, host, wss_port, ssl=ssl_context) running_servers.append(wss_server) logger.info( "WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port ) try: # Block until cancelled await asyncio.Future() except asyncio.CancelledError: pass finally: # Close all active browser connections so their handler coroutines exit active = list(_connections) if active: logger.info("Closing %d active WebSocket connection(s)...", len(active)) await asyncio.gather( *[ws.close() for ws in active], return_exceptions=True, ) # Stop the listening servers and wait for all handlers to finish for srv in running_servers: srv.close() await asyncio.gather( *[srv.wait_closed() for srv in running_servers], return_exceptions=True, ) logger.info("WebSocket server(s) stopped") def broadcast(typ: str, data) -> bool: """Thread-safe broadcast helper. Schedules coroutine(s) on the running loop to send message to all connected websockets. Returns False if server was not running. """ if not _loop: return False jmsg = json.dumps({"type": typ, "data": data}) to_close = [] for ws in list(_connections): if ws.state != websockets.protocol.State.OPEN: to_close.append(ws) continue try: asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop) except Exception: to_close.append(ws) logger.debug("ws.send exception: closed") for ws in to_close: try: asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop) except Exception: pass if ws in _connections: _connections.remove(ws) return True def connection_count() -> int: return len(_connections)