torch分类

转载分类加预测:::https://github.com/heXiangpeng/textcnn

bert文本分类:::https://blog.youkuaiyun.com/ganxiwu9686/article/details/85061759?utm_medium=distribute.pc_relevant.none-task-blog-blogcommendfrommachinelearnpai2-2.edu_weight&depth_1-utm_source=distribute.pc_relevant.none-task-blog-blogcommendfrommachinelearnpai2-2.edu_weight

### PyTorch分类模型评估指标 #### 准确率 (Accuracy) 准确率是指所有预测正确的样本占总样本的比例。对于多类别分类问题,准确率是一个常用的初步衡量标准。 ```python correct = 0 total = 0 with torch.no_grad(): for data in test_loader: images, labels = data outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total print(f'Accuracy of the network on the test images: {accuracy}%') ``` 此代码片段展示了如何计算测试集上的准确率[^1]。 #### 精度 (Precision), 召回率 (Recall), 和 F1分数 (F1 Score) 这些指标通常用于二元分类问题中,但在多标签或多类别的场景下同样适用。为了处理GPU上的数据并累积批次的结果,需要特别注意`target.data.cpu()`的操作来获取张量并在CPU上操作它。 ```python from sklearn.metrics import precision_score, recall_score, f1_score y_true = [] y_pred = [] with torch.no_grad(): for inputs, targets in dataloader: outputs = model(inputs.cuda()) preds = torch.argmax(outputs, dim=1).cpu() y_true.extend(targets.tolist()) y_pred.extend(preds.tolist()) precision = precision_score(y_true, y_pred, average='weighted') recall = recall_score(y_true, y_pred, average='weighted') f1 = f1_score(y_true, y_pred, average='weighted') print(f'Precision: {precision}, Recall: {recall}, F1-Score: {f1}') ``` 这段代码说明了如何收集每一批次的真实标签和预测标签,并最终计算整体的精度、召回率以及F1得分[^2]。 #### 混淆矩阵 (Confusion Matrix) 混淆矩阵提供了一种更直观的方式来看待分类器的表现,特别是当涉及到多个类别时。通过这个矩阵可以清楚地看出哪些类别被误分得最多。 ```python import seaborn as sns from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(10,7)) sns.heatmap(cm, annot=True, fmt='d') plt.xlabel('Predicted') plt.ylabel('Truth') plt.show() ``` 上述代码利用Seaborn库绘制了一个热图形式的混淆矩阵,帮助可视化不同类别之间的错误分布情况。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值