搭建了网络以及训练代码之后,程序在第一个epoch是没问题的,但是一进入第二个epoch就报错:RuntimeError: Trying to backward through the graph a second time。代码如下:
for epoch in range(config['epochs']):
# 训练阶段
model.train()
epoch_train_loss = 0
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1} [Train]')
for batch in progress_bar:
spectra = batch['spectrum'].to(device)
num_src = batch['num_sources'].to(device)
angles = batch['angles'].to(device)
optimizer.zero_grad()
pred_num, pred_angles = model(spectra)
loss = criterion(pred_num, pred_angles, num_src, angles)
loss.backward()
optimizer.step()
epoch_train_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
尝试按照提示修改backward部分,加入retain_grad=True,即:
loss.backward(retain_grad=True)
并没有解决。
解决方案: 将模型输出结果使用detach()函数将其从计算图中分离,使其不参与反向传播。如下:
pred_num, pred_angles = model(spectra)
loss = criterion(pred_num.detach(), pred_angles.detach(), num_src, angles)
这样就可以正常训练啦。