#!/usr/bin/env python3
import argparse
import json
import os
import re
from typing import Any, Dict, List, Tuple

from dotenv import load_dotenv
from smolagents import InferenceClientModel, LiteLLMModel, ToolCallingAgent, tool


@tool
def detect_category(issue_text: str) -> str:
    """고객 문의 텍스트를 카테고리로 분류한다.

    Args:
        issue_text: 고객 문의 원문
    """
    text = issue_text.lower()

    rules = {
        "billing": ["결제", "청구", "환불", "invoice", "billing", "요금"],
        "technical": ["오류", "에러", "버그", "느려", "접속", "다운", "500", "latency"],
        "account": ["로그인", "비밀번호", "인증", "계정", "탈퇴", "잠김", "signin", "account"],
    }

    for category, keywords in rules.items():
        if any(keyword in text for keyword in keywords):
            return category

    return "general"


@tool
def assign_priority(category: str, customer_tier: str = "standard", outage_minutes: int = 0) -> str:
    """문의 우선순위를 P1/P2/P3로 산정한다.

    Args:
        category: detect_category 결과(billing/technical/account/general)
        customer_tier: 고객 등급(vip/standard)
        outage_minutes: 서비스 장애 지속 시간(분)
    """
    c = category.strip().lower()
    t = customer_tier.strip().lower()
    o = int(outage_minutes)

    if o >= 30:
        return "P1"
    if t == "vip" and c in {"billing", "technical"}:
        return "P1"
    if c == "technical" or o >= 10:
        return "P2"
    return "P3"


@tool
def suggest_action(category: str, priority: str) -> str:
    """카테고리와 우선순위에 맞는 다음 조치 코드(action code)를 반환한다.

    Args:
        category: 문의 카테고리
        priority: 우선순위(P1/P2/P3)
    """
    c = category.strip().lower()
    p = priority.strip().upper()

    if c == "billing":
        return "BILLING_ESCALATE" if p == "P1" else "BILLING_QUEUE"
    if c == "technical":
        return "INCIDENT_HOTFIX" if p == "P1" else "INCIDENT_TRIAGE"
    if c == "account":
        return "ACCOUNT_ESCALATE" if p == "P1" else "ACCOUNT_GUIDE"
    return "GENERAL_QUEUE"


def select_model(provider: str = "auto"):
    load_dotenv()

    openai_key = os.getenv("OPENAI_API_KEY")
    hf_token = os.getenv("HF_TOKEN")

    if provider in {"auto", "openai"} and openai_key:
        return LiteLLMModel(model_id="openai/gpt-4o-mini", api_key=openai_key), "openai/gpt-4o-mini"

    if provider in {"auto", "huggingface"} and hf_token:
        return InferenceClientModel(token=hf_token), "InferenceClientModel(default)"

    raise RuntimeError("모델 키를 찾지 못했습니다. OPENAI_API_KEY 또는 HF_TOKEN을 .env에 설정하세요.")


def build_task_prompt(issue_text: str, customer_tier: str, outage_minutes: int) -> str:
    return (
        "너는 고객 문의를 분류하는 운영 에이전트다.\\n"
        "아래 입력을 바탕으로 반드시 도구를 사용해 triage를 수행해라.\\n"
        f"- issue_text: {issue_text}\\n"
        f"- customer_tier: {customer_tier}\\n"
        f"- outage_minutes: {outage_minutes}\\n\\n"
        "반드시 다음 순서로 도구를 사용해라: detect_category -> assign_priority -> suggest_action\\n"
        "마지막 줄은 정확히 이 형식으로만 출력해라: \\n"
        "FINAL: category=<category>; priority=<priority>; action=<action>"
    )


def parse_final_line(result_text: str) -> Tuple[str, str, str]:
    pattern = re.compile(
        r"final\s*:\s*category\s*=\s*([a-z_]+)\s*;\s*priority\s*=\s*(p[123])\s*;\s*action\s*=\s*([A-Z_]+)",
        re.IGNORECASE | re.DOTALL,
    )
    match = pattern.search(result_text)
    if not match:
        return "", "", ""
    category = match.group(1).lower()
    priority = match.group(2).upper()
    action = match.group(3).upper()
    return category, priority, action


def run_selfcheck() -> Dict[str, Any]:
    checks = {
        "detect_category(결제가 두 번 청구됐어요)": detect_category("결제가 두 번 청구됐어요"),
        "assign_priority(billing,vip,0)": assign_priority("billing", "vip", 0),
        "suggest_action(technical,P2)": suggest_action("technical", "P2"),
    }

    tools_ok = (
        checks["detect_category(결제가 두 번 청구됐어요)"] == "billing"
        and checks["assign_priority(billing,vip,0)"] == "P1"
        and checks["suggest_action(technical,P2)"] == "INCIDENT_TRIAGE"
    )

    return {
        "mode": "selfcheck",
        "tools_ok": tools_ok,
        "checks": checks,
    }


def run_single(issue_text: str, customer_tier: str, outage_minutes: int, provider: str = "auto") -> Dict[str, Any]:
    model, model_name = select_model(provider)
    agent = ToolCallingAgent(
        tools=[detect_category, assign_priority, suggest_action],
        model=model,
        max_steps=6,
    )

    task = build_task_prompt(issue_text, customer_tier, outage_minutes)
    result = str(agent.run(task))
    category, priority, action = parse_final_line(result)

    return {
        "mode": "single",
        "model": model_name,
        "input": {
            "issue_text": issue_text,
            "customer_tier": customer_tier,
            "outage_minutes": outage_minutes,
        },
        "result": result,
        "parsed": {
            "category": category,
            "priority": priority,
            "action": action,
            "parse_ok": bool(category and priority and action),
        },
    }


def run_eval(input_path: str, provider: str = "auto") -> Dict[str, Any]:
    with open(input_path, "r", encoding="utf-8") as f:
        tasks = json.load(f)

    if not isinstance(tasks, list):
        raise ValueError("입력 JSON 루트는 list여야 합니다")

    model, model_name = select_model(provider)
    agent = ToolCallingAgent(
        tools=[detect_category, assign_priority, suggest_action],
        model=model,
        max_steps=6,
    )

    records: List[Dict[str, Any]] = []
    passed = 0

    for i, item in enumerate(tasks, start=1):
        issue_text = str(item["issue_text"])
        customer_tier = str(item.get("customer_tier", "standard"))
        outage_minutes = int(item.get("outage_minutes", 0))

        expected = item["expected"]
        exp_category = str(expected["category"]).lower()
        exp_priority = str(expected["priority"]).upper()
        exp_action = str(expected["action"]).upper()

        task = build_task_prompt(issue_text, customer_tier, outage_minutes)
        result = str(agent.run(task))
        category, priority, action = parse_final_line(result)

        ok = category == exp_category and priority == exp_priority and action == exp_action
        passed += int(ok)

        records.append(
            {
                "id": i,
                "input": {
                    "issue_text": issue_text,
                    "customer_tier": customer_tier,
                    "outage_minutes": outage_minutes,
                },
                "expected": {
                    "category": exp_category,
                    "priority": exp_priority,
                    "action": exp_action,
                },
                "parsed": {
                    "category": category,
                    "priority": priority,
                    "action": action,
                    "parse_ok": bool(category and priority and action),
                },
                "pass": ok,
                "raw_result": result,
            }
        )

    total = len(records)
    score = round(passed / total, 3) if total else 0.0

    return {
        "mode": "eval",
        "model": model_name,
        "total": total,
        "passed": passed,
        "score": score,
        "pass": score >= 0.66,
        "records": records,
    }


def main():
    parser = argparse.ArgumentParser(description="HF Agents Day6 - ToolCallingAgent triage workflow")
    parser.add_argument("--mode", choices=["selfcheck", "single", "eval"], default="selfcheck")
    parser.add_argument("--provider", choices=["auto", "openai", "huggingface"], default="auto")
    parser.add_argument("--issue", default="결제가 두 번 청구됐어요. 빠르게 확인해주세요.")
    parser.add_argument("--tier", default="vip")
    parser.add_argument("--outage", type=int, default=0)
    parser.add_argument("--input", default="sample_tasks_day6.json")
    args = parser.parse_args()

    if args.mode == "selfcheck":
        output = run_selfcheck()
    elif args.mode == "single":
        output = run_single(args.issue, args.tier, args.outage, args.provider)
    else:
        output = run_eval(args.input, args.provider)

    print(json.dumps(output, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()
