refactor and rewrite for asyncio

This commit is contained in:
2026-02-06 12:34:59 -05:00
parent 700ea8d6a4
commit 4df700e4ef
54 changed files with 772 additions and 4334 deletions
+1 -1
View File
@@ -6,6 +6,6 @@ start moving functionality into the package.
"""
__all__ = ["main", "__version__"]
__version__ = "0.1"
__version__ = "5.0"
from .cli import main
+5 -3
View File
@@ -1,4 +1,5 @@
"""Configuration loader and defaults for hbd."""
import logging
import os
try:
@@ -14,6 +15,8 @@ DEFAULTS = {
"logfile": "/var/log/heartbeat.log",
"logfmt": "text",
"pushsrv": "pushover",
"pushover_token": "",
"pushover_user": "",
"interval": 20,
"grace": 2,
"dyndomains": ["wrede.org"],
@@ -40,14 +43,13 @@ def load_config(path=None):
if os.path.exists(path):
if yaml:
with open(path) as fh:
data = yaml.safe_load(fh) or {}
data = yaml.safe_load(fh)
# only keep known keys
for k, v in data.items():
if k in cfg:
cfg[k] = v
else:
# ignore unknown keys for now
pass
logging.warning("unknown config key %s in %s", k, path)
else:
# yaml not installed: do not attempt to parse; user must ensure defaults
pass
+98 -18
View File
@@ -1,9 +1,9 @@
"""DNS update helper and thread for heartbeat daemon."""
"""DNS update helper and pure asyncio worker for heartbeat daemon."""
from __future__ import annotations
import threading
import subprocess
from subprocess import Popen, PIPE, STDOUT
from typing import Optional
import asyncio
def create_nsupdate_payload(hostname: str, newip: str, dyndomain: str, dnsttl: str = "5") -> str:
@@ -54,38 +54,118 @@ def nsupdate(hostname: str, newip: str, dyndomain: str, nsupdate_bin: str = "/us
return out
def dnsupdatethread(hbdclass, cfg: dict, log: Optional[callable] = None, email: Optional[callable] = None):
"""Thread target: process dns update queue from hbdclass.Host.dnsQ.
async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional[callable] = None, email: Optional[callable] = None, loop: Optional[asyncio.AbstractEventLoop] = None):
"""Pure async DNS worker that processes updates from asyncio.Queue.
hbdclass: module with Host class that exposes dnsQ queue
cfg: configuration mapping with 'dyndomains' and 'nsupdate_bin'
log: callable(host, message)
email: callable(subject, message)
Exits when it receives a None sentinel.
"""
if loop is None:
loop = asyncio.get_running_loop()
dnsq = async_queue
if not dnsq:
if log:
try:
await loop.run_in_executor(None, log, None, "dns_update_worker: no queue available")
except Exception:
pass
return
while True:
name, addr = hbdclass.Host.dnsQ.get()
try:
item = await dnsq.get()
except Exception as e:
if log:
try:
await loop.run_in_executor(None, log, None, f"dns_update_worker: error getting item: {e}")
except Exception:
pass
break
if item is None:
break
try:
name, addr = item
except Exception:
try:
dnsq.task_done()
except Exception:
pass
continue
m = f"changed address to {addr}"
for dyndomain in cfg.get("dyndomains", []):
err = nsupdate(name, addr, dyndomain, nsupdate_bin=cfg.get("nsupdate_bin", "/usr/local/bin/nsupdate"))
err = await loop.run_in_executor(None, nsupdate, name, addr, dyndomain, cfg.get("nsupdate_bin", "/usr/local/bin/nsupdate"), cfg.get("rndc_key", "/etc/dhcpc/rndc-key"))
if err:
m += f", DNS update failed: {err}"
if email:
try:
email("error: nsupdate failed", f"{name}.dy.{dyndomain}: {m}")
await loop.run_in_executor(None, email, "error: nsupdate failed", f"{name}.dy.{dyndomain}: {m}")
except Exception:
pass
else:
m += ", DNS updated."
hbdclass.Host.dnsQ.task_done()
try:
dnsq.task_done()
except Exception:
pass
if log:
try:
log(name, m)
await loop.run_in_executor(None, log, name, m)
except Exception:
pass
if log:
try:
await loop.run_in_executor(None, log, None, "dns_update_worker exiting")
except Exception:
pass
def start_dns_thread(hbdclass, cfg: dict, log: Optional[callable] = None, email: Optional[callable] = None) -> threading.Thread:
t = threading.Thread(target=dnsupdatethread, args=(hbdclass, cfg, log, email))
t.daemon = True
t.start()
return t
def start_dns_worker(hbdclass, cfg: dict, log: Optional[callable] = None, email: Optional[callable] = None, loop: Optional[asyncio.AbstractEventLoop] = None):
"""Start the async DNS worker and return the Task.
Replaces Host.dnsQ with an asyncio.Queue wrapped in a thread-safe bridge
so legacy synchronous put() calls from UDP handlers still work.
"""
if loop is None:
loop = asyncio.get_event_loop()
# Create asyncio.Queue and wrap in a bridge for thread-safe puts
async_q = asyncio.Queue()
class _QueueBridge:
"""Thread-safe wrapper around asyncio.Queue for synchronous callers."""
def __init__(self, loop, aq):
self._loop = loop
self._aq = aq
def put(self, item):
"""Thread-safe put that schedules onto event loop."""
try:
# Try to detect if we're in the event loop thread
asyncio.get_running_loop()
# We're in event loop context, use put_nowait directly
self._aq.put_nowait(item)
except RuntimeError:
# We're in a different thread, schedule safely
try:
self._loop.call_soon_threadsafe(self._aq.put_nowait, item)
except Exception:
pass
def task_done(self):
"""Delegate task_done to asyncio.Queue."""
try:
self._aq.task_done()
except Exception:
pass
bridge = _QueueBridge(loop, async_q)
hbdclass.Host.dnsQ = bridge
task = loop.create_task(dns_update_worker(hbdclass, cfg, async_queue=async_q, log=log, email=email, loop=loop))
return task
+380
View File
@@ -0,0 +1,380 @@
"""
host and connection class shared between hbd and
the websit's heartbeat.py
"""
import time
import json
import copy
import queue
num = 0
MAXRTTS = 10
DEBUG = 2
def log(host, m):
if DEBUG:
print("class log: %s %s" % (host, m))
class Connection:
# map of addrs to names
htab = {}
UNKNOWN = "unknown"
UP = "up"
DOWN = "down"
OVERDUE = "overdue"
def __init__(self, host, cid, addr, afam):
self.host = host
self.cid = cid
if addr[0:7] == "::ffff:":
addr = addr[7:]
self.addr = addr
self.afam = afam
self.rtts = [0]
self.lastbeat = time.time()
self.statetime = self.lastbeat
self.deltastatetime = "computed"
self.state = Connection.UNKNOWN
if host:
Connection.htab[addr] = self.host.name
if self.host.isDynDns():
log(self.host.name, "dns update %s" % self.addr)
Host.dnsQ.put((self.host.name, self.addr))
def registerDns(self):
Host.dnsQ.put((self.host.name, self.addr))
def clearstate(self):
d = {}
d["addr"] = ""
d["rtt"] = ""
d["lastbeat"] = ""
d["state"] = ""
d["statetime"] = ""
d["deltastatetime"] = ""
d["rttstate"] = ""
return d
def statedict(self, Null=False):
d = self.clearstate()
now = time.time()
if not Null:
d["addr"] = self.addr
if self.rtts[-1]:
d["rtt"] = "%0.1f" % self.rtts[-1]
elif self.state == Connection.UNKNOWN:
d["rtt"] = ""
else:
d["rtt"] = "?"
d["lastbeat"] = self.lastbeat
if self.state == Connection.OVERDUE:
d["state"] = "<b>%s</b>" % self.state
else:
d["state"] = self.state
if self.state == Connection.UP:
d["rttstate"] = d["rtt"]
elif self.state == Connection.OVERDUE:
d["rttstate"] = ""
else:
d["rttstate"] = d["state"]
d["statetime"] = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(self.statetime)
)
delta = now - self.statetime
if self.state == Connection.UNKNOWN:
d["deltastatetime"] = ""
elif delta > 86400:
# d['deltastatetime'] = time.strftime("%d %H:%M:%S", time.gmtime(delta))
d["deltastatetime"] = "%0.1f days" % (delta / 86400.0)
elif delta > 3600:
# d['deltastatetime'] = time.strftime("%H:%M:%S", time.gmtime(delta))
d["deltastatetime"] = time.strftime("%k:%M hrs", time.gmtime(delta))
# d['deltastatetime'] = "%0.1f hrs" % (delta / 3600.)
elif delta > 60:
# d['deltastatetime'] = time.strftime("%M:%S", time.gmtime(delta))
d["deltastatetime"] = time.strftime("%M:%S mins", time.gmtime(delta))
# d['deltastatetime'] = "%0.1f mins" % (delta / 60.)
else:
# d['deltastatetime'] = time.strftime("%S", time.gmtime(delta))
d["deltastatetime"] = "%i secs" % (delta)
if self.state == Connection.UNKNOWN and now - self.lastbeat > 86400 * 10:
d = self.clearstate()
return d
def headerdict(self, afam):
d = {}
d["addr"] = "%s Addr" % afam
d["rtt"] = "Latencey"
d["lastbeat"] = "Last Contact"
d["state"] = "State"
d["statetime"] = "Last State"
d["rttstate"] = "Reach"
d["deltastatetime"] = "Last State"
return d
def jsons(self):
return json.dumps(self.__dict__)
# set new state, return number of secs in previous state
def newstate(self, state, now, when=0):
self.state = state
delta = now - when
s = delta - self.statetime
self.statetime = delta
return s
def getstate(self):
return self.state
def newaddr(self, addr, rtt, now):
self.lastbeat = now
self.rtts.append(rtt)
if len(self.rtts) > MAXRTTS:
del self.rtts[0]
if self.addr == addr:
r = None
else:
r = "changed from %s to %s" % (self.addr, addr)
try:
del Connection.htab[self.addr]
except:
pass
self.addr = addr
Connection.htab[addr] = self.host.name
if self.host.isDynDns():
Host.dnsQ.put((self.host.name, self.addr))
return r
#
class Host:
# Table of Hosts
hosts = {}
dnsQ = queue.Queue()
def __init__(self, name):
global num
self.name = name
if name:
num += 1
Host.hosts[name] = self
self.num = num
self.dyn = False
self.watched = False
self.upcount = 0
self.interval = 0
self.doesack = -1
self.cmds = []
self.cver = 0
self.connections = {}
self.hdwcounts = [[0, 0], [0, 0], [0, 0]]
def statedict(self):
d = {}
d["name"] = self.name
if self.dyn:
d["name"] += "*"
if self.watched:
d["name"] = "<b>%s</b>" % d["name"]
d["dyn"] = str(self.dyn)
d["ver"] = str(self.cver)
d["num"] = self.num
for c in ["IPv4", "IPv6"]:
if c in self.connections:
cs = self.connections[c].statedict()
else:
cs = ubConnection.statedict(True)
for csv in cs:
d["%s.%s" % (c, csv)] = cs[csv]
return d
def headerdict(self):
d = {}
d["name"] = "Name"
d["dyn"] = "Dyn"
d["ver"] = "Ver"
d["num"] = "??"
for c in ["IPv4", "IPv6"]:
cs = ubConnection.headerdict(c)
for csv in cs:
d["%s.%s" % (c, csv)] = cs[csv]
return d
def registerDns(self):
for af in self.connections:
self.connections[af].registerDns()
def stateinfo(self):
ddict = {}
for d in self.__dict__:
if d == "connections":
cl = []
for c in self.connections:
# dirty ugly hack: fix conn to host backpointer
cld = copy.deepcopy(self.connections[c].__dict__)
cld["host"] = cld["host"].name
cl.append(cld)
ddict[d] = cl
else:
ddict[d] = self.__dict__[d]
return ddict
def jsons(self):
return json.dumps(self.stateinfo())
def setcver(self, cver):
self.cver = cver
def isDynDns(self):
return self.dyn
def isIPv4(self, addr):
if isinstance(addr, tuple):
return addr[0].find(".") > 0
else:
return addr.find(".") > 0
def conndata(self, cid, addr, rtt, now):
if addr[0:7] == "::ffff:":
addr = addr[7:]
if self.isIPv4(addr):
afam = "IPv4"
else:
afam = "IPv6"
if afam not in self.connections:
self.connections[afam] = Connection(self, cid, addr, afam)
conn = self.connections[afam]
res = conn.newaddr(addr, rtt, now)
return conn, res
# called when reloading class from pickle, add new fields here
def fixup(self):
for c in ["IPv4", "IPv6"]:
if c in self.connections:
addr = self.connections[c].addr
if addr[0:7] == "::ffff:":
addr = addr[7:]
self.connections[c].addr = addr
pass
# def dispstate(self):
# if self.state in ["down", "overdue"]:
# state = "<b>%s</b>" % self.state
# elif self.state in ["up", "UP"]:
# state = ""
# for x in list(self.connections.keys()):
# try:
# state += " %5.1f" % (self.connections[x].rtts[-1])
# except:
# state += " %5s" % (self.connections[x].rtts[-1])
# elif self.state in ["unknown", "UNKNOWN"]:
# state = ""
# else:
# state = "%s" % self.state
# return state
def dispstats(self):
if self.doesack != -1:
if self.upcount > 0:
# return "(%0.1f%%) %s %s %s " % ((self.doesack * 100.0) / self.upcount, self.doesack, self.upcount, self.hdwcounts)
r = ""
for v in range(3):
a, u = self.hdwcounts[v]
if (self.upcount - u) != 0:
vs = "%0.0f" % (
100.0 - (((self.doesack - a) * 100.0) / (self.upcount - u))
)
if vs == "0":
vs = ""
else:
vs = "-"
r += '<td align="right">%s</td>' % vs
return r
else:
return "<td>(%s)</td><td></td><td></td>" % (self.doesack)
return '<td align="right">N/A</td><td></td<td></td>>'
hostfields_long = [
"name",
"IPv4.addr",
"IPv4.state",
("IPv4.rtt", 'style="text-align: right;"'),
("IPv4.statetime", 'style="text-align: right;"'),
"IPv6.addr",
"IPv6.state",
("IPv6.rtt", 'style="text-align: right;"'),
("IPv6.statetime", 'style="text-align: right;"'),
"ver",
]
hostfields_short = [
"name",
("IPv4.rttstate", 'style="text-align: right;"'),
("IPv4.deltastatetime", 'style="text-align: right;"'),
("IPv6.rttstate", 'style="text-align: right;"'),
("IPv6.deltastatetime", 'style="text-align: right;"'),
]
def gene(self, tag, v, attrib=None):
if attrib:
a = " %s" % attrib
else:
a = ""
return "<%s%s>%s</%s>" % (tag, a, v, tag)
def htmltable(self, tag, hd, short):
if short:
hostfields = Host.hostfields_short
else:
hostfields = Host.hostfields_long
h = []
for f in hostfields:
if isinstance(f, tuple):
h.append(self.gene(tag, hd[f[0]], f[1]))
else:
h.append(self.gene(tag, hd[f]))
return self.gene("tr", "\n".join(h))
def buildhosttable(self, short=False):
if DEBUG > 1:
print("DBG buildhosttable: start")
res = []
res.append('<table id="ntable" class="sortable">')
res.append(ubHost.htmltable("th", ubHost.headerdict(), short))
hosts_sorted = list(Host.hosts.keys())
if len(hosts_sorted):
hosts_sorted.sort()
for h in hosts_sorted:
res.append(ubHost.htmltable("td", Host.hosts[h].statedict(), short))
res.append("</table>")
if DEBUG > 1:
print("DBG buildhosttable: %s" % res)
return res
def buildmsgtable(self, msgs):
res = []
le = max(40 - len(Host.hosts), 3)
res.append("<h4>Log of Events</h4>")
for m in msgs[len(msgs) - le:]:
res.append("%s<BR>" % m)
return res
# create fake "unbound objects", remove in Python 3.0
ubHost = Host(None)
ubConnection = Connection(None, "", "", "")
+157 -198
View File
@@ -1,20 +1,23 @@
"""HTTP server and handler scaffolds (thin wrappers around http.server)."""
from http import server
"""HTTP server implementation using aiohttp and jinja2."""
import asyncio
import json
import time
import urllib.parse
from urllib3 import request
import os
import logging
from aiohttp import web
from fastapi.templating import Jinja2Templates
import jinja2
class HttpServer(server.ThreadingHTTPServer):
allow_reuse_address = True
logger = logging.getLogger(__name__)
def threaded(self):
pass
def _render_template(html_str: str, **context) -> str:
tmpl = jinja2.Template(html_str)
return tmpl.render(**context)
def make_handler_class(
async def start(
host: str,
port: int,
config,
hbdclass,
msgs_getter,
@@ -28,209 +31,165 @@ def make_handler_class(
get_now=None,
VER="",
):
"""Return a BaseHTTPRequestHandler subclass bound to runtime objects.
"""Start an aiohttp web server and block until cancelled.
`msgs_getter` should be a callable that returns a list-like of messages.
This function is intended to be awaited inside the main asyncio event loop.
"""
templates = Jinja2Templates(directory="templates")
get_now = get_now or (lambda: time.time())
class CustomHandler(server.BaseHTTPRequestHandler):
async def index(request):
res = []
res.append('<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">')
res.append("<html>")
res.append("<head>")
res.append(f"<title>Heartbeat</title>")
if tcss:
res.append(tcss)
res.append("</head>")
res.append('<body BGCOLOR = "#FFFFFF" LINK = "#008000" VLINK = "#008000">')
res.append(f"<H2>Heartbeat status {VER}</h2>")
res += hbdclass.ubHost.buildhosttable()
res += hbdclass.ubHost.buildmsgtable(msgs_getter())
res.append(
"<p> %s (%s)</p>" % (time.strftime("%H:%M:%S", time.localtime(get_now())), config.get("tz", "CET-1CDT"))
)
res.append("</body></html>")
body = "\n".join(res)
return web.Response(text=body, content_type="text/html")
server_version = f"HeartbeatHTTP/{VER}"
async def api_hosts(request):
lst = [hbdclass.Host.hosts[h].jsons() for h in hbdclass.Host.hosts]
return web.json_response(json.loads("[" + ",".join(lst) + "]"))
def version_string(self):
return self.server_version
async def api_messages(request):
lst = msgs_getter()[-30:]
return web.json_response(lst)
def handle(self):
async def cmd(request):
qa = request.rel_url.query
uname = qa.get("h")
ucmd = qa.get("c")
if not ucmd or not uname:
return web.Response(status=400, text="need h= and c= arguments")
if uname not in hbdclass.Host.hosts:
return web.Response(status=400, text=f"h={uname} not found")
hbdclass.Host.hosts[uname].cmds.append(("CMD", {"cmd": urllib.parse.unquote(ucmd)}))
return web.Response(text=f"cmd {uname} queued")
async def drop(request):
qa = request.rel_url.query
uname = qa.get("h")
if not uname:
return web.Response(status=400, text="need h= argument")
if uname not in hbdclass.Host.hosts:
return web.Response(status=400, text=f"h={uname} not found")
if log:
log(uname, "dropped")
del hbdclass.Host.hosts[uname]
return web.Response(text="Done")
async def register(request):
qa = request.rel_url.query
uname = qa.get("h")
if not uname:
return web.Response(status=400, text="need h= argument")
if uname not in hbdclass.Host.hosts:
return web.Response(status=400, text=f"h={uname} not found")
ll = hbdclass.Host.hosts[uname].registerDns()
if log:
log(uname, ll)
return web.Response(text=str(ll))
async def update(request):
qa = request.rel_url.query
uname = urllib.parse.unquote(qa.get("h", ""))
ucode = qa.get("c")
if not ucode or not uname:
return web.Response(status=400, text="need h= and c= arguments")
if uname != "All" and uname not in hbdclass.Host.hosts:
return web.Response(status=400, text=f"h={uname} not found")
if uname != "All":
names = [uname]
else:
names = [n for n in hbdclass.Host.hosts if hbdclass.Host.hosts[n].cver >= 2]
out = []
for n in names:
err = None
try:
return server.BaseHTTPRequestHandler.handle(self)
r = {"csum": None, "code": ucode}
hbdclass.Host.hosts[n].cmds.append(("UPD", r))
except Exception as e:
self.log_error("Request went away: %r", e)
self.close_connection = 1
return
err = str(e)
out.append(f"update started for {n}: {err if err else 'OK'}")
return web.Response(text="\n".join(out))
def do_HEAD(self):
self.setheaders(200)
async def restart(request):
# signal main application to perform restart if needed
# not implemented here - return OK
if log:
log(None, "restart request")
return web.Response(text="restart request")
def setheaders(self, code, headerdict={}):
self.send_response(code)
self.send_header(
"Last-Modified",
time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(get_now())),
)
for h in headerdict:
self.send_header(h, headerdict[h])
self.end_headers()
async def live(request):
# render template from templates/live.html using Jinja2
env = jinja2.Environment(loader=jinja2.FileSystemLoader(config.get("templates_dir", "templates")))
host = config.get("hb_host", "localhost")
extra_scripts = config.get("http_extra_scripts", "")
heartbeat_ws_url = f"ws://{host}:{config.get('ws_port', 50005)}/hbd"
tmpl = env.get_template("live.html")
body = tmpl.render(
title="Heartbeat",
header="Heartbeat",
request=request,
heartbeat_ws_url=heartbeat_ws_url,
extra_scripts=extra_scripts,
hosts=[hbdclass.Host.hosts[h].stateinfo() for h in sorted(hbdclass.Host.hosts)],
messages=msgs_getter()[-30:],
)
return web.Response(text=body, content_type="text/html")
def buildhead(self, title="Heartbeat", refresh=None, extras=None):
res = []
res.append('<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">')
res.append("<html>")
res.append("<head>")
res.append("<title>%s</title>" % (title))
if refresh:
res.append("<meta http-equiv = Refresh content = %d>\n" % refresh)
if extras:
res.append(extras)
res.append("</head>")
res.append('<body BGCOLOR = "#FFFFFF" LINK = "#008000" VLINK = "#008000">')
return res
async def static(request):
"""Serve files from the package static directory.
def buildpage(self):
res = self.buildhead(refresh=60, extras=tcss)
res.append("<H2>Heartbeat status %s</h2>" % VER)
res += hbdclass.ubHost.buildhosttable()
res += hbdclass.ubHost.buildmsgtable(msgs_getter())
res.append(
"<p> %s (%s)</p>" % (time.strftime("%H:%M:%S", time.localtime(get_now())), config.get("tz", "CET-1CDT"))
)
res.append("</body></html>")
return res
URL form: /static/<path>
"""
p = request.match_info.get("path", "")
base = os.path.abspath(os.path.join(os.path.dirname(__file__), "static"))
# normalize and prevent directory traversal
target = os.path.abspath(os.path.normpath(os.path.join(base, p)))
if not target.startswith(base + os.sep) and target != base:
return web.Response(status=403, text="Forbidden")
if not os.path.exists(target) or not os.path.isfile(target):
return web.Response(status=404, text="Not Found")
logger.info("serving static file: %s", target)
return web.FileResponse(path=target)
def builderror(self, code, cause, lcause):
res = []
res.append('<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">')
res.append("<html><head>")
res.append("<title>%s %s</title>" % (code, cause))
res.append("</head><body>")
res.append("<h1>%s</h1>" % (cause))
res.append("<p>%s</p>" % lcause)
res.append("<hr>")
res.append(
"<address>hbd (Unix) Server at %s:%s</address>" % (config.get("hbd_host"), config.get("hbd_port"))
)
res.append("</body></html>")
return code, res
app = web.Application()
app.add_routes(
[
web.get("/", index),
web.get("/api/0/hosts", api_hosts),
web.get("/api/0/messages", api_messages),
web.get("/c", cmd),
web.get("/d", drop),
web.get("/n", register),
web.get("/u", update),
web.get("/r", restart),
web.get("/live", live),
web.get("/static/{path:.*}", static),
]
)
def do_GET(self):
xsig = 0
rqAcceptEncoding = self.headers.get("Accept-encoding", {})
headerdict = {"Content-Type": "text/html; charset = ISO-8859-1"}
qr = urllib.parse.urlparse(self.path)
qa = urllib.parse.parse_qs(qr.query)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
if qr.path == "/":
res = self.buildpage()
if verbose:
print(f"HTTP server started on {host}:{port}")
elif qr.path == "/c": # command on host /c?h=melschserver&c=sudo%20ls
uname = qa.get("h", [None])[0]
ucmd = qa.get("c", [None])[0]
if not ucmd or not uname:
code, res = self.builderror(400, "Argument error", "need h= and c= arguments")
elif uname not in hbdclass.Host.hosts:
code, res = self.builderror(400, "Data error", "h=%s not found" % uname)
else:
hbdclass.Host.hosts[uname].cmds.append(("CMD", {"cmd": urllib.parse.unquote(ucmd)}))
res = self.buildhead()
res.append("cmd %s queued for host %s" % (uname, ucmd))
try:
await asyncio.Future()
finally:
await runner.cleanup()
elif qr.path == "/d": # drop host /d?h=melschserver
uname = qa.get("h", [None])[0]
if not uname:
code, res = self.builderror(400, "Argument error", "need h= argument")
if uname not in hbdclass.Host.hosts:
code, res = self.builderror(400, "Data error", "h=%s not found" % uname)
else:
if log:
log(uname, "dropped")
del hbdclass.Host.hosts[uname]
res = self.buildhead()
res.append("Done")
elif qr.path == "/n": # register name
uname = qa.get("h", [None])[0]
if not uname:
code, res = self.builderror(400, "Argument error", "need h= argument")
if uname not in hbdclass.Host.hosts:
code, res = self.builderror(400, "Data error", "h=%s not found" % uname)
else:
ll = hbdclass.Host.hosts[uname].registerDns()
res = self.buildhead()
res.append(ll)
if log:
log(uname, ll)
elif qr.path == "/u": # update
uname = urllib.parse.unquote(qa.get("h", [None])[0])
ucode = qa.get("c", [None])[0]
if not ucode or not uname:
code, res = self.builderror(400, "Argument error", "need h= and c= arguments")
elif uname != "All" and uname not in hbdclass.Host.hosts:
code, res = self.builderror(400, "Data error", "h=%s not found" % uname)
else:
res = self.buildhead()
if uname != "All":
names = [uname]
else:
names = []
for n in hbdclass.Host.hosts:
if hbdclass.Host.hosts[n].cver >= 2: # earliest version that supports update
names.append(n)
for n in names:
err = None
try:
from hbd import proto
# read code from a file name, fallback to sending ucode as data
err = None
# attempt to send update command to host
r = {"csum": None, "code": ucode}
hbdclass.Host.hosts[n].cmds.append(("UPD", r))
except Exception as e:
err = str(e)
res.append("update started for %s: %s<br>" % (n, err if err else "OK"))
res.append("Done")
elif qr.path == "/api/0/hosts": # api access to host table
headerdict = {"Content-Type": "application/json; charset=utf-8"}
lst = []
for h in hbdclass.Host.hosts:
lst.append(hbdclass.Host.hosts[h].jsons())
res = ["[" + ",".join(lst) + "]"]
elif qr.path == "/api/0/messages": # api access to host table
headerdict = {"Content-Type": "application/json; charset=utf-8"}
lst = msgs_getter()[-30:]
res = [json.dumps(lst)]
elif qr.path == "/r": # restart
res = self.buildhead()
res.append("restart request")
xsig = 1 # signal.SIGHUP will be handled by application
if log:
log(None, "restart request")
elif qr.path == "/live": # show live view with websockets
host = config.get("hb_host", "localhost")
extra_scripts = '' # '<script src="/static/js/live.js"></script>'
heartbeat_ws_url = f"ws://{host}:50005/hbd"
res = templates.TemplateResponse(
"live.html ",
{
"title": "Heartbeat",
"header": "Heartbeat",
"heartbeat_ws_url": heartbeat_ws_url,
"extra_scripts": extra_scripts,
},
)
else:
code, res = self.builderror(404, "Not Found", "requested URL was not found on this server.")
if "deflate" in rqAcceptEncoding:
headerdict["Content-Encoding"] = "deflate"
towrite = __import__("zlib").compress("\n".join(res).encode(), 6)
else:
towrite = "\n".join(res)
headerdict["Content-Length"] = len(towrite)
headerdict["Cache-Control"] = "private, must-revalidate, max-age=0"
headerdict["Expires"] = "Thu, 01 Jan 1970 00:00:00 GMT"
self.setheaders(200 if 'res' in locals() else code, headerdict)
self.wfile.write(towrite if isinstance(towrite, bytes) else towrite.encode())
if xsig:
# inform application via setting a flag on the server instance
try:
self.server.xsig = xsig
except Exception:
pass
return CustomHandler
+46
View File
@@ -0,0 +1,46 @@
"""monitor helper and thread for heartbeat daemon."""
from __future__ import annotations
import asyncio
import threading
import subprocess
import time
from subprocess import Popen, PIPE, STDOUT
from typing import Optional
from . import hbdclass
DROPOVERDUE = 7 * 24 * 3600
def checkoverdue(config: dict, hbdclass, log: callable, email: callable, pushmsg: callable, msg_to_websockets: callable):
now = time.time()
for h in list(hbdclass.Host.hosts.keys()):
pmsg = []
for c in hbdclass.Host.hosts[h].connections:
conn = hbdclass.Host.hosts[h].connections[c]
if conn.state == hbdclass.Connection.DOWN:
continue
timeout = hbdclass.Host.hosts[h].interval + config.get("grace", 10)
if conn.state == hbdclass.Connection.UP and (now - conn.lastbeat) > timeout:
conn.newstate(hbdclass.Connection.OVERDUE, now, config.get("grace", 10))
pmsg.append(conn.afam)
if (
conn.state == hbdclass.Connection.OVERDUE and (now - conn.lastbeat) > DROPOVERDUE
):
conn.newstate(hbdclass.Connection.UNKNOWN, conn.lastbeat)
if pmsg != []:
if h in config.get("watchhosts", []):
email("overdue", "%s overdue" % " and ".join(pmsg))
pushmsg("%s %s overdue" % (h, " and ".join(pmsg)))
log(h, "%s overdue" % " and ".join(pmsg))
msg_to_websockets("host", hbdclass.Host.hosts[h].stateinfo())
async def start(
config: dict,
hbdclass: callable,
log=None,
email=None,
pushmsg=None,
msg_to_websockets=None,
):
""" start a monitor loop that checks for overdue hosts every minute """
while True:
await asyncio.sleep(15) # 15 seconds between checks
checkoverdue(config, hbdclass, log, email, pushmsg, msg_to_websockets)
+12 -20
View File
@@ -1,4 +1,5 @@
"""Notification helpers: email, pushover, mattermost, signal and dispatcher."""
import logging
from typing import Optional
import http.client
import urllib.parse
@@ -11,6 +12,7 @@ DEFAULT_PUSHPROVIDERS = ["all", "pushover", "mattermost", "signal"]
# module-level configuration set via setup()
_config = {}
logger = logging.getLogger(__name__)
def setup(cfg: dict):
@@ -27,8 +29,7 @@ def send_email(aemail, smtpserver, sender, subject, body, debug=0):
server.set_debuglevel(1)
server.sendmail(sender, aemail, body)
except Exception as e:
if debug:
print("email send failed:", e)
logger.warning("email send failed: %s", e)
try:
server.quit()
except Exception:
@@ -72,12 +73,10 @@ def pushover(token: str, user: str, msg: str, debug: int = 0) -> bool:
{"Content-type": "application/x-www-form-urlencoded"},
)
r = conn.getresponse()
if debug:
print("pushover response:", r.status, r.reason)
logger.debug("pushover response: %s %s", r.status, r.reason)
return r.status == 200
except Exception as e:
if debug:
print("pushover error:", e)
logger.error("pushover error: %s", e)
return False
@@ -98,12 +97,10 @@ def pushmattermost(host: str, token: str, channel: str, msg: str, username: str
payload["icon_url"] = icon
try:
rc = mm.webhooks.call_webhook(token, payload)
if debug:
print("mattermost rc:", rc)
logger.debug("mattermost rc: %s", rc)
return bool(rc is None or rc == "")
except Exception as e:
if debug:
print("mattermost error:", e)
logger.error("mattermost error: %s", e)
return False
@@ -113,20 +110,16 @@ def pushsignal(signal_cli_bin: str, user: str, recipient: str, msg: str, debug:
Uses subprocess to call signal-cli. Returns True if the command succeeded.
"""
CLI = [signal_cli_bin, "-u", user, "send", "-m", msg, recipient]
if debug:
print("signal cli: ", CLI)
logger.debug("signal cli: %s", CLI)
try:
res = subprocess.run(CLI, capture_output=True)
if res.returncode != 0:
if debug:
print("signal failed:", res.stderr.decode())
logger.error("signal failed: %s". res.stderr.decode())
return False
if debug:
print("signal sent:", res.stdout.decode())
logger.debug("signal sent: %s", res.stdout.decode())
return True
except Exception as e:
if debug:
print("signal exception:", e)
logger.exception("signal exception: %s", e)
return False
@@ -152,8 +145,7 @@ def pushmsg(cfg: dict, msg: str, debug: int = 0):
if p in ("all", "signal"):
ok = pushsignal(cfg.get("signal_cli", "/usr/local/bin/signal-cli"), cfg.get("signal_user", ""), cfg.get("signal_recipient", ""), msg, debug=debug)
results["signal"] = ok
if debug:
print("push results:", results)
logger.debug("push results: %s", results)
return results
+242 -48
View File
@@ -1,38 +1,88 @@
"""Server runtime: starts UDP listener, HTTP server and websocket stubs."""
import asyncio
import logging
import atexit
import time
import signal
import sys
from . import __version__
from . import udp
from . import hbdclass
from . import ws as ws_mod
logger = logging.getLogger(__name__)
msg_to_websockets = ws_mod.broadcast
logf = None
lastfm = ["", "", ""]
# shared runtime collections and helpers
msgs = []
def initlog(logfile):
try:
return open(logfile, "a+")
except Exception as e:
import sys
print("cannot open loffile %s, using STDERR: %s" % (logfile, e))
return sys.stderr
def log(host, m, service=None):
ts = time.time()
s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {host or ''} {m}"
msgs.append(s)
logger.info(s)
if logf:
try:
logf.write(s + "\n")
logf.flush()
except Exception as e:
logger.warning("failed to write to logfile: %s", e)
msg_to_websockets("message", s)
def cleanup_function(config):
"""This function will be executed upon program exit."""
logger.info("Running cleanup function...")
import pickle
pickfile = config.get("pickfile", "hbd.pickle")
pickf = open(pickfile, "wb")
pick = pickle.Pickler(pickf)
pick.dump(hbdclass.Host.hosts)
pick.dump(msgs)
pick.dump(lastfm)
pickf.close()
logger.info("Cleanup complete.")
async def _run_async(config):
global msgs
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
# shared runtime collections and helpers
msgs = []
# Signal handlers for graceful shutdown
def signal_handler(signum, frame):
sig_name = signal.Signals(signum).name if hasattr(signal, 'Signals') else signum
logger.info(f"Received {sig_name}, initiating shutdown...")
loop.call_soon_threadsafe(shutdown_event.set)
# Register signal handlers
loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT, None)
loop.add_signal_handler(signal.SIGTERM, signal_handler, signal.SIGTERM, None)
# prepare runtime dependencies
import threading
import time
import hbdclass
from . import hbdclass
from . import http as http_mod
from . import ws as ws_mod
from . import dns as dns_mod
from . import notify as notify_mod
from . import monitor as monitor_mod
notify_mod.setup(config)
def log(host, m, service=None):
ts = time.time()
s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {host or ''} {m}"
msgs.append(s)
logger.info(s)
email = notify_mod.email
pushmsg = notify_mod.pushmsg_from_config
msg_to_websockets = ws_mod.broadcast
# UDP server endpoint (handler wired to handle_datagram with context)
bind_addr = ("0.0.0.0", config.get("hb_port", 50003))
@@ -46,7 +96,6 @@ async def _run_async(config):
email=email,
pushmsg=pushmsg,
msg_to_websockets=msg_to_websockets,
msgs=msgs,
DEBUG=config.get("debug", 0),
verbose=config.get("verbose", False),
)
@@ -57,32 +106,37 @@ async def _run_async(config):
local_addr=bind_addr,
)
# HTTP server (runs in its own thread)
# HTTP server (asyncio-based via aiohttp)
try:
handler_cls = http_mod.make_handler_class(
config=config,
hbdclass=hbdclass,
msgs_getter=lambda: msgs,
log=log,
email=email,
pushmsg=pushmsg,
msg_to_websockets=msg_to_websockets,
tcss=None,
DEBUG=config.get("debug", 0),
verbose=config.get("verbose", False),
get_now=lambda: time.time(),
VER="",
http_task = asyncio.create_task(
http_mod.start(
host=config.get("hbd_host", ""),
port=config.get("hbd_port", 50004),
config=config,
hbdclass=hbdclass,
msgs_getter=lambda: msgs,
log=log,
email=email,
pushmsg=pushmsg,
msg_to_websockets=msg_to_websockets,
tcss=None,
DEBUG=config.get("debug", 0),
verbose=config.get("verbose", False),
get_now=lambda: time.time(),
VER="",
)
)
serv = http_mod.HttpServer((config.get("hbd_host", ""), config.get("hbd_port", 50004)), handler_cls)
http_thread = threading.Thread(target=serv.serve_forever, daemon=True)
http_thread.start()
logger.info("HTTP server started on %s:%s", config.get("hbd_host", ""), config.get("hbd_port", 50004))
except Exception as e:
logger.exception("failed to start HTTP server: %s", e)
# start dns update thread
dns_mod.start_dns_thread(hbdclass, config, log=log, email=email)
logger.info("dns update thread started")
# start dns update worker (async)
dns_task = None
try:
dns_task = dns_mod.start_dns_worker(hbdclass, config, log=log, email=email, loop=loop)
logger.info("dns update worker started")
except Exception as e:
logger.exception("dns worker failed to start: %s", e)
# Start the websocket servers as a background task
try:
@@ -101,28 +155,168 @@ async def _run_async(config):
except Exception as e:
logger.exception("websocket server failed to start: %s", e)
# Start the monitor thread as a background task
try:
# run forever
await asyncio.Future()
finally:
transport.close()
try:
serv.shutdown()
except Exception:
pass
try:
ws_task.cancel()
except Exception:
pass
monitor_task = asyncio.create_task(
monitor_mod.start(
config=config,
hbdclass=hbdclass,
log=log,
email=email,
pushmsg=pushmsg,
msg_to_websockets=msg_to_websockets,
)
)
logger.info("Monitor task started")
except Exception as e:
logger.exception("monitor task failed to start: %s", e)
try:
# run forever until shutdown event is set
await shutdown_event.wait()
logger.info("Shutdown signal received, stopping services...")
except Exception as e:
logger.exception("Error in main loop: %s", e)
finally:
# Cancel all running tasks
logger.info("Cancelling tasks...")
try:
transport.close()
except Exception as e:
logger.warning("Error closing UDP transport: %s", e)
tasks_to_cancel = [http_task, ws_task, monitor_task]
for task in tasks_to_cancel:
if task:
try:
task.cancel()
logger.debug("Cancelled task: %s", task)
except Exception as e:
logger.warning("Error cancelling task: %s", e)
# Wait for tasks to finish cancellation with timeout
remaining_tasks = [t for t in tasks_to_cancel if t]
if remaining_tasks:
try:
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=2.0)
except asyncio.TimeoutError:
logger.warning("Timeout waiting for tasks to cancel")
except Exception as e:
logger.debug("Exception during task cancellation: %s", e)
# Signal DNS worker to exit and await it
try:
if 'dns_task' in locals() and dns_task:
try:
hbdclass.Host.dnsQ.put(None)
except Exception:
pass
try:
await asyncio.wait_for(dns_task, timeout=2.0)
logger.info("DNS worker finished")
except asyncio.TimeoutError:
logger.warning("Timeout waiting for DNS worker to finish")
dns_task.cancel()
except asyncio.CancelledError:
logger.info("DNS worker was cancelled")
except Exception as e:
logger.warning("Error awaiting DNS worker: %s", e)
finally:
# Clear queue bridge to release any held references
hbdclass.Host.dnsQ = None
except Exception as e:
logger.warning("Error stopping DNS worker: %s", e)
logger.info("All tasks cancelled")
def load_pickled_hosts(config, hbdclass):
"""Load pickled hosts from file, if available."""
global lastfm, msgs
import os
import pickle
pickfile = config.get("pickfile", "hbd.pickle")
dyndnshosts = config.get("dyndnshosts", [])
watchhosts = config.get("watchhosts", [])
drophosts = config.get("drophosts", [])
if 1 and os.path.exists(pickfile):
if config.get("verbose", False):
logger.info("opening pickls %s", pickfile)
pickf = open(pickfile, "rb")
pick = pickle.Unpickler(pickf)
try:
hbdclass.Host.hosts = pick.load()
msgs = pick.load()
try:
lastfm = pick.load()
except:
lastfm = ["", "", ""]
pickf.close()
except Exception as e:
print(("load pickled failed: %s" % e))
os.unlink(pickfile)
hbdclass.Connection.htab = {}
for h in list(hbdclass.Host.hosts.keys()):
hbdclass.Host.hosts[h].dyn = h in dyndnshosts
hbdclass.Host.hosts[h].watched = h in watchhosts
hbdclass.Host.hosts[h].fixup()
for h in drophosts:
if h in hbdclass.Host.hosts:
del hbdclass.Host.hosts[h]
if config.get("verbose", False):
logger.info("%s pickled hosts loaded", len(hbdclass.Host.hosts))
else:
if config.get("verbose", False):
logger.info("no pickled data")
def run(config):
"""Start the hbd service (blocking).
This is a thin wrapper around asyncio.run to host the async services.
Manually manages the event loop to ensure clean shutdown.
"""
global logf
import os
import threading
import time as time_module
logging.basicConfig(level=logging.DEBUG if config.get("debug", 0) > 0 else logging.INFO)
load_pickled_hosts(config, hbdclass)
logf = initlog(logfile=config.get("logfile", "messages.log"))
log(None, f"hbd version {__version__} starting up")
# Create and set the event loop manually
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
asyncio.run(_run_async(config))
loop.run_until_complete(_run_async(config))
except KeyboardInterrupt:
logger.info("Shutting down (KeyboardInterrupt)")
logger.info("Received KeyboardInterrupt, shutting down...")
except Exception as e:
logger.exception("Unhandled exception in main: %s", e)
finally:
cleanup_function(config)
logger.info("hbd shutdown complete")
if logf and logf != sys.stderr:
try:
logf.close()
except Exception:
pass
# Explicitly close the loop
try:
# Cancel all remaining tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# Run one more cycle to process cancellations
if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
except Exception:
pass
finally:
loop.close()
# Exit
os._exit(0)
Binary file not shown.

After

Width:  |  Height:  |  Size: 5.3 KiB

+142
View File
@@ -0,0 +1,142 @@
/* http://www.designcouch.com/home/why/2014/04/23/pure-css-drawer-menu/ */
* {
box-sizing: border-box;
/* adds animation for all transitions */
transition: .25s ease-in-out;
/* margin: 0;
padding: 0; */
/* text-size-adjust: none; */
}
/* Makes sure that everything is 100% height */
html,
body {
height: 100%;
overflow: hidden;
color:#303030;
background:#fafafa top left repeat-y;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", "Oxygen", "Ubuntu", "Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", "Helvetica", "Arial", sans-serif;
font-size:100%;
margin: 0;
}
#drawer-toggle {
position: absolute;
opacity: 0;
}
#drawer ul a {
display: block;
padding: 10px;
color: #c7c7c7;
text-decoration: none;
}
#drawer-toggle-label {
user-select: none;
left: 0px;
height: 50px;
width: 50px;
display: block;
position: fixed;
color: rgb(242, 242, 242);
background: rgba(255, 255, 255, .0);
z-index: 1;
}
/* adds our "hamburger" menu icon */
#drawer-toggle-label:before {
content: '';
display: block;
position: absolute;
height: 2px;
width: 24px;
background: #8d8d8d;
left: 13px;
top: 18px;
box-shadow: 0 6px 0 #8d8d8d, 0 12px 0 #8d8d8d;
}
header {
width: 100%;
position: fixed;
left: 0px;
background: #efefef;
padding: 10px 10px 10px 50px;
font-size: 30px;
line-height: 30px;
z-index: 0;
}
/* drawer menu pane - note the 0px width */
#drawer {
position: fixed;
top: 0;
width: 150px;
left: -150px;
height: 100%;
background: #2f2f2f;
overflow-x: hidden;
overflow-y: scroll;
padding: 0px;
}
@media all and (min-resolution: 150dpi) {
header {
font-size: 30px;
/* line-height: 45px; */
}
#drawer {
font-size: 120%;
}
/* body {
background-color: lightyellow;
} */
}
/* actual page content pane */
#content {
margin-left: 0px;
margin-top: 30px;
/* width: 100%; */
height: calc(100% - 50px);
overflow-x: hidden;
overflow-y: scroll;
padding: 20px;
flex: auto;
}
/* checked styles (menu open state) */
#drawer-toggle:checked ~ #drawer-toggle-label {
height: 100%;
width: calc(100% - 150px);
color: rgb(242, 242, 242);
background: rgba(255, 255, 255, .8);
}
#drawer-toggle:checked ~ #drawer-toggle-label,
#drawer-toggle:checked ~ header {
left: 150px;
}
#drawer-toggle:checked ~ #drawer {
left: 0px;
}
#drawer-toggle:checked ~ #content {
margin-left: 150px;
}
#copyright {
font-size: 9px;
float: left;
}
-2
View File
@@ -71,7 +71,6 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
- email: callable(subject, message)
- pushmsg: callable(message)
- msg_to_websockets: callable(typ, data)
- msgs: list for storing message strings
- DEBUG, verbose
"""
if not msg:
@@ -83,7 +82,6 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
email = ctx.get("email")
pushmsg = ctx.get("pushmsg")
msg_to_websockets = ctx.get("msg_to_websockets")
msgs = ctx.get("msgs")
DEBUG = ctx.get("DEBUG", 0)
verbose = ctx.get("verbose", False)
+29 -11
View File
@@ -19,10 +19,14 @@ _get_msgs: Optional[Callable[[], Iterable]] = None
_verbose = False
async def _handler(websocket, path):
async def _handler(websocket, path=None):
# Some versions of the websockets library call handler(connection) only;
# accept optional path and fall back to websocket.path when missing.
global _connections
_connections.add(websocket)
remote_address = websocket.remote_address
remote_address = getattr(websocket, "remote_address", None)
if path is None:
path = getattr(websocket, "path", None)
if _verbose:
logger.info("DBG ws_serve: %s: %s", remote_address, path)
try:
@@ -72,23 +76,36 @@ async def start(host: str, ws_port: int, wss_port: Optional[int] = None, ssl_con
servers = []
# plain WebSocket
ws_server = websockets.serve(_handler, host, ws_port, subprotocols=["hbd"])
ws_server = websockets.serve(_handler, host, ws_port) #, subprotocols=["hbd"])
websockets_logger = logging.getLogger("websockets.server")
websockets_logger.setLevel(logging.INFO)
servers.append(ws_server)
# secure WebSocket (optional)
if wss_port and ssl_context:
wss_server = websockets.serve(_handler, host, wss_port, ssl=ssl_context, subprotocols=["hbd"])
wss_server = websockets.serve(_handler, host, wss_port, ssl=ssl_context) #, subprotocols=["hbd"])
servers.append(wss_server)
# await starting of all servers
for srv in servers:
await srv
try:
for srv in servers:
await srv
if _verbose:
logger.info("WebSocket server started on port %s (wss %s)", ws_port, wss_port)
if _verbose:
logger.info("WebSocket server started on port %s (wss %s)", ws_port, wss_port)
# block forever (until loop is stopped or cancelled)
await asyncio.Future()
# block forever (until loop is stopped or cancelled)
await asyncio.Future()
except asyncio.CancelledError:
logger.info("WebSocket server shutting down...")
# Close all active connections
for conn in list(_connections):
try:
await conn.close()
except Exception:
pass
_connections.clear()
raise
def broadcast(typ: str, data) -> bool:
@@ -98,12 +115,13 @@ def broadcast(typ: str, data) -> bool:
connected websockets. Returns False if server was not running.
"""
global _loop
if not _loop:
return False
jmsg = json.dumps({"type": typ, "data": data})
to_close = []
for ws in list(_connections):
if ws.closed:
if ws.state != websockets.protocol.State.OPEN:
to_close.append(ws)
continue
try: