Major refactoring of the codebase, including restructuring of files and directories, renaming of modules and classes, and improvements to the overall organization and readability of the code. This refactoring aims to enhance maintainability, scalability, and clarity of the codebase while preserving existing functionality. The changes include:
- Restructuring of the project directory into client and server components - Renaming of modules and classes to better reflect their purpose and functionality - Moving common utilities and configurations to a shared location - Updating import statements to reflect the new structure - Adding new documentation files for better clarity on various aspects of the project - Removing deprecated or unused code to streamline the codebase - Ensuring that all existing functionality is preserved and that the codebase remains functional after the refactoring.
This commit is contained in:
@@ -0,0 +1,396 @@
|
||||
"""Server runtime: starts UDP listener, HTTP server and websocket stubs."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
import ssl
|
||||
from . import __version__
|
||||
|
||||
from . import udp
|
||||
from . import hbdclass
|
||||
|
||||
from . import ws as ws_mod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
msg_to_websockets = ws_mod.broadcast
|
||||
|
||||
logf = None
|
||||
lastfm = ["", "", ""]
|
||||
|
||||
# shared runtime collections and helpers
|
||||
msgs = []
|
||||
|
||||
|
||||
def initlog(logfile):
|
||||
try:
|
||||
return open(logfile, "a+")
|
||||
except Exception as e:
|
||||
import sys
|
||||
|
||||
print("cannot open loffile %s, using STDERR: %s" % (logfile, e))
|
||||
return sys.stderr
|
||||
|
||||
|
||||
def log(host, m, service=None):
|
||||
ts = time.time()
|
||||
s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {host or ''} {m}"
|
||||
msgs.append(s)
|
||||
logger.info(s)
|
||||
if logf:
|
||||
try:
|
||||
logf.write(s + "\n")
|
||||
logf.flush()
|
||||
except Exception as e:
|
||||
logger.warning("failed to write to logfile: %s", e)
|
||||
msg_to_websockets("message", s)
|
||||
|
||||
|
||||
def cleanup_function(config):
|
||||
"""This function will be executed upon program exit."""
|
||||
logger.info("Running cleanup function...")
|
||||
import pickle
|
||||
|
||||
pickfile = config.get("pickfile", "hbd.pickle")
|
||||
|
||||
pickf = open(pickfile, "wb")
|
||||
pick = pickle.Pickler(pickf)
|
||||
pick.dump(hbdclass.Host.hosts)
|
||||
pick.dump(msgs)
|
||||
pick.dump(lastfm)
|
||||
pickf.close()
|
||||
|
||||
logger.info("Cleanup complete.")
|
||||
|
||||
|
||||
async def _run_async(config):
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
# Signal handlers for graceful shutdown
|
||||
def signal_handler(signum, frame):
|
||||
sig_name = signal.Signals(signum).name if hasattr(signal, "Signals") else signum
|
||||
logger.info(f"Received {sig_name}, initiating shutdown...")
|
||||
loop.call_soon_threadsafe(shutdown_event.set)
|
||||
|
||||
# Register signal handlers
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT, None)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler, signal.SIGTERM, None)
|
||||
|
||||
from . import http as http_mod
|
||||
from . import dns as dns_mod
|
||||
from . import notify as notify_mod
|
||||
from . import monitor as monitor_mod
|
||||
from . import journal as journal_mod
|
||||
from ..client import threshold as threshold_mod
|
||||
|
||||
notify_mod.setup(config)
|
||||
|
||||
# Initialize message journal
|
||||
msg_journal = journal_mod.get_journal(config)
|
||||
await msg_journal.initialize()
|
||||
|
||||
# Initialize threshold checker
|
||||
threshold_checker = threshold_mod.ThresholdChecker(
|
||||
config=config,
|
||||
notification_callback=notify_mod.pushmsg_from_config,
|
||||
renotify_interval=config.get("threshold_renotify_interval", 3600),
|
||||
journal=msg_journal,
|
||||
)
|
||||
logger.info("Threshold checker initialized")
|
||||
|
||||
pushmsg = notify_mod.pushmsg_from_config
|
||||
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||
# Disable IPV6_V6ONLY option to enable dual-stack (listen on IPv4 as well)
|
||||
# This option is system-dependent; on many systems, setting it to False enables
|
||||
# the socket to handle both IPv4 and IPv6 traffic.
|
||||
try:
|
||||
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"Warning: Could not reset IPV6_V6ONLY not supported or dual-stack is unavailable. Error: {e}"
|
||||
)
|
||||
|
||||
# 3. Bind to all interfaces (::) on a specific port
|
||||
|
||||
# UDP server endpoint (handler wired to handle_datagram with context)
|
||||
bind_addr = ("::", config.get("hb_port", 50003))
|
||||
sock.bind(bind_addr)
|
||||
logger.info("Starting UDP server on %s:%s", *bind_addr)
|
||||
|
||||
def udp_handler(msg, addr, transport):
|
||||
ctx = dict(
|
||||
config=config,
|
||||
hbdclass=hbdclass,
|
||||
log=log,
|
||||
pushmsg=pushmsg,
|
||||
msg_to_websockets=msg_to_websockets,
|
||||
msg_journal=msg_journal,
|
||||
threshold_checker=threshold_checker,
|
||||
DEBUG=config.get("debug", 0),
|
||||
verbose=config.get("verbose", False),
|
||||
)
|
||||
udp.handle_datagram(msg, addr, transport, ctx)
|
||||
|
||||
transport, protocol = await loop.create_datagram_endpoint(
|
||||
lambda: udp.EchoServerProtocol(config=config, handler=udp_handler),
|
||||
sock=sock,
|
||||
)
|
||||
|
||||
# HTTP server (asyncio-based via aiohttp)
|
||||
try:
|
||||
http_task = asyncio.create_task(
|
||||
http_mod.start(
|
||||
host=config.get("hbd_host", ""),
|
||||
port=config.get("hbd_port", 50004),
|
||||
config=config,
|
||||
hbdclass=hbdclass,
|
||||
msgs_getter=lambda: msgs,
|
||||
log=log,
|
||||
pushmsg=pushmsg,
|
||||
msg_to_websockets=msg_to_websockets,
|
||||
threshold_checker=threshold_checker,
|
||||
tcss=None,
|
||||
DEBUG=config.get("debug", 0),
|
||||
verbose=config.get("verbose", False),
|
||||
get_now=lambda: time.time(),
|
||||
VER="",
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"HTTP server started on %s:%s",
|
||||
config.get("hbd_host", ""),
|
||||
config.get("hbd_port", 50004),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("failed to start HTTP server: %s", e)
|
||||
|
||||
# start dns update worker (async)
|
||||
dns_task = None
|
||||
try:
|
||||
dns_task = dns_mod.start_dns_worker(
|
||||
hbdclass, config, log=log, pushmsg=pushmsg, loop=loop
|
||||
)
|
||||
logger.info("dns update worker started")
|
||||
except Exception as e:
|
||||
logger.exception("dns worker failed to start: %s", e)
|
||||
|
||||
# Start the websocket servers as a background task
|
||||
if config.get("wss_port", None):
|
||||
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ssl_path = config.get("cert_path", "")
|
||||
wss_pem = ssl_path + config.get("wss_pem", "")
|
||||
wss_key = ssl_path + config.get("wss_key", "")
|
||||
try:
|
||||
ssl_context.load_cert_chain(wss_pem, keyfile=wss_key)
|
||||
except FileNotFoundError:
|
||||
logger.error("error: missing SSL keys %s or %s", wss_pem, wss_key)
|
||||
sys.exit(1)
|
||||
logger.info(
|
||||
"Starting secure WebSocket server on port %s with cert %s",
|
||||
config.get("wss_port", None),
|
||||
wss_pem,
|
||||
)
|
||||
else:
|
||||
ssl_context = None
|
||||
|
||||
try:
|
||||
ws_port = config.get("ws_port", 50005)
|
||||
logger.info("Starting WebSocket server on port %s", ws_port)
|
||||
ws_task = asyncio.create_task(
|
||||
ws_mod.start(
|
||||
host=config.get("hbd_host", ""),
|
||||
ws_port=ws_port,
|
||||
wss_port=config.get("wss_port", None),
|
||||
ssl_context=ssl_context,
|
||||
get_hosts=lambda: [
|
||||
hbdclass.Host.hosts[h].stateinfo()
|
||||
for h in sorted(hbdclass.Host.hosts)
|
||||
],
|
||||
get_msgs=lambda: msgs,
|
||||
verbose=config.get("verbose", False),
|
||||
)
|
||||
)
|
||||
logger.info("WebSocket task started")
|
||||
except Exception as e:
|
||||
logger.exception("websocket server failed to start: %s", e)
|
||||
|
||||
# Start the monitor thread as a background task
|
||||
try:
|
||||
monitor_task = asyncio.create_task(
|
||||
monitor_mod.start(
|
||||
config=config,
|
||||
hbdclass=hbdclass,
|
||||
log=log,
|
||||
pushmsg=pushmsg,
|
||||
msg_to_websockets=msg_to_websockets,
|
||||
)
|
||||
)
|
||||
logger.info("Monitor task started")
|
||||
except Exception as e:
|
||||
logger.exception("monitor task failed to start: %s", e)
|
||||
|
||||
try:
|
||||
# run forever until shutdown event is set
|
||||
await shutdown_event.wait()
|
||||
logger.info("Shutdown signal received, stopping services...")
|
||||
except Exception as e:
|
||||
logger.exception("Error in main loop: %s", e)
|
||||
finally:
|
||||
# Cancel all running tasks
|
||||
logger.info("Cancelling tasks...")
|
||||
try:
|
||||
transport.close()
|
||||
except Exception as e:
|
||||
logger.warning("Error closing UDP transport: %s", e)
|
||||
|
||||
tasks_to_cancel = [http_task, ws_task, monitor_task]
|
||||
for task in tasks_to_cancel:
|
||||
if task:
|
||||
try:
|
||||
task.cancel()
|
||||
logger.debug("Cancelled task: %s", task)
|
||||
except Exception as e:
|
||||
logger.warning("Error cancelling task: %s", e)
|
||||
|
||||
# Wait for tasks to finish cancellation with timeout
|
||||
remaining_tasks = [t for t in tasks_to_cancel if t]
|
||||
if remaining_tasks:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*remaining_tasks, return_exceptions=True),
|
||||
timeout=2.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout waiting for tasks to cancel")
|
||||
except Exception as e:
|
||||
logger.debug("Exception during task cancellation: %s", e)
|
||||
|
||||
# Close message journal
|
||||
try:
|
||||
await msg_journal.close()
|
||||
except Exception as e:
|
||||
logger.warning("Error closing message journal: %s", e)
|
||||
|
||||
# Signal DNS worker to exit and await it
|
||||
try:
|
||||
if "dns_task" in locals() and dns_task:
|
||||
try:
|
||||
hbdclass.Host.dnsQ.put(None)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await asyncio.wait_for(dns_task, timeout=2.0)
|
||||
logger.info("DNS worker finished")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout waiting for DNS worker to finish")
|
||||
dns_task.cancel()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("DNS worker was cancelled")
|
||||
except Exception as e:
|
||||
logger.warning("Error awaiting DNS worker: %s", e)
|
||||
finally:
|
||||
# Clear queue bridge to release any held references
|
||||
hbdclass.Host.dnsQ = None
|
||||
except Exception as e:
|
||||
logger.warning("Error stopping DNS worker: %s", e)
|
||||
|
||||
logger.info("All tasks cancelled")
|
||||
|
||||
|
||||
def load_pickled_hosts(config, hbdclass):
|
||||
"""Load pickled hosts from file, if available."""
|
||||
global lastfm, msgs
|
||||
import os
|
||||
import pickle
|
||||
|
||||
pickfile = config.get("pickfile", "hbd.pickle")
|
||||
dyndnshosts = config.get("dyndnshosts", [])
|
||||
watchhosts = config.get("watchhosts", [])
|
||||
drophosts = config.get("drophosts", [])
|
||||
if 1 and os.path.exists(pickfile):
|
||||
if config.get("verbose", False):
|
||||
logger.info("opening pickls %s", pickfile)
|
||||
pickf = open(pickfile, "rb")
|
||||
pick = pickle.Unpickler(pickf)
|
||||
try:
|
||||
hbdclass.Host.hosts = pick.load()
|
||||
msgs = pick.load()
|
||||
try:
|
||||
lastfm = pick.load()
|
||||
except Exception:
|
||||
lastfm = ["", "", ""]
|
||||
pickf.close()
|
||||
except Exception as e:
|
||||
logger.exception("load pickled failed: %s", e)
|
||||
os.unlink(pickfile)
|
||||
hbdclass.Connection.htab = {}
|
||||
for h in list(hbdclass.Host.hosts.keys()):
|
||||
hbdclass.Host.hosts[h].dyn = h in dyndnshosts
|
||||
hbdclass.Host.hosts[h].watched = h in watchhosts
|
||||
hbdclass.Host.hosts[h].fixup()
|
||||
for h in drophosts:
|
||||
if h in hbdclass.Host.hosts:
|
||||
del hbdclass.Host.hosts[h]
|
||||
if config.get("verbose", False):
|
||||
logger.info("%s pickled hosts loaded", len(hbdclass.Host.hosts))
|
||||
else:
|
||||
if config.get("verbose", False):
|
||||
logger.info("no pickled data")
|
||||
|
||||
|
||||
def run(config):
|
||||
"""Start the hbd service (blocking).
|
||||
|
||||
Manually manages the event loop to ensure clean shutdown.
|
||||
"""
|
||||
global logf
|
||||
import os
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if config.get("debug", 0) > 0 else logging.INFO
|
||||
)
|
||||
load_pickled_hosts(config, hbdclass)
|
||||
|
||||
logf = initlog(logfile=config.get("logfile", "messages.log"))
|
||||
log(None, f"hbd version {__version__} starting up")
|
||||
|
||||
# Create and set the event loop manually
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(_run_async(config))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received KeyboardInterrupt, shutting down...")
|
||||
except Exception as e:
|
||||
logger.exception("Unhandled exception in main: %s", e)
|
||||
finally:
|
||||
cleanup_function(config)
|
||||
logger.info("hbd shutdown complete")
|
||||
if logf and logf != sys.stderr:
|
||||
try:
|
||||
logf.close()
|
||||
except Exception:
|
||||
pass
|
||||
# Explicitly close the loop
|
||||
try:
|
||||
# Cancel all remaining tasks
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
# Run one more cycle to process cancellations
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Exit
|
||||
os._exit(0)
|
||||
Reference in New Issue
Block a user