import os
import argparse
import torch
import pandas as pd
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

def compute_errors(model, folder, label, device, image_size=(128, 128)):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor()
    ])

    paths = sorted([
        os.path.join(folder, f)
        for f in os.listdir(folder)
        if f.lower().endswith(".png")
    ])

    results = []
    for path in tqdm(paths, desc=f"Processing {'normal' if label==0 else 'abnormal'}"):
        img = Image.open(path).convert("L")
        tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            recon = model(tensor)
            error = torch.mean((tensor - recon) ** 2).item()

        results.append({
            "filename": os.path.basename(path),
            "reconstruction_error": error
        })

    return pd.DataFrame(results)

def evaluate_and_save_excel(model_path, normal_dir, abnormal_dir, output_xlsx):
    device = torch.device("cuda" if torch.cuda.is_available()
                          else "mps" if torch.backends.mps.is_available()
                          else "cpu")
    print(f"Using device: {device}")

    model = torch.jit.load(model_path, map_location=device)
    model.eval()

    df_normal = compute_errors(model, normal_dir, label=0, device=device)
    df_abnormal = compute_errors(model, abnormal_dir, label=1, device=device)

    os.makedirs(os.path.dirname(output_xlsx), exist_ok=True)
    with pd.ExcelWriter(output_xlsx) as writer:
        df_normal.to_excel(writer, index=False, sheet_name="normal")
        df_abnormal.to_excel(writer, index=False, sheet_name="abnormal")

    print(f"Saved Excel file to {output_xlsx}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate reconstruction error and save to Excel.")
    parser.add_argument("--model", type=str, required=True, help="Path to traced .pt model")
    parser.add_argument("--normal", type=str, required=True, help="Folder of normal test images")
    parser.add_argument("--abnormal", type=str, required=True, help="Folder of abnormal test images")
    parser.add_argument("--output", type=str, required=True, help="Path to output .xlsx file")
    args = parser.parse_args()

    evaluate_and_save_excel(
        model_path=args.model,
        normal_dir=args.normal,
        abnormal_dir=args.abnormal,
        output_xlsx=args.output
    )