#!/usr/bin/env python3
"""Persistent ARP MITM for TV traffic. Runs as a daemon.
Routes all TV traffic through the Pi with NAT forwarding.
Sandman/bot can then apply tc rules without needing to manage ARP."""

import json
import logging
import logging.handlers
import os
import signal
import socket
import struct
import subprocess
import sys
import time

IFACE = "end0"
GATEWAY_IP = "192.168.1.254"
STATE_FILE = "/share/arp_mitm_state.json"
LOG_FILE = "/share/arp_mitm.log"

log = logging.getLogger("arp_mitm")
log.setLevel(logging.INFO)
fmt = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
fh = logging.handlers.RotatingFileHandler(LOG_FILE, maxBytes=2_000_000, backupCount=2)
fh.setFormatter(fmt)
log.addHandler(fh)
sh = logging.StreamHandler()
sh.setFormatter(fmt)
log.addHandler(sh)


def get_our_mac() -> bytes:
    with open(f"/sys/class/net/{IFACE}/address") as f:
        return bytes.fromhex(f.read().strip().replace(":", ""))


def resolve_mac(ip: str) -> bytes | None:
    import re
    subprocess.run(["ping", "-c", "1", "-W", "2", ip], capture_output=True, timeout=5)
    try:
        r = subprocess.run(["ip", "neigh", "show", ip], capture_output=True, text=True, timeout=5)
        m = re.search(r"([\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2})", r.stdout, re.I)
        if m:
            return bytes.fromhex(m.group(1).replace(":", "").zfill(12))
    except Exception:
        pass
    return None


def build_arp_reply(src_mac: bytes, src_ip: str, dst_mac: bytes, dst_ip: str) -> bytes:
    eth = dst_mac + src_mac + b'\x08\x06'
    arp = struct.pack('!HHBBH', 1, 0x0800, 6, 4, 2)
    arp += src_mac + socket.inet_aton(src_ip)
    arp += dst_mac + socket.inet_aton(dst_ip)
    return eth + arp


def discover_tv() -> str | None:
    """Find TV via SSDP."""
    ssdp = 'M-SEARCH * HTTP/1.1\r\nHOST: 239.255.255.250:1900\r\nMAN: "ssdp:discover"\r\nMX: 3\r\nST: ssdp:all\r\n\r\n'
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
    sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2)
    sock.settimeout(5)
    try:
        sock.sendto(ssdp.encode(), ("239.255.255.250", 1900))
        while True:
            data, addr = sock.recvfrom(4096)
            text = data.decode(errors="replace")
            if "Philips" in text or "android" in text.lower():
                return addr[0]
    except socket.timeout:
        pass
    finally:
        sock.close()
    return None


def setup_forwarding():
    """Enable IP forwarding and NAT masquerade."""
    subprocess.run(["sysctl", "-w", "net.ipv4.ip_forward=1"], capture_output=True)
    # Flush existing NAT rules
    subprocess.run(["iptables", "-t", "nat", "-F"], capture_output=True)
    subprocess.run(["iptables", "-F", "FORWARD"], capture_output=True)
    # Add masquerade
    subprocess.run(["iptables", "-t", "nat", "-A", "POSTROUTING", "-o", IFACE, "-j", "MASQUERADE"], capture_output=True)
    subprocess.run(["iptables", "-A", "FORWARD", "-j", "ACCEPT"], capture_output=True)
    log.info("IP forwarding + NAT enabled")


def teardown_forwarding():
    subprocess.run(["iptables", "-t", "nat", "-F"], capture_output=True)
    subprocess.run(["iptables", "-F", "FORWARD"], capture_output=True)
    subprocess.run(["sysctl", "-w", "net.ipv4.ip_forward=0"], capture_output=True)
    log.info("IP forwarding + NAT disabled")


def restore_arp(our_mac, tv_mac, tv_ip, gw_mac, gw_ip):
    """Send correct ARP entries to restore the network."""
    try:
        sock = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(0x0003))
        sock.bind((IFACE, 0))
        bcast = b'\xff\xff\xff\xff\xff\xff'
        for _ in range(15):
            sock.send(build_arp_reply(gw_mac, gw_ip, tv_mac, tv_ip))
            sock.send(build_arp_reply(tv_mac, tv_ip, gw_mac, gw_ip))
            sock.send(build_arp_reply(gw_mac, gw_ip, bcast, "0.0.0.0"))
            time.sleep(0.2)
        sock.close()
        log.info("ARP restored (%d packets)", 15 * 3)
    except Exception as e:
        log.warning("ARP restore failed: %s", e)


def save_state(tv_ip, tv_mac_hex, gw_mac_hex):
    with open(STATE_FILE, "w") as f:
        json.dump({"tv_ip": tv_ip, "tv_mac": tv_mac_hex, "gw_mac": gw_mac_hex}, f)


def clear_state():
    try:
        os.remove(STATE_FILE)
    except FileNotFoundError:
        pass


IP_CHECK_INTERVAL = 300  # re-discover TV IP every 5 min
mac_fmt = lambda b: ":".join(f"{x:02x}" for x in b) if b else "?"

running = True


def shutdown(signum, frame):
    global running
    log.info("Shutting down (signal %d)...", signum)
    running = False


signal.signal(signal.SIGTERM, shutdown)
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGHUP, signal.SIG_IGN)


def discover_tv_with_cache() -> str | None:
    """Discover TV, falling back to cache."""
    ip = discover_tv()
    if ip:
        # Update cache
        try:
            with open("/share/sandman_bot_tv_ip", "w") as f:
                f.write(ip)
        except Exception:
            pass
        return ip
    # Fallback to cache
    try:
        if os.path.exists("/share/sandman_bot_tv_ip"):
            return open("/share/sandman_bot_tv_ip").read().strip() or None
    except Exception:
        pass
    return None


def run_spoof(tv_ip: str) -> str | None:
    """Run ARP spoof loop for one TV IP. Returns new IP if TV moved, None on shutdown."""
    our_mac = get_our_mac()
    tv_mac = resolve_mac(tv_ip)
    gw_mac = resolve_mac(GATEWAY_IP)

    if not our_mac or not tv_mac or not gw_mac:
        log.warning("MAC resolution failed (us=%s, tv=%s, gw=%s)",
                     mac_fmt(our_mac), mac_fmt(tv_mac), mac_fmt(gw_mac))
        return None

    log.info("Spoofing: TV %s (%s) ↔ GW %s (%s)",
             tv_ip, mac_fmt(tv_mac), GATEWAY_IP, mac_fmt(gw_mac))

    save_state(tv_ip, mac_fmt(tv_mac), mac_fmt(gw_mac))
    setup_forwarding()

    # Pin the TV's real MAC in our own ARP table so OUR traffic to the TV
    # doesn't get caught by our own spoof
    try:
        subprocess.run(["ip", "neigh", "replace", tv_ip, "lladdr",
                        mac_fmt(tv_mac), "nud", "permanent", "dev", IFACE],
                       capture_output=True, timeout=5)
        subprocess.run(["ip", "neigh", "replace", GATEWAY_IP, "lladdr",
                        mac_fmt(gw_mac), "nud", "permanent", "dev", IFACE],
                       capture_output=True, timeout=5)
        log.info("Pinned static ARP: TV=%s GW=%s (Pi can always reach both)", tv_ip, GATEWAY_IP)
    except Exception as e:
        log.warning("Failed to pin ARP: %s", e)

    pkt_to_tv = build_arp_reply(our_mac, GATEWAY_IP, tv_mac, tv_ip)
    pkt_to_gw = build_arp_reply(our_mac, tv_ip, gw_mac, GATEWAY_IP)

    try:
        sock = socket.socket(socket.AF_PACKET, socket.SOCK_RAW, socket.htons(0x0003))
        sock.bind((IFACE, 0))
    except Exception as e:
        log.error("Failed to open raw socket: %s", e)
        return None

    last_ip_check = time.time()

    try:
        while running:
            try:
                sock.send(pkt_to_tv)
                sock.send(pkt_to_gw)
            except Exception:
                pass
            time.sleep(1)

            # Periodically re-discover TV IP
            if time.time() - last_ip_check > IP_CHECK_INTERVAL:
                last_ip_check = time.time()
                new_ip = discover_tv()
                if new_ip and new_ip != tv_ip:
                    log.info("TV moved: %s → %s — restarting spoof", tv_ip, new_ip)
                    # Restore old ARP before switching
                    restore_arp(our_mac, tv_mac, tv_ip, gw_mac, GATEWAY_IP)
                    sock.close()
                    return new_ip  # signal caller to restart with new IP
    except Exception as e:
        log.error("Spoof loop error: %s", e)
    finally:
        try:
            sock.close()
        except Exception:
            pass
        log.info("Restoring ARP...")
        subprocess.run(["tc", "qdisc", "del", "dev", IFACE, "root"], capture_output=True)
        restore_arp(our_mac, tv_mac, tv_ip, gw_mac, GATEWAY_IP)
        if not running:
            teardown_forwarding()
            clear_state()
            log.info("Clean shutdown complete")

    return None


def main():
    log.info("ARP MITM daemon starting (with watchdog)...")

    while running:
        # Discover TV
        tv_ip = discover_tv_with_cache()
        if not tv_ip:
            log.warning("TV not found — retrying in 30s")
            time.sleep(30)
            continue

        log.info("TV IP: %s", tv_ip)

        # Run spoof — returns new IP if TV moved, None on shutdown/error
        new_ip = run_spoof(tv_ip)

        if not running:
            break

        if new_ip:
            # TV moved — loop back and spoof new IP immediately
            log.info("Restarting spoof for new IP %s", new_ip)
            continue

        # Spoof crashed or failed — wait and retry
        log.warning("Spoof stopped unexpectedly — restarting in 10s")
        time.sleep(10)

    log.info("MITM daemon exited")


if __name__ == "__main__":
    main()
