1a19088cfe
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
528 lines
19 KiB
Python
528 lines
19 KiB
Python
"""UDP listener and datagram processing."""
|
||
|
||
import asyncio
|
||
import socket
|
||
import struct
|
||
import time
|
||
import zlib
|
||
import logging
|
||
|
||
from platform import system as platform_system
|
||
|
||
from ..common.proto import stodict, oldmtodict
|
||
from ..common.utils import dur
|
||
from . import notify as notify_mod
|
||
|
||
logger = logging.getLogger(__name__)
|
||
eventlog = notify_mod.eventlog
|
||
|
||
# SO_TIMESTAMP: kernel attaches a struct timeval to each received datagram.
|
||
# Supported on Linux, FreeBSD, and macOS. The constant is not exposed by
|
||
# Python's socket module on all platforms
|
||
platform = platform_system()
|
||
if platform == "Darwin":
|
||
_SO_TIMESTAMP = 1024 # SO_TIMESTAMP on macOS (not in Python's socket module)
|
||
elif platform == "Linux":
|
||
_SO_TIMESTAMP = 29 # Linux value (not in older Python versions)
|
||
elif platform == "FreeBSD":
|
||
_SO_TIMESTAMP = 32 # FreeBSD value (not in older Python versions)
|
||
else:
|
||
logger.warning("SO_TIMESTAMP may not be supported on this platform (%s)", platform)
|
||
_SO_TIMESTAMP = None
|
||
|
||
# struct timeval uses two native C longs: tv_sec and tv_usec
|
||
_TIMEVAL = struct.Struct('@ll')
|
||
|
||
|
||
def enable_kernel_timestamps(sock) -> bool:
|
||
"""Try to enable SO_TIMESTAMP on *sock*.
|
||
|
||
Returns True if the kernel will supply receive timestamps, False otherwise
|
||
(unsupported platform, older kernel, or insufficient permissions).
|
||
"""
|
||
try:
|
||
sock.setsockopt(socket.SOL_SOCKET, _SO_TIMESTAMP, 1)
|
||
return True
|
||
except OSError:
|
||
return False
|
||
|
||
|
||
def _extract_kernel_ts(ancdata) -> float | None:
|
||
"""Parse recvmsg ancillary data and return the kernel receive time.
|
||
|
||
Returns seconds as a float, or None if no SO_TIMESTAMP cmsg is present.
|
||
"""
|
||
for cmsg_level, cmsg_type, cmsg_data in ancdata:
|
||
if cmsg_level == socket.SOL_SOCKET and cmsg_type == _SO_TIMESTAMP:
|
||
if len(cmsg_data) >= _TIMEVAL.size:
|
||
sec, usec = _TIMEVAL.unpack_from(cmsg_data)
|
||
return sec + usec * 1e-6
|
||
return None
|
||
|
||
|
||
class RecvmsgTransport:
|
||
"""Thin wrapper used when SO_TIMESTAMP is active (add_reader path).
|
||
|
||
Exposes the same sendto() / close() interface as asyncio's DatagramTransport
|
||
so the rest of the code does not need to know which path is in use.
|
||
"""
|
||
def __init__(self, loop, sock):
|
||
self._loop = loop
|
||
self._sock = sock
|
||
|
||
def sendto(self, data, addr):
|
||
try:
|
||
self._sock.sendto(data, addr)
|
||
except Exception as e:
|
||
logger.debug("sendto failed: %s", e)
|
||
|
||
def close(self):
|
||
try:
|
||
self._loop.remove_reader(self._sock.fileno())
|
||
except Exception:
|
||
pass
|
||
try:
|
||
self._sock.close()
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def make_recvmsg_reader(sock, handler, transport):
|
||
"""Return a callback suitable for loop.add_reader().
|
||
|
||
Reads one datagram per call using recvmsg() so that kernel timestamps in
|
||
the ancillary data are accessible. Falls back to time.time() if the
|
||
cmsg is missing.
|
||
|
||
handler(msg, addr, transport, kernel_ts) – same signature as udp_handler
|
||
in main.py with the optional kernel_ts argument.
|
||
"""
|
||
BUFSIZE = 65536
|
||
ANCBUFSIZE = 128 # enough for one struct timespec cmsg
|
||
|
||
def _read():
|
||
try:
|
||
data, ancdata, _, addr = sock.recvmsg(BUFSIZE, ANCBUFSIZE)
|
||
except BlockingIOError:
|
||
return
|
||
except OSError as e:
|
||
logger.warning("recvmsg error: %s", e)
|
||
return
|
||
try:
|
||
kernel_ts = _extract_kernel_ts(ancdata)
|
||
msg = parse_message(data)
|
||
if msg:
|
||
handler(msg, addr, transport, kernel_ts)
|
||
except Exception:
|
||
logger.exception("Error processing datagram from %s", addr)
|
||
|
||
return _read
|
||
|
||
|
||
class EchoServerProtocol(asyncio.DatagramProtocol):
|
||
def __init__(self, config=None, handler=None):
|
||
super().__init__()
|
||
self.config = config or {}
|
||
self.handler = handler
|
||
|
||
def connection_made(self, transport):
|
||
self.transport = transport
|
||
logger.info("UDP Server listening...")
|
||
|
||
def datagram_received(self, data, addr):
|
||
logger.debug("Received from %s", addr)
|
||
try:
|
||
msg = parse_message(data)
|
||
if self.handler:
|
||
# handler can be a callable provided by the application
|
||
# pass the transport so handlers can send replies (ACKs/commands)
|
||
self.handler(msg, addr, self.transport)
|
||
except Exception:
|
||
logger.exception("Error while processing datagram from %s", addr)
|
||
|
||
|
||
def parse_message(data: bytes):
|
||
"""Parse a raw datagram into a message dict.
|
||
|
||
Uses the protocol decoding helpers and falls back to old format when
|
||
decoding returns an empty dict (compat with older clients).
|
||
"""
|
||
msg = stodict(data)
|
||
if not msg:
|
||
# fallback to old format
|
||
msg = oldmtodict(data)
|
||
return msg
|
||
|
||
|
||
def dicttos(ID, d):
|
||
s = []
|
||
for k in d:
|
||
if isinstance(d[k], float):
|
||
s.append("%s=%0.5f" % (k, d[k]))
|
||
else:
|
||
s.append("%s=%s" % (k, d[k]))
|
||
pk = ";".join(s)
|
||
zpk = zlib.compress(pk.encode(), 6)
|
||
ID = "!" + ID + ":"
|
||
opk = ID.encode() + zpk
|
||
return opk
|
||
|
||
|
||
DROPOVERDUE = 7 * 24 * 3600 # seconds before an overdue host becomes UNKNOWN
|
||
|
||
|
||
def _set_connectivity_alert(host, afam, level_name):
|
||
"""Update (or clear) a connectivity alert_state entry for a host/address-family.
|
||
|
||
level_name is "CRITICAL", "WARNING", or "OK". "OK" removes the entry so
|
||
that recovered hosts don't clutter the Alerts Dashboard.
|
||
"""
|
||
from .threshold import AlertState, AlertLevel
|
||
metric_path = f"connectivity.{afam}"
|
||
level = getattr(AlertLevel, level_name, AlertLevel.OK)
|
||
if level == AlertLevel.OK:
|
||
host.alert_states.pop(metric_path, None)
|
||
return
|
||
if metric_path not in host.alert_states:
|
||
host.alert_states[metric_path] = AlertState(metric_path)
|
||
state = host.alert_states[metric_path]
|
||
state.update(level, level_name)
|
||
|
||
|
||
def _make_timer_callbacks(uname, host, ctx):
|
||
"""Return (on_overdue, on_unknown) async callbacks for connection timer logic.
|
||
|
||
Captured values are bound at call time so callbacks are safe to use in loops.
|
||
"""
|
||
msg_to_websockets = ctx.get("msg_to_websockets")
|
||
threshold_checker = ctx.get("threshold_checker")
|
||
cfg = ctx.get("config", {})
|
||
|
||
async def on_unknown(connection):
|
||
connection.newstate(connection.__class__.UNKNOWN, connection.lastbeat)
|
||
# Keep connectivity alert active when host transitions to unknown
|
||
if msg_to_websockets:
|
||
msg_to_websockets("host", host.stateinfo())
|
||
|
||
async def on_overdue(connection):
|
||
if connection.getstate() != connection.__class__.UP:
|
||
return
|
||
now = time.time()
|
||
connection.newstate(connection.__class__.OVERDUE, now, cfg.get("grace", 2))
|
||
msg = f"{connection.afam} overdue"
|
||
eventlog(uname, "CRITICAL", msg)
|
||
if host.watched:
|
||
asyncio.create_task(notify_mod.send_notification(
|
||
uname,
|
||
notify_mod.Notification(title=f"[CRITICAL] {uname}", body=msg, level="CRITICAL"),
|
||
))
|
||
# Track in alert_states so the Alerts Dashboard shows this
|
||
_set_connectivity_alert(host, connection.afam, "CRITICAL")
|
||
if threshold_checker:
|
||
threshold_checker.check_value(
|
||
host_name=uname,
|
||
metric_path="rtt",
|
||
value=float("inf"),
|
||
alert_states=host.alert_states,
|
||
)
|
||
if msg_to_websockets:
|
||
msg_to_websockets("host", host.stateinfo())
|
||
connection.reset_overdue_timer(DROPOVERDUE, on_unknown)
|
||
|
||
return on_overdue, on_unknown
|
||
|
||
|
||
def restore_connection_timers(hbdclass, ctx):
|
||
"""Restore overdue timers for all loaded connections after a pickle restore.
|
||
|
||
For UP connections, the remaining time until overdue is calculated from
|
||
lastbeat so that clients that vanished during hbd's downtime are detected.
|
||
For OVERDUE connections, the UNKNOWN drop timer is restored.
|
||
"""
|
||
now = time.time()
|
||
cfg = ctx.get("config", {})
|
||
grace = cfg.get("grace", 2)
|
||
|
||
restored = 0
|
||
for uname, host in list(hbdclass.Host.hosts.items()):
|
||
interval = host.interval
|
||
for afam, conn in list(host.connections.items()):
|
||
state = conn.getstate()
|
||
if state == hbdclass.Connection.DOWN:
|
||
continue
|
||
|
||
on_overdue, on_unknown = _make_timer_callbacks(uname, host, ctx)
|
||
|
||
if state == hbdclass.Connection.UP and interval > 0:
|
||
elapsed = now - conn.lastbeat
|
||
# Give hosts one full (interval + grace) of extra time on startup
|
||
# so hosts that were silent while hbd was down are not immediately
|
||
# flagged as overdue before they have a chance to check in.
|
||
startup_grace = interval + grace
|
||
remaining = max(startup_grace, 2 * startup_grace - elapsed)
|
||
conn.reset_overdue_timer(remaining, on_overdue)
|
||
logger.debug(
|
||
"Restored UP timer %s/%s: %.0fs remaining (elapsed %.0fs, startup grace %.0fs)",
|
||
uname, afam, remaining, elapsed, startup_grace,
|
||
)
|
||
restored += 1
|
||
|
||
elif state == hbdclass.Connection.OVERDUE:
|
||
elapsed_overdue = now - conn.statetime
|
||
remaining = DROPOVERDUE - elapsed_overdue
|
||
if remaining <= 1:
|
||
# Already past the drop window — mark UNKNOWN immediately
|
||
conn.newstate(hbdclass.Connection.UNKNOWN, conn.lastbeat)
|
||
logger.info(
|
||
"Marking %s/%s UNKNOWN (overdue %.1f days)",
|
||
uname, afam, elapsed_overdue / 86400,
|
||
)
|
||
else:
|
||
conn.reset_overdue_timer(remaining, on_unknown)
|
||
logger.debug(
|
||
"Restored OVERDUE timer %s/%s: %.0fs remaining",
|
||
uname, afam, remaining,
|
||
)
|
||
restored += 1
|
||
|
||
logger.info("Restored timers for %d connection(s)", restored)
|
||
|
||
|
||
def handle_datagram(msg: dict, addr, transport, ctx: dict):
|
||
"""Handle a parsed datagram message.
|
||
|
||
ctx is a dictionary with runtime dependencies:
|
||
- config: dict of configuration
|
||
- hbdclass: module providing Host/Connection classes
|
||
- log: callable(loghost, message)
|
||
- msg_to_websockets: callable(typ, data)
|
||
- msg_journal: MessageJournal instance for logging all messages
|
||
- DEBUG, verbose
|
||
"""
|
||
if not msg:
|
||
return
|
||
now = ctx.get("recv_ts") or time.time()
|
||
|
||
# Log message to journal
|
||
msg_journal = ctx.get("msg_journal")
|
||
if msg_journal:
|
||
# Create async task to log message (non-blocking)
|
||
import asyncio
|
||
try:
|
||
loop = asyncio.get_event_loop()
|
||
loop.create_task(msg_journal.log_message(msg, addr, now))
|
||
except Exception as e:
|
||
logger.debug(f"Failed to log message to journal: {e}")
|
||
|
||
cfg = ctx.get("config", {})
|
||
hbdcls = ctx.get("hbdclass")
|
||
msg_to_websockets = ctx.get("msg_to_websockets")
|
||
DEBUG = ctx.get("DEBUG", 0)
|
||
verbose = ctx.get("verbose", False)
|
||
|
||
# normalize addr (ip, port)
|
||
ip = addr[0] if isinstance(addr, (list, tuple)) else addr
|
||
name = msg.get("name", "unknown")
|
||
from ..common.utils import shortname
|
||
from . import config as config_mod
|
||
|
||
uname = shortname(name)
|
||
|
||
if uname not in hbdcls.Host.hosts:
|
||
host = hbdcls.Host(uname)
|
||
# Use new config function to check dyndns
|
||
dyndnshosts = config_mod.get_dyndnshosts(cfg)
|
||
host.dyn = uname in dyndnshosts
|
||
# Apply user-access settings from config
|
||
access = config_mod.get_host_access(cfg, uname)
|
||
host.apply_access(access["owner"], access["managers"], access["monitors"])
|
||
logger.info("New host signed on: %s (dyn=%s, access=%s)", uname, host.dyn, access)
|
||
newh = True
|
||
else:
|
||
host = hbdcls.Host.hosts[uname]
|
||
newh = False
|
||
|
||
cid = msg.get("id", 0)
|
||
try:
|
||
rtt = float(msg.get("rtt"))
|
||
except TypeError:
|
||
rtt = None
|
||
|
||
if msg.get("ID") == "HTB":
|
||
host.doesack = msg.get("acks", -1)
|
||
# send ACK back; ask client to resend plugin info when we have none yet
|
||
rmsg = {"time": time.time()}
|
||
if not host.plugin_data:
|
||
rmsg["request_update"] = 1
|
||
opkt = dicttos("ACK", rmsg)
|
||
try:
|
||
transport.sendto(opkt, addr)
|
||
except Exception as e:
|
||
if DEBUG > 0:
|
||
print(("cannot send ack: %s" % e))
|
||
|
||
elif msg.get("ID") == "PLG":
|
||
# Handle plugin data message
|
||
plugin_name = msg.get("plugin")
|
||
if plugin_name:
|
||
# Extract plugin fields, dropping protocol metadata fields
|
||
plugin_data = {k: v for k, v in msg.items()
|
||
if k not in ("ID", "plugin", "id", "name")}
|
||
# Store plugin data with timestamp
|
||
host.add_plugin_data(plugin_name, plugin_data, timestamp=now)
|
||
|
||
# If os_info reports an owner and none is configured server-side, apply it
|
||
if plugin_name == "os_info":
|
||
config_owner = config_mod.get_host_access(cfg, uname).get("owner")
|
||
default_owner = config_mod.get_default_owner(cfg)
|
||
inferred_owner = plugin_data.get("owner", config_owner or default_owner)
|
||
host.owner = inferred_owner
|
||
logger.info(f"owner for {uname} is '{host.owner}")
|
||
if DEBUG > 1:
|
||
print(f"Stored plugin data for {uname}: {plugin_name}")
|
||
|
||
# Check thresholds if checker is available
|
||
threshold_checker = ctx.get("threshold_checker")
|
||
if threshold_checker:
|
||
try:
|
||
state_changes = threshold_checker.check_plugin_data(
|
||
host_name=uname,
|
||
plugin_name=plugin_name,
|
||
data=plugin_data,
|
||
alert_states=host.alert_states,
|
||
)
|
||
if DEBUG > 1 and state_changes:
|
||
print(f"Threshold state changes for {uname}: {state_changes}")
|
||
except Exception as e:
|
||
logger.error(f"Error checking thresholds for {uname}.{plugin_name}: {e}")
|
||
|
||
# Notify websockets of plugin update
|
||
if msg_to_websockets:
|
||
try:
|
||
msg_to_websockets("plugin", {
|
||
"host": uname,
|
||
"plugin": plugin_name,
|
||
"data": plugin_data,
|
||
"timestamp": now
|
||
})
|
||
except Exception:
|
||
pass
|
||
|
||
try:
|
||
conn, res = host.conndata(cid, ip, rtt, now)
|
||
except Exception as e:
|
||
if DEBUG > 0:
|
||
print("conndata failed: %s" % e)
|
||
return
|
||
|
||
if res:
|
||
eventlog(uname, "WARNING", res)
|
||
if host.watched:
|
||
asyncio.create_task(notify_mod.send_notification(
|
||
uname,
|
||
notify_mod.Notification(title=f"[WARNING] {uname}", body=res, level="WARNING"),
|
||
))
|
||
|
||
interval = int(msg.get("interval", 0) or 0)
|
||
shutdown = msg.get("shutdown", 0)
|
||
service = msg.get("service", "unknown")
|
||
message = msg.get("msg", None)
|
||
boot = msg.get("boot", 0)
|
||
|
||
if boot:
|
||
eventlog(uname, "INFO", "booted")
|
||
if host.watched:
|
||
asyncio.create_task(notify_mod.send_notification(
|
||
uname,
|
||
notify_mod.Notification(title=f"[INFO] {uname}", body=f"{host.name} booted", level="INFO"),
|
||
))
|
||
if message:
|
||
eventlog(uname, "INFO", "msg: %s" % message, service=service)
|
||
|
||
if conn.getstate() != hbdcls.Connection.UP:
|
||
lasts = conn.state
|
||
d = conn.newstate(hbdcls.Connection.UP, now)
|
||
# Clear connectivity alert now that the host is back up
|
||
_set_connectivity_alert(host, conn.afam, "OK")
|
||
# Don't log/notify RECOVER for a brand-new host seen for the first time —
|
||
# it was never down, it just hasn't been seen before.
|
||
if not newh:
|
||
if d == 0 or lasts == "unknown":
|
||
m = "%s is up" % (conn.afam)
|
||
elif d < 4:
|
||
# Transient blip (likely client restart) — skip log and notification
|
||
m = None
|
||
else:
|
||
m = "%s back after being %s for %s" % (conn.afam, lasts, dur(d))
|
||
if m:
|
||
eventlog(uname, "RECOVER", m)
|
||
if host.watched:
|
||
asyncio.create_task(notify_mod.send_notification(
|
||
uname,
|
||
notify_mod.Notification(title=f"[RECOVER] {uname}", body=m, level="RECOVER"),
|
||
))
|
||
|
||
if boot or newh:
|
||
host.upcount = host.doesack
|
||
else:
|
||
host.upcount += 1
|
||
|
||
if shutdown:
|
||
m = "%s shutdown" % conn.afam
|
||
eventlog(uname, "INFO", m)
|
||
if host.watched:
|
||
asyncio.create_task(notify_mod.send_notification(
|
||
uname,
|
||
notify_mod.Notification(title=f"[INFO] {uname}", body=m, level="INFO"),
|
||
))
|
||
conn.newstate(hbdcls.Connection.DOWN, now)
|
||
_set_connectivity_alert(host, conn.afam, "CRITICAL")
|
||
|
||
if interval > 0:
|
||
host.interval = interval
|
||
|
||
# Timer-based reachability monitoring
|
||
# Reset overdue timer on every heartbeat
|
||
if interval > 0 and conn.getstate() != hbdcls.Connection.DOWN:
|
||
grace = cfg.get("grace", 2)
|
||
timeout_seconds = interval + grace
|
||
on_overdue, _ = _make_timer_callbacks(uname, host, ctx)
|
||
conn.reset_overdue_timer(timeout_seconds, on_overdue)
|
||
|
||
# Check RTT thresholds using the threshold checker
|
||
threshold_checker = ctx.get("threshold_checker")
|
||
if threshold_checker and rtt and rtt > 0:
|
||
# Metric path for RTT is simply "rtt"
|
||
metric_path = "rtt"
|
||
|
||
# Check against configured thresholds (handles alerts, notifications, etc.)
|
||
threshold_checker.check_value(
|
||
host_name=uname,
|
||
metric_path=metric_path,
|
||
value=rtt,
|
||
alert_states=host.alert_states
|
||
)
|
||
|
||
# send any commands we have queued
|
||
while len(host.cmds):
|
||
op, rmsg = host.cmds[0]
|
||
if op == "CMD":
|
||
del host.cmds[0]
|
||
eventlog(uname, "INFO", "command sent")
|
||
elif op == "UPD":
|
||
del host.cmds[0]
|
||
eventlog(uname, "INFO", "update initiated")
|
||
opkt = dicttos(op, rmsg)
|
||
try:
|
||
transport.sendto(opkt, addr)
|
||
except Exception as e:
|
||
if DEBUG > 0:
|
||
print(("cannot send cmd/update: %s" % e))
|
||
|
||
if msg_to_websockets:
|
||
try:
|
||
msg_to_websockets("host", host.stateinfo())
|
||
except Exception as e:
|
||
if DEBUG > 0:
|
||
print(("cannot send websocket message: %s" % e))
|