644 lines
17 KiB
Python
644 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
HeartBeat Client (hbc) - Async version with plugin support.
|
|
|
|
Sends heartbeat messages to HeartBeat Daemon (hbd) servers and collects
|
|
system information via plugins.
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import signal
|
|
import socket
|
|
import sys
|
|
import time
|
|
from hashlib import md5
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional
|
|
|
|
# Import protocol and config
|
|
from .config import load_config
|
|
from ..common.proto import dicttos, stodict
|
|
|
|
# Import plugin system
|
|
from .plugin import PluginRegistry, PluginLoader, InfoPlugin, MonitorPlugin
|
|
|
|
# Constants
|
|
PORT = 50003
|
|
INTERVAL = 10
|
|
VER = 6
|
|
MAXRECV = 32767
|
|
|
|
# Global state
|
|
running = True
|
|
dorestart = False
|
|
|
|
|
|
class AsyncConnection:
|
|
"""Async UDP connection to a heartbeat server."""
|
|
|
|
def __init__(self, conn_id: int, addr: str, port: int, af: int, name: str):
|
|
self.conn_id = conn_id
|
|
self.addr = addr
|
|
self.port = port
|
|
self.af = af
|
|
self.name = name
|
|
|
|
self.ackcount = 0
|
|
self.lastack = 0.0
|
|
self.send_count = 0
|
|
self.lastsend = 0.0
|
|
self.rtts = [0.0]
|
|
|
|
self.transport: Optional[asyncio.DatagramTransport] = None
|
|
self.protocol: Optional[asyncio.DatagramProtocol] = None
|
|
|
|
self.logger = logging.getLogger(f"hbc.conn.{addr}")
|
|
|
|
async def open(self) -> bool:
|
|
"""Open the UDP connection.
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
|
|
# Create datagram endpoint
|
|
self.transport, self.protocol = await loop.create_datagram_endpoint(
|
|
lambda: HeartbeatProtocol(self),
|
|
family=self.af
|
|
)
|
|
self.logger.debug(f"Opened connection to {self.addr}:{self.port}")
|
|
return True
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to open connection: {e}")
|
|
return False
|
|
|
|
def close(self):
|
|
"""Close the connection."""
|
|
if self.transport:
|
|
self.transport.close()
|
|
self.transport = None
|
|
self.protocol = None
|
|
|
|
async def sendto(self, msg: dict, msg_id: str = "HTB"):
|
|
"""Send a message to the server.
|
|
|
|
Args:
|
|
msg: Message dictionary
|
|
msg_id: Message ID (HTB, PLG, etc.)
|
|
"""
|
|
if not self.transport:
|
|
await self.open()
|
|
|
|
if not self.transport:
|
|
self.logger.error("Cannot send - no transport")
|
|
return
|
|
|
|
# Add standard fields
|
|
msg["name"] = shortname(self.name)
|
|
msg["id"] = self.conn_id
|
|
msg["ver"] = VER
|
|
msg["time"] = time.time()
|
|
|
|
# Encode message
|
|
data = dicttos(msg_id, msg)
|
|
|
|
# Send
|
|
self.transport.sendto(data, (self.addr, self.port))
|
|
self.send_count += 1
|
|
self.lastsend = time.time()
|
|
|
|
self.logger.debug(f"Sent {msg_id} message ({len(data)} bytes)")
|
|
|
|
def handle_ack(self, msg: dict, now: float):
|
|
"""Handle ACK message from server."""
|
|
try:
|
|
self.lastack = msg.get("time", now)
|
|
rtt = (self.lastack - self.lastsend) * 2000.0 # Convert to ms
|
|
except Exception:
|
|
self.lastack = now
|
|
rtt = (self.lastack - self.lastsend) * 1000.0
|
|
|
|
self.rtts.append(rtt)
|
|
if len(self.rtts) > 10:
|
|
self.rtts.pop(0)
|
|
|
|
self.ackcount += 1
|
|
self.logger.debug(f"ACK received, RTT: {rtt:.1f}ms")
|
|
|
|
|
|
class HeartbeatProtocol(asyncio.DatagramProtocol):
|
|
"""Protocol handler for incoming UDP messages."""
|
|
|
|
def __init__(self, connection: AsyncConnection):
|
|
self.connection = connection
|
|
self.logger = logging.getLogger("hbc.protocol")
|
|
|
|
def datagram_received(self, data: bytes, addr):
|
|
"""Handle incoming datagram."""
|
|
try:
|
|
msg = stodict(data)
|
|
if not msg:
|
|
self.logger.warning(f"Failed to parse message from {addr}")
|
|
return
|
|
|
|
now = time.time()
|
|
msg_id = msg.get("ID")
|
|
|
|
if msg_id == "ACK":
|
|
self.connection.handle_ack(msg, now)
|
|
elif msg_id == "CMD":
|
|
# Command from server
|
|
asyncio.create_task(handle_command(self.connection, msg))
|
|
elif msg_id == "UPD":
|
|
# Update from server
|
|
asyncio.create_task(handle_update(self.connection, msg))
|
|
else:
|
|
self.logger.warning(f"Unknown message type: {msg_id}")
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error processing datagram: {e}", exc_info=True)
|
|
|
|
def error_received(self, exc):
|
|
"""Handle protocol errors."""
|
|
self.logger.error(f"Protocol error: {exc}")
|
|
|
|
|
|
async def handle_command(conn: AsyncConnection, msg: dict):
|
|
"""Execute a command received from server."""
|
|
import subprocess
|
|
|
|
cmd = msg.get("cmd", "")
|
|
if not cmd:
|
|
return
|
|
|
|
logger = logging.getLogger("hbc.command")
|
|
logger.info(f"Executing command: {cmd}")
|
|
|
|
try:
|
|
result = subprocess.check_output(
|
|
cmd, shell=True, stderr=subprocess.STDOUT, timeout=30
|
|
).decode()
|
|
status = "OK"
|
|
except subprocess.CalledProcessError as e:
|
|
result = str(e)
|
|
status = "CalledProcessError"
|
|
except subprocess.TimeoutExpired:
|
|
result = "Command timed out"
|
|
status = "Timeout"
|
|
except Exception as e:
|
|
result = str(e)
|
|
status = "Error"
|
|
|
|
# Send response
|
|
response = {
|
|
"service": "command",
|
|
"msg": f"{status} {result}"
|
|
}
|
|
await conn.sendto(response)
|
|
|
|
|
|
async def handle_update(conn: AsyncConnection, msg: dict):
|
|
"""Handle self-update from server."""
|
|
import codecs
|
|
import shutil
|
|
|
|
logger = logging.getLogger("hbc.update")
|
|
|
|
try:
|
|
code = codecs.decode(msg["code"], "base64").decode()
|
|
csum = msg["csum"]
|
|
except Exception as e:
|
|
error = f"Missing code/csum: {e}"
|
|
logger.error(error)
|
|
await conn.sendto({"service": "update", "msg": error})
|
|
return
|
|
|
|
# Verify checksum
|
|
m = md5()
|
|
m.update(code.encode())
|
|
if m.hexdigest() != csum:
|
|
error = "Checksum mismatch"
|
|
logger.error(error)
|
|
await conn.sendto({"service": "update", "msg": error})
|
|
return
|
|
|
|
# Backup current file
|
|
fn = sys.argv[0]
|
|
ofn = f"{fn}.sav"
|
|
try:
|
|
shutil.copy2(fn, ofn)
|
|
except Exception as e:
|
|
error = f"Backup failed: {e}"
|
|
logger.error(error)
|
|
await conn.sendto({"service": "update", "msg": error})
|
|
return
|
|
|
|
# Write new code
|
|
try:
|
|
with open(fn, "w") as fh:
|
|
fh.write(code)
|
|
except Exception as e:
|
|
error = f"Write failed: {e}"
|
|
logger.error(error)
|
|
await conn.sendto({"service": "update", "msg": error})
|
|
return
|
|
|
|
logger.info("Update successful, restart required")
|
|
await conn.sendto({"service": "update", "msg": "OK"})
|
|
|
|
# Trigger restart
|
|
global dorestart
|
|
dorestart = True
|
|
stop()
|
|
|
|
|
|
async def heartbeat_sender(conn: AsyncConnection, interval: int):
|
|
"""Send periodic heartbeats.
|
|
|
|
Args:
|
|
conn: Connection to send on
|
|
interval: Heartbeat interval in seconds
|
|
"""
|
|
logger = logging.getLogger("hbc.heartbeat")
|
|
|
|
while running:
|
|
try:
|
|
msg = {
|
|
"acks": conn.ackcount,
|
|
"rtt": conn.rtts[-1],
|
|
"interval": interval
|
|
}
|
|
await conn.sendto(msg, "HTB")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error sending heartbeat: {e}", exc_info=True)
|
|
|
|
# Wait for next interval
|
|
await asyncio.sleep(interval)
|
|
|
|
|
|
async def plugin_collector(conn: AsyncConnection, registry: PluginRegistry):
|
|
"""Collect and send plugin data.
|
|
|
|
Args:
|
|
conn: Connection to send on
|
|
registry: Plugin registry
|
|
"""
|
|
logger = logging.getLogger("hbc.plugins")
|
|
|
|
# Collect InfoPlugins once at startup
|
|
info_plugins = registry.get_by_type(InfoPlugin)
|
|
for plugin in info_plugins:
|
|
try:
|
|
data = await plugin.collect()
|
|
if data:
|
|
# Create PLG message with plugin name
|
|
plugin_msg = {"plugin": plugin.name, **data}
|
|
await conn.sendto(plugin_msg, "PLG")
|
|
logger.info(f"Sent {plugin.name} data")
|
|
except Exception as e:
|
|
logger.error(f"Error collecting {plugin.name}: {e}", exc_info=True)
|
|
|
|
# Schedule MonitorPlugins
|
|
# Group plugins by interval
|
|
from collections import defaultdict
|
|
by_interval = defaultdict(list)
|
|
|
|
monitor_plugins = registry.get_by_type(MonitorPlugin)
|
|
for plugin in monitor_plugins:
|
|
by_interval[plugin.interval].append(plugin)
|
|
|
|
# Create tasks for each interval
|
|
tasks = []
|
|
for interval, plugins in by_interval.items():
|
|
task = asyncio.create_task(
|
|
plugin_collector_interval(conn, plugins, interval)
|
|
)
|
|
tasks.append(task)
|
|
|
|
# Wait for all tasks
|
|
if tasks:
|
|
await asyncio.gather(*tasks)
|
|
|
|
|
|
async def plugin_collector_interval(
|
|
conn: AsyncConnection,
|
|
plugins: List,
|
|
interval: int
|
|
):
|
|
"""Collect plugins on a specific interval.
|
|
|
|
Args:
|
|
conn: Connection to send on
|
|
plugins: List of plugins to collect
|
|
interval: Collection interval in seconds
|
|
"""
|
|
logger = logging.getLogger(f"hbc.plugins.{interval}s")
|
|
|
|
while running:
|
|
for plugin in plugins:
|
|
try:
|
|
data = await plugin.collect()
|
|
if data:
|
|
# Don't use encode_plugin_data - create dict directly
|
|
plugin_msg = {"plugin": plugin.name, **data}
|
|
await conn.sendto(plugin_msg, "PLG")
|
|
logger.debug(f"Sent {plugin.name} data")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error collecting {plugin.name}: {e}",
|
|
exc_info=True
|
|
)
|
|
|
|
await asyncio.sleep(interval)
|
|
|
|
|
|
def shortname(name: str) -> str:
|
|
"""Extract short hostname."""
|
|
return name.split(".")[0]
|
|
|
|
|
|
def stop():
|
|
"""Stop the event loop."""
|
|
global running
|
|
running = False
|
|
|
|
|
|
async def cleanup(connections: List[AsyncConnection]):
|
|
"""Cleanup connections on shutdown."""
|
|
logger = logging.getLogger("hbc.cleanup")
|
|
logger.info("Cleaning up connections")
|
|
|
|
for conn in connections:
|
|
try:
|
|
msg = {
|
|
"shutdown": 1,
|
|
"acks": conn.ackcount
|
|
}
|
|
await conn.sendto(msg)
|
|
except Exception as e:
|
|
logger.error(f"Error sending shutdown: {e}")
|
|
|
|
conn.close()
|
|
|
|
# Give messages time to send
|
|
await asyncio.sleep(0.5)
|
|
|
|
|
|
async def async_main(args, config):
|
|
"""Async main function."""
|
|
global running
|
|
|
|
logger = logging.getLogger("hbc.main")
|
|
|
|
# Setup
|
|
iam = socket.gethostname()
|
|
if args.name:
|
|
iam = args.name
|
|
|
|
hb_hosts = args.hosts
|
|
hb_port = config.get("hb_port", PORT)
|
|
interval = config.get("interval", INTERVAL)
|
|
|
|
logger.info(f"Starting hbc for {iam} -> {hb_hosts}")
|
|
logger.info(f"Port: {hb_port}, Interval: {interval}s")
|
|
|
|
# Create connections
|
|
connections = []
|
|
conn_id = 1
|
|
|
|
for host in hb_hosts:
|
|
try:
|
|
addrs = socket.getaddrinfo(host, hb_port, 0, 0, socket.SOL_UDP)
|
|
except socket.gaierror as e:
|
|
logger.error(f"Cannot resolve {host}: {e}")
|
|
continue
|
|
|
|
for addr_info in addrs:
|
|
af = addr_info[0]
|
|
addr = addr_info[4][0]
|
|
|
|
conn = AsyncConnection(conn_id, addr, hb_port, af, iam)
|
|
if await conn.open():
|
|
connections.append(conn)
|
|
conn_id += 1
|
|
|
|
if not connections:
|
|
logger.error("No connections established")
|
|
return 1
|
|
|
|
logger.info(f"Created {len(connections)} connections")
|
|
|
|
# Send boot/message if requested
|
|
if args.boot or args.message:
|
|
boot_msg = {}
|
|
if args.boot:
|
|
boot_msg["boot"] = 1
|
|
if args.message:
|
|
boot_msg["service"] = "service"
|
|
boot_msg["msg"] = args.message
|
|
|
|
boot_msg["acks"] = 0
|
|
for conn in connections:
|
|
await conn.sendto(boot_msg)
|
|
|
|
if args.message and not args.daemon:
|
|
# Message-only mode
|
|
await cleanup(connections)
|
|
return 0
|
|
|
|
# Load plugins
|
|
registry = PluginRegistry()
|
|
loader = PluginLoader(registry)
|
|
|
|
plugin_dir = Path(__file__).parent / "plugins"
|
|
if plugin_dir.exists():
|
|
count = await loader.load_from_directory(plugin_dir, config)
|
|
logger.info(f"Loaded {count} plugins")
|
|
else:
|
|
logger.warning(f"Plugin directory not found: {plugin_dir}")
|
|
|
|
# Start async tasks
|
|
tasks = []
|
|
|
|
# Heartbeat senders (one per connection)
|
|
for conn in connections:
|
|
task = asyncio.create_task(heartbeat_sender(conn, interval))
|
|
tasks.append(task)
|
|
|
|
# Plugin collector (uses all connections, but we'll use first one)
|
|
if connections and registry.get_enabled():
|
|
task = asyncio.create_task(plugin_collector(connections[0], registry))
|
|
tasks.append(task)
|
|
|
|
# Setup signal handlers
|
|
loop = asyncio.get_event_loop()
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, stop)
|
|
|
|
# Wait for stop or tasks to complete
|
|
try:
|
|
await asyncio.gather(*tasks)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Cleanup
|
|
await cleanup(connections)
|
|
await loader.unload_all()
|
|
|
|
return 0
|
|
|
|
|
|
def daemonize(
|
|
working_dir="/",
|
|
stdin="/dev/zero",
|
|
stdout="/dev/null",
|
|
stderr="/dev/null"
|
|
):
|
|
"""UNIX double-fork daemonization."""
|
|
try:
|
|
pid = os.fork()
|
|
if pid > 0:
|
|
os._exit(0)
|
|
except OSError as e:
|
|
sys.stderr.write(f"fork #1 failed: {e}\n")
|
|
os._exit(1)
|
|
|
|
os.chdir(working_dir)
|
|
os.setsid()
|
|
os.umask(0)
|
|
|
|
try:
|
|
pid = os.fork()
|
|
if pid > 0:
|
|
os._exit(0)
|
|
except OSError as e:
|
|
sys.stderr.write(f"fork #2 failed: {e}\n")
|
|
sys.exit(1)
|
|
|
|
sys.stdout.flush()
|
|
sys.stderr.flush()
|
|
|
|
si = open(stdin, "r")
|
|
so = open(stdout, "a+")
|
|
se = open(stderr, "a+")
|
|
|
|
os.dup2(si.fileno(), sys.stdin.fileno())
|
|
os.dup2(so.fileno(), sys.stdout.fileno())
|
|
os.dup2(se.fileno(), sys.stderr.fileno())
|
|
|
|
|
|
def build_parser():
|
|
"""Build argument parser."""
|
|
parser = argparse.ArgumentParser(
|
|
prog="hbc",
|
|
description="HeartBeatClient - send heartbeat messages to HeartBeatDaemon",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
)
|
|
parser.add_argument(
|
|
"-b", "--boot",
|
|
action="store_true",
|
|
help="Send a boot message"
|
|
)
|
|
parser.add_argument(
|
|
"-c", "--config",
|
|
dest="configfile",
|
|
help="Config file path (YAML)"
|
|
)
|
|
parser.add_argument(
|
|
"-m", "--message",
|
|
dest="message",
|
|
help="Send a message"
|
|
)
|
|
parser.add_argument(
|
|
"-n", "--name",
|
|
dest="name",
|
|
help="Name to use in heartbeat message"
|
|
)
|
|
parser.add_argument(
|
|
"-d", "--daemon",
|
|
action="store_true",
|
|
help="Run in daemon mode"
|
|
)
|
|
parser.add_argument(
|
|
"-v", "--verbose",
|
|
action="store_true",
|
|
help="Verbose output"
|
|
)
|
|
parser.add_argument(
|
|
"-x", "--debug",
|
|
action="count",
|
|
default=0,
|
|
help="Increase debug level"
|
|
)
|
|
parser.add_argument(
|
|
"hosts",
|
|
nargs="+",
|
|
help="Heartbeat daemon hosts to send to"
|
|
)
|
|
return parser
|
|
|
|
|
|
def main(argv=None):
|
|
"""Main entry point."""
|
|
global running, dorestart
|
|
|
|
parser = build_parser()
|
|
args = parser.parse_args(argv)
|
|
|
|
# Load config
|
|
config = load_config(args.configfile)
|
|
|
|
# Setup logging
|
|
log_level = logging.INFO
|
|
if args.verbose:
|
|
log_level = logging.DEBUG
|
|
if args.debug:
|
|
log_level = logging.DEBUG
|
|
|
|
logging.basicConfig(
|
|
level=log_level,
|
|
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S"
|
|
)
|
|
|
|
# Daemonize if requested
|
|
if args.daemon:
|
|
print("Daemonizing...")
|
|
import syslog
|
|
syslog.openlog("hbc", syslog.LOG_PID, syslog.LOG_DAEMON)
|
|
syslog.syslog(syslog.LOG_INFO, f"Starting heartbeat to {', '.join(args.hosts)}")
|
|
daemonize()
|
|
|
|
# Reconfigure logging for syslog
|
|
logging.basicConfig(
|
|
level=log_level,
|
|
format="hbc[%(process)d]: %(name)s %(levelname)s: %(message)s"
|
|
)
|
|
|
|
# Run async main
|
|
try:
|
|
exit_code = asyncio.run(async_main(args, config))
|
|
except KeyboardInterrupt:
|
|
logging.info("Interrupted by user")
|
|
exit_code = 0
|
|
except Exception as e:
|
|
logging.error(f"Fatal error: {e}", exc_info=True)
|
|
exit_code = 1
|
|
|
|
# Handle restart
|
|
if dorestart:
|
|
logging.info("Restarting...")
|
|
os.execv(sys.argv[0], sys.argv)
|
|
|
|
sys.exit(exit_code)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|