OptimizerHook
这段代码定义了一个名为 OptimizerHook 的类,它是一个用于优化器的自定义操作钩子。该钩子包含了一些用于梯度裁剪和检测异常参数的操作。这对于在深度学习训练过程中优化模型的性能和调试模型非常有用。
类的定义
OptimizerHook 类继承自 Hook,实现了一些与优化器相关的自定义操作。
参数说明
grad_clip: 一个字典,用于配置梯度裁剪的参数。默认值为 None。
detect_anomalous_params: 一个布尔值,用于调试目的。这将减慢训练速度,检测不包含在计算图中的异常参数。默认值为 False。
@HOOKS.register_module()
class OptimizerHook(Hook):
"""A hook contains custom operations for the optimizer.
Args:
grad_clip (dict, optional): A config dict to control the clip_grad.
Default: None.
detect_anomalous_params (bool): This option is only used for
debugging which will slow down the training speed.
Detect anomalous parameters that are not included in
the computational graph with `loss` as the root.
There are two cases
- Parameters were not used during
forward pass.
- Parameters were not used to produce
loss.
Default: False.
"""
def __init__(self,
grad_clip: Optional[dict] = None,
detect_anomalous_params: bool = False):
self.grad_clip = grad_clip
self.detect_anomalous_params = detect_anomalous_params
def clip_grads(self, params):
params = list(
filter(lambda p: p.requires_grad and p.grad is not None, params))
if len(params) > 0:
return clip_grad.clip_grad_norm_(params, **self.grad_clip)
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
if self.detect_anomalous_params:
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward()
if self.grad_clip is not None:
grad_norm = self.clip_grads(runner.model.parameters())
if grad_norm is not None:
# Add grad norm to the logger
runner.log_buffer.update({
'grad_norm': float(grad_norm)},
runner.outputs['num_samples'])
runner.optimizer.step()
def detect_anomalous_parameters(self, loss: Tensor, runner) -> None:
logger = runner.logger
parameters_in_graph = set()
visited = set()
def traverse(grad_fn):
if grad_fn is None:
return
if grad_fn not in visited:
visited.add(grad_fn)
if hasattr(grad_fn, 'variable'):
parameters_in_graph.add(grad_fn.variable)
parents = grad_fn.next_functions
if parents is not None:
for parent in parents:
grad_fn = parent[0]
traverse(grad_fn)
traverse(loss.grad_fn)
for n, p in runner.model.named_parameters():
if p not in parameters_in_graph and p.requires_grad:
logger.log(
level=logging.ERROR,
msg=f'{n} with shape {p.size()} is not '
f'in the computational graph \n')
主要逻辑
初始化参数
接受 grad_clip 和 detect_anomalous_params 两个可选参数,并将它们赋值给实例变量。
clip_grads 方法
过滤出需要梯度裁剪的参数。
如果有参数