AI/DL

lightning-hydra-template 코드분석

민사민서 2024. 8. 29. 14:12

https://github.com/ashleve/lightning-hydra-template

 

GitHub - ashleve/lightning-hydra-template: PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡

PyTorch Lightning + Hydra. A very user-friendly template for ML experimentation. ⚡🔥⚡ - ashleve/lightning-hydra-template

github.com

PyTorch Lightning과 Hydra를 사용하여 딥러닝 프로젝트를 설정하고 관리하기 위한 템플릿

딥러닝 모델 개발, 훈련, 검증, 테스트 등의 과정을 구조화하고 자동화

💡 템플릿을 활용한 General Workflow

  1. Write your PyTorch Lightning Module (see examples in src/modules/single_module.py)
  2. Write your PyTorch Lightning DataModule (see examples in src/datamodules/datamodules.py)
  3. Fill up your configs, particularly create experiment configs
  4. Run experiments:

💡 템플릿 코드 분석

  • configs/data/mnist.yaml
# 데이터 모듈 클래스의 경로를 정의함
_target_: src.data.mnist_datamodule.MNISTDataModule
# 데이터가 저장될 디렉토리 경로를 지정함 (paths/default.yaml에 저장)
data_dir: ${paths.data_dir}
batch_size: 128 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
# 데이터셋을 학습/검증/테스트로 나누는 비율
train_val_test_split: [55_000, 5_000, 10_000]
# 데이터 로딩을 위한 worker process 개수
num_workers: 0
# 데이터 로더가 고정된 메모리 영역을 사용할지?
pin_memory: False
  • configs/experiment/example.yaml
    # 기본 설정 정의 (override)
    defaults:
      - override /data: mnist
      - override /model: mnist
      - override /callbacks: default
      - override /trainer: default
    
    # 실험에 식별 위한 태그를 추가
    tags: ["mnist", "simple_dense_net"]
    
    # 랜덤 시드값 설정
    seed: 12345
    
    # 최소/최대 epoch 수, clip_Val 설정하여 그래디언트 폭발 방지
    trainer:
      min_epochs: 10
      max_epochs: 10
      gradient_clip_val: 0.5
    
    # 모델 설정 정의
    model:
    	# 옵티마이져 설정 => 학습률 정의
      optimizer:
        lr: 0.002
      # NW 아키텍쳐
      net:
        lin1_size: 128
        lin2_size: 256
        lin3_size: 64
      compile: false
    
    data:
      batch_size: 64
    
    # 로깅 설정: Weights&Biases 로거 설정 / Aim 로거 설정 등등
    logger:
      wandb:
        tags: ${tags}
        group: "mnist"
      aim:
        experiment: "mnist"
    
    실험 설정을 정의
  • configs/hparams_search/mnist_optuna.yaml
    • Optuna와 Lightning이 협력하여 학습률/배치크기/중간레이어 크기 등 다양한 하이퍼파라미터 탐색하고 최적의 조합을 찾는 과정 이루어짐
    • 하이퍼파라미터 검색을 위한 설정 정의
# 기본 설정을 정의 (optuna sweeper 사용해 하이퍼파라미터 최적화 수행)
defaults:
  - override /hydra/sweeper: optuna

# optuna가 최적화할 메트릭 => 검증 정확도
optimized_metric: "val/acc_best"

# hydra 설정을 정의함
hydra:
	# 멀티런 모드 => 여러 실험 병렬로 실행
  mode: "MULTIRUN" 
	
	# optuna sweeper 설정 정의
  sweeper:
    _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper

    # 최적화 결과 저장할 스토리지 URL
    storage: null

    # 최적화 결과 저장할 스터디 이름
    study_name: null

    # 병렬로 실행할 작업 수
    n_jobs: 1

    # 최적화 방향, 검증 정확도 최대화
    direction: maximize

    # total number of runs that will be executed
    n_trials: 20

		# 하이퍼파라미터 샘플러 결정
    sampler:
      _target_: optuna.samplers.TPESampler
      seed: 1234
      n_startup_trials: 10 # number of random sampling runs before optimization starts

    # define hyperparameter search space
    # 학습률 범위 설정, batch 크기 범위 설정, FC Layer 크기 후보군
    params:
      model.optimizer.lr: interval(0.0001, 0.1)
      data.batch_size: choice(32, 64, 128, 256)
      model.net.lin1_size: choice(64, 128, 256)
      model.net.lin2_size: choice(64, 128, 256)
      model.net.lin3_size: choice(32, 64, 128, 256)
  • configs/model/mnist.yaml모델 설정 정의
  • # 사용할 모델의 클래스 _target_: src.models.mnist_module.MNISTLitModule # Optimizer 설정 # Adam Optimizer, 추후 추가 인자 제공, lr, 가중치 감쇠 값 optimizer: _target_: torch.optim.Adam _partial_: true lr: 0.001 weight_decay: 0.0 # 학습률 스케쥴러 설정 # min 모드, 학습률 감소 비율은 0.1, 학습률 감소 전에 기다릴 epoch 수 scheduler: _target_: torch.optim.lr_scheduler.ReduceLROnPlateau _partial_: true mode: min factor: 0.1 patience: 10 # 신경망 아키텍처 설정 # input size(28*28), 중간 완전연결 레이어 net: _target_: src.models.components.simple_dense_net.SimpleDenseNet input_size: 784 lin1_size: 64 lin2_size: 128 lin3_size: 64 output_size: 10 # compile model for faster training with pytorch 2.0 compile: false
  • configs/eval.yaml
    # @package _global_
    
    defaults:
      - _self_ # 현재 설정 파일(eval.yaml) 포함
      - data: mnist # 평가에 사용할 데이터모듈, test_dataloader() 포함 필요
      - model: mnist # mnist 모델 사용
      - logger: null
      - trainer: default
      - paths: default
      - extras: default
      - hydra: default
    
    # 작업 이름
    task_name: "eval"
    
    # 작업 식별하는 태그
    tags: ["dev"]
    
    # passing checkpoint path is necessary for evaluation
    ckpt_path: ???
    
    
    • 평가 설정 정의
  • configs/train.yaml
    # @package _global_
    
    defaults:
      - _self_ # 현재 설정 파일(train.yaml)을 포함
      - data: mnist
      - model: mnist
      - callbacks: default
      - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
      - trainer: default
      - paths: default
      - extras: default
      - hydra: default
    
      # 실험 설정, 특정 하이퍼파라미터 버전 관리 시 사용
      - experiment: null
    
      # 하이퍼파라미터 최적화 설정
      - hparams_search: null
    
      # optional local config for machine/user specific settings
      - optional local: default
    
      # debugging config (enable through command line, e.g. `python train.py debug=default)
      - debug: null
    
    # task name, determines output directory path
    task_name: "train"
    
    # tags to help you identify your experiments
    # you can overwrite this in experiment configs
    # overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
    tags: ["dev"]
    
    # set False to skip model training
    train: True
    
    # evaluate on test set, using best model weights achieved during training
    # lightning chooses best weights based on the metric specified in checkpoint callback
    test: True
    
    # 훈련 재개하기 위한 체크포인트 경로
    # null로 설정 시 새로 훈련을 시작함
    ckpt_path: null
    
    # seed for random number generators in pytorch, numpy and python.random
    seed: null
    
    • 훈련 설정 정의
  • src/data/mnist_datamodule.py
from typing import Any, Dict, Optional, Tuple

import torch
from lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms


class MNISTDataModule(LightningDataModule):
		# 초기화 함수
		# data directory, 분할 비율, batch size, worker #, fixed mem, transformation 정의
    def __init__(
        self,
        data_dir: str = "data/",
        train_val_test_split: Tuple[int, int, int] = (55_000, 5_000, 10_000),
        batch_size: int = 64,
        num_workers: int = 0,
        pin_memory: bool = False,
    ) -> None:
        super().__init__()
				
				# 초기화 시 전달된 하이퍼파라미터 저장
        # 이후 self.hparams 통해 하이퍼파라미터에 접근 가능
        self.save_hyperparameters(logger=False)

        # 데이터를 텐서로 변환
        # (이미 알려진 MNIST dataset의) 평균과 표준편차로 정규화하는 변환 정의
        self.transforms = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
				
				# 데이터셋 초기화, 나중에 setup 메서드에서 할당
        self.data_train: Optional[Dataset] = None
        self.data_val: Optional[Dataset] = None
        self.data_test: Optional[Dataset] = None
				
				# 배치 사이즈 초기화
        self.batch_size_per_device = batch_size

    @property
    def num_classes(self) -> int:
        return 10

    def prepare_data(self) -> None:
        # 데이터를 다운로드함
        # Lightning 에서는 이 함수 한 번만 호출
        MNIST(self.hparams.data_dir, train=True, download=True)
        MNIST(self.hparams.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None) -> None:
		    # 데이터 로드, 학습, 검증, 테스트 데이터셋 설정
		    # stage 매개변수 => 현재 단계(fit, validate, test, predict), 기본값 None
    
        # 배치 크기 조정 => 배치 크기를 장치 수로 나눔
        if self.trainer is not None:
            if self.hparams.batch_size % self.trainer.world_size != 0:
                raise RuntimeError(
                    f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
                )
            self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size

        # 데이터셋을 로드하고, 학습/검증/테스트 데이터셋으로 분할
        # 이미 로드된 경우 다시 로드하지 않음
        if not self.data_train and not self.data_val and not self.data_test:
            trainset = MNIST(self.hparams.data_dir, train=True, transform=self.transforms)
            testset = MNIST(self.hparams.data_dir, train=False, transform=self.transforms)
            dataset = ConcatDataset(datasets=[trainset, testset])
            self.data_train, self.data_val, self.data_test = random_split(
                dataset=dataset,
                lengths=self.hparams.train_val_test_split,
                generator=torch.Generator().manual_seed(42),
            )

    def train_dataloader(self) -> DataLoader[Any]:
		    # 학습 데이터 로더를 생성하고 반환
        return DataLoader(
            dataset=self.data_train,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=True,
        )

    def val_dataloader(self) -> DataLoader[Any]:
        # 검증 데이터 로더를 생성하고 반환
        return DataLoader(
            dataset=self.data_val,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

    def test_dataloader(self) -> DataLoader[Any]:
        # 테스트 데이터 로더를 생성하고 반환
        return DataLoader(
            dataset=self.data_test,
            batch_size=self.batch_size_per_device,
            num_workers=self.hparams.num_workers,
            pin_memory=self.hparams.pin_memory,
            shuffle=False,
        )

    def teardown(self, stage: Optional[str] = None) -> None:
        # 훈련, 검증, 테스트 후 정리 작업 수행 (default do nothing)
        pass

    def state_dict(self) -> Dict[Any, Any]:
        # 체크포인트 저장 시 데이터 모듈 상태 저장 (default return empty dict)
        return {}

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        # 체크포인트 로드 시 데이터 모듈 상태 로드 (default do nothing)
        pass


if __name__ == "__main__":
    _ = MNISTDataModule()
  • src/models/mnist_module.py
    • lightning에서는 모델 학습 / 검증 / 테스트 / 예측 을 단계로 나누어 관리
    • 각 단계는 여러 epoch로 구성, 각 epoch는 여러 배치로 구성, 각 batch에 대한 작업을 single step
      • fit: 모델을 훈련(train)하고 검증(validate)하는 단계입니다. 이 단계는 주로 trainer.fit() 메서드를 호출할 때 실행됩니다.
      • validate: 훈련된 모델을 별도로 검증하는 단계입니다. 이 단계는 trainer.validate() 메서드를 호출할 때 실행됩니다.
      • test: 모델을 테스트하는 단계입니다. 이 단계는 trainer.test() 메서드를 호출할 때 실행됩니다.
      • predict: 모델을 사용하여 예측하는 단계입니다. 이 단계는 trainer.predict() 메서드를 호출할 때 실행됩니다
      from typing import Any, Dict, Tuple
      
      import torch
      from lightning import LightningModule
      from torchmetrics import MaxMetric, MeanMetric
      from torchmetrics.classification.accuracy import Accuracy
      
      class MNISTLitModule(LightningModule):
          def __init__(
              self,
              net: torch.nn.Module,
              optimizer: torch.optim.Optimizer,
              scheduler: torch.optim.lr_scheduler,
              compile: bool,
          ) -> None:
          # 훈련할 모델 / 최적화 알고리즘 / 학습률 스케쥴러 / 모델 컴파일 여부
          
              super().__init__()
              
              # 초기화 시 전달된 하이퍼파라미터 저장
              self.save_hyperparameters(logger=False)
      				
      				# 모델 인스턴스
              self.net = net
      
              # 손실 함수
              self.criterion = torch.nn.CrossEntropyLoss()
      
              # metric objects for calculating and averaging accuracy across batches
              self.train_acc = Accuracy(task="multiclass", num_classes=10)
              self.val_acc = Accuracy(task="multiclass", num_classes=10)
              self.test_acc = Accuracy(task="multiclass", num_classes=10)
      
              # for averaging loss across batches
              self.train_loss = MeanMetric()
              self.val_loss = MeanMetric()
              self.test_loss = MeanMetric()
      
              # for tracking best so far validation accuracy
              self.val_acc_best = MaxMetric()
      
          def forward(self, x: torch.Tensor) -> torch.Tensor:
      		    # 모델의 순전파 정의, 입력 이미지 텐서를 받음, 모델의 출력(logits) 리턴
              return self.net(x)
      
          def on_train_start(self) -> None:
              # 검증 지표가 훈련 시작 전 상태/이전 실행의 값 반영하지 않도록
              self.val_loss.reset()
              self.val_acc.reset()
              self.val_acc_best.reset()
      
          def model_step(
              self, batch: Tuple[torch.Tensor, torch.Tensor]
          ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
      		    # 배치 데이터를 처리하여 손실, 예측값, 실제값 계산
      		    
      		    # x는 입력 이미지, y는 실제 레이블
              x, y = batch
              # 모델의 출력 logits
              logits = self.forward(x)
              # 손실값 계산
              loss = self.criterion(logits, y)
              # 예측 레이블
              preds = torch.argmax(logits, dim=1)
              return loss, preds, y
      
          def training_step(
              self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
          ) -> torch.Tensor:
      				# 단일 훈련 단계를 수행함 (model_step 호출)
              loss, preds, targets = self.model_step(batch)
      
              # update and log metrics
              self.train_loss(loss)
              self.train_acc(preds, targets)
              self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
              self.log("train/acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
      
              # return loss or backpropagation will fail
              return loss
      
          def on_train_epoch_end(self) -> None:
      		    # train epoch 끝날 때마다 호출되는 hook function
              pass
      
          def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
              # 단일 검증 단계를 수행함 (model_step 호출)
              loss, preds, targets = self.model_step(batch)
      
              # update and log metrics
              self.val_loss(loss)
              self.val_acc(preds, targets)
              self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
              self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
      
          def on_validation_epoch_end(self) -> None:
              acc = self.val_acc.compute()  # get current val acc
              self.val_acc_best(acc)  # update best so far val acc
              # log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
              # otherwise metric would be reset by lightning after each epoch
              self.log("val/acc_best", self.val_acc_best.compute(), sync_dist=True, prog_bar=True)
      
          def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
              # 단일 테스트 단계를 수행함 (model_step 호출)
              loss, preds, targets = self.model_step(batch)
      
              # update and log metrics
              self.test_loss(loss)
              self.test_acc(preds, targets)
              self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
              self.log("test/acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True)
      
          def on_test_epoch_end(self) -> None:
      		    # test epoch 끝날 때마다 호출되는 hook function
              pass
      
          def setup(self, stage: str) -> None:
      		    # 각 단계(fit, validate, test, predict) 시작 시 호출되는 hook
              if self.hparams.compile and stage == "fit":
      		        # 모델을 컴파일
                  self.net = torch.compile(self.net)
      
          def configure_optimizers(self) -> Dict[str, Any]:
      			  # 최적화 알고리즘
              optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
              if self.hparams.scheduler is not None:
      		        # 학습률 스케쥴러
                  scheduler = self.hparams.scheduler(optimizer=optimizer)
                  return {
                      "optimizer": optimizer,
                      "lr_scheduler": {
                          "scheduler": scheduler,
                          "monitor": "val/loss",
                          "interval": "epoch",
                          "frequency": 1,
                      },
                  }
              return {"optimizer": optimizer}
      
      if __name__ == "__main__":
          _ = MNISTLitModule(None, None, None, None)
      
      
  • src/eval.py
    • trainer.test()를 통해 주어진 모델과 데이터모듈을 사용해 테스트를 수행하고, metric_dict 리턴
    • hydra.main ⇒ eval.yaml 파일을 사용
from typing import Any, Dict, List, Tuple

import hydra
import rootutils
from lightning import LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig

rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

    extras,
    instantiate_loggers,
    log_hyperparameters,
    task_wrapper,
)

log = RankedLogger(__name__, rank_zero_only=True)


@task_wrapper
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Evaluates given checkpoint on a datamodule testset.

    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
    failure. Useful for multiruns, saving info about the crash, etc.

    :param cfg: DictConfig configuration composed by Hydra.
    :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
    """
    assert cfg.ckpt_path

    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)

    log.info(f"Instantiating model <{cfg.model._target_}>")
    model: LightningModule = hydra.utils.instantiate(cfg.model)

    log.info("Instantiating loggers...")
    logger: List[Logger] = instantiate_loggers(cfg.get("logger"))

    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
    trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)

    object_dict = {
        "cfg": cfg,
        "datamodule": datamodule,
        "model": model,
        "logger": logger,
        "trainer": trainer,
    }

    if logger:
        log.info("Logging hyperparameters!")
        log_hyperparameters(object_dict)

    log.info("Starting testing!")
    trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)

    # for predictions use trainer.predict(...)
    # predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)

    metric_dict = trainer.callback_metrics

    return metric_dict, object_dict


@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
def main(cfg: DictConfig) -> None:
    """Main entry point for evaluation.

    :param cfg: DictConfig configuration composed by Hydra.
    """
    # apply extra utilities
    # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
    extras(cfg)

    evaluate(cfg)


if __name__ == "__main__":
    main()
  • src/train.py
    • trainer.fit()을 통해 모델을 학습하고, trainer.test()를 통해 테스트를 수행함
    • train_metrics와 test_metrics를 리턴함
    • hydra.main ⇒ train.yaml 파일을 사용
    • from typing import Any, Dict, List, Optional, Tuple
      # auto-log experiment in ClearML
      from clearml import Task, Logger
      
      import hydra
      import lightning as L
      import rootutils
      import torch
      from lightning import Callback, LightningDataModule, LightningModule, Trainer
      from lightning.pytorch.loggers import Logger
      from omegaconf import DictConfig
      
      rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
      # ------------------------------------------------------------------------------------ #
      # the setup_root above is equivalent to:
      # - adding project root dir to PYTHONPATH
      #       (so you don't need to force user to install project as a package)
      #       (necessary before importing any local modules e.g. `from src import utils`)
      # - setting up PROJECT_ROOT environment variable
      #       (which is used as a base for paths in "configs/paths/default.yaml")
      #       (this way all filepaths are the same no matter where you run the code)
      # - loading environment variables from ".env" in root dir
      #
      # you can remove it if you:
      # 1. either install project as a package or move entry files to project root dir
      # 2. set `root_dir` to "." in "configs/paths/default.yaml"
      #
      # more info: <https://github.com/ashleve/rootutils>
      # ------------------------------------------------------------------------------------ #
      
      from src.utils import (
          RankedLogger,
          extras,
          get_metric_value,
          instantiate_callbacks,
          instantiate_loggers,
          log_hyperparameters,
          task_wrapper,
      )
      
      log = RankedLogger(__name__, rank_zero_only=True)
      
      @task_wrapper
      def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
          """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
          training.
      
          This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
          failure. Useful for multiruns, saving info about the crash, etc.
      
          :param cfg: A DictConfig configuration composed by Hydra.
          :return: A tuple with metrics and dict with all instantiated objects.
          """
          # set seed for random number generators in pytorch, numpy and python.random
          if cfg.get("seed"):
              L.seed_everything(cfg.seed, workers=True)
      
          log.info(f"Instantiating datamodule <{cfg.data._target_}>")
          datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
      
          log.info(f"Instantiating model <{cfg.model._target_}>")
          model: LightningModule = hydra.utils.instantiate(cfg.model)
      
          log.info("Instantiating callbacks...")
          callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
      
          log.info("Instantiating loggers...")
          logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
      
          log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
          trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
      
          object_dict = {
              "cfg": cfg,
              "datamodule": datamodule,
              "model": model,
              "callbacks": callbacks,
              "logger": logger,
              "trainer": trainer,
          }
      
          if logger:
              log.info("Logging hyperparameters!")
              log_hyperparameters(object_dict)
      
          if cfg.get("train"):
              log.info("Starting training!")
              trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
      
          train_metrics = trainer.callback_metrics
      
          if cfg.get("test"):
              log.info("Starting testing!")
              ckpt_path = trainer.checkpoint_callback.best_model_path
              if ckpt_path == "":
                  log.warning("Best ckpt not found! Using current weights for testing...")
                  ckpt_path = None
              trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
              log.info(f"Best ckpt path: {ckpt_path}")
      
          test_metrics = trainer.callback_metrics
      
          # merge train and test metrics
          metric_dict = {**train_metrics, **test_metrics}
      
          return metric_dict, object_dict
      
      @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
      def main(cfg: DictConfig) -> Optional[float]:
          """Main entry point for training.
      
          :param cfg: DictConfig configuration composed by Hydra.
          :return: Optional[float] with optimized metric value.
          """
          # apply extra utilities
          # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
          extras(cfg)
      
          # train the model
          metric_dict, _ = train(cfg)
      
          # safely retrieve metric value for hydra-based hyperparameter optimization
          metric_value = get_metric_value(
              metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
          )
      
          # return optimized metric
          return metric_value
      
      if __name__ == "__main__":
          task = Task.init(project_name='lightning-hydra', task_name='experiment1')
          main()