import argparse
import pandas as pd
import matplotlib.pyplot as plt
import os

def on_key(event):
    if event.key in ['q', 'Q', 'escape']:
        plt.close()

def plot_loss_curve(xlsx_path):
    if not os.path.exists(xlsx_path):
        print(f"File not found: {xlsx_path}")
        return

    df = pd.read_excel(xlsx_path)

    if "Epoch" not in df.columns or "Loss" not in df.columns:
        print("Error: Excel file must contain 'Epoch' and 'Loss' columns.")
        return

    epochs = df["Epoch"]
    losses = df["Loss"]

    fig, ax = plt.subplots(figsize=(8, 5))
    fig.canvas.mpl_connect("key_press_event", on_key)

    ax.plot(epochs, losses, marker="o", label="Loss")
    ax.set_title("Training Loss Curve")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.grid(True)
    ax.legend()
    plt.tight_layout()
    print("Press Q or ESC to close the plot.")
    plt.show()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Plot training loss curve from Excel file.")
    parser.add_argument("xlsx_path", type=str, help="Path to Excel file (e.g., models/ae_R6_loss.xlsx)")
    args = parser.parse_args()

    plot_loss_curve(args.xlsx_path)