Files
heartbeat/scripts/hbc_windows.py
T
2026-06-07 07:53:57 -04:00

1200 lines
40 KiB
Python

#!/usr/bin/env python3
"""hbc_windows — HeartBeat Client for Windows, installable as a Windows Service.
Run standalone:
python hbc_windows.py <hbd-host>
Install as Windows Service (using NSSM):
nssm install heartbeat "C:\\path\\to\\hbc_windows.exe" <hbd-host>
nssm start heartbeat
Config: %PROGRAMDATA%\\heartbeat\\hbc.json
Logs: %PROGRAMDATA%\\heartbeat\\hbc.log
"""
import argparse
import asyncio
import ctypes
import json
import logging
import os
import platform
import re
import shutil
import signal
import socket
import subprocess
import sys
import time
import zlib
from abc import ABC, abstractmethod
from collections import defaultdict
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
# updated by scripts/bumpminor.sh
__version__ = "5.3.10"
_DATA_DIR = os.path.join(os.environ.get("PROGRAMDATA", "C:\\ProgramData"), "heartbeat")
LOG_FILE = os.path.join(_DATA_DIR, "hbc.log")
CONFIG_FILE = os.path.join(_DATA_DIR, "hbc.json")
# ---------------------------------------------------------------------------
# Protocol (mirrors hbd/common/proto.py)
# ---------------------------------------------------------------------------
def _encode_value(v: Any) -> str:
if isinstance(v, float):
return f"{v:0.5f}"
if isinstance(v, (list, dict)):
return "@" + json.dumps(v)
if isinstance(v, bool):
return str(int(v))
return str(v)
def _decode_value(val: str) -> Any:
if not val:
return val
if val.startswith("@"):
try:
return json.loads(val[1:])
except Exception:
return val[1:]
if val[0].isdigit() or (val[0] == "-" and len(val) > 1 and val[1].isdigit()):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
return val
def _dicttos(msg_id: str, d: Dict[str, Any]) -> bytes:
payload = ";".join(f"{k}={_encode_value(v)}" for k, v in d.items()).encode()
return ("!" + msg_id + ":").encode() + zlib.compress(payload, 6)
def _stodict(data: bytes) -> Dict[str, Any]:
result: Dict[str, Any] = {}
if not data:
return result
if chr(data[0]) == "!":
try:
payload = zlib.decompress(data[5:]).decode()
except Exception:
return {}
result["ID"] = data[1:4].decode()
else:
try:
head, payload = data.split(b":", 1)
payload = payload.decode()
result["ID"] = head.decode()
except Exception:
return {}
for item in payload.split(";"):
if not item:
continue
kv = item.split("=", 1)
result[kv[0].strip()] = _decode_value(kv[1].strip()) if len(kv) > 1 else None
return result
# ---------------------------------------------------------------------------
# Config (JSON, default %PROGRAMDATA%\heartbeat\hbc.json)
# ---------------------------------------------------------------------------
_DEFAULTS: Dict[str, Any] = {
"hb_port": 50003,
"interval": 10,
"owner": None,
"plugins": {},
}
def _load_config(path: Optional[str] = None) -> Dict[str, Any]:
cfg = dict(_DEFAULTS)
if not path:
path = CONFIG_FILE
if not os.path.exists(path):
path = os.path.join(os.path.expanduser("~"), ".hbc.json")
if os.path.exists(path):
try:
with open(path) as fh:
cfg.update(json.load(fh))
logging.getLogger("hbc.config").info("loaded config from %s", path)
except Exception as e:
logging.getLogger("hbc.config").warning("cannot read %s: %s", path, e)
return cfg
# ---------------------------------------------------------------------------
# Plugin base classes
# ---------------------------------------------------------------------------
class Plugin(ABC):
name: str = ""
version: str = "1.0.0"
description: str = ""
interval: int = 0
enabled: bool = True
def __init__(self, config: Optional[Dict[str, Any]] = None):
self.config = config or {}
self.logger = logging.getLogger(f"plugin.{self.name}")
self.skip_reason: Optional[str] = None
@abstractmethod
async def initialize(self) -> bool: ...
@abstractmethod
async def collect(self) -> Dict[str, Any]: ...
async def cleanup(self) -> None:
pass
class InfoPlugin(Plugin):
"""Collected once at startup, result is cached."""
interval = 0
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
self._cache: Optional[Dict[str, Any]] = None
async def collect(self) -> Dict[str, Any]:
if self._cache is None:
self._cache = await self._collect_info()
return self._cache
@abstractmethod
async def _collect_info(self) -> Dict[str, Any]: ...
class MonitorPlugin(Plugin):
"""Collected periodically at self.interval seconds."""
interval = 60
async def collect(self) -> Dict[str, Any]:
try:
return await self._collect_metrics()
except Exception as e:
self.logger.error("collect: %s", e)
return {}
@abstractmethod
async def _collect_metrics(self) -> Dict[str, Any]: ...
# ---------------------------------------------------------------------------
# Plugin: os_info
# ---------------------------------------------------------------------------
class OSInfoPlugin(InfoPlugin):
name = "os_info"
description = "OS and hardware info"
async def initialize(self) -> bool:
return True
async def _collect_info(self) -> Dict[str, Any]:
info: Dict[str, Any] = {
"os": platform.system(),
"os_release": platform.release(),
"os_version": platform.version(),
"machine": platform.machine(),
"python_version": platform.python_version(),
}
try:
info["hostname"] = socket.getfqdn()
except Exception:
pass
try:
info["processor"] = platform.processor()
except Exception:
pass
if platform.system() == "Windows":
try:
out = subprocess.check_output(
["wmic", "os", "get", "Caption", "/value"],
timeout=5, stderr=subprocess.DEVNULL
).decode(errors="replace")
for line in out.splitlines():
if line.startswith("Caption="):
info["os_pretty_name"] = line.split("=", 1)[1].strip()
break
except Exception:
pass
else:
try:
with open("/etc/os-release") as fh:
for line in fh:
if line.startswith("PRETTY_NAME="):
info["os_pretty_name"] = line.split("=", 1)[1].strip().strip('"')
break
except Exception:
pass
return info
# ---------------------------------------------------------------------------
# Plugin: ping_monitor
# ---------------------------------------------------------------------------
def _parse_ping_rtt(output: str) -> Optional[float]:
for pattern in (
r"rtt min/avg/max/mdev = [\d.]+/([\d.]+)/",
r"round-trip min/avg/max/stddev = [\d.]+/([\d.]+)/",
r"Average = ([\d.]+)ms",
):
m = re.search(pattern, output)
if m:
return float(m.group(1))
return None
class PingMonitorPlugin(MonitorPlugin):
name = "ping_monitor"
description = "ICMP ping RTT to configured hosts"
interval = 60
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
cfg = config or {}
self.interval = cfg.get("interval", 60)
self.targets: List[str] = cfg.get("targets", [])
self.count: int = cfg.get("count", 3)
self.timeout: int = cfg.get("timeout", 5)
async def initialize(self) -> bool:
if not self.targets:
self.skip_reason = "no targets configured"
return False
return True
async def _ping_one(self, host: str) -> Optional[float]:
if sys.platform == "win32":
cmd = ["ping", "-n", str(self.count), "-w", str(self.timeout * 1000), host]
else:
cmd = ["ping", "-c", str(self.count), "-W", str(self.timeout), host]
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
out, _ = await asyncio.wait_for(proc.communicate(), timeout=self.timeout * self.count + 5)
if proc.returncode == 0:
return _parse_ping_rtt(out.decode(errors="replace"))
except Exception as e:
self.logger.debug("ping %s: %s", host, e)
return None
async def _collect_metrics(self) -> Dict[str, Any]:
results: Dict[str, Any] = {}
for host in self.targets:
rtt = await self._ping_one(host)
key = re.sub(r"[^a-zA-Z0-9_]", "_", host)
results[f"ping_{key}_rtt"] = rtt if rtt is not None else -1.0
results[f"ping_{key}_ok"] = 1 if rtt is not None else 0
return results
# ---------------------------------------------------------------------------
# Plugin: nagios_runner
# ---------------------------------------------------------------------------
class NagiosRunnerPlugin(MonitorPlugin):
name = "nagios_runner"
description = "Run Nagios-compatible check scripts"
interval = 300
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
cfg = config or {}
self.interval = cfg.get("interval", 300)
self.checks: Dict[str, str] = cfg.get("checks", {})
self.timeout: int = cfg.get("timeout", 30)
async def initialize(self) -> bool:
if not self.checks:
self.skip_reason = "no checks configured"
return False
return True
async def _run_check(self, name: str, command: str) -> Dict[str, Any]:
try:
proc = await asyncio.create_subprocess_exec(
*command.split(),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
out, err = await asyncio.wait_for(proc.communicate(), timeout=self.timeout)
rc = proc.returncode or 0
output = (out or err or b"").decode(errors="replace").strip()
except asyncio.TimeoutError:
rc = 3
output = f"UNKNOWN: check timed out after {self.timeout}s"
except Exception as e:
rc = 3
output = f"UNKNOWN: {e}"
return {
f"{name}_status_code": rc,
f"{name}_output": output[:500],
}
async def _collect_metrics(self) -> Dict[str, Any]:
results: Dict[str, Any] = {}
for name, command in self.checks.items():
r = await self._run_check(name, command)
results.update(r)
return results
# ---------------------------------------------------------------------------
# Plugin: cpu_monitor (Linux only — skips on Windows)
# ---------------------------------------------------------------------------
def _read_cpu_stat() -> Optional[List[int]]:
try:
with open("/proc/stat") as fh:
line = fh.readline()
parts = line.split()
if parts[0] == "cpu":
return [int(x) for x in parts[1:]]
except Exception:
pass
return None
class CPUMonitorPlugin(MonitorPlugin):
name = "cpu_monitor"
description = "CPU usage via /proc/stat (Linux only)"
interval = 300
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
self.interval = (config or {}).get("interval", 300)
self._prev: Optional[Tuple[float, List[int]]] = None
async def initialize(self) -> bool:
if platform.system() != "Linux":
self.skip_reason = "Linux only (/proc/stat not available)"
return False
stat = _read_cpu_stat()
if stat is None:
self.skip_reason = "/proc/stat not readable"
return False
self._prev = (time.time(), stat)
return True
async def _collect_metrics(self) -> Dict[str, Any]:
now = time.time()
curr = _read_cpu_stat()
if curr is None or self._prev is None:
return {}
prev_ts, prev = self._prev
self._prev = (now, curr)
dt = now - prev_ts
if dt <= 0:
return {}
idle_idx = 3
prev_total = sum(prev)
curr_total = sum(curr)
total_delta = curr_total - prev_total
idle_delta = curr[idle_idx] - prev[idle_idx]
cpu_pct = round(100.0 * (1.0 - idle_delta / total_delta), 1) if total_delta else 0.0
data: Dict[str, Any] = {
"cpu_percent": cpu_pct,
"cpu_count": os.cpu_count() or 1,
}
try:
la = os.getloadavg()
data["load_1"] = la[0]
data["load_5"] = la[1]
data["load_15"] = la[2]
except (AttributeError, OSError):
pass
return data
# ---------------------------------------------------------------------------
# Plugin: memory_monitor (Linux via /proc/meminfo; Windows via ctypes)
# ---------------------------------------------------------------------------
def _read_meminfo() -> Dict[str, int]:
result: Dict[str, int] = {}
try:
with open("/proc/meminfo") as fh:
for line in fh:
parts = line.split()
if len(parts) >= 2:
try:
result[parts[0].rstrip(":")] = int(parts[1])
except ValueError:
pass
except Exception:
pass
return result
if sys.platform == "win32":
import ctypes.wintypes
class _MEMORYSTATUSEX(ctypes.Structure):
_fields_ = [
("dwLength", ctypes.c_ulong),
("dwMemoryLoad", ctypes.c_ulong),
("ullTotalPhys", ctypes.c_ulonglong),
("ullAvailPhys", ctypes.c_ulonglong),
("ullTotalPageFile", ctypes.c_ulonglong),
("ullAvailPageFile", ctypes.c_ulonglong),
("ullTotalVirtual", ctypes.c_ulonglong),
("ullAvailVirtual", ctypes.c_ulonglong),
("ullAvailExtendedVirtual", ctypes.c_ulonglong),
]
def _windows_memory_info() -> Optional[Dict[str, Any]]:
stat = _MEMORYSTATUSEX()
stat.dwLength = ctypes.sizeof(_MEMORYSTATUSEX)
if not ctypes.windll.kernel32.GlobalMemoryStatusEx(ctypes.byref(stat)):
return None
total = stat.ullTotalPhys
avail = stat.ullAvailPhys
used = total - avail
page_total = stat.ullTotalPageFile
page_avail = stat.ullAvailPageFile
swap_total = max(page_total - total, 0)
swap_used = max(page_total - page_avail - used, 0)
data: Dict[str, Any] = {
"memory_total": total,
"memory_used": used,
"memory_available": avail,
"memory_percent": stat.dwMemoryLoad,
}
if swap_total > 0:
data["swap_total"] = swap_total
data["swap_used"] = swap_used
data["swap_free"] = swap_total - swap_used
data["swap_percent"] = round(100.0 * swap_used / swap_total, 1) if swap_total else 0.0
return data
class MemoryMonitorPlugin(MonitorPlugin):
name = "memory_monitor"
description = "Memory usage"
interval = 300
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
self.interval = (config or {}).get("interval", 300)
async def initialize(self) -> bool:
if sys.platform == "win32":
return True
if platform.system() != "Linux":
self.skip_reason = "Linux or Windows only"
return False
if not _read_meminfo():
self.skip_reason = "/proc/meminfo not readable"
return False
return True
async def _collect_metrics(self) -> Dict[str, Any]:
if sys.platform == "win32":
result = _windows_memory_info()
return result or {}
mi = _read_meminfo()
if not mi:
return {}
total = mi.get("MemTotal", 0)
avail = mi.get("MemAvailable", mi.get("MemFree", 0))
free = mi.get("MemFree", 0)
arc_kb = 0
try:
with open("/proc/spl/kstat/zfs/arcstats") as _f:
for _line in _f:
_p = _line.split()
if len(_p) >= 3 and _p[0] == "size":
arc_kb = int(_p[2]) // 1024
break
except (OSError, ValueError):
pass
avail = min(avail + arc_kb, total)
used = total - avail
data: Dict[str, Any] = {
"memory_total": total * 1024,
"memory_used": used * 1024,
"memory_available": avail * 1024,
"memory_free": free * 1024,
"memory_percent": round(100.0 * used / total, 1) if total else 0.0,
}
for field, key in (("Buffers", "memory_buffers"), ("Cached", "memory_cached"),
("Active", "memory_active"), ("Inactive", "memory_inactive")):
if field in mi:
data[key] = mi[field] * 1024
stotal = mi.get("SwapTotal", 0)
if stotal:
sfree = mi.get("SwapFree", 0)
sused = stotal - sfree
data["swap_total"] = stotal * 1024
data["swap_used"] = sused * 1024
data["swap_free"] = sfree * 1024
data["swap_percent"] = round(100.0 * sused / stotal, 1)
return data
# ---------------------------------------------------------------------------
# Plugin: disk_monitor (Windows via ctypes; Unix via df -P)
# ---------------------------------------------------------------------------
def _windows_drives() -> List[str]:
drives = []
for letter in "CDEFGHIJKLMNOPQRSTUVWXYZ":
path = f"{letter}:\\"
if os.path.exists(path):
drives.append(path)
return drives
def _windows_disk_usage(path: str) -> Optional[Dict[str, Any]]:
avail = ctypes.c_ulonglong(0)
total = ctypes.c_ulonglong(0)
free = ctypes.c_ulonglong(0)
ok = ctypes.windll.kernel32.GetDiskFreeSpaceExW(
path,
ctypes.byref(avail),
ctypes.byref(total),
ctypes.byref(free),
)
if not ok or total.value == 0:
return None
used = total.value - free.value
return {
"total": total.value,
"used": used,
"free": free.value,
"percent": round(100.0 * used / total.value, 1),
}
class DiskMonitorPlugin(MonitorPlugin):
name = "disk_monitor"
description = "Disk usage"
interval = 300
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
cfg = config or {}
self.interval = cfg.get("interval", 300)
self.mounts: List[str] = cfg.get("mounts", [])
async def initialize(self) -> bool:
if sys.platform != "win32" and platform.system() not in ("Linux", "Darwin", "FreeBSD"):
self.skip_reason = "unsupported platform"
return False
return True
async def _collect_metrics(self) -> Dict[str, Any]:
if sys.platform == "win32":
return self._collect_windows()
return await self._collect_unix()
def _collect_windows(self) -> Dict[str, Any]:
drives = self.mounts if self.mounts else _windows_drives()
partitions: Dict[str, Any] = {}
for drive in drives:
info = _windows_disk_usage(drive)
if info:
mount_key = drive.rstrip("\\").replace(":", "")
partitions[mount_key] = info
return {"partitions": partitions} if partitions else {}
async def _collect_unix(self) -> Dict[str, Any]:
try:
proc = await asyncio.create_subprocess_exec(
"df", "-P",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
out, _ = await asyncio.wait_for(proc.communicate(), timeout=10)
except Exception as e:
self.logger.warning("df failed: %s", e)
return {}
partitions: Dict[str, Any] = {}
for line in out.decode(errors="replace").splitlines()[1:]:
parts = line.split()
if len(parts) < 6:
continue
mount = parts[5]
if self.mounts and mount not in self.mounts:
continue
try:
total_kb = int(parts[1])
used_kb = int(parts[2])
avail_kb = int(parts[3])
pct = int(parts[4].rstrip("%"))
partitions[mount] = {
"total": total_kb * 1024,
"used": used_kb * 1024,
"free": avail_kb * 1024,
"percent": pct,
}
except (ValueError, IndexError):
continue
return {"partitions": partitions} if partitions else {}
# ---------------------------------------------------------------------------
# Plugin: network_monitor (Linux only — skips on Windows)
# ---------------------------------------------------------------------------
def _read_net_dev() -> Dict[str, Tuple[int, int]]:
result: Dict[str, Tuple[int, int]] = {}
try:
with open("/proc/net/dev") as fh:
for line in fh.readlines()[2:]:
parts = line.split()
if len(parts) < 10:
continue
result[parts[0].rstrip(":")] = (int(parts[1]), int(parts[9]))
except Exception:
pass
return result
class NetworkMonitorPlugin(MonitorPlugin):
name = "network_monitor"
description = "Network I/O rates via /proc/net/dev (Linux only)"
interval = 300
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
cfg = config or {}
self.interval = cfg.get("interval", 300)
self.skip_ifaces: set = set(cfg.get("skip_interfaces", ["lo"]))
self._prev: Optional[Tuple[float, Dict[str, Tuple[int, int]]]] = None
async def initialize(self) -> bool:
if platform.system() != "Linux":
self.skip_reason = "Linux only (/proc/net/dev not available)"
return False
dev = _read_net_dev()
if not dev:
self.skip_reason = "/proc/net/dev not readable"
return False
self._prev = (time.time(), dev)
return True
async def _collect_metrics(self) -> Dict[str, Any]:
now = time.time()
curr = _read_net_dev()
if not curr or self._prev is None:
self._prev = (now, curr)
return {}
prev_ts, prev = self._prev
dt = now - prev_ts
self._prev = (now, curr)
if dt <= 0:
return {}
interfaces: Dict[str, Any] = {}
for iface, (rx, tx) in curr.items():
if iface in self.skip_ifaces or iface not in prev:
continue
prx, ptx = prev[iface]
interfaces[iface] = {
"bytes_recv": rx,
"bytes_sent": tx,
"bytes_recv_delta": rx - prx,
"bytes_sent_delta": tx - ptx,
}
return {"interfaces": interfaces} if interfaces else {}
# ---------------------------------------------------------------------------
# Plugin registry
# ---------------------------------------------------------------------------
_ALL_PLUGIN_CLASSES: List[type] = [
OSInfoPlugin,
PingMonitorPlugin,
NagiosRunnerPlugin,
CPUMonitorPlugin,
MemoryMonitorPlugin,
DiskMonitorPlugin,
NetworkMonitorPlugin,
]
async def _load_plugins(cfg: Dict[str, Any]) -> List[Plugin]:
log = logging.getLogger("hbc.plugins")
plugins_cfg: Dict[str, Any] = cfg.get("plugins", {})
loaded: List[Plugin] = []
for cls in _ALL_PLUGIN_CLASSES:
plugin_cfg = dict(plugins_cfg.get(cls.name) or cfg.get(cls.name) or {})
if "owner" in cfg and "owner" not in plugin_cfg:
plugin_cfg["owner"] = cfg["owner"]
plugin: Plugin = cls(config=plugin_cfg)
try:
ok = await plugin.initialize()
except Exception as e:
log.error("init %s: %s", cls.name, e)
ok = False
if ok:
loaded.append(plugin)
log.info("loaded %s (interval=%ds)", plugin.name, plugin.interval)
else:
log.info("skip %s: %s", plugin.name, plugin.skip_reason or "init failed")
return loaded
# ---------------------------------------------------------------------------
# Global state
# ---------------------------------------------------------------------------
_running = True
_dorestart = False
_shutdown_event: Optional[asyncio.Event] = None
_active_tasks: List[asyncio.Task] = []
PORT = 50003
INTERVAL = 10
def _shortname(name: str) -> str:
return name.split(".")[0]
def _stop():
global _running
_running = False
if _shutdown_event:
_shutdown_event.set()
for t in _active_tasks:
if not t.done():
t.cancel()
async def _sleep(seconds: float):
try:
if _shutdown_event:
await asyncio.wait_for(_shutdown_event.wait(), timeout=seconds)
else:
await asyncio.sleep(seconds)
except asyncio.TimeoutError:
pass
# ---------------------------------------------------------------------------
# UDP protocol handler + connection
# ---------------------------------------------------------------------------
class _HeartbeatProtocol(asyncio.DatagramProtocol):
def __init__(self, conn: "AsyncConnection"):
self._conn = conn
self._log = logging.getLogger("hbc.proto")
def datagram_received(self, data: bytes, addr):
try:
msg = _stodict(data)
if not msg:
return
msg_id = msg.get("ID")
now = time.time()
if msg_id == "ACK":
self._conn._handle_ack(msg, now)
elif msg_id == "CMD":
asyncio.create_task(_handle_command(self._conn, msg))
elif msg_id == "UPD":
asyncio.create_task(_handle_update(self._conn))
else:
self._log.debug("unknown msg type: %s", msg_id)
except Exception as e:
self._log.error("datagram error: %s", e)
def error_received(self, exc):
self._log.warning("protocol error on %s: %s — will retry", self._conn.addr, exc)
self._conn.close()
class AsyncConnection:
def __init__(self, conn_id: int, addr: str, port: int, af: int, name: str):
self.conn_id = conn_id
self.addr = addr
self.port = port
self.af = af
self.name = name
self.ackcount = 0
self.lastsend = 0.0
self.rtts: List[float] = [0.0]
self._transport: Optional[asyncio.DatagramTransport] = None
self._dead = False
self._request_info: asyncio.Event = asyncio.Event()
self._log = logging.getLogger(f"hbc.conn.{addr}")
async def open(self) -> bool:
try:
loop = asyncio.get_event_loop()
self._transport, _ = await loop.create_datagram_endpoint(
lambda: _HeartbeatProtocol(self), family=self.af
)
return True
except Exception as e:
self._log.error("open: %s", e)
return False
def close(self):
if self._transport:
self._transport.close()
self._transport = None
def _handle_ack(self, msg: Dict[str, Any], now: float):
rtt = (now - self.lastsend) * 1000.0
self.rtts.append(rtt)
if len(self.rtts) > 10:
self.rtts.pop(0)
self.ackcount += 1
if msg.get("request_update"):
self._request_info.set()
async def sendto(self, msg: Dict[str, Any], msg_id: str = "HTB"):
if self._dead:
return
if not self._transport:
await self.open()
if not self._transport:
return
out = dict(msg)
out["name"] = _shortname(self.name)
out["id"] = self.conn_id
out["time"] = time.time()
self._transport.sendto(_dicttos(msg_id, out), (self.addr, self.port))
self.lastsend = time.time()
# ---------------------------------------------------------------------------
# Server command handlers
# ---------------------------------------------------------------------------
async def _handle_command(conn: AsyncConnection, msg: Dict[str, Any]):
cmd = msg.get("cmd", "")
if not cmd:
return
log = logging.getLogger("hbc.cmd")
log.info("exec: %s", cmd)
try:
out = subprocess.check_output(
cmd, shell=True, stderr=subprocess.STDOUT, timeout=30
).decode()
status = "OK"
except subprocess.CalledProcessError as e:
out, status = str(e), "Error"
except subprocess.TimeoutExpired:
out, status = "timed out", "Timeout"
except Exception as e:
out, status = str(e), "Error"
await conn.sendto({"service": "command", "msg": f"{status} {out}"})
async def _handle_update(conn: AsyncConnection):
log = logging.getLogger("hbc.update")
installer = shutil.which("hb_install.sh")
if installer is None:
candidate = Path(sys.argv[0]).parent / "hb_install.sh"
if candidate.exists():
installer = str(candidate)
if installer is None:
err = "hb_install.sh not found in PATH or alongside hbc_windows.py"
log.error(err)
await conn.sendto({"service": "update", "msg": err})
return
log.info("running installer: %s", installer)
try:
proc = await asyncio.create_subprocess_exec(
installer, "mini",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
out, _ = await asyncio.wait_for(proc.communicate(), timeout=120)
except asyncio.TimeoutError:
err = "installer timed out"
log.error(err)
await conn.sendto({"service": "update", "msg": err})
return
except Exception as e:
err = f"installer failed: {e}"
log.error(err)
await conn.sendto({"service": "update", "msg": err})
return
if proc.returncode != 0:
err = f"installer exited {proc.returncode}: {out.decode().strip()}"
log.error(err)
await conn.sendto({"service": "update", "msg": err})
return
log.info("update successful, restarting")
await conn.sendto({"service": "update", "msg": "OK"})
global _dorestart
_dorestart = True
_stop()
# ---------------------------------------------------------------------------
# Heartbeat sender
# ---------------------------------------------------------------------------
async def _heartbeat_sender(conn: AsyncConnection, interval: int):
log = logging.getLogger("hbc.hb")
while _running:
try:
await conn.sendto({
"acks": conn.ackcount,
"rtt": conn.rtts[-1],
"interval": interval,
})
except Exception as e:
log.error("send: %s", e)
await _sleep(interval)
# ---------------------------------------------------------------------------
# Plugin collection loops
# ---------------------------------------------------------------------------
async def _run_info_plugins(conn: AsyncConnection, plugins: List[Plugin]):
log = logging.getLogger("hbc.plugins")
for plugin in plugins:
try:
data = await plugin.collect()
if data:
await conn.sendto({"plugin": plugin.name, **data}, "PLG")
log.info("sent %s", plugin.name)
except Exception as e:
log.error("%s collect: %s", plugin.name, e)
async def _run_monitor_group(conn: AsyncConnection, plugins: List[Plugin], interval: int):
log = logging.getLogger(f"hbc.plugins.{interval}s")
while _running:
for plugin in plugins:
try:
data = await plugin.collect()
if data:
await conn.sendto({"plugin": plugin.name, **data}, "PLG")
log.debug("sent %s", plugin.name)
except asyncio.CancelledError:
raise
except Exception as e:
log.error("%s: %s", plugin.name, e)
await _sleep(interval)
async def _info_refresh_loop(conn: AsyncConnection, info: List[Plugin]):
log = logging.getLogger("hbc.plugins")
while _running:
await conn._request_info.wait()
if not _running:
break
conn._request_info.clear()
log.info("refreshing InfoPlugins on server request")
for plugin in info:
plugin._cache = None
await _run_info_plugins(conn, info)
async def _plugin_collector(conn: AsyncConnection, plugins: List[Plugin]):
info = [p for p in plugins if isinstance(p, InfoPlugin)]
monitor = [p for p in plugins if isinstance(p, MonitorPlugin)]
await _run_info_plugins(conn, info)
by_interval: Dict[int, List[Plugin]] = defaultdict(list)
for p in monitor:
by_interval[p.interval].append(p)
tasks = [asyncio.create_task(_info_refresh_loop(conn, info))]
tasks += [asyncio.create_task(_run_monitor_group(conn, grp, iv))
for iv, grp in by_interval.items()]
await asyncio.gather(*tasks, return_exceptions=True)
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
def _configure_file_logging(level: int, log_file: str = LOG_FILE):
log_dir = os.path.dirname(log_file)
if log_dir:
os.makedirs(log_dir, exist_ok=True)
root = logging.getLogger()
for h in root.handlers[:]:
root.removeHandler(h)
h.close()
fh = RotatingFileHandler(log_file, maxBytes=5 * 1024 * 1024, backupCount=3, encoding="utf-8")
fh.setFormatter(logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S"))
root.addHandler(fh)
root.setLevel(level)
# ---------------------------------------------------------------------------
# Async main
# ---------------------------------------------------------------------------
async def _async_main(args, cfg: Dict[str, Any]) -> int:
global _running, _shutdown_event, _active_tasks, send_shutdown
_running = True
_shutdown_event = asyncio.Event()
_active_tasks = []
log = logging.getLogger("hbc.main")
iam = args.name or socket.gethostname()
port = cfg.get("hb_port", PORT)
interval = cfg.get("interval", INTERVAL)
log.info("hbc_windows %s on %s -> %s port=%d interval=%ds",
__version__, iam, args.hosts, port, interval)
af_filter = (socket.AF_INET if getattr(args, "ipv4_only", False)
else socket.AF_INET6 if getattr(args, "ipv6_only", False)
else 0)
connections: List[AsyncConnection] = []
conn_id = 1
_retry_delay = 5
while _running and not connections:
for host in args.hosts:
try:
addrs = socket.getaddrinfo(host, port, af_filter, 0, socket.SOL_UDP)
except socket.gaierror as e:
log.warning("cannot resolve %s: %s — retrying in %ds", host, e, _retry_delay)
continue
for ai in addrs:
conn = AsyncConnection(conn_id, ai[4][0], port, ai[0], iam)
if await conn.open():
connections.append(conn)
conn_id += 1
if not connections:
await _sleep(_retry_delay)
_retry_delay = min(_retry_delay * 2, 60)
if not connections:
return 1
send_shutdown = False
if args.boot or args.message:
bmsg: Dict[str, Any] = {"acks": 0}
if args.boot:
bmsg["boot"] = 1
args.boot = False
send_shutdown = True
if args.message:
bmsg["service"] = "service"
bmsg["msg"] = args.message
target = next((c for c in connections if c._transport), connections[0])
await target.sendto(bmsg)
if args.message and not args.daemon:
await asyncio.sleep(0.3)
for c in connections:
c.close()
return 0
plugins = await _load_plugins(cfg)
# Windows: signal.signal() instead of loop.add_signal_handler() (Unix-only)
loop = asyncio.get_running_loop()
signal.signal(signal.SIGTERM, lambda s, f: loop.call_soon_threadsafe(_stop))
signal.signal(signal.SIGINT, lambda s, f: loop.call_soon_threadsafe(_stop))
for conn in connections:
_active_tasks.append(asyncio.create_task(_heartbeat_sender(conn, interval)))
if plugins and connections:
_active_tasks.append(asyncio.create_task(_plugin_collector(connections[0], plugins)))
try:
await asyncio.gather(*_active_tasks, return_exceptions=True)
except asyncio.CancelledError:
pass
log.info("shutting down")
target = next((c for c in connections if c._transport), connections[0] if connections else None)
if target and send_shutdown:
try:
await target.sendto({"shutdown": 1, "acks": target.ackcount})
except Exception:
pass
for conn in connections:
conn.close()
await asyncio.sleep(0.3)
for plugin in plugins:
await plugin.cleanup()
return 0
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main(argv=None):
global _dorestart
parser = argparse.ArgumentParser(
prog="hbc_windows",
description="HeartBeat Client for Windows — no external dependencies",
)
parser.add_argument("-b", "--boot", action="store_true", help="Send boot message")
parser.add_argument("-c", "--config", dest="configfile", help="Config file (JSON)")
parser.add_argument("-m", "--message", dest="message", help="Send a one-shot message")
parser.add_argument("-n", "--name", dest="name", help="Override hostname")
parser.add_argument("-d", "--daemon", action="store_true",
help="Log to file instead of console (for service use)")
parser.add_argument("--log-file", dest="log_file", default=None,
help=f"Log file path (default: {LOG_FILE})")
parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
parser.add_argument("-x", "--debug", action="count", default=0, help="Debug level")
af_group = parser.add_mutually_exclusive_group()
af_group.add_argument("-4", dest="ipv4_only", action="store_true", help="Use IPv4 only")
af_group.add_argument("-6", dest="ipv6_only", action="store_true", help="Use IPv6 only")
parser.add_argument("hosts", nargs="+", help="HBD server(s)")
args = parser.parse_args(argv)
level = logging.WARNING
if args.verbose:
level = logging.INFO
if args.debug:
level = logging.DEBUG
if args.daemon or args.log_file:
_configure_file_logging(level, args.log_file or LOG_FILE)
else:
logging.basicConfig(
level=level,
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
cfg = _load_config(args.configfile)
try:
rc = asyncio.run(_async_main(args, cfg))
except KeyboardInterrupt:
rc = 0
except Exception as e:
logging.error("fatal: %s", e, exc_info=True)
rc = 1
if _dorestart:
logging.info("restarting...")
os.execv(sys.argv[0], sys.argv)
sys.exit(rc)
if __name__ == "__main__":
main()