import random
import numpy as np
import torch
import os
import cv2
import pandas as pd
import albumentations as A
from pathlib import Path
from sklearn.model_selection import train_test_split, StratifiedKFold
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score
import timm
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# Check if a GPU (CUDA) is available, and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Define configurations
class Config:
    seed = 42
    model_name = "swsl_resnext50_32x4d"
    epoch_size = 30
    batch_size = 48
    learning_rate = 1e-4
    early_stop = 5
    k_fold_num = 5


# Set random seeds
def set_random_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


# Data preprocessing functions
def img_gather_(img_path):
    class_encoder = {
        'dog': 0,
        'elephant': 1,
        'giraffe': 2,
        'guitar': 3,
        'horse': 4,
        'house': 5,
        'person': 6
    }

    file_lists = []
    label_lists = []

    for class_name in os.listdir(img_path):
        class_dir = os.path.join(img_path, class_name)
        file_list = [os.path.join(class_dir, file) for file in os.listdir(class_dir)]
        label_list = [class_encoder[class_name]] * len(file_list)

        file_lists.extend(file_list)
        label_lists.extend(label_list)

    return np.array(file_lists), np.array(label_lists)


class TrainDataset(Dataset):
    def __init__(self, file_lists, label_lists, transforms=None):
        self.file_lists = file_lists.copy()
        self.label_lists = label_lists.copy()
        self.transforms = transforms

    def __getitem__(self, idx):
        img = cv2.imread(self.file_lists[idx], cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transforms:
            img = self.transforms(image=img)["image"]

        img = img.transpose(2, 0, 1)

        return torch.tensor(img, dtype=torch.float), torch.tensor(self.label_lists[idx], dtype=torch.long)

    def __len__(self):
        assert len(self.file_lists) == len(self.label_lists)
        return len(self.file_lists)


class TestDataset(Dataset):
    def __init__(self, file_lists, transforms=None):
        self.file_lists = file_lists.copy()
        self.transforms = transforms

    def __getitem__(self, idx):
        img = cv2.imread(self.file_lists[idx], cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transforms:
            img = self.transforms(image=img)["image"]

        img = img.transpose(2, 0, 1)

        return torch.tensor(img, dtype=torch.float)

    def __len__(self):
        return len(self.file_lists)


# Create and return the model
def create_model():
    model = timm.create_model(Config.model_name, pretrained=True, num_classes=7)
    # Modify model architecture if needed
    return model.to(device)


# Create optimizer and scheduler
def create_optimizer_scheduler(model):
    feature_extractor = [param for name, param in model.named_parameters() if "fc" not in name]
    classifier = [param for name, param in model.fc.parameters()]
    params = [
        {"params": feature_extractor, "lr": Config.learning_rate * 0.5},
        {"params": classifier, "lr": Config.learning_rate}
    ]
    optimizer = AdamW(params, lr=Config.learning_rate)
    scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0)
    return optimizer, scheduler


# Create loss function
def create_loss_function():
    class_num = [329, 205, 235, 134, 151, 245, 399]
    class_weight = torch.tensor(np.max(class_num) / class_num).to(device=device, dtype=torch.float)
    criterion = nn.CrossEntropyLoss(weight=class_weight)
    return criterion


# Train one epoch
def train_step(model, data_loader, optimizer, criterion, epoch_idx):
    model.train()
    for iter_idx, (train_imgs, train_labels) in enumerate(data_loader["train_loader"], 1):
        train_imgs, train_labels = train_imgs.to(device=device, dtype=torch.float), train_labels.to(device)
        optimizer.zero_grad()
        train_pred = model(train_imgs)
        train_loss = criterion(train_pred, train_labels)
        train_loss.backward()
        optimizer.step()

        print(
            f"[Epoch {epoch_idx}/{Config.epoch_size}] model training iteration {iter_idx}/{len(data_loader['train_loader'])}",
            end="\r")


# Validation function
def validate(model, valid_loader, criterion):
    model.eval()
    valid_loss = []
    valid_acc = []
    valid_f1 = []
    with torch.no_grad():
        for iter_idx, (valid_imgs, valid_labels) in enumerate(valid_loader, 1):
            valid_imgs, valid_labels = valid_imgs.to(device=device, dtype=torch.float), valid_labels.to(device)
            valid_pred = model(valid_imgs)
            loss = criterion(valid_pred, valid_labels)
            valid_loss.append(loss.cpu().item())
            valid_pred_c = valid_pred.argmax(dim=-1)
            valid_acc.extend((valid_pred_c == valid_labels).cpu().tolist())
            f1 = f1_score(y_true=valid_labels.cpu().numpy(), y_pred=valid_pred_c.cpu().numpy(), average="macro")
            valid_f1.append(f1)

            print(f"[Validation] iteration {iter_idx}/{len(valid_loader)}", end="\r")

    valid_loss = np.mean(valid_loss)
    valid_acc = np.mean(valid_acc) * 100
    valid_f1 = np.mean(valid_f1)
    print(f"Validation loss: {valid_loss:.4f} | Validation acc: {valid_acc:.2f}% | Validation f1 score: {valid_f1:.4f}")
    return valid_loss, valid_acc, valid_f1


# Main training function
def train(data_loader):
    model = create_model().to(device)
    optimizer, scheduler = create_optimizer_scheduler(model)
    criterion = create_loss_function()

    best_model_state = None
    best_f1 = 0
    early_stop_count = 0

    for epoch_idx in range(1, Config.epoch_size + 1):
        train_step(model, data_loader, optimizer, criterion, epoch_idx)
        valid_loss, valid_acc, valid_f1 = validate(model, data_loader["valid_loader"], criterion)
        scheduler.step(valid_loss)

        if valid_f1 > best_f1:
            best_f1 = valid_f1
            best_model_state = model.state_dict()
            early_stop_count = 0
        else:
            early_stop_count += 1

        if early_stop_count == Config.early_stop:
            print("Early stopped." + " " * 30)
            break

    return best_model_state


# Main training and inference flow
if __name__ == "__main__":
    set_random_seeds(Config.seed)

    data_lists, data_labels = img_gather_("./data/train")
    best_models = []

    if Config.k_fold_num == -1:
        train_lists, valid_lists, train_labels, valid_labels = train_test_split(
            data_lists, data_labels, train_size=0.8, shuffle=True, random_state=Config.seed, stratify=data_labels
        )
        train_transforms = A.Compose([
            A.Rotate(),
            A.HorizontalFlip(),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            A.Normalize()
        ])
        valid_transforms = A.Compose([A.Normalize()])
        train_dataset = TrainDataset(file_lists=train_lists, label_lists=train_labels, transforms=train_transforms)
        valid_dataset = TrainDataset(file_lists=valid_lists, label_lists=valid_labels, transforms=valid_transforms)
        train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
        valid_loader = DataLoader(valid_dataset, batch_size=Config.batch_size, shuffle=True)
        data_loader = {"train_loader": train_loader, "valid_loader": valid_loader}
        print("No fold training starts ... ")
        best_model = train(data_loader)
        best_models.append(best_model)
    else:
        skf = StratifiedKFold(n_splits=Config.k_fold_num, random_state=Config.seed, shuffle=True)
        print(f"{Config.k_fold_num} fold training starts ... ")
        for fold_idx, (train_idx, valid_idx) in enumerate(skf.split(data_lists, data_labels), 1):
            train_lists, train_labels = data_lists[train_idx], data_labels[train_idx]
            valid_lists, valid_labels = data_lists[valid_idx], data_labels[valid_idx]
            train_transforms = A.Compose([
                A.Rotate(),
                A.HorizontalFlip(),
                A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                A.Normalize()
            ])
            valid_transforms = A.Compose([A.Normalize()])
            train_dataset = TrainDataset(file_lists=train_lists, label_lists=train_labels, transforms=train_transforms)
            valid_dataset = TrainDataset(file_lists=valid_lists, label_lists=valid_labels, transforms=valid_transforms)
            train_loader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True)
            valid_loader = DataLoader(valid_dataset, batch_size=Config.batch_size, shuffle=True)
            data_loader = {"train_loader": train_loader, "valid_loader": valid_loader}
            print(f"- {fold_idx} fold -")
            best_model = train(data_loader)
            best_models.append(best_model)

    test_transforms = A.Compose([A.Normalize()])
    test_files = sorted(Path("./data/test/0").glob("*"))
    test_dataset = TestDataset(file_lists=test_files, transforms=test_transforms)
    test_loader = DataLoader(test_dataset, batch_size=Config.batch_size, shuffle=False)

    answer_logits = []

    model = create_model().to(device)

    for fold_idx, best_model_state in enumerate(best_models, 1):
        model.load_state_dict(best_model_state)
        model.eval()
        fold_logits = []
        with torch.no_grad():
            for iter_idx, test_imgs in enumerate(test_loader, 1):
                test_imgs = test_imgs.to(device)
                test_pred = model(test_imgs)
                fold_logits.extend(test_pred.cpu().tolist())
                print(f"[{fold_idx} fold] inference iteration {iter_idx}/{len(test_loader)}", end="\r")
        answer_logits.append(fold_logits)

    answer_logits = np.mean(answer_logits, axis=0)
    answer_value = np.argmax(answer_logits, axis=-1)

    i = 0
    while True:
        submission_path = f"submissions/submission_{i}.csv"
        if not Path(submission_path).is_file():
            break
        i += 1

    submission = pd.read_csv("test_answer_sample_.csv", index_col=False)
    submission["answer value"] = answer_value
    submission["answer value"].to_csv(submission_path, index=False)
    print("\nAll done.")

'코드테스트' 카테고리의 다른 글

파이토치_1  (1) 2023.10.31
프로그래머스_1  (0) 2023.10.26

+ Recent posts