'''
実行方法：
$ python3 train_autoencoder.py --epochs 50 --output models/AE_50epoch.pth

保存されるファイル：
models/ae_R6.pth：重みのみ（state_dict）
models/ae_R6_traced.pt：構造＋重み（TorchScript）
models/ae_R6_loss.xlsx：学習曲線データ（Excel形式）
'''

import os
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import pandas as pd

# ======== データディレクトリ（学習用データの明示） ========
TRAIN_DIR = './data/train/normal_R6_doragon'

# ======== デバイス自動判定（CUDA → MPS → CPU） ========
if torch.cuda.is_available():
    device = torch.device("cuda")
    device_name = torch.cuda.get_device_name(0)
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device = torch.device("mps")
    device_name = "Apple MPS"
else:
    device = torch.device("cpu")
    device_name = "CPU"

print(f"Using device: {device} ({device_name})")

# ======== データセット定義 ========
class ImageFolderDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.paths = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(".png")]
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("L")
        if self.transform:
            img = self.transform(img)
        return img

# ======== オートエンコーダーモデル定義 ========
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

# ======== 学習処理メイン関数 ========
def train(num_epochs, output_path):
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor()
    ])

    loss_history = []

    dataset = ImageFolderDataset(TRAIN_DIR, transform)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

    model = Autoencoder().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for imgs in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            imgs = imgs.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, imgs)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        loss_history.append(avg_loss)

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}")

    # 保存用パス定義
    base_path = os.path.splitext(output_path)[0]
    loss_xlsx_path = base_path + "_loss.xlsx"
    traced_path = base_path + "_traced.pt"

    # ディレクトリ作成
    dir_path = os.path.dirname(base_path)
    if dir_path:
        os.makedirs(dir_path, exist_ok=True)

    # ✅ 重み（state_dict）を保存
    torch.save(model.state_dict(), output_path)
    print(f"Saved model weights (state_dict) to {output_path}")

    # ✅ モデル構造＋重み（TorchScript）を保存
    example_input = torch.randn(1, 1, 128, 128).to(device)
    traced_model = torch.jit.trace(model, example_input)
    traced_model.save(traced_path)
    print(f"Saved traced model to {traced_path}")

    # ✅ 学習曲線データをExcel保存
    df = pd.DataFrame({
        "Epoch": list(range(1, num_epochs + 1)),
        "Loss": loss_history
    })
    df.to_excel(loss_xlsx_path, index=False)
    print(f"Saved loss history to {loss_xlsx_path}")

# ======== エントリーポイント ========
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Autoencoder training for anomaly detection")
    parser.add_argument("--epochs", type=int, default=30, help="Number of training epochs")
    parser.add_argument("--output", type=str, required=True, help="Path to save the trained model weights (e.g., models/ae.pth)")

    args = parser.parse_args()

    train(num_epochs=args.epochs, output_path=args.output)