Files
heartbeat/hbd/server/main.py
T

491 lines
17 KiB
Python

"""Server runtime: starts UDP listener, HTTP server and websocket stubs."""
import asyncio
import logging
import socket
import time
import signal
import sys
import ssl
from . import __version__
from . import udp
from . import hbdclass
from . import ws as ws_mod
from . import notify as notify_mod
from . import data
logger = logging.getLogger(__name__)
msg_to_websockets = ws_mod.broadcast
eventlog = notify_mod.eventlog
# shared runtime collections and helpers
def save_state(config, hbdclass):
"""Save current state to pickle file. Safe to call at any time."""
import pickle
import os
# Clear timer references before pickling (they can't be serialized)
for hostname, host in list(hbdclass.Host.hosts.items()):
for conn_type, conn in host.connections.items():
if hasattr(conn, 'cancel_overdue_timer'):
conn.cancel_overdue_timer()
if hasattr(conn, 'overdue_timer'):
conn.overdue_timer = None
if hasattr(conn, 'overdue_callback'):
conn.overdue_callback = None
if hasattr(conn, 'timeout_duration'):
conn.timeout_duration = None
pickfile = config.get("pickfile", "hbd.pickle")
tmpfile = pickfile + ".tmp"
try:
with open(tmpfile, "wb") as pickf:
pick = pickle.Pickler(pickf)
pick.dump(hbdclass.Host.hosts)
pick.dump(data.msgs)
os.replace(tmpfile, pickfile)
except Exception as e:
logger.error("Failed to save state: %s", e)
try:
os.unlink(tmpfile)
except Exception:
pass
def cleanup_function(config, hbdclass):
"""This function will be executed upon program exit."""
logger.info("Running cleanup function...")
save_state(config, hbdclass)
logger.info("Cleanup complete.")
async def reload_configuration(config_obj, config_path, components):
"""Reload configuration and update all components.
Args:
config_obj: ReloadableConfig instance
config_path: Path to config file
components: Dict with threshold_checker and other components
Returns:
True if reload succeeded, False otherwise
"""
try:
logger.info("=" * 60)
logger.info("Starting configuration reload...")
logger.info("=" * 60)
# Reload config file
new_config = await config_obj.reload(config_path)
# Update notify module
notify_mod.reload_config(new_config)
# Reload threshold checker
if 'threshold_checker' in components:
components['threshold_checker'].reload(new_config)
# Note: Changes to the following require restart:
# - hb_port, hbd_port, ws_port (already bound)
# - SSL certificates (already loaded)
# - pickfile (already opened)
# - journal settings (journal already initialized)
# These are reloadable and effective immediately:
# - notification_channels
# - threshold_configs
# - hosts (watchhosts, dyndnshosts, notification_channels)
# - grace period (used on next heartbeat)
# - debug/verbose flags (used on next message)
logger.info("=" * 60)
logger.info("Configuration reload completed successfully")
logger.info("=" * 60)
return True
except Exception as e:
logger.error("=" * 60)
logger.error(f"Failed to reload configuration: {e}", exc_info=True)
logger.error("Keeping previous configuration")
logger.error("=" * 60)
return False
async def _run_async(config, config_path=None):
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
reload_event = asyncio.Event()
# Signal handlers for graceful shutdown and reload
def signal_handler(signum, frame):
sig_name = signal.Signals(signum).name if hasattr(signal, "Signals") else signum
logger.info(f"Received {sig_name}, initiating shutdown...")
loop.call_soon_threadsafe(shutdown_event.set)
def reload_handler(signum, frame):
sig_name = signal.Signals(signum).name if hasattr(signal, "Signals") else signum
logger.info(f"Received {sig_name}, initiating config reload...")
loop.call_soon_threadsafe(reload_event.set)
# Register signal handlers
loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT, None)
loop.add_signal_handler(signal.SIGTERM, signal_handler, signal.SIGTERM, None)
loop.add_signal_handler(signal.SIGHUP, reload_handler, signal.SIGHUP, None)
from . import http as http_mod
from . import dns as dns_mod
from . import notify as notify_mod
from . import journal as journal_mod
from . import threshold as threshold_mod
notify_mod.setup(config)
# Initialize message journal
msg_journal = journal_mod.get_journal(config)
await msg_journal.initialize()
# Initialize threshold checker
threshold_checker = threshold_mod.ThresholdChecker(
config=config,
renotify_interval=config.get("threshold_renotify_interval", 3600),
journal=msg_journal,
)
logger.info("Threshold checker initialized")
# Components dict for reload orchestration
components = {
'threshold_checker': threshold_checker,
'msg_journal': msg_journal,
}
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# Disable IPV6_V6ONLY option to enable dual-stack (listen on IPv4 as well)
# This option is system-dependent; on many systems, setting it to False enables
# the socket to handle both IPv4 and IPv6 traffic.
try:
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
except OSError as e:
logger.warning(
f"Warning: Could not reset IPV6_V6ONLY not supported or dual-stack is unavailable. Error: {e}"
)
# 3. Bind to all interfaces (::) on a specific port
# UDP server endpoint (handler wired to handle_datagram with context)
bind_addr = ("::", config.get("hb_port", 50003))
sock.bind(bind_addr)
logger.info("Starting UDP server on %s:%s", *bind_addr)
def udp_handler(msg, addr, transport):
ctx = dict(
config=config,
hbdclass=hbdclass,
log=eventlog,
msg_to_websockets=msg_to_websockets,
msg_journal=msg_journal,
threshold_checker=threshold_checker,
DEBUG=config.get("debug", 0),
verbose=config.get("verbose", False),
)
udp.handle_datagram(msg, addr, transport, ctx)
transport, protocol = await loop.create_datagram_endpoint(
lambda: udp.EchoServerProtocol(config=config, handler=udp_handler),
sock=sock,
)
# Restore connection timers for hosts loaded from pickle
restore_ctx = dict(
config=config,
hbdclass=hbdclass,
log=eventlog,
msg_to_websockets=msg_to_websockets,
threshold_checker=threshold_checker,
)
udp.restore_connection_timers(hbdclass, restore_ctx)
# HTTP server (asyncio-based via aiohttp)
try:
http_task = asyncio.create_task(
http_mod.start(
host=config.get("hbd_host", ""),
port=config.get("hbd_port", 50004),
config=config,
hbdclass=hbdclass,
tcss=None,
verbose=config.get("verbose", False),
get_now=lambda: time.time(),
VER="",
)
)
logger.info(
"HTTP server started on %s:%s",
config.get("hbd_host", ""),
config.get("hbd_port", 50004),
)
except Exception as e:
logger.exception("failed to start HTTP server: %s", e)
# start dns update worker (async)
dns_task = None
try:
dns_task = dns_mod.start_dns_worker(
hbdclass, config, log=eventlog, loop=loop
)
logger.info("dns update worker started")
except Exception as e:
logger.exception("dns worker failed to start: %s", e)
# Start the websocket servers as a background task
if config.get("wss_port", None):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_path = config.get("cert_path", "")
wss_pem = ssl_path + config.get("wss_pem", "")
wss_key = ssl_path + config.get("wss_key", "")
try:
ssl_context.load_cert_chain(wss_pem, keyfile=wss_key)
except FileNotFoundError:
logger.error("error: missing SSL keys %s or %s", wss_pem, wss_key)
sys.exit(1)
logger.info(
"Starting secure WebSocket server on port %s with cert %s",
config.get("wss_port", None),
wss_pem,
)
else:
ssl_context = None
try:
ws_port = config.get("ws_port", 50005)
logger.info("Starting WebSocket server on port %s", ws_port)
ws_task = asyncio.create_task(
ws_mod.start(
host=config.get("hbd_host", ""),
ws_port=ws_port,
wss_port=config.get("wss_port", None),
ssl_context=ssl_context,
get_hosts=lambda: [
hbdclass.Host.hosts[h].stateinfo()
for h in sorted(hbdclass.Host.hosts)
],
# get_msgs=lambda: msgs,
config=config,
)
)
logger.info("WebSocket task started")
except Exception as e:
logger.exception("websocket server failed to start: %s", e)
# Periodic autosave task
autosave_interval = config.get("autosave_interval", 300) # default: 5 minutes
async def autosave_task():
while True:
await asyncio.sleep(autosave_interval)
logger.debug("Autosaving state...")
save_state(config, hbdclass)
logger.debug("Autosave complete (%d hosts)", len(hbdclass.Host.hosts))
autosave = asyncio.create_task(autosave_task())
logger.info("Autosave task started (interval: %ds)", autosave_interval)
# Main event loop - monitor shutdown and reload events
try:
while True:
# Wait for either shutdown or reload event
done, pending = await asyncio.wait(
[
asyncio.create_task(shutdown_event.wait()),
asyncio.create_task(reload_event.wait()),
],
return_when=asyncio.FIRST_COMPLETED
)
# Check which event was triggered
if shutdown_event.is_set():
logger.info("Shutdown signal received, stopping services...")
# Cancel pending wait tasks
for task in pending:
task.cancel()
break
if reload_event.is_set():
# Clear the event for next reload
reload_event.clear()
# Cancel pending wait tasks
for task in pending:
task.cancel()
# Perform reload if config_path is available
if config_path:
await reload_configuration(config, config_path, components)
else:
logger.warning("Cannot reload: no config path available")
# Continue main loop
continue
except Exception as e:
logger.exception("Error in main loop: %s", e)
finally:
# Cancel all running tasks
logger.info("Cancelling tasks...")
try:
transport.close()
except Exception as e:
logger.warning("Error closing UDP transport: %s", e)
tasks_to_cancel = [http_task, ws_task, autosave]
for task in tasks_to_cancel:
if task:
try:
task.cancel()
logger.debug("Cancelled task: %s", task)
except Exception as e:
logger.warning("Error cancelling task: %s", e)
# Wait for tasks to finish cancellation with timeout
remaining_tasks = [t for t in tasks_to_cancel if t]
if remaining_tasks:
try:
await asyncio.wait_for(
asyncio.gather(*remaining_tasks, return_exceptions=True),
timeout=2.0,
)
except asyncio.TimeoutError:
logger.warning("Timeout waiting for tasks to cancel")
except Exception as e:
logger.debug("Exception during task cancellation: %s", e)
# Close message journal
try:
await msg_journal.close()
except Exception as e:
logger.warning("Error closing message journal: %s", e)
# Signal DNS worker to exit and await it
try:
if "dns_task" in locals() and dns_task:
try:
hbdclass.Host.dnsQ.put(None)
except Exception:
pass
try:
await asyncio.wait_for(dns_task, timeout=2.0)
logger.info("DNS worker finished")
except asyncio.TimeoutError:
logger.warning("Timeout waiting for DNS worker to finish")
dns_task.cancel()
except asyncio.CancelledError:
logger.info("DNS worker was cancelled")
except Exception as e:
logger.warning("Error awaiting DNS worker: %s", e)
finally:
# Clear queue bridge to release any held references
hbdclass.Host.dnsQ = None
except Exception as e:
logger.warning("Error stopping DNS worker: %s", e)
logger.info("All tasks cancelled")
def load_pickled_hosts(config, hbdclass):
"""Load pickled hosts from file, if available."""
import os
import pickle
from . import config as config_mod
pickfile = config.get("pickfile", "hbd.pickle")
dyndnshosts = config_mod.get_dyndnshosts(config)
watchhosts = config_mod.get_watchhosts(config)
drophosts = config.get("drophosts", [])
if 1 and os.path.exists(pickfile):
if config.get("verbose", False):
logger.info("opening pickls %s", pickfile)
pickf = open(pickfile, "rb")
pick = pickle.Unpickler(pickf)
try:
hbdclass.Host.hosts = pick.load()
data.msgs = pick.load()
pickf.close()
except Exception as e:
logger.exception("load pickled failed: %s", e)
os.unlink(pickfile)
hbdclass.Connection.htab = {}
for h in list(hbdclass.Host.hosts.keys()):
hbdclass.Host.hosts[h].dyn = h in dyndnshosts
hbdclass.Host.hosts[h].watched = h in watchhosts
hbdclass.Host.hosts[h].fixup()
for h in drophosts:
if h in hbdclass.Host.hosts:
del hbdclass.Host.hosts[h]
if config.get("verbose", False):
logger.info("%s pickled hosts loaded", len(hbdclass.Host.hosts))
else:
if config.get("verbose", False):
logger.info("no pickled data")
def run(config, config_path=None):
"""Start the hbd service (blocking).
Manually manages the event loop to ensure clean shutdown.
Args:
config: Configuration dictionary
config_path: Path to config file (for reload support)
"""
import os
logging.basicConfig(
level=logging.DEBUG if config.get("debug", 0) > 0 else logging.INFO
)
load_pickled_hosts(config, hbdclass)
notify_mod.initlog(logfile=config.get("logfile", "messages.log"))
eventlog(None, "INFO", f"hbd version {__version__} starting up")
if config_path:
logger.info(f"Config file: {config_path} (reload with SIGHUP)")
else:
logger.warning("No config path provided - reload via SIGHUP disabled")
# Create and set the event loop manually
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(_run_async(config, config_path=config_path))
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down...")
except Exception as e:
logger.exception("Unhandled exception in main: %s", e)
finally:
cleanup_function(config, hbdclass)
logger.info("hbd shutdown complete")
eventlog(None, "INFO", f"hbd version {__version__} shutdown")
notify_mod.closelog()
# Explicitly close the loop
try:
# Cancel all remaining tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# Run one more cycle to process cancellations
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
except Exception:
pass
finally:
loop.close()
# Exit
os._exit(0)