accept websocket connection on http:.../ws

This commit is contained in:
Andreas Wrede
2026-04-12 06:44:32 -04:00
parent 6559f5462c
commit 7f049a4e26
3 changed files with 93 additions and 168 deletions
+5 -5
View File
@@ -12,6 +12,7 @@ from . import data
from . import notify as notify_mod from . import notify as notify_mod
from . import settings as settings_mod from . import settings as settings_mod
from . import users as users_mod from . import users as users_mod
from . import ws as ws_mod
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -242,11 +243,9 @@ async def start(
env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_dir)) env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_dir))
host = config.get("hb_host", "localhost") host = config.get("hb_host", "localhost")
extra_scripts = config.get("http_extra_scripts", "") extra_scripts = config.get("http_extra_scripts", "")
host = request.host.split(":")[0] host = request.host # includes port if non-standard
if config.get("wss_port"): scheme = "wss" if request.secure else "ws"
heartbeat_ws_url = f"wss://{host}:{config['wss_port']}/hbd" heartbeat_ws_url = f"{scheme}://{host}/ws"
else:
heartbeat_ws_url = f"ws://{host}:{config.get('ws_port', 50005)}/hbd"
tmpl = env.get_template("live.html") tmpl = env.get_template("live.html")
body = tmpl.render( body = tmpl.render(
title="Heartbeat", title="Heartbeat",
@@ -843,6 +842,7 @@ async def start(
web.get("/settings", settings_page), web.get("/settings", settings_page),
web.get("/static/{path:.*}", static), web.get("/static/{path:.*}", static),
web.get("/favicon.ico", favicon), web.get("/favicon.ico", favicon),
web.get("/ws", ws_mod.handler),
] ]
) )
+12 -40
View File
@@ -275,45 +275,17 @@ async def _run_async(config, config_path=None):
except Exception as e: except Exception as e:
logger.exception("dns worker failed to start: %s", e) logger.exception("dns worker failed to start: %s", e)
# Start the websocket servers as a background task # Register WebSocket state — connections are now served through /ws on the HTTP port
if config.get("wss_port", None): ws_task = None
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ws_mod.setup(
ssl_path = config.get("cert_path", "") loop=loop,
wss_pem = ssl_path + config.get("wss_pem", "") get_hosts=lambda: [
wss_key = ssl_path + config.get("wss_key", "") hbdclass.Host.hosts[h].stateinfo()
try: for h in sorted(hbdclass.Host.hosts)
ssl_context.load_cert_chain(wss_pem, keyfile=wss_key) ],
except FileNotFoundError: verbose=config.get("verbose", False),
logger.error("error: missing SSL keys %s or %s", wss_pem, wss_key) )
sys.exit(1) logger.info("WebSocket handler registered on /ws (HTTP port %s)", config.get("hbd_port", 50004))
logger.info(
"Starting secure WebSocket server on port %s with cert %s",
config.get("wss_port", None),
wss_pem,
)
else:
ssl_context = None
try:
ws_port = config.get("ws_port", 50005)
logger.info("Starting WebSocket server on port %s", ws_port)
ws_task = asyncio.create_task(
ws_mod.start(
host=config.get("hbd_host", ""),
ws_port=ws_port,
wss_port=config.get("wss_port", None),
ssl_context=ssl_context,
get_hosts=lambda: [
hbdclass.Host.hosts[h].stateinfo()
for h in sorted(hbdclass.Host.hosts)
],
# get_msgs=lambda: msgs,
config=config,
)
)
logger.info("WebSocket task started")
except Exception as e:
logger.exception("websocket server failed to start: %s", e)
# Periodic autosave task # Periodic autosave task
autosave_interval = config.get("autosave_interval", 300) # default: 5 minutes autosave_interval = config.get("autosave_interval", 300) # default: 5 minutes
@@ -375,7 +347,7 @@ async def _run_async(config, config_path=None):
except Exception as e: except Exception as e:
logger.warning("Error closing UDP transport: %s", e) logger.warning("Error closing UDP transport: %s", e)
tasks_to_cancel = [http_task, ws_task, autosave] tasks_to_cancel = [http_task, autosave]
for task in tasks_to_cancel: for task in tasks_to_cancel:
if task: if task:
try: try:
+76 -123
View File
@@ -1,7 +1,8 @@
"""WebSocket server and broadcast helpers for hbd. """WebSocket handler and broadcast helpers for hbd.
Provides an asyncio-based WebSocket server and a thread-safe broadcast WebSocket connections are served through the regular HTTP port via the
function that other threads or synchronous code can call. /ws route registered in http.py (aiohttp WebSocketResponse upgrade).
The separate standalone WebSocket server on ws_port is no longer used.
""" """
import asyncio import asyncio
@@ -10,147 +11,99 @@ import logging
from typing import Callable, Iterable, Optional from typing import Callable, Iterable, Optional
from . import data from . import data
import websockets
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
_connections = set() _connections: set = set()
_loop: Optional[asyncio.AbstractEventLoop] = None _loop: Optional[asyncio.AbstractEventLoop] = None
_get_hosts: Optional[Callable[[], Iterable]] = None _get_hosts: Optional[Callable[[], Iterable]] = None
#_get_msgs: Optional[Callable[[], Iterable]] = None _verbose: bool = False
_verbose = False
async def _handler(websocket, path=None): def setup(
_connections.add(websocket) loop: asyncio.AbstractEventLoop,
remote_address = websocket.remote_address get_hosts: Optional[Callable[[], Iterable]] = None,
if path is None: verbose: bool = False,
path = getattr(websocket, "path", None)
logger.info("WebSocket connection from %s: %s", remote_address, path)
try:
# send initial hosts
if _get_hosts:
try:
hosts = list(_get_hosts())
logger.debug("Sending %d hosts to new WebSocket client", len(hosts))
for h in hosts:
jmsg = json.dumps({"type": "host", "data": h})
await websocket.send(jmsg)
except Exception as e:
logger.error("Error sending initial hosts: %s", e, exc_info=True)
# send recent messages
if data.msgs:
try:
# msgs = list(_get_msgs())[-100:]
logger.debug("Sending %d recent messages to new WebSocket client", len(data.msgs))
for m in data.msgs:
jmsg = json.dumps({"type": "message", "data": m})
await websocket.send(jmsg)
except Exception as e:
logger.error("Error sending initial messages: %s", e, exc_info=True)
# keep connection open until client disconnects
async for _ in websocket:
# we don't expect meaningful incoming messages besides the initial
# client 'hello' that some clients send; ignore for now
if _verbose:
logger.debug("received ws data: %s", _)
except (
websockets.exceptions.ConnectionClosedOK,
websockets.exceptions.ConnectionClosedError,
) as e:
logger.info("WebSocket closed from %s: %r", remote_address, e)
except Exception as e:
logger.exception("WebSocket handler exception from %s: %s", remote_address, e)
finally:
logger.debug("Removing WebSocket connection from %s", remote_address)
_connections.discard(websocket)
async def start(
host: str,
ws_port: int,
wss_port: Optional[int] = None,
ssl_context=None,
get_hosts: Optional[Callable] = None,
# get_msgs: Optional[Callable] = None,
config: dict = {},
): ):
"""Start WebSocket servers and block until cancelled. """Register the running loop and initial-state callback.
This is intended to be awaited inside the main asyncio event loop. Call this once from _run_async before starting the HTTP server.
If `wss_port` and `ssl_context` are provided, a WSS server will also be
started.
""" """
global _loop, _get_hosts, _verbose global _loop, _get_hosts, _verbose
_loop = asyncio.get_running_loop() _loop = loop
_get_hosts = get_hosts _get_hosts = get_hosts
_verbose = config.get("verbose", False), _verbose = verbose
_debug = config.get("debug", 0),
# Start servers and keep the server objects for clean shutdown
running_servers = []
ws_server = await websockets.serve(_handler, host, ws_port)
running_servers.append(ws_server)
if wss_port and ssl_context:
wss_server = await websockets.serve(_handler, host, wss_port, ssl=ssl_context)
running_servers.append(wss_server)
logger.info( async def handler(request):
"WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port """aiohttp WebSocket upgrade handler — register as GET /ws."""
) from aiohttp import web
ws = web.WebSocketResponse()
await ws.prepare(request)
_connections.add(ws)
remote = request.remote
logger.info("WebSocket connected from %s", remote)
try: try:
# Block until cancelled # Send current host state to the new client
await asyncio.Future() if _get_hosts:
except asyncio.CancelledError: try:
pass for h in list(_get_hosts()):
await ws.send_str(json.dumps({"type": "host", "data": h}))
except Exception as e:
logger.error("Error sending initial hosts: %s", e)
# Send recent messages
if data.msgs:
try:
for m in data.msgs:
await ws.send_str(json.dumps({"type": "message", "data": m}))
except Exception as e:
logger.error("Error sending initial messages: %s", e)
# Keep connection open, ignore incoming frames
async for msg in ws:
from aiohttp import WSMsgType
if msg.type == WSMsgType.TEXT:
if _verbose:
logger.debug("ws recv from %s: %s", remote, msg.data)
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
break
except Exception as e:
logger.exception("WebSocket handler error from %s: %s", remote, e)
finally: finally:
# Close all active browser connections so their handler coroutines exit _connections.discard(ws)
active = list(_connections) logger.info("WebSocket disconnected from %s", remote)
if active:
logger.info("Closing %d active WebSocket connection(s)...", len(active)) return ws
await asyncio.gather(
*[ws.close() for ws in active],
return_exceptions=True,
)
# Stop the listening servers and wait for all handlers to finish
for srv in running_servers:
srv.close()
await asyncio.gather(
*[srv.wait_closed() for srv in running_servers],
return_exceptions=True,
)
logger.info("WebSocket server(s) stopped")
def broadcast(typ: str, data) -> bool: def broadcast(typ: str, payload) -> bool:
"""Thread-safe broadcast helper. """Thread-safe broadcast to all connected WebSocket clients.
Schedules coroutine(s) on the running loop to send message to all Can be called from any thread; schedules sends on the event loop.
connected websockets. Returns False if server was not running. Returns False if the loop is not running yet.
""" """
if not _loop: if not _loop:
return False return False
jmsg = json.dumps({"type": typ, "data": data}) jmsg = json.dumps({"type": typ, "data": payload})
to_close = []
for ws in list(_connections): async def _send_all():
if ws.state != websockets.protocol.State.OPEN: dead = set()
to_close.append(ws) for ws in list(_connections):
continue try:
try: if not ws.closed:
asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop) await ws.send_str(jmsg)
except Exception: else:
to_close.append(ws) dead.add(ws)
logger.debug("ws.send exception: closed") except Exception:
for ws in to_close: dead.add(ws)
try: for ws in dead:
asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop) _connections.discard(ws)
except Exception:
pass asyncio.run_coroutine_threadsafe(_send_all(), _loop)
if ws in _connections:
_connections.remove(ws)
return True return True