import argparse
import json
import re
from pathlib import Path
from typing import Dict, List, Tuple


def search_docs(query: str) -> str:
    return f"[search_docs] '{query}'에 대한 문서 검색 결과 3건"


def calculator(expression: str) -> str:
    safe_expr = re.sub(r"[^0-9\s\+\-\*\/\(\)\.]", "", expression)
    if not safe_expr.strip():
        return "[calculator] 계산 가능한 수식이 없습니다"
    try:
        value = eval(safe_expr, {"__builtins__": {}}, {})
        return f"[calculator] 결과: {value}"
    except Exception:
        return "[calculator] 수식 계산 실패"


def baseline_router(user_request: str) -> str:
    """의도적으로 단순한 라우터: 검색 도구를 과도하게 선택"""
    if "날씨" in user_request or "뉴스" in user_request or "정리" in user_request:
        return "search_docs"
    # 숫자 질문도 대체로 검색으로 보내는 나쁜 습관
    return "search_docs"


def contract_router(user_request: str) -> str:
    """프롬프트 계약(규칙)이 반영된 라우터"""
    math_keywords = ["계산", "더하기", "빼기", "곱", "나누기", "나눈", "나눠", "합", "평균"]
    has_math_keyword = any(k in user_request for k in math_keywords)
    has_math_pattern = bool(re.search(r"\d+\s*[\+\-\*\/]\s*\d+", user_request))

    if has_math_keyword or has_math_pattern:
        return "calculator"
    return "search_docs"


def run_tool(tool_name: str, user_request: str) -> str:
    if tool_name == "calculator":
        return calculator(user_request)
    return search_docs(user_request)


def evaluate(tasks: List[Dict[str, str]], router_name: str) -> Tuple[List[Dict[str, str]], float]:
    router = baseline_router if router_name == "baseline" else contract_router
    logs = []
    correct = 0

    for item in tasks:
        request = item["request"]
        expected = item["expected_tool"]
        chosen = router(request)
        result = run_tool(chosen, request)
        ok = chosen == expected
        if ok:
            correct += 1
        logs.append(
            {
                "request": request,
                "expected_tool": expected,
                "chosen_tool": chosen,
                "success": ok,
                "tool_result": result,
            }
        )

    accuracy = correct / len(tasks) if tasks else 0.0
    return logs, accuracy


def load_tasks(path: Path) -> List[Dict[str, str]]:
    with path.open("r", encoding="utf-8") as f:
        data = json.load(f)
    return data["tasks"]


def main() -> None:
    parser = argparse.ArgumentParser(description="HF Day3 Prompt Pattern practice")
    parser.add_argument("--mode", choices=["compare", "single"], default="compare")
    parser.add_argument(
        "--input",
        default="sample_tasks_day3.json",
        help="비교 모드에서 사용할 입력 JSON 파일 경로",
    )
    parser.add_argument("--request", default="12 + 37 계산해줘", help="single 모드에서 테스트할 요청")
    args = parser.parse_args()

    if args.mode == "single":
        b = baseline_router(args.request)
        c = contract_router(args.request)
        print("=== SINGLE TEST ===")
        print(json.dumps(
            {
                "request": args.request,
                "baseline_tool": b,
                "contract_tool": c,
                "baseline_result": run_tool(b, args.request),
                "contract_result": run_tool(c, args.request),
            },
            ensure_ascii=False,
            indent=2,
        ))
        return

    tasks = load_tasks(Path(args.input))
    baseline_logs, baseline_acc = evaluate(tasks, "baseline")
    contract_logs, contract_acc = evaluate(tasks, "contract")

    print("=== DAY3 PROMPT CONTRACT EVALUATION ===")
    print(json.dumps(
        {
            "input_file": args.input,
            "task_count": len(tasks),
            "baseline_accuracy": round(baseline_acc, 2),
            "contract_accuracy": round(contract_acc, 2),
            "improved": contract_acc > baseline_acc,
        },
        ensure_ascii=False,
        indent=2,
    ))

    print("\n--- Baseline Logs ---")
    print(json.dumps(baseline_logs, ensure_ascii=False, indent=2))

    print("\n--- Contract Logs ---")
    print(json.dumps(contract_logs, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()
