#!/usr/bin/env python3 """ HeartBeat Client (hbc) - Async version with plugin support. Sends heartbeat messages to HeartBeat Daemon (hbd) servers and collects system information via plugins. """ import argparse import asyncio import logging import os import signal import socket import sys import time from hashlib import md5 from pathlib import Path from typing import Dict, List, Optional # Import protocol and config from .config import load_config from ..common.proto import dicttos, stodict # Import plugin system from .plugin import PluginRegistry, PluginLoader, InfoPlugin, MonitorPlugin # Constants PORT = 50003 INTERVAL = 10 MAXRECV = 32767 # Global state running = True dorestart = False shutdown_event: Optional[asyncio.Event] = None active_tasks: List[asyncio.Task] = [] class AsyncConnection: """Async UDP connection to a heartbeat server.""" def __init__(self, conn_id: int, addr: str, port: int, af: int, name: str): self.conn_id = conn_id self.addr = addr self.port = port self.af = af self.name = name self.ackcount = 0 self.lastack = 0.0 self.send_count = 0 self.lastsend = 0.0 self.rtts = [0.0] self.transport: Optional[asyncio.DatagramTransport] = None self.protocol: Optional[asyncio.DatagramProtocol] = None self.logger = logging.getLogger(f"hbc.conn.{addr}") async def open(self) -> bool: """Open the UDP connection. Returns: True if successful, False otherwise """ try: loop = asyncio.get_event_loop() # Create datagram endpoint self.transport, self.protocol = await loop.create_datagram_endpoint( lambda: HeartbeatProtocol(self), family=self.af ) self.logger.debug(f"Opened connection to {self.addr}:{self.port}") return True except Exception as e: self.logger.error(f"Failed to open connection: {e}") return False def close(self): """Close the connection.""" if self.transport: self.transport.close() self.transport = None self.protocol = None async def sendto(self, msg: dict, msg_id: str = "HTB"): """Send a message to the server. Args: msg: Message dictionary msg_id: Message ID (HTB, PLG, etc.) """ if not self.transport: await self.open() if not self.transport: self.logger.error("Cannot send - no transport") return # Add standard fields msg["name"] = shortname(self.name) msg["id"] = self.conn_id msg["time"] = time.time() # Encode message data = dicttos(msg_id, msg) # Send self.transport.sendto(data, (self.addr, self.port)) self.send_count += 1 self.lastsend = time.time() self.logger.debug(f"Sent {msg_id} message ({len(data)} bytes)") def handle_ack(self, msg: dict, now: float): """Handle ACK message from server. RTT is calculated as: (time ACK received) - (time HTB sent) """ self.lastack = now # Calculate RTT: time ACK received minus time HTB sent rtt = (now - self.lastsend) * 1000.0 # Convert to ms self.rtts.append(rtt) if len(self.rtts) > 10: self.rtts.pop(0) self.ackcount += 1 self.logger.debug(f"ACK received, RTT: {rtt:.1f}ms") class HeartbeatProtocol(asyncio.DatagramProtocol): """Protocol handler for incoming UDP messages.""" def __init__(self, connection: AsyncConnection): self.connection = connection self.logger = logging.getLogger("hbc.protocol") def datagram_received(self, data: bytes, addr): """Handle incoming datagram.""" try: msg = stodict(data) if not msg: self.logger.warning(f"Failed to parse message from {addr}") return now = time.time() msg_id = msg.get("ID") if msg_id == "ACK": self.connection.handle_ack(msg, now) elif msg_id == "CMD": # Command from server asyncio.create_task(handle_command(self.connection, msg)) elif msg_id == "UPD": # Update from server asyncio.create_task(handle_update(self.connection, msg)) else: self.logger.warning(f"Unknown message type: {msg_id}") except Exception as e: self.logger.error(f"Error processing datagram: {e}", exc_info=True) def error_received(self, exc): """Handle protocol errors.""" self.logger.error(f"Protocol error: {exc}") async def handle_command(conn: AsyncConnection, msg: dict): """Execute a command received from server.""" import subprocess cmd = msg.get("cmd", "") if not cmd: return logger = logging.getLogger("hbc.command") logger.info(f"Executing command: {cmd}") try: result = subprocess.check_output( cmd, shell=True, stderr=subprocess.STDOUT, timeout=30 ).decode() status = "OK" except subprocess.CalledProcessError as e: result = str(e) status = "CalledProcessError" except subprocess.TimeoutExpired: result = "Command timed out" status = "Timeout" except Exception as e: result = str(e) status = "Error" # Send response response = { "service": "command", "msg": f"{status} {result}" } await conn.sendto(response) async def handle_update(conn: AsyncConnection, msg: dict): """Handle self-update from server.""" import codecs import shutil logger = logging.getLogger("hbc.update") try: code = codecs.decode(msg["code"], "base64").decode() csum = msg["csum"] except Exception as e: error = f"Missing code/csum: {e}" logger.error(error) await conn.sendto({"service": "update", "msg": error}) return # Verify checksum m = md5() m.update(code.encode()) if m.hexdigest() != csum: error = "Checksum mismatch" logger.error(error) await conn.sendto({"service": "update", "msg": error}) return # Backup current file fn = sys.argv[0] ofn = f"{fn}.sav" try: shutil.copy2(fn, ofn) except Exception as e: error = f"Backup failed: {e}" logger.error(error) await conn.sendto({"service": "update", "msg": error}) return # Write new code try: with open(fn, "w") as fh: fh.write(code) except Exception as e: error = f"Write failed: {e}" logger.error(error) await conn.sendto({"service": "update", "msg": error}) return logger.info("Update successful, restart required") await conn.sendto({"service": "update", "msg": "OK"}) # Trigger restart global dorestart dorestart = True stop() async def heartbeat_sender(conn: AsyncConnection, interval: int): """Send periodic heartbeats. Args: conn: Connection to send on interval: Heartbeat interval in seconds """ logger = logging.getLogger("hbc.heartbeat") while running: try: msg = { "acks": conn.ackcount, "rtt": conn.rtts[-1], "interval": interval } await conn.sendto(msg, "HTB") except Exception as e: logger.error(f"Error sending heartbeat: {e}", exc_info=True) except asyncio.CancelledError: logger.debug("Heartbeat sender cancelled") raise # Wait for next interval or shutdown event try: if shutdown_event: await asyncio.wait_for( shutdown_event.wait(), timeout=interval ) break else: await asyncio.sleep(interval) except asyncio.TimeoutError: pass # Normal timeout, continue loop except asyncio.CancelledError: logger.debug("Heartbeat sender cancelled during sleep") raise async def plugin_collector(conn: AsyncConnection, registry: PluginRegistry): """Collect and send plugin data. Args: conn: Connection to send on registry: Plugin registry """ logger = logging.getLogger("hbc.plugins") # Collect InfoPlugins once at startup info_plugins = registry.get_by_type(InfoPlugin) for plugin in info_plugins: try: data = await plugin.collect() if data: # Create PLG message with plugin name plugin_msg = {"plugin": plugin.name, **data} await conn.sendto(plugin_msg, "PLG") logger.info(f"Sent {plugin.name} data") except Exception as e: logger.error(f"Error collecting {plugin.name}: {e}", exc_info=True) # Schedule MonitorPlugins # Group plugins by interval from collections import defaultdict by_interval = defaultdict(list) monitor_plugins = registry.get_by_type(MonitorPlugin) for plugin in monitor_plugins: by_interval[plugin.interval].append(plugin) # Create tasks for each interval tasks = [] for interval, plugins in by_interval.items(): task = asyncio.create_task( plugin_collector_interval(conn, plugins, interval) ) tasks.append(task) # Wait for all tasks if tasks: try: await asyncio.gather(*tasks, return_exceptions=True) except asyncio.CancelledError: logger.debug("Plugin collector cancelled, cancelling sub-tasks") for task in tasks: if not task.done(): task.cancel() raise async def plugin_collector_interval( conn: AsyncConnection, plugins: List, interval: int ): """Collect plugins on a specific interval. Args: conn: Connection to send on plugins: List of plugins to collect interval: Collection interval in seconds """ logger = logging.getLogger(f"hbc.plugins.{interval}s") while running: for plugin in plugins: try: data = await plugin.collect() if data: # Don't use encode_plugin_data - create dict directly plugin_msg = {"plugin": plugin.name, **data} await conn.sendto(plugin_msg, "PLG") logger.debug(f"Sent {plugin.name} data") except asyncio.CancelledError: logger.debug("Plugin collector cancelled") raise except Exception as e: logger.error( f"Error collecting {plugin.name}: {e}", exc_info=True ) # Wait for next interval or shutdown event try: if shutdown_event: await asyncio.wait_for( shutdown_event.wait(), timeout=interval ) break else: await asyncio.sleep(interval) except asyncio.TimeoutError: pass # Normal timeout, continue loop except asyncio.CancelledError: logger.debug("Plugin collector cancelled during sleep") raise def shortname(name: str) -> str: """Extract short hostname.""" return name.split(".")[0] def stop(): """Stop the event loop.""" global running running = False # Set shutdown event to wake up sleeping tasks if shutdown_event: shutdown_event.set() # Cancel all active tasks for task in active_tasks: if not task.done(): task.cancel() async def cleanup(connections: List[AsyncConnection]): """Cleanup connections on shutdown.""" logger = logging.getLogger("hbc.cleanup") logger.info("Cleaning up connections") for conn in connections: try: msg = { "shutdown": 1, "acks": conn.ackcount } await conn.sendto(msg) except Exception as e: logger.error(f"Error sending shutdown: {e}") conn.close() # Give messages time to send await asyncio.sleep(0.5) async def async_main(args, config): """Async main function.""" global running, shutdown_event, active_tasks # Create shutdown event shutdown_event = asyncio.Event() active_tasks = [] logger = logging.getLogger("hbc.main") # Setup iam = socket.gethostname() if args.name: iam = args.name hb_hosts = args.hosts hb_port = config.get("hb_port", PORT) interval = config.get("interval", INTERVAL) logger.info(f"Starting hbc for {iam} -> {hb_hosts}") logger.info(f"Port: {hb_port}, Interval: {interval}s") # Create connections connections = [] conn_id = 1 for host in hb_hosts: try: addrs = socket.getaddrinfo(host, hb_port, 0, 0, socket.SOL_UDP) except socket.gaierror as e: logger.error(f"Cannot resolve {host}: {e}") continue for addr_info in addrs: af = addr_info[0] addr = addr_info[4][0] conn = AsyncConnection(conn_id, addr, hb_port, af, iam) if await conn.open(): connections.append(conn) conn_id += 1 if not connections: logger.error("No connections established") return 1 logger.info(f"Created {len(connections)} connections") # Send boot/message if requested if args.boot or args.message: boot_msg = {} if args.boot: boot_msg["boot"] = 1 if args.message: boot_msg["service"] = "service" boot_msg["msg"] = args.message boot_msg["acks"] = 0 for conn in connections: await conn.sendto(boot_msg) if args.message and not args.daemon: # Message-only mode await cleanup(connections) return 0 # Load plugins registry = PluginRegistry() loader = PluginLoader(registry) plugin_dir = Path(__file__).parent / "plugins" if plugin_dir.exists(): count = await loader.load_from_directory(plugin_dir, config) logger.info(f"Loaded {count} plugins") else: logger.warning(f"Plugin directory not found: {plugin_dir}") # Setup signal handlers loop = asyncio.get_event_loop() for sig in (signal.SIGTERM, signal.SIGINT): loop.add_signal_handler(sig, stop) # Start async tasks # Heartbeat senders (one per connection) for conn in connections: task = asyncio.create_task(heartbeat_sender(conn, interval)) active_tasks.append(task) # Plugin collector (uses all connections, but we'll use first one) if connections and registry.get_enabled(): task = asyncio.create_task(plugin_collector(connections[0], registry)) active_tasks.append(task) # Wait for stop or tasks to complete try: await asyncio.gather(*active_tasks, return_exceptions=True) except asyncio.CancelledError: logger.info("Tasks cancelled") # Cleanup logger.info("Shutting down...") await cleanup(connections) await loader.unload_all() return 0 def daemonize( working_dir="/", stdin="/dev/zero", stdout="/dev/null", stderr="/dev/null" ): """UNIX double-fork daemonization.""" try: pid = os.fork() if pid > 0: os._exit(0) except OSError as e: sys.stderr.write(f"fork #1 failed: {e}\n") os._exit(1) os.chdir(working_dir) os.setsid() os.umask(0) try: pid = os.fork() if pid > 0: os._exit(0) except OSError as e: sys.stderr.write(f"fork #2 failed: {e}\n") sys.exit(1) sys.stdout.flush() sys.stderr.flush() si = open(stdin, "r") so = open(stdout, "a+") se = open(stderr, "a+") os.dup2(si.fileno(), sys.stdin.fileno()) os.dup2(so.fileno(), sys.stdout.fileno()) os.dup2(se.fileno(), sys.stderr.fileno()) def _reconfigure_logging_for_daemon(log_level: int) -> None: """Replace StreamHandlers (now writing to /dev/null) with a SysLogHandler.""" from logging.handlers import SysLogHandler root = logging.getLogger() for handler in root.handlers[:]: root.removeHandler(handler) handler.close() try: syslog_handler = SysLogHandler( address="/dev/log", facility=SysLogHandler.LOG_DAEMON, ) except OSError: syslog_handler = SysLogHandler( address=("localhost", 514), facility=SysLogHandler.LOG_DAEMON, ) # Attach the fallback first so the warning reaches syslog syslog_handler.setFormatter( logging.Formatter("hbc[%(process)d]: %(name)s %(levelname)s: %(message)s") ) root.addHandler(syslog_handler) root.setLevel(log_level) logging.warning("/dev/log not found, using syslog UDP localhost:514") return syslog_handler.setFormatter( logging.Formatter("hbc[%(process)d]: %(name)s %(levelname)s: %(message)s") ) root.addHandler(syslog_handler) root.setLevel(log_level) def build_parser(): """Build argument parser.""" parser = argparse.ArgumentParser( prog="hbc", description="HeartBeatClient - send heartbeat messages to HeartBeatDaemon", formatter_class=argparse.RawDescriptionHelpFormatter, ) parser.add_argument( "-b", "--boot", action="store_true", help="Send a boot message" ) parser.add_argument( "-c", "--config", dest="configfile", help="Config file path (YAML)" ) parser.add_argument( "-m", "--message", dest="message", help="Send a message" ) parser.add_argument( "-n", "--name", dest="name", help="Name to use in heartbeat message" ) parser.add_argument( "-d", "--daemon", action="store_true", help="Run in daemon mode" ) parser.add_argument( "-v", "--verbose", action="store_true", help="Verbose output" ) parser.add_argument( "-x", "--debug", action="count", default=0, help="Increase debug level" ) parser.add_argument( "hosts", nargs="+", help="Heartbeat daemon hosts to send to" ) return parser def main(argv=None): """Main entry point.""" global running, dorestart parser = build_parser() args = parser.parse_args(argv) # Setup logging log_level = logging.WARNING if args.verbose: log_level = logging.INFO if args.debug: log_level = logging.DEBUG logging.basicConfig( level=log_level, format="%(asctime)s %(name)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) # Load config config = load_config(args.configfile) # Daemonize if requested if args.daemon: print("Daemonizing...") daemonize() _reconfigure_logging_for_daemon(log_level) logging.info(f"hbc starting, sending heartbeat to {', '.join(args.hosts)}") # Run async main try: exit_code = asyncio.run(async_main(args, config)) except KeyboardInterrupt: logging.info("Interrupted by user") exit_code = 0 except Exception as e: logging.error(f"Fatal error: {e}", exc_info=True) exit_code = 1 # Handle restart if dorestart: logging.info("Restarting...") os.execv(sys.argv[0], sys.argv) sys.exit(exit_code) if __name__ == "__main__": main()