refactor and rewrite for asyncio
This commit is contained in:
+242
-48
@@ -1,38 +1,88 @@
|
||||
"""Server runtime: starts UDP listener, HTTP server and websocket stubs."""
|
||||
import asyncio
|
||||
import logging
|
||||
import atexit
|
||||
import time
|
||||
import signal
|
||||
import sys
|
||||
from . import __version__
|
||||
|
||||
from . import udp
|
||||
from . import hbdclass
|
||||
from . import ws as ws_mod
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
msg_to_websockets = ws_mod.broadcast
|
||||
|
||||
logf = None
|
||||
lastfm = ["", "", ""]
|
||||
|
||||
# shared runtime collections and helpers
|
||||
msgs = []
|
||||
|
||||
def initlog(logfile):
|
||||
try:
|
||||
return open(logfile, "a+")
|
||||
except Exception as e:
|
||||
import sys
|
||||
print("cannot open loffile %s, using STDERR: %s" % (logfile, e))
|
||||
return sys.stderr
|
||||
|
||||
def log(host, m, service=None):
|
||||
ts = time.time()
|
||||
s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {host or ''} {m}"
|
||||
msgs.append(s)
|
||||
logger.info(s)
|
||||
if logf:
|
||||
try:
|
||||
logf.write(s + "\n")
|
||||
logf.flush()
|
||||
except Exception as e:
|
||||
logger.warning("failed to write to logfile: %s", e)
|
||||
msg_to_websockets("message", s)
|
||||
|
||||
def cleanup_function(config):
|
||||
"""This function will be executed upon program exit."""
|
||||
logger.info("Running cleanup function...")
|
||||
import pickle
|
||||
pickfile = config.get("pickfile", "hbd.pickle")
|
||||
|
||||
pickf = open(pickfile, "wb")
|
||||
pick = pickle.Pickler(pickf)
|
||||
pick.dump(hbdclass.Host.hosts)
|
||||
pick.dump(msgs)
|
||||
pick.dump(lastfm)
|
||||
pickf.close()
|
||||
|
||||
logger.info("Cleanup complete.")
|
||||
|
||||
async def _run_async(config):
|
||||
global msgs
|
||||
loop = asyncio.get_running_loop()
|
||||
shutdown_event = asyncio.Event()
|
||||
|
||||
# shared runtime collections and helpers
|
||||
msgs = []
|
||||
# Signal handlers for graceful shutdown
|
||||
def signal_handler(signum, frame):
|
||||
sig_name = signal.Signals(signum).name if hasattr(signal, 'Signals') else signum
|
||||
logger.info(f"Received {sig_name}, initiating shutdown...")
|
||||
loop.call_soon_threadsafe(shutdown_event.set)
|
||||
|
||||
# Register signal handlers
|
||||
loop.add_signal_handler(signal.SIGINT, signal_handler, signal.SIGINT, None)
|
||||
loop.add_signal_handler(signal.SIGTERM, signal_handler, signal.SIGTERM, None)
|
||||
|
||||
# prepare runtime dependencies
|
||||
import threading
|
||||
import time
|
||||
import hbdclass
|
||||
from . import hbdclass
|
||||
from . import http as http_mod
|
||||
from . import ws as ws_mod
|
||||
from . import dns as dns_mod
|
||||
from . import notify as notify_mod
|
||||
from . import monitor as monitor_mod
|
||||
|
||||
notify_mod.setup(config)
|
||||
|
||||
def log(host, m, service=None):
|
||||
ts = time.time()
|
||||
s = f"{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(ts))} {host or ''} {m}"
|
||||
msgs.append(s)
|
||||
logger.info(s)
|
||||
|
||||
email = notify_mod.email
|
||||
pushmsg = notify_mod.pushmsg_from_config
|
||||
msg_to_websockets = ws_mod.broadcast
|
||||
|
||||
# UDP server endpoint (handler wired to handle_datagram with context)
|
||||
bind_addr = ("0.0.0.0", config.get("hb_port", 50003))
|
||||
@@ -46,7 +96,6 @@ async def _run_async(config):
|
||||
email=email,
|
||||
pushmsg=pushmsg,
|
||||
msg_to_websockets=msg_to_websockets,
|
||||
msgs=msgs,
|
||||
DEBUG=config.get("debug", 0),
|
||||
verbose=config.get("verbose", False),
|
||||
)
|
||||
@@ -57,32 +106,37 @@ async def _run_async(config):
|
||||
local_addr=bind_addr,
|
||||
)
|
||||
|
||||
# HTTP server (runs in its own thread)
|
||||
# HTTP server (asyncio-based via aiohttp)
|
||||
try:
|
||||
handler_cls = http_mod.make_handler_class(
|
||||
config=config,
|
||||
hbdclass=hbdclass,
|
||||
msgs_getter=lambda: msgs,
|
||||
log=log,
|
||||
email=email,
|
||||
pushmsg=pushmsg,
|
||||
msg_to_websockets=msg_to_websockets,
|
||||
tcss=None,
|
||||
DEBUG=config.get("debug", 0),
|
||||
verbose=config.get("verbose", False),
|
||||
get_now=lambda: time.time(),
|
||||
VER="",
|
||||
http_task = asyncio.create_task(
|
||||
http_mod.start(
|
||||
host=config.get("hbd_host", ""),
|
||||
port=config.get("hbd_port", 50004),
|
||||
config=config,
|
||||
hbdclass=hbdclass,
|
||||
msgs_getter=lambda: msgs,
|
||||
log=log,
|
||||
email=email,
|
||||
pushmsg=pushmsg,
|
||||
msg_to_websockets=msg_to_websockets,
|
||||
tcss=None,
|
||||
DEBUG=config.get("debug", 0),
|
||||
verbose=config.get("verbose", False),
|
||||
get_now=lambda: time.time(),
|
||||
VER="",
|
||||
)
|
||||
)
|
||||
serv = http_mod.HttpServer((config.get("hbd_host", ""), config.get("hbd_port", 50004)), handler_cls)
|
||||
http_thread = threading.Thread(target=serv.serve_forever, daemon=True)
|
||||
http_thread.start()
|
||||
logger.info("HTTP server started on %s:%s", config.get("hbd_host", ""), config.get("hbd_port", 50004))
|
||||
except Exception as e:
|
||||
logger.exception("failed to start HTTP server: %s", e)
|
||||
|
||||
# start dns update thread
|
||||
dns_mod.start_dns_thread(hbdclass, config, log=log, email=email)
|
||||
logger.info("dns update thread started")
|
||||
# start dns update worker (async)
|
||||
dns_task = None
|
||||
try:
|
||||
dns_task = dns_mod.start_dns_worker(hbdclass, config, log=log, email=email, loop=loop)
|
||||
logger.info("dns update worker started")
|
||||
except Exception as e:
|
||||
logger.exception("dns worker failed to start: %s", e)
|
||||
|
||||
# Start the websocket servers as a background task
|
||||
try:
|
||||
@@ -101,28 +155,168 @@ async def _run_async(config):
|
||||
except Exception as e:
|
||||
logger.exception("websocket server failed to start: %s", e)
|
||||
|
||||
# Start the monitor thread as a background task
|
||||
try:
|
||||
# run forever
|
||||
await asyncio.Future()
|
||||
finally:
|
||||
transport.close()
|
||||
try:
|
||||
serv.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
ws_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
monitor_task = asyncio.create_task(
|
||||
monitor_mod.start(
|
||||
config=config,
|
||||
hbdclass=hbdclass,
|
||||
log=log,
|
||||
email=email,
|
||||
pushmsg=pushmsg,
|
||||
msg_to_websockets=msg_to_websockets,
|
||||
)
|
||||
)
|
||||
logger.info("Monitor task started")
|
||||
except Exception as e:
|
||||
logger.exception("monitor task failed to start: %s", e)
|
||||
|
||||
try:
|
||||
# run forever until shutdown event is set
|
||||
await shutdown_event.wait()
|
||||
logger.info("Shutdown signal received, stopping services...")
|
||||
except Exception as e:
|
||||
logger.exception("Error in main loop: %s", e)
|
||||
finally:
|
||||
# Cancel all running tasks
|
||||
logger.info("Cancelling tasks...")
|
||||
try:
|
||||
transport.close()
|
||||
except Exception as e:
|
||||
logger.warning("Error closing UDP transport: %s", e)
|
||||
|
||||
tasks_to_cancel = [http_task, ws_task, monitor_task]
|
||||
for task in tasks_to_cancel:
|
||||
if task:
|
||||
try:
|
||||
task.cancel()
|
||||
logger.debug("Cancelled task: %s", task)
|
||||
except Exception as e:
|
||||
logger.warning("Error cancelling task: %s", e)
|
||||
|
||||
# Wait for tasks to finish cancellation with timeout
|
||||
remaining_tasks = [t for t in tasks_to_cancel if t]
|
||||
if remaining_tasks:
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout waiting for tasks to cancel")
|
||||
except Exception as e:
|
||||
logger.debug("Exception during task cancellation: %s", e)
|
||||
|
||||
# Signal DNS worker to exit and await it
|
||||
try:
|
||||
if 'dns_task' in locals() and dns_task:
|
||||
try:
|
||||
hbdclass.Host.dnsQ.put(None)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await asyncio.wait_for(dns_task, timeout=2.0)
|
||||
logger.info("DNS worker finished")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout waiting for DNS worker to finish")
|
||||
dns_task.cancel()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("DNS worker was cancelled")
|
||||
except Exception as e:
|
||||
logger.warning("Error awaiting DNS worker: %s", e)
|
||||
finally:
|
||||
# Clear queue bridge to release any held references
|
||||
hbdclass.Host.dnsQ = None
|
||||
except Exception as e:
|
||||
logger.warning("Error stopping DNS worker: %s", e)
|
||||
|
||||
logger.info("All tasks cancelled")
|
||||
|
||||
|
||||
def load_pickled_hosts(config, hbdclass):
|
||||
"""Load pickled hosts from file, if available."""
|
||||
global lastfm, msgs
|
||||
import os
|
||||
import pickle
|
||||
|
||||
pickfile = config.get("pickfile", "hbd.pickle")
|
||||
dyndnshosts = config.get("dyndnshosts", [])
|
||||
watchhosts = config.get("watchhosts", [])
|
||||
drophosts = config.get("drophosts", [])
|
||||
if 1 and os.path.exists(pickfile):
|
||||
if config.get("verbose", False):
|
||||
logger.info("opening pickls %s", pickfile)
|
||||
pickf = open(pickfile, "rb")
|
||||
pick = pickle.Unpickler(pickf)
|
||||
try:
|
||||
hbdclass.Host.hosts = pick.load()
|
||||
msgs = pick.load()
|
||||
try:
|
||||
lastfm = pick.load()
|
||||
except:
|
||||
lastfm = ["", "", ""]
|
||||
pickf.close()
|
||||
except Exception as e:
|
||||
print(("load pickled failed: %s" % e))
|
||||
os.unlink(pickfile)
|
||||
hbdclass.Connection.htab = {}
|
||||
for h in list(hbdclass.Host.hosts.keys()):
|
||||
hbdclass.Host.hosts[h].dyn = h in dyndnshosts
|
||||
hbdclass.Host.hosts[h].watched = h in watchhosts
|
||||
hbdclass.Host.hosts[h].fixup()
|
||||
for h in drophosts:
|
||||
if h in hbdclass.Host.hosts:
|
||||
del hbdclass.Host.hosts[h]
|
||||
if config.get("verbose", False):
|
||||
logger.info("%s pickled hosts loaded", len(hbdclass.Host.hosts))
|
||||
else:
|
||||
if config.get("verbose", False):
|
||||
logger.info("no pickled data")
|
||||
|
||||
def run(config):
|
||||
"""Start the hbd service (blocking).
|
||||
|
||||
This is a thin wrapper around asyncio.run to host the async services.
|
||||
Manually manages the event loop to ensure clean shutdown.
|
||||
"""
|
||||
global logf
|
||||
import os
|
||||
import threading
|
||||
import time as time_module
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG if config.get("debug", 0) > 0 else logging.INFO)
|
||||
load_pickled_hosts(config, hbdclass)
|
||||
|
||||
logf = initlog(logfile=config.get("logfile", "messages.log"))
|
||||
log(None, f"hbd version {__version__} starting up")
|
||||
|
||||
# Create and set the event loop manually
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
asyncio.run(_run_async(config))
|
||||
loop.run_until_complete(_run_async(config))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down (KeyboardInterrupt)")
|
||||
logger.info("Received KeyboardInterrupt, shutting down...")
|
||||
except Exception as e:
|
||||
logger.exception("Unhandled exception in main: %s", e)
|
||||
finally:
|
||||
cleanup_function(config)
|
||||
logger.info("hbd shutdown complete")
|
||||
if logf and logf != sys.stderr:
|
||||
try:
|
||||
logf.close()
|
||||
except Exception:
|
||||
pass
|
||||
# Explicitly close the loop
|
||||
try:
|
||||
# Cancel all remaining tasks
|
||||
pending = asyncio.all_tasks(loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
# Run one more cycle to process cancellations
|
||||
if pending:
|
||||
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Exit
|
||||
os._exit(0)
|
||||
|
||||
Reference in New Issue
Block a user