link and flake cleanup

This commit is contained in:
2026-02-08 16:05:03 -05:00
parent 087a264e97
commit 5e6dfc75ad
24 changed files with 393 additions and 186 deletions
+1
View File
@@ -7,5 +7,6 @@ __pycache__/
.venv/ .venv/
test/ test/
build/ build/
dist/
*.egg-info/ *.egg-info/
ssl/ ssl/
+33 -15
View File
@@ -1,5 +1,3 @@
# Heartbeat Daemon (hbd) ✅ # Heartbeat Daemon (hbd) ✅
A lightweight daemon that listens for UDP heartbeat messages and acts on them: keeps host state, optionally updates DNS records via `nsupdate`, forwards messages to WebSocket clients, and sends notifications (email, Pushover, Mattermost, Signal). It is a refactor of a previously monolithic script into a modular Python package (`hbd`). A lightweight daemon that listens for UDP heartbeat messages and acts on them: keeps host state, optionally updates DNS records via `nsupdate`, forwards messages to WebSocket clients, and sends notifications (email, Pushover, Mattermost, Signal). It is a refactor of a previously monolithic script into a modular Python package (`hbd`).
@@ -20,25 +18,23 @@ A lightweight daemon that listens for UDP heartbeat messages and acts on them: k
## ⚙️ Quickstart ## ⚙️ Quickstart
Prerequisites: Prerequisites:
- Python 3.10+ (project uses language features from recent Python) - Python 3.10+ (project uses language features from recent Python)
- `nsupdate` (for DNS updates) if using dynamic DNS - `nsupdate` (for DNS updates) if using dynamic DNS
Install dependencies (recommended into a venv): Install dependencies (recommended into a venv):
```bash This project now declares its dependencies in `pyproject.toml`. Instead
python3 -m venv .venv of the old `requirements.txt` flow, install the package into a virtualenv
source .venv/bin/activate using `pip`:
python -m pip install --upgrade pip
python -m pip install -r requirements.txt See `scripts/install.sh` for a way to install.
# for development/testing tools
python -m pip install -r requirements-dev.txt
```
Run the daemon (example): Run the daemon (example):
```bash ```bash
# run with default config lookup (~/.hb.yaml) # run with default config lookup (~/.hb.yaml)
PYTHONPATH=. hbd -c .hb.yaml -f -v hbd -c .hb.yaml -f -v
``` ```
You can also run it directly via the package entrypoint after installation: You can also run it directly via the package entrypoint after installation:
@@ -65,7 +61,6 @@ PYTHONPATH=. python -m debugpy --listen 5678 --wait-for-client -m hbd.cli -c .hb
Set breakpoints in modules such as `hbd/udp.py`, `hbd/dns.py`, or `hbd/server.py`, and use the **Attach** configuration to connect. Use `justMyCode: false` if you need to step into third-party code. Set breakpoints in modules such as `hbd/udp.py`, `hbd/dns.py`, or `hbd/server.py`, and use the **Attach** configuration to connect. Use `justMyCode: false` if you need to step into third-party code.
--- ---
## 🛠 Configuration ## 🛠 Configuration
@@ -82,6 +77,13 @@ Set breakpoints in modules such as `hbd/udp.py`, `hbd/dns.py`, or `hbd/server.py
- `interval` / `grace`: heartbeat timing configuration - `interval` / `grace`: heartbeat timing configuration
- `dyndomains`: list of dyndomains to update via `nsupdate` - `dyndomains`: list of dyndomains to update via `nsupdate`
- `nsupdate_bin`: path to nsupdate binary - `nsupdate_bin`: path to nsupdate binary
- `ws_port`: port for plain WebSocket connections (default: 50005)
- `wss_port`: port for secure WebSocket (WSS) connections (default: none).
If set, `hbd` will attempt to serve WSS on this port when `wss_pem` and
`wss_key` SSL files are available under `cert_path` (see below).
- `cert_path`: directory where TLS certificate and key are looked up (default: /usr/local/etc/ssl/)
- `wss_pem`: filename for the certificate chain (default: fullchain.pem)
- `wss_key`: filename for the private key (default: privkey.pem)
Example `.hb.yaml` (minimal): Example `.hb.yaml` (minimal):
@@ -102,7 +104,11 @@ pushsrv: pushover
- `hbd.proto` — serialization/deserialization of heartbeat messages (supports compressed payloads) - `hbd.proto` — serialization/deserialization of heartbeat messages (supports compressed payloads)
- `hbd.udp` — UDP parsing and `handle_datagram` implementation (main state machine) - `hbd.udp` — UDP parsing and `handle_datagram` implementation (main state machine)
- `hbd.dns``create_nsupdate_payload`, `nsupdate`, and a background DNS thread (`start_dns_thread`) - `hbd.dns``create_nsupdate_payload`, `nsupdate`, and an asyncio DNS worker (`start_dns_worker`).
The DNS worker now runs as an `asyncio` task and the package exposes a
small thread-safe bridge so legacy synchronous code can `put()` updates
into the queue; there is no longer a permanently-blocking background
`threading.Thread`.
- `hbd.notify` — email and push notification helpers - `hbd.notify` — email and push notification helpers
- `hbd.ws` — WebSocket server and thread-safe broadcast helpers - `hbd.ws` — WebSocket server and thread-safe broadcast helpers
- `hbd.http` — HTTP handler factory for the status UI/API - `hbd.http` — HTTP handler factory for the status UI/API
@@ -112,6 +118,17 @@ pushsrv: pushover
This modular layout makes the code easier to test and maintain. This modular layout makes the code easier to test and maintain.
**Runtime & Shutdown**
- The main runtime is asyncio-based. Services (UDP listener, HTTP server, WebSocket server, monitor, and DNS worker) run as asyncio tasks.
- On SIGINT/SIGTERM the server triggers a graceful shutdown: it cancels active tasks, signals the DNS worker via a sentinel, and cleans up resources before exit.
- The DNS update worker is implemented as an `asyncio` task; synchronous producers can still enqueue DNS updates via a small thread-safe bridge available at `hbd.hbdclass.Host.dnsQ`.
**Templates & Static Files**
- Template files are located under `hbd/templates` by default. The HTTP server resolves templates relative to the `hbd` package but the path can be overridden with the `templates_dir` config key.
- Static assets (CSS/JS/images) are served from `hbd/static` via the `/static/<path>` HTTP route. Place your static files in that directory or configure the HTTP server as needed.
--- ---
## 🧪 Testing & Dev ## 🧪 Testing & Dev
@@ -126,8 +143,8 @@ pytest -q
``` ```
Developer tooling included: Developer tooling included:
- `pyproject.toml` — project metadata and dependencies - `pyproject.toml` — project metadata and dependencies
- `requirements-dev.txt` — dev/test dependencies
- `tox.ini` — convenience wrappers for running tests, lint, and mypy - `tox.ini` — convenience wrappers for running tests, lint, and mypy
To run linters and type checks locally: To run linters and type checks locally:
@@ -153,6 +170,7 @@ tox -e mypy
## 🤝 Contributing ## 🤝 Contributing
Contributions welcome! Please: Contributions welcome! Please:
1. Open an issue to discuss larger changes. 1. Open an issue to discuss larger changes.
2. Create a topic branch and a clear PR. 2. Create a topic branch and a clear PR.
3. Add tests for new features and run linters. 3. Add tests for new features and run linters.
@@ -167,8 +185,8 @@ This repository is licensed under the MIT license. See `LICENSE` for details.
--- ---
If you'd like, I can also: If you'd like, I can also:
- add a **GitHub Actions** workflow that runs tests and lint on push/PR 🔁 - add a **GitHub Actions** workflow that runs tests and lint on push/PR 🔁
- add a `CONTRIBUTING.md` template for PRs and code style 💬 - add a `CONTRIBUTING.md` template for PRs and code style 💬
Which one should I do next? ✨ Which one should I do next? ✨
BIN
View File
Binary file not shown.
BIN
View File
Binary file not shown.
+13 -4
View File
@@ -1,4 +1,5 @@
"""Command line interface for hbd package.""" """Command line interface for hbd package."""
import argparse import argparse
from .config import load_config from .config import load_config
@@ -13,11 +14,19 @@ def build_parser():
description="HeartBeatDaemon - Wait for heartbeat messages and act on them (or their absence)", description="HeartBeatDaemon - Wait for heartbeat messages and act on them (or their absence)",
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
parser.add_argument("-c", "--config", dest="configfile", help="Config file path (YAML)") parser.add_argument(
parser.add_argument("-f", "--foreground", action="store_true", help="Run in foreground") "-c", "--config", dest="configfile", help="Config file path (YAML)"
)
parser.add_argument(
"-f", "--foreground", action="store_true", help="Run in foreground"
)
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument("-p", "--pushsrv", dest="pushsrv", choices=PUSHSRVS, help="Push service to use") parser.add_argument(
parser.add_argument("-x", "--debug", action="count", default=0, help="Increase debug level") "-p", "--pushsrv", dest="pushsrv", choices=PUSHSRVS, help="Push service to use"
)
parser.add_argument(
"-x", "--debug", action="count", default=0, help="Increase debug level"
)
return parser return parser
+2 -1
View File
@@ -1,4 +1,5 @@
"""Configuration loader and defaults for hbd.""" """Configuration loader and defaults for hbd."""
import logging import logging
import os import os
@@ -37,7 +38,7 @@ DEFAULTS = {
"wss_port": None, "wss_port": None,
"cert_path": "/usr/local/etc/ssl/", "cert_path": "/usr/local/etc/ssl/",
"wss_pem": "fullchain.pem", "wss_pem": "fullchain.pem",
"wss_key": "privkey.pem" "wss_key": "privkey.pem",
} }
+66 -13
View File
@@ -1,13 +1,23 @@
"""DNS update helper and pure asyncio worker for heartbeat daemon.""" """DNS update helper and pure asyncio worker for heartbeat daemon."""
from __future__ import annotations from __future__ import annotations
import subprocess
from subprocess import Popen, PIPE, STDOUT from subprocess import Popen, PIPE, STDOUT
from typing import Optional from typing import Optional
import asyncio import asyncio
def create_nsupdate_payload(hostname: str, newip: str, dyndomain: str, dnsttl: str = "5") -> str: def create_nsupdate_payload(
D = {"domain": dyndomain, "fqdn": f"{hostname}.dy.{dyndomain}", "dnsttl": dnsttl, "newip": newip, "ts": __import__("time").strftime("%Y-%m-%d.%H:%M:%S", __import__("time").gmtime())} hostname: str, newip: str, dyndomain: str, dnsttl: str = "5"
) -> str:
D = {
"domain": dyndomain,
"fqdn": f"{hostname}.dy.{dyndomain}",
"dnsttl": dnsttl,
"newip": newip,
"ts": __import__("time").strftime(
"%Y-%m-%d.%H:%M:%S", __import__("time").gmtime()
),
}
if ":" in newip: if ":" in newip:
nsup = ( nsup = (
"""update delete %(fqdn)s AAAA """update delete %(fqdn)s AAAA
@@ -17,7 +27,8 @@ update add %(fqdn)s %(dnsttl)s TXT "Created: %(ts)s"
send send
answer answer
""" % D """
% D
) )
else: else:
nsup = ( nsup = (
@@ -28,12 +39,19 @@ update add %(fqdn)s %(dnsttl)s TXT "Created: %(ts)s"
send send
answer answer
""" % D """
% D
) )
return nsup return nsup
def nsupdate(hostname: str, newip: str, dyndomain: str, nsupdate_bin: str = "/usr/local/bin/nsupdate", rndc_key: str = "/etc/dhcpc/rndc-key") -> Optional[str]: def nsupdate(
hostname: str,
newip: str,
dyndomain: str,
nsupdate_bin: str = "/usr/local/bin/nsupdate",
rndc_key: str = "/etc/dhcpc/rndc-key",
) -> Optional[str]:
"""Perform DNS update via nsupdate command. """Perform DNS update via nsupdate command.
Returns None on success, else returns combined stdout/stderr as a string. Returns None on success, else returns combined stdout/stderr as a string.
@@ -54,7 +72,14 @@ def nsupdate(hostname: str, newip: str, dyndomain: str, nsupdate_bin: str = "/us
return out return out
async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional[callable] = None, pushmsg: Optional[callable] = None, loop: Optional[asyncio.AbstractEventLoop] = None): async def dns_update_worker(
hbdclass,
cfg: dict,
async_queue=None,
log: Optional[callable] = None,
pushmsg: Optional[callable] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
"""Pure async DNS worker that processes updates from asyncio.Queue. """Pure async DNS worker that processes updates from asyncio.Queue.
Exits when it receives a None sentinel. Exits when it receives a None sentinel.
@@ -66,7 +91,9 @@ async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional
if not dnsq: if not dnsq:
if log: if log:
try: try:
await loop.run_in_executor(None, log, None, "dns_update_worker: no queue available") await loop.run_in_executor(
None, log, None, "dns_update_worker: no queue available"
)
except Exception: except Exception:
pass pass
return return
@@ -77,7 +104,9 @@ async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional
except Exception as e: except Exception as e:
if log: if log:
try: try:
await loop.run_in_executor(None, log, None, f"dns_update_worker: error getting item: {e}") await loop.run_in_executor(
None, log, None, f"dns_update_worker: error getting item: {e}"
)
except Exception: except Exception:
pass pass
break break
@@ -96,12 +125,25 @@ async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional
m = f"changed address to {addr}" m = f"changed address to {addr}"
for dyndomain in cfg.get("dyndomains", []): for dyndomain in cfg.get("dyndomains", []):
err = await loop.run_in_executor(None, nsupdate, name, addr, dyndomain, cfg.get("nsupdate_bin", "/usr/local/bin/nsupdate"), cfg.get("rndc_key", "/etc/dhcpc/rndc-key")) err = await loop.run_in_executor(
None,
nsupdate,
name,
addr,
dyndomain,
cfg.get("nsupdate_bin", "/usr/local/bin/nsupdate"),
cfg.get("rndc_key", "/etc/dhcpc/rndc-key"),
)
if err: if err:
m += f", DNS update failed: {err}" m += f", DNS update failed: {err}"
if pushmsg: if pushmsg:
try: try:
await loop.run_in_executor(None, pushmsg, "error: nsupdate failed", f"{name}.dy.{dyndomain}: {m}") await loop.run_in_executor(
None,
pushmsg,
"error: nsupdate failed",
f"{name}.dy.{dyndomain}: {m}",
)
except Exception: except Exception:
pass pass
else: else:
@@ -125,7 +167,13 @@ async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional
pass pass
def start_dns_worker(hbdclass, cfg: dict, log: Optional[callable] = None, pushmsg: Optional[callable] = None, loop: Optional[asyncio.AbstractEventLoop] = None): def start_dns_worker(
hbdclass,
cfg: dict,
log: Optional[callable] = None,
pushmsg: Optional[callable] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
):
"""Start the async DNS worker and return the Task. """Start the async DNS worker and return the Task.
Replaces Host.dnsQ with an asyncio.Queue wrapped in a thread-safe bridge Replaces Host.dnsQ with an asyncio.Queue wrapped in a thread-safe bridge
@@ -139,6 +187,7 @@ def start_dns_worker(hbdclass, cfg: dict, log: Optional[callable] = None, pushms
class _QueueBridge: class _QueueBridge:
"""Thread-safe wrapper around asyncio.Queue for synchronous callers.""" """Thread-safe wrapper around asyncio.Queue for synchronous callers."""
def __init__(self, loop, aq): def __init__(self, loop, aq):
self._loop = loop self._loop = loop
self._aq = aq self._aq = aq
@@ -167,5 +216,9 @@ def start_dns_worker(hbdclass, cfg: dict, log: Optional[callable] = None, pushms
bridge = _QueueBridge(loop, async_q) bridge = _QueueBridge(loop, async_q)
hbdclass.Host.dnsQ = bridge hbdclass.Host.dnsQ = bridge
task = loop.create_task(dns_update_worker(hbdclass, cfg, async_queue=async_q, log=log, pushmsg=pushmsg, loop=loop)) task = loop.create_task(
dns_update_worker(
hbdclass, cfg, async_queue=async_q, log=log, pushmsg=pushmsg, loop=loop
)
)
return task return task
+27 -20
View File
@@ -7,10 +7,7 @@ import time
import socket import socket
import os import os
import signal import signal
import getopt
import string
import select import select
import errno
import traceback import traceback
from hashlib import md5 from hashlib import md5
import shutil import shutil
@@ -37,13 +34,13 @@ helpflag = False
verbose = False verbose = False
fdaemon = False fdaemon = False
daemonized = False daemonized = False
optlist = []
msgboot = {} msgboot = {}
home = os.environ["HOME"] home = os.environ["HOME"]
configfile = "%s/.hbrc" % home configfile = "%s/.hbrc" % home
cmdargs = [] cmdargs = []
iam = socket.gethostname() iam = socket.gethostname()
def log(msg): def log(msg):
if fdaemon: if fdaemon:
syslog.syslog(syslog.LOG_ERR, msg) syslog.syslog(syslog.LOG_ERR, msg)
@@ -115,7 +112,7 @@ class Conn:
try: try:
self.lastack = msgDict["time"] self.lastack = msgDict["time"]
mul = 2 mul = 2
except: except Exception:
self.lastack = now self.lastack = now
mul = 1 mul = 1
rtt = (self.lastack - self.lastsend) * mul rtt = (self.lastack - self.lastsend) * mul
@@ -140,7 +137,7 @@ def shortname(name):
def dicttos(ID, d): def dicttos(ID, d):
s = [] s = []
for k in d: for k in d:
if type(d[k]) == type(1.2): if isinstance(d[k], float):
s.append("%s=%0.5f" % (k, d[k])) s.append("%s=%0.5f" % (k, d[k]))
else: else:
s.append("%s=%s" % (k, d[k])) s.append("%s=%s" % (k, d[k]))
@@ -169,7 +166,7 @@ def stodict(msg):
v = vr[1].strip() v = vr[1].strip()
try: try:
v = eval(v) v = eval(v)
except: except Exception:
pass pass
d[k] = v d[k] = v
if verbose: if verbose:
@@ -199,7 +196,7 @@ def XXstodict(msg):
try: try:
if v[0].isdigit(): if v[0].isdigit():
v = eval(v) v = eval(v)
except: except Exception:
pass pass
d[k] = v d[k] = v
return d return d
@@ -208,8 +205,8 @@ def XXstodict(msg):
def syslogtrace(note): def syslogtrace(note):
logm = "%s hbc died: \n%s" % (note, traceback.format_exc()) logm = "%s hbc died: \n%s" % (note, traceback.format_exc())
log(logm) log(logm)
for l in logm.split("\n"): for line in logm.split("\n"):
syslog.syslog(syslog.LOG_ERR, " tb: %s" % l) syslog.syslog(syslog.LOG_ERR, " tb: %s" % line)
if verbose: if verbose:
print(logm) print(logm)
@@ -314,7 +311,7 @@ def restart():
e = "fallthrough" e = "fallthrough"
try: try:
os.execv(sys.argv[0], [sys.argv[0]] + cmdargs) os.execv(sys.argv[0], [sys.argv[0]] + cmdargs)
except Exception as e: except Exception:
pass pass
print("should not be here:", str(e)) print("should not be here:", str(e))
log("restart failed: %s" % e) log("restart failed: %s" % e)
@@ -350,7 +347,7 @@ def process():
if running: if running:
running = False running = False
break break
except: except Exception:
if running: if running:
syslogtrace("select") syslogtrace("select")
running = False running = False
@@ -374,12 +371,12 @@ def process():
"sock.recvfrom: %s (%s) %s" "sock.recvfrom: %s (%s) %s"
% (addr, len(data), str(msgDict)[:80]) % (addr, len(data), str(msgDict)[:80])
) )
if msgDict == None: if msgDict is None:
print("bad backet from %s (%s) %s" % (addr, len(data), data)) print("bad backet from %s (%s) %s" % (addr, len(data), data))
elif msgDict["ID"] == "ACK": elif msgDict["ID"] == "ACK":
conns[conn].ack(msgDict, now) conns[conn].ack(msgDict, now)
elif msgDict["ID"] == "UPD": elif msgDict["ID"] == "UPD":
if doupdate(conn, msgDict) == None: if doupdate(conn, msgDict) is None:
if verbose: if verbose:
print("process: restart after update") print("process: restart after update")
dorestart = True dorestart = True
@@ -473,6 +470,7 @@ def daemonize(
os.dup2(so.fileno(), sys.stdout.fileno()) os.dup2(so.fileno(), sys.stdout.fileno())
os.dup2(se.fileno(), sys.stderr.fileno()) os.dup2(se.fileno(), sys.stderr.fileno())
# #
# Main program # Main program
# #
@@ -483,17 +481,26 @@ def build_parser():
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
) )
parser.add_argument("-b", "--boot", action="store_true", help="Send a boot message") parser.add_argument("-b", "--boot", action="store_true", help="Send a boot message")
parser.add_argument("-c", "--config", dest="configfile", help="Config file path (YAML)") parser.add_argument(
"-c", "--config", dest="configfile", help="Config file path (YAML)"
)
parser.add_argument("-m", "--message", dest="message", help="Send a message") parser.add_argument("-m", "--message", dest="message", help="Send a message")
parser.add_argument("-n", "--name", dest="name", help="Name to use in heartbeat message") parser.add_argument(
parser.add_argument("-f", "--daemon", action="store_true", help="Run in daemon mode") "-n", "--name", dest="name", help="Name to use in heartbeat message"
)
parser.add_argument(
"-f", "--daemon", action="store_true", help="Run in daemon mode"
)
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument("-x", "--debug", action="count", default=0, help="Increase debug level") parser.add_argument(
"-x", "--debug", action="count", default=0, help="Increase debug level"
)
parser.add_argument("hosts", nargs="+", help="Heartbeat daemon hosts to send to") parser.add_argument("hosts", nargs="+", help="Heartbeat daemon hosts to send to")
return parser return parser
def main(argv=None): def main(argv=None):
global msgonly, helpflag, verbose, fdaemon, daemonized, optlist, msgboot, home, configfile, cmdargs, iam, hb_port, conns, interval, hb_hosts global msgonly, verbose, fdaemon, daemonized, cmdargs, iam, hb_port, conns, interval, hb_hosts
parser = build_parser() parser = build_parser()
args = parser.parse_args(argv) args = parser.parse_args(argv)
@@ -575,7 +582,6 @@ def main(argv=None):
syslog.syslog(syslog.LOG_ERR, "starting heartbeat to %s" % ",".join(hb_hosts)) syslog.syslog(syslog.LOG_ERR, "starting heartbeat to %s" % ",".join(hb_hosts))
signal.signal(signal.SIGTERM, handler) signal.signal(signal.SIGTERM, handler)
running = True
try: try:
process() process()
except Exception as e: except Exception as e:
@@ -589,5 +595,6 @@ def main(argv=None):
if dorestart: if dorestart:
restart() restart()
if __name__ == "__main__": if __name__ == "__main__":
main() main()
+2 -3
View File
@@ -148,7 +148,7 @@ class Connection:
r = "changed from %s to %s" % (self.addr, addr) r = "changed from %s to %s" % (self.addr, addr)
try: try:
del Connection.htab[self.addr] del Connection.htab[self.addr]
except: except Exception:
pass pass
self.addr = addr self.addr = addr
Connection.htab[addr] = self.host.name Connection.htab[addr] = self.host.name
@@ -293,7 +293,6 @@ class Host:
def dispstats(self): def dispstats(self):
if self.doesack != -1: if self.doesack != -1:
if self.upcount > 0: if self.upcount > 0:
# return "(%0.1f%%) %s %s %s " % ((self.doesack * 100.0) / self.upcount, self.doesack, self.upcount, self.hdwcounts)
r = "" r = ""
for v in range(3): for v in range(3):
a, u = self.hdwcounts[v] a, u = self.hdwcounts[v]
@@ -372,7 +371,7 @@ class Host:
res = [] res = []
le = max(40 - len(Host.hosts), 3) le = max(40 - len(Host.hosts), 3)
res.append("<h4>Log of Events</h4>") res.append("<h4>Log of Events</h4>")
for m in msgs[len(msgs) - le:]: for m in msgs[len(msgs) - le :]:
res.append("%s<BR>" % m) res.append("%s<BR>" % m)
return res return res
+15 -6
View File
@@ -1,4 +1,5 @@
"""HTTP server implementation using aiohttp and jinja2.""" """HTTP server implementation using aiohttp and jinja2."""
import asyncio import asyncio
import json import json
import time import time
@@ -6,15 +7,16 @@ import urllib.parse
import os import os
import logging import logging
from aiohttp import web from aiohttp import web
from fastapi.templating import Jinja2Templates
import jinja2 import jinja2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _render_template(html_str: str, **context) -> str: def _render_template(html_str: str, **context) -> str:
tmpl = jinja2.Template(html_str) tmpl = jinja2.Template(html_str)
return tmpl.render(**context) return tmpl.render(**context)
async def start( async def start(
host: str, host: str,
port: int, port: int,
@@ -42,7 +44,7 @@ async def start(
res.append('<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">') res.append('<!DOCTYPE HTML PUBLIC "-//IETF//DTD HTML 2.0//EN">')
res.append("<html>") res.append("<html>")
res.append("<head>") res.append("<head>")
res.append(f"<title>Heartbeat</title>") res.append("<title>Heartbeat</title>")
if tcss: if tcss:
res.append(tcss) res.append(tcss)
res.append("</head>") res.append("</head>")
@@ -51,7 +53,11 @@ async def start(
res += hbdclass.ubHost.buildhosttable() res += hbdclass.ubHost.buildhosttable()
res += hbdclass.ubHost.buildmsgtable(msgs_getter()) res += hbdclass.ubHost.buildmsgtable(msgs_getter())
res.append( res.append(
"<p> %s (%s)</p>" % (time.strftime("%H:%M:%S", time.localtime(get_now())), config.get("tz", "CET-1CDT")) "<p> %s (%s)</p>"
% (
time.strftime("%H:%M:%S", time.localtime(get_now())),
config.get("tz", "CET-1CDT"),
)
) )
res.append("</body></html>") res.append("</body></html>")
body = "\n".join(res) body = "\n".join(res)
@@ -73,7 +79,9 @@ async def start(
return web.Response(status=400, text="need h= and c= arguments") return web.Response(status=400, text="need h= and c= arguments")
if uname not in hbdclass.Host.hosts: if uname not in hbdclass.Host.hosts:
return web.Response(status=400, text=f"h={uname} not found") return web.Response(status=400, text=f"h={uname} not found")
hbdclass.Host.hosts[uname].cmds.append(("CMD", {"cmd": urllib.parse.unquote(ucmd)})) hbdclass.Host.hosts[uname].cmds.append(
("CMD", {"cmd": urllib.parse.unquote(ucmd)})
)
return web.Response(text=f"cmd {uname} queued") return web.Response(text=f"cmd {uname} queued")
async def drop(request): async def drop(request):
@@ -150,7 +158,9 @@ async def start(
request=request, request=request,
heartbeat_ws_url=heartbeat_ws_url, heartbeat_ws_url=heartbeat_ws_url,
extra_scripts=extra_scripts, extra_scripts=extra_scripts,
hosts=[hbdclass.Host.hosts[h].stateinfo() for h in sorted(hbdclass.Host.hosts)], hosts=[
hbdclass.Host.hosts[h].stateinfo() for h in sorted(hbdclass.Host.hosts)
],
messages=msgs_getter()[-30:], messages=msgs_getter()[-30:],
) )
return web.Response(text=body, content_type="text/html") return web.Response(text=body, content_type="text/html")
@@ -209,4 +219,3 @@ async def start(
await asyncio.Future() await asyncio.Future()
finally: finally:
await runner.cleanup() await runner.cleanup()
+14 -8
View File
@@ -1,15 +1,19 @@
"""monitor helper and thread for heartbeat daemon.""" """monitor helper and thread for heartbeat daemon."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import threading
import subprocess
import time import time
from subprocess import Popen, PIPE, STDOUT
from typing import Optional
from . import hbdclass
DROPOVERDUE = 7 * 24 * 3600 DROPOVERDUE = 7 * 24 * 3600
def checkoverdue(config: dict, hbdclass, log: callable, pushmsg: callable, msg_to_websockets: callable):
def checkoverdue(
config: dict,
hbdclass,
log: callable,
pushmsg: callable,
msg_to_websockets: callable,
):
now = time.time() now = time.time()
for h in list(hbdclass.Host.hosts.keys()): for h in list(hbdclass.Host.hosts.keys()):
pmsg = [] pmsg = []
@@ -22,7 +26,8 @@ def checkoverdue(config: dict, hbdclass, log: callable, pushmsg: callable, msg_t
conn.newstate(hbdclass.Connection.OVERDUE, now, config.get("grace", 10)) conn.newstate(hbdclass.Connection.OVERDUE, now, config.get("grace", 10))
pmsg.append(conn.afam) pmsg.append(conn.afam)
if ( if (
conn.state == hbdclass.Connection.OVERDUE and (now - conn.lastbeat) > DROPOVERDUE conn.state == hbdclass.Connection.OVERDUE
and (now - conn.lastbeat) > DROPOVERDUE
): ):
conn.newstate(hbdclass.Connection.UNKNOWN, conn.lastbeat) conn.newstate(hbdclass.Connection.UNKNOWN, conn.lastbeat)
if pmsg != []: if pmsg != []:
@@ -31,6 +36,7 @@ def checkoverdue(config: dict, hbdclass, log: callable, pushmsg: callable, msg_t
log(h, "%s overdue" % " and ".join(pmsg)) log(h, "%s overdue" % " and ".join(pmsg))
msg_to_websockets("host", hbdclass.Host.hosts[h].stateinfo()) msg_to_websockets("host", hbdclass.Host.hosts[h].stateinfo())
async def start( async def start(
config: dict, config: dict,
hbdclass: callable, hbdclass: callable,
@@ -38,7 +44,7 @@ async def start(
pushmsg=None, pushmsg=None,
msg_to_websockets=None, msg_to_websockets=None,
): ):
""" start a monitor loop that checks for overdue hosts every minute """ """start a monitor loop that checks for overdue hosts every minute"""
while True: while True:
await asyncio.sleep(15) # 15 seconds between checks await asyncio.sleep(15) # 15 seconds between checks
checkoverdue(config, hbdclass, log, pushmsg, msg_to_websockets) checkoverdue(config, hbdclass, log, pushmsg, msg_to_websockets)
+42 -9
View File
@@ -1,4 +1,5 @@
"""Notification helpers: email, pushover, mattermost, signal and dispatcher.""" """Notification helpers: email, pushover, mattermost, signal and dispatcher."""
import logging import logging
from typing import Optional from typing import Optional
import http.client import http.client
@@ -6,7 +7,6 @@ import urllib.parse
import subprocess import subprocess
import smtplib import smtplib
import time import time
import traceback
DEFAULT_PUSHPROVIDERS = ["all", "pushover", "mattermost", "signal"] DEFAULT_PUSHPROVIDERS = ["all", "pushover", "mattermost", "signal"]
@@ -60,7 +60,12 @@ def email(subject: str, msg: str, debug: int = 0) -> bool:
fromemail = _config.get("fromemail") fromemail = _config.get("fromemail")
smtpserver = _config.get("smtpserver") smtpserver = _config.get("smtpserver")
if not toaddrs or not fromemail or not smtpserver: if not toaddrs or not fromemail or not smtpserver:
logger.warning("email config incomplete: toemail=%s, fromemail=%s, smtpserver=%s", toaddrs, fromemail, smtpserver) logger.warning(
"email config incomplete: toemail=%s, fromemail=%s, smtpserver=%s",
toaddrs,
fromemail,
smtpserver,
)
return False return False
date = time.strftime("%a, %d %b %Y %H:%M:%S %z", time.localtime()) date = time.strftime("%a, %d %b %Y %H:%M:%S %z", time.localtime())
body = "To: %s\nFrom: %s\nSubject: %s\nDate: %s\n\n%s" % ( body = "To: %s\nFrom: %s\nSubject: %s\nDate: %s\n\n%s" % (
@@ -91,7 +96,15 @@ def pushover(token: str, user: str, msg: str, debug: int = 0) -> bool:
return False return False
def pushmattermost(host: str, token: str, channel: str, msg: str, username: str = "hbd", icon: Optional[str] = None, debug: int = 0) -> bool: def pushmattermost(
host: str,
token: str,
channel: str,
msg: str,
username: str = "hbd",
icon: Optional[str] = None,
debug: int = 0,
) -> bool:
"""Send a message to Mattermost via simple webhook driver if available. """Send a message to Mattermost via simple webhook driver if available.
This helper tries to import mattermostdriver.Driver and uses webhooks if present. This helper tries to import mattermostdriver.Driver and uses webhooks if present.
@@ -115,7 +128,9 @@ def pushmattermost(host: str, token: str, channel: str, msg: str, username: str
return False return False
def pushsignal(signal_cli_bin: str, user: str, recipient: str, msg: str, debug: int = 0) -> bool: def pushsignal(
signal_cli_bin: str, user: str, recipient: str, msg: str, debug: int = 0
) -> bool:
"""Send a message via signal-cli (requires local installation). """Send a message via signal-cli (requires local installation).
Uses subprocess to call signal-cli. Returns True if the command succeeded. Uses subprocess to call signal-cli. Returns True if the command succeeded.
@@ -125,7 +140,7 @@ def pushsignal(signal_cli_bin: str, user: str, recipient: str, msg: str, debug:
try: try:
res = subprocess.run(CLI, capture_output=True) res = subprocess.run(CLI, capture_output=True)
if res.returncode != 0: if res.returncode != 0:
logger.error("signal failed: %s". res.stderr.decode()) logger.error("signal failed: %s".res.stderr.decode())
return False return False
logger.debug("signal sent: %s", res.stdout.decode()) logger.debug("signal sent: %s", res.stdout.decode())
return True return True
@@ -148,13 +163,32 @@ def pushmsg(cfg: dict, msg: str, debug: int = 0):
results = {} results = {}
p = cfg.get("pushsrv", "pushover") p = cfg.get("pushsrv", "pushover")
if p in ("all", "pushover"): if p in ("all", "pushover"):
ok = pushover(cfg.get("pushover_token", ""), cfg.get("pushover_user", ""), msg, debug=debug) ok = pushover(
cfg.get("pushover_token", ""),
cfg.get("pushover_user", ""),
msg,
debug=debug,
)
results["pushover"] = ok results["pushover"] = ok
if p in ("all", "mattermost"): if p in ("all", "mattermost"):
ok = pushmattermost(cfg.get("matter_host", ""), cfg.get("matter_token", ""), cfg.get("matter_channel", ""), msg, username=cfg.get("matter_username", "hbd"), icon=cfg.get("matter_icon"), debug=debug) ok = pushmattermost(
cfg.get("matter_host", ""),
cfg.get("matter_token", ""),
cfg.get("matter_channel", ""),
msg,
username=cfg.get("matter_username", "hbd"),
icon=cfg.get("matter_icon"),
debug=debug,
)
results["mattermost"] = ok results["mattermost"] = ok
if p in ("all", "signal"): if p in ("all", "signal"):
ok = pushsignal(cfg.get("signal_cli", "/usr/local/bin/signal-cli"), cfg.get("signal_user", ""), cfg.get("signal_recipient", ""), msg, debug=debug) ok = pushsignal(
cfg.get("signal_cli", "/usr/local/bin/signal-cli"),
cfg.get("signal_user", ""),
cfg.get("signal_recipient", ""),
msg,
debug=debug,
)
results["signal"] = ok results["signal"] = ok
if p in ("all", "email"): if p in ("all", "email"):
ok = email("Heartbeat notification", msg, debug=debug) ok = email("Heartbeat notification", msg, debug=debug)
@@ -166,4 +200,3 @@ def pushmsg(cfg: dict, msg: str, debug: int = 0):
def pushmsg_from_config(msg: str, debug: int = 0) -> dict: def pushmsg_from_config(msg: str, debug: int = 0) -> dict:
"""Use the module-level configuration dict to dispatch a push message.""" """Use the module-level configuration dict to dispatch a push message."""
return pushmsg(_config, msg, debug=debug) return pushmsg(_config, msg, debug=debug)
+1
View File
@@ -1,4 +1,5 @@
"""Message encoding/decoding utilities for hbd protocol.""" """Message encoding/decoding utilities for hbd protocol."""
from typing import Dict, Any from typing import Dict, Any
import zlib import zlib
+41 -18
View File
@@ -1,4 +1,5 @@
"""Server runtime: starts UDP listener, HTTP server and websocket stubs.""" """Server runtime: starts UDP listener, HTTP server and websocket stubs."""
import asyncio import asyncio
import logging import logging
import socket import socket
@@ -6,7 +7,6 @@ import time
import signal import signal
import sys import sys
import ssl import ssl
import pathlib
from . import __version__ from . import __version__
from . import udp from . import udp
@@ -23,14 +23,17 @@ lastfm = ["", "", ""]
# shared runtime collections and helpers # shared runtime collections and helpers
msgs = [] msgs = []
def initlog(logfile): def initlog(logfile):
try: try:
return open(logfile, "a+") return open(logfile, "a+")
except Exception as e: except Exception as e:
import sys import sys
print("cannot open loffile %s, using STDERR: %s" % (logfile, e)) print("cannot open loffile %s, using STDERR: %s" % (logfile, e))
return sys.stderr return sys.stderr
def log(host, m, service=None): def log(host, m, service=None):
ts = time.time() ts = time.time()
s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {host or ''} {m}" s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {host or ''} {m}"
@@ -44,10 +47,12 @@ def log(host, m, service=None):
logger.warning("failed to write to logfile: %s", e) logger.warning("failed to write to logfile: %s", e)
msg_to_websockets("message", s) msg_to_websockets("message", s)
def cleanup_function(config): def cleanup_function(config):
"""This function will be executed upon program exit.""" """This function will be executed upon program exit."""
logger.info("Running cleanup function...") logger.info("Running cleanup function...")
import pickle import pickle
pickfile = config.get("pickfile", "hbd.pickle") pickfile = config.get("pickfile", "hbd.pickle")
pickf = open(pickfile, "wb") pickf = open(pickfile, "wb")
@@ -59,14 +64,14 @@ def cleanup_function(config):
logger.info("Cleanup complete.") logger.info("Cleanup complete.")
async def _run_async(config): async def _run_async(config):
global msgs
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
# Signal handlers for graceful shutdown # Signal handlers for graceful shutdown
def signal_handler(signum, frame): def signal_handler(signum, frame):
sig_name = signal.Signals(signum).name if hasattr(signal, 'Signals') else signum sig_name = signal.Signals(signum).name if hasattr(signal, "Signals") else signum
logger.info(f"Received {sig_name}, initiating shutdown...") logger.info(f"Received {sig_name}, initiating shutdown...")
loop.call_soon_threadsafe(shutdown_event.set) loop.call_soon_threadsafe(shutdown_event.set)
@@ -74,9 +79,6 @@ async def _run_async(config):
loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT, None) loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT, None)
loop.add_signal_handler(signal.SIGTERM, signal_handler, signal.SIGTERM, None) loop.add_signal_handler(signal.SIGTERM, signal_handler, signal.SIGTERM, None)
# prepare runtime dependencies
import threading
# from . import hbdclass
from . import http as http_mod from . import http as http_mod
from . import dns as dns_mod from . import dns as dns_mod
from . import notify as notify_mod from . import notify as notify_mod
@@ -93,7 +95,9 @@ async def _run_async(config):
try: try:
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False) sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
except OSError as e: except OSError as e:
logger.error(f"Warning: Could not set IPV6_V6ONLY to False. System may not support dual-stack or option is unavailable. Error: {e}") logger.warning(
f"Warning: Could not reset IPV6_V6ONLY not supported or dual-stack is unavailable. Error: {e}"
)
# 3. Bind to all interfaces (::) on a specific port # 3. Bind to all interfaces (::) on a specific port
@@ -138,14 +142,20 @@ async def _run_async(config):
VER="", VER="",
) )
) )
logger.info("HTTP server started on %s:%s", config.get("hbd_host", ""), config.get("hbd_port", 50004)) logger.info(
"HTTP server started on %s:%s",
config.get("hbd_host", ""),
config.get("hbd_port", 50004),
)
except Exception as e: except Exception as e:
logger.exception("failed to start HTTP server: %s", e) logger.exception("failed to start HTTP server: %s", e)
# start dns update worker (async) # start dns update worker (async)
dns_task = None dns_task = None
try: try:
dns_task = dns_mod.start_dns_worker(hbdclass, config, log=log, pushmsg=pushmsg, loop=loop) dns_task = dns_mod.start_dns_worker(
hbdclass, config, log=log, pushmsg=pushmsg, loop=loop
)
logger.info("dns update worker started") logger.info("dns update worker started")
except Exception as e: except Exception as e:
logger.exception("dns worker failed to start: %s", e) logger.exception("dns worker failed to start: %s", e)
@@ -161,7 +171,11 @@ async def _run_async(config):
except FileNotFoundError: except FileNotFoundError:
logger.error("error: missing SSL keys %s or %s", wss_pem, wss_key) logger.error("error: missing SSL keys %s or %s", wss_pem, wss_key)
sys.exit(1) sys.exit(1)
logger.info("Starting secure WebSocket server on port %s with cert %s", config.get("wss_port", None), wss_pem) logger.info(
"Starting secure WebSocket server on port %s with cert %s",
config.get("wss_port", None),
wss_pem,
)
else: else:
ssl_context = None ssl_context = None
@@ -172,7 +186,10 @@ async def _run_async(config):
ws_port=config.get("ws_port", None), ws_port=config.get("ws_port", None),
wss_port=config.get("wss_port", None), wss_port=config.get("wss_port", None),
ssl_context=ssl_context, ssl_context=ssl_context,
get_hosts=lambda: [hbdclass.Host.hosts[h].stateinfo() for h in sorted(hbdclass.Host.hosts)], get_hosts=lambda: [
hbdclass.Host.hosts[h].stateinfo()
for h in sorted(hbdclass.Host.hosts)
],
get_msgs=lambda: msgs, get_msgs=lambda: msgs,
verbose=config.get("verbose", False), verbose=config.get("verbose", False),
) )
@@ -223,7 +240,10 @@ async def _run_async(config):
remaining_tasks = [t for t in tasks_to_cancel if t] remaining_tasks = [t for t in tasks_to_cancel if t]
if remaining_tasks: if remaining_tasks:
try: try:
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=2.0) await asyncio.wait_for(
asyncio.gather(*remaining_tasks, return_exceptions=True),
timeout=2.0,
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning("Timeout waiting for tasks to cancel") logger.warning("Timeout waiting for tasks to cancel")
except Exception as e: except Exception as e:
@@ -231,7 +251,7 @@ async def _run_async(config):
# Signal DNS worker to exit and await it # Signal DNS worker to exit and await it
try: try:
if 'dns_task' in locals() and dns_task: if "dns_task" in locals() and dns_task:
try: try:
hbdclass.Host.dnsQ.put(None) hbdclass.Host.dnsQ.put(None)
except Exception: except Exception:
@@ -275,7 +295,7 @@ def load_pickled_hosts(config, hbdclass):
msgs = pick.load() msgs = pick.load()
try: try:
lastfm = pick.load() lastfm = pick.load()
except: except Exception:
lastfm = ["", "", ""] lastfm = ["", "", ""]
pickf.close() pickf.close()
except Exception as e: except Exception as e:
@@ -295,6 +315,7 @@ def load_pickled_hosts(config, hbdclass):
if config.get("verbose", False): if config.get("verbose", False):
logger.info("no pickled data") logger.info("no pickled data")
def run(config): def run(config):
"""Start the hbd service (blocking). """Start the hbd service (blocking).
@@ -302,10 +323,10 @@ def run(config):
""" """
global logf global logf
import os import os
import threading
import time as time_module
logging.basicConfig(level=logging.DEBUG if config.get("debug", 0) > 0 else logging.INFO) logging.basicConfig(
level=logging.DEBUG if config.get("debug", 0) > 0 else logging.INFO
)
load_pickled_hosts(config, hbdclass) load_pickled_hosts(config, hbdclass)
logf = initlog(logfile=config.get("logfile", "messages.log")) logf = initlog(logfile=config.get("logfile", "messages.log"))
@@ -337,7 +358,9 @@ def run(config):
task.cancel() task.cancel()
# Run one more cycle to process cancellations # Run one more cycle to process cancellations
if pending: if pending:
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
except Exception: except Exception:
pass pass
finally: finally:
+6 -5
View File
@@ -1,14 +1,14 @@
"""UDP listener and datagram processing.""" """UDP listener and datagram processing."""
import asyncio import asyncio
import zlib import zlib
import logging import logging
logger = logging.getLogger(__name__)
from .proto import stodict, oldmtodict from .proto import stodict, oldmtodict
from hbd.utils import dur from hbd.utils import dur
logger = logging.getLogger(__name__)
class EchoServerProtocol(asyncio.DatagramProtocol): class EchoServerProtocol(asyncio.DatagramProtocol):
def __init__(self, config=None, handler=None): def __init__(self, config=None, handler=None):
@@ -44,6 +44,7 @@ def parse_message(data: bytes):
msg = oldmtodict(data) msg = oldmtodict(data)
return msg return msg
def dicttos(ID, d, compress=False): def dicttos(ID, d, compress=False):
s = [] s = []
for k in d: for k in d:
@@ -61,6 +62,7 @@ def dicttos(ID, d, compress=False):
opk = ID + ":" + zpk opk = ID + ":" + zpk
return opk return opk
def handle_datagram(msg: dict, addr, transport, ctx: dict): def handle_datagram(msg: dict, addr, transport, ctx: dict):
"""Handle a parsed datagram message. """Handle a parsed datagram message.
@@ -87,6 +89,7 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
ip = addr[0] if isinstance(addr, (list, tuple)) else addr ip = addr[0] if isinstance(addr, (list, tuple)) else addr
name = msg.get("name", "unknown") name = msg.get("name", "unknown")
from hbd.utils import shortname from hbd.utils import shortname
uname = shortname(name) uname = shortname(name)
if uname not in hbdcls.Host.hosts: if uname not in hbdcls.Host.hosts:
@@ -215,5 +218,3 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
msg_to_websockets("host", host.stateinfo()) msg_to_websockets("host", host.stateinfo())
except Exception: except Exception:
pass pass
-1
View File
@@ -1,5 +1,4 @@
"""Utility helpers extracted from the original script.""" """Utility helpers extracted from the original script."""
import time
def shortname(name: str) -> str: def shortname(name: str) -> str:
+21 -8
View File
@@ -3,6 +3,7 @@
Provides an asyncio-based WebSocket server and a thread-safe broadcast Provides an asyncio-based WebSocket server and a thread-safe broadcast
function that other threads or synchronous code can call. function that other threads or synchronous code can call.
""" """
import asyncio import asyncio
import json import json
import logging import logging
@@ -20,7 +21,6 @@ _verbose = False
async def _handler(websocket, path=None): async def _handler(websocket, path=None):
global _connections
_connections.add(websocket) _connections.add(websocket)
remote_address = websocket.remote_address remote_address = websocket.remote_address
if path is None: if path is None:
@@ -46,7 +46,10 @@ async def _handler(websocket, path=None):
if _verbose: if _verbose:
logger.debug("received ws data: %s", _) logger.debug("received ws data: %s", _)
except (websockets.exceptions.ConnectionClosedOK, websockets.exceptions.ConnectionClosedError) as e: except (
websockets.exceptions.ConnectionClosedOK,
websockets.exceptions.ConnectionClosedError,
) as e:
if _verbose: if _verbose:
logger.info("ws closed: %r", e) logger.info("ws closed: %r", e)
except Exception as e: except Exception as e:
@@ -59,7 +62,15 @@ async def _handler(websocket, path=None):
await websocket.wait_closed() await websocket.wait_closed()
async def start(host: str, ws_port: int, wss_port: Optional[int] = None, ssl_context=None, get_hosts: Optional[Callable] = None, get_msgs: Optional[Callable] = None, verbose: bool = False): async def start(
host: str,
ws_port: int,
wss_port: Optional[int] = None,
ssl_context=None,
get_hosts: Optional[Callable] = None,
get_msgs: Optional[Callable] = None,
verbose: bool = False,
):
"""Start WebSocket servers and block until cancelled. """Start WebSocket servers and block until cancelled.
This is intended to be awaited inside the main asyncio event loop. This is intended to be awaited inside the main asyncio event loop.
@@ -77,11 +88,13 @@ async def start(host: str, ws_port: int, wss_port: Optional[int] = None, ssl_con
websockets_logger = logging.getLogger("websockets.server") websockets_logger = logging.getLogger("websockets.server")
websockets_logger.setLevel(logging.DEBUG if verbose else logging.INFO) websockets_logger.setLevel(logging.DEBUG if verbose else logging.INFO)
# regular WebSocket # regular WebSocket
ws_server = websockets.serve(_handler, host, ws_port) #, subprotocols=["hbd"]) ws_server = websockets.serve(_handler, host, ws_port) # , subprotocols=["hbd"])
servers.append(ws_server) servers.append(ws_server)
# secure WebSocket (optional) # secure WebSocket (optional)
if wss_port and ssl_context: if wss_port and ssl_context:
wss_server = websockets.serve(_handler, host, wss_port, ssl=ssl_context ) #, subprotocols=["hbd"]) wss_server = websockets.serve(
_handler, host, wss_port, ssl=ssl_context
) # , subprotocols=["hbd"])
servers.append(wss_server) servers.append(wss_server)
# await starting of all servers # await starting of all servers
@@ -89,7 +102,9 @@ async def start(host: str, ws_port: int, wss_port: Optional[int] = None, ssl_con
await srv await srv
if _verbose: if _verbose:
logger.info("WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port) logger.info(
"WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port
)
# block forever (until loop is stopped or cancelled) # block forever (until loop is stopped or cancelled)
await asyncio.Future() await asyncio.Future()
@@ -101,8 +116,6 @@ def broadcast(typ: str, data) -> bool:
Schedules coroutine(s) on the running loop to send message to all Schedules coroutine(s) on the running loop to send message to all
connected websockets. Returns False if server was not running. connected websockets. Returns False if server was not running.
""" """
global _loop
if not _loop: if not _loop:
return False return False
jmsg = json.dumps({"type": typ, "data": data}) jmsg = json.dumps({"type": typ, "data": data})
+10
View File
@@ -45,3 +45,13 @@ include = ["hbd*"]
[tool.setuptools.package-data] [tool.setuptools.package-data]
"hbd" = ["*.yaml", "static/*", "static/*/*", "templates/*"] "hbd" = ["*.yaml", "static/*", "static/*/*", "templates/*"]
[tool.black]
line-length = 111
[tool.flake8]
max-line-length = 111
[tool.pylint.format]
max-line-length = 111
View File
+29 -6
View File
@@ -47,7 +47,13 @@ class TestDNS(unittest.TestCase):
proc.communicate.return_value = (b"some error", None) proc.communicate.return_value = (b"some error", None)
mock_popen.return_value = proc mock_popen.return_value = proc
err = dns.nsupdate("host", "1.2.3.4", "example", nsupdate_bin="/usr/bin/nsupdate", rndc_key="/etc/rndc.key") err = dns.nsupdate(
"host",
"1.2.3.4",
"example",
nsupdate_bin="/usr/bin/nsupdate",
rndc_key="/etc/rndc.key",
)
self.assertIsNotNone(err) self.assertIsNotNone(err)
self.assertIn("some error", err) self.assertIn("some error", err)
@@ -71,7 +77,9 @@ class TestDNS(unittest.TestCase):
Host = FakeHost Host = FakeHost
# start the thread (daemon) that processes the queue # start the thread (daemon) that processes the queue
t = dns.start_dns_thread(FakeHbd, {"dyndomains": ["example"]}, log=log, email=email) t = dns.start_dns_thread(
FakeHbd, {"dyndomains": ["example"]}, log=log, email=email
)
self.assertTrue(t.is_alive()) self.assertTrue(t.is_alive())
# enqueue one item and wait for it to be processed (polling with timeout) # enqueue one item and wait for it to be processed (polling with timeout)
@@ -83,7 +91,9 @@ class TestDNS(unittest.TestCase):
time.sleep(0.1) time.sleep(0.1)
self.assertTrue(logs, "dnsupdatethread did not call log") self.assertTrue(logs, "dnsupdatethread did not call log")
self.assertTrue(any("changed address" in m or "DNS updated" in m for (_h, m) in logs)) self.assertTrue(
any("changed address" in m or "DNS updated" in m for (_h, m) in logs)
)
def test_dnsupdatethread_calls_email_on_failure(self): def test_dnsupdatethread_calls_email_on_failure(self):
# patch nsupdate to fail with an error message # patch nsupdate to fail with an error message
@@ -104,7 +114,9 @@ class TestDNS(unittest.TestCase):
class FakeHbd: class FakeHbd:
Host = FakeHost Host = FakeHost
t = dns.start_dns_thread(FakeHbd, {"dyndomains": ["example"]}, log=log, email=email) dns.start_dns_thread(
FakeHbd, {"dyndomains": ["example"]}, log=log, email=email
)
# enqueue and wait for the email to be sent # enqueue and wait for the email to be sent
FakeHbd.Host.dnsQ.put(("testhost", "1.2.3.4")) FakeHbd.Host.dnsQ.put(("testhost", "1.2.3.4"))
@@ -114,12 +126,23 @@ class TestDNS(unittest.TestCase):
time.sleep(0.1) time.sleep(0.1)
self.assertTrue(emails, "dnsupdatethread did not call email on failure") self.assertTrue(emails, "dnsupdatethread did not call email on failure")
self.assertTrue(any("nsupdate failed" in s or "nsupdate failed" in m or "error" in m for (s, m) in emails)) self.assertTrue(
any(
"nsupdate failed" in s or "nsupdate failed" in m or "error" in m
for (s, m) in emails
)
)
@patch("hbd.dns.Popen") @patch("hbd.dns.Popen")
def test_nsupdate_raises_oserror(self, mock_popen): def test_nsupdate_raises_oserror(self, mock_popen):
mock_popen.side_effect = OSError("noexec") mock_popen.side_effect = OSError("noexec")
err = dns.nsupdate("h", "1.2.3.4", "example", nsupdate_bin="/usr/bin/nsupdate", rndc_key="/etc/rndc.key") err = dns.nsupdate(
"h",
"1.2.3.4",
"example",
nsupdate_bin="/usr/bin/nsupdate",
rndc_key="/etc/rndc.key",
)
self.assertIsNotNone(err) self.assertIsNotNone(err)
self.assertIn("execution failed", err) self.assertIn("execution failed", err)
+19 -17
View File
@@ -1,9 +1,11 @@
from hbd.udp import handle_datagram, parse_message from hbd.udp import handle_datagram, parse_message
from hbd.proto import dicttos from hbd.proto import dicttos
class FakeTransport: class FakeTransport:
def __init__(self): def __init__(self):
self.sent = [] self.sent = []
def sendto(self, data, addr): def sendto(self, data, addr):
self.sent.append((data, addr)) self.sent.append((data, addr))
@@ -18,30 +20,30 @@ def test_handle_cmd_sends_command():
import hbdclass import hbdclass
ctx = { ctx = {
'config': {'watchhosts':[], 'dyndnshosts':[]}, "config": {"watchhosts": [], "dyndnshosts": []},
'hbdclass': hbdclass, "hbdclass": hbdclass,
'log': dummy_noop, "log": dummy_noop,
'email': dummy_noop, "email": dummy_noop,
'pushmsg': dummy_noop, "pushmsg": dummy_noop,
'msg_to_websockets': dummy_noop, "msg_to_websockets": dummy_noop,
'msgs': [], "msgs": [],
'DEBUG': 0, "DEBUG": 0,
'verbose': False, "verbose": False,
} }
# create host by sending initial heartbeat # create host by sending initial heartbeat
msg = parse_message(dicttos('HTB', {'name':'cmdhost','interval':10})) msg = parse_message(dicttos("HTB", {"name": "cmdhost", "interval": 10}))
handle_datagram(msg, ('127.0.0.1',50000), ftr, ctx) handle_datagram(msg, ("127.0.0.1", 50000), ftr, ctx)
assert ftr.sent[0][0] == b'ACK' assert ftr.sent[0][0] == b"ACK"
# queue a CMD for the host and send another heartbeat; expect command sent # queue a CMD for the host and send another heartbeat; expect command sent
h = hbdclass.Host.hosts['cmdhost'] h = hbdclass.Host.hosts["cmdhost"]
h.cmds.append(('CMD', {'cmd': 'doit'})) h.cmds.append(("CMD", {"cmd": "doit"}))
ftr.sent.clear() ftr.sent.clear()
msg2 = parse_message(dicttos('HTB', {'name':'cmdhost','interval':10})) msg2 = parse_message(dicttos("HTB", {"name": "cmdhost", "interval": 10}))
handle_datagram(msg2, ('127.0.0.1',50000), ftr, ctx) handle_datagram(msg2, ("127.0.0.1", 50000), ftr, ctx)
# should have sent ACK and the command; last send should be non-empty # should have sent ACK and the command; last send should be non-empty
assert len(ftr.sent) >= 1 assert len(ftr.sent) >= 1
# the command for cver 0 will be sent as raw cmd string # the command for cver 0 will be sent as raw cmd string
# so at least one send contains b'doit' or similar # so at least one send contains b'doit' or similar
assert any(b'doit' in s[0] for s in ftr.sent) assert any(b"doit" in s[0] for s in ftr.sent)
-1
View File
@@ -1,4 +1,3 @@
import pytest
from hbd.proto import dicttos, stodict, oldmtodict from hbd.proto import dicttos, stodict, oldmtodict
+4 -4
View File
@@ -3,12 +3,12 @@ from hbd.proto import dicttos
def test_parse_message_uncompressed(): def test_parse_message_uncompressed():
raw = dicttos('HTB', {'name': 'host', 'interval': 1}) raw = dicttos("HTB", {"name": "host", "interval": 1})
m = parse_message(raw) m = parse_message(raw)
assert m['ID'].startswith('HTB') assert m["ID"].startswith("HTB")
def test_parse_message_compressed(): def test_parse_message_compressed():
raw = dicttos('ACK', {'time': 1}, compress=True) raw = dicttos("ACK", {"time": 1}, compress=True)
m = parse_message(raw) m = parse_message(raw)
assert 'ID' in m assert "ID" in m
+1 -1
View File
@@ -22,5 +22,5 @@ commands =
mypy hbd mypy hbd
[flake8] [flake8]
max-line-length = 88 max-line-length = 111
extend-ignore = E203 extend-ignore = E203