深度强化学习调试与优化指南
1. 调试技巧
1.1 梯度检查代码示例
以下是一段用于检查梯度范数的代码:
try:
grad_norm = param.grad.norm()
assert min_norm < grad_norm < max_norm, f'Gradient norm for {p_name} is {grad_norm:g}, fails the extreme value check {min_norm} < grad_norm < {max_norm}. Loss: {loss:g}. Check your network and loss computation.'
except Exception as e:
logger.warning(e)
logger.info(f'Gradient norms passed value check.')
logger.debug('Passed network parameter update check.')
# store grad norms for debugging
net.store_grad_norms()
return loss
return check_fn
此代码通过 try-except 块检查梯度范数是否在指定范围内,若不在则记录警告信息。
1.2 单个损失检查
若损失函数由多个单个损失组成,需分别检查每个损失,确保其产生正确的训练行为和网络更新。具
超级会员免费看
订阅专栏 解锁全文
996

被折叠的 条评论
为什么被折叠?



