#!/usr/bin/env python3
import argparse
import ast
import json
import operator as op
import re
from dataclasses import dataclass
from typing import Any, Dict, List


ALLOWED_OPS = {
    ast.Add: op.add,
    ast.Sub: op.sub,
    ast.Mult: op.mul,
    ast.Div: op.truediv,
    ast.Pow: op.pow,
    ast.USub: op.neg,
}


@dataclass
class ToolResult:
    ok: bool
    tool: str
    output: Any
    error: str = ""


def safe_eval(expr: str) -> float:
    """Safely evaluate basic arithmetic expressions."""
    def _eval(node):
        if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
            return node.value
        if isinstance(node, ast.BinOp) and type(node.op) in ALLOWED_OPS:
            return ALLOWED_OPS[type(node.op)](_eval(node.left), _eval(node.right))
        if isinstance(node, ast.UnaryOp) and type(node.op) in ALLOWED_OPS:
            return ALLOWED_OPS[type(node.op)](_eval(node.operand))
        raise ValueError("지원하지 않는 수식입니다")

    parsed = ast.parse(expr, mode="eval")
    return float(_eval(parsed.body))


def extract_expression(text: str) -> str:
    # 숫자/연산자만 남겨 계산식 추출
    candidate = re.sub(r"[^0-9\+\-\*\/\(\)\. ]", "", text)
    candidate = re.sub(r"\s+", "", candidate)
    if not candidate:
        raise ValueError("계산식을 찾지 못했습니다")
    return candidate


def search_docs(query: str) -> ToolResult:
    kb = [
        {"title": "환불 정책", "snippet": "결제 후 7일 이내 환불 가능", "url": "https://example.com/refund"},
        {"title": "온보딩 체크리스트", "snippet": "계정 생성 > 권한 요청 > 첫 과제", "url": "https://example.com/onboarding"},
        {"title": "장애 대응 가이드", "snippet": "로그 확인 > 재시작 > 에스컬레이션", "url": "https://example.com/incident"},
    ]
    hits = [doc for doc in kb if any(token in (doc["title"] + doc["snippet"]) for token in query.split())]
    return ToolResult(ok=True, tool="search_docs", output=hits[:3] if hits else kb[:2])


def calculate(request: str) -> ToolResult:
    try:
        expr = extract_expression(request)
        value = safe_eval(expr)
        return ToolResult(ok=True, tool="calculate", output={"expression": expr, "value": value})
    except Exception as e:
        return ToolResult(ok=False, tool="calculate", output=None, error=str(e))


def draft_reply(request: str) -> ToolResult:
    template = (
        "안녕하세요. 문의 주셔서 감사합니다.\n"
        "요청하신 내용은 확인 후 오늘 중으로 업데이트드리겠습니다.\n"
        "추가로 필요한 정보가 있으면 회신 부탁드립니다."
    )
    return ToolResult(ok=True, tool="draft_reply", output={"request": request, "reply": template})


def baseline_router(_: str) -> str:
    return "search_docs"


def contract_router(request: str) -> str:
    text = request.strip()

    calc_keywords = ["계산", "합계", "예산", "비용", "곱", "나눠", "더해", "매출"]
    reply_keywords = ["답장", "회신", "메일", "공지", "문구", "메시지", "알림"]

    if any(k in text for k in calc_keywords) or re.search(r"\d+[\+\-\*\/]\d+", text):
        return "calculate"
    if any(k in text for k in reply_keywords):
        return "draft_reply"
    return "search_docs"


def run_tool(tool_name: str, request: str) -> ToolResult:
    if tool_name == "search_docs":
        return search_docs(request)
    if tool_name == "calculate":
        return calculate(request)
    if tool_name == "draft_reply":
        return draft_reply(request)
    return ToolResult(ok=False, tool=tool_name, output=None, error="unknown_tool")


def evaluate(tasks: List[Dict[str, str]]) -> Dict[str, Any]:
    records = []
    baseline_hit = 0
    contract_hit = 0

    for idx, task in enumerate(tasks, start=1):
        request = task["request"]
        expected = task["expected_tool"]

        b_tool = baseline_router(request)
        c_tool = contract_router(request)

        b_result = run_tool(b_tool, request)
        c_result = run_tool(c_tool, request)

        baseline_hit += int(b_tool == expected)
        contract_hit += int(c_tool == expected)

        records.append(
            {
                "id": idx,
                "request": request,
                "expected_tool": expected,
                "baseline_tool": b_tool,
                "contract_tool": c_tool,
                "contract_ok": c_result.ok,
                "contract_output": c_result.output,
                "contract_error": c_result.error,
            }
        )

    total = len(tasks)
    baseline_accuracy = round(baseline_hit / total, 3) if total else 0.0
    contract_accuracy = round(contract_hit / total, 3) if total else 0.0

    return {
        "total": total,
        "baseline_accuracy": baseline_accuracy,
        "contract_accuracy": contract_accuracy,
        "improved": contract_accuracy > baseline_accuracy,
        "pass": contract_accuracy >= 0.75,
        "records": records,
    }


def run_single(request: str) -> Dict[str, Any]:
    chosen = contract_router(request)
    result = run_tool(chosen, request)
    return {
        "request": request,
        "chosen_tool": chosen,
        "ok": result.ok,
        "output": result.output,
        "error": result.error,
    }


def load_tasks(path: str) -> List[Dict[str, str]]:
    with open(path, "r", encoding="utf-8") as f:
        payload = json.load(f)
    if not isinstance(payload, list):
        raise ValueError("입력 JSON은 list 형식이어야 합니다")
    return payload


def main():
    parser = argparse.ArgumentParser(description="HF Agents Day4 Mini Project")
    parser.add_argument("--mode", choices=["eval", "single"], default="eval")
    parser.add_argument("--input", default="sample_tasks_day4.json")
    parser.add_argument("--request", default="이번 달 비용 120000 + 85000 계산해줘")
    args = parser.parse_args()

    if args.mode == "eval":
        tasks = load_tasks(args.input)
        report = evaluate(tasks)
        print(json.dumps(report, ensure_ascii=False, indent=2))
    else:
        output = run_single(args.request)
        print(json.dumps(output, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()
