112 lines
3.2 KiB
Python
112 lines
3.2 KiB
Python
"""WebSocket handler and broadcast helpers for hbd.
|
|
|
|
WebSocket connections are served through the regular HTTP port via the
|
|
/ws route registered in http.py (aiohttp WebSocketResponse upgrade).
|
|
The separate standalone WebSocket server on ws_port is no longer used.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Callable, Iterable, Optional
|
|
from . import data
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_connections: set = set()
|
|
_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
_get_hosts: Optional[Callable[[], Iterable]] = None
|
|
_verbose: bool = False
|
|
|
|
|
|
def setup(
|
|
loop: asyncio.AbstractEventLoop,
|
|
get_hosts: Optional[Callable[[], Iterable]] = None,
|
|
verbose: bool = False,
|
|
):
|
|
"""Register the running loop and initial-state callback.
|
|
|
|
Call this once from _run_async before starting the HTTP server.
|
|
"""
|
|
global _loop, _get_hosts, _verbose
|
|
_loop = loop
|
|
_get_hosts = get_hosts
|
|
_verbose = verbose
|
|
|
|
|
|
async def handler(request):
|
|
"""aiohttp WebSocket upgrade handler — register as GET /ws."""
|
|
from aiohttp import web
|
|
|
|
ws = web.WebSocketResponse()
|
|
await ws.prepare(request)
|
|
|
|
_connections.add(ws)
|
|
remote = request.remote
|
|
logger.info("WebSocket connected from %s", remote)
|
|
|
|
try:
|
|
# Send current host state to the new client
|
|
if _get_hosts:
|
|
try:
|
|
for h in list(_get_hosts()):
|
|
await ws.send_str(json.dumps({"type": "host", "data": h}))
|
|
except Exception as e:
|
|
logger.error("Error sending initial hosts: %s", e)
|
|
|
|
# Send recent messages
|
|
if data.msgs:
|
|
try:
|
|
for m in data.msgs:
|
|
await ws.send_str(json.dumps({"type": "message", "data": m}))
|
|
except Exception as e:
|
|
logger.error("Error sending initial messages: %s", e)
|
|
|
|
# Keep connection open, ignore incoming frames
|
|
async for msg in ws:
|
|
from aiohttp import WSMsgType
|
|
if msg.type == WSMsgType.TEXT:
|
|
if _verbose:
|
|
logger.debug("ws recv from %s: %s", remote, msg.data)
|
|
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
|
|
break
|
|
|
|
except Exception as e:
|
|
logger.exception("WebSocket handler error from %s: %s", remote, e)
|
|
finally:
|
|
_connections.discard(ws)
|
|
logger.info("WebSocket disconnected from %s", remote)
|
|
|
|
return ws
|
|
|
|
|
|
def broadcast(typ: str, payload) -> bool:
|
|
"""Thread-safe broadcast to all connected WebSocket clients.
|
|
|
|
Can be called from any thread; schedules sends on the event loop.
|
|
Returns False if the loop is not running yet.
|
|
"""
|
|
if not _loop:
|
|
return False
|
|
jmsg = json.dumps({"type": typ, "data": payload})
|
|
|
|
async def _send_all():
|
|
dead = set()
|
|
for ws in list(_connections):
|
|
try:
|
|
if not ws.closed:
|
|
await ws.send_str(jmsg)
|
|
else:
|
|
dead.add(ws)
|
|
except Exception:
|
|
dead.add(ws)
|
|
for ws in dead:
|
|
_connections.discard(ws)
|
|
|
|
asyncio.run_coroutine_threadsafe(_send_all(), _loop)
|
|
return True
|
|
|
|
|
|
def connection_count() -> int:
|
|
return len(_connections)
|