背景
编码项目代码时,往往涉及到很多的超参数。Hydra可以帮助整理这些超参数,使实验过程中的参数设置更清晰。
从config.yaml文件中读取超参数
假设当前的文件路径为:
├─configs
│ └─config.yaml
└─main.py
config.yaml的内容是:
name: exp
save_dir: ./checkpoint
data:
dataroot: ./data
batch_size: 64
main.py的内容是:
from omegaconf import DictConfig, OmegaConf
import hydra
@hydra.main(version_base=None, config_path="configs", config_name="config")
def func(cfg):
cfg_str = OmegaConf.to_yaml(cfg) # 将cfg转换成string格式,方便打印
print(cfg_str)
if __name__ == "__main__":
func()
在@hydra.main
中,config_path代表存放配置文件的文件夹,config_name代表主配置文件的名称。@hydra.main读取这些配置后形成数据格式DictConfig,然后传递给func。
读取多个.yaml的超参数
假设当前的文件路径为:
├─configs
│ ├─data
│ │ └─data_1.yaml
│ ├─model
│ │ └─model_1.yaml
│ └─config.yaml
└─main.py
.yaml的内容分别是:
# config.yaml
name: exp
save_dir: ./checkpoint
defaults:
- data: data_1
- model: model_1
# data_1.yaml
dataroot: ./data
# model_1.yaml
n_layers: 3
执行main.py后输出:
data:
dataroot: ./data
model:
n_layers: 3
name: exp
save_dir: ./checkpoint
保存本次实验的configs
假设当前目录是:
├─configs
│ └─config.yaml
└─main.py
config.yaml内容是:
# 通用
name: exp # 本次实验的名称
seed: 42 # 随机种子
checkpoint_dir: checkpoints # ckpt目录
output_dir: ./${checkpoint_dir}/${name} # 本次实验的输出目录
# hydra每次运行会自动生成一个目录,用于区分不同的实验
hydra:
run:
# dir: ./${checkpoint_dir}/${now:%Y-%m-%d}/${now:%H-%M-%S}
dir: ./${checkpoint_dir}/${name}
那么保存函数可以是:
def save_configs(cfg):
# 保存configs
OmegaConf.save(cfg, cfg.output_dir + '/configs.yaml')
print('[Config] configs saved in %s' % (cfg.output_dir + '/configs.yaml'))