#!/usr/bin/env python3
"""splice-ctx — SPLICE context/memory manager CLI.

Discovers context files (Claude memories, project MEMORY.md cards, workspace
docs), tracks which are enabled for AI asks, and emits a token-budgeted
context blob. The override rule: the AI (`--by ai`) never flips an entry the
user set by hand (`set_by == "user"`).

State: ~/.config/splice/contexts.json
"""
import argparse
import glob
import json
import os
import re
import signal
import subprocess
import sys

signal.signal(signal.SIGPIPE, signal.SIG_DFL)


def _share_dir_bootstrap():
    env = os.environ.get("SPLICE_SHARE")
    if env and os.path.isdir(env):
        return os.path.abspath(env)
    here = os.path.dirname(os.path.abspath(sys.argv[0]))
    dev = os.path.normpath(os.path.join(here, "..", "share", "splice"))
    if os.path.isfile(os.path.join(dev, "splice_common.py")):
        return dev
    return "/usr/share/splice"


sys.path.insert(0, _share_dir_bootstrap())
import splice_common as sc  # noqa: E402


# ---------------------------------------------------------------- helpers

def read_text(path):
    try:
        with open(path, encoding="utf-8", errors="replace") as f:
            return f.read()
    except OSError:
        return None


def title_for(path, text):
    """First markdown H1, else frontmatter `description:`, else filename."""
    if text:
        lines = text.splitlines()
        for line in lines:
            m = re.match(r"^#\s+(.+?)\s*$", line)
            if m:
                return m.group(1)
        # frontmatter block: leading --- ... ---
        if lines and lines[0].strip() == "---":
            for line in lines[1:]:
                if line.strip() in ("---", "..."):
                    break
                m = re.match(r"^description:\s*(.+?)\s*$", line)
                if m:
                    return m.group(1).strip("\"'")
    return os.path.basename(path)


def discover():
    """Yield (id, path, kind) for every context found via settings."""
    settings = sc.load_settings()
    seen = set()
    out = []

    def emit(cid, path, kind):
        path = os.path.abspath(path)
        if cid in seen or not os.path.isfile(path):
            return
        seen.add(cid)
        out.append((cid, path, kind))

    for d in settings.get("memory_dirs", []):
        d = os.path.expanduser(d)
        if not os.path.isdir(d):
            continue
        for name in sorted(os.listdir(d)):
            if not name.lower().endswith((".md", ".markdown", ".txt")):
                continue
            base = os.path.splitext(name)[0]
            emit("mem:" + base, os.path.join(d, name), "memory")

    for pattern in settings.get("project_globs", []):
        for path in sorted(glob.glob(os.path.expanduser(pattern))):
            if not os.path.isfile(path):
                continue
            name = os.path.basename(path)
            if name == "MEMORY.md":
                cid = "proj:" + os.path.basename(os.path.dirname(path))
            else:
                cid = "ws:" + name
            emit(cid, path, "project")
    return out


def total_enabled_tokens(data):
    return sum(
        c.get("tokens_est", 0)
        for c in data["contexts"]
        if c.get("enabled") and not c.get("missing")
    )


# ---------------------------------------------------------------- commands

def cmd_scan(args):
    data = sc.load_contexts()
    old = {c["id"]: c for c in data["contexts"]}
    found = discover()
    found_ids = {cid for cid, _, _ in found}

    contexts = []
    new = removed = missing = 0
    for cid, path, kind in found:
        prev = old.get(cid)
        text = read_text(path)
        entry = {
            "id": cid,
            "title": title_for(path, text),
            "path": path,
            "kind": kind,
            "enabled": prev.get("enabled", False) if prev else False,
            "set_by": prev.get("set_by", "default") if prev else "default",
            "tokens_est": sc.est_tokens(text) if text is not None else 0,
        }
        if prev is None:
            new += 1
        contexts.append(entry)

    # user-added files: keep even if the path vanished (mark missing);
    # everything else whose path vanished is dropped.
    for cid, prev in old.items():
        if cid in found_ids:
            continue
        if prev.get("kind") == "file":
            entry = dict(prev)
            text = read_text(prev.get("path", ""))
            if text is None:
                entry["missing"] = True
                missing += 1
            else:
                entry.pop("missing", None)
                entry["title"] = prev.get("title") or title_for(prev["path"], text)
                entry["tokens_est"] = sc.est_tokens(text)
            contexts.append(entry)
        elif os.path.isfile(prev.get("path", "")):
            # still on disk but no longer matched by settings → dropped too;
            # spec drops entries not re-discovered unless kind=file
            removed += 1
        else:
            removed += 1

    data["contexts"] = contexts
    sc.save_contexts(data)
    print(
        "scan: %d contexts (%d new, %d removed, %d missing), ~%dt enabled"
        % (len(contexts), new, removed, missing, total_enabled_tokens(data))
    )
    return 0


def cmd_list(args):
    data = sc.load_contexts()
    total = total_enabled_tokens(data)
    if args.json:
        print(json.dumps(
            {"contexts": data["contexts"], "total_enabled_tokens": total},
            indent=2,
        ))
        return 0
    if not data["contexts"]:
        print("no contexts — run: splice-ctx scan")
        return 0
    for c in data["contexts"]:
        mark = "x" if c.get("enabled") else " "
        miss = "  [missing]" if c.get("missing") else ""
        print("[%s] %-28s ~%dt  (%s)  %s%s" % (
            mark, c["id"], c.get("tokens_est", 0),
            c.get("set_by", "default"), c.get("title", ""), miss,
        ))
    print("total enabled: ~%dt" % total)
    return 0


def _set_enabled(ids, by, value):
    data = sc.load_contexts()
    byid = {c["id"]: c for c in data["contexts"]}
    rc = 0
    changed = False
    for cid in ids:
        c = byid.get(cid)
        if c is None:
            print("splice-ctx: unknown context id: %s" % cid, file=sys.stderr)
            rc = 1
            continue
        if by == "ai" and c.get("set_by") == "user":
            print("respecting user override: %s" % cid)
            continue
        c["enabled"] = value
        c["set_by"] = by
        changed = True
        print("%s: %s (by %s)" % ("enabled" if value else "disabled", cid, by))
    if changed:
        sc.save_contexts(data)
    return rc


def cmd_enable(args):
    return _set_enabled(args.ids, args.by, True)


def cmd_disable(args):
    return _set_enabled(args.ids, args.by, False)


def cmd_release(args):
    data = sc.load_contexts()
    byid = {c["id"]: c for c in data["contexts"]}
    rc = 0
    changed = False
    for cid in args.ids:
        c = byid.get(cid)
        if c is None:
            print("splice-ctx: unknown context id: %s" % cid, file=sys.stderr)
            rc = 1
            continue
        c["set_by"] = "default"
        changed = True
        print("released: %s" % cid)
    if changed:
        sc.save_contexts(data)
    return rc


def cmd_add(args):
    path = os.path.abspath(os.path.expanduser(args.path))
    text = read_text(path)
    if text is None:
        print("splice-ctx: cannot read file: %s" % path, file=sys.stderr)
        return 1
    data = sc.load_contexts()
    # reuse the entry if this exact path was added before
    entry = next(
        (c for c in data["contexts"]
         if c.get("kind") == "file" and c.get("path") == path),
        None,
    )
    if entry is None:
        base = os.path.splitext(os.path.basename(path))[0] or "file"
        existing = {c["id"] for c in data["contexts"]}
        cid, n = "file:" + base, 2
        while cid in existing:
            cid = "file:%s-%d" % (base, n)
            n += 1
        entry = {"id": cid, "path": path, "kind": "file"}
        data["contexts"].append(entry)
    entry["title"] = args.title or title_for(path, text)
    entry["enabled"] = True
    entry["set_by"] = "user"
    entry["tokens_est"] = sc.est_tokens(text)
    entry.pop("missing", None)
    sc.save_contexts(data)
    print("added: %s ~%dt  %s" % (entry["id"], entry["tokens_est"], entry["title"]))
    return 0


TRIM_MARKER = "[...trimmed by splice for token budget...]"


def cmd_blob(args):
    data = sc.load_contexts()
    enabled = [c for c in data["contexts"] if c.get("enabled")]
    if not enabled:
        return 0
    budget = args.max_tokens
    if budget is None:
        budget = int(sc.load_settings().get("max_context_tokens", 6000))
    # est_tokens = len//4, so total est <= budget  <=>  total chars <= budget*4 + 3
    char_limit = max(0, budget) * 4 + 3
    used_chars = 0
    out = []
    for c in enabled:
        text = read_text(c.get("path", ""))
        if text is None:
            continue
        sep = 1 if out else 0  # "\n" between blocks when joined
        block = "### CONTEXT: %s\n%s\n" % (c.get("title", c["id"]), text.rstrip("\n"))
        if used_chars + sep + len(block) <= char_limit:
            out.append(block)
            used_chars += sep + len(block)
            continue
        # truncate at a line boundary so the total stays within budget
        header = "### CONTEXT: %s\n" % c.get("title", c["id"])
        fixed = sep + len(header) + len(TRIM_MARKER) + 1  # +1: marker newline
        body_chars = 0
        kept = []
        for line in text.rstrip("\n").splitlines():
            cost = len(line) + 1  # line + its newline
            if used_chars + fixed + body_chars + cost > char_limit:
                break
            kept.append(line)
            body_chars += cost
        if not kept:
            continue  # not even room for a trimmed stub; try the next one
        block = header + "\n".join(kept) + "\n" + TRIM_MARKER + "\n"
        out.append(block)
        used_chars += sep + len(block)
    sys.stdout.write("\n".join(out))
    return 0


def _auto_hints(cwd):
    hints = []
    try:
        r = subprocess.run(
            ["git", "-C", cwd, "remote", "-v"],
            capture_output=True, text=True, timeout=3,
        )
        if r.returncode == 0 and r.stdout.strip():
            first = r.stdout.strip().splitlines()[0]
            hints.append("git remote: " + first)
    except (OSError, subprocess.TimeoutExpired):
        pass
    try:
        names = sorted(os.listdir(cwd))[:20]
        if names:
            hints.append("dir entries: " + ", ".join(names))
    except OSError:
        pass
    return hints


def _parse_id_array(raw, known_ids):
    """Defensively extract a JSON array of known ids from model output."""
    if not raw:
        return None
    text = raw.strip()
    # strip markdown code fences if present
    text = re.sub(r"^```[a-zA-Z]*\s*|\s*```$", "", text).strip()
    candidates = [text]
    m = re.search(r"\[.*?\]", text, re.DOTALL)
    if m:
        candidates.append(m.group(0))
    for cand in candidates:
        try:
            arr = json.loads(cand)
        except ValueError:
            continue
        if not isinstance(arr, list):
            continue
        ids = [x for x in arr if isinstance(x, str) and x in known_ids]
        return ids[:8]
    return None


def cmd_auto(args):
    data = sc.load_contexts()
    contexts = data["contexts"]
    if not contexts:
        print("auto: no contexts known — run: splice-ctx scan", file=sys.stderr)
        return 1
    cwd = os.path.abspath(os.path.expanduser(args.cwd or os.getcwd()))
    known_ids = {c["id"] for c in contexts}

    lines = ["%s — %s" % (c["id"], c.get("title", "")) for c in contexts]
    hints = _auto_hints(cwd)
    prompt = (
        "Select context files relevant to a terminal AI assistant working in "
        "the directory below. Contexts (id — title):\n"
        + "\n".join(lines)
        + "\n\nWorking directory: " + (os.path.basename(cwd) or "/")
    )
    if hints:
        prompt += "\n" + "\n".join(hints)
    prompt += (
        "\n\nReply with ONLY a JSON array of the most relevant context ids "
        '(at most 8), e.g. ["mem:example"]. No prose.'
    )

    model = sc.load_settings().get("fix_model", "haiku")
    raw = sc.run_claude(prompt, model=model)
    picked = _parse_id_array(raw, known_ids)
    if picked is None:
        print("auto: model call failed or returned no usable id list — no changes",
              file=sys.stderr)
        return 1

    # Apply via the --by ai path: enable picked, disable the rest.
    # The override rule keeps every user-pinned entry untouched.
    picked_set = set(picked)
    for c in contexts:
        want = c["id"] in picked_set
        if c.get("enabled") == want and c.get("set_by") != "user":
            # still record ai stewardship only when flipping; leave as-is
            continue
        if c.get("set_by") == "user":
            if c.get("enabled") != want:
                print("respecting user override: %s" % c["id"])
            continue
        c["enabled"] = want
        c["set_by"] = "ai"
        print("%s: %s (by ai)" % ("enabled" if want else "disabled", c["id"]))
    sc.save_contexts(data)
    print("auto: %d context(s) selected" % len(picked))
    return 0


# ---------------------------------------------------------------- main

def main(argv=None):
    p = argparse.ArgumentParser(
        prog="splice-ctx",
        description="SPLICE context/memory manager (state: ~/.config/splice/contexts.json)",
    )
    sub = p.add_subparsers(dest="cmd", required=True)

    sub.add_parser("scan", help="discover contexts from settings dirs/globs")

    sp = sub.add_parser("list", help="list contexts")
    sp.add_argument("--json", action="store_true", help="machine-readable output")

    for name, help_ in (("enable", "enable contexts"), ("disable", "disable contexts")):
        sp = sub.add_parser(name, help=help_)
        sp.add_argument("ids", nargs="+", metavar="id")
        sp.add_argument("--by", choices=["user", "ai"], default="user")

    sp = sub.add_parser("release", help="return contexts to AI-manageable (set_by=default)")
    sp.add_argument("ids", nargs="+", metavar="id")

    sp = sub.add_parser("add", help="register an arbitrary file as a context")
    sp.add_argument("path")
    sp.add_argument("--title")

    sp = sub.add_parser("blob", help="print enabled contexts as ### CONTEXT blocks")
    sp.add_argument("--max-tokens", type=int, default=None)

    sp = sub.add_parser("auto", help="let the model pick relevant contexts (one cheap call)")
    sp.add_argument("--cwd", default=None)

    args = p.parse_args(argv)
    handler = {
        "scan": cmd_scan, "list": cmd_list, "enable": cmd_enable,
        "disable": cmd_disable, "release": cmd_release, "add": cmd_add,
        "blob": cmd_blob, "auto": cmd_auto,
    }[args.cmd]
    return handler(args)


if __name__ == "__main__":
    sys.exit(main())
