144 lines
4.3 KiB
Python
144 lines
4.3 KiB
Python
"""WebSocket server and broadcast helpers for hbd.
|
|
|
|
Provides an asyncio-based WebSocket server and a thread-safe broadcast
|
|
function that other threads or synchronous code can call.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Callable, Iterable, Optional
|
|
|
|
import websockets
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_connections = set()
|
|
_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
_get_hosts: Optional[Callable[[], Iterable]] = None
|
|
_get_msgs: Optional[Callable[[], Iterable]] = None
|
|
_verbose = False
|
|
|
|
|
|
async def _handler(websocket, path=None):
|
|
_connections.add(websocket)
|
|
remote_address = websocket.remote_address
|
|
if path is None:
|
|
path = getattr(websocket, "path", None)
|
|
if _verbose:
|
|
logger.info("DBG ws_serve: %s: %s", remote_address, path)
|
|
try:
|
|
# send initial hosts
|
|
if _get_hosts:
|
|
for h in _get_hosts():
|
|
jmsg = json.dumps({"type": "host", "data": h})
|
|
await websocket.send(jmsg)
|
|
# send recent messages
|
|
if _get_msgs:
|
|
for m in list(_get_msgs())[-100:]:
|
|
jmsg = json.dumps({"type": "message", "data": m})
|
|
await websocket.send(jmsg)
|
|
|
|
# keep connection open until client disconnects
|
|
async for _ in websocket:
|
|
# we don't expect meaningful incoming messages besides the initial
|
|
# client 'hello' that some clients send; ignore for now
|
|
if _verbose:
|
|
logger.debug("received ws data: %s", _)
|
|
|
|
except (
|
|
websockets.exceptions.ConnectionClosedOK,
|
|
websockets.exceptions.ConnectionClosedError,
|
|
) as e:
|
|
if _verbose:
|
|
logger.info("ws closed: %r", e)
|
|
except Exception as e:
|
|
logger.exception("ws handler exception: %s", e)
|
|
finally:
|
|
try:
|
|
_connections.remove(websocket)
|
|
except KeyError:
|
|
pass
|
|
await websocket.wait_closed()
|
|
|
|
|
|
async def start(
|
|
host: str,
|
|
ws_port: int,
|
|
wss_port: Optional[int] = None,
|
|
ssl_context=None,
|
|
get_hosts: Optional[Callable] = None,
|
|
get_msgs: Optional[Callable] = None,
|
|
verbose: bool = False,
|
|
):
|
|
"""Start WebSocket servers and block until cancelled.
|
|
|
|
This is intended to be awaited inside the main asyncio event loop.
|
|
If `wss_port` and `ssl_context` are provided, a WSS server will also be
|
|
started.
|
|
"""
|
|
global _loop, _get_hosts, _get_msgs, _verbose
|
|
_loop = asyncio.get_running_loop()
|
|
_get_hosts = get_hosts
|
|
_get_msgs = get_msgs
|
|
_verbose = verbose
|
|
|
|
servers = []
|
|
# plain WebSocket
|
|
websockets_logger = logging.getLogger("websockets.server")
|
|
websockets_logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
|
# regular WebSocket
|
|
ws_server = websockets.serve(_handler, host, ws_port) # , subprotocols=["hbd"])
|
|
servers.append(ws_server)
|
|
# secure WebSocket (optional)
|
|
if wss_port and ssl_context:
|
|
wss_server = websockets.serve(
|
|
_handler, host, wss_port, ssl=ssl_context
|
|
) # , subprotocols=["hbd"])
|
|
servers.append(wss_server)
|
|
|
|
# await starting of all servers
|
|
for srv in servers:
|
|
await srv
|
|
|
|
if _verbose:
|
|
logger.info(
|
|
"WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port
|
|
)
|
|
|
|
# block forever (until loop is stopped or cancelled)
|
|
await asyncio.Future()
|
|
|
|
|
|
def broadcast(typ: str, data) -> bool:
|
|
"""Thread-safe broadcast helper.
|
|
|
|
Schedules coroutine(s) on the running loop to send message to all
|
|
connected websockets. Returns False if server was not running.
|
|
"""
|
|
if not _loop:
|
|
return False
|
|
jmsg = json.dumps({"type": typ, "data": data})
|
|
to_close = []
|
|
for ws in list(_connections):
|
|
if ws.state != websockets.protocol.State.OPEN:
|
|
to_close.append(ws)
|
|
continue
|
|
try:
|
|
asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop)
|
|
except Exception:
|
|
to_close.append(ws)
|
|
logger.debug("ws.send exception: closed")
|
|
for ws in to_close:
|
|
try:
|
|
asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop)
|
|
except Exception:
|
|
pass
|
|
if ws in _connections:
|
|
_connections.remove(ws)
|
|
return True
|
|
|
|
|
|
def connection_count() -> int:
|
|
return len(_connections)
|