refactor and rewrite for asyncio
This commit is contained in:
@@ -19,10 +19,14 @@ _get_msgs: Optional[Callable[[], Iterable]] = None
|
||||
_verbose = False
|
||||
|
||||
|
||||
async def _handler(websocket, path):
|
||||
async def _handler(websocket, path=None):
|
||||
# Some versions of the websockets library call handler(connection) only;
|
||||
# accept optional path and fall back to websocket.path when missing.
|
||||
global _connections
|
||||
_connections.add(websocket)
|
||||
remote_address = websocket.remote_address
|
||||
remote_address = getattr(websocket, "remote_address", None)
|
||||
if path is None:
|
||||
path = getattr(websocket, "path", None)
|
||||
if _verbose:
|
||||
logger.info("DBG ws_serve: %s: %s", remote_address, path)
|
||||
try:
|
||||
@@ -72,23 +76,36 @@ async def start(host: str, ws_port: int, wss_port: Optional[int] = None, ssl_con
|
||||
|
||||
servers = []
|
||||
# plain WebSocket
|
||||
ws_server = websockets.serve(_handler, host, ws_port, subprotocols=["hbd"])
|
||||
ws_server = websockets.serve(_handler, host, ws_port) #, subprotocols=["hbd"])
|
||||
websockets_logger = logging.getLogger("websockets.server")
|
||||
websockets_logger.setLevel(logging.INFO)
|
||||
servers.append(ws_server)
|
||||
|
||||
# secure WebSocket (optional)
|
||||
if wss_port and ssl_context:
|
||||
wss_server = websockets.serve(_handler, host, wss_port, ssl=ssl_context, subprotocols=["hbd"])
|
||||
wss_server = websockets.serve(_handler, host, wss_port, ssl=ssl_context) #, subprotocols=["hbd"])
|
||||
servers.append(wss_server)
|
||||
|
||||
# await starting of all servers
|
||||
for srv in servers:
|
||||
await srv
|
||||
try:
|
||||
for srv in servers:
|
||||
await srv
|
||||
|
||||
if _verbose:
|
||||
logger.info("WebSocket server started on port %s (wss %s)", ws_port, wss_port)
|
||||
if _verbose:
|
||||
logger.info("WebSocket server started on port %s (wss %s)", ws_port, wss_port)
|
||||
|
||||
# block forever (until loop is stopped or cancelled)
|
||||
await asyncio.Future()
|
||||
# block forever (until loop is stopped or cancelled)
|
||||
await asyncio.Future()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("WebSocket server shutting down...")
|
||||
# Close all active connections
|
||||
for conn in list(_connections):
|
||||
try:
|
||||
await conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
_connections.clear()
|
||||
raise
|
||||
|
||||
|
||||
def broadcast(typ: str, data) -> bool:
|
||||
@@ -98,12 +115,13 @@ def broadcast(typ: str, data) -> bool:
|
||||
connected websockets. Returns False if server was not running.
|
||||
"""
|
||||
global _loop
|
||||
|
||||
if not _loop:
|
||||
return False
|
||||
jmsg = json.dumps({"type": typ, "data": data})
|
||||
to_close = []
|
||||
for ws in list(_connections):
|
||||
if ws.closed:
|
||||
if ws.state != websockets.protocol.State.OPEN:
|
||||
to_close.append(ws)
|
||||
continue
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user