pytorch项目:测试代码中correct += (y_pred == y).sum().item()

这篇博客探讨了Python中列表和NumPy数组的区别,并通过示例展示了在比较和操作过程中可能出现的问题。在Case1中,尝试将列表直接进行相等比较导致错误,而在Case2和Case3中,使用NumPy数组和PyTorch张量进行比较则能得到预期结果。这强调了在处理数值计算时使用数组库的重要性。
部署运行你感兴趣的模型镜像

#测试代码中correct += (y_pred == y).sum().item()
import torch
import numpy as np
correct = 0

#case1:当y_pre和y为列表时
y = [1,5,7,8]
y_pred = [2,4,7,9]

correct += (y_pred == y).sum().item() #报错:AttributeError: ‘bool’ object has no attribute ‘sum’

#case2:当y_pre和y为np
y = np.array([3,6,8])
y_pred = np.array([5,6,9])
correct += (y_pred == y).sum().item()
print(correct) # 1

#case3 当为tensor时
y= torch.tensor([5,8,3])
y_pred = torch.tensor([5,8,4])
correct=0
correct += (y_pred == y).sum().item()
print(correct) #2

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

criterion = nn.CrossEntropyLoss(ignore_index=-1) 这个应该写在哪里? 以下是现在训练网络的代码: def train_model(model, train_loader, test_loader, epochs=50): train_losses, test_losses = [], [] pres_accs, wifi_accs, zb_accs = [], [], [] test_pres_accs, test_wifi_accs, test_zb_accs = [], [], [] for epoch in range(epochs): model.train() running_loss = 0.0 correct_pres = 0 correct_wifi = 0 correct_zb = 0 total = 0 for inputs, pres_labels, wifi_labels, zb_labels in train_loader: inputs = inputs.to(device) pres_labels = pres_labels.to(device) wifi_labels = wifi_labels.to(device) zb_labels = zb_labels.to(device) optimizer.zero_grad() outputs = model(inputs) # 损失计算 loss_pres = criterion_pres(outputs['presence'], pres_labels) loss_wifi = criterion_ch(outputs['wifi_channel'], wifi_labels) loss_zb = criterion_ch(outputs['zigbee_channel'], zb_labels) total_loss = loss_pres + loss_wifi + loss_zb total_loss.backward() optimizer.step() running_loss += total_loss.item() # 准确率:Presence(全匹配才算对) pred_pres = (torch.sigmoid(outputs['presence']) > 0.5).int() true_pres = pres_labels.int() correct_pres += (pred_pres == true_pres).all(dim=1).sum().item() # WiFi 信道准确率(仅当WiFi存在时统计) mask_wifi = pres_labels[:, 0] > 0.5 # 只看WiFi存在的样本 if mask_wifi.sum() > 0: _, pred_wifi = outputs['wifi_channel'][mask_wifi].max(1) correct_wifi += (pred_wifi == wifi_labels[mask_wifi]).sum().item() # ZigBee 信道准确率 mask_zb = pres_labels[:, 2] > 0.5 if mask_zb.sum() > 0: _, pred_zb = outputs['zigbee_channel'][mask_zb].max(1) correct_zb += (pred_zb == zb_labels[mask_zb]).sum().item() total += inputs.size(0) avg_loss = running_loss / len(train_loader) train_losses.append(avg_loss) pres_accs.append(correct_pres / total) wifi_accs.append(correct_wifi / max(mask_wifi.sum().item(), 1)) zb_accs.append(correct_zb / max(mask_zb.sum().item(), 1)) # 测试阶段 model.eval() test_loss = 0.0 test_correct_pres = 0 test_correct_wifi = 0 test_correct_zb = 0 test_total = 0 with torch.no_grad(): for inputs, pres_labels, wifi_labels, zb_labels in test_loader: inputs = inputs.to(device) pres_labels = pres_labels.to(device) wifi_labels = wifi_labels.to(device) zb_labels = zb_labels.to(device) outputs = model(inputs) loss_pres = criterion_pres(outputs['presence'], pres_labels) loss_wifi = criterion_ch(outputs['wifi_channel'], wifi_labels) loss_zb = criterion_ch(outputs['zigbee_channel'], zb_labels) test_loss += (loss_pres + loss_wifi + loss_zb).item() # Presence Accuracy pred_pres = (torch.sigmoid(outputs['presence']) > 0.5).int() true_pres = pres_labels.int() test_correct_pres += (pred_pres == true_pres).all(dim=1).sum().item() # WiFi Channel Acc mask_wifi = pres_labels[:, 0] > 0.5 if mask_wifi.sum() > 0: _, pred_wifi = outputs['wifi_channel'][mask_wifi].max(1) test_correct_wifi += (pred_wifi == wifi_labels[mask_wifi]).sum().item() # ZigBee Channel Acc mask_zb = pres_labels[:, 2] > 0.5 if mask_zb.sum() > 0: _, pred_zb = outputs['zigbee_channel'][mask_zb].max(1) test_correct_zb += (pred_zb == zb_labels[mask_zb]).sum().item() test_total += inputs.size(0) test_losses.append(test_loss / len(test_loader)) test_pres_accs.append(test_correct_pres / test_total) test_wifi_accs.append(test_correct_wifi / max(mask_wifi.sum().item(), 1)) test_zb_accs.append(test_correct_zb / max(mask_zb.sum().item(), 1)) if (epoch + 1) % 10 == 0: print(f"Epoch [{epoch+1}/{epochs}] " f"Loss: {avg_loss:.4f} " f"Pres Acc: {pres_accs[-1]:.4f}, Test Pres Acc: {test_pres_accs[-1]:.4f} " f"Wifi Ch Acc: {wifi_accs[-1]:.4f}, ZB Ch Acc: {zb_accs[-1]:.4f}") return train_losses, test_losses, pres_accs, test_pres_accs, wifi_accs, zb_accs, test_wifi_accs, test_zb_accs
10-11
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值