训练分类网络resnet18时发生报错:RuntimeError: CUDA error: device-side assert triggered
原始代码:
for batch_index, (images, labels) in enumerate(train_dataloader):
if epoch <= args.warm:
warmup_scheduler.step()
images = images.cuda()
labels = labels.cuda()
optimizer.zero_grad()
predicts = net(images)
loss = 0
for i, pi in enumerate(predicts):
t = torch.full_like(pi, 0, device=device)
if labels[i].long() > 0:
print('labels:', labels[i])
t[labels[i].long() - 1] = 1
print('t', t)
loss_curr = criterion(pi, t) # BCE
loss += loss_curr
train_total_loss += loss_curr.item()
loss.backward()
optimizer.step()
发生报错后,将报错位置进行打印:
labels: tensor([7], device='cuda:0', dtype=torch.int32)
t tensor([0.