请优化这段代码“import onnxruntime as ort
import numpy as np
import argparse
import logging
import os
import os.path as osp
import mmengine
import torch
from tqdm import tqdm
import pandas as pd
from torch.utils.data import DataLoader
from datasets import build_dataset
from utils import pr_label_from_logits
from models.utils import process_model_inputs
from torch.utils.data import Subset, ConcatDataset
def extract_img_paths(dataset):
"""从不同数据集结构中提取所有图像路径"""
# 处理ConcatDataset
if isinstance(dataset, ConcatDataset):
all_paths = []
for sub in dataset.datasets:
all_paths.extend(extract_img_paths(sub))
return all_paths
# 处理Subset
if isinstance(dataset, Subset):
return extract_img_paths(dataset.dataset)[dataset.indices]
# 处理PersonAttributesDataset
if hasattr(dataset, 'img_paths'):
return dataset.img_paths
# 处理其他具有data_list属性的数据集
if hasattr(dataset, 'data_list'):
return [item['img_path'] for item in dataset.data_list]
# 通用方法:尝试访问样本
try:
return [dataset[i]['img_path'] for i in range(len(dataset))]
except Exception:
raise AttributeError("无法从此数据集类型提取图像路径")
def filter_dataset(dataset, filt_csv):
"""通用数据集过滤函数"""
if not filt_csv:
return dataset
# 获取目标文件名
df = pd.read_csv(filt_csv)
target_filenames = set(df['filename'].tolist())
# 从数据集中提取所有图像路径
img_paths = extract_img_paths(dataset)
# 创建过滤索引
filtered_indices = [
i for i, img_path in enumerate(img_paths)
if os.path.basename(img_path) in target_filenames
]
# 返回过滤后的子集
return Subset(dataset, filtered_indices)
def load_onnx_model(onnx_model_path):
# 加载ONNX模型
session = ort.InferenceSession(onnx_model_path)
return session
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model_root', default='configs\person_attribute')
parser.add_argument('--ckpt_path', default="E:\\python\\cct_train\\onnx_model\\simple.onnx")
parser.add_argument('--imgs_root',
default=None)
parser.add_argument('--file', default=None) #样本标签
parser.add_argument('--save_root', default=None, help="default = draw")
parser.add_argument('--not_draw_results',
default=False, action='store_true')
parser.add_argument('--filt_csv', default=None, help='filter the dataset')
parser.add_argument('--num_classes', type=int, default=8, help='the num of classes')
parser.add_argument('--transforms', default='test_v1', help="test_v1 padding_v1 padding_v2")
return parser.parse_args()
def test_onnx(model_session, data_loader):
# 获取输入名称
input_name = model_session.get_inputs()[0].name
# 初始化精度计算变量
# 这里只有6类,没有衣服颜色类别
category_correct_counts = [0] * 6
category_total_counts = [0] * 6
category_recall_total = [0] * 6
out_img_paths = []
out_logits = []
out_labels = []
# 创建错误样本记录结构
error_records = {
'attributes': {i: [] for i in range(6)}, # 0-5个属性类别
'samples': [] # 完整样本错误记录
}
# 使用 tqdm 包装 data_loader
for idx, data in enumerate(tqdm(data_loader, desc="推理进度", total=len(data_loader))):
# if idx % 2 != 0: # 奇数次跳过推理
# continue
process_data = process_model_inputs(data, 'cpu') # 将数据处理为CPU格式 进行归一化以及通道转换
img_data = process_data['img'].numpy() # 将图像数据转换为numpy格式
# 单张图片推理
result = model_session.run(None, {input_name: img_data})
predict_logits = np.concatenate(result, axis=1) # 压缩维度
#处理标签
predict_labels = pr_label_from_logits(torch.from_numpy(predict_logits))
# 计算准确率
for img_path, predict_label, predict_logit, gt_label in zip(data['img_path'],
predict_labels, predict_logits, process_data['label']):
predict_label = [ele for ele in predict_label]
gt_label = gt_label.cpu().tolist()
gt_label = [int(gt_label[i]) for i in range(len(gt_label))]
if len(gt_label) == 8:
if gt_label == [-1]*8:
print(img_path)
print("gt_label不合规, 跳过")
continue
elif len(gt_label) == 32:
if gt_label == [-1]*32:
print(img_path)
print("gt_label不合规, 跳过")
continue
elif len(gt_label) == 6:
if gt_label == [-1]*6:
print(img_path)
print("gt_label不合规, 跳过")
continue
# 创建当前样本的错误记录
sample_error = {
'img_path': img_path,
'wrong_attributes': [] # 存储错误的属性索引和值
}
if len(gt_label) == 8 or len(gt_label)==32:
# 处理二分类部分
for i in range(5): # 二分类部分1
if gt_label[i] != -1:
# 有效样本数
category_recall_total[i] += 1
if predict_label[i] != -1:
# total和racall应该是一样的,预测结果只能是0和1
category_total_counts[i] += 1
if predict_label[i] == gt_label[i]:
category_correct_counts[i] += 1 # 预测正确数
else:
wrong_attr = {
'attribute_idx': i,
'gt_value': gt_label[i],
'pred_value': predict_label[i],
'confidence': predict_logit[i] # 预测置信度
}
sample_error['wrong_attributes'].append(wrong_attr)
error_records['attributes'][i].append(img_path)
# 处理下半身衣服类型
lower_body_index = 5
if len(gt_label)==8:
if gt_label[5: 8] == [-1] * 3:
pass
else:
category_total_counts[5] += 1
if predict_label[17: 20] == gt_label[5: 8]:
category_correct_counts[5] += 1
else:
# 记录下半身错误
wrong_attr = {
'attribute_idx': lower_body_index,
'gt_value': gt_label[5:8],
'pred_value': predict_label[17:20],
'confidence': max(predict_logit[17:20]) # 最高置信度
}
sample_error['wrong_attributes'].append(wrong_attr)
error_records['attributes'][lower_body_index].append(img_path)
else:
if gt_label[17: 20] == [-1] * 3:
pass
else:
category_total_counts[5] += 1
if predict_label[17: 20] == gt_label[17: 20]:
category_correct_counts[5] += 1
else:
# 记录下半身错误
wrong_attr = {
'attribute_idx': lower_body_index,
'gt_value': gt_label[17:20],
'pred_value': predict_label[17:20],
'confidence': max(predict_logit[17:20]) # 最高置信度
}
sample_error['wrong_attributes'].append(wrong_attr)
error_records['attributes'][lower_body_index].append(img_path)
# 如果样本有错误属性,添加到总记录
elif len(gt_label)==6:
# 处理二分类部分
for i in range(3): # 二分类部分1
if gt_label[i] != -1:
# 有效样本数
category_recall_total[i+2] += 1
if predict_label[i+2] != -1:
# total和racall应该是一样的,预测结果只能是0和1
category_total_counts[i+2] += 1
if predict_label[i+2] == gt_label[i]:
category_correct_counts[i+2] += 1 # 预测正确数
else:
wrong_attr = {
'attribute_idx': i+2,
'gt_value': gt_label[i],
'pred_value': predict_label[i+2],
'confidence': predict_logit[i+2] # 预测置信度
}
sample_error['wrong_attributes'].append(wrong_attr)
error_records['attributes'][i+2].append(img_path)
# 处理下半身衣服类型
lower_body_index = 5
if gt_label[3: 6] == [-1] * 3:
pass
else:
category_total_counts[5] += 1
if predict_label[17: 20] == gt_label[3: 6]:
category_correct_counts[5] += 1
else:
# 记录下半身错误
wrong_attr = {
'attribute_idx': lower_body_index,
'gt_value': gt_label[3: 6],
'pred_value': predict_label[17:20],
'confidence': max(predict_logit[17:20]) # 最高置信度
}
sample_error['wrong_attributes'].append(wrong_attr)
error_records['attributes'][lower_body_index].append(img_path)
# 如果样本有错误属性,添加到总记录
if sample_error['wrong_attributes']:
error_records['samples'].append(sample_error)
out_img_paths.extend(data['img_path'])
out_logits.append(predict_logits)
out_labels.append(predict_labels)
category_accuracies = []
for correct_count, total_count in zip(category_correct_counts, category_total_counts):
accuracy = correct_count / total_count if total_count > 0 else 0.0
category_accuracies.append(accuracy)
# 可视化
for i in range(6):
print(category_correct_counts[i], category_total_counts[i])
out_logits = np.concatenate(out_logits, axis=0)
out_labels = np.concatenate(out_labels, axis=0)
return out_img_paths, out_logits, out_labels, category_accuracies, error_records
def main(cfg, ckpt_path, test_imgs_root, file, save_root, filt_csv, num_classes, transforms):
if not os.path.exists(ckpt_path):
return
# 建立模型
model_session = load_onnx_model(ckpt_path)
#读取图片,test_v1进行resize变换,为(128,256)
test_datasets = [dict(type='PersonAttributesDataset')]
test_datasets[0]['imgs_root'] = test_imgs_root
test_datasets[0]['labels_path'] = file
test_datasets[0]['num_classes'] = num_classes
# test_datasets[0]['transforms'] = 'test_v1'
# test_datasets[0]['transforms'] = 'padding_v1'
test_datasets[0]['transforms'] = transforms
test_dataset = build_dataset(test_datasets)
# num_workers 用于数据加载的子进程
cfg['val_data_loader'] = dict()
cfg['val_data_loader']['batch_size'] = 1
cfg['val_data_loader']['num_workers'] = 4
if filt_csv:
# 筛选样本
test_dataset = filter_dataset(test_dataset, filt_csv) # 应用过滤
test_data_loader = DataLoader(test_dataset, **cfg['val_data_loader'])
# 测试onnx模型
img_paths, predict_logits, predict_labels, acc, error_records = test_onnx(model_session, test_data_loader)
print(acc)
if save_root:
generate_error_report(error_records, save_root)
mmengine.dump([img_paths, predict_logits, predict_labels], osp.join(save_root, 'test.pkl'))
def generate_error_report(error_records,save_root, output_name="error_analysis_report.txt"):
"""生成详细的错误分析报告"""
output_path = osp.join(save_root, output_name)
with open(output_path, 'w') as f:
# 按属性统计错误
f.write("===== 属性错误统计 =====\n")
for attr_idx, img_paths in error_records['attributes'].items():
f.write(f"属性 {attr_idx} 错误样本数: {len(img_paths)}\n")
# 详细错误样本列表
f.write("\n===== 详细错误样本列表 =====\n")
for sample in error_records['samples']:
f.write(f"\n图像路径: {sample['img_path']}\n")
# 错误属性详情
f.write("错误属性:\n")
for wrong_attr in sample['wrong_attributes']:
f.write(f" 属性 {wrong_attr['attribute_idx']}: ")
f.write(f"真实值={wrong_attr['gt_value']}, ")
f.write(f"预测值={wrong_attr['pred_value']}, ")
f.write(f"置信度={wrong_attr['confidence']:.4f}\n")
# 错误样本汇总
f.write("\n===== 错误样本汇总 =====\n")
all_error_paths = set()
for attr_list in error_records['attributes'].values():
all_error_paths.update(attr_list)
f.write(f"总错误样本数: {len(all_error_paths)}\n")
f.write("所有错误样本路径:\n")
for path in sorted(all_error_paths):
f.write(f"{path}\n")
print(f"错误分析报告已保存至: {output_path}")
if __name__ == '__main__':
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())
args = parse_args()
cfg_path = osp.join(args.model_root, 'config.yml')
cfg = mmengine.load(cfg_path)
logger.info(f'predict args {args}')
main(cfg, args.ckpt_path, args.imgs_root, args.file, args.save_root, args.filt_csv, args.num_classes, args.transforms)”,使运行结果和错误分析报告的名称可以在args中指定,减少不必要的内容。
最新发布