#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Evaluate a single political speech against an authoritarianism taxonomy using an LLM.
Based on research by Delgado-Mohatar & Alelú-Paz, When Algorithms Guard Democracy — integrating Levitsky & Ziblatt’s four dimensions with LLM analysis.

Inputs:
- <speech_file>.txt     (a single speech as text file)

Outputs:
- <name>_<model>.json        (per-category speech nested scores)
- <name>_<model>.csv         (flat rows for human review)
- <name>_<model>.ndjson      (auditable per-call logs: prompt + raw response)

Usage:
    python evaluate.py --backend openai|gemini|grok|llama_local --author_name <author> --speech_file <speech_file.json>
"""

import os, json, re, csv, argparse, ast
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from tqdm import tqdm

# ==============
# Configuration 
# ==============

# Default backend: "openai" | "gemini" | "grok" | "llama_local"
DEFAULT_BACKEND = "llama_local"

# Default model names by backend 
DEFAULT_MODELS = {
    "openai": "gpt-4o",
    "gemini": "gemini-2.5-pro",
    "grok":   "grok-4-fast-reasoning",
    "llama_local": "gemma:2b"
}

# API keys / endpoints (set via env; fallback optional)
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY")
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "YOUR GEMINI_API_KEY")
GROK_API_KEY   = os.environ.get("GROK_API_KEY", "YOUR_GROK_API_KEY")
LLAMA_BASE_URL = os.environ.get("LLAMA_BASE_URL", "http://localhost:11434")  # e.g., Ollama

# Files
PROMPT_FILE = "llm_data/evaluation_prompt.txt"
TAXONOMY_FILE = "llm_data/taxonomy.json"

# LLM call parameters (global constants)
TEMPERATURE = 0.1
TOP_P = 1.0
MAX_RETRIES = 2
RETRY_SLEEP = 2

# ==============
# LLM Utilities
# ==============

def call_llm(backend: str, model: str, prompt: str, timeout: int = 60) -> str:
    """
    Call the configured LLM backend with a simple text prompt.
    Must return a raw string response.
    """
    backend = backend.lower()
    if backend == "openai":
        return call_openai(model=model, prompt=prompt, timeout=timeout)
    elif backend == "gemini":
        return call_gemini(model=model, prompt=prompt, timeout=timeout)
    elif backend == "grok":
        return call_grok(model=model, prompt=prompt, timeout=timeout)
    elif backend == "llama_local":
        return call_llama_local(model=model, prompt=prompt, timeout=timeout)
    else:
        raise ValueError(f"Unsupported backend: {backend}")

def call_openai(model: str, prompt: str, timeout: int = 60) -> str:
    """
    OpenAI chat completion (responses should be JSON with keys score, reasoning).
    Requires: pip install openai
    """
    if not OPENAI_API_KEY or OPENAI_API_KEY.startswith("YOUR_"):
        raise RuntimeError("OPENAI_API_KEY is not set. Use env var OPENAI_API_KEY.")
    try:
        from openai import OpenAI
        client = OpenAI(api_key=OPENAI_API_KEY)

        resp = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "Return ONLY valid JSON: {\"score\": <int 1-10>, \"reasoning\": \"...\"}"},
                {"role": "user", "content": prompt}
            ],
            temperature=TEMPERATURE,
            top_p=TOP_P,
            timeout=timeout,
            max_completion_tokens=200
        )
        return (resp.choices[0].message.content or "").strip()
    except Exception as e:
        print(f"Error in OpenAI call: {type(e).__name__}: {e}")
        raise

def call_gemini(model: str, prompt: str, timeout: int = 60) -> str:
    """
    Google Gemini (text-only). Requires: pip install google-generativeai
    Forces JSON via response_mime_type.
    """
    if not GEMINI_API_KEY or GEMINI_API_KEY.startswith("YOUR_"):
        raise RuntimeError("GEMINI_API_KEY is not set. Use env var GEMINI_API_KEY.")
    import google.generativeai as genai
    genai.configure(api_key=GEMINI_API_KEY)
    gmodel = genai.GenerativeModel(model_name=model)
    resp = gmodel.generate_content(
        prompt,
        generation_config={
            "temperature": TEMPERATURE,
            "response_mime_type": "application/json",  # force JSON
        },
        safety_settings=None,  # optional: do not harden filters unless needed
        request_options={"timeout": timeout},
    )
    return (resp.text or "").strip()

def call_grok(model: str, prompt: str, timeout: int = 60) -> str:
    """
    XAI Grok chat completion (responses should be JSON with keys score, reasoning).
    Requires: pip install xai-sdk
    """
    try:
        from xai_sdk import Client
        from xai_sdk.chat import user, system

        client = Client(api_key=GROK_API_KEY)

        chat = client.chat.create(model=model, temperature=TEMPERATURE, top_p=TOP_P)
        chat.append(system("Return ONLY valid JSON: {\"score\": <int 1-10>, \"reasoning\": \"...\"}"))
        chat.append(user(prompt))
        
        response = chat.sample()
                    
        return (response.content)
    except Exception as e:
        print(f"Error in Grok call: {type(e).__name__}: {e}")
        raise

def call_llama_local(model: str, prompt: str, timeout: int = 60) -> str:
    """
    Minimal example for a local Llama server (e.g., Ollama).
    Imports ollama only if used.
    """
    try:
        from ollama import chat
        from ollama import ChatResponse
    except Exception as e:
        raise RuntimeError("Backend llama_local requires package 'ollama' installed (pip install ollama).") from e    

    # Call the local Llama server
    response = chat(
        model=model,
        messages=[{'role': 'user', 'content': prompt}],
        options={"temperature": TEMPERATURE},
    )
    
    return response.message.content


# ==================
# Helper functions
# ==================

def read_file(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        return f.read()

def load_json(path: str) -> Any:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def save_json(path: str, data: Any) -> None:
    outdir = os.path.dirname(path)
    if outdir:
        os.makedirs(outdir, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

def append_ndjson(path: str, obj: Dict[str, Any]) -> None:
    outdir = os.path.dirname(path)
    if outdir:
        os.makedirs(outdir, exist_ok=True)
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(obj, ensure_ascii=False) + "\n")

JSON_OBJECT_RE = re.compile(r"\{.*\}", re.DOTALL)

def robust_parse_llm(raw: str):
    """
    Returns (score:int|None, reasoning:str, parsed:dict|None). Never raises.
    Tries json.loads; if fails, finds first {...}; if fails, ast.literal_eval;
    last resort: regex 1..10 and reasoning=raw.
    """
    if not raw:
        return None, "", None

    # 1) Direct JSON
    try:
        obj = json.loads(raw)
        if isinstance(obj, list) and obj:
            obj = obj[0]
        if isinstance(obj, dict):
            score = obj.get("score", None)
            reasoning = obj.get("reasoning", "")
            return score, reasoning, obj
    except Exception:
        pass

    # 2) Extract first {...}
    m = JSON_OBJECT_RE.search(raw)
    if m:
        snippet = m.group(0)
        try:
            obj = json.loads(snippet)
            score = obj.get("score", None)
            reasoning = obj.get("reasoning", "")
            return score, reasoning, obj
        except Exception:
            try:
                # Handles single quotes
                obj = ast.literal_eval(snippet)
                if isinstance(obj, dict):
                    score = obj.get("score", None)
                    reasoning = obj.get("reasoning", "")
                    return score, reasoning, obj
            except Exception:
                pass

    # 3) Last resort: number 1..10 and reasoning=raw
    score = None
    m = re.search(r"\b(10|[1-9])\b", raw)
    if m:
        score = int(m.group(1))
    return score, raw.strip(), None

def format_indicators(indicators: List[str]) -> str:
    if not indicators:
        return "- (none)"
    return "\n".join([f"- {q}" for q in indicators])

def format_rubric(rubric: Dict[str, str]) -> str:
    order = ["1", "2-3", "4-5", "6-7", "8-9", "10"]
    parts = []
    for k in order:
        if k in rubric:
            parts.append(f"{k}: {rubric[k]}")
    for k in rubric:
        if k not in order:
            parts.append(f"{k}: {rubric[k]}")
    return "\n".join(parts)

def render_prompt(template: str, speech: str, category_name: str, indicators: List[str], rubric: Dict[str, str]) -> str:
    prompt = template
    prompt = prompt.replace("{SPEECH}", speech)
    prompt = prompt.replace("{CATEGORY_NAME}", category_name)
    prompt = prompt.replace("{INDICATORS}", format_indicators(indicators))
    prompt = prompt.replace("{RUBRIC}", format_rubric(rubric))
    return prompt


# ==================
# Main evaluation
# ==================

def evaluate(
    author,
    backend: str = DEFAULT_BACKEND,
    model: Optional[str] = None,
    speech_file: Optional[str] = None    
) -> None:

    # Resolve model
    model = model or DEFAULT_MODELS.get(backend, DEFAULT_MODELS["openai"])

    # Output paths
    speech_file_name = os.path.splitext(os.path.basename(speech_file))[0]
    output_json = f"evaluations/{author}_{speech_file_name}_{model}.json"
    output_csv  = f"evaluations/{author}_{speech_file_name}_{model}.csv"
    log_ndjson  = f"evaluations/{author}_{speech_file_name}_{model}.ndjson"

    # Ensure output dir exists
    outdir = os.path.dirname(output_csv)
    if outdir:
        os.makedirs(outdir, exist_ok=True)    

    # Load assets
    prompt_template = read_file(PROMPT_FILE)
    taxonomy = load_json(TAXONOMY_FILE)

    # Load a single speech from a plain text file
    with open(speech_file, "r", encoding="utf-8") as f:
        speech_text = f.read().strip()    

    print ("Successfully loaded speech file:", speech_file)

    # Create a flat list of all subcategories to iterate over for tqdm
    evaluation_steps = []
    for cat in taxonomy.get("categories", []):
        for sub in cat.get("subcategories", []):
            evaluation_steps.append({"category": cat, "subcategory": sub})

    nested_output = []   # list of {author, date, speech_ID, scores: [...]}

    # Reset logs
    if os.path.exists(log_ndjson):
        os.remove(log_ndjson)

    # CSV writer
    with open(output_csv, "w", newline="", encoding="utf-8") as fcsv:
        writer = csv.DictWriter(
            fcsv,
            fieldnames=["author", "category", "subcategory", "score", "reasoning"]
        )
        writer.writeheader()

        per_speech_scores = []

        # For each category/subcategory, with a tqdm progress bar
        for step in tqdm(evaluation_steps, desc="Evaluating categories"):
            cat = step["category"]
            sub = step["subcategory"]
            
            cat_name = cat.get("name", "")
            sub_name = sub.get("name", "")
            indicators = sub.get("indicators", [])
            rubric = sub.get("rubric", {})

            prompt = render_prompt(
                template=prompt_template,
                speech=speech_text,
                category_name=f"{cat_name} / {sub_name}",
                indicators=indicators,
                rubric=rubric,
            )

            # LLM call with retries
            llm_response = None
            try:
                llm_response = call_llm(backend=backend, model=model, prompt=prompt)                    
            except Exception as e:                
                print(f"LLM call failed for {cat_name} / {sub_name}: {type(e).__name__}: {e}")
                # Abort execution on failure
                raise
                
            # Parse robustly                 
            score, reasoning, _ = robust_parse_llm(llm_response)

            # Log (ndjson)
            log_obj = {
                "ts": datetime.now(timezone.utc).isoformat(),
                "backend": backend,
                "model": model,
                "author": author,                    
                "category": cat_name,
                "subcategory": sub_name,
                "prompt": prompt,
                "raw_response": llm_response,
                "reasoning": reasoning,
                "parsed_score": score
            }
            append_ndjson(log_ndjson, log_obj)

            # Nested per-speech
            per_speech_scores.append({
                "category": cat_name,
                "subcategory": sub_name,
                "reasoning": reasoning,
                "score": score
            })

            # Flat CSV row (keep raw_response truncated for readability)
            row = {
                "author": author,                
                "category": cat_name,
                "subcategory": sub_name,
                "score": score,
                "reasoning": reasoning
            }
            writer.writerow(row)

        # Append and save JSON after the speech to be robust to interruptions
        nested_output.append({
            "author": author,            
            "scores": per_speech_scores
        })
        save_json(output_json, nested_output)

    print(f"Done. Wrote:\n- {output_json}\n- {output_csv}\n- {log_ndjson}")


# ===========
# CLI
# ===========

def main():
    parser = argparse.ArgumentParser(description="Evaluate a single speech with a taxonomy via LLM.")
    parser.add_argument("--backend", type=str, default=DEFAULT_BACKEND,
                        choices=["openai", "gemini", "grok", "llama_local"],
                        help="LLM backend to use (default: llama_local).")
    parser.add_argument("--author_name", required=True, help="Name of the speech author")
    parser.add_argument("--speech_file", required=True, help="Path to the text file containing the speech")    
    parser.add_argument("--model", type=str, default=None, help="Model name (optional)")    

    args = parser.parse_args()

    evaluate(
        author=args.author_name,
        backend=args.backend,        
        speech_file=args.speech_file,
        model=args.model        
    )

if __name__ == "__main__":
    main()