371 lines
13 KiB
Python
371 lines
13 KiB
Python
"""UDP listener and datagram processing."""
|
|
|
|
import asyncio
|
|
import zlib
|
|
import logging
|
|
|
|
from ..common.proto import stodict, oldmtodict
|
|
from ..common.utils import dur
|
|
from . import notify as notify_mod
|
|
|
|
logger = logging.getLogger(__name__)
|
|
eventlog = notify_mod.eventlog
|
|
|
|
|
|
class EchoServerProtocol(asyncio.DatagramProtocol):
|
|
def __init__(self, config=None, handler=None):
|
|
super().__init__()
|
|
self.config = config or {}
|
|
self.handler = handler
|
|
|
|
def connection_made(self, transport):
|
|
self.transport = transport
|
|
logger.info("UDP Server listening...")
|
|
|
|
def datagram_received(self, data, addr):
|
|
logger.debug("Received from %s", addr)
|
|
try:
|
|
msg = parse_message(data)
|
|
if self.handler:
|
|
# handler can be a callable provided by the application
|
|
# pass the transport so handlers can send replies (ACKs/commands)
|
|
self.handler(msg, addr, self.transport)
|
|
except Exception:
|
|
logger.exception("Error while processing datagram from %s", addr)
|
|
|
|
|
|
def parse_message(data: bytes):
|
|
"""Parse a raw datagram into a message dict.
|
|
|
|
Uses the protocol decoding helpers and falls back to old format when
|
|
decoding returns an empty dict (compat with older clients).
|
|
"""
|
|
msg = stodict(data)
|
|
if not msg:
|
|
# fallback to old format
|
|
msg = oldmtodict(data)
|
|
return msg
|
|
|
|
|
|
def dicttos(ID, d):
|
|
s = []
|
|
for k in d:
|
|
if isinstance(d[k], float):
|
|
s.append("%s=%0.5f" % (k, d[k]))
|
|
else:
|
|
s.append("%s=%s" % (k, d[k]))
|
|
pk = ";".join(s)
|
|
zpk = zlib.compress(pk.encode(), 6)
|
|
ID = "!" + ID + ":"
|
|
opk = ID.encode() + zpk
|
|
return opk
|
|
|
|
|
|
DROPOVERDUE = 7 * 24 * 3600 # seconds before an overdue host becomes UNKNOWN
|
|
|
|
|
|
def _make_timer_callbacks(uname, host, watchhosts, ctx):
|
|
"""Return (on_overdue, on_unknown) async callbacks for connection timer logic.
|
|
|
|
Captured values are bound at call time so callbacks are safe to use in loops.
|
|
"""
|
|
msg_to_websockets = ctx.get("msg_to_websockets")
|
|
threshold_checker = ctx.get("threshold_checker")
|
|
cfg = ctx.get("config", {})
|
|
|
|
async def on_unknown(connection):
|
|
connection.newstate(connection.__class__.UNKNOWN, connection.lastbeat)
|
|
if msg_to_websockets:
|
|
msg_to_websockets("host", host.stateinfo())
|
|
|
|
async def on_overdue(connection):
|
|
import time
|
|
if connection.getstate() != connection.__class__.UP:
|
|
return
|
|
now = time.time()
|
|
connection.newstate(connection.__class__.OVERDUE, now, cfg.get("grace", 2))
|
|
msg = f"{connection.afam} overdue"
|
|
eventlog(uname, "CRITICAL" if uname in watchhosts else "WARNING", msg)
|
|
if uname in watchhosts:
|
|
notify_mod.pushmsg_for_host(uname, f"{uname} {msg}")
|
|
if threshold_checker:
|
|
threshold_checker.check_value(
|
|
host_name=uname,
|
|
metric_path="rtt",
|
|
value=float("inf"),
|
|
alert_states=host.alert_states,
|
|
)
|
|
if msg_to_websockets:
|
|
msg_to_websockets("host", host.stateinfo())
|
|
connection.reset_overdue_timer(DROPOVERDUE, on_unknown)
|
|
|
|
return on_overdue, on_unknown
|
|
|
|
|
|
def restore_connection_timers(hbdclass, ctx):
|
|
"""Restore overdue timers for all loaded connections after a pickle restore.
|
|
|
|
For UP connections, the remaining time until overdue is calculated from
|
|
lastbeat so that clients that vanished during hbd's downtime are detected.
|
|
For OVERDUE connections, the UNKNOWN drop timer is restored.
|
|
"""
|
|
import time
|
|
now = time.time()
|
|
cfg = ctx.get("config", {})
|
|
grace = cfg.get("grace", 2)
|
|
from . import config as config_mod
|
|
watchhosts = config_mod.get_watchhosts(cfg)
|
|
|
|
restored = 0
|
|
for uname, host in list(hbdclass.Host.hosts.items()):
|
|
interval = host.interval
|
|
for afam, conn in list(host.connections.items()):
|
|
state = conn.getstate()
|
|
if state == hbdclass.Connection.DOWN:
|
|
continue
|
|
|
|
on_overdue, on_unknown = _make_timer_callbacks(uname, host, watchhosts, ctx)
|
|
|
|
if state == hbdclass.Connection.UP and interval > 0:
|
|
elapsed = now - conn.lastbeat
|
|
remaining = max(1.0, (interval + grace) - elapsed)
|
|
conn.reset_overdue_timer(remaining, on_overdue)
|
|
logger.debug(
|
|
"Restored UP timer %s/%s: %.0fs remaining (elapsed %.0fs)",
|
|
uname, afam, remaining, elapsed,
|
|
)
|
|
restored += 1
|
|
|
|
elif state == hbdclass.Connection.OVERDUE:
|
|
elapsed_overdue = now - conn.statetime
|
|
remaining = DROPOVERDUE - elapsed_overdue
|
|
if remaining <= 1:
|
|
# Already past the drop window — mark UNKNOWN immediately
|
|
conn.newstate(hbdclass.Connection.UNKNOWN, conn.lastbeat)
|
|
logger.info(
|
|
"Marking %s/%s UNKNOWN (overdue %.1f days)",
|
|
uname, afam, elapsed_overdue / 86400,
|
|
)
|
|
else:
|
|
conn.reset_overdue_timer(remaining, on_unknown)
|
|
logger.debug(
|
|
"Restored OVERDUE timer %s/%s: %.0fs remaining",
|
|
uname, afam, remaining,
|
|
)
|
|
restored += 1
|
|
|
|
logger.info("Restored timers for %d connection(s)", restored)
|
|
|
|
|
|
def handle_datagram(msg: dict, addr, transport, ctx: dict):
|
|
"""Handle a parsed datagram message.
|
|
|
|
ctx is a dictionary with runtime dependencies:
|
|
- config: dict of configuration
|
|
- hbdclass: module providing Host/Connection classes
|
|
- log: callable(loghost, message)
|
|
- msg_to_websockets: callable(typ, data)
|
|
- msg_journal: MessageJournal instance for logging all messages
|
|
- DEBUG, verbose
|
|
"""
|
|
if not msg:
|
|
return
|
|
now = __import__("time").time()
|
|
|
|
# Log message to journal
|
|
msg_journal = ctx.get("msg_journal")
|
|
if msg_journal:
|
|
# Create async task to log message (non-blocking)
|
|
import asyncio
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
loop.create_task(msg_journal.log_message(msg, addr, now))
|
|
except Exception as e:
|
|
logger.debug(f"Failed to log message to journal: {e}")
|
|
|
|
cfg = ctx.get("config", {})
|
|
hbdcls = ctx.get("hbdclass")
|
|
log = ctx.get("log")
|
|
msg_to_websockets = ctx.get("msg_to_websockets")
|
|
DEBUG = ctx.get("DEBUG", 0)
|
|
verbose = ctx.get("verbose", False)
|
|
|
|
# normalize addr (ip, port)
|
|
ip = addr[0] if isinstance(addr, (list, tuple)) else addr
|
|
name = msg.get("name", "unknown")
|
|
from ..common.utils import shortname
|
|
from . import config as config_mod
|
|
|
|
uname = shortname(name)
|
|
|
|
if uname not in hbdcls.Host.hosts:
|
|
host = hbdcls.Host(uname)
|
|
# Use new config function to check dyndns
|
|
dyndnshosts = config_mod.get_dyndnshosts(cfg)
|
|
host.dyn = uname in dyndnshosts
|
|
if verbose:
|
|
print(("XX: New host, num now %s" % (len(hbdcls.Host.hosts))))
|
|
newh = True
|
|
else:
|
|
host = hbdcls.Host.hosts[uname]
|
|
newh = False
|
|
|
|
# Get watchhosts once for use throughout message handling
|
|
watchhosts = config_mod.get_watchhosts(cfg)
|
|
|
|
cid = msg.get("id", 0)
|
|
try:
|
|
rtt = float(msg.get("rtt"))
|
|
except TypeError:
|
|
rtt = None
|
|
|
|
if msg.get("ID") == "HTB":
|
|
host.doesack = msg.get("acks", -1)
|
|
# send ACK back
|
|
rmsg = {"time": __import__("time").time()}
|
|
opkt = dicttos("ACK", rmsg)
|
|
try:
|
|
transport.sendto(opkt, addr)
|
|
except Exception as e:
|
|
if DEBUG > 0:
|
|
print(("cannot send ack: %s" % e))
|
|
|
|
elif msg.get("ID") == "PLG":
|
|
# Handle plugin data message
|
|
plugin_name = msg.get("plugin")
|
|
if plugin_name:
|
|
# Extract plugin fields, dropping protocol metadata fields
|
|
plugin_data = {k: v for k, v in msg.items()
|
|
if k not in ("ID", "plugin", "id", "name")}
|
|
# Store plugin data with timestamp
|
|
host.add_plugin_data(plugin_name, plugin_data, timestamp=now)
|
|
if DEBUG > 1:
|
|
print(f"Stored plugin data for {uname}: {plugin_name}")
|
|
|
|
# Check thresholds if checker is available
|
|
threshold_checker = ctx.get("threshold_checker")
|
|
if threshold_checker:
|
|
try:
|
|
state_changes = threshold_checker.check_plugin_data(
|
|
host_name=uname,
|
|
plugin_name=plugin_name,
|
|
data=plugin_data,
|
|
alert_states=host.alert_states,
|
|
)
|
|
if DEBUG > 1 and state_changes:
|
|
print(f"Threshold state changes for {uname}: {state_changes}")
|
|
except Exception as e:
|
|
logger.error(f"Error checking thresholds for {uname}.{plugin_name}: {e}")
|
|
|
|
# Notify websockets of plugin update
|
|
if msg_to_websockets:
|
|
try:
|
|
msg_to_websockets("plugin", {
|
|
"host": uname,
|
|
"plugin": plugin_name,
|
|
"data": plugin_data,
|
|
"timestamp": now
|
|
})
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
conn, res = host.conndata(cid, ip, rtt, now)
|
|
except Exception as e:
|
|
if DEBUG > 0:
|
|
print("conndata failed: %s" % e)
|
|
return
|
|
|
|
if res:
|
|
eventlog(uname, "WARNING", res)
|
|
if uname in watchhosts:
|
|
notify_mod.pushmsg_for_host(uname, "%s %s" % (host.name, res))
|
|
|
|
interval = int(msg.get("interval", 0) or 0)
|
|
shutdown = msg.get("shutdown", 0)
|
|
service = msg.get("service", "unknown")
|
|
message = msg.get("msg", None)
|
|
boot = msg.get("boot", 0)
|
|
|
|
if boot:
|
|
eventlog(uname, "INFO", "booted")
|
|
if uname in watchhosts:
|
|
m = "%s booted" % (host.name)
|
|
notify_mod.pushmsg_for_host(uname, m)
|
|
if message:
|
|
eventlog(uname, "INFO", "msg: %s" % message, service=service)
|
|
if uname in watchhosts:
|
|
notify_mod.pushmsg_for_host(uname, message)
|
|
|
|
if conn.getstate() != hbdcls.Connection.UP:
|
|
lasts = conn.state
|
|
d = conn.newstate(hbdcls.Connection.UP, now)
|
|
if d == 0 or lasts == "unknown":
|
|
m = "%s is up" % (conn.afam)
|
|
else:
|
|
m = "%s back after being %s for %s" % (conn.afam, lasts, dur(d))
|
|
eventlog(uname, "RECOVER", m)
|
|
if uname in watchhosts:
|
|
notify_mod.pushmsg_for_host(uname, "%s %s is back" % (uname, conn.afam))
|
|
|
|
if boot or newh:
|
|
host.upcount = host.doesack
|
|
else:
|
|
host.upcount += 1
|
|
|
|
if shutdown:
|
|
eventlog(uname, "INFO", "%s shutdown" % conn.afam)
|
|
if uname in watchhosts:
|
|
notify_mod.pushmsg_for_host(uname, "%s %s shutdown" % (uname, conn.afam))
|
|
conn.newstate(hbdcls.Connection.DOWN, now)
|
|
|
|
if interval > 0:
|
|
host.interval = interval
|
|
|
|
# Timer-based reachability monitoring
|
|
# Reset overdue timer on every heartbeat
|
|
if interval > 0 and conn.getstate() != hbdcls.Connection.DOWN:
|
|
grace = cfg.get("grace", 2)
|
|
timeout_seconds = interval + grace
|
|
on_overdue, _ = _make_timer_callbacks(uname, host, watchhosts, ctx)
|
|
conn.reset_overdue_timer(timeout_seconds, on_overdue)
|
|
|
|
# Check RTT thresholds using the threshold checker
|
|
threshold_checker = ctx.get("threshold_checker")
|
|
if threshold_checker and rtt and rtt > 0:
|
|
# Metric path for RTT is simply "rtt"
|
|
metric_path = "rtt"
|
|
|
|
# Check against configured thresholds (handles alerts, notifications, etc.)
|
|
threshold_checker.check_value(
|
|
host_name=uname,
|
|
metric_path=metric_path,
|
|
value=rtt,
|
|
alert_states=host.alert_states
|
|
)
|
|
|
|
# send any commands we have queued
|
|
while len(host.cmds):
|
|
op, rmsg = host.cmds[0]
|
|
if op == "CMD":
|
|
del host.cmds[0]
|
|
if log:
|
|
log(uname, "command sent")
|
|
elif op == "UPD":
|
|
del host.cmds[0]
|
|
if log:
|
|
log(uname, "update initiated")
|
|
opkt = dicttos(op, rmsg)
|
|
try:
|
|
transport.sendto(opkt, addr)
|
|
except Exception as e:
|
|
if DEBUG > 0:
|
|
print(("cannot send cmd/update: %s" % e))
|
|
|
|
if msg_to_websockets:
|
|
try:
|
|
msg_to_websockets("host", host.stateinfo())
|
|
except Exception as e:
|
|
if DEBUG > 0:
|
|
print(("cannot send websocket message: %s" % e))
|