def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
accu_loss = torch.zeros(1).to(device) # 累计损失 tensor([0.])
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
optimizer.zero_grad()
sample_num = 0
data_loader = tqdm(data_loader, file=sys.stdout) # 循环遍历数据加载器:通过tqdm包装的data_loader进行循环,以便在训练过程中显示进度条。对于每批数据
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]
'''images.shape[0]用于获取当前批次(batch)中图像的数量。
在PyTorch中,一个张量(Tensor)的.shape属性返回一个元组,
描述了张量在每个维度上的大小。对于图像数据,这通常遵循格式
(批次大小, 通道数, 高度, 宽度):
'''
pred = model(images.to(device))
pred_classes = torch.max(pred, dim=1)[1]
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
loss = loss_function(pred, labels.to(device))
loss.backward()
accu_loss += loss.detach()
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num)
if not torch.isfinite(loss): # 检查损失的有效性:如果在某一步损失变成非有限值(比如NaN或无穷大),训练将停止,并显示警告信息。
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
DeepLearning MobileViT训练过程学习
最新推荐文章于 2025-09-07 21:49:50 发布
本文详细解释了如何在PyTorch中实现一个训练循环,包括使用CrossEntropyLoss、优化器更新、计算损失和精度,以及处理非有限损失的情况。
部署运行你感兴趣的模型镜像
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch
Cuda
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
2821

被折叠的 条评论
为什么被折叠?



