"""
Generates analysis plots from an evaluation CSV file.
Based on research by Delgado-Mohatar & Alelú-Paz, When Algorithms Guard Democracy — integrating Levitsky & Ziblatt’s four dimensions with LLM analysis.

This script creates dot plots and radar charts to visualize authoritarianism scores
based on the methodology from Delgado-Mohatar & Alelú-Paz.

Usage:
    python analysis.py --csv_file <path_to_your_evaluation.csv>
"""

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import argparse
import textwrap
import os
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from numpy.linalg import norm

# Hitler "Gold Standard" vector based on max scores per subcategory
# This vector represents the maximum observed score for each indicator across all of Hitler's speeches.
hitler_gold_standard_max = pd.Series({
    "Attacks on the Established Order or the Electoral Process": 10,
    "Censorship Proposals": 9,
    "Glorification/Normalization of Historical or Foreign Violence": 10,
    "Opponent as Corrupt/Traitor": 9,
    "Opponent as Existential Threat": 10,
    "Organizational Links with Violence": 10,
    "Repressive Measures": 10,
    "Restriction of Civil Liberties": 10,
    "Support/Justification of Present Violence": 10,
    "Use of Extra-Legal Solutions": 10,
    "‘Enemies of the People’ Language": 10
}).sort_index()

# Base CSV files 
reference_speeches = {
    "Hitler pregoverment": "llm_data/reference_speeches/Hitler_gpt-4o_evaluation_flat_1922-1932.csv",
    "Hitler goverment": "llm_data/reference_speeches/Hitler_gpt-4o_evaluation_flat_speeches_1933-1939.csv",
    "Trump rallies": "llm_data/reference_speeches/Trump_gpt-4o_evaluation_flat_rallies.csv",
    "Trump presidency 2017-2021": "llm_data/reference_speeches/Trump_gpt-4o_evaluation_flat_presidential_speeches_2017-2021.csv",
    "Trump presidency 2025": "llm_data/reference_speeches/Trump_gpt-4o_evaluation_flat_presidential_speeches_2025.csv",
    "Nicola Sturgeon presidency": "llm_data/reference_speeches/Nicola_gpt-4o_evaluation_speeches_No_Covid.csv",
}

def load_and_clean_data(user_csv_path: str) -> pd.DataFrame:
    """
    Loads reference CSVs and the user-provided CSV, then cleans and combines them.
    
    Args:
        user_csv_path: Path to the user's input CSV file.

    Returns:
        A cleaned and combined pandas DataFrame.
    """
    all_data = []

    # 1. Load reference speeches
    for person, path in reference_speeches.items():
        if not Path(path).exists():
            print(f"Warning: Reference file not found, skipping: {path}")
            continue
        df_ref = pd.read_csv(path)
        df_ref["person"] = person
        all_data.append(df_ref)

    # 2. Load user's CSV
    if not Path(user_csv_path).exists():
        raise FileNotFoundError(f"User CSV file not found at: {user_csv_path}")
    
    df_user = pd.read_csv(user_csv_path)
    # Use the parent directory name of the CSV as the person's name
    user_person_name = Path(user_csv_path).parent.name
    df_user["person"] = user_person_name
    all_data.append(df_user)
    
    print(f"Loaded user data for '{user_person_name}'.")

    # 3. Combine and clean all data
    if not all_data:
        raise ValueError("No data could be loaded. Check file paths.")

    df_all = pd.concat(all_data, ignore_index=True)

    # Standardize column names (use 'indicator' if 'subcategory' exists)
    if "indicator" not in df_all.columns and "subcategory" in df_all.columns:
        df_all = df_all.rename(columns={"subcategory": "indicator"})

    # Ensure required columns exist
    required_cols = ["score", "indicator", "category", "person"]
    for col in required_cols:
        if col not in df_all.columns:
            raise ValueError(f"Combined DataFrame is missing required column: '{col}'")

    # Clean data
    df_all = df_all.dropna(subset=["score", "indicator", "category"])
    df_all["score"] = pd.to_numeric(df_all["score"], errors="coerce")
    df_all = df_all.dropna(subset=["score"])
    
    print(f"Successfully loaded and combined data for {len(df_all['person'].unique())} sources.")
    return df_all

def generate_dot_plot(df: pd.DataFrame, plot_type: str, output_dir: str):
    """
    Generates and saves a comparative dot plot for categories or indicators.

    Args:
        df: The combined input DataFrame with data for all persons.
        plot_type: Either 'category' or 'indicator'.
        output_dir: Directory to save the plot.
    """
    if plot_type not in ['category', 'indicator']:
        raise ValueError("plot_type must be 'category' or 'indicator'")

    # Aggregate data: find the maximum score per person and plot_type
    agg_data = df.groupby(["person", plot_type])["score"].max().reset_index()
    agg_data = agg_data.rename(columns={"score": "max_score"})

    # Define a consistent order for persons
    person_order = sorted(df['person'].unique())

    # Palette
    unique_items = sorted(agg_data[plot_type].unique())
    palette = dict(zip(unique_items, sns.color_palette("tab20", n_colors=len(unique_items))))

    # Create Figure
    plt.figure(figsize=(16, 10))
    ax = sns.stripplot(
        data=agg_data,
        x="person", y="max_score",
        hue=plot_type,
        order=person_order,
        size=12,
        alpha=0.9,
        jitter=0.2,
        linewidth=0.5,
        edgecolor="black",
        palette=palette
    )

    # Aesthetics
    plt.axhline(y=6, color="red", linestyle="--", linewidth=2, alpha=0.9, label="Risk Threshold (6)")
    ax.set_xlabel("Source", fontsize=14)
    ax.set_ylabel(f"Score (maximum per {plot_type})", fontsize=14)
    plt.xticks(rotation=45, ha="right", fontsize=12)
    plt.yticks(fontsize=12)
    ax.legend(
        title=plot_type.capitalize(),
        bbox_to_anchor=(1.02, 1.0),
        loc="upper left",
        frameon=False
    )
    plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to make space for legend
    
    # Save plot
    output_path = Path(output_dir) / f"comparative_dot_plot_{plot_type}.png"
    plt.savefig(output_path)
    print(f"Saved dot plot to: {output_path}")
    plt.close()

def generate_radar_chart(df: pd.DataFrame, plot_type: str, output_dir: str):
    """
    Generates and saves a comparative radar chart for categories or indicators.

    Args:
        df: The combined input DataFrame with data for all persons.
        plot_type: Either 'category' or 'indicator'.
        output_dir: Directory to save the plot.
    """
    if plot_type not in ['category', 'indicator']:
        raise ValueError("plot_type must be 'category' or 'indicator'")

    # Aggregate data
    agg_data = df.groupby(["person", plot_type])["score"].max().reset_index()
    agg_data = agg_data.rename(columns={"score": "max_score"})
    
    persons = sorted(agg_data["person"].unique())
    labels = sorted(agg_data[plot_type].unique())
    num_vars = len(labels)

    # Angles for the radar chart
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[:1]

    # Setup plot
    fig, ax = plt.subplots(figsize=(14, 14), subplot_kw=dict(polar=True))

    # Plot data for each person
    for person in persons:
        person_data = agg_data[agg_data["person"] == person]
        values = person_data.set_index(plot_type)["max_score"].reindex(labels).fillna(0).tolist()
        values += values[:1]
        ax.plot(angles, values, linewidth=2, linestyle='solid', label=person)
        ax.fill(angles, values, alpha=0.1)

    # Aesthetics
    ax.set_yticklabels([])
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels([textwrap.fill(l, 15) for l in labels], size=10)
    
    # Risk threshold line
    ax.plot(np.linspace(0, 2 * np.pi, 100), [6] * 100, color="red", linestyle="--", label="Risk Threshold (6)")

    plt.title(f"Comparative Max Scores per {plot_type.capitalize()}", size=16, y=1.1)
    ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1))

    # Save plot
    output_path = Path(output_dir) / f"comparative_radar_chart_{plot_type}.png"
    plt.savefig(output_path, bbox_inches='tight')
    print(f"Saved radar chart to: {output_path}")
    plt.close()

# --- PCA Analysis Functions (from pca_analysis.py) ---

def build_feature_matrix(df_all: pd.DataFrame, level="indicator", stat="max"):
    """Builds a person-feature matrix for PCA."""
    if stat == "max":
        agg = df_all.groupby(["person", level])["score"].max().reset_index()
    else:
        raise ValueError("Only stat='max' is currently supported for integration.")
    
    mat = agg.pivot(index="person", columns=level, values="score").fillna(0.0)
    mat = mat.sort_index(axis=1)
    return mat

def preprocess(X_df: pd.DataFrame):
    """Imputes and scales the feature matrix."""
    imputer = SimpleImputer(strategy="median")
    scaler = StandardScaler()
    X_imp = imputer.fit_transform(X_df.values)
    X_z = scaler.fit_transform(X_imp)
    return X_z

def pca_kmeans(X_z: np.ndarray, persons_index, random_state=42):
    """Performs PCA and K-Means clustering."""
    # PCA
    pca = PCA(n_components=2, random_state=random_state)
    X_pcs = pca.fit_transform(X_z)

    # K-Means with silhouette score to find best k
    best = {"k": None, "score": -1, "labels": None}
    k_max = min(6, X_pcs.shape[0] - 1)
    for k in range(2, k_max + 1):
        km = KMeans(n_clusters=k, n_init=20, random_state=random_state)
        labels = km.fit_predict(X_pcs)
        sil = silhouette_score(X_pcs, labels)
        if sil > best["score"]:
            best = {"k": k, "score": sil, "labels": labels}
    
    print(f"[PCA] Using 2 components.")
    print(f"[KMeans] Best k={best['k']} with silhouette score={best['score']:.3f}")
    return X_pcs, best['labels'], best['k']

def spread_duplicate_points(XY: np.ndarray, spread_radius=0.05):
    """Spreads overlapping points for better visualization."""
    XY_new = XY.copy()
    groups_map = {}
    for i, p in enumerate(XY):
        key = tuple(np.round(p, 5))
        groups_map.setdefault(key, []).append(i)

    offsets = {}
    for key, indices in groups_map.items():
        if len(indices) > 1:
            angles = np.linspace(0, 2 * np.pi, len(indices), endpoint=False)
            for i, idx in enumerate(indices):
                dx = spread_radius * np.cos(angles[i])
                dy = spread_radius * np.sin(angles[i])
                XY_new[idx, 0] += dx
                XY_new[idx, 1] += dy
                offsets[idx] = (dx, dy)
    return XY_new, offsets

def generate_pca_plot(df: pd.DataFrame, output_dir: str):
    """
    Generates and saves a PCA scatter plot.
    """
    print("Generating PCA plot...")
    # 1. Build feature matrix
    X_df = build_feature_matrix(df, level="indicator", stat="max")
    
    # 2. Preprocess data
    X_z = preprocess(X_df)
    
    # 3. Run PCA and K-Means
    persons_index = X_df.index.tolist()
    X_pcs, labels, k = pca_kmeans(X_z, persons_index)

    # 4. Plot
    XY_spread, offsets = spread_duplicate_points(X_pcs[:, :2])

    plt.figure(figsize=(11, 8))
    ax = plt.gca()
    sns.scatterplot(
        x=XY_spread[:, 0], y=XY_spread[:, 1],
        hue=labels, palette="tab10", s=150, alpha=0.9, edgecolor="white", ax=ax
    )

    for i, name in enumerate(persons_index):
        dx, dy = offsets.get(i, (0.0, 0.0))
        ax.text(
            XY_spread[i, 0] + np.sign(dx if dx != 0 else 1) * 0.01,
            XY_spread[i, 1] + np.sign(dy if dy != 0 else 1) * 0.01,
            " " + name,
            fontsize=9, va="center"
        )

    ax.set_xlabel("Principal Component 1")
    ax.set_ylabel("Principal Component 2")
    ax.set_title(f"PCA of Sources based on Indicators (k={k})")
    ax.legend(title="Cluster", frameon=False, bbox_to_anchor=(1.02, 1), loc="upper left")
    plt.tight_layout()

    # Save plot
    output_path = Path(output_dir) / "pca_scatter_plot.png"
    plt.savefig(output_path)
    print(f"Saved PCA plot to: {output_path}")
    plt.close()

def calculate_and_display_hitlerian_index(df: pd.DataFrame, hitler_vector: pd.Series):
    """
    Calculates authoritarian indices using two methods:
    1. Cosine Similarity: Measures the similarity in the *shape* of the authoritarian profile.
    2. Euclidean Similarity: Measures similarity based on the *magnitude* and distance.
    """
    print("\n" + "="*60)
    print("  Authoritarian Profile Analysis vs. Hitler Gold Standard")
    print("="*60)

    persons = sorted([p for p in df['person'].unique() if 'Hitler' not in p])
    hitler_vector = hitler_vector.sort_index()
    
    results = []

    for person in persons:
        person_scores = df[df['person'] == person].groupby('indicator')['score'].max()
        person_vector = person_scores.reindex(hitler_vector.index).fillna(0)
        
        v_hitler = hitler_vector.values
        v_person = person_vector.values
        
        # --- 1. Cosine Similarity ---
        if norm(v_person) == 0 or norm(v_hitler) == 0:
            cosine_sim = 0.0
        else:
            cosine_sim = np.dot(v_hitler, v_person) / (norm(v_hitler) * norm(v_person))
        
        # --- 2. Euclidean-based Similarity ---
        # Calculate Euclidean distance
        print("v_hitler: ", v_hitler)
        print(f"v_person: ({person})", v_person)
        euclidean_dist = norm(v_hitler - v_person)
        # Normalize the distance by the length of the max possible vector (all 10s)
        # This gives a value between 0 and 1, where 0 is identical.
        max_possible_dist = norm(np.full_like(v_hitler, 10))
        normalized_dist = euclidean_dist / max_possible_dist
        # Invert to get a similarity score (1 is identical, 0 is maximally distant)
        euclidean_sim = 1 - normalized_dist

        results.append({
            "Leader": person,
            "Cosine Index (%)": cosine_sim * 100,
            "Euclidean Index (%)": euclidean_sim * 100
        })

    if not results:
        print("No leaders to compare against the Hitler reference vector.")
        return

    # Display results in a formatted table
    results_df = pd.DataFrame(results).sort_values(by="Euclidean Index (%)", ascending=False)
    print(results_df.to_string(index=False, formatters={
        'Cosine Index (%)': '{:,.2f}'.format,
        'Euclidean Index (%)': '{:,.2f}'.format
    }))
    print("="*60)
    print("\n- Cosine Index: Measures profile SHAPE similarity (direction). High if peaks align.")
    print("- Euclidean Index: Measures profile INTENSITY similarity (magnitude). High if scores are close in value.")


def main():
    """
    Main function to parse arguments and generate plots.
    """
    parser = argparse.ArgumentParser(
        description="Generate analysis plots from an authoritarianism evaluation CSV file."
    )
    parser.add_argument(
        "--csv_file",
        required=True,
        help="Path to the input CSV file generated by evaluate.py."
    )
    args = parser.parse_args()

    # Define output directory based on input file
    output_dir = "analysis"
    os.makedirs(output_dir, exist_ok=True)

    # Load and process data
    df_cleaned = load_and_clean_data(args.csv_file)

    # Calculate and display the Hitlerian Index
    calculate_and_display_hitlerian_index(df_cleaned, hitler_gold_standard_max)

    # Generate plots
    generate_dot_plot(df_cleaned, 'category', output_dir)
    generate_dot_plot(df_cleaned, 'indicator', output_dir)
    generate_radar_chart(df_cleaned, 'category', output_dir)
    generate_radar_chart(df_cleaned, 'indicator', output_dir)
    generate_pca_plot(df_cleaned, output_dir)

if __name__ == "__main__":
    main()