01 | 问题描述
最近在应用LSTM实验时,发现遇到了意外的错误:
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
根据常规经验,既然提示“backward through the graph a second time”,往往可以在loss.backward()中添加“retain_graph=True”保存计算图。
重新运行发现提示新错误:
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [128, 6]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
此时提示出现“原位操作”错误,猜测应该是在反向计算梯度时,相关张量发生了不应该有的变化。
上述发生错误的代码如下:
for epoch in range(EPOCH):
batch_num = 0
h0 = torch.randn((num_layers, batch_size, hidden_dim), requires_grad=False)
c0 = torch.randn((num_layers, batch_size, hidden_dim), requires_grad=False)
for batch in data_loader:
optimizer.zero_grad()
output, (h0, c0) = lstm(batch[0], h0, c0)
loss = loss_fn(output.reshape(-1,embedding_dim), batch[1])
loss.backward(retain_graph=True)
optimizer.step()
batch_num += 1
if (batch_num) % min(10, batch_count // 2) == 0:
print(f'epoch {epoch} : batch {batch_num} : {loss.item():.4f}')
02 | 错误分析
网上查阅相关资料,24小时不得解,但是问题逐步聚焦在了LSTM的应用部分。
最终发现问题出现在LSTM的隐藏变量与记忆变量的递归引用上。
上述问题发生在第七行代码:
① 为了使得不同批次(batch)处理间的隐藏变量可以记忆上一次训练的成果,因而将上次计算得到的h0和c0作为新批次训练的初始状态
② 这就导致在计算第二批次时,计算图中涉及的h0和c0是源于第一批次的计算得到的(而这些h0和c0又是第一批次计算图中参数的函数);默认状态下为了反向计算相关参数的梯度,必须通过第二批次输入的h0和c0回溯到第一批次中的变量,必然用到第一批次的计算图,然而此时的计算图早已释放了!
③ 即便使用retain_graph=True保存了上一次的计算图,但是这一过程中的h0已经不再是最初的h0,因此会提示“in-place”错误!
解决办法就是:利用detach()方法将h0与c0和上一批次计算分离。
optimizer.zero_grad()
h0 = h0.detach()
c0 = c0.detach()
output, (h0, c0) = lstm(batch[0], h0, c0)
运行成功!
经验结论:当在RNN中递归赋值变量时,需要使用detach()方法隔绝不同批次间的影响。