diff --git a/hbd/server/http.py b/hbd/server/http.py index f70137b..fba04e2 100644 --- a/hbd/server/http.py +++ b/hbd/server/http.py @@ -12,6 +12,7 @@ from . import data from . import notify as notify_mod from . import settings as settings_mod from . import users as users_mod +from . import ws as ws_mod logger = logging.getLogger(__name__) @@ -242,11 +243,9 @@ async def start( env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_dir)) host = config.get("hb_host", "localhost") extra_scripts = config.get("http_extra_scripts", "") - host = request.host.split(":")[0] - if config.get("wss_port"): - heartbeat_ws_url = f"wss://{host}:{config['wss_port']}/hbd" - else: - heartbeat_ws_url = f"ws://{host}:{config.get('ws_port', 50005)}/hbd" + host = request.host # includes port if non-standard + scheme = "wss" if request.secure else "ws" + heartbeat_ws_url = f"{scheme}://{host}/ws" tmpl = env.get_template("live.html") body = tmpl.render( title="Heartbeat", @@ -843,6 +842,7 @@ async def start( web.get("/settings", settings_page), web.get("/static/{path:.*}", static), web.get("/favicon.ico", favicon), + web.get("/ws", ws_mod.handler), ] ) diff --git a/hbd/server/main.py b/hbd/server/main.py index 270b2d3..ab46ef4 100644 --- a/hbd/server/main.py +++ b/hbd/server/main.py @@ -275,45 +275,17 @@ async def _run_async(config, config_path=None): except Exception as e: logger.exception("dns worker failed to start: %s", e) - # Start the websocket servers as a background task - if config.get("wss_port", None): - ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) - ssl_path = config.get("cert_path", "") - wss_pem = ssl_path + config.get("wss_pem", "") - wss_key = ssl_path + config.get("wss_key", "") - try: - ssl_context.load_cert_chain(wss_pem, keyfile=wss_key) - except FileNotFoundError: - logger.error("error: missing SSL keys %s or %s", wss_pem, wss_key) - sys.exit(1) - logger.info( - "Starting secure WebSocket server on port %s with cert %s", - config.get("wss_port", None), - wss_pem, - ) - else: - ssl_context = None - - try: - ws_port = config.get("ws_port", 50005) - logger.info("Starting WebSocket server on port %s", ws_port) - ws_task = asyncio.create_task( - ws_mod.start( - host=config.get("hbd_host", ""), - ws_port=ws_port, - wss_port=config.get("wss_port", None), - ssl_context=ssl_context, - get_hosts=lambda: [ - hbdclass.Host.hosts[h].stateinfo() - for h in sorted(hbdclass.Host.hosts) - ], -# get_msgs=lambda: msgs, - config=config, - ) - ) - logger.info("WebSocket task started") - except Exception as e: - logger.exception("websocket server failed to start: %s", e) + # Register WebSocket state — connections are now served through /ws on the HTTP port + ws_task = None + ws_mod.setup( + loop=loop, + get_hosts=lambda: [ + hbdclass.Host.hosts[h].stateinfo() + for h in sorted(hbdclass.Host.hosts) + ], + verbose=config.get("verbose", False), + ) + logger.info("WebSocket handler registered on /ws (HTTP port %s)", config.get("hbd_port", 50004)) # Periodic autosave task autosave_interval = config.get("autosave_interval", 300) # default: 5 minutes @@ -375,7 +347,7 @@ async def _run_async(config, config_path=None): except Exception as e: logger.warning("Error closing UDP transport: %s", e) - tasks_to_cancel = [http_task, ws_task, autosave] + tasks_to_cancel = [http_task, autosave] for task in tasks_to_cancel: if task: try: diff --git a/hbd/server/ws.py b/hbd/server/ws.py index 590438f..cf2ccd4 100644 --- a/hbd/server/ws.py +++ b/hbd/server/ws.py @@ -1,7 +1,8 @@ -"""WebSocket server and broadcast helpers for hbd. +"""WebSocket handler and broadcast helpers for hbd. -Provides an asyncio-based WebSocket server and a thread-safe broadcast -function that other threads or synchronous code can call. +WebSocket connections are served through the regular HTTP port via the +/ws route registered in http.py (aiohttp WebSocketResponse upgrade). +The separate standalone WebSocket server on ws_port is no longer used. """ import asyncio @@ -10,147 +11,99 @@ import logging from typing import Callable, Iterable, Optional from . import data -import websockets - logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) -_connections = set() + +_connections: set = set() _loop: Optional[asyncio.AbstractEventLoop] = None _get_hosts: Optional[Callable[[], Iterable]] = None -#_get_msgs: Optional[Callable[[], Iterable]] = None -_verbose = False +_verbose: bool = False -async def _handler(websocket, path=None): - _connections.add(websocket) - remote_address = websocket.remote_address - if path is None: - path = getattr(websocket, "path", None) - logger.info("WebSocket connection from %s: %s", remote_address, path) - try: - # send initial hosts - if _get_hosts: - try: - hosts = list(_get_hosts()) - logger.debug("Sending %d hosts to new WebSocket client", len(hosts)) - for h in hosts: - jmsg = json.dumps({"type": "host", "data": h}) - await websocket.send(jmsg) - except Exception as e: - logger.error("Error sending initial hosts: %s", e, exc_info=True) - # send recent messages - if data.msgs: - try: -# msgs = list(_get_msgs())[-100:] - logger.debug("Sending %d recent messages to new WebSocket client", len(data.msgs)) - for m in data.msgs: - jmsg = json.dumps({"type": "message", "data": m}) - await websocket.send(jmsg) - except Exception as e: - logger.error("Error sending initial messages: %s", e, exc_info=True) - - # keep connection open until client disconnects - async for _ in websocket: - # we don't expect meaningful incoming messages besides the initial - # client 'hello' that some clients send; ignore for now - if _verbose: - logger.debug("received ws data: %s", _) - - except ( - websockets.exceptions.ConnectionClosedOK, - websockets.exceptions.ConnectionClosedError, - ) as e: - logger.info("WebSocket closed from %s: %r", remote_address, e) - except Exception as e: - logger.exception("WebSocket handler exception from %s: %s", remote_address, e) - finally: - logger.debug("Removing WebSocket connection from %s", remote_address) - _connections.discard(websocket) - - -async def start( - host: str, - ws_port: int, - wss_port: Optional[int] = None, - ssl_context=None, - get_hosts: Optional[Callable] = None, -# get_msgs: Optional[Callable] = None, - config: dict = {}, +def setup( + loop: asyncio.AbstractEventLoop, + get_hosts: Optional[Callable[[], Iterable]] = None, + verbose: bool = False, ): - """Start WebSocket servers and block until cancelled. + """Register the running loop and initial-state callback. - This is intended to be awaited inside the main asyncio event loop. - If `wss_port` and `ssl_context` are provided, a WSS server will also be - started. + Call this once from _run_async before starting the HTTP server. """ global _loop, _get_hosts, _verbose - _loop = asyncio.get_running_loop() + _loop = loop _get_hosts = get_hosts - _verbose = config.get("verbose", False), - _debug = config.get("debug", 0), + _verbose = verbose - # Start servers and keep the server objects for clean shutdown - running_servers = [] - ws_server = await websockets.serve(_handler, host, ws_port) - running_servers.append(ws_server) - if wss_port and ssl_context: - wss_server = await websockets.serve(_handler, host, wss_port, ssl=ssl_context) - running_servers.append(wss_server) - logger.info( - "WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port - ) +async def handler(request): + """aiohttp WebSocket upgrade handler — register as GET /ws.""" + from aiohttp import web + + ws = web.WebSocketResponse() + await ws.prepare(request) + + _connections.add(ws) + remote = request.remote + logger.info("WebSocket connected from %s", remote) try: - # Block until cancelled - await asyncio.Future() - except asyncio.CancelledError: - pass + # Send current host state to the new client + if _get_hosts: + try: + for h in list(_get_hosts()): + await ws.send_str(json.dumps({"type": "host", "data": h})) + except Exception as e: + logger.error("Error sending initial hosts: %s", e) + + # Send recent messages + if data.msgs: + try: + for m in data.msgs: + await ws.send_str(json.dumps({"type": "message", "data": m})) + except Exception as e: + logger.error("Error sending initial messages: %s", e) + + # Keep connection open, ignore incoming frames + async for msg in ws: + from aiohttp import WSMsgType + if msg.type == WSMsgType.TEXT: + if _verbose: + logger.debug("ws recv from %s: %s", remote, msg.data) + elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE): + break + + except Exception as e: + logger.exception("WebSocket handler error from %s: %s", remote, e) finally: - # Close all active browser connections so their handler coroutines exit - active = list(_connections) - if active: - logger.info("Closing %d active WebSocket connection(s)...", len(active)) - await asyncio.gather( - *[ws.close() for ws in active], - return_exceptions=True, - ) - # Stop the listening servers and wait for all handlers to finish - for srv in running_servers: - srv.close() - await asyncio.gather( - *[srv.wait_closed() for srv in running_servers], - return_exceptions=True, - ) - logger.info("WebSocket server(s) stopped") + _connections.discard(ws) + logger.info("WebSocket disconnected from %s", remote) + + return ws -def broadcast(typ: str, data) -> bool: - """Thread-safe broadcast helper. +def broadcast(typ: str, payload) -> bool: + """Thread-safe broadcast to all connected WebSocket clients. - Schedules coroutine(s) on the running loop to send message to all - connected websockets. Returns False if server was not running. + Can be called from any thread; schedules sends on the event loop. + Returns False if the loop is not running yet. """ if not _loop: return False - jmsg = json.dumps({"type": typ, "data": data}) - to_close = [] - for ws in list(_connections): - if ws.state != websockets.protocol.State.OPEN: - to_close.append(ws) - continue - try: - asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop) - except Exception: - to_close.append(ws) - logger.debug("ws.send exception: closed") - for ws in to_close: - try: - asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop) - except Exception: - pass - if ws in _connections: - _connections.remove(ws) + jmsg = json.dumps({"type": typ, "data": payload}) + + async def _send_all(): + dead = set() + for ws in list(_connections): + try: + if not ws.closed: + await ws.send_str(jmsg) + else: + dead.add(ws) + except Exception: + dead.add(ws) + for ws in dead: + _connections.discard(ws) + + asyncio.run_coroutine_threadsafe(_send_all(), _loop) return True