#!/usr/bin/env python3
import argparse
import json
import os
import re
from typing import Any, Dict, List, Tuple

from dotenv import load_dotenv
from smolagents import InferenceClientModel, LiteLLMModel, ToolCallingAgent, tool


@tool
def detect_incident_type(issue_text: str) -> str:
    """장애 텍스트를 incident type으로 분류한다.

    Args:
        issue_text: 장애/운영 이슈 원문
    """
    text = issue_text.lower()

    rules = {
        "payment": ["결제", "청구", "환불", "payment", "billing", "카드"],
        "auth": ["로그인", "인증", "비밀번호", "otp", "account", "signin"],
        "latency": ["느려", "지연", "timeout", "latency", "응답", "500", "502", "503"],
    }

    for incident_type, keywords in rules.items():
        if any(keyword in text for keyword in keywords):
            return incident_type

    return "general"


@tool
def assess_severity(incident_type: str, users_affected: int = 0, payment_path: int = 0) -> str:
    """장애 심각도를 sev1/sev2/sev3로 산정한다.

    Args:
        incident_type: detect_incident_type 결과(payment/auth/latency/general)
        users_affected: 영향 사용자 수
        payment_path: 결제 경로 영향 여부(0/1)
    """
    t = incident_type.strip().lower()
    users = int(users_affected)
    payment = int(payment_path)

    if payment == 1 and t in {"payment", "auth"}:
        return "sev1"
    if users >= 5000:
        return "sev1"
    if users >= 1000:
        return "sev2"
    if t == "general":
        return "sev3"
    return "sev2"


@tool
def estimate_eta_minutes(severity: str) -> int:
    """심각도 기준 1차 안정화 목표 ETA(분)를 반환한다.

    Args:
        severity: assess_severity 결과(sev1/sev2/sev3)
    """
    s = severity.strip().lower()
    if s == "sev1":
        return 30
    if s == "sev2":
        return 120
    return 480


@tool
def choose_response_lane(severity: str, incident_type: str) -> str:
    """장애 대응 lane 코드를 반환한다.

    Args:
        severity: sev1/sev2/sev3
        incident_type: payment/auth/latency/general
    """
    s = severity.strip().lower()
    t = incident_type.strip().lower()

    if s == "sev1":
        return "WAR_ROOM"
    if s == "sev2" and t in {"payment", "auth", "latency"}:
        return "ONCALL_TRIAGE"
    return "BACKLOG_REVIEW"


@tool
def choose_comms_channel(severity: str, users_affected: int = 0) -> str:
    """권장 커뮤니케이션 채널 코드를 반환한다.

    Args:
        severity: sev1/sev2/sev3
        users_affected: 영향 사용자 수
    """
    s = severity.strip().lower()
    users = int(users_affected)

    if s == "sev1" or users >= 2000:
        return "STATUS_PAGE_AND_SLACK"
    if s == "sev2":
        return "SLACK_ONLY"
    return "TICKET_COMMENT"


@tool
def choose_rollback_policy(severity: str, payment_path: int = 0) -> str:
    """권장 롤백 정책 코드를 반환한다.

    Args:
        severity: sev1/sev2/sev3
        payment_path: 결제 경로 영향 여부(0/1)
    """
    s = severity.strip().lower()
    payment = int(payment_path)

    if payment == 1 and s in {"sev1", "sev2"}:
        return "CANARY_AND_ROLLBACK_READY"
    if s == "sev1":
        return "IMMEDIATE_ROLLBACK"
    if s == "sev2":
        return "FEATURE_FLAG_FIRST"
    return "NORMAL_DEPLOY"


def select_model(provider: str = "auto"):
    load_dotenv()

    openai_key = os.getenv("OPENAI_API_KEY")
    hf_token = os.getenv("HF_TOKEN")

    if provider in {"auto", "openai"} and openai_key:
        return LiteLLMModel(model_id="openai/gpt-4o-mini", api_key=openai_key), "openai/gpt-4o-mini"

    if provider in {"auto", "huggingface"} and hf_token:
        return InferenceClientModel(token=hf_token), "InferenceClientModel(default)"

    raise RuntimeError("모델 키를 찾지 못했습니다. OPENAI_API_KEY 또는 HF_TOKEN을 .env에 설정하세요.")


def parse_final_line(result_text: str) -> Tuple[str, str, str, str, int, str]:
    pattern = re.compile(
        r"final\s*:\s*type\s*=\s*([a-z_]+)\s*;\s*severity\s*=\s*(sev[123])\s*;\s*lane\s*=\s*([A-Z_]+)\s*;\s*comms\s*=\s*([A-Z_]+)\s*;\s*eta\s*=\s*(\d+)\s*;\s*rollback\s*=\s*([A-Z_]+)",
        re.IGNORECASE | re.DOTALL,
    )
    match = pattern.search(result_text)
    if not match:
        return "", "", "", "", 0, ""

    incident_type = match.group(1).lower()
    severity = match.group(2).lower()
    lane = match.group(3).upper()
    comms = match.group(4).upper()
    eta = int(match.group(5))
    rollback = match.group(6).upper()
    return incident_type, severity, lane, comms, eta, rollback


def build_manager_task(issue_text: str, users_affected: int, payment_path: int) -> str:
    return (
        "너는 incident commander 매니저다.\\n"
        "아래 입력을 바탕으로 managed agent를 사용해 운영 결정을 내려라.\\n"
        f"- issue_text: {issue_text}\\n"
        f"- users_affected: {users_affected}\\n"
        f"- payment_path: {payment_path}\\n\\n"
        "반드시 다음 순서로 호출해라:\\n"
        "1) risk_analyst: type/severity/eta를 도출\\n"
        "2) response_planner: severity/type 기반 lane/comms/rollback을 도출\\n"
        "최종 출력 마지막 줄은 정확히 아래 형식으로만 출력해라:\\n"
        "FINAL: type=<type>; severity=<severity>; lane=<lane>; comms=<comms>; eta=<eta>; rollback=<rollback>"
    )


def build_agents(model):
    risk_agent = ToolCallingAgent(
        tools=[detect_incident_type, assess_severity, estimate_eta_minutes],
        model=model,
        max_steps=6,
        name="risk_analyst",
        description="장애의 type/severity/eta를 산정하는 전문 에이전트",
        instructions=(
            "반드시 도구를 순서대로 사용해라: detect_incident_type -> assess_severity -> estimate_eta_minutes. "
            "마지막 줄은 'RISK_FINAL: type=<type>; severity=<severity>; eta=<eta>' 형식으로만 출력한다."
        ),
    )

    response_agent = ToolCallingAgent(
        tools=[choose_response_lane, choose_comms_channel, choose_rollback_policy],
        model=model,
        max_steps=6,
        name="response_planner",
        description="장애 severity/type에 맞는 대응 lane/comms/rollback을 정하는 전문 에이전트",
        instructions=(
            "입력에 포함된 severity, incident_type, users_affected, payment_path를 읽고 "
            "반드시 도구를 순서대로 사용해라: choose_response_lane -> choose_comms_channel -> choose_rollback_policy. "
            "마지막 줄은 'PLAN_FINAL: lane=<lane>; comms=<comms>; rollback=<rollback>' 형식으로만 출력한다."
        ),
    )

    manager = ToolCallingAgent(
        tools=[],
        managed_agents=[risk_agent, response_agent],
        model=model,
        max_steps=10,
    )

    return manager


def run_selfcheck() -> Dict[str, Any]:
    checks = {
        "detect_incident_type(결제 청구 실패)": detect_incident_type("결제 청구 실패"),
        "assess_severity(payment,1200,1)": assess_severity("payment", 1200, 1),
        "estimate_eta_minutes(sev2)": estimate_eta_minutes("sev2"),
        "choose_response_lane(sev1,payment)": choose_response_lane("sev1", "payment"),
        "choose_comms_channel(sev3,3000)": choose_comms_channel("sev3", 3000),
        "choose_rollback_policy(sev2,1)": choose_rollback_policy("sev2", 1),
    }

    tools_ok = (
        checks["detect_incident_type(결제 청구 실패)"] == "payment"
        and checks["assess_severity(payment,1200,1)"] == "sev1"
        and checks["estimate_eta_minutes(sev2)"] == 120
        and checks["choose_response_lane(sev1,payment)"] == "WAR_ROOM"
        and checks["choose_comms_channel(sev3,3000)"] == "STATUS_PAGE_AND_SLACK"
        and checks["choose_rollback_policy(sev2,1)"] == "CANARY_AND_ROLLBACK_READY"
    )

    return {
        "mode": "selfcheck",
        "tools_ok": tools_ok,
        "checks": checks,
    }


def run_single(issue_text: str, users_affected: int, payment_path: int, provider: str = "auto") -> Dict[str, Any]:
    model, model_name = select_model(provider)
    manager = build_agents(model)

    task = build_manager_task(issue_text, users_affected, payment_path)
    result = str(manager.run(task))
    incident_type, severity, lane, comms, eta, rollback = parse_final_line(result)

    return {
        "mode": "single",
        "model": model_name,
        "input": {
            "issue_text": issue_text,
            "users_affected": users_affected,
            "payment_path": payment_path,
        },
        "result": result,
        "parsed": {
            "type": incident_type,
            "severity": severity,
            "lane": lane,
            "comms": comms,
            "eta": eta,
            "rollback": rollback,
            "parse_ok": bool(incident_type and severity and lane and comms and eta and rollback),
        },
    }


def run_eval(input_path: str, provider: str = "auto") -> Dict[str, Any]:
    with open(input_path, "r", encoding="utf-8") as f:
        tasks = json.load(f)

    if not isinstance(tasks, list):
        raise ValueError("입력 JSON 루트는 list여야 합니다")

    model, model_name = select_model(provider)
    manager = build_agents(model)

    records: List[Dict[str, Any]] = []
    passed = 0

    for i, item in enumerate(tasks, start=1):
        issue_text = str(item["issue_text"])
        users_affected = int(item.get("users_affected", 0))
        payment_path = int(item.get("payment_path", 0))

        expected = item["expected"]
        exp_type = str(expected["type"]).lower()
        exp_severity = str(expected["severity"]).lower()
        exp_lane = str(expected["lane"]).upper()
        exp_comms = str(expected["comms"]).upper()
        exp_eta = int(expected["eta"])
        exp_rollback = str(expected["rollback"]).upper()

        task = build_manager_task(issue_text, users_affected, payment_path)
        result = str(manager.run(task))
        incident_type, severity, lane, comms, eta, rollback = parse_final_line(result)

        ok = (
            incident_type == exp_type
            and severity == exp_severity
            and lane == exp_lane
            and comms == exp_comms
            and eta == exp_eta
            and rollback == exp_rollback
        )
        passed += int(ok)

        records.append(
            {
                "id": i,
                "input": {
                    "issue_text": issue_text,
                    "users_affected": users_affected,
                    "payment_path": payment_path,
                },
                "expected": {
                    "type": exp_type,
                    "severity": exp_severity,
                    "lane": exp_lane,
                    "comms": exp_comms,
                    "eta": exp_eta,
                    "rollback": exp_rollback,
                },
                "parsed": {
                    "type": incident_type,
                    "severity": severity,
                    "lane": lane,
                    "comms": comms,
                    "eta": eta,
                    "rollback": rollback,
                    "parse_ok": bool(incident_type and severity and lane and comms and eta and rollback),
                },
                "pass": ok,
                "raw_result": result,
            }
        )

    total = len(records)
    score = round(passed / total, 3) if total else 0.0

    return {
        "mode": "eval",
        "model": model_name,
        "total": total,
        "passed": passed,
        "score": score,
        "pass": score >= 0.66,
        "records": records,
    }


def main():
    parser = argparse.ArgumentParser(description="HF Agents Day8 - Multi-agent incident commander")
    parser.add_argument("--mode", choices=["selfcheck", "single", "eval"], default="selfcheck")
    parser.add_argument("--provider", choices=["auto", "openai", "huggingface"], default="auto")
    parser.add_argument("--issue", default="결제 승인 실패가 급증하고 고객 불만이 접수됩니다.")
    parser.add_argument("--users", type=int, default=3200)
    parser.add_argument("--payment", type=int, default=1)
    parser.add_argument("--input", default="sample_tasks_day8.json")
    args = parser.parse_args()

    if args.mode == "selfcheck":
        output = run_selfcheck()
    elif args.mode == "single":
        output = run_single(args.issue, args.users, args.payment, args.provider)
    else:
        output = run_eval(args.input, args.provider)

    print(json.dumps(output, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()
