效果:
- 将实验用的参数写入 yaml 文件,而不是全部用
argparse传,否则命令会很长,而且在 jupyter notebook 内用不了 argparse; - 同时支持在命令行临时加、改一些参数,避免事必要在 yaml 中改参数,比较灵活(如 grid-search 时遍历不同的 loss weights)。
最初是在 MMDetection 中看到这种写法,参考 [1] 中 --cfg-options 这个参数,核心是 DictAction 类,定义在 [2]。yaml 一些支持的写法参考 [3]。本文同时作为 python yaml 读、写简例。
Code
DictAction类抄自 [2];parse_cfg函数读 yaml 参数,并按命令行输入加、改参数(覆盖 yaml),用EasyDict装;- 用 yaml 备份参数时,用
easydict2dict将 EasyDict 递归改回 dict,yaml 会干净点。不用也行。
from argparse import Action, ArgumentParser, Namespace
import copy
from easydict import EasyDict
from typing import Any, Optional, Sequence, Tuple, Union
import yaml
class DictAction(Action):
"""抄自 MMEngine
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options can
be passed as comma separated values, i.e 'KEY=V1,V2,V3', or with explicit
brackets, i.e. 'KEY=[V1,V2,V3]'. It also support nested brackets to build
list/tuple values. e.g. 'KEY=[(V1,V2),(V3,V4)]'
"""
@staticmethod
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
"""parse int/float/bool value in the string."""
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False
if val == 'None':
return None
return val
@staticmethod
def _parse_iterable(val: str) -> Union[list, tuple, Any]:
"""Parse iterable values in the string.
All elements inside '()' or '[]' are treated as iterable values.
Args:
val (str): Value string.
Returns:
list | tuple | Any: The expanded list or tuple from the string,
or single value if no iterable values are found.
Examples:
>>> DictAction._parse_iterable('1,2,3')
[1, 2, 3]
>>> DictAction._parse_iterable('[a, b, c]')
['a', 'b', 'c']
>>> DictAction._parse_iterable('[(1, 2, 3), [a, b], c]')
[(1, 2, 3), ['a', 'b'], 'c']
"""
def find_next_comma(string):
"""Find the position of next comma in the string.
If no ',' is found in the string, return the string length. All
chars inside '()' and '[]' are treated as one element and thus ','
inside these brackets are ignored.
"""
assert (string.count('(') == string.count(')')) and (
string.count('[') == string.count(']')), \
f'Imbalanced brackets exist in {string}'
end = len(string)
for idx, char in enumerate(string):
pre = string[:idx]
# The string before this ',' is balanced
if ((char == ',') and (pre.count('(') == pre.count(')'))
and (pre.count('[') == pre.count(']'))):
end = idx
break
return end
# Strip ' and " characters and replace whitespace.
val = val.strip('\'\"').replace(' ', '')
is_tuple = False
if val.startswith('(') and val.endswith(')'):
is_tuple = True
val = val[1:-1]
elif val.startswith('[') and val.endswith(']'):
val = val[1:-1]
elif ',' not in val:
# val is a single value
return DictAction._parse_int_float_bool(val)
values = []
while len(val) > 0:
comma_idx = find_next_comma(val)
element = DictAction._parse_iterable(val[:comma_idx])
values.append(element)
val = val[comma_idx + 1:]
if is_tuple:
return tuple(values)
return values
def __call__(self,
parser: ArgumentParser,
namespace: Namespace,
values: Union[str, Sequence[Any], None],
option_string: str = None):
"""Parse Variables in string and add them into argparser.
Args:
parser (ArgumentParser): Argument parser.
namespace (Namespace): Argument namespace.
values (Union[str, Sequence[Any], None]): Argument string.
option_string (list[str], optional): Option string.
Defaults to None.
"""
# Copied behavior from `argparse._ExtendAction`.
options = copy.copy(getattr(namespace, self.dest, None) or {})
if values is not None:
for kv in values:
key, val = kv.split('=', maxsplit=1)
options[key] = self._parse_iterable(val)
setattr(namespace, self.dest, options)
def parse_cfg(yaml_file, update_dict={}):
"""load configurations from a yaml file & update from command-line argments
Input:
yaml_file: str, path to a yaml configuration file
update_dict: dict, to modify/update options in those yaml configurations
Output:
cfg: EasyDict
"""
with open(args.cfg, "r") as f:
cfg = EasyDict(yaml.safe_load(f))
if update_dict:
assert isinstance(update_dict, dict)
for k, v in update_dict.items():
k_list = k.split('.')
assert len(k_list) > 0
if len(k_list) == 1: # 单级,e.g. lr=0.1
cfg[k_list[0]] = v
else: # 多级,e.g. optimizer.group1.lr=0.2
ptr = cfg
for i, _k in enumerate(k_list):
if i == len(k_list) - 1: # last layer
ptr[_k] = v
elif _k not in ptr:
ptr[_k] = EasyDict()
ptr = ptr[_k]
return cfg
def easydict2dict(ed):
"""convert EasyDict to dict for clean yaml"""
d = {}
for k, v in ed.items():
if isinstance(v, EasyDict):
d[k] = easydict2dict(v)
else:
d[k] = v
return d
if "__main__" == __name__:
# test command:
# python config.py --cfg-options int=5 dict2.lr=8 dict2.newdict.newitem=fly
import pprint
parser = ArgumentParser()
parser.add_argument("--cfg", type=str, default="config.yaml", help="指定 yaml")
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
# 命令行临时加、改参数
pprint.pprint(args.cfg_options) # dict
# 读 yaml,并按命令行输入加、改参数
cfg = parse_cfg(args.cfg, args.cfg_options)
pprint.pprint(cfg)
# 备份 yaml(写 yaml)
with open("backup-config.yaml", 'w') as f:
# yaml.dump(cfg, f) # OK
yaml.dump(easydict2dict(cfg), f) # cleaner yaml
输入的 config.yaml:
- 语法参考 [3]
# An example yaml configuration file, used in utils/config.py as an example input.
# Ref: https://pyyaml.org/wiki/PyYAMLDocumentation
log_path: ./log
none: [~, null, None]
bool: [true, false, on, off, True, False]
int: 42 # <- 改它
float: 3.14159
list: [LITE, RES_ACID, SUS_DEXT]
list2:
- -1
- 0
- 1
str:
a
2
0.2
# d: tom
dict: {hp: 13, sp: 5}
dict2: # <- 加它
lr: 0.01 # <- 改它
decay_rate: 0.1
name: jerry
测试:
# --cfg-options 支持多级指定(用「.」分隔)
python config.py --cfg config.yaml --cfg-options int=5 dict2.lr=8 dict2.newdict.newitem=fly
输出:
{'dict2.lr': 8, 'dict2.newdict.newitem': 'fly', 'int': 5}
{'bool': [True, False, True, False, True, False],
'dict': {'hp': 13, 'sp': 5},
'dict2': {'decay_rate': 0.1,
'lr': 8, # <- 改了
'name': 'jerry',
'newdict': {'newitem': 'fly'}}, # <- 加了
'float': 3.14159,
'int': 5, # <- 改了
'list': ['LITE', 'RES_ACID', 'SUS_DEXT'],
'list2': [-1, 0, 1],
'log_path': './log',
'none': [None, None, 'None'],
'str': 'a 2 0.2'}
json + argparse.Namespace
如果是基于别人的代码改,已经用 argparse 写了很长的参数列表,而现在想在 jupyter notebook 内用 checkpoint 跑些东西,可以用 json 存下 argparse 的参数,然后在 notebook 里读此 json 文件,并用 argparse.Namespace 类复原。这可以用作前文 yaml + EasyDict 的简单替代方案。
example
在用原代码 .py 脚本跑训练时,把命令行参数存入 json 文件:
# train.py
import argparse, json
parser = argparse.ArgumentParser()
parser.add_argument("lr", type=float)
# ...此处省略 n 行 parser.add_argument
args = parser.parse_args()
print(type(args), args.__dict__, args.lr) # <class 'argparse.Namespace'> {'lr': 0.1} 0.1
# 将跑此实验的命令行参数存入 json 文件
with open("cmdln-args.json", "w") as f:
json.dump(args.__dict__, f, indent=1)
训练命令:python train.py 0.1 <以及后续 n 个参数>。
在 jupyter notebook 内,不能在命令行传参,改从之前存下的 cmdln-args.json 读参数:
# test.ipynb
import argparse, json
with open("cmdln-args.json", "r") as f:
args = argparse.Namespace(**json.load(f))
print(type(args), args.__dict__, args.lr) # <class 'argparse.Namespace'> {'lr': 0.1} 0.1
本文介绍了一种在Python中通过yaml文件管理和临时修改命令行参数的方法,利用`DictAction`和`EasyDict`简化配置过程,特别是在MMDetection这样的项目中,如grid-search时调整lossweights。
1522

被折叠的 条评论
为什么被折叠?



