import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import torch.nn.functional as F
from torchvision import transforms, datasets

# Constants
USE_CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if USE_CUDA else "cpu")
BATCH_SIZE = 32
EPOCHS = 10
MEAN = [0.5, 0.5, 0.5]
STD = [0.5, 0.5, 0.5]

print(f'Using Pytorch version: {torch.__version__}, Device: {DEVICE}')


def denormalize(tensor):
    mean_tensor = torch.tensor(MEAN).view(-1, 1, 1)
    std_tensor = torch.tensor(STD).view(-1, 1, 1)
    return tensor * std_tensor + mean_tensor


def get_dataloaders(batch_size):
    data_transform = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(MEAN, STD)
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(MEAN, STD)
        ])
    }

    image_datasets = {x: datasets.ImageFolder("./hymenoptera_data", data_transform[x]) for x in ['train', 'val']}
    return {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, num_workers=0, shuffle=True) for x in ['train', 'val']}


def display_images(dataloader):
    for (X_train, y_train) in dataloader:
        print(f'X_train: {X_train.size()}, type: {X_train.type()}')
        print(f'y_train: {y_train.size()}, type: {y_train.type()}')
        break

    plt.figure(figsize=(10, 1))
    for i in range(10):
        plt.subplot(1, 10, i + 1)
        plt.axis('off')
        image_to_display = denormalize(X_train[i]).detach().cpu().numpy()
        plt.imshow(np.transpose(image_to_display, (1, 2, 0)))
        plt.title(f'Class: {y_train[i].item()}')
    plt.show()


def train_and_evaluate(model, train_loader, val_loader, optimizer, epochs, save_path="best_model_weights.pth"):
    best_val_loss = float("inf")  # initialize with a high value

    for epoch in range(1, epochs + 1):
        # Training loop
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

            if batch_idx % 10 == 0:
                print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

        # Evaluation loop
        model.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(DEVICE), target.to(DEVICE)
                output = model(data)
                test_loss += F.cross_entropy(output, target, reduction="sum").item()
                prediction = output.argmax(dim=1)
                correct += prediction.eq(target).sum().item()

        test_loss /= len(val_loader.dataset)
        test_accuracy = 100. * correct / len(val_loader.dataset)
        print(f"[{epoch}] Test Loss: {test_loss:.4f}, accuracy: {test_accuracy:.2f}%\n")

        # Save the model weights if this epoch has the best validation loss so far
        if test_loss < best_val_loss:
            best_val_loss = test_loss
            torch.save(model.state_dict(), save_path)
            print(f"Model weights saved to {save_path} with validation loss: {best_val_loss:.4f}")


def main():
    dataloaders = get_dataloaders(BATCH_SIZE)
    display_images(dataloaders['train'])

    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 2)  # Modify the final FC layer
    model = model.to(DEVICE)

    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    train_and_evaluate(model, dataloaders["train"], dataloaders["val"], optimizer, EPOCHS)


if __name__ == "__main__":
    main()

'코드테스트' 카테고리의 다른 글

파이토치_2  (1) 2023.10.31
프로그래머스_1  (0) 2023.10.26

+ Recent posts