"""WebSocket server and broadcast helpers for hbd. Provides an asyncio-based WebSocket server and a thread-safe broadcast function that other threads or synchronous code can call. """ import asyncio import json import logging from typing import Callable, Iterable, Optional import websockets logger = logging.getLogger(__name__) _connections = set() _loop: Optional[asyncio.AbstractEventLoop] = None _get_hosts: Optional[Callable[[], Iterable]] = None _get_msgs: Optional[Callable[[], Iterable]] = None _verbose = False async def _handler(websocket, path=None): _connections.add(websocket) remote_address = websocket.remote_address if path is None: path = getattr(websocket, "path", None) if _verbose: logger.info("DBG ws_serve: %s: %s", remote_address, path) try: # send initial hosts if _get_hosts: for h in _get_hosts(): jmsg = json.dumps({"type": "host", "data": h}) await websocket.send(jmsg) # send recent messages if _get_msgs: for m in list(_get_msgs())[-100:]: jmsg = json.dumps({"type": "message", "data": m}) await websocket.send(jmsg) # keep connection open until client disconnects async for _ in websocket: # we don't expect meaningful incoming messages besides the initial # client 'hello' that some clients send; ignore for now if _verbose: logger.debug("received ws data: %s", _) except ( websockets.exceptions.ConnectionClosedOK, websockets.exceptions.ConnectionClosedError, ) as e: if _verbose: logger.info("ws closed: %r", e) except Exception as e: logger.exception("ws handler exception: %s", e) finally: try: _connections.remove(websocket) except KeyError: pass await websocket.wait_closed() async def start( host: str, ws_port: int, wss_port: Optional[int] = None, ssl_context=None, get_hosts: Optional[Callable] = None, get_msgs: Optional[Callable] = None, verbose: bool = False, ): """Start WebSocket servers and block until cancelled. This is intended to be awaited inside the main asyncio event loop. If `wss_port` and `ssl_context` are provided, a WSS server will also be started. """ global _loop, _get_hosts, _get_msgs, _verbose _loop = asyncio.get_running_loop() _get_hosts = get_hosts _get_msgs = get_msgs _verbose = verbose servers = [] # plain WebSocket websockets_logger = logging.getLogger("websockets.server") websockets_logger.setLevel(logging.DEBUG if verbose else logging.INFO) # regular WebSocket ws_server = websockets.serve(_handler, host, ws_port) # , subprotocols=["hbd"]) servers.append(ws_server) # secure WebSocket (optional) if wss_port and ssl_context: wss_server = websockets.serve( _handler, host, wss_port, ssl=ssl_context ) # , subprotocols=["hbd"]) servers.append(wss_server) # await starting of all servers for srv in servers: await srv if _verbose: logger.info( "WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port ) # block forever (until loop is stopped or cancelled) await asyncio.Future() def broadcast(typ: str, data) -> bool: """Thread-safe broadcast helper. Schedules coroutine(s) on the running loop to send message to all connected websockets. Returns False if server was not running. """ if not _loop: return False jmsg = json.dumps({"type": typ, "data": data}) to_close = [] for ws in list(_connections): if ws.state != websockets.protocol.State.OPEN: to_close.append(ws) continue try: asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop) except Exception: to_close.append(ws) logger.debug("ws.send exception: closed") for ws in to_close: try: asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop) except Exception: pass if ws in _connections: _connections.remove(ws) return True def connection_count() -> int: return len(_connections)