AI/DL

lightning-hydra-template CIFAR-10 데이터셋 학습해보기

민사민서 2024. 8. 29. 14:49

https://github.com/ashleve/lightning-hydra-template

 

GitHub - ashleve/lightning-hydra-template: PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡

PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡🔥⚡ - ashleve/lightning-hydra-template

github.com

 

💡 CIFAR-10 데이터셋 활용해보기

1. configs/data/cifar10.yaml 파일을 만들어 설정 추가

CIFAR10 dataset은 이미 학습세트 / 테스트세트 나뉘어 제공, 검증 세트 필요하다면 학습 세트를 분할

_target_: src.data.cifar10_datamodule.CIFAR10DataModule
data_dir: ${paths.data_dir}
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
train_val_test_split: [45_000, 5_000]
num_workers: 4
pin_memory: True

 

2. configs/model/cifar10.yaml 파일 생성

_target_: src.models.cifar10_module.CIFAR10LitModule

optimizer:
  _target_: torch.optim.Adam
  _partial_: true
  lr: 0.001
  weight_decay: 0.0

scheduler:
  _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
  _partial_: true
  mode: min
  factor: 0.1
  patience: 10

net:
  _target_: torchvision.models.resnet18
  pretrained: false
  num_classes: 10

# compile model for faster training with pytorch 2.0
compile: false

 

3. configs/eval.yaml & train.yaml 파일 수정

# @package _global_

defaults:
  - _self_
  - data: cifar10 # choose datamodule with `test_dataloader()` for evaluation
  - model: cifar10
  - logger: null
  - trainer: default
  - paths: default
  - extras: default
  - hydra: default

task_name: "eval"

tags: ["dev"]

# passing checkpoint path is necessary for evaluation
ckpt_path: ???
# @package _global_

# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
  - _self_
  - data: cifar10
  - model: cifar10
  - callbacks: default
  - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
  - trainer: default
  - paths: default
  - extras: default
  - hydra: default

  # experiment configs allow for version control of specific hyperparameters
  # e.g. best hyperparameters for given model and datamodule
  - experiment: null

  # config for hyperparameter optimization
  - hparams_search: null

  # optional local config for machine/user specific settings
  # it's optional since it doesn't need to exist and is excluded from version control
  - optional local: default

  # debugging config (enable through command line, e.g. `python train.py debug=default)
  - debug: null

# task name, determines output directory path
task_name: "train"

# tags to help you identify your experiments
# you can overwrite this in experiment configs
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
tags: ["dev"]

# set False to skip model training
train: True

# evaluate on test set, using best model weights achieved during training
# lightning chooses best weights based on the metric specified in checkpoint callback
test: True

# simply provide checkpoint path to resume training
ckpt_path: null

# seed for random number generators in pytorch, numpy and python.random
seed: null

 

4. src/data/cifar10_datamodule.py 파일 생성

from typing import Any, Dict, Optional, Tuple

import torch
from lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms


class CIFAR10DataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str = "data/",
        train_val_split: Tuple[int, int] = (45_000, 5_000),
        batch_size: int = 128,
        num_workers: int = 4,
        pin_memory: bool = True,
    ) -> None:
        super().__init__()
        self.save_hyperparameters(logger=False)

        # CIFAR-10 데이터셋의 평균과 표준편차를 사용하여 정규화
        self.transforms = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]
        )

        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None

        self.batch_size_per_device = batch_size

    @property
    def num_classes(self) -> int:
        return 10

    def prepare_data(self) -> None:
        CIFAR10(self.hparams.data_dir, train=True, download=True)
        CIFAR10(self.hparams.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None) -> None:
        if self.trainer is not None:
            if self.hparams.batch_size % self.trainer.world_size != 0:
                raise RuntimeError(
                    f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
                )
            self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size

        if not self.data_train and not self.data_val and not self.data_test:
            full_train_dataset = CIFAR10(self.hparams.data_dir, train=True, transform=self.transforms)
            self.data_train, self.data_val = random_split(
                full_train_dataset,
                self.hparams.train_val_split,
                generator=torch.Generator().manual_seed(42),
            )
            self.data_test = CIFAR10(self.hparams.data_dir, train=False, transform=self.transforms)

    def train_dataloader(self) -> DataLoader[Any]:
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
        )

    def val_dataloader(self) -> DataLoader[Any]:
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

    def test_dataloader(self) -> DataLoader[Any]:
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

    def teardown(self, stage: Optional[str] = None) -> None:
        pass

    def state_dict(self) -> Dict[Any, Any]:
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        pass


if __name__ == "__main__":
    _ = CIFAR10DataModule()

 

5. src/models/cifar10_module.py 파일 생성

from typing import Any, Dict, Tuple

import torch
from lightning import LightningModule
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification.accuracy import Accuracy

class CIFAR10LitModule(LightningModule):
    def __init__(
        self,
        net: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler,
        compile: bool,
    ) -> None:
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.net = net

        # loss function
        self.criterion = torch.nn.CrossEntropyLoss()

        # metric objects for calculating and averaging accuracy across batches
        self.train_acc = Accuracy(task="multiclass", num_classes=10)
        self.val_acc = Accuracy(task="multiclass", num_classes=10)
        self.test_acc = Accuracy(task="multiclass", num_classes=10)

        # for averaging loss across batches
        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        # for tracking best so far validation accuracy
        self.val_acc_best = MaxMetric()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

    def on_train_start(self) -> None:
        self.val_loss.reset()
        self.val_acc.reset()
        self.val_acc_best.reset()

    def model_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, y = batch
        logits = self.forward(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.train_loss(loss)
        self.train_acc(preds, targets)
        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)

        # return loss or backpropagation will fail
        return loss

    def on_train_epoch_end(self) -> None:
        pass

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.val_loss(loss)
        self.val_acc(preds, targets)
        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

    def on_validation_epoch_end(self) -> None:
        acc = self.val_acc.compute()  # get current val acc
        self.val_acc_best(acc)  # update best so far val acc
        self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.test_loss(loss)
        self.test_acc(preds, targets)
        self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)

    def on_test_epoch_end(self) -> None:
        pass

    def setup(self, stage: str) -> None:
        if self.hparams.compile and stage == "fit":
            self.net = torch.compile(self.net)

    def configure_optimizers(self) -> Dict[str, Any]:
        optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "val/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}


if __name__ == "__main__":
    _ = CIFAR10LitModule(None, None, None, None)

 

💡 트러블슈팅

1. src.data.cifar10_datamodule.py 모듈이 인식이 안 됨 (configs/data/cifar10.yaml 의 _target_에서 …)

 

- 재접속하면 해결됨

- 캐시문제? test.py 로 이름 변경해서 연결했다가 다시 이름 바꾸니 됨

 

2. 돌렸는데 속도 너무 느림, 살펴보니 GPU utilization이 0이다!

# train on CPU
python train.py trainer=cpu

# train on 1 GPU
python train.py trainer=gpu

 

3. 각 epoch 끝나고 시작하는 사이에 10-15초 딜레이 발생

매 epoch 끝날 때마다 save_last 옵션으로 always save an exact copy of last checkpoint to a file lask.ckpt

fully overwrite the file ⇒ 이 과정을 생략하거나, every_n_epochs 마다 save 하게 하면 기다리는 시간을 줄일 수 있음

(근데 굳이..? 왜냐면 체크포인트 저장해야 특정 시점부터 train 이어서 할 수도 있고..)

 

configs/callbacks/default.yaml 에서 save_last 를 False로 바꾸고

model_checkpoint:
  dirpath: ${paths.output_dir}/checkpoints
  filename: "epoch_{epoch:03d}"
  monitor: "val/acc"
  mode: "max"
  save_last: False
  auto_insert_metric_name: False

configs/callbacks/model_checkpoint.yaml에서

  every_n_epochs: 5 # number of epochs between checkpoints

이렇게 세팅해두면 매 5번째 epoch 끝나고만 저장하느라 delay가 생길 것

 

4. CIFAR10 데이터셋에 대해 accuracy가 0.7-0.8 정도밖에 나오지 않더라

 

  • 전처리 외에 random crop, 좌우반전 등의 증강 기법 추가
  • resnet18 대신 더 깊은 모델 resnet50, resnet101 등을 백본으로 사용
  • 하이퍼파라미터 튜닝 ⇒ 학습률, 배치 크기, 옵티마이저의 파라미터 등을 튜닝, 스케쥴러 설정 조정
  • 가중치 초기화 방법 변경, 배치 정규화, dropout 등의 레이어 추가
  • warm-up learning rate 스케쥴링

등등 시도해볼 건 많네요