포스트

PyTorch-Lightning 1.5에서 args로 하이퍼파라미터 전달

재작년(…) 사내 양성과정의 일환으로 진행했던 미니 프로젝트를 다시 끄집어냈다. 팀 프로젝트로 진행했던 작업이긴 하나 소스코드를 너무 대충 짜논 턱에 못 알아보겠는 부분도 많고, 너무 오래 전에 작업했던 내용이라 그 사이에 SKT에서 KoGPT2 version 2를 내고 카카오에서도 KoGPT 발표하고 난리가 나서 주말마다 리팩토링을 해보기로 했다.

2020년 때와의 가장 큰 차이점이라고 한다면 역시 PyTorch-Lightning이 정식 버전을 내고 .5가 붙었다는 건데, 이 때문에 안 맞는 API가 너~무 많았다. (지금 리팩토링하면서 사용하고 있는 PyTorch-Lightning은 버전 1.5.10) 그 중에서도 특히나 골때리는 게 2개 정도 있었는데 이번에는 argparse를 이용해 파싱한 args를 LightningModule 내에서 하이퍼파라미터로 사용할 수 있게 전달하는 부분을 다룰 생각이다.

2년 전에는 __init__에서 하이퍼파라미터 저장을 이렇게 했다.

1
2
3
4
5
class CMPersonaChat(LightningModule):
    def __init__(self, hparams, *args):
        super(CMPersonaChat, self).__init__()
        self.hparams = hparams  # <-- 이 방식으로 지정
        self.kogpt2 = get_kogpt2_model(hparams.model_params)

물론 지금은 안 된다. self.hparams에 직접 값을 지정하는 것이 불가능해졌기 때문에 save_hyperparameters()를 이용해야 한다.

1
2
3
4
5
class CMPersonaChat(LightningModule):
    def __init__(self, hparams, *args):
        super(CMPersonaChat, self).__init__()
        self.save_hyperparameters()  # <-- 여기가 바로 달라진 부분
        self.kogpt2 = get_kogpt2_model()

근데 이래놨는데 configure_optimizers에서 lr이 제대로 읽히지 않는 것이었다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py:1584: UserWarning: GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.
  rank_zero_warn(
Traceback (most recent call last):
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py", line 272, in __getattr__
    return self[key]
KeyError: 'lr'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "main.py", line 493, in <module>
    main()
  File "main.py", line 365, in main
    train(dataloader, args)
  File "main.py", line 280, in train
    trainer.fit(model, dataloader['train'], dataloader['valid'])
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit
    self._call_and_handle_interrupt(
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1145, in _run
    self.accelerator.setup(self)
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/accelerators/cpu.py", line 35, in setup
    return super().setup(trainer)
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 93, in setup
    self.setup_optimizers(trainer)
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 354, in setup_optimizers
    optimizers, lr_schedulers, optimizer_frequencies = self.training_type_plugin.init_optimizers(
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 245, in init_optimizers
    return trainer.init_optimizers(model)
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/trainer/optimizers.py", line 35, in init_optimizers
    optim_conf = self.call_hook("configure_optimizers", pl_module=pl_module)
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1501, in call_hook
    output = model_fx(*args, **kwargs)
  File "main.py", line 166, in configure_optimizers
    lr=self.hparams.lr, correct_bias=False)
  File "/home/judith/miniconda3/envs/cm/lib/python3.8/site-packages/pytorch_lightning/utilities/parsing.py", line 274, in __getattr__
    raise AttributeError(f'Missing attribute "{key}"') from exp
AttributeError: Missing attribute "lr"

2년 전까지만 해도 잘 돌던 코드였는데 갑자기 왜 이러나 싶어서 __init__에서 self.hparams를 출력해봤는데 좀 골때리게 나오고 있었다.

1
"hparams": Namespace(accelerator=None, accumulate_grad_batches=None, amp_backend='native', amp_level=None, auto_lr_find=False, auto_scale_batch_size=False, auto_select_gpus=False, benchmark=False, check_val_every_n_epoch=1, checkpoint_callback=None, dataset_cache='./dataset_cache', dataset_path='new_dataset/personachat_translated_sample.json', default_root_dir=None, detect_anomaly=False, deterministic=False, device='cuda', devices=None, enable_checkpointing=True, enable_model_summary=True, enable_progress_bar=True, eval_dataset_path='eval/eval_merge.json', fast_dev_run=False, flush_logs_every_n_steps=None, gpus=None, gradient_clip_algorithm=None, gradient_clip_val=None, ipus=None, limit_predict_batches=1.0, limit_test_batches=1.0, limit_train_batches=1.0, limit_val_batches=1.0, log_every_n_steps=50, log_gpu_memory=None, logger=True, lr=5e-05, max_epochs=None, max_history=2, max_len=768, max_steps=-1, max_time=None, min_epochs=None, min_length=1, min_steps=None, mode='train', model_params=None, move_metrics_to_cpu=False, multiple_trainloader_mode='max_size_cycle', name='cm_kogpt2', no_sample=False, num_candidates=1, num_eval_pp=10, num_nodes=1, num_processes=1, num_sanity_val_steps=2, num_workers=8, overfit_batches=0.0, personality_permutations=1, plugins=None, precision=32, prepare_data_per_node=None, process_position=0, profiler=None, progress_bar_refresh_rate=None, reload_dataloaders_every_epoch=False, reload_dataloaders_every_n_epochs=0, replace_sampler_ddp=True, resume_from_checkpoint=None, stochastic_weight_avg=False, strategy=None, sync_batchnorm=False, temperature=0.7, terminate_on_nan=None, top_k=0, top_p=0.9, tpu_cores=None, track_grad_norm=-1, train_batch_size=4, val_check_interval=1.0, valid_batch_size=2, warmup_ratio=0.1, weights_save_path=None, weights_summary='top')

self.hparams가 왜 dict로 받아지는 거야…? ㅋㅋㅋㅋㅋㅋㅋ… 이렇게 되면 self.hparams['hparams'].lr 와 같은 식으로 써야 한다는 건데, 말도 안 되는 짓이었다. 아마 2년 사이에 PyTorch-Lightning의 코딩 철학에 많은 변화가 생긴 모양이었다.

그래서 문서를 다시 뒤져 확인한 뒤 코드를 몇 군데 수정했다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 모델 정의
class CMPersonaChat(LightningModule):
    def __init__(self, **hparams):  # should get hparams with ** if you want pass args
    # def __init__(self, hparams):  # not like this
        super(CMPersonaChat, self).__init__()
        self.save_hyperparameters()
        self.kogpt2 = get_kogpt2_model()

# 실제 모델 불러서 args 넘기는 부분
parser = CMPersonaChat.add_model_specific_args(parser)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()

model = CMPersonaChat(**vars(args))  # <-- namespace인 args를 dict로 바꾸고 key-value pair로 던져야 한다

수정 결과, LightningModule 내에서 하이퍼파라미터가 필요할 때 self.hparams.lr 과 같은 식으로 불러도 코드가 제대로 실행되는 걸 확인했다. 그런데 웬만하면 PyTorch-Lightning 쪽에서 dictnamespace 둘 다 받아지게 지원을 해줬으면 좋겠다. 딱히 어려운 건 아닌데!

이 기사는 저작권자의 CC BY 4.0 라이센스를 따릅니다.

Comments powered by Disqus.