import argparse
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import numpy as np

def on_key(event):
    if event.key in ['q', 'Q', 'escape']:
        plt.close()

def plot_histogram_and_roc(xlsx_path, bins=50):
    # Excel 読み込み
    xls = pd.ExcelFile(xlsx_path)
    if not {"normal", "abnormal"}.issubset(set(xls.sheet_names)):
        print("Error: Excel file must contain 'normal' and 'abnormal' sheets.")
        return

    df_normal = pd.read_excel(xls, sheet_name="normal")
    df_abnormal = pd.read_excel(xls, sheet_name="abnormal")

    if "reconstruction_error" not in df_normal.columns or "reconstruction_error" not in df_abnormal.columns:
        print("Error: Missing 'reconstruction_error' column.")
        return

    errors_normal = df_normal["reconstruction_error"].values
    errors_abnormal = df_abnormal["reconstruction_error"].values

    # ROC用データ
    y_true = np.concatenate([np.zeros(len(errors_normal)), np.ones(len(errors_abnormal))])
    y_score = np.concatenate([errors_normal, errors_abnormal])

    fpr, tpr, _ = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)

    # プロット準備
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    fig.canvas.mpl_connect("key_press_event", on_key)

    # ヒストグラム（左）
    ax1.hist(errors_normal, bins=bins, alpha=0.7, color='steelblue', label='Normal', edgecolor='black')
    ax1.hist(errors_abnormal, bins=bins, alpha=0.7, color='darkorange', label='Abnormal', edgecolor='black')
    ax1.set_title("Reconstruction Error Histogram")
    ax1.set_xlabel("Reconstruction Error")
    ax1.set_ylabel("Frequency")
    ax1.legend()
    ax1.grid(True)

    # ROC曲線（右）
    ax2.plot(fpr, tpr, color='darkred', lw=2, label=f'AUC = {roc_auc:.3f}')
    ax2.plot([0, 1], [0, 1], linestyle='--', color='gray', lw=1)
    ax2.set_xlim([0.0, 1.0])
    ax2.set_ylim([0.0, 1.05])
    ax2.set_xlabel('False Positive Rate')
    ax2.set_ylabel('True Positive Rate')
    ax2.set_title('ROC Curve')
    ax2.legend(loc="lower right")
    ax2.grid(True)

    plt.tight_layout()
    print("Press Q or ESC to close.")
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot histogram and ROC curve from Excel file.")
    parser.add_argument("xlsx_path", type=str, help="Path to .xlsx file with 'normal' and 'abnormal' sheets")
    parser.add_argument("--bins", type=int, default=50, help="Number of histogram bins")
    args = parser.parse_args()

    plot_histogram_and_roc(args.xlsx_path, bins=args.bins)