#!/usr/bin/env python3
# $Id: hbc,v 1.9 2012/03/29 02:08:36 andreas Exp $
# NEW
import sys
import time
import socket
import os
import signal
import getopt
import string
import select
import errno
import traceback
from hashlib import md5
import shutil
import zlib
import subprocess
import syslog
import codecs


PORT = 50003
INTERVAL = 10
REOPENC = 6
PIDFILE = "/tmp/hbc.pid"
VER = 6
MAXRECV = 32767

running = True
dorestart = False
warned1 = False


def log(msg):
    if fdaemon:
        syslog.syslog(syslog.LOG_ERR, msg)
    else:
        print(msg)


def handler(signum, frame):
    if signum == signal.SIGTERM:
        cleanup()


class NullDevice:
    def write(self, s):
        pass


class Conn:
    def __init__(self, conId, addr, port, af):
        self.conId = conId
        self.addr = addr
        self.port = port
        self.af = af

        self.ackcount = 0  # num of accks received
        self.lastack = 0  # time() last ACK was received
        self.send = 0
        self.lastsend = 0  # time() last msg was sent
        self.rtts = [0]
        self.sock = None

    def __str__(self):
        return "Con(%s, %s %s)" % (self.addr, self.port, self.af)

    def open(self):
        self.sock = socket.socket(self.af, socket.SOCK_DGRAM)
        self.sock.setsockopt(
            socket.SOL_SOCKET,
            socket.SO_REUSEADDR,
            self.sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1,
        )

    def sendto(self, msg, ID="HTB"):  # default ID is HearTBeat
        global warned1

        if self.send % REOPENC == 0:
            self.close()
        if not self.sock:
            self.open()
        msg["name"] = shortname(iam)
        msg["id"] = self.conId
        msg["ver"] = VER
        msg["time"] = time.time()
        m = dicttos(ID, msg)  # always compress
        if verbose:
            log("conn.send('%s', (%s:%s) %s)" % (msg, self.addr, self.port, len(m)))
        try:
            self.sock.sendto(m, (self.addr, self.port))
        except socket.error as e:
            if not warned1:
                log("socket error: %s %s:%s" % (e, self.addr, self.port))
            warned1 = True
            self.close()
            return
        self.send += 1
        self.lastsend = time.time()

    def ack(self, msgDict, now):
        try:
            self.lastack = msgDict["time"]
            mul = 2
        except:
            self.lastack = now
            mul = 1
        rtt = (self.lastack - self.lastsend) * mul
        if verbose:
            log("ack RTT: %0.1f ms (now %s)" % (rtt * 1000.0, now))
        self.rtts.append(rtt * 1000.0)
        if len(self.rtts) > 10:
            del self.rtts[0]
        self.ackcount += 1

    def close(self):
        if self.sock:
            self.sock.close()
        self.sock = None


def shortname(name):
    r = name.split(".")
    return r[0]


def dicttos(ID, d):
    s = []
    for k in d:
        if type(d[k]) == type(1.2):
            s.append("%s=%0.5f" % (k, d[k]))
        else:
            s.append("%s=%s" % (k, d[k]))
    pk = ";".join(s)
    zpk = zlib.compress(pk.encode(), 6)
    ID = "!" + ID + ":"
    return ID.encode() + zpk


def stodict(msg):
    d = {}
    if len(msg) > 0 and chr(msg[0]) == "!":
        pk = zlib.decompress(msg[5:]).decode()
        d["ID"] = msg[1:4].decode()
    else:
        r0 = msg.split(":", 1)
        pk = r0[1]
        d["ID"] = r0[0]
    r = pk.split(";")
    for v in r:
        vr = v.split("=", 1)
        k = vr[0].strip()
        if len(vr) == 1:
            d[k] = None
        else:
            v = vr[1].strip()
            try:
                v = eval(v)
            except:
                pass
            d[k] = v
    if verbose:
        print("msg is %s" % d)
    return d


def XXstodict(msg):
    d = {}
    r0 = msg.split(":", 1)
    if len(r0) == 1:
        return None
    if r0[0][0] == "!":  # compressed
        pk = zlib.decompress(msg[len(r0[0]) + 1 :])
        d["ID"] = r0[0][1:]
    else:
        pk = r0[1]
        d["ID"] = r0[0]
    r = pk.split(";")
    for v in r:
        vr = v.split("=", 1)
        k = vr[0].strip()
        if len(vr) == 1:
            d[k] = None
        else:
            v = vr[1].strip()
            try:
                if v[0].isdigit():
                    v = eval(v)
            except:
                pass
            d[k] = v
    return d


def syslogtrace(note):
    logm = "%s hbc died: \n%s" % (note, traceback.format_exc())
    log(logm)
    for l in logm.split("\n"):
        syslog.syslog(syslog.LOG_ERR, " tb: %s" % l)
    if verbose:
        print(logm)


conId = 1


def createConnections(hosts):
    global conId
    for host in hosts:
        if verbose:
            log("createConnections for %s" % host)
        try:
            rs = socket.getaddrinfo(host, hb_port, 0, 0, socket.SOL_UDP)
        except socket.gaierror:
            logm = "%s hbc died: \n%s" % ("createConnections", traceback.format_exc())
            if verbose:
                log(logm)
            return None
        for r in rs:
            if verbose:
                log("address %s" % str(r))
            if r[0] in [10, 24, 28, 30]:  # for Linux, NetBSD, FreeBSD
                af = socket.AF_INET6
            elif r[0] == 2:
                af = socket.AF_INET
            else:
                print("dont know this net type: %s" % r[0][0])
                sys.exit(1)

            addr = r[4][0]
            conns[conId] = Conn(conId, addr, hb_port, af)
            if verbose:
                print("cons[%s] = %s" % (conId, str(conns[conId])))
            conId += 1


def doexec(conn, data):
    try:
        ro = subprocess.check_output(
            data, stderr=subprocess.STDOUT, shell=True
        ).decode()
        fail = "OK"
    except subprocess.CalledProcessError as e:
        ro = str(e)
        fail = "CalledProcessError"
    except Exception as e:
        syslogtrace("System")
        ro = "N/A"
        fail = "cmd failed: %s" % e
    msg = {"service": "command", "msg": fail + " " + ro}
    conns[conn].sendto(msg)


def doupdate(conn, msgDict):
    fail = None
    try:
        code = codecs.decode(msgDict["code"], "base64").decode()
        csum = msgDict["csum"]
    except Exception as e:
        fail = "csum/code missing: %s" % e
    if not fail:
        fail = doupdateone(code, csum)

    msg = {"service": "update", "msg": fail if fail else "OK"}
    conns[conn].sendto(msg)
    if not fail:
        log("hc updates, fs = %s" % (len(code)))

    return fail


def doupdateone(code, csum):

    m = md5()
    m.update(code.encode())
    icsum = m.hexdigest()
    if icsum != csum:
        return "checksum error"

    fn = sys.argv[0]
    ofn = "%s.sav" % fn
    try:
        shutil.copy2(fn, ofn)
    except Exception as e:
        return "cannot make backup copy: %s" % e

    try:
        fh = open(fn, "w")
        fh.write(code)
        fh.close()
    except Exception as e:
        return "cannot write new code: %s" % e

    return None


def restart():
    if verbose:
        print("restart: execv %s %s" % (sys.argv[0], [sys.argv[0]] + cmdargs))
    syslog.syslog(syslog.LOG_ERR, "restart %s" % (sys.argv[0]))
    e = "fallthrough"
    try:
        os.execv(sys.argv[0], [sys.argv[0]] + cmdargs)
    except Exception as e:
        pass
    print("should not be here:", str(e))
    log("restart failed: %s" % e)


def process():
    global running, dorestart

    nextReport = time.time()

    while running:
        while time.time() < nextReport:
            ifiles = {}
            conIds = {}
            for conn in conns:
                if conns[conn].sock:
                    ifiles[conns[conn].sock.fileno()] = conns[conn].sock
                    conIds[conns[conn].sock.fileno()] = conn

            sleep = nextReport - time.time()
            if sleep <= 0:
                break
            try:
                r = select.select(list(ifiles.keys()), [], [], sleep)
                now = (
                    time.time()
                )  # nb: delay from actual packet arrival to select is ca. 105ms!
            except KeyboardInterrupt:
                running = False
                break
            except SystemExit:
                log("daemon exit, running was %s" % running)
                if running:
                    running = False
                break
            except:
                if running:
                    syslogtrace("select")
                    running = False
                break
            for rfh in r[0]:
                conn = conIds[rfh]
                data, addr = ifiles[rfh].recvfrom(MAXRECV)
                if verbose:
                    print("sock.recvfrom: %s (%s) %s" % (addr, len(data), data[:4]))
                try:
                    msgDict = stodict(data)
                except Exception as e:
                    print(
                        "failed to parse incoming data from %s: %s (%s)"
                        % (addr, data, e)
                    )
                    continue

                if verbose:
                    print(
                        "sock.recvfrom: %s (%s) %s"
                        % (addr, len(data), str(msgDict)[:80])
                    )
                if msgDict == None:
                    print("bad backet from %s (%s) %s" % (addr, len(data), data))
                elif msgDict["ID"] == "ACK":
                    conns[conn].ack(msgDict, now)
                elif msgDict["ID"] == "UPD":
                    if doupdate(conn, msgDict) == None:
                        if verbose:
                            print("process: restart after update")
                        dorestart = True
                        break
                elif msgDict["ID"] == "CMD":
                    doexec(conn, msgDict["cmd"])
                else:
                    doexec(conn, data)  # deprecated until no more VER - hbc
            if dorestart:
                running = False
                break
        if not running:
            break
        for conn in conns:
            msg = {"acks": conns[conn].ackcount, "rtt": conns[conn].rtts[-1]}
            conns[conn].sendto(msg)
            time.sleep(
                0.1
            )  # N.B. Linux (i.e. Rasperry Pi 3 drops the second pkg unless delayed
        if nextReport + interval >= time.time():
            nextReport += interval
        else:
            nextReport = time.time() + interval

    if verbose:
        log("process: done running")


def cleanup():
    global running
    if not running:
        return
    if verbose:
        log("cleanup")
    running = False
    for conn in conns:
        msg = {"shutdown": 1, "acks": conns[conn].ackcount}
        conns[conn].sendto(msg)
        conns[conn].close()
    time.sleep(1)
    closeall()


def closeall():
    if verbose:
        syslog.syslog(syslog.LOG_ERR, "closecall")
    for conn in conns:
        conns[conn].close()


def daemonize(
    working_dir="/", stdin="/dev/zero", stdout="/dev/null", stderr="/dev/null"
):
    """
    Does the UNIX double-fork magic, see Stevens' "Advanced Programming in the
    UNIX Environment" for details (ISBN 0201563177)
    http://www.yendor.com/programming/unix/apue/proc/fork2.c
    """

    try:
        # first fork
        pid = os.fork()
        if pid > 0:
            # exit from first parent
            os._exit(0)
    except OSError as e:
        sys.stderr.write("fork #1 failed: %d (%s)\n" % (e.errno, e.strerror))
        os._exit(1)

    # decouple from parent environment
    os.chdir(working_dir)
    os.setsid()
    os.umask(0)
    # second fork
    try:
        pid = os.fork()
        if pid > 0:
            # exit from second parent
            os._exit(0)
    except OSError as e:
        sys.stderr.write("fork #2 failed: %d (%s)\n" % (e.errno, e.strerror))
        sys.exit(1)

    # redirects standard file descriptors
    sys.stdout.flush()
    sys.stderr.flush()
    si = open(stdin, "r")
    so = open(stdout, "a+")
    se = open(stderr, "a+")
    os.dup2(si.fileno(), sys.stdin.fileno())
    os.dup2(so.fileno(), sys.stdout.fileno())
    os.dup2(se.fileno(), sys.stderr.fileno())


msgonly = False
helpflag = False
verbose = False
fdaemon = False
daemonized = False
optlist = []
args = []
msgboot = {}
home = os.environ["HOME"]
configfile = "%s/.hbrc" % home
cmdargs = []
iam = socket.gethostname()


try:
    optlist, args = getopt.getopt(sys.argv[1:], "bc:dhm:n:v")
except:
    helpflag = True

for o, a in optlist:
    if o == "-b":
        msgboot["boot"] = 1
    elif o == "-c":
        configfile = a
        cmdargs += [o, a]
    elif o == "-d":
        fdaemon = True
        cmdargs += [o]
    elif o == "-h":
        helpflag = True
    elif o == "-m":
        msgboot["service"] = "service"
        msgboot["msg"] = a
        msgonly = True
    elif o == "-n":
        iam = a
        cmdargs += [o, a]
    elif o == "-v":
        verbose = True
        cmdargs += [o]


cmdargs += args
if verbose:
    print("cmdargs for restart are %s" % cmdargs)

if helpflag:
    print("hbc HeartBeatClient")
    print("usage: hbc [-bdhv] [-c configfile] [-m msg][host1 [..]]")
    print()
    print("	-b	indicate machine boot")
    print("	-c configfile")
    print("	-d daemonize")
    print("	-h this help")
    print("	-m send a message")
    print("	-v verbose")
    print()
    print(
        """ config file can contain 
hb_hosts=('host1', 'host2', ..._
hb_port=50003
interval=20
logfile=...
logfmt={|test|msg}
grace=SECONDS
reportstrict={True|False}
"""
    )

    sys.exit(1)

#
# set defaults

hb_port = PORT
interval = INTERVAL
hb_hosts = []

try:
    f = open(configfile, "r")
    if verbose:
        print("notice: using config file %s" % configfile)
except:
    if verbose:
        print("warning: running without config file: %s" % configfile)
    f = None

if f:
    while 1:
        l = f.readline()
        if len(l) == 0:
            break
        r = l[:-1].split("=")
        if r[0] == "hb_hosts":
            hb_hosts = eval(r[1])
            if verbose:
                print("notice: cfg hb_hosts: %s" % hb_hosts)
        elif r[0] == "interval":
            interval = eval(r[1])
        elif r[0] == "hb_port":
            hb_port = eval(r[1])
        elif r[0] == "name":
            iam = eval(r[1])
            if verbose:
                print("name set to %s" % iam)
    f.close()

if len(args) != 0:
    hb_hosts = args


if len(hb_hosts) == 0:
    print("no hb server specified")
    sys.exit(1)

#
if verbose:
    print("notice: hb_hosts: %s" % str(hb_hosts))
    print("notice: hb_port: %s" % hb_port)
    print("notice: interval: %s" % interval)
    print("notice: iam: %s" % iam)
    print("notice: msgonly: %s" % msgonly)
    print("notice: msgboot: %s" % msgboot)

if not msgonly:
    msgboot["interval"] = interval


conns = {}
while True:
    if verbose:
        log("create connections")
    createConnections(hb_hosts)
    if len(conns) != 0:
        break
    if verbose:
        log("no connections yet, sleep a bit")
    time.sleep(2)

if verbose:
    log("%s connections created" % (len(conns)))

if len(msgboot) > 0:
    if verbose:
        print("on boot")
    msgboot["acks"] = 0
    for conn in conns:
        conns[conn].sendto(msgboot)

if msgonly:
    if verbose:
        print("msgboot done msgonly=%s" % msgonly)
    closeall()
    sys.exit(0)

#
syslog.openlog("hbc", syslog.LOG_PID, syslog.LOG_DAEMON)
if fdaemon:
    print("daemoinizing.")
    daemonize()
    daemonized = True
    syslog.syslog(syslog.LOG_ERR, "starting heartbeat to %s" % ",".join(hb_hosts))

signal.signal(signal.SIGTERM, handler)
running = True
try:
    process()
except Exception as e:
    syslogtrace("process")
    if verbose:
        print("err: process exit: %s" % e)

if verbose:
    log("main: cleanup")
cleanup()
if dorestart:
    restart()
