accept websocket connection on http:.../ws
This commit is contained in:
+5
-5
@@ -12,6 +12,7 @@ from . import data
|
|||||||
from . import notify as notify_mod
|
from . import notify as notify_mod
|
||||||
from . import settings as settings_mod
|
from . import settings as settings_mod
|
||||||
from . import users as users_mod
|
from . import users as users_mod
|
||||||
|
from . import ws as ws_mod
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -242,11 +243,9 @@ async def start(
|
|||||||
env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_dir))
|
env = jinja2.Environment(loader=jinja2.FileSystemLoader(templates_dir))
|
||||||
host = config.get("hb_host", "localhost")
|
host = config.get("hb_host", "localhost")
|
||||||
extra_scripts = config.get("http_extra_scripts", "")
|
extra_scripts = config.get("http_extra_scripts", "")
|
||||||
host = request.host.split(":")[0]
|
host = request.host # includes port if non-standard
|
||||||
if config.get("wss_port"):
|
scheme = "wss" if request.secure else "ws"
|
||||||
heartbeat_ws_url = f"wss://{host}:{config['wss_port']}/hbd"
|
heartbeat_ws_url = f"{scheme}://{host}/ws"
|
||||||
else:
|
|
||||||
heartbeat_ws_url = f"ws://{host}:{config.get('ws_port', 50005)}/hbd"
|
|
||||||
tmpl = env.get_template("live.html")
|
tmpl = env.get_template("live.html")
|
||||||
body = tmpl.render(
|
body = tmpl.render(
|
||||||
title="Heartbeat",
|
title="Heartbeat",
|
||||||
@@ -843,6 +842,7 @@ async def start(
|
|||||||
web.get("/settings", settings_page),
|
web.get("/settings", settings_page),
|
||||||
web.get("/static/{path:.*}", static),
|
web.get("/static/{path:.*}", static),
|
||||||
web.get("/favicon.ico", favicon),
|
web.get("/favicon.ico", favicon),
|
||||||
|
web.get("/ws", ws_mod.handler),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+7
-35
@@ -275,45 +275,17 @@ async def _run_async(config, config_path=None):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("dns worker failed to start: %s", e)
|
logger.exception("dns worker failed to start: %s", e)
|
||||||
|
|
||||||
# Start the websocket servers as a background task
|
# Register WebSocket state — connections are now served through /ws on the HTTP port
|
||||||
if config.get("wss_port", None):
|
ws_task = None
|
||||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
ws_mod.setup(
|
||||||
ssl_path = config.get("cert_path", "")
|
loop=loop,
|
||||||
wss_pem = ssl_path + config.get("wss_pem", "")
|
|
||||||
wss_key = ssl_path + config.get("wss_key", "")
|
|
||||||
try:
|
|
||||||
ssl_context.load_cert_chain(wss_pem, keyfile=wss_key)
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error("error: missing SSL keys %s or %s", wss_pem, wss_key)
|
|
||||||
sys.exit(1)
|
|
||||||
logger.info(
|
|
||||||
"Starting secure WebSocket server on port %s with cert %s",
|
|
||||||
config.get("wss_port", None),
|
|
||||||
wss_pem,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ssl_context = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
ws_port = config.get("ws_port", 50005)
|
|
||||||
logger.info("Starting WebSocket server on port %s", ws_port)
|
|
||||||
ws_task = asyncio.create_task(
|
|
||||||
ws_mod.start(
|
|
||||||
host=config.get("hbd_host", ""),
|
|
||||||
ws_port=ws_port,
|
|
||||||
wss_port=config.get("wss_port", None),
|
|
||||||
ssl_context=ssl_context,
|
|
||||||
get_hosts=lambda: [
|
get_hosts=lambda: [
|
||||||
hbdclass.Host.hosts[h].stateinfo()
|
hbdclass.Host.hosts[h].stateinfo()
|
||||||
for h in sorted(hbdclass.Host.hosts)
|
for h in sorted(hbdclass.Host.hosts)
|
||||||
],
|
],
|
||||||
# get_msgs=lambda: msgs,
|
verbose=config.get("verbose", False),
|
||||||
config=config,
|
|
||||||
)
|
)
|
||||||
)
|
logger.info("WebSocket handler registered on /ws (HTTP port %s)", config.get("hbd_port", 50004))
|
||||||
logger.info("WebSocket task started")
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("websocket server failed to start: %s", e)
|
|
||||||
|
|
||||||
# Periodic autosave task
|
# Periodic autosave task
|
||||||
autosave_interval = config.get("autosave_interval", 300) # default: 5 minutes
|
autosave_interval = config.get("autosave_interval", 300) # default: 5 minutes
|
||||||
@@ -375,7 +347,7 @@ async def _run_async(config, config_path=None):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Error closing UDP transport: %s", e)
|
logger.warning("Error closing UDP transport: %s", e)
|
||||||
|
|
||||||
tasks_to_cancel = [http_task, ws_task, autosave]
|
tasks_to_cancel = [http_task, autosave]
|
||||||
for task in tasks_to_cancel:
|
for task in tasks_to_cancel:
|
||||||
if task:
|
if task:
|
||||||
try:
|
try:
|
||||||
|
|||||||
+73
-120
@@ -1,7 +1,8 @@
|
|||||||
"""WebSocket server and broadcast helpers for hbd.
|
"""WebSocket handler and broadcast helpers for hbd.
|
||||||
|
|
||||||
Provides an asyncio-based WebSocket server and a thread-safe broadcast
|
WebSocket connections are served through the regular HTTP port via the
|
||||||
function that other threads or synchronous code can call.
|
/ws route registered in http.py (aiohttp WebSocketResponse upgrade).
|
||||||
|
The separate standalone WebSocket server on ws_port is no longer used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -10,147 +11,99 @@ import logging
|
|||||||
from typing import Callable, Iterable, Optional
|
from typing import Callable, Iterable, Optional
|
||||||
from . import data
|
from . import data
|
||||||
|
|
||||||
import websockets
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
_connections = set()
|
_connections: set = set()
|
||||||
_loop: Optional[asyncio.AbstractEventLoop] = None
|
_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
_get_hosts: Optional[Callable[[], Iterable]] = None
|
_get_hosts: Optional[Callable[[], Iterable]] = None
|
||||||
#_get_msgs: Optional[Callable[[], Iterable]] = None
|
_verbose: bool = False
|
||||||
_verbose = False
|
|
||||||
|
|
||||||
|
|
||||||
async def _handler(websocket, path=None):
|
def setup(
|
||||||
_connections.add(websocket)
|
loop: asyncio.AbstractEventLoop,
|
||||||
remote_address = websocket.remote_address
|
get_hosts: Optional[Callable[[], Iterable]] = None,
|
||||||
if path is None:
|
verbose: bool = False,
|
||||||
path = getattr(websocket, "path", None)
|
|
||||||
logger.info("WebSocket connection from %s: %s", remote_address, path)
|
|
||||||
try:
|
|
||||||
# send initial hosts
|
|
||||||
if _get_hosts:
|
|
||||||
try:
|
|
||||||
hosts = list(_get_hosts())
|
|
||||||
logger.debug("Sending %d hosts to new WebSocket client", len(hosts))
|
|
||||||
for h in hosts:
|
|
||||||
jmsg = json.dumps({"type": "host", "data": h})
|
|
||||||
await websocket.send(jmsg)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error sending initial hosts: %s", e, exc_info=True)
|
|
||||||
# send recent messages
|
|
||||||
if data.msgs:
|
|
||||||
try:
|
|
||||||
# msgs = list(_get_msgs())[-100:]
|
|
||||||
logger.debug("Sending %d recent messages to new WebSocket client", len(data.msgs))
|
|
||||||
for m in data.msgs:
|
|
||||||
jmsg = json.dumps({"type": "message", "data": m})
|
|
||||||
await websocket.send(jmsg)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Error sending initial messages: %s", e, exc_info=True)
|
|
||||||
|
|
||||||
# keep connection open until client disconnects
|
|
||||||
async for _ in websocket:
|
|
||||||
# we don't expect meaningful incoming messages besides the initial
|
|
||||||
# client 'hello' that some clients send; ignore for now
|
|
||||||
if _verbose:
|
|
||||||
logger.debug("received ws data: %s", _)
|
|
||||||
|
|
||||||
except (
|
|
||||||
websockets.exceptions.ConnectionClosedOK,
|
|
||||||
websockets.exceptions.ConnectionClosedError,
|
|
||||||
) as e:
|
|
||||||
logger.info("WebSocket closed from %s: %r", remote_address, e)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("WebSocket handler exception from %s: %s", remote_address, e)
|
|
||||||
finally:
|
|
||||||
logger.debug("Removing WebSocket connection from %s", remote_address)
|
|
||||||
_connections.discard(websocket)
|
|
||||||
|
|
||||||
|
|
||||||
async def start(
|
|
||||||
host: str,
|
|
||||||
ws_port: int,
|
|
||||||
wss_port: Optional[int] = None,
|
|
||||||
ssl_context=None,
|
|
||||||
get_hosts: Optional[Callable] = None,
|
|
||||||
# get_msgs: Optional[Callable] = None,
|
|
||||||
config: dict = {},
|
|
||||||
):
|
):
|
||||||
"""Start WebSocket servers and block until cancelled.
|
"""Register the running loop and initial-state callback.
|
||||||
|
|
||||||
This is intended to be awaited inside the main asyncio event loop.
|
Call this once from _run_async before starting the HTTP server.
|
||||||
If `wss_port` and `ssl_context` are provided, a WSS server will also be
|
|
||||||
started.
|
|
||||||
"""
|
"""
|
||||||
global _loop, _get_hosts, _verbose
|
global _loop, _get_hosts, _verbose
|
||||||
_loop = asyncio.get_running_loop()
|
_loop = loop
|
||||||
_get_hosts = get_hosts
|
_get_hosts = get_hosts
|
||||||
_verbose = config.get("verbose", False),
|
_verbose = verbose
|
||||||
_debug = config.get("debug", 0),
|
|
||||||
|
|
||||||
# Start servers and keep the server objects for clean shutdown
|
|
||||||
running_servers = []
|
|
||||||
ws_server = await websockets.serve(_handler, host, ws_port)
|
|
||||||
running_servers.append(ws_server)
|
|
||||||
if wss_port and ssl_context:
|
|
||||||
wss_server = await websockets.serve(_handler, host, wss_port, ssl=ssl_context)
|
|
||||||
running_servers.append(wss_server)
|
|
||||||
|
|
||||||
logger.info(
|
async def handler(request):
|
||||||
"WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port
|
"""aiohttp WebSocket upgrade handler — register as GET /ws."""
|
||||||
)
|
from aiohttp import web
|
||||||
|
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
_connections.add(ws)
|
||||||
|
remote = request.remote
|
||||||
|
logger.info("WebSocket connected from %s", remote)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Block until cancelled
|
# Send current host state to the new client
|
||||||
await asyncio.Future()
|
if _get_hosts:
|
||||||
except asyncio.CancelledError:
|
try:
|
||||||
pass
|
for h in list(_get_hosts()):
|
||||||
|
await ws.send_str(json.dumps({"type": "host", "data": h}))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error sending initial hosts: %s", e)
|
||||||
|
|
||||||
|
# Send recent messages
|
||||||
|
if data.msgs:
|
||||||
|
try:
|
||||||
|
for m in data.msgs:
|
||||||
|
await ws.send_str(json.dumps({"type": "message", "data": m}))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error sending initial messages: %s", e)
|
||||||
|
|
||||||
|
# Keep connection open, ignore incoming frames
|
||||||
|
async for msg in ws:
|
||||||
|
from aiohttp import WSMsgType
|
||||||
|
if msg.type == WSMsgType.TEXT:
|
||||||
|
if _verbose:
|
||||||
|
logger.debug("ws recv from %s: %s", remote, msg.data)
|
||||||
|
elif msg.type in (WSMsgType.ERROR, WSMsgType.CLOSE):
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("WebSocket handler error from %s: %s", remote, e)
|
||||||
finally:
|
finally:
|
||||||
# Close all active browser connections so their handler coroutines exit
|
_connections.discard(ws)
|
||||||
active = list(_connections)
|
logger.info("WebSocket disconnected from %s", remote)
|
||||||
if active:
|
|
||||||
logger.info("Closing %d active WebSocket connection(s)...", len(active))
|
return ws
|
||||||
await asyncio.gather(
|
|
||||||
*[ws.close() for ws in active],
|
|
||||||
return_exceptions=True,
|
|
||||||
)
|
|
||||||
# Stop the listening servers and wait for all handlers to finish
|
|
||||||
for srv in running_servers:
|
|
||||||
srv.close()
|
|
||||||
await asyncio.gather(
|
|
||||||
*[srv.wait_closed() for srv in running_servers],
|
|
||||||
return_exceptions=True,
|
|
||||||
)
|
|
||||||
logger.info("WebSocket server(s) stopped")
|
|
||||||
|
|
||||||
|
|
||||||
def broadcast(typ: str, data) -> bool:
|
def broadcast(typ: str, payload) -> bool:
|
||||||
"""Thread-safe broadcast helper.
|
"""Thread-safe broadcast to all connected WebSocket clients.
|
||||||
|
|
||||||
Schedules coroutine(s) on the running loop to send message to all
|
Can be called from any thread; schedules sends on the event loop.
|
||||||
connected websockets. Returns False if server was not running.
|
Returns False if the loop is not running yet.
|
||||||
"""
|
"""
|
||||||
if not _loop:
|
if not _loop:
|
||||||
return False
|
return False
|
||||||
jmsg = json.dumps({"type": typ, "data": data})
|
jmsg = json.dumps({"type": typ, "data": payload})
|
||||||
to_close = []
|
|
||||||
|
async def _send_all():
|
||||||
|
dead = set()
|
||||||
for ws in list(_connections):
|
for ws in list(_connections):
|
||||||
if ws.state != websockets.protocol.State.OPEN:
|
|
||||||
to_close.append(ws)
|
|
||||||
continue
|
|
||||||
try:
|
try:
|
||||||
asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop)
|
if not ws.closed:
|
||||||
|
await ws.send_str(jmsg)
|
||||||
|
else:
|
||||||
|
dead.add(ws)
|
||||||
except Exception:
|
except Exception:
|
||||||
to_close.append(ws)
|
dead.add(ws)
|
||||||
logger.debug("ws.send exception: closed")
|
for ws in dead:
|
||||||
for ws in to_close:
|
_connections.discard(ws)
|
||||||
try:
|
|
||||||
asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop)
|
asyncio.run_coroutine_threadsafe(_send_all(), _loop)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if ws in _connections:
|
|
||||||
_connections.remove(ws)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user