support SSL ws session

This commit is contained in:
2026-02-08 14:32:12 -05:00
parent 535b839bfc
commit 4c53b7cec9
12 changed files with 93 additions and 71 deletions
+8 -2
View File
@@ -30,8 +30,14 @@ DEFAULTS = {
"smtpserver": "smtp.fastmail.com",
"smtpuser": "andreas@wrede.ca",
"smtppassword": "pvtvefyp5gbhnch2",
"smtpport": 587
"smtpport": 587,
"toemail": ["aew.hbd.notify@wrede.ca"],
"fromemail": "aew.hbd@wrede.ca",
"ws_port": 50005,
"wss_port": None,
"cert_path": "/usr/local/etc/ssl/",
"wss_pem": "fullchain.pem",
"wss_key": "privkey.pem"
}
+5 -5
View File
@@ -54,7 +54,7 @@ def nsupdate(hostname: str, newip: str, dyndomain: str, nsupdate_bin: str = "/us
return out
async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional[callable] = None, email: Optional[callable] = None, loop: Optional[asyncio.AbstractEventLoop] = None):
async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional[callable] = None, pushmsg: Optional[callable] = None, loop: Optional[asyncio.AbstractEventLoop] = None):
"""Pure async DNS worker that processes updates from asyncio.Queue.
Exits when it receives a None sentinel.
@@ -99,9 +99,9 @@ async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional
err = await loop.run_in_executor(None, nsupdate, name, addr, dyndomain, cfg.get("nsupdate_bin", "/usr/local/bin/nsupdate"), cfg.get("rndc_key", "/etc/dhcpc/rndc-key"))
if err:
m += f", DNS update failed: {err}"
if email:
if pushmsg:
try:
await loop.run_in_executor(None, email, "error: nsupdate failed", f"{name}.dy.{dyndomain}: {m}")
await loop.run_in_executor(None, pushmsg, "error: nsupdate failed", f"{name}.dy.{dyndomain}: {m}")
except Exception:
pass
else:
@@ -125,7 +125,7 @@ async def dns_update_worker(hbdclass, cfg: dict, async_queue=None, log: Optional
pass
def start_dns_worker(hbdclass, cfg: dict, log: Optional[callable] = None, email: Optional[callable] = None, loop: Optional[asyncio.AbstractEventLoop] = None):
def start_dns_worker(hbdclass, cfg: dict, log: Optional[callable] = None, pushmsg: Optional[callable] = None, loop: Optional[asyncio.AbstractEventLoop] = None):
"""Start the async DNS worker and return the Task.
Replaces Host.dnsQ with an asyncio.Queue wrapped in a thread-safe bridge
@@ -167,5 +167,5 @@ def start_dns_worker(hbdclass, cfg: dict, log: Optional[callable] = None, email:
bridge = _QueueBridge(loop, async_q)
hbdclass.Host.dnsQ = bridge
task = loop.create_task(dns_update_worker(hbdclass, cfg, async_queue=async_q, log=log, email=email, loop=loop))
task = loop.create_task(dns_update_worker(hbdclass, cfg, async_queue=async_q, log=log, pushmsg=pushmsg, loop=loop))
return task
+14 -1
View File
@@ -139,7 +139,10 @@ async def start(
host = config.get("hb_host", "localhost")
extra_scripts = config.get("http_extra_scripts", "")
host = request.host.split(":")[0]
heartbeat_ws_url = f"ws://{host}:{config.get('ws_port', 50005)}/hbd"
if config.get("wss_port"):
heartbeat_ws_url = f"wss://{host}:{config['wss_port']}/hbd"
else:
heartbeat_ws_url = f"ws://{host}:{config.get('ws_port', 50005)}/hbd"
tmpl = env.get_template("live.html")
body = tmpl.render(
title="Heartbeat",
@@ -158,6 +161,7 @@ async def start(
URL form: /static/<path>
"""
p = request.match_info.get("path", "")
logger.debug("static file requested: %s", p)
base = os.path.abspath(os.path.join(os.path.dirname(__file__), "static"))
# normalize and prevent directory traversal
target = os.path.abspath(os.path.normpath(os.path.join(base, p)))
@@ -168,6 +172,14 @@ async def start(
logger.info("serving static file: %s", target)
return web.FileResponse(path=target)
async def favicon(request):
"""Serve favicon.ico from the package static directory."""
base = os.path.abspath(os.path.join(os.path.dirname(__file__), "static/images"))
target = os.path.join(base, "favicon.ico")
if not os.path.exists(target) or not os.path.isfile(target):
return web.Response(status=404, text="Not Found")
return web.FileResponse(path=target)
app = web.Application()
app.add_routes(
[
@@ -181,6 +193,7 @@ async def start(
web.get("/r", restart),
web.get("/live", live),
web.get("/static/{path:.*}", static),
web.get("/favicon.ico", favicon),
]
)
+2 -4
View File
@@ -9,7 +9,7 @@ from typing import Optional
from . import hbdclass
DROPOVERDUE = 7 * 24 * 3600
def checkoverdue(config: dict, hbdclass, log: callable, email: callable, pushmsg: callable, msg_to_websockets: callable):
def checkoverdue(config: dict, hbdclass, log: callable, pushmsg: callable, msg_to_websockets: callable):
now = time.time()
for h in list(hbdclass.Host.hosts.keys()):
pmsg = []
@@ -27,7 +27,6 @@ def checkoverdue(config: dict, hbdclass, log: callable, email: callable, pushmsg
conn.newstate(hbdclass.Connection.UNKNOWN, conn.lastbeat)
if pmsg != []:
if h in config.get("watchhosts", []):
email("overdue", "%s overdue" % " and ".join(pmsg))
pushmsg("%s %s overdue" % (h, " and ".join(pmsg)))
log(h, "%s overdue" % " and ".join(pmsg))
msg_to_websockets("host", hbdclass.Host.hosts[h].stateinfo())
@@ -36,11 +35,10 @@ async def start(
config: dict,
hbdclass: callable,
log=None,
email=None,
pushmsg=None,
msg_to_websockets=None,
):
""" start a monitor loop that checks for overdue hosts every minute """
while True:
await asyncio.sleep(15) # 15 seconds between checks
checkoverdue(config, hbdclass, log, email, pushmsg, msg_to_websockets)
checkoverdue(config, hbdclass, log, pushmsg, msg_to_websockets)
+13 -7
View File
@@ -21,21 +21,21 @@ def setup(cfg: dict):
_config = dict(cfg)
def send_email(aemail, smtpserver, sender, subject, body, debug=0):
def send_email(toaddrs, smtpserver, sender, subject, body, debug=0):
"""Send a plain email via SMTP. Returns True on success."""
try:
smtpport = _config.get("smtpport", 587)
server = smtplib.SMTP(smtpserver, smtpport))
server = smtplib.SMTP(smtpserver, smtpport)
if debug > 0:
server.set_debuglevel(1)
if smtpport == 587:
server.starttls()
server.ehlo()
smtpuser = _config.get("smtpuser", None)
smtppassword = _config.get("smtp_password", None)
smtppassword = _config.get("smtppassword", None)
if smtpuser and smtppassword:
server.login(smtpuser, smtppassword)
server.sendmail(sender, aemail, body)
server.sendmail(sender, toaddrs, body)
except Exception as e:
logger.warning("email send failed: %s", e)
try:
@@ -56,9 +56,12 @@ def email(subject: str, msg: str, debug: int = 0) -> bool:
Uses module-level configuration to supply recipient list, smtp server
and sender address.
"""
toaddrs = _config.get("AEMAIL") or _config.get("aemail") or _config.get("email_to") or []
fromemail = _config.get("fromemail") or _config.get("sender") or f"aew.heartbeat@{_config.get('domain','local') }"
smtpserver = _config.get("smtpserver") or _config.get("SMTPSERVER", "localhost")
toaddrs = _config.get("toemail")
fromemail = _config.get("fromemail")
smtpserver = _config.get("smtpserver")
if not toaddrs or not fromemail or not smtpserver:
logger.warning("email config incomplete: toemail=%s, fromemail=%s, smtpserver=%s", toaddrs, fromemail, smtpserver)
return False
date = time.strftime("%a, %d %b %Y %H:%M:%S %z", time.localtime())
body = "To: %s\nFrom: %s\nSubject: %s\nDate: %s\n\n%s" % (
toaddrs[0] if toaddrs else "",
@@ -153,6 +156,9 @@ def pushmsg(cfg: dict, msg: str, debug: int = 0):
if p in ("all", "signal"):
ok = pushsignal(cfg.get("signal_cli", "/usr/local/bin/signal-cli"), cfg.get("signal_user", ""), cfg.get("signal_recipient", ""), msg, debug=debug)
results["signal"] = ok
if p in ("all", "email"):
ok = email("Heartbeat notification", msg, debug=debug)
results["email"] = ok
logger.debug("push results: %s", results)
return results
+19 -7
View File
@@ -5,6 +5,8 @@ import socket
import time
import signal
import sys
import ssl
import pathlib
from . import __version__
from . import udp
@@ -82,7 +84,6 @@ async def _run_async(config):
notify_mod.setup(config)
email = notify_mod.email
pushmsg = notify_mod.pushmsg_from_config
sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
@@ -106,7 +107,6 @@ async def _run_async(config):
config=config,
hbdclass=hbdclass,
log=log,
email=email,
pushmsg=pushmsg,
msg_to_websockets=msg_to_websockets,
DEBUG=config.get("debug", 0),
@@ -129,7 +129,6 @@ async def _run_async(config):
hbdclass=hbdclass,
msgs_getter=lambda: msgs,
log=log,
email=email,
pushmsg=pushmsg,
msg_to_websockets=msg_to_websockets,
tcss=None,
@@ -146,19 +145,33 @@ async def _run_async(config):
# start dns update worker (async)
dns_task = None
try:
dns_task = dns_mod.start_dns_worker(hbdclass, config, log=log, email=email, loop=loop)
dns_task = dns_mod.start_dns_worker(hbdclass, config, log=log, pushmsg=pushmsg, loop=loop)
logger.info("dns update worker started")
except Exception as e:
logger.exception("dns worker failed to start: %s", e)
# Start the websocket servers as a background task
if config.get("wss_port", None):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_path = config.get("cert_path", "")
wss_pem = ssl_path + config.get("wss_pem", "")
wss_key = ssl_path + config.get("wss_key", "")
try:
ssl_context.load_cert_chain(wss_pem, keyfile=wss_key)
except FileNotFoundError:
logger.error("error: missing SSL keys %s or %s", wss_pem, wss_key)
sys.exit(1)
logger.info("Starting secure WebSocket server on port %s with cert %s", config.get("wss_port", None), wss_pem)
else:
ssl_context = None
try:
ws_task = asyncio.create_task(
ws_mod.start(
host=config.get("hbd_host", ""),
ws_port=config.get("ws_port", 50005),
ws_port=config.get("ws_port", None),
wss_port=config.get("wss_port", None),
ssl_context=None,
ssl_context=ssl_context,
get_hosts=lambda: [hbdclass.Host.hosts[h].stateinfo() for h in sorted(hbdclass.Host.hosts)],
get_msgs=lambda: msgs,
verbose=config.get("verbose", False),
@@ -175,7 +188,6 @@ async def _run_async(config):
config=config,
hbdclass=hbdclass,
log=log,
email=email,
pushmsg=pushmsg,
msg_to_websockets=msg_to_websockets,
)
+2 -2
View File
@@ -185,11 +185,11 @@
function WS_Connect() {
if ("WebSocket" in window) {
//N.B: subprotocol field causes chrome to error 1006
var ws_hbd = new WebSocket("{{heartbeat_ws_url}}" /*, "hdb" */);
var ws_hbd = new WebSocket("{{heartbeat_ws_url}}", /* "hdb" */ );
ws_hbd.onopen = function () {
// Web Socket is connected, send data using send()
console.log("ws connect");
console.log("ws connect {{heartbeat_ws_url}}");
// Hide modal window if visible
var modal = document.getElementById("connectionModal");
if (modal) {
-14
View File
@@ -68,7 +68,6 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
- config: dict of configuration
- hbdclass: module providing Host/Connection classes
- log: callable(loghost, message)
- email: callable(subject, message)
- pushmsg: callable(message)
- msg_to_websockets: callable(typ, data)
- DEBUG, verbose
@@ -79,7 +78,6 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
cfg = ctx.get("config", {})
hbdcls = ctx.get("hbdclass")
log = ctx.get("log")
email = ctx.get("email")
pushmsg = ctx.get("pushmsg")
msg_to_websockets = ctx.get("msg_to_websockets")
DEBUG = ctx.get("DEBUG", 0)
@@ -122,8 +120,6 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
if log:
log(uname, res)
if uname in cfg.get("watchhosts", []):
if email:
email("address change", "%s %s" % (host.name, res))
if pushmsg:
pushmsg("%s %s" % (host.name, res))
@@ -138,16 +134,12 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
log(uname, "booted")
if uname in cfg.get("watchhosts", []):
m = "%s booted" % (host.name)
if email:
email("booted", m)
if pushmsg:
pushmsg(m)
if message:
if log:
log(uname, "msg: %s" % message, service=service)
if uname in cfg.get("watchhosts", []):
if email:
email("msg", message)
if pushmsg:
pushmsg(message)
@@ -158,8 +150,6 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
if log:
log(uname, m)
if uname in cfg.get("watchhosts", []):
if email:
email("%s back" % conn.afam, uname)
if pushmsg:
pushmsg("%s %s is back" % (uname, conn.afam))
@@ -172,8 +162,6 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
if log:
log(uname, "%s shutdown" % conn.afam)
if uname in cfg.get("watchhosts", []):
if email:
email("shutdown", "%s %s shutdown" % (uname, conn.afam))
if pushmsg:
pushmsg("%s %s shutdown" % (uname, conn.afam))
conn.newstate(hbdcls.Connection.DOWN, now)
@@ -197,8 +185,6 @@ def handle_datagram(msg: dict, addr, transport, ctx: dict):
while len(host.cmds):
op, rmsg = host.cmds[0]
if op == "CMD":
if email:
email("%s cmd exec" % uname, "command '%s' sent" % rmsg)
del host.cmds[0]
if log:
log(uname, "command sent")
+12 -25
View File
@@ -20,11 +20,9 @@ _verbose = False
async def _handler(websocket, path=None):
# Some versions of the websockets library call handler(connection) only;
# accept optional path and fall back to websocket.path when missing.
global _connections
_connections.add(websocket)
remote_address = getattr(websocket, "remote_address", None)
remote_address = websocket.remote_address
if path is None:
path = getattr(websocket, "path", None)
if _verbose:
@@ -76,36 +74,25 @@ async def start(host: str, ws_port: int, wss_port: Optional[int] = None, ssl_con
servers = []
# plain WebSocket
ws_server = websockets.serve(_handler, host, ws_port) #, subprotocols=["hbd"])
websockets_logger = logging.getLogger("websockets.server")
websockets_logger.setLevel(logging.INFO)
websockets_logger.setLevel(logging.DEBUG if verbose else logging.INFO)
# regular WebSocket
ws_server = websockets.serve(_handler, host, ws_port) #, subprotocols=["hbd"])
servers.append(ws_server)
# secure WebSocket (optional)
if wss_port and ssl_context:
wss_server = websockets.serve(_handler, host, wss_port, ssl=ssl_context) #, subprotocols=["hbd"])
wss_server = websockets.serve(_handler, host, wss_port, ssl=ssl_context ) #, subprotocols=["hbd"])
servers.append(wss_server)
# await starting of all servers
try:
for srv in servers:
await srv
for srv in servers:
await srv
if _verbose:
logger.info("WebSocket server started on port %s (wss %s)", ws_port, wss_port)
if _verbose:
logger.info("WebSocket server(s) started on port %s (wss %s)", ws_port, wss_port)
# block forever (until loop is stopped or cancelled)
await asyncio.Future()
except asyncio.CancelledError:
logger.info("WebSocket server shutting down...")
# Close all active connections
for conn in list(_connections):
try:
await conn.close()
except Exception:
pass
_connections.clear()
raise
# block forever (until loop is stopped or cancelled)
await asyncio.Future()
def broadcast(typ: str, data) -> bool:
@@ -115,7 +102,7 @@ def broadcast(typ: str, data) -> bool:
connected websockets. Returns False if server was not running.
"""
global _loop
if not _loop:
return False
jmsg = json.dumps({"type": typ, "data": data})