refactor
This commit is contained in:
@@ -0,0 +1,125 @@
|
||||
"""WebSocket server and broadcast helpers for hbd.
|
||||
|
||||
Provides an asyncio-based WebSocket server and a thread-safe broadcast
|
||||
function that other threads or synchronous code can call.
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Callable, Iterable, Optional
|
||||
|
||||
import websockets
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_connections = set()
|
||||
_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_get_hosts: Optional[Callable[[], Iterable]] = None
|
||||
_get_msgs: Optional[Callable[[], Iterable]] = None
|
||||
_verbose = False
|
||||
|
||||
|
||||
async def _handler(websocket, path):
|
||||
global _connections
|
||||
_connections.add(websocket)
|
||||
remote_address = websocket.remote_address
|
||||
if _verbose:
|
||||
logger.info("DBG ws_serve: %s: %s", remote_address, path)
|
||||
try:
|
||||
# send initial hosts
|
||||
if _get_hosts:
|
||||
for h in _get_hosts():
|
||||
jmsg = json.dumps({"type": "host", "data": h})
|
||||
await websocket.send(jmsg)
|
||||
# send recent messages
|
||||
if _get_msgs:
|
||||
for m in list(_get_msgs())[-100:]:
|
||||
jmsg = json.dumps({"type": "message", "data": m})
|
||||
await websocket.send(jmsg)
|
||||
|
||||
# keep connection open until client disconnects
|
||||
async for _ in websocket:
|
||||
# we don't expect meaningful incoming messages besides the initial
|
||||
# client 'hello' that some clients send; ignore for now
|
||||
if _verbose:
|
||||
logger.debug("received ws data: %s", _)
|
||||
|
||||
except (websockets.exceptions.ConnectionClosedOK, websockets.exceptions.ConnectionClosedError) as e:
|
||||
if _verbose:
|
||||
logger.info("ws closed: %r", e)
|
||||
except Exception as e:
|
||||
logger.exception("ws handler exception: %s", e)
|
||||
finally:
|
||||
try:
|
||||
_connections.remove(websocket)
|
||||
except KeyError:
|
||||
pass
|
||||
await websocket.wait_closed()
|
||||
|
||||
|
||||
async def start(host: str, ws_port: int, wss_port: Optional[int] = None, ssl_context=None, get_hosts: Optional[Callable] = None, get_msgs: Optional[Callable] = None, verbose: bool = False):
|
||||
"""Start WebSocket servers and block until cancelled.
|
||||
|
||||
This is intended to be awaited inside the main asyncio event loop.
|
||||
If `wss_port` and `ssl_context` are provided, a WSS server will also be
|
||||
started.
|
||||
"""
|
||||
global _loop, _get_hosts, _get_msgs, _verbose
|
||||
_loop = asyncio.get_running_loop()
|
||||
_get_hosts = get_hosts
|
||||
_get_msgs = get_msgs
|
||||
_verbose = verbose
|
||||
|
||||
servers = []
|
||||
# plain WebSocket
|
||||
ws_server = websockets.serve(_handler, host, ws_port, subprotocols=["hbd"])
|
||||
servers.append(ws_server)
|
||||
|
||||
# secure WebSocket (optional)
|
||||
if wss_port and ssl_context:
|
||||
wss_server = websockets.serve(_handler, host, wss_port, ssl=ssl_context, subprotocols=["hbd"])
|
||||
servers.append(wss_server)
|
||||
|
||||
# await starting of all servers
|
||||
for srv in servers:
|
||||
await srv
|
||||
|
||||
if _verbose:
|
||||
logger.info("WebSocket server started on port %s (wss %s)", ws_port, wss_port)
|
||||
|
||||
# block forever (until loop is stopped or cancelled)
|
||||
await asyncio.Future()
|
||||
|
||||
|
||||
def broadcast(typ: str, data) -> bool:
|
||||
"""Thread-safe broadcast helper.
|
||||
|
||||
Schedules coroutine(s) on the running loop to send message to all
|
||||
connected websockets. Returns False if server was not running.
|
||||
"""
|
||||
global _loop
|
||||
if not _loop:
|
||||
return False
|
||||
jmsg = json.dumps({"type": typ, "data": data})
|
||||
to_close = []
|
||||
for ws in list(_connections):
|
||||
if ws.closed:
|
||||
to_close.append(ws)
|
||||
continue
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop)
|
||||
except Exception:
|
||||
to_close.append(ws)
|
||||
logger.debug("ws.send exception: closed")
|
||||
for ws in to_close:
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop)
|
||||
except Exception:
|
||||
pass
|
||||
if ws in _connections:
|
||||
_connections.remove(ws)
|
||||
return True
|
||||
|
||||
|
||||
def connection_count() -> int:
|
||||
return len(_connections)
|
||||
Reference in New Issue
Block a user