0543266c92
- Restructuring of the project directory into client and server components - Renaming of modules and classes to better reflect their purpose and functionality - Moving common utilities and configurations to a shared location - Updating import statements to reflect the new structure - Adding new documentation files for better clarity on various aspects of the project - Removing deprecated or unused code to streamline the codebase - Ensuring that all existing functionality is preserved and that the codebase remains functional after the refactoring.
152 lines
4.9 KiB
Python
152 lines
4.9 KiB
Python
"""WebSocket server and broadcast helpers for hbd.
|
|
|
|
Provides an asyncio-based WebSocket server and a thread-safe broadcast
|
|
function that other threads or synchronous code can call.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Callable, Iterable, Optional
|
|
|
|
import websockets
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_connections = set()
|
|
_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
_get_hosts: Optional[Callable[[], Iterable]] = None
|
|
_get_msgs: Optional[Callable[[], Iterable]] = None
|
|
_verbose = False
|
|
|
|
|
|
async def _handler(websocket, path=None):
|
|
_connections.add(websocket)
|
|
remote_address = websocket.remote_address
|
|
if path is None:
|
|
path = getattr(websocket, "path", None)
|
|
logger.info("WebSocket connection from %s: %s", remote_address, path)
|
|
try:
|
|
# send initial hosts
|
|
if _get_hosts:
|
|
try:
|
|
hosts = list(_get_hosts())
|
|
logger.debug("Sending %d hosts to new WebSocket client", len(hosts))
|
|
for h in hosts:
|
|
jmsg = json.dumps({"type": "host", "data": h})
|
|
await websocket.send(jmsg)
|
|
except Exception as e:
|
|
logger.error("Error sending initial hosts: %s", e, exc_info=True)
|
|
# send recent messages
|
|
if _get_msgs:
|
|
try:
|
|
msgs = list(_get_msgs())[-100:]
|
|
logger.debug("Sending %d recent messages to new WebSocket client", len(msgs))
|
|
for m in msgs:
|
|
jmsg = json.dumps({"type": "message", "data": m})
|
|
await websocket.send(jmsg)
|
|
except Exception as e:
|
|
logger.error("Error sending initial messages: %s", e, exc_info=True)
|
|
|
|
# keep connection open until client disconnects
|
|
async for _ in websocket:
|
|
# we don't expect meaningful incoming messages besides the initial
|
|
# client 'hello' that some clients send; ignore for now
|
|
if _verbose:
|
|
logger.debug("received ws data: %s", _)
|
|
|
|
except (
|
|
websockets.exceptions.ConnectionClosedOK,
|
|
websockets.exceptions.ConnectionClosedError,
|
|
) as e:
|
|
logger.info("WebSocket closed from %s: %r", remote_address, e)
|
|
except Exception as e:
|
|
logger.exception("WebSocket handler exception from %s: %s", remote_address, e)
|
|
finally:
|
|
logger.debug("Removing WebSocket connection from %s", remote_address)
|
|
try:
|
|
_connections.remove(websocket)
|
|
except KeyError:
|
|
pass
|
|
await websocket.wait_closed()
|
|
|
|
|
|
async def start(
|
|
host: str,
|
|
ws_port: int,
|
|
wss_port: Optional[int] = None,
|
|
ssl_context=None,
|
|
get_hosts: Optional[Callable] = None,
|
|
get_msgs: Optional[Callable] = None,
|
|
verbose: bool = False,
|
|
):
|
|
"""Start WebSocket servers and block until cancelled.
|
|
|
|
This is intended to be awaited inside the main asyncio event loop.
|
|
If `wss_port` and `ssl_context` are provided, a WSS server will also be
|
|
started.
|
|
"""
|
|
global _loop, _get_hosts, _get_msgs, _verbose
|
|
_loop = asyncio.get_running_loop()
|
|
_get_hosts = get_hosts
|
|
_get_msgs = get_msgs
|
|
_verbose = verbose
|
|
|
|
servers = []
|
|
# plain WebSocket
|
|
websockets_logger = logging.getLogger("websockets.server")
|
|
websockets_logger.setLevel(logging.DEBUG if verbose else logging.INFO)
|
|
# regular WebSocket
|
|
ws_server = websockets.serve(_handler, host, ws_port) # , subprotocols=["hbd"])
|
|
servers.append(ws_server)
|
|
# secure WebSocket (optional)
|
|
if wss_port and ssl_context:
|
|
wss_server = websockets.serve(
|
|
_handler, host, wss_port, ssl=ssl_context
|
|
) # , subprotocols=["hbd"])
|
|
servers.append(wss_server)
|
|
|
|
# await starting of all servers
|
|
for srv in servers:
|
|
await srv
|
|
|
|
logger.info(
|
|
"WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port
|
|
)
|
|
|
|
# block forever (until loop is stopped or cancelled)
|
|
await asyncio.Future()
|
|
|
|
|
|
def broadcast(typ: str, data) -> bool:
|
|
"""Thread-safe broadcast helper.
|
|
|
|
Schedules coroutine(s) on the running loop to send message to all
|
|
connected websockets. Returns False if server was not running.
|
|
"""
|
|
if not _loop:
|
|
return False
|
|
jmsg = json.dumps({"type": typ, "data": data})
|
|
to_close = []
|
|
for ws in list(_connections):
|
|
if ws.state != websockets.protocol.State.OPEN:
|
|
to_close.append(ws)
|
|
continue
|
|
try:
|
|
asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop)
|
|
except Exception:
|
|
to_close.append(ws)
|
|
logger.debug("ws.send exception: closed")
|
|
for ws in to_close:
|
|
try:
|
|
asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop)
|
|
except Exception:
|
|
pass
|
|
if ws in _connections:
|
|
_connections.remove(ws)
|
|
return True
|
|
|
|
|
|
def connection_count() -> int:
|
|
return len(_connections)
|