#!/usr/bin/env python3
"""Telegram bot for live monitoring and control of the sandman TV sleep pressure system."""

import asyncio
import datetime
import json
import logging
import math
import os
import re
import subprocess
import sys
import time
from logging.handlers import RotatingFileHandler
from pathlib import Path

import tv_control
from telegram import Update, BotCommand, InlineKeyboardButton, InlineKeyboardMarkup, WebAppInfo
from telegram.ext import Application, CommandHandler, CallbackQueryHandler, ContextTypes
from telegram.constants import ParseMode

WEBAPP_URL = "https://a0d7b954-ssh.tail821319.ts.net/devices.html"
TV_REMOTE_URL = "https://a0d7b954-ssh.tail821319.ts.net/tv_remote.html"
SANDMAN_APP_URL = "https://a0d7b954-ssh.tail821319.ts.net/sandman_app.html"

# ── Config ──────────────────────────────────────────────────────────────────

BOT_TOKEN = "8782685448:AAG2QRMgxr1WGc1J8VcQGwxmNmj7rDDDnBI"
LOG_FILE = "/share/sandman.log"
BOT_LOG_FILE = "/share/sandman_bot.log"
CHAT_ID_FILE = "/share/sandman_bot_chat_id"
TV_IP_CACHE = "/share/sandman_bot_tv_ip"
SANDMAN_SCRIPT = "/share/sandman.py"
PROD_CONFIG = "/share/sandman_config.json"
TEST_CONFIG = "/share/sandman_config_test.json"

# JointSpace auth and UPnP constants now in tv_control module

STATE_POLL_INTERVAL = 30
LOG_POLL_INTERVAL = 2

# ── Attention Deficit Model ────────────────────────────────────────────────

STATE_FILE = "/share/sandman_state.json"
SPEED_FILE = "/share/sandman_speed.json"

BASE_IMPACT_RATE = 0.003
RECOVERY_RATE = 0.0005

TIME_MULTIPLIERS = [
    {"start": "06:00", "end": "18:00", "rate": 1.0},
    {"start": "18:00", "end": "22:00", "rate": 1.2},
    {"start": "22:00", "end": "22:30", "rate": 0.3},
    {"start": "22:30", "end": "00:00", "rate": 1.8},
    {"start": "00:00", "end": "02:00", "rate": 2.5},
    {"start": "02:00", "end": "06:00", "rate": 3.0},
]


def read_state() -> dict:
    try:
        with open(STATE_FILE) as f:
            return json.load(f)
    except Exception:
        return {"deficit": 0.0, "last_updated": None, "session_start": None, "total_watch_today_min": 0, "daily_log": {}}


def write_state(state: dict):
    with open(STATE_FILE, "w") as f:
        json.dump(state, f, indent=2)


def _time_multiplier_at(hour: int, minute: int) -> float:
    """Return the time-of-day multiplier for a given HH:MM."""
    t = hour * 60 + minute
    for entry in TIME_MULTIPLIERS:
        sh, sm = map(int, entry["start"].split(":"))
        eh, em = map(int, entry["end"].split(":"))
        s = sh * 60 + sm
        e = eh * 60 + em
        if e <= s:
            # Wraps midnight (e.g. 22:30 -> 00:00 means 22:30 -> 24:00)
            e += 24 * 60
        t_adj = t if t >= s else t + 24 * 60
        if s <= t_adj < e:
            return entry["rate"]
    return 1.0


def _fmt_duration(minutes: float) -> str:
    """Format minutes as Xh Ym."""
    h = int(minutes) // 60
    m = int(minutes) % 60
    if h > 0:
        return f"{h}h {m:02d}m"
    return f"{m}m"


def _estimate_recovery(deficit: float, target: float) -> str:
    """Estimate hours to recover from deficit to target (rough, assumes no watching)."""
    if deficit <= target:
        return "now"
    # deficit * exp(-rate * min) = target  =>  min = -ln(target/deficit) / rate
    if target <= 0:
        target = 0.001
    minutes = -math.log(target / deficit) / RECOVERY_RATE
    hours = minutes / 60
    if hours < 1:
        return f"~{int(minutes)}m"
    return f"~{hours:.0f}h"

# ── Logging ─────────────────────────────────────────────────────────────────

log = logging.getLogger("sandman_bot")
log.setLevel(logging.INFO)
_fmt = logging.Formatter("%(asctime)s  %(levelname)-8s  %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
_rh = RotatingFileHandler(BOT_LOG_FILE, maxBytes=2_000_000, backupCount=3)
_rh.setFormatter(_fmt)
log.addHandler(_rh)
_sh = logging.StreamHandler()
_sh.setFormatter(_fmt)
log.addHandler(_sh)

# ── Globals ─────────────────────────────────────────────────────────────────

owner_chat_id: int | None = None
tv_ip: str | None = None
last_power: str | None = None
last_app: str | None = None
last_playback: str | None = None
last_title: str | None = None
log_offset: int = 0
last_volume: int | None = None

# ── TV discovery (delegated to tv_control) ─────────────────────────────────


def get_tv_ip(force_rediscover: bool = False) -> str | None:
    """Return cached TV IP, re-discovering if needed."""
    global tv_ip
    if tv_ip and not force_rediscover:
        return tv_ip
    if not force_rediscover and os.path.exists(TV_IP_CACHE):
        cached = Path(TV_IP_CACHE).read_text().strip()
        if cached:
            tv_ip = cached
            tv_control.TV_IP = cached
            return tv_ip
    tv_ip = tv_control.discover_tv()
    if tv_ip:
        Path(TV_IP_CACHE).write_text(tv_ip)
        tv_control.TV_IP = tv_ip
    return tv_ip


def get_tv_ip_or_cached() -> str | None:
    """Return TV IP, preferring cache even if TV is off (for standby commands)."""
    global tv_ip
    if tv_ip:
        return tv_ip
    if os.path.exists(TV_IP_CACHE):
        cached = Path(TV_IP_CACHE).read_text().strip()
        if cached:
            tv_ip = cached
            tv_control.TV_IP = cached
            return tv_ip
    tv_ip = "192.168.1.50"
    tv_control.TV_IP = tv_ip
    return tv_ip


def invalidate_tv_ip():
    global tv_ip
    tv_ip = None
    try:
        os.remove(TV_IP_CACHE)
    except FileNotFoundError:
        pass


# ── JointSpace helpers (delegated to tv_control) ──────────────────────────

_consecutive_failures = 0


def js_get(path: str, timeout: float = 5) -> dict | None:
    global _consecutive_failures
    ip = get_tv_ip_or_cached()
    if not ip:
        return None
    result = tv_control.js_get(path, timeout=timeout, ip=ip)
    if result is None:
        _consecutive_failures += 1
        if _consecutive_failures >= 10:
            new_ip = tv_control.discover_tv()
            if new_ip and new_ip != ip:
                global tv_ip
                tv_ip = new_ip
                tv_control.TV_IP = new_ip
                Path(TV_IP_CACHE).write_text(new_ip)
                log.info("TV IP changed: %s -> %s", ip, new_ip)
            _consecutive_failures = 0
    else:
        _consecutive_failures = 0
    return result


def js_post(path: str, data: dict, timeout: float = 5) -> bool:
    ip = get_tv_ip_or_cached()
    if not ip:
        return False
    return tv_control.js_post(path, data, timeout=timeout, ip=ip)


def send_key(key: str) -> bool:
    ip = get_tv_ip_or_cached()
    return tv_control.js_key(key, ip=ip)


# ── UPnP helpers (delegated to tv_control) ──────────────────────────────────


def upnp_get_volume() -> int | None:
    ip = get_tv_ip_or_cached()
    return tv_control.get_volume(ip=ip)


def upnp_set_volume(vol: int):
    ip = get_tv_ip_or_cached()
    tv_control.set_volume(vol, ip=ip)


def upnp_set_mute(muted: bool):
    ip = get_tv_ip_or_cached()
    tv_control.set_mute(muted, ip=ip)


# ── Security ────────────────────────────────────────────────────────────────


def load_owner() -> int | None:
    if os.path.exists(CHAT_ID_FILE):
        try:
            return int(Path(CHAT_ID_FILE).read_text().strip())
        except Exception:
            pass
    return None


def save_owner(chat_id: int):
    global owner_chat_id
    owner_chat_id = chat_id
    Path(CHAT_ID_FILE).write_text(str(chat_id))
    log.info("Owner chat_id saved: %d", chat_id)


def is_owner(update: Update) -> bool:
    global owner_chat_id
    if owner_chat_id is None:
        owner_chat_id = load_owner()
    if owner_chat_id is None:
        save_owner(update.effective_chat.id)
        return True
    return update.effective_chat.id == owner_chat_id


# ── Sandman process management ──────────────────────────────────────────────


def find_sandman_pid() -> int | None:
    try:
        out = subprocess.check_output(["ps", "aux"], text=True)
        my_pid = os.getpid()
        for line in out.splitlines():
            if "sandman.py" in line and "sandman_bot" not in line and "grep" not in line:
                parts = line.split()
                pid = int(parts[1])
                if pid != my_pid:
                    return pid
    except Exception:
        pass
    return None


def kill_sandman() -> bool:
    pid = find_sandman_pid()
    if pid:
        try:
            subprocess.run(["kill", str(pid)], check=True)
            log.info("Killed sandman PID %d", pid)
            return True
        except Exception as e:
            log.error("Failed to kill sandman: %s", e)
    return False


def start_sandman(config: str) -> bool:
    try:
        proc = subprocess.Popen(
            ["python3", SANDMAN_SCRIPT, "--config", config, "--log-dir", "/share/"],
            start_new_session=True,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
        log.info("Started sandman PID %d with config %s", proc.pid, config)
        return True
    except Exception as e:
        log.error("Failed to start sandman: %s", e)
        return False


# ── ADB helpers (delegated to tv_control) ──────────────────────────────────


def adb_available() -> bool:
    try:
        subprocess.run(["adb", "version"], capture_output=True, check=True)
        return True
    except (FileNotFoundError, subprocess.CalledProcessError):
        return False


def adb_get_current_app() -> tuple:
    ip = get_tv_ip_or_cached()
    return tv_control.adb_get_current_app(ip=ip)


def take_screenshot() -> str | None:
    if not adb_available():
        return None
    ip = get_tv_ip_or_cached()
    return tv_control.adb_screenshot(ip=ip)


# ── Formatting helpers ──────────────────────────────────────────────────────

def fmt_log_line(line: str) -> str:
    """Format a sandman log line for Telegram readability.
    Input:  '2026-03-20 23:09:11 [INFO]    Grace period: 54 min ...'
    Output: '23:09 Grace period: 54 min ...'
    """
    # Strip timestamp prefix, keep just HH:MM
    m = re.match(r"\d{4}-\d{2}-\d{2} (\d{2}:\d{2}):\d{2} \[(\w+)\]\s*(.*)", line)
    if m:
        t, level, msg = m.group(1), m.group(2), m.group(3)
        if level == "WARNING":
            return f"⚠️ {t} {msg}"
        return f"{t} {msg}"
    return line.strip()


# ── Command handlers ────────────────────────────────────────────────────────


async def cmd_start(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    await update.message.reply_text(
        "🌙 *Sandman Bot*\n\n"
        "Use the command menu (tap /) for all controls.\n"
        "Log updates will be pushed here automatically.",
        parse_mode=ParseMode.MARKDOWN,
    )


async def cmd_status(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    msg = await update.message.reply_text("⏳ Checking...")
    ip = get_tv_ip_or_cached()
    lines = []
    if not ip:
        lines.append("📺 TV IP: unknown (never discovered)")
    else:
        pw = js_get("powerstate")
        power = pw.get("powerstate", "?") if pw else "unreachable"
        icon = "🟢" if power == "On" else "🔴" if power == "Standby" else "❓"
        lines.append(f"{icon} Power: {power}")
        lines.append(f"🌐 IP: {ip}")

        vol_data = js_get("audio/volume")
        if vol_data:
            cur = vol_data.get("current", "?")
            mx = vol_data.get("max", "?")
            muted = "🔇" if vol_data.get("muted") else "🔊"
            lines.append(f"{muted} Volume: {cur}/{mx}")
        else:
            uv = upnp_get_volume()
            if uv is not None:
                lines.append(f"🔊 Volume: {uv} (UPnP)")

        pkg, playback, extras = adb_get_current_app()
        if pkg:
            friendly = APP_NAMES.get(pkg, pkg)
            status = f" ({playback})" if playback else ""
            lines.append(f"📱 App: {friendly}{status}")
            if extras:
                if extras.get("title"):
                    title_line = f"🎬 {extras['title']}"
                    if extras.get("artist"):
                        title_line += f" — {extras['artist']}"
                    lines.append(title_line)
                if extras.get("position_s"):
                    mins, secs = divmod(extras["position_s"], 60)
                    lines.append(f"⏱ {mins}:{secs:02d}")
        else:
            # Fallback to JointSpace
            act = js_get("activities/current")
            if act:
                p = act.get("component", {}).get("packageName", "")
                if p and p != "NA":
                    lines.append(f"📱 App: {APP_NAMES.get(p, p)}")

    # Sandman status
    pid = find_sandman_pid()
    if pid:
        # Parse sandman config and log to get current state
        try:
            # Which config?
            r = subprocess.run(["ps", "aux"], capture_output=True, text=True, timeout=3)
            config_name = "test" if "config_test" in r.stdout else "prod"

            # Read last relevant log lines for grace/severity info
            import re as _re
            grace_info = ""
            severity_info = ""
            with open(LOG_FILE, "r") as f:
                for line in f.readlines()[-30:]:
                    if "Grace period" in line:
                        m = _re.search(r"Grace period — (\d+) min remaining", line)
                        if m:
                            grace_info = f"⏳ Grace: {m.group(1)} min remaining"
                    elif "sev=" in line:
                        m = _re.search(r"sev=([\d.]+) max_sev=(\d+) \| (\S+)", line)
                        if m:
                            severity_info = f"⚡ Last action: {m.group(3)} (sev {m.group(1)})"

            # Bad hours check
            import datetime
            now = datetime.datetime.now()
            # Read config for bad hours
            config_path = TEST_CONFIG if config_name == "test" else PROD_CONFIG
            with open(config_path) as f:
                cfg = json.load(f)
            bad_start = cfg.get("bad_hours_start", "22:30")
            bad_peak = cfg.get("bad_hours_peak", "00:30")
            grace_min = cfg.get("grace_period_minutes", 60)

            bh, bm = map(int, bad_start.split(":"))
            bad_time = now.replace(hour=bh, minute=bm, second=0)
            if now.hour < 12 and bh > 12:
                bad_time = bad_time.replace(day=bad_time.day - 1)

            if now >= bad_time or now.hour < 6:
                mins_in = int((now - bad_time).total_seconds() / 60) % (24 * 60)
                lines.append(f"🌙 Bad hours: {bad_start}→{bad_peak} (IN — {mins_in}m in)")
            else:
                mins_until = int((bad_time - now).total_seconds() / 60)
                lines.append(f"🌙 Bad hours: {bad_start}→{bad_peak} ({mins_until}m until)")

            lines.append(f"🤖 Sandman: {config_name} mode (PID {pid})")
            if grace_info:
                lines.append(grace_info)
            if severity_info:
                lines.append(severity_info)
        except Exception:
            lines.append(f"🤖 Sandman: running (PID {pid})")
    else:
        lines.append("🤖 Sandman: stopped")

    # Throttle status
    tc_check = subprocess.run(["tc", "qdisc", "show", "dev", "end0"], capture_output=True, text=True)
    if "htb" in tc_check.stdout:
        lines.append("🌐 Throttle: ON")
    mitm_running = os.path.exists("/share/arp_mitm_state.json")
    if mitm_running:
        lines.append("🔀 MITM: active")

    # Deficit info
    state = read_state()
    d = state.get("deficit", 0.0)
    session_min = 0
    if state.get("session_start"):
        try:
            ss = datetime.datetime.fromisoformat(state["session_start"])
            session_min = (datetime.datetime.now(ss.tzinfo or datetime.timezone.utc) - ss).total_seconds() / 60
        except Exception:
            pass
    today_min = state.get("total_watch_today_min", 0)
    lines.append(f"🧠 Deficit: {d:.2f} | Session: {_fmt_duration(session_min)} | Today: {_fmt_duration(today_min)}")

    await msg.edit_text("\n".join(lines))


async def cmd_screenshot(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    if not adb_available():
        await update.message.reply_text("❌ ADB not installed")
        return
    msg = await update.message.reply_text("📸 Taking screenshot...")
    path = await asyncio.to_thread(take_screenshot)
    if path and os.path.exists(path) and os.path.getsize(path) > 0:
        await update.message.reply_photo(photo=open(path, "rb"))
        await msg.delete()
    else:
        await msg.edit_text("❌ Screenshot failed — TV off or DRM content")


async def cmd_mute(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    vol_data = await asyncio.to_thread(lambda: js_get("audio/volume"))
    if vol_data:
        new_mute = not vol_data.get("muted", False)
        await asyncio.to_thread(lambda: upnp_set_mute(new_mute))
        await update.message.reply_text(f"{'🔇 Muted' if new_mute else '🔊 Unmuted'}")
    else:
        await update.message.reply_text("❌ Can't reach TV")


async def cmd_volume(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    args = ctx.args
    if not args or not args[0].isdigit():
        vol_data = await asyncio.to_thread(lambda: js_get("audio/volume"))
        if vol_data:
            await update.message.reply_text(f"🔊 Volume: {vol_data.get('current', '?')}/{vol_data.get('max', '?')}")
        else:
            await update.message.reply_text("Usage: /volume N (0-60)")
        return
    vol = min(60, max(0, int(args[0])))
    await asyncio.to_thread(lambda: upnp_set_volume(vol))
    await update.message.reply_text(f"🔊 Volume → {vol}")


async def cmd_pause(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    ok = await asyncio.to_thread(lambda: send_key("Pause"))
    await update.message.reply_text("⏸ Paused" if ok else "❌ Failed")


async def cmd_play(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    ok = await asyncio.to_thread(lambda: send_key("Play"))
    await update.message.reply_text("▶️ Playing" if ok else "❌ Failed")


async def cmd_home(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    ok = await asyncio.to_thread(lambda: send_key("Home"))
    await update.message.reply_text("🏠 Home" if ok else "❌ Failed")


async def cmd_power(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    ip = get_tv_ip_or_cached()
    if not ip:
        await update.message.reply_text("❌ TV IP unknown — turn TV on manually first so I can discover it")
        return
    ok = await asyncio.to_thread(lambda: send_key("Standby"))
    await update.message.reply_text("⚡ Power toggled" if ok else "❌ Failed — TV may be fully off (not standby)")


async def cmd_test(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Quick test: set deficit to 0.7 for immediate heavy actions."""
    if not is_owner(update):
        return
    state = read_state()
    state["deficit"] = 0.7
    state["last_updated"] = datetime.datetime.now().isoformat()
    write_state(state)
    await update.message.reply_text("⚡ Test mode — deficit set to 0.70\nHeavy actions should trigger soon.")


async def cmd_prod(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    msg = await update.message.reply_text("⏳ Switching to production...")
    kill_sandman()
    await asyncio.sleep(1)
    ok = start_sandman(PROD_CONFIG)
    await msg.edit_text("🌙 Production mode ON" if ok else "❌ Failed to start")


async def cmd_stop(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    ok = kill_sandman()
    await update.message.reply_text("🛑 Sandman stopped" if ok else "ℹ️ Sandman wasn't running")


async def cmd_start_sandman(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    if find_sandman_pid():
        await update.message.reply_text("ℹ️ Sandman already running")
        return
    ok = start_sandman(PROD_CONFIG)
    await update.message.reply_text("✅ Sandman started" if ok else "❌ Failed to start")


async def cmd_log(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    if not is_owner(update):
        return
    if not os.path.exists(LOG_FILE):
        await update.message.reply_text("ℹ️ No log file")
        return
    try:
        with open(LOG_FILE, "r") as f:
            lines = f.readlines()
        n = 15
        if ctx.args and ctx.args[0].isdigit():
            n = min(50, int(ctx.args[0]))
        tail = lines[-n:] if len(lines) >= n else lines
        formatted = [fmt_log_line(l) for l in tail if l.strip()]
        text = "\n".join(formatted)
        if not text:
            text = "(empty)"
        if len(text) > 4000:
            text = text[-4000:]
        await update.message.reply_text(text)
    except Exception as e:
        await update.message.reply_text(f"❌ Error: {e}")


async def cmd_mitm(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Control ARP MITM daemon — /mitm on|off|status"""
    if not is_owner(update):
        return
    args = ctx.args
    running = os.path.exists("/share/arp_mitm_state.json")

    if not args:
        await update.message.reply_text(
            f"🔀 MITM: {'ON' if running else 'OFF'}\n\n"
            "/mitm on — start (needed for throttle)\n"
            "/mitm off — stop and restore network")
        return

    arg = args[0].lower()
    if arg == "on":
        if running:
            await update.message.reply_text("ℹ️ MITM already running")
            return
        msg = await update.message.reply_text("⏳ Starting MITM...")
        subprocess.Popen(["python3", "/share/arp_mitm.py"], start_new_session=True,
                         stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        await asyncio.sleep(5)
        if os.path.exists("/share/arp_mitm_state.json"):
            await msg.edit_text("🔀 MITM ON — TV traffic flows through Pi")
        else:
            await msg.edit_text("❌ MITM failed to start")
    elif arg == "off":
        if not running:
            await update.message.reply_text("ℹ️ MITM not running")
            return
        # Remove tc first
        tv_control.throttle_remove()
        _clear_throttle_state()
        # Stop MITM gracefully
        subprocess.run(["pkill", "-f", "arp_mitm.py"], capture_output=True)
        await asyncio.sleep(5)
        await update.message.reply_text("🔀 MITM OFF — network restored")


DEVICE_NAMES_FILE = "/share/device_names.json"
HA_DEVICES = [
    "switch.sonoff_10020b052f_1",
    "switch.sonoff_10020b052f_2",
    "switch.sonoff_10020b0c7c_1",
    "switch.sonoff_10020b0c7c_2",
    "switch.magic_switch_s1e_bb55_switch_1",
    "switch.magic_switch_s1e_bb55_switch_2",
    "switch.magic_switch_s1e_bb55_switch_3",
    "switch.magic_switch_s1e_595a_kitchen_switch_1",
    "switch.magic_switch_s1e_595a_kitchen_switch_2",
    "switch.magic_switch_s1e_595a_kitchen_switch_3",
    "switch.tz3000_cayepv1a_ts011f",
    "switch.tz3000_cayepv1a_ts011f_2",
    "switch.tz3000_cayepv1a_ts011f_3",
    "light.fan",
    "light.fan_2",
    "light.fan_3",
]


def _get_device_names() -> dict:
    try:
        with open(DEVICE_NAMES_FILE) as f:
            return json.load(f)
    except Exception:
        return {}


def _ha_api(method: str, path: str, data: dict = None) -> dict | None:
    return tv_control.ha_api(method, path, data)


def _get_ha_device_states() -> list:
    """Returns [(entity_id, friendly_name, state), ...]"""
    names = _get_device_names()
    raw = tv_control.ha_get_device_states(HA_DEVICES)
    # Override friendly names with custom names if available
    result = []
    for eid, default_name, state in raw:
        name = names.get(eid, default_name)
        result.append((eid, name, state))
    return result


DEVICE_AREAS = {
    "🛏 Bedroom": [
        "switch.sonoff_10020b052f_2",   # Lights outside Bedrooms
    ],
    "🛋 Living Room": [
        "switch.magic_switch_s1e_bb55_switch_1",  # Living Room Lights
        "switch.magic_switch_s1e_bb55_switch_2",  # Living Room Track Lights
        "switch.magic_switch_s1e_bb55_switch_3",  # TV Track Lights
    ],
    "🍽 Dining": [
        "switch.sonoff_10020b052f_1",   # Dining Area Lights
        "switch.sonoff_10020b0c7c_1",   # Dining Area Cove Lights
    ],
    "🍳 Kitchen": [
        "switch.sonoff_10020b0c7c_2",   # Kitchen Cove Lights
        "switch.magic_switch_s1e_595a_kitchen_switch_3",  # Kitchen Lights
        "switch.magic_switch_s1e_595a_kitchen_switch_2",  # Kitchen Toilet Lights
        "switch.tz3000_cayepv1a_ts011f_3",  # Induction Fan and Hob
    ],
    "🧺 Laundry": [
        "switch.magic_switch_s1e_595a_kitchen_switch_1",  # Laundry Area Lights
    ],
}


async def cmd_devices(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Open devices Mini App."""
    if not is_owner(update):
        return
    keyboard = [[InlineKeyboardButton(
        "🏠 Open Devices",
        web_app=WebAppInfo(url=WEBAPP_URL)
    )]]
    await update.message.reply_text(
        "Tap to open:",
        reply_markup=InlineKeyboardMarkup(keyboard)
    )


async def cmd_tv(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Open TV remote Mini App."""
    if not is_owner(update):
        return
    keyboard = [[InlineKeyboardButton(
        "📺 Open TV Remote",
        web_app=WebAppInfo(url=TV_REMOTE_URL)
    )]]
    await update.message.reply_text(
        "Tap to open:",
        reply_markup=InlineKeyboardMarkup(keyboard)
    )


async def cmd_sandman_app(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Open Sandman observability Mini App."""
    if not is_owner(update):
        return
    keyboard = [[InlineKeyboardButton(
        "🌙 Open Sandman",
        web_app=WebAppInfo(url=SANDMAN_APP_URL)
    )]]
    await update.message.reply_text(
        "Tap to open:",
        reply_markup=InlineKeyboardMarkup(keyboard)
    )


async def device_callback(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Handle device toggle button press."""
    query = update.callback_query
    if not query or not query.data.startswith("dev:"):
        return
    await query.answer()

    eid = query.data[4:]
    domain = eid.split(".")[0]

    # Get current state
    state_data = await asyncio.to_thread(lambda: _ha_api("GET", f"/states/{eid}"))
    if not state_data:
        await query.edit_message_text("❌ Device unreachable")
        return

    current = state_data.get("state", "off")
    service = "turn_off" if current == "on" else "turn_on"
    await asyncio.to_thread(lambda: _ha_api("POST", f"/services/{domain}/{service}", {"entity_id": eid}))

    # Wait for state change and refresh the keyboard
    await asyncio.sleep(0.5)
    devices = await asyncio.to_thread(_get_ha_device_states)
    keyboard = []
    for d_eid, name, state in devices:
        icon = "🟢" if state == "on" else "⚫" if state == "off" else "❓"
        keyboard.append([InlineKeyboardButton(
            f"{icon} {name}",
            callback_data=f"dev:{d_eid}"
        )])

    await query.edit_message_text(
        "🏠 Devices — tap to toggle",
        reply_markup=InlineKeyboardMarkup(keyboard)
    )


async def cmd_discover(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Force re-discover TV IP."""
    if not is_owner(update):
        return
    msg = await update.message.reply_text("🔍 Scanning...")
    invalidate_tv_ip()
    ip = await asyncio.to_thread(lambda: get_tv_ip(force_rediscover=True))
    if ip:
        await msg.edit_text(f"📺 Found TV at {ip}")
    else:
        await msg.edit_text("❌ TV not found")


async def cmd_netspeed(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Measure actual TV bandwidth + show throttle status."""
    if not is_owner(update):
        return
    ip = get_tv_ip_or_cached()
    msg = await update.message.reply_text("⏳ Measuring speed...")

    def measure():
        lines = []

        # Throttle status
        try:
            r = subprocess.run(["tc", "-s", "qdisc", "show", "dev", "end0"],
                               capture_output=True, text=True, timeout=5)
            if "htb" in r.stdout:
                # Extract the throttle rate from tc class
                rc = subprocess.run(["tc", "class", "show", "dev", "end0"],
                                    capture_output=True, text=True, timeout=5)
                import re as _re
                m = _re.search(r"rate (\d+[KMG]?bit)", rc.stdout.split("1:20")[-1] if "1:20" in rc.stdout else "")
                rate = m.group(1) if m else "?"
                lines.append(f"🔴 Throttle ACTIVE — limit: {rate}")
            else:
                lines.append("🟢 Throttle OFF — full speed")
        except Exception:
            lines.append("❓ Throttle status unknown")

        if not ip:
            lines.append("📶 TV IP unknown")
            return "\n".join(lines)

        # Ping
        try:
            r = subprocess.run(["ping", "-c", "3", "-W", "2", ip],
                               capture_output=True, text=True, timeout=10)
            import re as _re
            m = _re.search(r"([\d.]+)/([\d.]+)/([\d.]+)", r.stdout)
            if m:
                avg = float(m.group(2))
                if avg < 5:
                    lines.append(f"📶 Ping: {avg:.0f}ms (excellent)")
                elif avg < 20:
                    lines.append(f"📶 Ping: {avg:.0f}ms (good)")
                else:
                    lines.append(f"📶 Ping: {avg:.0f}ms (slow)")
            loss_m = _re.search(r"(\d+)% packet loss", r.stdout)
            if loss_m and int(loss_m.group(1)) > 0:
                lines.append(f"⚠️ Packet loss: {loss_m.group(1)}%")
        except Exception:
            lines.append("📶 Ping failed")

        # Quick bandwidth test: download from TV's JointSpace API and measure throughput
        try:
            import time as _time
            start = _time.time()
            # Fetch a few endpoints to measure throughput
            total_bytes = 0
            for endpoint in ["system", "applications", "channeldb/tv/channelLists/all"]:
                try:
                    r = tv_control.js_get(endpoint, timeout=5, ip=ip)
                    total_bytes += len(json.dumps(r).encode()) if r else 0
                except Exception:
                    pass
            elapsed = _time.time() - start
            if elapsed > 0 and total_bytes > 0:
                kbps = (total_bytes * 8) / elapsed / 1000
                lines.append(f"📊 API throughput: {kbps:.0f} kbps ({total_bytes/1024:.0f}KB in {elapsed:.1f}s)")
        except Exception:
            pass

        # What this means for streaming
        lines.append("")
        lines.append("📺 Netflix requirements:")
        lines.append("  4K = 25mbps, 1080p = 5mbps, 720p = 3mbps")

        return "\n".join(lines)

    result = await asyncio.to_thread(measure)
    await msg.edit_text(result)

    await update.message.reply_text("\n".join(lines) if lines else "No data")


_throttler = None  # global LinuxBandwidthThrottler instance
_throttle_timer = None  # asyncio task for auto-off
_sandman_mod = None
THROTTLE_STATE_FILE = "/share/sandman_throttle_active"


def _load_sandman_module():
    """Import sandman.py as a module."""
    global _sandman_mod
    if _sandman_mod is None:
        import importlib.util
        spec = importlib.util.spec_from_file_location("sandman", SANDMAN_SCRIPT)
        _sandman_mod = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(_sandman_mod)
    return _sandman_mod


def _cleanup_orphaned_throttle():
    """On startup, clean up any orphaned tc rules from a previous crash."""
    if not os.path.exists(THROTTLE_STATE_FILE):
        return
    log.info("Found orphaned throttle state — cleaning up tc rules and ARP...")
    tv_control.throttle_remove()
    # Send gratuitous ARP to restore correct entries
    try:
        state = json.loads(Path(THROTTLE_STATE_FILE).read_text())
        tv_ip = state.get("tv_ip")
        if tv_ip:
            mod = _load_sandman_module()
            t = mod.LinuxBandwidthThrottler(tv_ip)
            t._our_mac = t._get_our_mac()
            t._tv_mac = t._resolve_mac(tv_ip)
            t._gw_mac = t._resolve_mac("192.168.1.254")
            if t._tv_mac and t._gw_mac:
                t._restore_arp()
                log.info("ARP restored for %s", tv_ip)
    except Exception as e:
        log.warning("Orphan cleanup ARP restore failed: %s", e)
    try:
        subprocess.run(["sysctl", "-w", "net.ipv4.ip_forward=0"], capture_output=True)
    except Exception:
        pass
    try:
        os.remove(THROTTLE_STATE_FILE)
    except Exception:
        pass
    log.info("Orphaned throttle cleaned up")


def _save_throttle_state(tv_ip: str, bw: int):
    Path(THROTTLE_STATE_FILE).write_text(json.dumps({"tv_ip": tv_ip, "bw": bw}))


def _clear_throttle_state():
    try:
        os.remove(THROTTLE_STATE_FILE)
    except FileNotFoundError:
        pass


def _get_throttler():
    """Import and return a LinuxBandwidthThrottler from sandman.py."""
    global _throttler
    if _throttler and _throttler.active:
        return _throttler
    mod = _load_sandman_module()
    ip = get_tv_ip_or_cached()
    if not ip:
        return None
    _throttler = mod.LinuxBandwidthThrottler(ip)
    return _throttler


async def cmd_throttle(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    global _throttle_timer
    if not is_owner(update):
        return
    args = ctx.args
    # No args: show current status
    if not args:
        tc_check = subprocess.run(["tc", "qdisc", "show", "dev", "end0"], capture_output=True, text=True)
        is_throttled = "htb" in tc_check.stdout
        if is_throttled:
            await update.message.reply_text("🌐 Throttle is ON\n\n/throttle off — restore full speed")
        else:
            await update.message.reply_text("🌐 Throttle is OFF\n\n/throttle 2000 — degraded (deficit 0.50+)\n/throttle 1000 — visibly bad quality\n/throttle 800 — very degraded (deficit 0.70+)\n/throttle 500 — barely loads\n/throttle 2000 60 — for 60s\n\nSandman auto-throttles based on deficit level.")
        return

    arg = args[0].lower()

    if arg == "off":
        if _throttle_timer and not _throttle_timer.done():
            _throttle_timer.cancel()
            _throttle_timer = None
        # Remove tc rules only — keep MITM running for future throttles
        tv_control.throttle_remove()
        _clear_throttle_state()
        await update.message.reply_text("🌐 Throttle OFF — full speed restored")
        return

    # Parse: /throttle 5000 or /throttle 5000 60
    if arg == "on":
        bw = 8000
    elif arg.isdigit():
        bw = int(arg)
    else:
        await update.message.reply_text("Usage:\n/throttle 5000 — throttle to 5mbps (auto-off 120s)\n/throttle 5000 60 — throttle for 60s\n/throttle off — restore full speed")
        return

    duration = 0  # 0 = no auto-off (manual /throttle off required)
    if len(args) >= 2 and args[1].isdigit():
        duration = int(args[1])

    ip = get_tv_ip_or_cached()
    if not ip:
        await update.message.reply_text("❌ TV IP unknown")
        return

    # Auto-start MITM daemon if not running
    mitm_running = os.path.exists("/share/arp_mitm_state.json")
    if not mitm_running:
        msg = await update.message.reply_text(f"⏳ Starting MITM + {bw}kbps throttle...")
        # Start MITM daemon
        subprocess.Popen(["python3", "/share/arp_mitm.py"], start_new_session=True,
                         stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        await asyncio.sleep(5)  # wait for ARP spoof to stabilize
        if not os.path.exists("/share/arp_mitm_state.json"):
            await msg.edit_text("❌ MITM failed to start — TV may be off")
            return
    else:
        msg = await update.message.reply_text(f"⏳ Applying {bw}kbps throttle...")
    try:
        ok = await asyncio.to_thread(lambda: tv_control.throttle_apply(tv_ip=ip, bandwidth_kbps=bw))
        _save_throttle_state(ip, bw)

        if duration > 0:
            await msg.edit_text(f"🌐 Throttle ON — {bw}kbps (auto-off in {duration}s)")

            async def auto_off():
                await asyncio.sleep(duration)
                tv_control.throttle_remove()
                _clear_throttle_state()
                if owner_chat_id:
                    try:
                        await app.bot.send_message(chat_id=owner_chat_id, text=f"🌐 Throttle auto-off after {duration}s — full speed restored")
                    except Exception:
                        pass

            if _throttle_timer and not _throttle_timer.done():
                _throttle_timer.cancel()
            _throttle_timer = asyncio.create_task(auto_off())
        else:
            await msg.edit_text(f"🌐 Throttle ON — {bw}kbps (use /throttle off to stop)")

    except Exception as e:
        await msg.edit_text(f"❌ Throttle error: {e}")


# ── Deficit commands ───────────────────────────────────────────────────────


async def cmd_deficit(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Show or set the current attention deficit."""
    if not is_owner(update):
        return
    args = ctx.args

    # /deficit 0.7 — set deficit
    if args:
        try:
            val = float(args[0])
            val = max(0.0, min(1.0, val))
            state = read_state()
            state["deficit"] = val
            state["last_updated"] = datetime.datetime.now().isoformat()
            write_state(state)
            await update.message.reply_text(f"✅ Deficit set to {val:.2f}")
            return
        except ValueError:
            await update.message.reply_text("Usage: /deficit [0.0-1.0]")
            return

    # /deficit — show current
    state = read_state()
    d = state.get("deficit", 0.0)
    session_min = 0
    if state.get("session_start"):
        try:
            ss = datetime.datetime.fromisoformat(state["session_start"])
            session_min = (datetime.datetime.now(ss.tzinfo or datetime.timezone.utc) - ss).total_seconds() / 60
        except Exception:
            pass
    today_min = state.get("total_watch_today_min", 0)

    recovery_15 = _estimate_recovery(d, 0.15)
    recovery_05 = _estimate_recovery(d, 0.05)

    text = (
        f"🧠 Deficit: {d:.2f}\n"
        f"📊 Session: {_fmt_duration(session_min)} | Today: {_fmt_duration(today_min)}\n"
        f"🔮 Recovery: {recovery_15} to 0.15 | {recovery_05} to 0.05"
    )
    await update.message.reply_text(text)


async def cmd_simulate(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Simulate deficit after watching N hours starting at HH:MM.
    Usage: /simulate 3h 23:00"""
    if not is_owner(update):
        return
    args = ctx.args
    if not args or len(args) < 2:
        await update.message.reply_text("Usage: /simulate 3h 23:00")
        return

    # Parse duration
    dur_str = args[0].lower()
    try:
        if "h" in dur_str:
            hours = float(dur_str.replace("h", ""))
        elif "m" in dur_str:
            hours = float(dur_str.replace("m", "")) / 60
        else:
            hours = float(dur_str)
        total_min = hours * 60
    except ValueError:
        await update.message.reply_text("Bad duration. Use e.g. 3h or 90m")
        return

    # Parse start time
    try:
        start_h, start_m = map(int, args[1].split(":"))
    except ValueError:
        await update.message.reply_text("Bad time. Use e.g. 23:00")
        return

    state = read_state()
    deficit = state.get("deficit", 0.0)
    start_deficit = deficit

    # Simulate minute by minute
    cur_h, cur_m = start_h, start_m
    for _ in range(int(total_min)):
        tm = _time_multiplier_at(cur_h, cur_m)
        session_min_so_far = _ + 1
        impact = BASE_IMPACT_RATE * math.log(1 + session_min_so_far / 30) * tm
        deficit += impact
        deficit = min(deficit, 1.0)
        cur_m += 1
        if cur_m >= 60:
            cur_m = 0
            cur_h = (cur_h + 1) % 24

    end_time = f"{cur_h:02d}:{cur_m:02d}"
    text = (
        f"📈 Simulation: {args[0]} from {args[1]}\n"
        f"Start deficit: {start_deficit:.2f}\n"
        f"End deficit:   {deficit:.2f} (at {end_time})\n"
        f"Delta:         +{deficit - start_deficit:.2f}"
    )
    await update.message.reply_text(text)


async def cmd_curve(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Show text bar chart of deficit projection over next 4 hours."""
    if not is_owner(update):
        return

    state = read_state()
    deficit = state.get("deficit", 0.0)
    now = datetime.datetime.now()
    cur_h, cur_m = now.hour, now.minute

    # Assume continuous watching from now
    lines = []
    BAR_WIDTH = 20
    session_min_so_far = 0
    if state.get("session_start"):
        try:
            ss = datetime.datetime.fromisoformat(state["session_start"])
            session_min_so_far = max(0, (now - ss.replace(tzinfo=None)).total_seconds() / 60)
        except Exception:
            pass

    # Show every 30 minutes for 4 hours (9 points)
    for step in range(9):
        t_h = (cur_h + (cur_m + step * 30) // 60) % 24
        t_m = (cur_m + step * 30) % 60
        label = f"{t_h:02d}:{t_m:02d}"

        # Simulate from current to this point
        sim_deficit = deficit
        for m in range(step * 30):
            sm = session_min_so_far + m + 1
            tm = _time_multiplier_at((cur_h + (cur_m + m) // 60) % 24, (cur_m + m) % 60)
            impact = BASE_IMPACT_RATE * math.log(1 + sm / 30) * tm
            sim_deficit += impact
            sim_deficit = min(sim_deficit, 1.0)

        filled = int(sim_deficit * BAR_WIDTH)
        bar = "█" * filled + "░" * (BAR_WIDTH - filled)
        warn = " ⚡" if sim_deficit >= 0.7 else ""
        lines.append(f"{label} {bar} {sim_deficit:.2f}{warn}")

    await update.message.reply_text("📊 Deficit projection (if watching):\n```\n" + "\n".join(lines) + "\n```", parse_mode=ParseMode.MARKDOWN)


async def cmd_speed(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Set sandman speed multiplier. /speed 10 = 10x faster, /speed 1 = normal."""
    if not is_owner(update):
        return
    args = ctx.args
    if not args:
        # Show current speed
        try:
            with open(SPEED_FILE) as f:
                data = json.load(f)
            mult = data.get("speed", 1)
        except Exception:
            mult = 1
        await update.message.reply_text(f"⏩ Speed multiplier: {mult}x\nUsage: /speed 10 or /speed 1")
        return

    try:
        mult = float(args[0])
        if mult < 0.1 or mult > 100:
            await update.message.reply_text("Speed must be 0.1–100")
            return
    except ValueError:
        await update.message.reply_text("Usage: /speed 10")
        return

    with open(SPEED_FILE, "w") as f:
        json.dump({"speed": mult}, f)

    if mult == 1:
        await update.message.reply_text("⏩ Speed reset to 1x (normal)")
    else:
        await update.message.reply_text(f"⏩ Speed set to {mult}x")


async def cmd_history(update: Update, ctx: ContextTypes.DEFAULT_TYPE):
    """Show daily watch history from sandman_state.json."""
    if not is_owner(update):
        return
    state = read_state()
    daily_log = state.get("daily_log", {})
    if not daily_log:
        await update.message.reply_text("📅 No watch history yet.")
        return

    lines = ["📅 Watch history:"]
    # Sort by date, show last 14 days
    sorted_dates = sorted(daily_log.keys())[-14:]
    for date_str in sorted_dates:
        minutes = daily_log[date_str]
        if isinstance(minutes, dict):
            minutes = minutes.get("watch_min", 0)
        if minutes == 0:
            lines.append(f"  {date_str}: 0m (rest day ✅)")
        else:
            lines.append(f"  {date_str}: {_fmt_duration(minutes)}")

    await update.message.reply_text("\n".join(lines))


# ── Background tasks ───────────────────────────────────────────────────────


async def log_watcher(app: Application):
    """Poll sandman.log for new lines and push to Telegram."""
    global log_offset
    if os.path.exists(LOG_FILE):
        log_offset = os.path.getsize(LOG_FILE)
    else:
        log_offset = 0

    while True:
        await asyncio.sleep(LOG_POLL_INTERVAL)
        if not owner_chat_id:
            continue
        try:
            if not os.path.exists(LOG_FILE):
                continue
            size = os.path.getsize(LOG_FILE)
            if size < log_offset:
                log_offset = 0
            if size == log_offset:
                continue
            with open(LOG_FILE, "r") as f:
                f.seek(log_offset)
                new_lines = f.readlines()
                log_offset = f.tell()
            if not new_lines:
                continue
            # Format and send — skip DEBUG level and boring lines
            formatted = []
            for l in new_lines:
                if not l.strip():
                    continue
                if "[DEBUG]" in l:
                    continue
                if "Auto-discovering" in l:
                    continue
                formatted.append(fmt_log_line(l))
            if not formatted:
                continue
            text = "\n".join(formatted)
            for i in range(0, len(text), 4000):
                chunk = text[i:i + 4000]
                try:
                    await app.bot.send_message(chat_id=owner_chat_id, text=chunk)
                except Exception as e:
                    log.warning("Failed to send log: %s", e)
        except Exception as e:
            log.warning("Log watcher error: %s", e)


APP_NAMES = tv_control.APP_NAMES


last_muted: bool | None = None
last_ambilight: str | None = None


async def state_monitor(app: Application):
    """Periodically check TV state and notify on changes — power, app, volume, mute, ambilight."""
    global last_power, last_app, last_volume, last_muted, last_ambilight, last_playback, last_title
    while True:
        await asyncio.sleep(5)  # poll every 5s for responsiveness
        if not owner_chat_id:
            continue
        try:
            pw = await asyncio.to_thread(lambda: js_get("powerstate"))
            if pw:
                power = pw.get("powerstate")
                if last_power is not None and power != last_power:
                    icon = "📺" if power == "On" else "💤"
                    try:
                        await app.bot.send_message(
                            chat_id=owner_chat_id, text=f"{icon} TV → {power}"
                        )
                    except Exception:
                        pass
                last_power = power

                if power == "On":
                    # App + playback tracking via ADB
                    pkg, playback, extras = await asyncio.to_thread(adb_get_current_app)
                    if pkg and pkg != last_app:
                        friendly = APP_NAMES.get(pkg, pkg)
                        if last_app is not None:
                            parts = [f"📱 App → {friendly}"]
                            if playback:
                                parts[0] += f" ({playback})"
                            if extras and extras.get("video_id"):
                                parts.append(f"🎬 youtu.be/{extras['video_id']}")
                            try:
                                await app.bot.send_message(
                                    chat_id=owner_chat_id, text="\n".join(parts)
                                )
                            except Exception:
                                pass
                        last_app = pkg

                    # Playback state tracking (PLAYING/PAUSED)
                    if playback and playback != last_playback:
                        if last_playback is not None:
                            icon = "▶️" if playback == "PLAYING" else "⏸" if playback == "PAUSED" else "⏹"
                            try:
                                await app.bot.send_message(
                                    chat_id=owner_chat_id, text=f"{icon} {playback}"
                                )
                            except Exception:
                                pass
                        last_playback = playback

                    # Title/video change tracking
                    if extras and extras.get("title"):
                        title = extras["title"]
                        if title != last_title:
                            if last_title is not None:
                                msg_text = f"🎬 {title}"
                                if extras.get("artist"):
                                    msg_text += f" — {extras['artist']}"
                                try:
                                    await app.bot.send_message(
                                        chat_id=owner_chat_id, text=msg_text
                                    )
                                except Exception:
                                    pass
                            last_title = title

                    # Volume + mute tracking
                    vol_data = await asyncio.to_thread(lambda: js_get("audio/volume"))
                    if vol_data:
                        cur = vol_data.get("current")
                        muted = vol_data.get("muted", False)

                        # Volume change
                        if cur is not None and last_volume is not None and cur != last_volume:
                            diff = cur - last_volume
                            arrow = "🔺" if diff > 0 else "🔻"
                            try:
                                await app.bot.send_message(
                                    chat_id=owner_chat_id,
                                    text=f"{arrow} Volume {last_volume} → {cur} ({'+' if diff > 0 else ''}{diff})"
                                )
                            except Exception:
                                pass
                        last_volume = cur

                        # Mute change
                        if last_muted is not None and muted != last_muted:
                            try:
                                await app.bot.send_message(
                                    chat_id=owner_chat_id,
                                    text=f"{'🔇 Muted' if muted else '🔊 Unmuted'}"
                                )
                            except Exception:
                                pass
                        last_muted = muted

                    # Ambilight tracking
                    ambi = await asyncio.to_thread(lambda: js_get("ambilight/power"))
                    if ambi:
                        ambi_state = ambi.get("power", "")
                        if last_ambilight is not None and ambi_state != last_ambilight:
                            try:
                                await app.bot.send_message(
                                    chat_id=owner_chat_id,
                                    text=f"💡 Ambilight → {ambi_state}"
                                )
                            except Exception:
                                pass
                        last_ambilight = ambi_state

            else:
                if last_power == "On":
                    last_power = None
                    try:
                        await app.bot.send_message(
                            chat_id=owner_chat_id, text="❓ TV unreachable"
                        )
                    except Exception:
                        pass
        except Exception as e:
            log.warning("State monitor error: %s", e)


async def adb_input_monitor(app: Application):
    """Monitor TV remote key presses via ADB getevent. Requires wireless debugging."""
    if not adb_available():
        log.info("ADB not available — skipping input monitor")
        return

    KEY_NAMES = {
        "0073": "Vol+", "0072": "Vol-", "0074": "Power", "0066": "Home",
        "009e": "Back", "0160": "OK/Confirm", "0067": "Up", "006c": "Down",
        "0069": "Left", "006a": "Right", "00a4": "Pause", "00cf": "Play",
        "00d0": "PlayPause", "00a5": "FastFwd", "00a6": "Rewind",
        "0071": "Mute", "0192": "Menu", "0174": "Exit",
        "0193": "Red", "0194": "Green", "0195": "Yellow", "0196": "Blue",
        "00e2": "CH+", "00e3": "CH-", "0090": "Num0", "0087": "Num1",
        "0088": "Num2", "0089": "Num3", "008a": "Num4", "008b": "Num5",
        "008c": "Num6", "008d": "Num7", "008e": "Num8", "008f": "Num9",
        "0166": "Info", "016b": "Guide", "00ae": "Source",
    }

    ip = None
    connected = False

    while True:
        if not owner_chat_id:
            await asyncio.sleep(5)
            continue

        # Get TV IP and connect ADB
        if not connected:
            ip = get_tv_ip_or_cached()
            if not ip:
                await asyncio.sleep(10)
                continue
            target = f"{ip}:5555"
            try:
                r = await asyncio.to_thread(
                    lambda: subprocess.run(["adb", "connect", target],
                                           capture_output=True, text=True, timeout=10))
                if "connected" in r.stdout.lower() or "already" in r.stdout.lower():
                    connected = True
                    log.info("ADB connected to %s for input monitoring", target)
                else:
                    log.info("ADB connect failed: %s", r.stdout.strip())
                    await asyncio.sleep(30)
                    continue
            except Exception as e:
                log.warning("ADB connect error: %s", e)
                await asyncio.sleep(30)
                continue

        # Stream getevent — each key press is a line
        target = f"{ip}:5555"
        try:
            proc = await asyncio.create_subprocess_exec(
                "adb", "-s", target, "shell", "getevent", "-l",
                stdout=asyncio.subprocess.PIPE,
                stderr=asyncio.subprocess.PIPE,
            )

            while True:
                try:
                    line = await asyncio.wait_for(proc.stdout.readline(), timeout=60)
                except asyncio.TimeoutError:
                    continue
                if not line:
                    break
                text = line.decode(errors="replace").strip()
                # Filter for key DOWN events: EV_KEY KEY_xxx DOWN
                if "EV_KEY" not in text or "DOWN" not in text:
                    continue
                # Extract key code
                parts = text.split()
                # Format: /dev/input/eventX EV_KEY KEY_NAME DOWN
                key_label = None
                for p in parts:
                    if p.startswith("KEY_"):
                        key_label = p.replace("KEY_", "")
                        break
                    # Sometimes it's a hex code
                    if len(p) == 4 and all(c in "0123456789abcdef" for c in p.lower()):
                        key_label = KEY_NAMES.get(p.lower(), p)

                if key_label:
                    try:
                        await app.bot.send_message(
                            chat_id=owner_chat_id, text=f"🎮 Remote: {key_label}"
                        )
                    except Exception:
                        pass

        except Exception as e:
            log.warning("ADB getevent error: %s", e)
            connected = False
            await asyncio.sleep(10)
            try:
                proc.kill()
            except Exception:
                pass


async def post_init(app: Application):
    """Start background tasks after bot init."""
    global owner_chat_id
    owner_chat_id = load_owner()
    log.info("Bot starting. Owner chat_id: %s", owner_chat_id)

    await app.bot.set_my_commands([
        BotCommand("status", "Quick overview"),
        BotCommand("tv", "📺 TV Remote"),
        BotCommand("sandman", "🌙 Sandman"),
        BotCommand("devices", "🏠 Devices"),
        BotCommand("power", "Toggle TV power"),
        BotCommand("pause", "Play/Pause"),
        BotCommand("mute", "Toggle mute"),
        BotCommand("ss", "Screenshot to chat"),
    ])

    # Clean up any orphaned throttle from a previous crash
    await asyncio.to_thread(_cleanup_orphaned_throttle)

    log.info("Discovering TV...")
    ip = await asyncio.to_thread(get_tv_ip)
    log.info("TV IP: %s", ip)

    asyncio.create_task(log_watcher(app))
    asyncio.create_task(state_monitor(app))
    asyncio.create_task(adb_input_monitor(app))


# ── Main ────────────────────────────────────────────────────────────────────


def _graceful_shutdown(signum, frame):
    """Clean up throttle on SIGTERM/SIGINT before exit."""
    log.info("Received signal %d — cleaning up...", signum)
    if _throttler and _throttler.active:
        try:
            _throttler.stop()
            _clear_throttle_state()
            log.info("Throttle stopped gracefully")
        except Exception as e:
            log.warning("Throttle cleanup failed: %s", e)
    sys.exit(0)


import signal as _signal
_signal.signal(_signal.SIGTERM, _graceful_shutdown)
_signal.signal(_signal.SIGINT, _graceful_shutdown)


def main():
    app = Application.builder().token(BOT_TOKEN).post_init(post_init).build()

    app.add_handler(CommandHandler("start", cmd_start))
    app.add_handler(CommandHandler("status", cmd_status))
    app.add_handler(CommandHandler("deficit", cmd_deficit))
    app.add_handler(CommandHandler("simulate", cmd_simulate))
    app.add_handler(CommandHandler("curve", cmd_curve))
    app.add_handler(CommandHandler("speed", cmd_speed))
    app.add_handler(CommandHandler("history", cmd_history))
    app.add_handler(CommandHandler("screenshot", cmd_screenshot))
    app.add_handler(CommandHandler("ss", cmd_screenshot))
    app.add_handler(CommandHandler("mute", cmd_mute))
    app.add_handler(CommandHandler("volume", cmd_volume))
    app.add_handler(CommandHandler("pause", cmd_pause))
    app.add_handler(CommandHandler("play", cmd_play))
    app.add_handler(CommandHandler("home", cmd_home))
    app.add_handler(CommandHandler("power", cmd_power))
    app.add_handler(CommandHandler("test", cmd_test))
    app.add_handler(CommandHandler("prod", cmd_prod))
    app.add_handler(CommandHandler("stop", cmd_stop))
    app.add_handler(CommandHandler("start_sandman", cmd_start_sandman))
    app.add_handler(CommandHandler("log", cmd_log))
    app.add_handler(CommandHandler("discover", cmd_discover))
    app.add_handler(CommandHandler("devices", cmd_devices))
    app.add_handler(CommandHandler("tv", cmd_tv))
    app.add_handler(CommandHandler("sandman", cmd_sandman_app))
    app.add_handler(CallbackQueryHandler(device_callback, pattern="^dev:"))
    app.add_handler(CommandHandler("mitm", cmd_mitm))
    app.add_handler(CommandHandler("netspeed", cmd_netspeed))
    app.add_handler(CommandHandler("throttle", cmd_throttle))

    log.info("Sandman bot starting polling...")
    app.run_polling(allowed_updates=Update.ALL_TYPES)


if __name__ == "__main__":
    main()
