import os
import argparse
from PIL import Image
import torch
from torchvision import transforms
from tqdm import tqdm

def load_images(folder, transform):
    image_paths = sorted([
        os.path.join(folder, f)
        for f in os.listdir(folder)
        if f.lower().endswith(".png")
    ])
    return image_paths

def reconstruct(model_path, input_folder, output_folder, image_size=(128, 128)):
    os.makedirs(output_folder, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available()
                          else "mps" if torch.backends.mps.is_available()
                          else "cpu")
    print(f"Using device: {device}")

    # モデル読み込み
    model = torch.jit.load(model_path, map_location=device)
    model.eval()

    # 前処理
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor()
    ])
    to_pil = transforms.ToPILImage()

    # 画像パス読み込み
    image_paths = load_images(input_folder, transform)

    for path in tqdm(image_paths, desc="Reconstructing"):
        img = Image.open(path).convert("L")
        tensor = transform(img).unsqueeze(0).to(device)  # shape: (1, 1, H, W)

        with torch.no_grad():
            recon = model(tensor).cpu().squeeze(0)  # shape: (1, H, W)

        out_img = to_pil(recon)
        base = os.path.basename(path)
        out_path = os.path.join(output_folder, base)
        out_img.save(out_path)

    print(f"Reconstructed images saved to: {output_folder}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Reconstruct images using traced autoencoder.")
    parser.add_argument("--model", type=str, required=True, help="Path to traced .pt model")
    parser.add_argument("--input", type=str, required=True, help="Folder containing input PNG images")
    parser.add_argument("--output", type=str, required=True, help="Folder to save reconstructed images")
    args = parser.parse_args()

    reconstruct(model_path=args.model, input_folder=args.input, output_folder=args.output)