This commit is contained in:
2026-02-04 12:45:35 -05:00
parent 0a27b763f7
commit 700ea8d6a4
51 changed files with 4715 additions and 2904 deletions
-1342
View File
File diff suppressed because it is too large Load Diff
+11
View File
@@ -0,0 +1,11 @@
"""hbd package - scaffolding for heartbeat daemon
This package contains the refactored modules for the original monolithic
`hbd` script. The initial implementation contains small scaffolds so you can
start moving functionality into the package.
"""
__all__ = ["main", "__version__"]
__version__ = "0.1"
from .cli import main
+45
View File
@@ -0,0 +1,45 @@
"""Command line interface for hbd package."""
import argparse
from .config import load_config
from .server import run as run_server
PUSHSRVS = ["all", "pushover", "mattermost"]
def build_parser():
parser = argparse.ArgumentParser(
prog="hbd",
description="HeartBeatDaemon - Wait for heartbeat messages and act on them (or their absence)",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("-c", "--config", dest="configfile", help="Config file path (YAML)")
parser.add_argument("-f", "--foreground", action="store_true", help="Run in foreground")
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument("-p", "--pushsrv", dest="pushsrv", choices=PUSHSRVS, help="Push service to use")
parser.add_argument("-x", "--debug", action="count", default=0, help="Increase debug level")
return parser
def main(argv=None):
parser = build_parser()
args = parser.parse_args(argv)
config = load_config(args.configfile)
# Apply CLI overrides
if args.foreground:
config["foreground"] = True
if args.verbose:
config["verbose"] = True
if args.pushsrv:
config["pushsrv"] = args.pushsrv
if args.debug:
config.setdefault("debug", 0)
config["debug"] += args.debug
run_server(config)
if __name__ == "__main__":
main()
+54
View File
@@ -0,0 +1,54 @@
"""Configuration loader and defaults for hbd."""
import os
try:
import yaml
except Exception:
yaml = None
DEFAULTS = {
"hb_port": 50003,
"hbd_port": 50004,
"hbd_host": "",
"pickfile": "/tmp/hb.pick",
"logfile": "/var/log/heartbeat.log",
"logfmt": "text",
"pushsrv": "pushover",
"interval": 20,
"grace": 2,
"dyndomains": ["wrede.org"],
"watchhosts": [],
"dyndnshosts": [],
"drophosts": [],
"nsupdate_bin": "/usr/bin/nsupdate",
"foreground": False,
"verbose": False,
"debug": 0,
}
def load_config(path=None):
"""Load configuration from a YAML file and merge with defaults.
If YAML is not available or the file does not exist, defaults are returned.
"""
cfg = DEFAULTS.copy()
if not path:
# default path (~/.hb.yaml)
path = os.path.join(os.path.expanduser("~"), ".hb.yaml")
if os.path.exists(path):
if yaml:
with open(path) as fh:
data = yaml.safe_load(fh) or {}
# only keep known keys
for k, v in data.items():
if k in cfg:
cfg[k] = v
else:
# ignore unknown keys for now
pass
else:
# yaml not installed: do not attempt to parse; user must ensure defaults
pass
return cfg
+91
View File
@@ -0,0 +1,91 @@
"""DNS update helper and thread for heartbeat daemon."""
from __future__ import annotations
import threading
import subprocess
from subprocess import Popen, PIPE, STDOUT
from typing import Optional
def create_nsupdate_payload(hostname: str, newip: str, dyndomain: str, dnsttl: str = "5") -> str:
D = {"domain": dyndomain, "fqdn": f"{hostname}.dy.{dyndomain}", "dnsttl": dnsttl, "newip": newip, "ts": __import__("time").strftime("%Y-%m-%d.%H:%M:%S", __import__("time").gmtime())}
if ":" in newip:
nsup = (
"""update delete %(fqdn)s AAAA
update add %(fqdn)s %(dnsttl)s AAAA %(newip)s
update delete %(fqdn)s TXT
update add %(fqdn)s %(dnsttl)s TXT "Created: %(ts)s"
send
answer
""" % D
)
else:
nsup = (
"""update delete %(fqdn)s A
update add %(fqdn)s %(dnsttl)s A %(newip)s
update delete %(fqdn)s TXT
update add %(fqdn)s %(dnsttl)s TXT "Created: %(ts)s"
send
answer
""" % D
)
return nsup
def nsupdate(hostname: str, newip: str, dyndomain: str, nsupdate_bin: str = "/usr/local/bin/nsupdate", rndc_key: str = "/etc/dhcpc/rndc-key") -> Optional[str]:
"""Perform DNS update via nsupdate command.
Returns None on success, else returns combined stdout/stderr as a string.
"""
nsup = create_nsupdate_payload(hostname, newip, dyndomain)
cmd = [nsupdate_bin, "-k", rndc_key, "-v"]
try:
p = Popen(cmd, shell=False, bufsize=0, stdin=PIPE, stdout=PIPE, stderr=STDOUT)
except OSError as e:
return f"nsupdate: execution failed: {e}"
except Exception as e:
return f"nsupdate: some error occured: {e}"
(output, err) = p.communicate(nsup.encode())
out = output.decode() if output else ""
if out.find("status: NOERROR") >= 0:
return None
return out
def dnsupdatethread(hbdclass, cfg: dict, log: Optional[callable] = None, email: Optional[callable] = None):
"""Thread target: process dns update queue from hbdclass.Host.dnsQ.
hbdclass: module with Host class that exposes dnsQ queue
cfg: configuration mapping with 'dyndomains' and 'nsupdate_bin'
log: callable(host, message)
email: callable(subject, message)
"""
while True:
name, addr = hbdclass.Host.dnsQ.get()
m = f"changed address to {addr}"
for dyndomain in cfg.get("dyndomains", []):
err = nsupdate(name, addr, dyndomain, nsupdate_bin=cfg.get("nsupdate_bin", "/usr/local/bin/nsupdate"))
if err:
m += f", DNS update failed: {err}"
if email:
try:
email("error: nsupdate failed", f"{name}.dy.{dyndomain}: {m}")
except Exception:
pass
else:
m += ", DNS updated."
hbdclass.Host.dnsQ.task_done()
if log:
try:
log(name, m)
except Exception:
pass
def start_dns_thread(hbdclass, cfg: dict, log: Optional[callable] = None, email: Optional[callable] = None) -> threading.Thread:
t = threading.Thread(target=dnsupdatethread, args=(hbdclass, cfg, log, email))
t.daemon = True
t.start()
return t
+236
View File
@@ -0,0 +1,236 @@
"""HTTP server and handler scaffolds (thin wrappers around http.server)."""
from http import server
import json
import time
import urllib.parse
from urllib3 import request
from fastapi.templating import Jinja2Templates
class HttpServer(server.ThreadingHTTPServer):
allow_reuse_address = True
def threaded(self):
pass
def make_handler_class(
config,
hbdclass,
msgs_getter,
log=None,
email=None,
pushmsg=None,
msg_to_websockets=None,
tcss=None,
DEBUG=0,
verbose=False,
get_now=None,
VER="",
):
"""Return a BaseHTTPRequestHandler subclass bound to runtime objects.
`msgs_getter` should be a callable that returns a list-like of messages.
"""
templates = Jinja2Templates(directory="templates")
get_now = get_now or (lambda: time.time())
class CustomHandler(server.BaseHTTPRequestHandler):
server_version = f"HeartbeatHTTP/{VER}"
def version_string(self):
return self.server_version
def handle(self):
try:
return server.BaseHTTPRequestHandler.handle(self)
except Exception as e:
self.log_error("Request went away: %r", e)
self.close_connection = 1
return
def do_HEAD(self):
self.setheaders(200)
def setheaders(self, code, headerdict={}):
self.send_response(code)
self.send_header(
"Last-Modified",
time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime(get_now())),
)
for h in headerdict:
self.send_header(h, headerdict[h])
self.end_headers()
def buildhead(self, title="Heartbeat", refresh=None, extras=None):
res = []
res.append('<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">')
res.append("<html>")
res.append("<head>")
res.append("<title>%s</title>" % (title))
if refresh:
res.append("<meta http-equiv = Refresh content = %d>\n" % refresh)
if extras:
res.append(extras)
res.append("</head>")
res.append('<body BGCOLOR = "#FFFFFF" LINK = "#008000" VLINK = "#008000">')
return res
def buildpage(self):
res = self.buildhead(refresh=60, extras=tcss)
res.append("<H2>Heartbeat status %s</h2>" % VER)
res += hbdclass.ubHost.buildhosttable()
res += hbdclass.ubHost.buildmsgtable(msgs_getter())
res.append(
"<p> %s (%s)</p>" % (time.strftime("%H:%M:%S", time.localtime(get_now())), config.get("tz", "CET-1CDT"))
)
res.append("</body></html>")
return res
def builderror(self, code, cause, lcause):
res = []
res.append('<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">')
res.append("<html><head>")
res.append("<title>%s %s</title>" % (code, cause))
res.append("</head><body>")
res.append("<h1>%s</h1>" % (cause))
res.append("<p>%s</p>" % lcause)
res.append("<hr>")
res.append(
"<address>hbd (Unix) Server at %s:%s</address>" % (config.get("hbd_host"), config.get("hbd_port"))
)
res.append("</body></html>")
return code, res
def do_GET(self):
xsig = 0
rqAcceptEncoding = self.headers.get("Accept-encoding", {})
headerdict = {"Content-Type": "text/html; charset = ISO-8859-1"}
qr = urllib.parse.urlparse(self.path)
qa = urllib.parse.parse_qs(qr.query)
if qr.path == "/":
res = self.buildpage()
elif qr.path == "/c": # command on host /c?h=melschserver&c=sudo%20ls
uname = qa.get("h", [None])[0]
ucmd = qa.get("c", [None])[0]
if not ucmd or not uname:
code, res = self.builderror(400, "Argument error", "need h= and c= arguments")
elif uname not in hbdclass.Host.hosts:
code, res = self.builderror(400, "Data error", "h=%s not found" % uname)
else:
hbdclass.Host.hosts[uname].cmds.append(("CMD", {"cmd": urllib.parse.unquote(ucmd)}))
res = self.buildhead()
res.append("cmd %s queued for host %s" % (uname, ucmd))
elif qr.path == "/d": # drop host /d?h=melschserver
uname = qa.get("h", [None])[0]
if not uname:
code, res = self.builderror(400, "Argument error", "need h= argument")
if uname not in hbdclass.Host.hosts:
code, res = self.builderror(400, "Data error", "h=%s not found" % uname)
else:
if log:
log(uname, "dropped")
del hbdclass.Host.hosts[uname]
res = self.buildhead()
res.append("Done")
elif qr.path == "/n": # register name
uname = qa.get("h", [None])[0]
if not uname:
code, res = self.builderror(400, "Argument error", "need h= argument")
if uname not in hbdclass.Host.hosts:
code, res = self.builderror(400, "Data error", "h=%s not found" % uname)
else:
ll = hbdclass.Host.hosts[uname].registerDns()
res = self.buildhead()
res.append(ll)
if log:
log(uname, ll)
elif qr.path == "/u": # update
uname = urllib.parse.unquote(qa.get("h", [None])[0])
ucode = qa.get("c", [None])[0]
if not ucode or not uname:
code, res = self.builderror(400, "Argument error", "need h= and c= arguments")
elif uname != "All" and uname not in hbdclass.Host.hosts:
code, res = self.builderror(400, "Data error", "h=%s not found" % uname)
else:
res = self.buildhead()
if uname != "All":
names = [uname]
else:
names = []
for n in hbdclass.Host.hosts:
if hbdclass.Host.hosts[n].cver >= 2: # earliest version that supports update
names.append(n)
for n in names:
err = None
try:
from hbd import proto
# read code from a file name, fallback to sending ucode as data
err = None
# attempt to send update command to host
r = {"csum": None, "code": ucode}
hbdclass.Host.hosts[n].cmds.append(("UPD", r))
except Exception as e:
err = str(e)
res.append("update started for %s: %s<br>" % (n, err if err else "OK"))
res.append("Done")
elif qr.path == "/api/0/hosts": # api access to host table
headerdict = {"Content-Type": "application/json; charset=utf-8"}
lst = []
for h in hbdclass.Host.hosts:
lst.append(hbdclass.Host.hosts[h].jsons())
res = ["[" + ",".join(lst) + "]"]
elif qr.path == "/api/0/messages": # api access to host table
headerdict = {"Content-Type": "application/json; charset=utf-8"}
lst = msgs_getter()[-30:]
res = [json.dumps(lst)]
elif qr.path == "/r": # restart
res = self.buildhead()
res.append("restart request")
xsig = 1 # signal.SIGHUP will be handled by application
if log:
log(None, "restart request")
elif qr.path == "/live": # show live view with websockets
host = config.get("hb_host", "localhost")
extra_scripts = '' # '<script src="/static/js/live.js"></script>'
heartbeat_ws_url = f"ws://{host}:50005/hbd"
res = templates.TemplateResponse(
"live.html ",
{
"title": "Heartbeat",
"header": "Heartbeat",
"heartbeat_ws_url": heartbeat_ws_url,
"extra_scripts": extra_scripts,
},
)
else:
code, res = self.builderror(404, "Not Found", "requested URL was not found on this server.")
if "deflate" in rqAcceptEncoding:
headerdict["Content-Encoding"] = "deflate"
towrite = __import__("zlib").compress("\n".join(res).encode(), 6)
else:
towrite = "\n".join(res)
headerdict["Content-Length"] = len(towrite)
headerdict["Cache-Control"] = "private, must-revalidate, max-age=0"
headerdict["Expires"] = "Thu, 01 Jan 1970 00:00:00 GMT"
self.setheaders(200 if 'res' in locals() else code, headerdict)
self.wfile.write(towrite if isinstance(towrite, bytes) else towrite.encode())
if xsig:
# inform application via setting a flag on the server instance
try:
self.server.xsig = xsig
except Exception:
pass
return CustomHandler
+163
View File
@@ -0,0 +1,163 @@
"""Notification helpers: email, pushover, mattermost, signal and dispatcher."""
from typing import Optional
import http.client
import urllib.parse
import subprocess
import smtplib
import time
import traceback
DEFAULT_PUSHPROVIDERS = ["all", "pushover", "mattermost", "signal"]
# module-level configuration set via setup()
_config = {}
def setup(cfg: dict):
"""Initialize notifier defaults from a configuration dict."""
global _config
_config = dict(cfg)
def send_email(aemail, smtpserver, sender, subject, body, debug=0):
"""Send a plain email via SMTP. Returns True on success."""
try:
server = smtplib.SMTP(smtpserver)
if debug > 0:
server.set_debuglevel(1)
server.sendmail(sender, aemail, body)
except Exception as e:
if debug:
print("email send failed:", e)
try:
server.quit()
except Exception:
pass
return False
try:
server.quit()
except Exception:
pass
return True
def email(subject: str, msg: str, debug: int = 0) -> bool:
"""Convenience wrapper exposed to the rest of the application.
Uses module-level configuration to supply recipient list, smtp server
and sender address.
"""
toaddrs = _config.get("AEMAIL") or _config.get("aemail") or _config.get("email_to") or []
fromemail = _config.get("fromemail") or _config.get("sender") or f"aew.heartbeat@{_config.get('domain','local') }"
smtpserver = _config.get("SMTPSERVER") or _config.get("smtpserver") or _config.get("SMTPSERVER", "localhost")
date = time.strftime("%a, %d %b %Y %H:%M:%S %z", time.localtime())
body = "To: %s\nFrom: %s\nSubject: %s\nDate: %s\n\n%s" % (
toaddrs[0] if toaddrs else "",
fromemail,
subject,
date,
msg,
)
return send_email(toaddrs, smtpserver, fromemail, subject, body, debug=debug)
def pushover(token: str, user: str, msg: str, debug: int = 0) -> bool:
"""Send message via Pushover API."""
conn = http.client.HTTPSConnection("api.pushover.net:443")
try:
conn.request(
"POST",
"/1/messages.json",
urllib.parse.urlencode({"token": token, "user": user, "message": msg}),
{"Content-type": "application/x-www-form-urlencoded"},
)
r = conn.getresponse()
if debug:
print("pushover response:", r.status, r.reason)
return r.status == 200
except Exception as e:
if debug:
print("pushover error:", e)
return False
def pushmattermost(host: str, token: str, channel: str, msg: str, username: str = "hbd", icon: Optional[str] = None, debug: int = 0) -> bool:
"""Send a message to Mattermost via simple webhook driver if available.
This helper tries to import mattermostdriver.Driver and uses webhooks if present.
If the import fails it returns False.
"""
try:
from mattermostdriver import Driver
except Exception:
return False
ses = {"url": host, "scheme": "http", "basepath": "/api/v4", "port": 8065}
mm = Driver(ses)
payload = {"text": msg, "channel": channel, "username": username}
if icon:
payload["icon_url"] = icon
try:
rc = mm.webhooks.call_webhook(token, payload)
if debug:
print("mattermost rc:", rc)
return bool(rc is None or rc == "")
except Exception as e:
if debug:
print("mattermost error:", e)
return False
def pushsignal(signal_cli_bin: str, user: str, recipient: str, msg: str, debug: int = 0) -> bool:
"""Send a message via signal-cli (requires local installation).
Uses subprocess to call signal-cli. Returns True if the command succeeded.
"""
CLI = [signal_cli_bin, "-u", user, "send", "-m", msg, recipient]
if debug:
print("signal cli: ", CLI)
try:
res = subprocess.run(CLI, capture_output=True)
if res.returncode != 0:
if debug:
print("signal failed:", res.stderr.decode())
return False
if debug:
print("signal sent:", res.stdout.decode())
return True
except Exception as e:
if debug:
print("signal exception:", e)
return False
def pushmsg(cfg: dict, msg: str, debug: int = 0):
"""Dispatch push notifications according to `cfg['pushsrv']`.
cfg is expected to contain keys for different services when needed, e.g.
- cfg['pushsrv'] : one of 'all', 'pushover', 'mattermost', 'signal'
- cfg['pushover_token'], cfg['pushover_user']
- cfg['matter_host'], cfg['matter_token'], cfg['matter_channel']
- cfg['signal_cli'], cfg['signal_user'], cfg['signal_recipient']
Returns a dict of results per provider.
"""
results = {}
p = cfg.get("pushsrv", "pushover")
if p in ("all", "pushover"):
ok = pushover(cfg.get("pushover_token", ""), cfg.get("pushover_user", ""), msg, debug=debug)
results["pushover"] = ok
if p in ("all", "mattermost"):
ok = pushmattermost(cfg.get("matter_host", ""), cfg.get("matter_token", ""), cfg.get("matter_channel", ""), msg, username=cfg.get("matter_username", "hbd"), icon=cfg.get("matter_icon"), debug=debug)
results["mattermost"] = ok
if p in ("all", "signal"):
ok = pushsignal(cfg.get("signal_cli", "/usr/local/bin/signal-cli"), cfg.get("signal_user", ""), cfg.get("signal_recipient", ""), msg, debug=debug)
results["signal"] = ok
if debug:
print("push results:", results)
return results
def pushmsg_from_config(msg: str, debug: int = 0) -> dict:
"""Use the module-level configuration dict to dispatch a push message."""
return pushmsg(_config, msg, debug=debug)
+81
View File
@@ -0,0 +1,81 @@
"""Message encoding/decoding utilities for hbd protocol."""
from typing import Dict, Any
import zlib
def dicttos(ID: str, d: Dict[str, Any], compress: bool = False):
"""Serialize a dict to protocol message bytes.
If compress is True, the payload is zlib-compressed and the message is
prefixed with `!ID:` as the original script did. Otherwise the format is
`ID:key=value;...` (bytes).
"""
s = []
for k in d:
v = d[k]
if isinstance(v, float):
s.append(f"{k}={v:0.5f}")
else:
s.append(f"{k}={v}")
pk = ";".join(s)
if compress:
zpk = zlib.compress(pk.encode(), 6)
hdr = ("!" + ID + ":").encode()
return hdr + zpk
else:
return (ID + ":" + pk).encode()
def stodict(msg: bytes):
"""Deserialize a protocol message into a dict.
Mirrors original behaviour: detects compressed messages starting with
'!' and decodes accordingly. Returns a dict with key 'ID' set to the
message ID and the parsed key/value pairs.
"""
d = {}
if len(msg) > 0 and chr(msg[0]) == "!":
# message is: b'!ID:' + compressed_payload
# original code used msg[1:4].decode() for ID (3 bytes including colon)
try:
pk = zlib.decompress(msg[5:]).decode()
except Exception:
# malformed compressed payload
return {}
d["ID"] = msg[1:4].decode()
else:
try:
r0 = msg.split(b":", 1)
pk = r0[1].decode()
d["ID"] = r0[0].decode()
except Exception:
return {}
if not pk:
return d
parts = pk.split(";")
for v in parts:
if not v:
continue
vr = v.split("=", 1)
k = vr[0].strip()
if len(vr) == 1:
d[k] = None
else:
val = vr[1].strip()
if val and val[0].isdigit():
try:
val_e = eval(val)
except Exception:
val_e = val
d[k] = val_e
else:
d[k] = val
return d
def oldmtodict(msg: bytes):
"""Compatibility wrapper for old-style messages (no ID prefix).
The original implementation prefixed with 'HTB:' and called stodict.
"""
return stodict(b"HTB:" + msg)
+128
View File
@@ -0,0 +1,128 @@
"""Server runtime: starts UDP listener, HTTP server and websocket stubs."""
import asyncio
import logging
from . import udp
logger = logging.getLogger(__name__)
async def _run_async(config):
loop = asyncio.get_running_loop()
# shared runtime collections and helpers
msgs = []
# prepare runtime dependencies
import threading
import time
import hbdclass
from . import http as http_mod
from . import ws as ws_mod
from . import dns as dns_mod
from . import notify as notify_mod
notify_mod.setup(config)
def log(host, m, service=None):
ts = time.time()
s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {host or ''} {m}"
msgs.append(s)
logger.info(s)
email = notify_mod.email
pushmsg = notify_mod.pushmsg_from_config
msg_to_websockets = ws_mod.broadcast
# UDP server endpoint (handler wired to handle_datagram with context)
bind_addr = ("0.0.0.0", config.get("hb_port", 50003))
logger.info("Starting UDP server on %s:%s", *bind_addr)
def udp_handler(msg, addr, transport):
ctx = dict(
config=config,
hbdclass=hbdclass,
log=log,
email=email,
pushmsg=pushmsg,
msg_to_websockets=msg_to_websockets,
msgs=msgs,
DEBUG=config.get("debug", 0),
verbose=config.get("verbose", False),
)
udp.handle_datagram(msg, addr, transport, ctx)
transport, protocol = await loop.create_datagram_endpoint(
lambda: udp.EchoServerProtocol(config=config, handler=udp_handler),
local_addr=bind_addr,
)
# HTTP server (runs in its own thread)
try:
handler_cls = http_mod.make_handler_class(
config=config,
hbdclass=hbdclass,
msgs_getter=lambda: msgs,
log=log,
email=email,
pushmsg=pushmsg,
msg_to_websockets=msg_to_websockets,
tcss=None,
DEBUG=config.get("debug", 0),
verbose=config.get("verbose", False),
get_now=lambda: time.time(),
VER="",
)
serv = http_mod.HttpServer((config.get("hbd_host", ""), config.get("hbd_port", 50004)), handler_cls)
http_thread = threading.Thread(target=serv.serve_forever, daemon=True)
http_thread.start()
logger.info("HTTP server started on %s:%s", config.get("hbd_host", ""), config.get("hbd_port", 50004))
except Exception as e:
logger.exception("failed to start HTTP server: %s", e)
# start dns update thread
dns_mod.start_dns_thread(hbdclass, config, log=log, email=email)
logger.info("dns update thread started")
# Start the websocket servers as a background task
try:
ws_task = asyncio.create_task(
ws_mod.start(
host=config.get("hbd_host", ""),
ws_port=config.get("ws_port", 50005),
wss_port=config.get("wss_port", None),
ssl_context=None,
get_hosts=lambda: [hbdclass.Host.hosts[h].stateinfo() for h in sorted(hbdclass.Host.hosts)],
get_msgs=lambda: msgs,
verbose=config.get("verbose", False),
)
)
logger.info("WebSocket task started")
except Exception as e:
logger.exception("websocket server failed to start: %s", e)
try:
# run forever
await asyncio.Future()
finally:
transport.close()
try:
serv.shutdown()
except Exception:
pass
try:
ws_task.cancel()
except Exception:
pass
def run(config):
"""Start the hbd service (blocking).
This is a thin wrapper around asyncio.run to host the async services.
"""
logging.basicConfig(level=logging.DEBUG if config.get("debug", 0) > 0 else logging.INFO)
try:
asyncio.run(_run_async(config))
except KeyboardInterrupt:
logger.info("Shutting down (KeyboardInterrupt)")
+235
View File
@@ -0,0 +1,235 @@
"""UDP listener and datagram processing."""
import asyncio
from compression import zlib
import logging
logger = logging.getLogger(__name__)
from .proto import stodict, oldmtodict
from hbd.utils import dur
class EchoServerProtocol(asyncio.DatagramProtocol):
def __init__(self, config=None, handler=None):
super().__init__()
self.config = config or {}
self.handler = handler
def connection_made(self, transport):
self.transport = transport
logger.info("UDP Server listening...")
def datagram_received(self, data, addr):
logger.debug("Received from %s", addr)
try:
msg = parse_message(data)
if self.handler:
# handler can be a callable provided by the application
# pass the transport so handlers can send replies (ACKs/commands)
self.handler(msg, addr, self.transport)
except Exception:
logger.exception("Error while processing datagram from %s", addr)
def parse_message(data: bytes):
"""Parse a raw datagram into a message dict.
Uses the protocol decoding helpers and falls back to old format when
decoding returns an empty dict (compat with older clients).
"""
msg = stodict(data)
if not msg:
# fallback to old format
msg = oldmtodict(data)
return msg
def dicttos(ID, d, compress=False):
s = []
for k in d:
if isinstance(d[k], float):
s.append("%s=%0.5f" % (k, d[k]))
else:
s.append("%s=%s" % (k, d[k]))
pk = ";".join(s)
if compress:
zpk = zlib.compress(pk.encode(), 6)
ID = "!" + ID + ":"
opk = ID.encode() + zpk
else:
zpk = pk
opk = ID + ":" + zpk
return opk
def handle_datagram(msg: dict, addr, transport, ctx: dict):
"""Handle a parsed datagram message.
ctx is a dictionary with runtime dependencies:
- config: dict of configuration
- hbdclass: module providing Host/Connection classes
- log: callable(loghost, message)
- email: callable(subject, message)
- pushmsg: callable(message)
- msg_to_websockets: callable(typ, data)
- msgs: list for storing message strings
- DEBUG, verbose
"""
if not msg:
return
now = __import__("time").time()
cfg = ctx.get("config", {})
hbdcls = ctx.get("hbdclass")
log = ctx.get("log")
email = ctx.get("email")
pushmsg = ctx.get("pushmsg")
msg_to_websockets = ctx.get("msg_to_websockets")
msgs = ctx.get("msgs")
DEBUG = ctx.get("DEBUG", 0)
verbose = ctx.get("verbose", False)
# normalize addr (ip, port)
ip = addr[0] if isinstance(addr, (list, tuple)) else addr
name = msg.get("name", "unknown")
from hbd.utils import shortname
uname = shortname(name)
if uname not in hbdcls.Host.hosts:
host = hbdcls.Host(uname)
host.dyn = uname in cfg.get("dyndnshosts", [])
if verbose:
print(("XX: New host, num now %s" % (len(hbdcls.Host.hosts))))
newh = True
else:
host = hbdcls.Host.hosts[uname]
newh = False
cid = msg.get("id", 0)
try:
rtt = float(msg.get("rtt", None))
except Exception:
rtt = None
if msg.get("ID") == "HTB":
host.doesack = msg.get("acks", -1)
host.setcver(msg.get("ver", 0))
try:
conn, res = host.conndata(cid, ip, rtt, now)
except Exception as e:
if DEBUG > 0:
print("conndata failed: %s" % e)
return
if res:
if log:
log(uname, res)
if uname in cfg.get("watchhosts", []):
if email:
email("address change", "%s %s" % (host.name, res))
if pushmsg:
pushmsg("%s %s" % (host.name, res))
interval = int(msg.get("interval", 0) or 0)
shutdown = msg.get("shutdown", 0)
service = msg.get("service", "unknown")
message = msg.get("msg", None)
boot = msg.get("boot", 0)
if boot:
if log:
log(uname, "booted")
if uname in cfg.get("watchhosts", []):
m = "%s booted" % (host.name)
if email:
email("booted", m)
if pushmsg:
pushmsg(m)
if message:
if log:
log(uname, "msg: %s" % message, service=service)
if uname in cfg.get("watchhosts", []):
if email:
email("msg", message)
if pushmsg:
pushmsg(message)
if conn.getstate() != hbdcls.Connection.UP:
lasts = conn.state
d = conn.newstate(hbdcls.Connection.UP, now)
m = "%s back after being %s for %s" % (conn.afam, lasts, dur(d))
if log:
log(uname, m)
if uname in cfg.get("watchhosts", []):
if email:
email("%s back" % conn.afam, uname)
if pushmsg:
pushmsg("%s %s is back" % (uname, conn.afam))
if boot or newh:
host.upcount = host.doesack
else:
host.upcount += 1
if shutdown:
if log:
log(uname, "%s shutdown" % conn.afam)
if uname in cfg.get("watchhosts", []):
if email:
email("shutdown", "%s %s shutdown" % (uname, conn.afam))
if pushmsg:
pushmsg("%s %s shutdown" % (uname, conn.afam))
conn.newstate(hbdcls.Connection.DOWN, now)
if interval > 0:
host.interval = interval
# send ACK back
rmsg = {"time": __import__("time").time()}
if host.cver < 1:
opkt = b"ACK"
else:
opkt = dicttos("ACK", rmsg, host.cver > 1)
try:
transport.sendto(opkt, addr)
except Exception as e:
if DEBUG > 0:
print(("cannot send ack: %s" % e))
# send any commands we have queued
while len(host.cmds):
op, rmsg = host.cmds[0]
if op == "CMD":
if email:
email("%s cmd exec" % uname, "command '%s' sent" % rmsg)
del host.cmds[0]
if log:
log(uname, "command sent")
if host.cver < 1:
rmsg = rmsg["cmd"]
elif op == "UPD":
del host.cmds[0]
if log:
log(uname, "update initiated")
if host.cver < 1:
if log:
log(uname, " ver 0 does not support UPD")
continue
if host.cver < 1:
opkt = rmsg if isinstance(rmsg, (bytes, str)) else str(rmsg)
if isinstance(opkt, str):
opkt = opkt.encode()
else:
opkt = dicttos(op, rmsg, True)
try:
transport.sendto(opkt, addr)
except Exception as e:
if DEBUG > 0:
print(("cannot send cmd/update: %s" % e))
if msg_to_websockets:
try:
msg_to_websockets("host", host.stateinfo())
except Exception:
pass
+36
View File
@@ -0,0 +1,36 @@
"""Utility helpers extracted from the original script."""
import time
def shortname(name: str) -> str:
return name.split(".")[0]
def dur(sec: int) -> str:
sec = int(sec)
h = int(sec / 3600)
m = int((sec - h * 3600) / 60)
s = int((sec - h * 3600) % 60)
if h > 0:
return "%d:%02d:%02d" % (h, m, s)
if m > 0:
return "%d:%02d" % (m, s)
return "0:%02d" % s
def initlog(logfile: str):
"""Open logfile for appending; fall back to creating it or returning stderr.
This mirrors the original behaviour from the monolithic script.
"""
try:
return open(logfile, "a+")
except Exception:
pass
try:
return open(logfile, "w")
except Exception as e:
import sys
print(f"cannot open logfile {logfile}, using STDERR: {e}")
return sys.stderr
+125
View File
@@ -0,0 +1,125 @@
"""WebSocket server and broadcast helpers for hbd.
Provides an asyncio-based WebSocket server and a thread-safe broadcast
function that other threads or synchronous code can call.
"""
import asyncio
import json
import logging
from typing import Callable, Iterable, Optional
import websockets
logger = logging.getLogger(__name__)
_connections = set()
_loop: Optional[asyncio.AbstractEventLoop] = None
_get_hosts: Optional[Callable[[], Iterable]] = None
_get_msgs: Optional[Callable[[], Iterable]] = None
_verbose = False
async def _handler(websocket, path):
global _connections
_connections.add(websocket)
remote_address = websocket.remote_address
if _verbose:
logger.info("DBG ws_serve: %s: %s", remote_address, path)
try:
# send initial hosts
if _get_hosts:
for h in _get_hosts():
jmsg = json.dumps({"type": "host", "data": h})
await websocket.send(jmsg)
# send recent messages
if _get_msgs:
for m in list(_get_msgs())[-100:]:
jmsg = json.dumps({"type": "message", "data": m})
await websocket.send(jmsg)
# keep connection open until client disconnects
async for _ in websocket:
# we don't expect meaningful incoming messages besides the initial
# client 'hello' that some clients send; ignore for now
if _verbose:
logger.debug("received ws data: %s", _)
except (websockets.exceptions.ConnectionClosedOK, websockets.exceptions.ConnectionClosedError) as e:
if _verbose:
logger.info("ws closed: %r", e)
except Exception as e:
logger.exception("ws handler exception: %s", e)
finally:
try:
_connections.remove(websocket)
except KeyError:
pass
await websocket.wait_closed()
async def start(host: str, ws_port: int, wss_port: Optional[int] = None, ssl_context=None, get_hosts: Optional[Callable] = None, get_msgs: Optional[Callable] = None, verbose: bool = False):
"""Start WebSocket servers and block until cancelled.
This is intended to be awaited inside the main asyncio event loop.
If `wss_port` and `ssl_context` are provided, a WSS server will also be
started.
"""
global _loop, _get_hosts, _get_msgs, _verbose
_loop = asyncio.get_running_loop()
_get_hosts = get_hosts
_get_msgs = get_msgs
_verbose = verbose
servers = []
# plain WebSocket
ws_server = websockets.serve(_handler, host, ws_port, subprotocols=["hbd"])
servers.append(ws_server)
# secure WebSocket (optional)
if wss_port and ssl_context:
wss_server = websockets.serve(_handler, host, wss_port, ssl=ssl_context, subprotocols=["hbd"])
servers.append(wss_server)
# await starting of all servers
for srv in servers:
await srv
if _verbose:
logger.info("WebSocket server started on port %s (wss %s)", ws_port, wss_port)
# block forever (until loop is stopped or cancelled)
await asyncio.Future()
def broadcast(typ: str, data) -> bool:
"""Thread-safe broadcast helper.
Schedules coroutine(s) on the running loop to send message to all
connected websockets. Returns False if server was not running.
"""
global _loop
if not _loop:
return False
jmsg = json.dumps({"type": typ, "data": data})
to_close = []
for ws in list(_connections):
if ws.closed:
to_close.append(ws)
continue
try:
asyncio.run_coroutine_threadsafe(ws.send(jmsg), _loop)
except Exception:
to_close.append(ws)
logger.debug("ws.send exception: closed")
for ws in to_close:
try:
asyncio.run_coroutine_threadsafe(ws.wait_closed(), _loop)
except Exception:
pass
if ws in _connections:
_connections.remove(ws)
return True
def connection_count() -> int:
return len(_connections)