Major refactoring of the codebase, including restructuring of files and directories, renaming of modules and classes, and improvements to the overall organization and readability of the code. This refactoring aims to enhance maintainability, scalability, and clarity of the codebase while preserving existing functionality. The changes include:
- Restructuring of the project directory into client and server components - Renaming of modules and classes to better reflect their purpose and functionality - Moving common utilities and configurations to a shared location - Updating import statements to reflect the new structure - Adding new documentation files for better clarity on various aspects of the project - Removing deprecated or unused code to streamline the codebase - Ensuring that all existing functionality is preserved and that the codebase remains functional after the refactoring.
This commit is contained in:
@@ -0,0 +1,643 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
HeartBeat Client (hbc) - Async version with plugin support.
|
||||
|
||||
Sends heartbeat messages to HeartBeat Daemon (hbd) servers and collects
|
||||
system information via plugins.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
# Import protocol and config
|
||||
from .config import load_config
|
||||
from ..common.proto import dicttos, stodict
|
||||
|
||||
# Import plugin system
|
||||
from .plugin import PluginRegistry, PluginLoader, InfoPlugin, MonitorPlugin
|
||||
|
||||
# Constants
|
||||
PORT = 50003
|
||||
INTERVAL = 10
|
||||
VER = 6
|
||||
MAXRECV = 32767
|
||||
|
||||
# Global state
|
||||
running = True
|
||||
dorestart = False
|
||||
|
||||
|
||||
class AsyncConnection:
|
||||
"""Async UDP connection to a heartbeat server."""
|
||||
|
||||
def __init__(self, conn_id: int, addr: str, port: int, af: int, name: str):
|
||||
self.conn_id = conn_id
|
||||
self.addr = addr
|
||||
self.port = port
|
||||
self.af = af
|
||||
self.name = name
|
||||
|
||||
self.ackcount = 0
|
||||
self.lastack = 0.0
|
||||
self.send_count = 0
|
||||
self.lastsend = 0.0
|
||||
self.rtts = [0.0]
|
||||
|
||||
self.transport: Optional[asyncio.DatagramTransport] = None
|
||||
self.protocol: Optional[asyncio.DatagramProtocol] = None
|
||||
|
||||
self.logger = logging.getLogger(f"hbc.conn.{addr}")
|
||||
|
||||
async def open(self) -> bool:
|
||||
"""Open the UDP connection.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Create datagram endpoint
|
||||
self.transport, self.protocol = await loop.create_datagram_endpoint(
|
||||
lambda: HeartbeatProtocol(self),
|
||||
family=self.af
|
||||
)
|
||||
self.logger.debug(f"Opened connection to {self.addr}:{self.port}")
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to open connection: {e}")
|
||||
return False
|
||||
|
||||
def close(self):
|
||||
"""Close the connection."""
|
||||
if self.transport:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
self.protocol = None
|
||||
|
||||
async def sendto(self, msg: dict, msg_id: str = "HTB"):
|
||||
"""Send a message to the server.
|
||||
|
||||
Args:
|
||||
msg: Message dictionary
|
||||
msg_id: Message ID (HTB, PLG, etc.)
|
||||
"""
|
||||
if not self.transport:
|
||||
await self.open()
|
||||
|
||||
if not self.transport:
|
||||
self.logger.error("Cannot send - no transport")
|
||||
return
|
||||
|
||||
# Add standard fields
|
||||
msg["name"] = shortname(self.name)
|
||||
msg["id"] = self.conn_id
|
||||
msg["ver"] = VER
|
||||
msg["time"] = time.time()
|
||||
|
||||
# Encode message
|
||||
data = dicttos(msg_id, msg, compress=True)
|
||||
|
||||
# Send
|
||||
self.transport.sendto(data, (self.addr, self.port))
|
||||
self.send_count += 1
|
||||
self.lastsend = time.time()
|
||||
|
||||
self.logger.debug(f"Sent {msg_id} message ({len(data)} bytes)")
|
||||
|
||||
def handle_ack(self, msg: dict, now: float):
|
||||
"""Handle ACK message from server."""
|
||||
try:
|
||||
self.lastack = msg.get("time", now)
|
||||
rtt = (self.lastack - self.lastsend) * 2000.0 # Convert to ms
|
||||
except Exception:
|
||||
self.lastack = now
|
||||
rtt = (self.lastack - self.lastsend) * 1000.0
|
||||
|
||||
self.rtts.append(rtt)
|
||||
if len(self.rtts) > 10:
|
||||
self.rtts.pop(0)
|
||||
|
||||
self.ackcount += 1
|
||||
self.logger.debug(f"ACK received, RTT: {rtt:.1f}ms")
|
||||
|
||||
|
||||
class HeartbeatProtocol(asyncio.DatagramProtocol):
|
||||
"""Protocol handler for incoming UDP messages."""
|
||||
|
||||
def __init__(self, connection: AsyncConnection):
|
||||
self.connection = connection
|
||||
self.logger = logging.getLogger("hbc.protocol")
|
||||
|
||||
def datagram_received(self, data: bytes, addr):
|
||||
"""Handle incoming datagram."""
|
||||
try:
|
||||
msg = stodict(data)
|
||||
if not msg:
|
||||
self.logger.warning(f"Failed to parse message from {addr}")
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
msg_id = msg.get("ID")
|
||||
|
||||
if msg_id == "ACK":
|
||||
self.connection.handle_ack(msg, now)
|
||||
elif msg_id == "CMD":
|
||||
# Command from server
|
||||
asyncio.create_task(handle_command(self.connection, msg))
|
||||
elif msg_id == "UPD":
|
||||
# Update from server
|
||||
asyncio.create_task(handle_update(self.connection, msg))
|
||||
else:
|
||||
self.logger.warning(f"Unknown message type: {msg_id}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error processing datagram: {e}", exc_info=True)
|
||||
|
||||
def error_received(self, exc):
|
||||
"""Handle protocol errors."""
|
||||
self.logger.error(f"Protocol error: {exc}")
|
||||
|
||||
|
||||
async def handle_command(conn: AsyncConnection, msg: dict):
|
||||
"""Execute a command received from server."""
|
||||
import subprocess
|
||||
|
||||
cmd = msg.get("cmd", "")
|
||||
if not cmd:
|
||||
return
|
||||
|
||||
logger = logging.getLogger("hbc.command")
|
||||
logger.info(f"Executing command: {cmd}")
|
||||
|
||||
try:
|
||||
result = subprocess.check_output(
|
||||
cmd, shell=True, stderr=subprocess.STDOUT, timeout=30
|
||||
).decode()
|
||||
status = "OK"
|
||||
except subprocess.CalledProcessError as e:
|
||||
result = str(e)
|
||||
status = "CalledProcessError"
|
||||
except subprocess.TimeoutExpired:
|
||||
result = "Command timed out"
|
||||
status = "Timeout"
|
||||
except Exception as e:
|
||||
result = str(e)
|
||||
status = "Error"
|
||||
|
||||
# Send response
|
||||
response = {
|
||||
"service": "command",
|
||||
"msg": f"{status} {result}"
|
||||
}
|
||||
await conn.sendto(response)
|
||||
|
||||
|
||||
async def handle_update(conn: AsyncConnection, msg: dict):
|
||||
"""Handle self-update from server."""
|
||||
import codecs
|
||||
import shutil
|
||||
|
||||
logger = logging.getLogger("hbc.update")
|
||||
|
||||
try:
|
||||
code = codecs.decode(msg["code"], "base64").decode()
|
||||
csum = msg["csum"]
|
||||
except Exception as e:
|
||||
error = f"Missing code/csum: {e}"
|
||||
logger.error(error)
|
||||
await conn.sendto({"service": "update", "msg": error})
|
||||
return
|
||||
|
||||
# Verify checksum
|
||||
m = md5()
|
||||
m.update(code.encode())
|
||||
if m.hexdigest() != csum:
|
||||
error = "Checksum mismatch"
|
||||
logger.error(error)
|
||||
await conn.sendto({"service": "update", "msg": error})
|
||||
return
|
||||
|
||||
# Backup current file
|
||||
fn = sys.argv[0]
|
||||
ofn = f"{fn}.sav"
|
||||
try:
|
||||
shutil.copy2(fn, ofn)
|
||||
except Exception as e:
|
||||
error = f"Backup failed: {e}"
|
||||
logger.error(error)
|
||||
await conn.sendto({"service": "update", "msg": error})
|
||||
return
|
||||
|
||||
# Write new code
|
||||
try:
|
||||
with open(fn, "w") as fh:
|
||||
fh.write(code)
|
||||
except Exception as e:
|
||||
error = f"Write failed: {e}"
|
||||
logger.error(error)
|
||||
await conn.sendto({"service": "update", "msg": error})
|
||||
return
|
||||
|
||||
logger.info("Update successful, restart required")
|
||||
await conn.sendto({"service": "update", "msg": "OK"})
|
||||
|
||||
# Trigger restart
|
||||
global dorestart
|
||||
dorestart = True
|
||||
stop()
|
||||
|
||||
|
||||
async def heartbeat_sender(conn: AsyncConnection, interval: int):
|
||||
"""Send periodic heartbeats.
|
||||
|
||||
Args:
|
||||
conn: Connection to send on
|
||||
interval: Heartbeat interval in seconds
|
||||
"""
|
||||
logger = logging.getLogger("hbc.heartbeat")
|
||||
|
||||
while running:
|
||||
try:
|
||||
msg = {
|
||||
"acks": conn.ackcount,
|
||||
"rtt": conn.rtts[-1],
|
||||
"interval": interval
|
||||
}
|
||||
await conn.sendto(msg, "HTB")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending heartbeat: {e}", exc_info=True)
|
||||
|
||||
# Wait for next interval
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
async def plugin_collector(conn: AsyncConnection, registry: PluginRegistry):
|
||||
"""Collect and send plugin data.
|
||||
|
||||
Args:
|
||||
conn: Connection to send on
|
||||
registry: Plugin registry
|
||||
"""
|
||||
logger = logging.getLogger("hbc.plugins")
|
||||
|
||||
# Collect InfoPlugins once at startup
|
||||
info_plugins = registry.get_by_type(InfoPlugin)
|
||||
for plugin in info_plugins:
|
||||
try:
|
||||
data = await plugin.collect()
|
||||
if data:
|
||||
# Create PLG message with plugin name
|
||||
plugin_msg = {"plugin": plugin.name, **data}
|
||||
await conn.sendto(plugin_msg, "PLG")
|
||||
logger.info(f"Sent {plugin.name} data")
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting {plugin.name}: {e}", exc_info=True)
|
||||
|
||||
# Schedule MonitorPlugins
|
||||
# Group plugins by interval
|
||||
from collections import defaultdict
|
||||
by_interval = defaultdict(list)
|
||||
|
||||
monitor_plugins = registry.get_by_type(MonitorPlugin)
|
||||
for plugin in monitor_plugins:
|
||||
by_interval[plugin.interval].append(plugin)
|
||||
|
||||
# Create tasks for each interval
|
||||
tasks = []
|
||||
for interval, plugins in by_interval.items():
|
||||
task = asyncio.create_task(
|
||||
plugin_collector_interval(conn, plugins, interval)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all tasks
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
async def plugin_collector_interval(
|
||||
conn: AsyncConnection,
|
||||
plugins: List,
|
||||
interval: int
|
||||
):
|
||||
"""Collect plugins on a specific interval.
|
||||
|
||||
Args:
|
||||
conn: Connection to send on
|
||||
plugins: List of plugins to collect
|
||||
interval: Collection interval in seconds
|
||||
"""
|
||||
logger = logging.getLogger(f"hbc.plugins.{interval}s")
|
||||
|
||||
while running:
|
||||
for plugin in plugins:
|
||||
try:
|
||||
data = await plugin.collect()
|
||||
if data:
|
||||
# Don't use encode_plugin_data - create dict directly
|
||||
plugin_msg = {"plugin": plugin.name, **data}
|
||||
await conn.sendto(plugin_msg, "PLG")
|
||||
logger.debug(f"Sent {plugin.name} data")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error collecting {plugin.name}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
def shortname(name: str) -> str:
|
||||
"""Extract short hostname."""
|
||||
return name.split(".")[0]
|
||||
|
||||
|
||||
def stop():
|
||||
"""Stop the event loop."""
|
||||
global running
|
||||
running = False
|
||||
|
||||
|
||||
async def cleanup(connections: List[AsyncConnection]):
|
||||
"""Cleanup connections on shutdown."""
|
||||
logger = logging.getLogger("hbc.cleanup")
|
||||
logger.info("Cleaning up connections")
|
||||
|
||||
for conn in connections:
|
||||
try:
|
||||
msg = {
|
||||
"shutdown": 1,
|
||||
"acks": conn.ackcount
|
||||
}
|
||||
await conn.sendto(msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending shutdown: {e}")
|
||||
|
||||
conn.close()
|
||||
|
||||
# Give messages time to send
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
async def async_main(args, config):
|
||||
"""Async main function."""
|
||||
global running
|
||||
|
||||
logger = logging.getLogger("hbc.main")
|
||||
|
||||
# Setup
|
||||
iam = socket.gethostname()
|
||||
if args.name:
|
||||
iam = args.name
|
||||
|
||||
hb_hosts = args.hosts
|
||||
hb_port = config.get("hb_port", PORT)
|
||||
interval = config.get("interval", INTERVAL)
|
||||
|
||||
logger.info(f"Starting hbc for {iam} -> {hb_hosts}")
|
||||
logger.info(f"Port: {hb_port}, Interval: {interval}s")
|
||||
|
||||
# Create connections
|
||||
connections = []
|
||||
conn_id = 1
|
||||
|
||||
for host in hb_hosts:
|
||||
try:
|
||||
addrs = socket.getaddrinfo(host, hb_port, 0, 0, socket.SOL_UDP)
|
||||
except socket.gaierror as e:
|
||||
logger.error(f"Cannot resolve {host}: {e}")
|
||||
continue
|
||||
|
||||
for addr_info in addrs:
|
||||
af = addr_info[0]
|
||||
addr = addr_info[4][0]
|
||||
|
||||
conn = AsyncConnection(conn_id, addr, hb_port, af, iam)
|
||||
if await conn.open():
|
||||
connections.append(conn)
|
||||
conn_id += 1
|
||||
|
||||
if not connections:
|
||||
logger.error("No connections established")
|
||||
return 1
|
||||
|
||||
logger.info(f"Created {len(connections)} connections")
|
||||
|
||||
# Send boot/message if requested
|
||||
if args.boot or args.message:
|
||||
boot_msg = {}
|
||||
if args.boot:
|
||||
boot_msg["boot"] = 1
|
||||
if args.message:
|
||||
boot_msg["service"] = "service"
|
||||
boot_msg["msg"] = args.message
|
||||
|
||||
boot_msg["acks"] = 0
|
||||
for conn in connections:
|
||||
await conn.sendto(boot_msg)
|
||||
|
||||
if args.message and not args.daemon:
|
||||
# Message-only mode
|
||||
await cleanup(connections)
|
||||
return 0
|
||||
|
||||
# Load plugins
|
||||
registry = PluginRegistry()
|
||||
loader = PluginLoader(registry)
|
||||
|
||||
plugin_dir = Path(__file__).parent / "plugins"
|
||||
if plugin_dir.exists():
|
||||
count = await loader.load_from_directory(plugin_dir, config)
|
||||
logger.info(f"Loaded {count} plugins")
|
||||
else:
|
||||
logger.warning(f"Plugin directory not found: {plugin_dir}")
|
||||
|
||||
# Start async tasks
|
||||
tasks = []
|
||||
|
||||
# Heartbeat senders (one per connection)
|
||||
for conn in connections:
|
||||
task = asyncio.create_task(heartbeat_sender(conn, interval))
|
||||
tasks.append(task)
|
||||
|
||||
# Plugin collector (uses all connections, but we'll use first one)
|
||||
if connections and registry.get_enabled():
|
||||
task = asyncio.create_task(plugin_collector(connections[0], registry))
|
||||
tasks.append(task)
|
||||
|
||||
# Setup signal handlers
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, stop)
|
||||
|
||||
# Wait for stop or tasks to complete
|
||||
try:
|
||||
await asyncio.gather(*tasks)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Cleanup
|
||||
await cleanup(connections)
|
||||
await loader.unload_all()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def daemonize(
|
||||
working_dir="/",
|
||||
stdin="/dev/zero",
|
||||
stdout="/dev/null",
|
||||
stderr="/dev/null"
|
||||
):
|
||||
"""UNIX double-fork daemonization."""
|
||||
try:
|
||||
pid = os.fork()
|
||||
if pid > 0:
|
||||
os._exit(0)
|
||||
except OSError as e:
|
||||
sys.stderr.write(f"fork #1 failed: {e}\n")
|
||||
os._exit(1)
|
||||
|
||||
os.chdir(working_dir)
|
||||
os.setsid()
|
||||
os.umask(0)
|
||||
|
||||
try:
|
||||
pid = os.fork()
|
||||
if pid > 0:
|
||||
os._exit(0)
|
||||
except OSError as e:
|
||||
sys.stderr.write(f"fork #2 failed: {e}\n")
|
||||
sys.exit(1)
|
||||
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
si = open(stdin, "r")
|
||||
so = open(stdout, "a+")
|
||||
se = open(stderr, "a+")
|
||||
|
||||
os.dup2(si.fileno(), sys.stdin.fileno())
|
||||
os.dup2(so.fileno(), sys.stdout.fileno())
|
||||
os.dup2(se.fileno(), sys.stderr.fileno())
|
||||
|
||||
|
||||
def build_parser():
|
||||
"""Build argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="hbc",
|
||||
description="HeartBeatClient - send heartbeat messages to HeartBeatDaemon",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-b", "--boot",
|
||||
action="store_true",
|
||||
help="Send a boot message"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c", "--config",
|
||||
dest="configfile",
|
||||
help="Config file path (YAML)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m", "--message",
|
||||
dest="message",
|
||||
help="Send a message"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n", "--name",
|
||||
dest="name",
|
||||
help="Name to use in heartbeat message"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d", "--daemon",
|
||||
action="store_true",
|
||||
help="Run in daemon mode"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose",
|
||||
action="store_true",
|
||||
help="Verbose output"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x", "--debug",
|
||||
action="count",
|
||||
default=0,
|
||||
help="Increase debug level"
|
||||
)
|
||||
parser.add_argument(
|
||||
"hosts",
|
||||
nargs="+",
|
||||
help="Heartbeat daemon hosts to send to"
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
"""Main entry point."""
|
||||
global running, dorestart
|
||||
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
# Load config
|
||||
config = load_config(args.configfile)
|
||||
|
||||
# Setup logging
|
||||
log_level = logging.INFO
|
||||
if args.verbose:
|
||||
log_level = logging.DEBUG
|
||||
if args.debug:
|
||||
log_level = logging.DEBUG
|
||||
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="%(asctime)s %(name)s %(levelname)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
# Daemonize if requested
|
||||
if args.daemon:
|
||||
print("Daemonizing...")
|
||||
import syslog
|
||||
syslog.openlog("hbc", syslog.LOG_PID, syslog.LOG_DAEMON)
|
||||
syslog.syslog(syslog.LOG_INFO, f"Starting heartbeat to {', '.join(args.hosts)}")
|
||||
daemonize()
|
||||
|
||||
# Reconfigure logging for syslog
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format="hbc[%(process)d]: %(name)s %(levelname)s: %(message)s"
|
||||
)
|
||||
|
||||
# Run async main
|
||||
try:
|
||||
exit_code = asyncio.run(async_main(args, config))
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Interrupted by user")
|
||||
exit_code = 0
|
||||
except Exception as e:
|
||||
logging.error(f"Fatal error: {e}", exc_info=True)
|
||||
exit_code = 1
|
||||
|
||||
# Handle restart
|
||||
if dorestart:
|
||||
logging.info("Restarting...")
|
||||
os.execv(sys.argv[0], sys.argv)
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user