save state to pickle file, restart timers on restart

This commit is contained in:
2026-04-06 17:24:59 -04:00
parent 57c4b86430
commit 832a8b0bda
6 changed files with 195 additions and 91 deletions
+1 -1
View File
@@ -189,7 +189,7 @@ class Connection:
except Exception:
pass
self.addr = addr
Connection.htab[addr] = self.host.nameconnection_count
Connection.htab[addr] = self.host.name
if self.host.isDynDns():
Host.dnsQ.put((self.host.name, self.addr))
return r
+46 -10
View File
@@ -22,12 +22,12 @@ eventlog = notify_mod.eventlog
# shared runtime collections and helpers
def cleanup_function(config, hbdclass):
"""This function will be executed upon program exit."""
logger.info("Running cleanup function...")
def save_state(config, hbdclass):
"""Save current state to pickle file. Safe to call at any time."""
import pickle
import os
# Ensure all timer references are cleared before pickling
# Clear timer references before pickling (they can't be serialized)
for hostname, host in list(hbdclass.Host.hosts.items()):
for conn_type, conn in host.connections.items():
if hasattr(conn, 'cancel_overdue_timer'):
@@ -40,13 +40,26 @@ def cleanup_function(config, hbdclass):
conn.timeout_duration = None
pickfile = config.get("pickfile", "hbd.pickle")
tmpfile = pickfile + ".tmp"
pickf = open(pickfile, "wb")
pick = pickle.Pickler(pickf)
pick.dump(hbdclass.Host.hosts)
pick.dump(data.msgs)
pickf.close()
try:
with open(tmpfile, "wb") as pickf:
pick = pickle.Pickler(pickf)
pick.dump(hbdclass.Host.hosts)
pick.dump(data.msgs)
os.replace(tmpfile, pickfile)
except Exception as e:
logger.error("Failed to save state: %s", e)
try:
os.unlink(tmpfile)
except Exception:
pass
def cleanup_function(config, hbdclass):
"""This function will be executed upon program exit."""
logger.info("Running cleanup function...")
save_state(config, hbdclass)
logger.info("Cleanup complete.")
@@ -185,6 +198,16 @@ async def _run_async(config, config_path=None):
sock=sock,
)
# Restore connection timers for hosts loaded from pickle
restore_ctx = dict(
config=config,
hbdclass=hbdclass,
log=eventlog,
msg_to_websockets=msg_to_websockets,
threshold_checker=threshold_checker,
)
udp.restore_connection_timers(hbdclass, restore_ctx)
# HTTP server (asyncio-based via aiohttp)
try:
http_task = asyncio.create_task(
@@ -257,6 +280,19 @@ async def _run_async(config, config_path=None):
except Exception as e:
logger.exception("websocket server failed to start: %s", e)
# Periodic autosave task
autosave_interval = config.get("autosave_interval", 300) # default: 5 minutes
async def autosave_task():
while True:
await asyncio.sleep(autosave_interval)
logger.debug("Autosaving state...")
save_state(config, hbdclass)
logger.debug("Autosave complete (%d hosts)", len(hbdclass.Host.hosts))
autosave = asyncio.create_task(autosave_task())
logger.info("Autosave task started (interval: %ds)", autosave_interval)
# Main event loop - monitor shutdown and reload events
try:
while True:
@@ -304,7 +340,7 @@ async def _run_async(config, config_path=None):
except Exception as e:
logger.warning("Error closing UDP transport: %s", e)
tasks_to_cancel = [http_task, ws_task]
tasks_to_cancel = [http_task, ws_task, autosave]
for task in tasks_to_cancel:
if task:
try:
+15 -5
View File
@@ -459,23 +459,27 @@
}
}
// Protocol metadata fields injected by the client never plugin metrics
const SKIP_FIELDS = new Set(['id', 'name']);
function renderPluginData(data, timestamp) {
// Check if this should be rendered as a simple table
const pluginName = getCurrentPluginName();
const simplePlugins = ['os_info', 'cpu_monitor', 'memory_monitor', 'nagios_runner'];
if (simplePlugins.includes(pluginName) && isSimpleKeyValueData(data)) {
return renderSimpleDataTable(data, timestamp);
}
let html = '<div class="metric-grid">';
for (const [key, value] of Object.entries(data)) {
if (SKIP_FIELDS.has(key)) continue;
// Skip nested objects for now, handle them separately
if (typeof value === 'object' && value !== null) {
continue;
}
html += renderMetric(key, value);
}
@@ -572,10 +576,11 @@
// Table body
html += '<tbody>';
for (const [key, value] of Object.entries(data)) {
if (SKIP_FIELDS.has(key)) continue;
const label = formatLabel(key);
const formattedValue = formatValue(key, value);
const unit = getUnit(key);
html += '<tr>';
html += `<td class="name">${label}</td>`;
html += `<td class="value">${formattedValue}${unit ? ' ' + unit : ''}</td>`;
@@ -1012,12 +1017,17 @@
}
function formatLabel(key) {
if (key === 'time') return 'Collected At';
return key
.replace(/_/g, ' ')
.replace(/\b\w/g, l => l.toUpperCase());
}
function formatValue(key, value) {
// Epoch timestamp field sent by the client alongside plugin data
if (key === 'time' && typeof value === 'number') {
return new Date(value * 1000).toLocaleString();
}
if (typeof value === 'number') {
// Format percentages
if (key.includes('percent') || key.includes('usage')) {
+101 -47
View File
@@ -61,6 +61,102 @@ def dicttos(ID, d):
return opk
DROPOVERDUE = 7 * 24 * 3600 # seconds before an overdue host becomes UNKNOWN
def _make_timer_callbacks(uname, host, watchhosts, ctx):
"""Return (on_overdue, on_unknown) async callbacks for connection timer logic.
Captured values are bound at call time so callbacks are safe to use in loops.
"""
msg_to_websockets = ctx.get("msg_to_websockets")
threshold_checker = ctx.get("threshold_checker")
cfg = ctx.get("config", {})
async def on_unknown(connection):
connection.newstate(connection.__class__.UNKNOWN, connection.lastbeat)
if msg_to_websockets:
msg_to_websockets("host", host.stateinfo())
async def on_overdue(connection):
import time
if connection.getstate() != connection.__class__.UP:
return
now = time.time()
connection.newstate(connection.__class__.OVERDUE, now, cfg.get("grace", 2))
msg = f"{connection.afam} overdue"
eventlog(uname, "CRITICAL" if uname in watchhosts else "WARNING", msg)
if uname in watchhosts:
notify_mod.pushmsg_for_host(uname, f"{uname} {msg}")
if threshold_checker:
threshold_checker.check_value(
host_name=uname,
metric_path="rtt",
value=float("inf"),
alert_states=host.alert_states,
)
if msg_to_websockets:
msg_to_websockets("host", host.stateinfo())
connection.reset_overdue_timer(DROPOVERDUE, on_unknown)
return on_overdue, on_unknown
def restore_connection_timers(hbdclass, ctx):
"""Restore overdue timers for all loaded connections after a pickle restore.
For UP connections, the remaining time until overdue is calculated from
lastbeat so that clients that vanished during hbd's downtime are detected.
For OVERDUE connections, the UNKNOWN drop timer is restored.
"""
import time
now = time.time()
cfg = ctx.get("config", {})
grace = cfg.get("grace", 2)
from . import config as config_mod
watchhosts = config_mod.get_watchhosts(cfg)
restored = 0
for uname, host in list(hbdclass.Host.hosts.items()):
interval = host.interval
for afam, conn in list(host.connections.items()):
state = conn.getstate()
if state == hbdclass.Connection.DOWN:
continue
on_overdue, on_unknown = _make_timer_callbacks(uname, host, watchhosts, ctx)
if state == hbdclass.Connection.UP and interval > 0:
elapsed = now - conn.lastbeat
remaining = max(1.0, (interval + grace) - elapsed)
conn.reset_overdue_timer(remaining, on_overdue)
logger.debug(
"Restored UP timer %s/%s: %.0fs remaining (elapsed %.0fs)",
uname, afam, remaining, elapsed,
)
restored += 1
elif state == hbdclass.Connection.OVERDUE:
elapsed_overdue = now - conn.statetime
remaining = DROPOVERDUE - elapsed_overdue
if remaining <= 1:
# Already past the drop window — mark UNKNOWN immediately
conn.newstate(hbdclass.Connection.UNKNOWN, conn.lastbeat)
logger.info(
"Marking %s/%s UNKNOWN (overdue %.1f days)",
uname, afam, elapsed_overdue / 86400,
)
else:
conn.reset_overdue_timer(remaining, on_unknown)
logger.debug(
"Restored OVERDUE timer %s/%s: %.0fs remaining",
uname, afam, remaining,
)
restored += 1
logger.info("Restored timers for %d connection(s)", restored)
def handle_datagram(msg: dict, addr, transport, ctx: dict):
"""Handle a parsed datagram message.
@@ -138,8 +234,9 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
# Handle plugin data message
plugin_name = msg.get("plugin")
if plugin_name:
# Extract all fields except ID and plugin name
plugin_data = {k: v for k, v in msg.items() if k not in ["ID", "plugin"]}
# Extract plugin fields, dropping protocol metadata fields
plugin_data = {k: v for k, v in msg.items()
if k not in ("ID", "plugin", "id", "name")}
# Store plugin data with timestamp
host.add_plugin_data(plugin_name, plugin_data, timestamp=now)
if DEBUG > 1:
@@ -229,51 +326,8 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
# Reset overdue timer on every heartbeat
if interval > 0 and conn.getstate() != hbdcls.Connection.DOWN:
grace = cfg.get("grace", 2)
timeout_seconds = (interval + grace) if interval > 0 else 30
# Create callback for timer expiration
async def on_overdue(connection):
"""Called when connection timer expires (no heartbeat received)."""
import time
now = time.time()
# Only mark as overdue if still in UP state (not already marked)
if connection.getstate() == hbdcls.Connection.UP:
connection.newstate(hbdcls.Connection.OVERDUE, now, cfg.get("grace", 2))
msg = f"{connection.afam} overdue"
eventlog(uname, "CRITICAL" if uname in watchhosts else "WARNING", msg)
if uname in watchhosts:
notify_mod.pushmsg_for_host(uname, f"{uname} {msg}")
# Check RTT thresholds with infinite RTT for overdue hosts
threshold_checker = ctx.get("threshold_checker")
if threshold_checker:
metric_path = "rtt"
threshold_checker.check_value(
host_name=uname,
metric_path=metric_path,
value=float('inf'),
alert_states=host.alert_states
)
# Notify websockets
if msg_to_websockets:
msg_to_websockets("host", host.stateinfo())
# Set a longer timer for marking as UNKNOWN (7 days)
DROPOVERDUE = 7 * 24 * 3600
async def on_unknown(connection):
"""Mark connection as unknown after extended absence."""
connection.newstate(hbdcls.Connection.UNKNOWN, connection.lastbeat)
if msg_to_websockets:
msg_to_websockets("host", host.stateinfo())
connection.reset_overdue_timer(DROPOVERDUE, on_unknown)
# Reset the timer
timeout_seconds = interval + grace
on_overdue, _ = _make_timer_callbacks(uname, host, watchhosts, ctx)
conn.reset_overdue_timer(timeout_seconds, on_overdue)
# Check RTT thresholds using the threshold checker
+29 -26
View File
@@ -65,11 +65,7 @@ async def _handler(websocket, path=None):
logger.exception("WebSocket handler exception from %s: %s", remote_address, e)
finally:
logger.debug("Removing WebSocket connection from %s", remote_address)
try:
_connections.remove(websocket)
except KeyError:
pass
await websocket.wait_closed()
_connections.discard(websocket)
async def start(
@@ -93,33 +89,40 @@ async def start(
_verbose = config.get("verbose", False),
_debug = config.get("debug", 0),
servers = []
# plain WebSocket
websockets_logger = logging.getLogger("websockets.server")
#if _debug > 2:
# websockets_logger.setLevel(logging.DEBUG)
#else:
# websockets_logger.setLevel(logging.INFO)
# regular WebSocket
ws_server = websockets.serve(_handler, host, ws_port) # , subprotocols=["hbd"])
servers.append(ws_server)
# secure WebSocket (optional)
# Start servers and keep the server objects for clean shutdown
running_servers = []
ws_server = await websockets.serve(_handler, host, ws_port)
running_servers.append(ws_server)
if wss_port and ssl_context:
wss_server = websockets.serve(
_handler, host, wss_port, ssl=ssl_context
) # , subprotocols=["hbd"])
servers.append(wss_server)
# await starting of all servers
for srv in servers:
await srv
wss_server = await websockets.serve(_handler, host, wss_port, ssl=ssl_context)
running_servers.append(wss_server)
logger.info(
"WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port
)
# block forever (until loop is stopped or cancelled)
await asyncio.Future()
try:
# Block until cancelled
await asyncio.Future()
except asyncio.CancelledError:
pass
finally:
# Close all active browser connections so their handler coroutines exit
active = list(_connections)
if active:
logger.info("Closing %d active WebSocket connection(s)...", len(active))
await asyncio.gather(
*[ws.close() for ws in active],
return_exceptions=True,
)
# Stop the listening servers and wait for all handlers to finish
for srv in running_servers:
srv.close()
await asyncio.gather(
*[srv.wait_closed() for srv in running_servers],
return_exceptions=True,
)
logger.info("WebSocket server(s) stopped")
def broadcast(typ: str, data) -> bool: