"""WebSocket handler and broadcast helpers for hbd. WebSocket connections are served through the regular HTTP port via the /ws route registered in http.py (aiohttp WebSocketResponse upgrade). The separate standalone WebSocket server on ws_port is no longer used. """ import asyncio import json import logging from typing import Callable, Iterable, Optional from . import data logger = logging.getLogger(__name__) _connections: set = set() _loop: Optional[asyncio.AbstractEventLoop] = None _get_hosts: Optional[Callable[[], Iterable]] = None _verbose: bool = False def setup( loop: asyncio.AbstractEventLoop, get_hosts: Optional[Callable[[], Iterable]] = None, verbose: bool = False, ): """Register the running loop and initial-state callback. Call this once from _run_async before starting the HTTP server. """ global _loop, _get_hosts, _verbose _loop = loop _get_hosts = get_hosts _verbose = verbose async def handler(request): """aiohttp WebSocket upgrade handler — register as GET /ws.""" from aiohttp import web ws = web.WebSocketResponse() await ws.prepare(request) _connections.add(ws) remote = request.remote logger.info("WebSocket connected from %s", remote) try: # Send current host state to the new client if _get_hosts: try: for h in list(_get_hosts()): await ws.send_str(json.dumps({"type": "host", "data": h})) except Exception as e: logger.error("Error sending initial hosts: %s", e) # Send recent messages if data.msgs: try: for m in data.msgs: await ws.send_str(json.dumps({"type": "message", "data": m})) except Exception as e: logger.error("Error sending initial messages: %s", e) # Keep connection open, ignore incoming frames async for msg in ws: from aiohttp import WSMsgType if msg.type == WSMsgType.TEXT: if _verbose: logger.debug("ws recv from %s: %s", remote, msg.data) elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): break except Exception as e: logger.exception("WebSocket handler error from %s: %s", remote, e) finally: _connections.discard(ws) logger.info("WebSocket disconnected from %s", remote) return ws def broadcast(typ: str, payload) -> bool: """Thread-safe broadcast to all connected WebSocket clients. Can be called from any thread; schedules sends on the event loop. Returns False if the loop is not running yet. """ if not _loop: return False jmsg = json.dumps({"type": typ, "data": payload}) async def _send_all(): dead = set() for ws in list(_connections): try: if not ws.closed: await ws.send_str(jmsg) else: dead.add(ws) except Exception: dead.add(ws) for ws in dead: _connections.discard(ws) asyncio.run_coroutine_threadsafe(_send_all(), _loop) return True def connection_count() -> int: return len(_connections)