https://github.com/ashleve/lightning-hydra-template
PyTorch Lightning과 Hydra를 사용하여 딥러닝 프로젝트를 설정하고 관리하기 위한 템플릿
딥러닝 모델 개발, 훈련, 검증, 테스트 등의 과정을 구조화하고 자동화
💡 템플릿을 활용한 General Workflow
- Write your PyTorch Lightning Module (see examples in src/modules/single_module.py)
- Write your PyTorch Lightning DataModule (see examples in src/datamodules/datamodules.py)
- Fill up your configs, particularly create experiment configs
- Run experiments:
💡 템플릿 코드 분석
- configs/data/mnist.yaml
# 데이터 모듈 클래스의 경로를 정의함
_target_: src.data.mnist_datamodule.MNISTDataModule
# 데이터가 저장될 디렉토리 경로를 지정함 (paths/default.yaml에 저장)
data_dir: ${paths.data_dir}
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
# 데이터셋을 학습/검증/테스트로 나누는 비율
train_val_test_split: [55_000, 5_000, 10_000]
# 데이터 로딩을 위한 worker process 개수
num_workers: 0
# 데이터 로더가 고정된 메모리 영역을 사용할지?
pin_memory: False
- configs/experiment/example.yaml
실험 설정을 정의# 기본 설정 정의 (override) defaults: - override /data: mnist - override /model: mnist - override /callbacks: default - override /trainer: default # 실험에 식별 위한 태그를 추가 tags: ["mnist", "simple_dense_net"] # 랜덤 시드값 설정 seed: 12345 # 최소/최대 epoch 수, clip_Val 설정하여 그래디언트 폭발 방지 trainer: min_epochs: 10 max_epochs: 10 gradient_clip_val: 0.5 # 모델 설정 정의 model: # 옵티마이져 설정 => 학습률 정의 optimizer: lr: 0.002 # NW 아키텍쳐 net: lin1_size: 128 lin2_size: 256 lin3_size: 64 compile: false data: batch_size: 64 # 로깅 설정: Weights&Biases 로거 설정 / Aim 로거 설정 등등 logger: wandb: tags: ${tags} group: "mnist" aim: experiment: "mnist"
- configs/hparams_search/mnist_optuna.yaml
- Optuna와 Lightning이 협력하여 학습률/배치크기/중간레이어 크기 등 다양한 하이퍼파라미터 탐색하고 최적의 조합을 찾는 과정 이루어짐
- 하이퍼파라미터 검색을 위한 설정 정의
# 기본 설정을 정의 (optuna sweeper 사용해 하이퍼파라미터 최적화 수행)
defaults:
- override /hydra/sweeper: optuna
# optuna가 최적화할 메트릭 => 검증 정확도
optimized_metric: "val/acc_best"
# hydra 설정을 정의함
hydra:
# 멀티런 모드 => 여러 실험 병렬로 실행
mode: "MULTIRUN"
# optuna sweeper 설정 정의
sweeper:
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
# 최적화 결과 저장할 스토리지 URL
storage: null
# 최적화 결과 저장할 스터디 이름
study_name: null
# 병렬로 실행할 작업 수
n_jobs: 1
# 최적화 방향, 검증 정확도 최대화
direction: maximize
# total number of runs that will be executed
n_trials: 20
# 하이퍼파라미터 샘플러 결정
sampler:
_target_: optuna.samplers.TPESampler
seed: 1234
n_startup_trials: 10 # number of random sampling runs before optimization starts
# define hyperparameter search space
# 학습률 범위 설정, batch 크기 범위 설정, FC Layer 크기 후보군
params:
model.optimizer.lr: interval(0.0001, 0.1)
data.batch_size: choice(32, 64, 128, 256)
model.net.lin1_size: choice(64, 128, 256)
model.net.lin2_size: choice(64, 128, 256)
model.net.lin3_size: choice(32, 64, 128, 256)
- configs/model/mnist.yaml모델 설정 정의
# 사용할 모델의 클래스 _target_: src.models.mnist_module.MNISTLitModule # Optimizer 설정 # Adam Optimizer, 추후 추가 인자 제공, lr, 가중치 감쇠 값 optimizer: _target_: torch.optim.Adam _partial_: true lr: 0.001 weight_decay: 0.0 # 학습률 스케쥴러 설정 # min 모드, 학습률 감소 비율은 0.1, 학습률 감소 전에 기다릴 epoch 수 scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau _partial_: true mode: min factor: 0.1 patience: 10 # 신경망 아키텍처 설정 # input size(28*28), 중간 완전연결 레이어 net: _target_: src.models.components.simple_dense_net.SimpleDenseNet input_size: 784 lin1_size: 64 lin2_size: 128 lin3_size: 64 output_size: 10 # compile model for faster training with pytorch 2.0 compile: false
- configs/eval.yaml
# @package _global_ defaults: - _self_ # 현재 설정 파일(eval.yaml) 포함 - data: mnist # 평가에 사용할 데이터모듈, test_dataloader() 포함 필요 - model: mnist # mnist 모델 사용 - logger: null - trainer: default - paths: default - extras: default - hydra: default # 작업 이름 task_name: "eval" # 작업 식별하는 태그 tags: ["dev"] # passing checkpoint path is necessary for evaluation ckpt_path: ???
- 평가 설정 정의
- configs/train.yaml
# @package _global_ defaults: - _self_ # 현재 설정 파일(train.yaml)을 포함 - data: mnist - model: mnist - callbacks: default - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) - trainer: default - paths: default - extras: default - hydra: default # 실험 설정, 특정 하이퍼파라미터 버전 관리 시 사용 - experiment: null # 하이퍼파라미터 최적화 설정 - hparams_search: null # optional local config for machine/user specific settings - optional local: default # debugging config (enable through command line, e.g. `python train.py debug=default) - debug: null # task name, determines output directory path task_name: "train" # tags to help you identify your experiments # you can overwrite this in experiment configs # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` tags: ["dev"] # set False to skip model training train: True # evaluate on test set, using best model weights achieved during training # lightning chooses best weights based on the metric specified in checkpoint callback test: True # 훈련 재개하기 위한 체크포인트 경로 # null로 설정 시 새로 훈련을 시작함 ckpt_path: null # seed for random number generators in pytorch, numpy and python.random seed: null
- 훈련 설정 정의
- src/data/mnist_datamodule.py
from typing import Any, Dict, Optional, Tuple
import torch
from lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
class MNISTDataModule(LightningDataModule):
# 초기화 함수
# data directory, 분할 비율, batch size, worker #, fixed mem, transformation 정의
def __init__(
self,
data_dir: str = "data/",
train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000),
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
) -> None:
super().__init__()
# 초기화 시 전달된 하이퍼파라미터 저장
# 이후 self.hparams 통해 하이퍼파라미터에 접근 가능
self.save_hyperparameters(logger=False)
# 데이터를 텐서로 변환
# (이미 알려진 MNIST dataset의) 평균과 표준편차로 정규화하는 변환 정의
self.transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
# 데이터셋 초기화, 나중에 setup 메서드에서 할당
self.data_train: Optional[Dataset] = None
self.data_val: Optional[Dataset] = None
self.data_test: Optional[Dataset] = None
# 배치 사이즈 초기화
self.batch_size_per_device = batch_size
@property
def num_classes(self) -> int:
return 10
def prepare_data(self) -> None:
# 데이터를 다운로드함
# Lightning 에서는 이 함수 한 번만 호출
MNIST(self.hparams.data_dir, train=True, download=True)
MNIST(self.hparams.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None) -> None:
# 데이터 로드, 학습, 검증, 테스트 데이터셋 설정
# stage 매개변수 => 현재 단계(fit, validate, test, predict), 기본값 None
# 배치 크기 조정 => 배치 크기를 장치 수로 나눔
if self.trainer is not None:
if self.hparams.batch_size % self.trainer.world_size != 0:
raise RuntimeError(
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
)
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
# 데이터셋을 로드하고, 학습/검증/테스트 데이터셋으로 분할
# 이미 로드된 경우 다시 로드하지 않음
if not self.data_train and not self.data_val and not self.data_test:
trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)
dataset = ConcatDataset(datasets=[trainset, testset])
self.data_train, self.data_val, self.data_test = random_split(
dataset=dataset,
lengths=self.hparams.train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
def train_dataloader(self) -> DataLoader[Any]:
# 학습 데이터 로더를 생성하고 반환
return DataLoader(
dataset=self.data_train,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=True,
)
def val_dataloader(self) -> DataLoader[Any]:
# 검증 데이터 로더를 생성하고 반환
return DataLoader(
dataset=self.data_val,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)
def test_dataloader(self) -> DataLoader[Any]:
# 테스트 데이터 로더를 생성하고 반환
return DataLoader(
dataset=self.data_test,
batch_size=self.batch_size_per_device,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
shuffle=False,
)
def teardown(self, stage: Optional[str] = None) -> None:
# 훈련, 검증, 테스트 후 정리 작업 수행 (default do nothing)
pass
def state_dict(self) -> Dict[Any, Any]:
# 체크포인트 저장 시 데이터 모듈 상태 저장 (default return empty dict)
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# 체크포인트 로드 시 데이터 모듈 상태 로드 (default do nothing)
pass
if __name__ == "__main__":
_ = MNISTDataModule()
- src/models/mnist_module.py
- lightning에서는 모델 학습 / 검증 / 테스트 / 예측 을 단계로 나누어 관리
- 각 단계는 여러 epoch로 구성, 각 epoch는 여러 배치로 구성, 각 batch에 대한 작업을 single step
- fit: 모델을 훈련(train)하고 검증(validate)하는 단계입니다. 이 단계는 주로 trainer.fit() 메서드를 호출할 때 실행됩니다.
- validate: 훈련된 모델을 별도로 검증하는 단계입니다. 이 단계는 trainer.validate() 메서드를 호출할 때 실행됩니다.
- test: 모델을 테스트하는 단계입니다. 이 단계는 trainer.test() 메서드를 호출할 때 실행됩니다.
- predict: 모델을 사용하여 예측하는 단계입니다. 이 단계는 trainer.predict() 메서드를 호출할 때 실행됩니다
from typing import Any, Dict, Tuple import torch from lightning import LightningModule from torchmetrics import MaxMetric, MeanMetric from torchmetrics.classification.accuracy import Accuracy class MNISTLitModule(LightningModule): def __init__( self, net: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler, compile: bool, ) -> None: # 훈련할 모델 / 최적화 알고리즘 / 학습률 스케쥴러 / 모델 컴파일 여부 super().__init__() # 초기화 시 전달된 하이퍼파라미터 저장 self.save_hyperparameters(logger=False) # 모델 인스턴스 self.net = net # 손실 함수 self.criterion = torch.nn.CrossEntropyLoss() # metric objects for calculating and averaging accuracy across batches self.train_acc = Accuracy(task="multiclass", num_classes=10) self.val_acc = Accuracy(task="multiclass", num_classes=10) self.test_acc = Accuracy(task="multiclass", num_classes=10) # for averaging loss across batches self.train_loss = MeanMetric() self.val_loss = MeanMetric() self.test_loss = MeanMetric() # for tracking best so far validation accuracy self.val_acc_best = MaxMetric() def forward(self, x: torch.Tensor) -> torch.Tensor: # 모델의 순전파 정의, 입력 이미지 텐서를 받음, 모델의 출력(logits) 리턴 return self.net(x) def on_train_start(self) -> None: # 검증 지표가 훈련 시작 전 상태/이전 실행의 값 반영하지 않도록 self.val_loss.reset() self.val_acc.reset() self.val_acc_best.reset() def model_step( self, batch: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # 배치 데이터를 처리하여 손실, 예측값, 실제값 계산 # x는 입력 이미지, y는 실제 레이블 x, y = batch # 모델의 출력 logits logits = self.forward(x) # 손실값 계산 loss = self.criterion(logits, y) # 예측 레이블 preds = torch.argmax(logits, dim=1) return loss, preds, y def training_step( self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: # 단일 훈련 단계를 수행함 (model_step 호출) loss, preds, targets = self.model_step(batch) # update and log metrics self.train_loss(loss) self.train_acc(preds, targets) self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True) # return loss or backpropagation will fail return loss def on_train_epoch_end(self) -> None: # train epoch 끝날 때마다 호출되는 hook function pass def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: # 단일 검증 단계를 수행함 (model_step 호출) loss, preds, targets = self.model_step(batch) # update and log metrics self.val_loss(loss) self.val_acc(preds, targets) self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) def on_validation_epoch_end(self) -> None: acc = self.val_acc.compute() # get current val acc self.val_acc_best(acc) # update best so far val acc # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object # otherwise metric would be reset by lightning after each epoch self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True) def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None: # 단일 테스트 단계를 수행함 (model_step 호출) loss, preds, targets = self.model_step(batch) # update and log metrics self.test_loss(loss) self.test_acc(preds, targets) self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True) def on_test_epoch_end(self) -> None: # test epoch 끝날 때마다 호출되는 hook function pass def setup(self, stage: str) -> None: # 각 단계(fit, validate, test, predict) 시작 시 호출되는 hook if self.hparams.compile and stage == "fit": # 모델을 컴파일 self.net = torch.compile(self.net) def configure_optimizers(self) -> Dict[str, Any]: # 최적화 알고리즘 optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) if self.hparams.scheduler is not None: # 학습률 스케쥴러 scheduler = self.hparams.scheduler(optimizer=optimizer) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler, "monitor": "val/loss", "interval": "epoch", "frequency": 1, }, } return {"optimizer": optimizer} if __name__ == "__main__": _ = MNISTLitModule(None, None, None, None)
- src/eval.py
- trainer.test()를 통해 주어진 모델과 데이터모듈을 사용해 테스트를 수행하고, metric_dict 리턴
- hydra.main ⇒ eval.yaml 파일을 사용
from typing import Any, Dict, List, Tuple
import hydra
import rootutils
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
extras,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)
log = RankedLogger(__name__, rank_zero_only=True)
@task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Evaluates given checkpoint on a datamodule testset.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: DictConfig configuration composed by Hydra.
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
"""
assert cfg.ckpt_path
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
log.info("Starting testing!")
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
# for predictions use trainer.predict(...)
# predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
metric_dict = trainer.callback_metrics
return metric_dict, object_dict
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
def main(cfg: DictConfig) -> None:
"""Main entry point for evaluation.
:param cfg: DictConfig configuration composed by Hydra.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
extras(cfg)
evaluate(cfg)
if __name__ == "__main__":
main()
- src/train.py
- trainer.fit()을 통해 모델을 학습하고, trainer.test()를 통해 테스트를 수행함
- train_metrics와 test_metrics를 리턴함
- hydra.main ⇒ train.yaml 파일을 사용
-
from typing import Any, Dict, List, Optional, Tuple # auto-log experiment in ClearML from clearml import Task, Logger import hydra import lightning as L import rootutils import torch from lightning import Callback, LightningDataModule, LightningModule, Trainer from lightning.pytorch.loggers import Logger from omegaconf import DictConfig rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) # ------------------------------------------------------------------------------------ # # the setup_root above is equivalent to: # - adding project root dir to PYTHONPATH # (so you don't need to force user to install project as a package) # (necessary before importing any local modules e.g. `from src import utils`) # - setting up PROJECT_ROOT environment variable # (which is used as a base for paths in "configs/paths/default.yaml") # (this way all filepaths are the same no matter where you run the code) # - loading environment variables from ".env" in root dir # # you can remove it if you: # 1. either install project as a package or move entry files to project root dir # 2. set `root_dir` to "." in "configs/paths/default.yaml" # # more info: <https://github.com/ashleve/rootutils> # ------------------------------------------------------------------------------------ # from src.utils import ( RankedLogger, extras, get_metric_value, instantiate_callbacks, instantiate_loggers, log_hyperparameters, task_wrapper, ) log = RankedLogger(__name__, rank_zero_only=True) @task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Trains the model. Can additionally evaluate on a testset, using best weights obtained during training. This method is wrapped in optional @task_wrapper decorator, that controls the behavior during failure. Useful for multiruns, saving info about the crash, etc. :param cfg: A DictConfig configuration composed by Hydra. :return: A tuple with metrics and dict with all instantiated objects. """ # set seed for random number generators in pytorch, numpy and python.random if cfg.get("seed"): L.seed_everything(cfg.seed, workers=True) log.info(f"Instantiating datamodule <{cfg.data._target_}>") datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) log.info(f"Instantiating model <{cfg.model._target_}>") model: LightningModule = hydra.utils.instantiate(cfg.model) log.info("Instantiating callbacks...") callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks")) log.info("Instantiating loggers...") logger: List[Logger] = instantiate_loggers(cfg.get("logger")) log.info(f"Instantiating trainer <{cfg.trainer._target_}>") trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) object_dict = { "cfg": cfg, "datamodule": datamodule, "model": model, "callbacks": callbacks, "logger": logger, "trainer": trainer, } if logger: log.info("Logging hyperparameters!") log_hyperparameters(object_dict) if cfg.get("train"): log.info("Starting training!") trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) train_metrics = trainer.callback_metrics if cfg.get("test"): log.info("Starting testing!") ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "": log.warning("Best ckpt not found! Using current weights for testing...") ckpt_path = None trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) log.info(f"Best ckpt path: {ckpt_path}") test_metrics = trainer.callback_metrics # merge train and test metrics metric_dict = {**train_metrics, **test_metrics} return metric_dict, object_dict @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") def main(cfg: DictConfig) -> Optional[float]: """Main entry point for training. :param cfg: DictConfig configuration composed by Hydra. :return: Optional[float] with optimized metric value. """ # apply extra utilities # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) extras(cfg) # train the model metric_dict, _ = train(cfg) # safely retrieve metric value for hydra-based hyperparameter optimization metric_value = get_metric_value( metric_dict=metric_dict, metric_name=cfg.get("optimized_metric") ) # return optimized metric return metric_value if __name__ == "__main__": task = Task.init(project_name='lightning-hydra', task_name='experiment1') main()
'AI > DL' 카테고리의 다른 글
서버 ClearML agent 설정 + 도커에서 돌아가게 세팅하기 (0) | 2024.08.29 |
---|---|
효율적인 MLOps를 가능케 하는 ClearML (0) | 2024.08.29 |
lightning-hydra-template CIFAR-10 데이터셋 학습해보기 (1) | 2024.08.29 |
[책] PyTorch를 활용한 머신러닝/딥러닝 철저 입문 (1) | 2024.08.29 |
Loss function for classification & regression (0) | 2024.08.29 |