-
Notifications
You must be signed in to change notification settings - Fork 7
/
train.py
38 lines (26 loc) · 874 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import logging
import os
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from cxrclip import convert_dictconfig_to_dict, run, seed_everything
log = logging.getLogger(__name__)
@hydra.main(version_base=None, config_path="configs", config_name="train")
def main(cfg: DictConfig):
OmegaConf.resolve(cfg)
if "LOCAL_RANK" in os.environ:
# for ddp
# passed by torchrun or torch.distributed.launch
local_rank = int(os.environ["LOCAL_RANK"])
else:
# for debugging
local_rank = -1
if local_rank < 1:
log.info(f"Configurations:\n{OmegaConf.to_yaml(cfg)}")
seed_everything(cfg.base.seed)
# torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
cfg = convert_dictconfig_to_dict(cfg)
run(local_rank, cfg)
if __name__ == "__main__":
main()