"""WebSocket handler and broadcast helpers for hbd. WebSocket connections are served through the regular HTTP port via the /ws route registered in http.py (aiohttp WebSocketResponse upgrade). The separate standalone WebSocket server on ws_port is no longer used. """ import asyncio import json import logging from typing import Callable, Iterable, Optional from . import data logger = logging.getLogger(__name__) # Map of WebSocket → User object (or None when auth is disabled) _connections: dict = {} _loop: Optional[asyncio.AbstractEventLoop] = None _get_hosts: Optional[Callable[[], Iterable]] = None _verbose: bool = False def setup( loop: asyncio.AbstractEventLoop, get_hosts: Optional[Callable[[], Iterable]] = None, verbose: bool = False, ): """Register the running loop and initial-state callback. Call this once from _run_async before starting the HTTP server. """ global _loop, _get_hosts, _verbose _loop = loop _get_hosts = get_hosts _verbose = verbose def _user_can_see_host(user, host_name: str) -> bool: """Return True if *user* may see updates for *host_name* (manager or higher).""" from . import hbdclass, users as users_mod if user is None or not users_mod.users_enabled(): return True if user.admin: return True host = hbdclass.Host.hosts.get(host_name) if host is None: return False return host.is_manager(user.username) def _get_token(request) -> str: """Extract session token from request (mirrors logic in http.py).""" auth = request.headers.get("Authorization", "") if auth.startswith("Bearer "): return auth[7:].strip() token = request.headers.get("X-Auth-Token", "") if token: return token return request.cookies.get("hbd_session", "") async def handler(request): """aiohttp WebSocket upgrade handler — register as GET /ws.""" from aiohttp import web from . import users as users_mod ws = web.WebSocketResponse() await ws.prepare(request) token = _get_token(request) user = users_mod.get_session_user(token) if token else None _connections[ws] = user remote = request.remote logger.info("WebSocket connected from %s", remote) try: # Send current host state, filtered to hosts this user may see if _get_hosts: try: for h in list(_get_hosts()): host_name = h.get("raw_name") or h.get("name", "") if _user_can_see_host(user, host_name): await ws.send_str(json.dumps({"type": "host", "data": h})) except Exception as e: logger.error("Error sending initial hosts: %s", e) # Send recent messages if data.msgs: try: for m in data.msgs: await ws.send_str(json.dumps({"type": "message", "data": m})) except Exception as e: logger.error("Error sending initial messages: %s", e) # Keep connection open, ignore incoming frames async for msg in ws: from aiohttp import WSMsgType if msg.type == WSMsgType.TEXT: if _verbose: logger.debug("ws recv from %s: %s", remote, msg.data) elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): break except Exception as e: logger.exception("WebSocket handler error from %s: %s", remote, e) finally: _connections.pop(ws, None) logger.info("WebSocket disconnected from %s", remote) return ws def broadcast(typ: str, payload) -> bool: """Thread-safe broadcast to all connected WebSocket clients. For host and plugin updates, only sends to clients whose user has manager-or-higher access to that host. Other message types are broadcast to all clients. Can be called from any thread; schedules sends on the event loop. Returns False if the loop is not running yet. """ if not _loop: return False # Determine the host name for access-filtered message types host_name: Optional[str] = None if typ in ("host", "plugin"): host_name = payload.get("raw_name") or payload.get("host") or payload.get("name") jmsg = json.dumps({"type": typ, "data": payload}) async def _send_all(): dead = set() for ws, user in list(_connections.items()): try: if ws.closed: dead.add(ws) continue if host_name is not None and not _user_can_see_host(user, host_name): continue await ws.send_str(jmsg) except Exception: dead.add(ws) for ws in dead: _connections.pop(ws, None) asyncio.run_coroutine_threadsafe(_send_all(), _loop) return True def connection_count() -> int: return len(_connections)