#!/usr/bin/env python3
import argparse
import json
import os
import re
from typing import Any, Dict, List

from dotenv import load_dotenv
from smolagents import CodeAgent, InferenceClientModel, LiteLLMModel, tool


@tool
def add_numbers(a: float, b: float) -> float:
    """두 수를 더한다.

    Args:
        a: 첫 번째 수
        b: 두 번째 수
    """
    return a + b


@tool
def multiply_numbers(a: float, b: float) -> float:
    """두 수를 곱한다.

    Args:
        a: 첫 번째 수
        b: 두 번째 수
    """
    return a * b


@tool
def add_vat(amount: float, vat_rate: float = 0.1) -> float:
    """공급가에 부가세를 더한 총액을 계산한다.

    Args:
        amount: 공급가
        vat_rate: 부가세율(기본 10%)
    """
    return amount * (1 + vat_rate)


def select_model(provider: str = "auto"):
    load_dotenv()

    openai_key = os.getenv("OPENAI_API_KEY")
    hf_token = os.getenv("HF_TOKEN")

    if provider in {"auto", "openai"} and openai_key:
        return LiteLLMModel(model_id="openai/gpt-4o-mini", api_key=openai_key), "openai/gpt-4o-mini"

    if provider in {"auto", "huggingface"} and hf_token:
        return InferenceClientModel(token=hf_token), "InferenceClientModel(default)"

    raise RuntimeError("모델 키를 찾지 못했습니다. OPENAI_API_KEY 또는 HF_TOKEN을 .env에 설정하세요.")


def extract_numbers(text: str) -> List[float]:
    cleaned = text.replace(",", "")
    matches = re.findall(r"-?\d+(?:\.\d+)?", cleaned)
    return [float(m) for m in matches]


def is_expected_number_present(text: str, expected: float) -> bool:
    numbers = extract_numbers(text)
    return any(abs(n - expected) < 1e-6 for n in numbers)


def run_selfcheck() -> Dict[str, Any]:
    checks = {
        "add_numbers(31,14)": add_numbers(31, 14),
        "multiply_numbers(45,2)": multiply_numbers(45, 2),
        "add_vat(100000,0.1)": add_vat(100000, 0.1),
    }
    return {
        "mode": "selfcheck",
        "tools_ok": abs(checks["multiply_numbers(45,2)"] - 90) < 1e-6 and abs(checks["add_vat(100000,0.1)"] - 110000) < 1e-6,
        "checks": checks,
    }


def run_single(task: str, provider: str = "auto") -> Dict[str, Any]:
    model, model_name = select_model(provider)
    agent = CodeAgent(
        tools=[add_numbers, multiply_numbers, add_vat],
        model=model,
        max_steps=5,
    )
    result = agent.run(task)
    return {
        "mode": "single",
        "model": model_name,
        "task": task,
        "result": str(result),
    }


def run_eval(input_path: str, provider: str = "auto") -> Dict[str, Any]:
    with open(input_path, "r", encoding="utf-8") as f:
        tasks = json.load(f)

    if not isinstance(tasks, list):
        raise ValueError("입력 JSON 루트는 list여야 합니다")

    model, model_name = select_model(provider)
    agent = CodeAgent(
        tools=[add_numbers, multiply_numbers, add_vat],
        model=model,
        max_steps=5,
    )

    records = []
    passed = 0

    for i, item in enumerate(tasks, start=1):
        task = item["task"]
        expected = float(item["expected_number"])
        result = str(agent.run(task))
        ok = is_expected_number_present(result, expected)
        passed += int(ok)
        records.append(
            {
                "id": i,
                "task": task,
                "expected_number": expected,
                "result": result,
                "pass": ok,
            }
        )

    total = len(records)
    score = round(passed / total, 3) if total else 0.0
    return {
        "mode": "eval",
        "model": model_name,
        "total": total,
        "passed": passed,
        "score": score,
        "pass": score >= 0.66,
        "records": records,
    }


def main():
    parser = argparse.ArgumentParser(description="HF Agents Day5 - First smolagents CodeAgent")
    parser.add_argument("--mode", choices=["selfcheck", "single", "eval"], default="selfcheck")
    parser.add_argument("--task", default="31과 14를 더한 뒤 2를 곱해줘. 답은 숫자만 말해줘")
    parser.add_argument("--input", default="sample_tasks_day5.json")
    parser.add_argument("--provider", choices=["auto", "openai", "huggingface"], default="auto")
    args = parser.parse_args()

    if args.mode == "selfcheck":
        output = run_selfcheck()
    elif args.mode == "single":
        output = run_single(args.task, args.provider)
    else:
        output = run_eval(args.input, args.provider)

    print(json.dumps(output, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()
