"""WebSocket server and broadcast helpers for hbd. Provides an asyncio-based WebSocket server and a thread-safe broadcast function that other threads or synchronous code can call. """ import asyncio import json import logging from typing import Callable, Iterable, Optional import websockets logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) _connections = set() _loop: Optional[asyncio.AbstractEventLoop] = None _get_hosts: Optional[Callable[[], Iterable]] = None _get_msgs: Optional[Callable[[], Iterable]] = None _verbose = False async def _handler(websocket, path=None): _connections.add(websocket) remote_address = websocket.remote_address if path is None: path = getattr(websocket, "path", None) logger.info("WebSocket connection from %s: %s", remote_address, path) try: # send initial hosts if _get_hosts: try: hosts = list(_get_hosts()) logger.debug("Sending %d hosts to new WebSocket client", len(hosts)) for h in hosts: jmsg = json.dumps({"type": "host", "data": h}) await websocket.send(jmsg) except Exception as e: logger.error("Error sending initial hosts: %s", e, exc_info=True) # send recent messages if _get_msgs: try: msgs = list(_get_msgs())[-100:] logger.debug("Sending %d recent messages to new WebSocket client", len(msgs)) for m in msgs: jmsg = json.dumps({"type": "message", "data": m}) await websocket.send(jmsg) except Exception as e: logger.error("Error sending initial messages: %s", e, exc_info=True) # keep connection open until client disconnects async for _ in websocket: # we don't expect meaningful incoming messages besides the initial # client 'hello' that some clients send; ignore for now if _verbose: logger.debug("received ws data: %s", _) except ( websockets.exceptions.ConnectionClosedOK, websockets.exceptions.ConnectionClosedError, ) as e: logger.info("WebSocket closed from %s: %r", remote_address, e) except Exception as e: logger.exception("WebSocket handler exception from %s: %s", remote_address, e) finally: logger.debug("Removing WebSocket connection from %s", remote_address) try: _connections.remove(websocket) except KeyError: pass await websocket.wait_closed() async def start( host: str, ws_port: int, wss_port: Optional[int] = None, ssl_context=None, get_hosts: Optional[Callable] = None, get_msgs: Optional[Callable] = None, config: dict = {}, ): """Start WebSocket servers and block until cancelled. This is intended to be awaited inside the main asyncio event loop. If `wss_port` and `ssl_context` are provided, a WSS server will also be started. """ global _loop, _get_hosts, _get_msgs, _verbose _loop = asyncio.get_running_loop() _get_hosts = get_hosts _get_msgs = get_msgs _verbose = config.get("verbose", False), _debug = config.get("debug", False), servers = [] # plain WebSocket websockets_logger = logging.getLogger("websockets.server") websockets_logger.setLevel(logging.DEBUG if _debug > 2 else logging.INFO) # regular WebSocket ws_server = websockets.serve(_handler, host, ws_port) # , subprotocols=["hbd"]) servers.append(ws_server) # secure WebSocket (optional) if wss_port and ssl_context: wss_server = websockets.serve( _handler, host, wss_port, ssl=ssl_context ) # , subprotocols=["hbd"]) servers.append(wss_server) # await starting of all servers for srv in servers: await srv logger.info( "WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port ) # block forever (until loop is stopped or cancelled) await asyncio.Future() def broadcast(typ: str, data) -> bool: """Thread-safe broadcast helper. Schedules coroutine(s) on the running loop to send message to all connected websockets. Returns False if server was not running. """ if not _loop: return False jmsg = json.dumps({"type": typ, "data": data}) to_close = [] for ws in list(_connections): if ws.state != websockets.protocol.State.OPEN: to_close.append(ws) continue try: asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop) except Exception: to_close.append(ws) logger.debug("ws.send exception: closed") for ws in to_close: try: asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop) except Exception: pass if ws in _connections: _connections.remove(ws) return True def connection_count() -> int: return len(_connections)