目录
【思考】什么是范数,什么是L2范数,这里为什么要打印梯度范数?
6.2 梯度爆炸实验
造成简单循环网络较难建模长程依赖问题的原因有两个:梯度爆炸和梯度消失。一般来讲,循环网络的梯度爆炸问题比较容易解决,一般通过权重衰减或梯度截断可以较好地来避免;对于梯度消失问题,更加有效的方式是改变模型,比如通过长短期记忆网络LSTM来进行缓解。
本节将首先进行复现简单循环网络中的梯度爆炸问题,然后尝试使用梯度截断的方式进行解决。这里采用长度为20的数据集进行实验,训练过程中将进行输出,
,
的梯度向量的范数,以此来衡量梯度的变化情况。
6.2.1 梯度打印函数
使用custom_print_log实现了在训练过程中打印梯度的功能,custom_print_log需要接收runner的实例,并通过model.named_parameters()获取该模型中的参数名和参数值. 这里我们分别定义W_list, U_list和b_list,用于分别存储训练过程中参数W,U和b的梯度范数。
import torch
W_list = []
U_list = []
b_list = []
# 计算梯度范数
def custom_print_log(runner):
model = runner.model
W_grad_l2, U_grad_l2, b_grad_l2 = 0, 0, 0
for name, param in model.named_parameters():
if name == "rnn_model.W":
W_grad_l2 = torch.norm(param.grad, p=2).numpy()
if name == "rnn_model.U":
U_grad_l2 = torch.norm(param.grad, p=2).numpy()
if name == "rnn_model.b":
b_grad_l2 = torch.norm(param.grad, p=2).numpy()
print(f"[Training] W_grad_l2: {W_grad_l2:.5f}, U_grad_l2: {U_grad_l2:.5f}, b_grad_l2: {b_grad_l2:.5f} ")
W_list.append(W_grad_l2)
U_list.append(U_grad_l2)
b_list.append(b_grad_l2)
【思考】什么是范数,什么是L2范数,这里为什么要打印梯度范数?
范数:
范数是一种强化了的距离概念。我们知道距离的定义是:只要满足非负、自反、三角不等式就可以称之为距离。而范数在定义上比距离多了一条数乘的运算法则。有时候为了便于理解,我们可以把范数当作距离来理解。
L2范数:
我们用的最多的度量距离“欧氏距离”就是一种L2范数,它的定义如下: