根据我下边的代码修改,这段代码我在R-CNN可以正常使用,改为ssd使用。训练好的模型为ssd模型,使用vgg16,图像和标注在YOLO数据集上,我要数据,比如map50,map50-95,精度P,召回率R,参数量Parameters,计算量GFLOPs,检测速度FPS。生成一段可以直接使用的代码 import os
import sys
import json
import torch
import numpy as np
import pandas as pd
import time
import argparse
import logging
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import glob
import shutil
import warnings
from collections import defaultdict
# 忽略特定警告
warnings.filterwarnings('ignore')
# ==============================================
# 1. 修复路径问题 - 自动查找SSD项目
# ==============================================
def find_ssd_project():
"""查找SSD项目目录"""
current_dir = os.path.dirname(os.path.abspath(__file__))
# 可能的项目路径
possible_project_dirs = [
# 在当前目录或其父目录中查找
current_dir,
os.path.dirname(current_dir),
os.path.dirname(os.path.dirname(current_dir)),
# 特定名称的目录
os.path.join(os.path.dirname(current_dir), 'ssd-pytorch-master'),
os.path.join(current_dir, 'ssd-pytorch-master'),
# 您的特定路径
'C:/A/ZY/0_YOLO资源包/1_环境配置/ssd-pytorch-master',
'../ssd-pytorch-master',
'../../ssd-pytorch-master',
]
# 检查每个可能的目录
for project_dir in possible_project_dirs:
if os.path.exists(project_dir):
# 检查是否包含必要的文件
ssd_file = os.path.join(project_dir, 'ssd.py')
nets_dir = os.path.join(project_dir, 'nets')
# SSD项目通常有ssd.py或nets/ssd.py
if os.path.exists(nets_dir) or os.path.exists(ssd_file):
print(f"找到SSD项目目录: {project_dir}")
return project_dir
return None
# 查找并添加项目目录
project_dir = find_ssd_project()
if project_dir:
# 添加到Python路径
if project_dir not in sys.path:
sys.path.insert(0, project_dir)
else:
print("错误: 未找到SSD项目目录")
print("请确保您在正确的目录中运行")
print("\n当前目录:", os.getcwd())
sys.exit(1)
# ==============================================
# 2. 导入SSD项目模块
# ==============================================
try:
# 尝试导入SSD模型(根据您的SSD项目结构)
from nets.ssd import SSD
print("成功导入SSD模块")
# 假设您的SSD类初始化方式
# 可能需要根据实际项目调整
SSD_MODEL = SSD
except ImportError:
try:
# 另一种可能的导入方式
from ssd import SSD
SSD_MODEL = SSD
print("成功导入SSD模块")
except ImportError:
try:
# 有些项目可能这样组织
from model.ssd import SSD
SSD_MODEL = SSD
print("成功导入SSD模块")
except ImportError as e:
print(f"导入SSD模块失败: {e}")
print("请检查SSD模型的文件结构")
# 列出当前目录文件
print("\n当前目录内容:")
for file in os.listdir('.'):
if file.endswith('.py'):
print(f" {file}")
print("\nnets目录内容:")
nets_dir = os.path.join(project_dir, 'nets')
if os.path.exists(nets_dir):
for file in os.listdir(nets_dir):
if file.endswith('.py'):
print(f" {file}")
sys.exit(1)
try:
# 尝试导入utils模块
from utils.utils import get_classes
from utils.utils_map import get_map
print("成功导入utils模块")
except ImportError as e:
print(f"导入utils模块失败: {e}")
# 尝试直接导入
import importlib.util
# 尝试导入get_classes
utils_path = os.path.join(project_dir, "utils", "utils.py")
if os.path.exists(utils_path):
spec = importlib.util.spec_from_file_location("utils", utils_path)
utils_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(utils_module)
get_classes = utils_module.get_classes
print("成功直接导入get_classes")
else:
print(f"未找到utils.py: {utils_path}")
sys.exit(1)
# 尝试导入get_map
utils_map_path = os.path.join(project_dir, "utils", "utils_map.py")
if os.path.exists(utils_map_path):
spec2 = importlib.util.spec_from_file_location("utils_map", utils_map_path)
utils_map_module = importlib.util.module_from_spec(spec2)
spec2.loader.exec_module(utils_map_module)
get_map = utils_map_module.get_map
print("成功直接导入get_map")
else:
print(f"未找到utils_map.py: {utils_map_path}")
sys.exit(1)
# ==============================================
# 3. SSD评估函数
# ==============================================
def setup_logging(output_dir):
"""设置日志记录"""
os.makedirs(output_dir, exist_ok=True)
log_path = os.path.join(output_dir, 'ssd_evaluation.log')
logger = logging.getLogger('SSD_Evaluation')
logger.setLevel(logging.INFO)
# 清除现有的handlers
logger.handlers.clear()
# 文件handler
file_handler = logging.FileHandler(log_path, encoding='utf-8')
file_handler.setLevel(logging.INFO)
# 控制台handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 格式化器
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 添加handlers
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger
def calculate_model_params(model):
"""计算模型参数量"""
try:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
except:
total_params = 0
trainable_params = 0
return total_params, trainable_params
def load_val_images(yolo_data_path):
"""加载val集的图像"""
val_images = []
# 尝试多个可能的路径模式
possible_paths = [
os.path.join(yolo_data_path, 'val.txt'),
os.path.join(yolo_data_path, 'images/val'),
os.path.join(yolo_data_path, 'val/images'),
os.path.join(yolo_data_path, 'val'),
os.path.join(yolo_data_path, 'images'),
yolo_data_path
]
for path in possible_paths:
if os.path.exists(path):
if path.endswith('.txt'):
with open(path, 'r') as f:
for line in f:
line = line.strip()
if line:
val_images.append(line)
if val_images:
return val_images
else:
# 加载图像文件
image_extensions = ['*.jpg', '*.jpeg', '*.png', '*.bmp']
for ext in image_extensions:
val_images.extend(glob.glob(os.path.join(path, '**', ext), recursive=True))
val_images.extend(glob.glob(os.path.join(path, ext)))
if val_images:
return val_images
return val_images
def find_image_full_path(yolo_data_path, img_record):
"""根据记录查找完整的图像路径"""
img_record = img_record.strip()
# 如果已经是绝对路径且存在
if os.path.isabs(img_record) and os.path.exists(img_record):
return img_record
# 尝试多个可能的路径
possible_paths = [
os.path.join(yolo_data_path, img_record),
os.path.join(yolo_data_path, 'images', img_record),
os.path.join(yolo_data_path, 'images', 'val', img_record),
os.path.join(yolo_data_path, 'val', 'images', img_record),
os.path.join(yolo_data_path, 'val', img_record),
]
for test_path in possible_paths:
if os.path.exists(test_path):
return test_path
# 如果记录中没有扩展名,尝试添加常见扩展名
filename, ext = os.path.splitext(img_record)
if not ext:
for ext in ['.jpg', '.jpeg', '.png', '.bmp']:
test_paths = [
os.path.join(yolo_data_path, 'images', 'val', filename + ext),
os.path.join(yolo_data_path, 'images', filename + ext),
os.path.join(yolo_data_path, filename + ext),
os.path.join(yolo_data_path, 'val', 'images', filename + ext),
]
for test_path in test_paths:
if os.path.exists(test_path):
return test_path
return None
def find_label_file(yolo_data_path, img_path):
"""查找对应的标签文件"""
# 获取图像文件名(不含扩展名)
img_filename = os.path.splitext(os.path.basename(img_path))[0]
# 获取图像的相对路径(相对于images目录)
if 'images' in img_path:
img_rel_path = img_path.split('images')[1].lstrip(os.sep)
img_rel_dir = os.path.dirname(img_rel_path)
else:
img_rel_dir = ''
# 尝试不同的标签路径
possible_label_paths = [
img_path.replace('images', 'labels').replace(os.path.splitext(img_path)[1], '.txt'),
os.path.join(yolo_data_path, 'labels', 'val', img_filename + '.txt'),
os.path.join(yolo_data_path, 'labels', img_filename + '.txt'),
os.path.join(yolo_data_path, 'labels', img_rel_dir, img_filename + '.txt'),
os.path.join(yolo_data_path, 'val', 'labels', img_filename + '.txt'),
]
for label_path in possible_label_paths:
if os.path.exists(label_path):
return label_path
return None
def evaluate_ssd_val(model, classes_path, yolo_data_path, output_dir='ssd_evaluation_results_val',
confidence=0.5, nms_iou=0.5, max_images=None, logger=None):
"""
使用val集评估SSD模型
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
if logger is None:
logger = logging.getLogger('SSD_Evaluation')
# 获取类别信息
class_names, _ = get_classes(classes_path)
num_classes = len(class_names)
logger.info(f"类别数量: {num_classes}")
logger.info(f"类别列表: {class_names}")
# 1. 计算模型参数量
logger.info("计算模型参数量...")
total_params, trainable_params = calculate_model_params(model)
logger.info(f"模型参数量统计:")
logger.info(f" 总参数量: {total_params:,}")
logger.info(f" 可训练参数量: {trainable_params:,}")
# 2. 加载val集图像
logger.info("加载val集图像...")
val_image_records = load_val_images(yolo_data_path)
if not val_image_records:
logger.error("未找到val集图像!")
return None
logger.info(f"找到 {len(val_image_records)} 个val集记录")
# 转换为完整路径
val_images = []
for record in tqdm(val_image_records, desc="解析图像路径"):
full_path = find_image_full_path(yolo_data_path, record)
if full_path and os.path.exists(full_path):
val_images.append(full_path)
else:
logger.warning(f"无法找到图像: {record}")
logger.info(f"成功解析 {len(val_images)} 个val集图像路径")
if not val_images:
logger.error("没有找到有效的图像文件!")
return None
# 限制测试图像数量
if max_images and max_images < len(val_images):
test_images = val_images[:max_images]
logger.info(f"限制测试 {max_images} 张图像")
else:
test_images = val_images
logger.info(f"测试全部 {len(test_images)} 张图像")
# 3. 准备评估数据结构
all_predictions = []
all_ground_truths = []
inference_times = []
# 4. 创建临时目录
temp_dir = os.path.join(output_dir, 'temp_mAP')
# 删除旧目录(如果存在)
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
os.makedirs(temp_dir, exist_ok=True)
os.makedirs(os.path.join(temp_dir, 'ground-truth'), exist_ok=True)
os.makedirs(os.path.join(temp_dir, 'detection-results'), exist_ok=True)
logger.info("开始推理...")
# 5. 进行推理
for img_path in tqdm(test_images, desc="推理进度"):
try:
# 获取图像ID
image_id = os.path.splitext(os.path.basename(img_path))[0]
# 加载图像
image = Image.open(img_path)
image = image.convert('RGB')
img_width, img_height = image.size
# 推理时间测量
start_time = time.time()
# 使用SSD模型进行检测
# 根据您的SSD项目,这里可能需要调整
if hasattr(model, 'detect_image'):
# 假设SSD模型有detect_image方法
r_image, predictions = model.detect_image(image)
# 保存检测结果到文件
det_file = os.path.join(temp_dir, 'detection-results', f"{image_id}.txt")
with open(det_file, 'w') as f:
for pred in predictions:
# 假设predictions的格式为 [x1, y1, x2, y2, class_id, confidence]
if len(pred) >= 6:
class_id = int(pred[4])
confidence_score = float(pred[5])
if class_id < num_classes:
class_name = class_names[class_id]
x1, y1, x2, y2 = pred[:4]
f.write(f"{class_name} {confidence_score:.6f} {x1:.1f} {y1:.1f} {x2:.1f} {y2:.1f}\n")
inference_time = time.time() - start_time
inference_times.append(inference_time)
# 查找对应的标签文件
label_path = find_label_file(yolo_data_path, img_path)
if label_path and os.path.exists(label_path):
with open(label_path, 'r') as f:
lines = f.readlines()
for line in lines:
parts = line.strip().split()
if len(parts) >= 5:
try:
class_id = int(parts[0])
x_center = float(parts[1]) * img_width
y_center = float(parts[2]) * img_height
width = float(parts[3]) * img_width
height = float(parts[4]) * img_height
x1 = max(0, x_center - width / 2)
y1 = max(0, y_center - height / 2)
x2 = min(img_width, x_center + width / 2)
y2 = min(img_height, y_center + height / 2)
if 0 <= class_id < num_classes:
all_ground_truths.append({
'image_id': image_id,
'bbox': [x1, y1, x2, y2],
'category_id': class_id,
'category_name': class_names[class_id]
})
# 保存ground truth到文件
gt_file = os.path.join(temp_dir, 'ground-truth', f"{image_id}.txt")
with open(gt_file, 'a') as f:
f.write(f"{class_names[class_id]} {x1:.1f} {y1:.1f} {x2:.1f} {y2:.1f}\n")
except Exception as e:
logger.warning(f"解析标签行时出错: {line}, 错误: {e}")
continue
else:
logger.warning(f"未找到标签文件: {img_path}")
except Exception as e:
logger.error(f"处理图像 {img_path} 时出错: {e}")
continue
logger.info(f"成功处理 {len(inference_times)} 张图像")
logger.info(f"真实标注有 {len(all_ground_truths)} 个目标")
if len(inference_times) == 0:
logger.error("没有成功处理任何图像!")
return None
# 6. 确保所有图像都有对应的文件
logger.info("确保所有文件存在...")
# 为每个图像创建空的检测结果文件(如果不存在)
for img_path in test_images:
image_id = os.path.splitext(os.path.basename(img_path))[0]
det_file = os.path.join(temp_dir, 'detection-results', f"{image_id}.txt")
if not os.path.exists(det_file):
with open(det_file, 'w') as f:
pass
# 7. 计算FPS
avg_inference_time = np.mean(inference_times)
fps = 1.0 / avg_inference_time if avg_inference_time > 0 else 0
logger.info(f"推理速度:")
logger.info(f" 平均推理时间: {avg_inference_time * 1000:.2f}ms")
logger.info(f" FPS: {fps:.2f}")
# 8. 计算mAP50
logger.info("计算mAP50指标...")
map50 = 0.0
try:
map50 = get_map(0.5, False, score_threhold=0.5, path=temp_dir)
logger.info(f" mAP50: {map50:.4f}")
except Exception as e:
logger.error(f"计算mAP50时出错: {e}")
map50 = 0.0
# 9. 生成评估报告
evaluation_report = {
'model_info': {
'total_parameters': int(total_params),
'trainable_parameters': int(trainable_params),
'confidence_threshold': confidence,
'nms_iou_threshold': nms_iou
},
'performance_metrics': {
'mAP50': float(map50)
},
'inference_speed': {
'avg_inference_time_ms': float(avg_inference_time * 1000),
'fps': float(fps),
'test_images_count': len(inference_times)
},
'dataset_info': {
'val_set_size': len(val_images),
'actual_tested_count': len(inference_times),
'num_classes': num_classes,
'class_names': class_names
},
'detection_counts': {
'total_ground_truths': len(all_ground_truths),
'average_ground_truths_per_image': len(all_ground_truths) / len(inference_times) if inference_times else 0
}
}
# 10. 保存结果
logger.info("保存评估结果...")
# 保存为JSON
json_path = os.path.join(output_dir, 'ssd_evaluation_results.json')
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(evaluation_report, f, indent=4, ensure_ascii=False, default=str)
logger.info(f" JSON结果保存至: {json_path}")
# 保存为CSV
csv_data = []
main_metrics = [
('模型参数量', evaluation_report['model_info']['total_parameters'], 'Parameters'),
('mAP50', evaluation_report['performance_metrics']['mAP50'], ''),
('FPS', evaluation_report['inference_speed']['fps'], ''),
('推理时间', evaluation_report['inference_speed']['avg_inference_time_ms'], 'ms'),
]
for name, value, unit in main_metrics:
csv_data.append({
'指标类别': '主要指标',
'指标名称': name,
'数值': value,
'单位': unit
})
df = pd.DataFrame(csv_data)
csv_path = os.path.join(output_dir, 'ssd_evaluation_results.csv')
df.to_csv(csv_path, index=False, encoding='utf-8-sig')
logger.info(f" CSV结果保存至: {csv_path}")
# 11. 清理临时目录
try:
shutil.rmtree(temp_dir, ignore_errors=True)
except Exception as e:
logger.warning(f"清理临时目录时出错: {e}")
return evaluation_report
def main():
"""主函数:使用val集评估SSD模型"""
# 解析命令行参数
parser = argparse.ArgumentParser(description='SSD模型评估')
parser.add_argument('--dataset', type=str, default='dataset11.0', help='数据集路径')
parser.add_argument('--classes', type=str, default='dataset11.0/classes.txt', help='类别文件路径')
parser.add_argument('--confidence', type=float, default=0.5, help='置信度阈值')
parser.add_argument('--nms_iou', type=float, default=0.3, help='NMS IoU阈值')
parser.add_argument('--output_dir', type=str, default='ssd_evaluation_results_val', help='输出目录')
parser.add_argument('--max_images', type=int, default=None, help='最大测试图像数')
parser.add_argument('--model_path', type=str, default=None, help='模型权重路径')
args = parser.parse_args()
config = {
'yolo_dataset_path': args.dataset,
'classes_path': args.classes,
'confidence': args.confidence,
'nms_iou': args.nms_iou,
'output_dir': args.output_dir,
'max_images': args.max_images,
'model_path': args.model_path
}
# 设置日志
logger = setup_logging(config['output_dir'])
# 打印配置
logger.info("=" * 60)
logger.info("SSD模型评估配置")
logger.info("=" * 60)
for key, value in config.items():
logger.info(f"{key}: {value}")
# 初始化SSD模型
logger.info("加载SSD模型...")
try:
# 根据您的SSD项目初始化方式调整
# 这里是一个示例,您需要根据实际情况修改
ssd_model = SSD(
confidence=config['confidence'],
nms_iou=config['nms_iou'],
model_path=config['model_path']
)
except Exception as e:
logger.error(f"加载SSD模型失败: {e}")
# 尝试其他初始化方式
try:
# 尝试不同的参数
ssd_model = SSD()
logger.info("使用默认参数初始化SSD模型")
except Exception as e2:
logger.error(f"默认初始化也失败: {e2}")
return
# 评估模型
logger.info("开始评估...")
results = evaluate_ssd_val(
model=ssd_model,
classes_path=config['classes_path'],
yolo_data_path=config['yolo_dataset_path'],
output_dir=config['output_dir'],
confidence=config['confidence'],
nms_iou=config['nms_iou'],
max_images=config['max_images'],
logger=logger
)
if results:
logger.info("=" * 60)
logger.info("SSD模型评估完成!")
logger.info("=" * 60)
# 打印关键指标
print("\n" + "=" * 60)
print("SSD关键指标汇总:")
print("=" * 60)
print(f"1. 模型参数:")
print(f" 参数量: {results['model_info']['total_parameters']:,} Parameters")
print(f"\n2. 检测性能:")
print(f" mAP50: {results['performance_metrics']['mAP50']:.4f}")
print(f"\n3. 推理速度:")
print(f" FPS: {results['inference_speed']['fps']:.2f}")
print(f" 推理时间: {results['inference_speed']['avg_inference_time_ms']:.2f}ms")
print(f"\n4. 数据集:")
print(f" val集大小: {results['dataset_info']['val_set_size']}张")
print(f" 实际测试: {results['dataset_info']['actual_tested_count']}张")
print(f" 类别: {', '.join(results['dataset_info']['class_names'])}")
print(f"\n详细报告已保存至: {config['output_dir']}")
# 生成简洁报告
summary_path = os.path.join(config['output_dir'], 'SSD评估报告.txt')
with open(summary_path, 'w', encoding='utf-8') as f:
f.write("=" * 60 + "\n")
f.write("SSD模型评估报告\n")
f.write("=" * 60 + "\n\n")
f.write(f"评估时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
f.write("一、模型信息:\n")
f.write(f" 总参数量: {results['model_info']['total_parameters']:,} Parameters\n")
f.write(f" 置信度阈值: {results['model_info']['confidence_threshold']}\n")
f.write(f" NMS IoU阈值: {results['model_info']['nms_iou_threshold']}\n\n")
f.write("二、检测性能:\n")
f.write(f" mAP50: {results['performance_metrics']['mAP50']:.4f}\n\n")
f.write("三、推理速度:\n")
f.write(f" FPS: {results['inference_speed']['fps']:.2f}\n")
f.write(f" 平均推理时间: {results['inference_speed']['avg_inference_time_ms']:.2f}ms\n")
f.write(f" 测试图像数量: {results['inference_speed']['test_images_count']}张\n\n")
f.write("四、数据集信息:\n")
f.write(f" val集大小: {results['dataset_info']['val_set_size']}张\n")
f.write(f" 实际测试: {results['dataset_info']['actual_tested_count']}张\n")
f.write(f" 类别数量: {results['dataset_info']['num_classes']}\n")
f.write(f" 类别列表: {', '.join(results['dataset_info']['class_names'])}\n")
print(f"简洁报告: {summary_path}")
else:
logger.error("SSD评估失败!")
if __name__ == "__main__":
main() 这是我的ssd.py代码, import colorsys
import os
import time
import warnings
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image, ImageDraw, ImageFont
from nets.ssd import SSD300
from utils.anchors import get_anchors
from utils.utils import (cvtColor, get_classes, preprocess_input, resize_image,
show_config)
from utils.utils_bbox import BBoxUtility
warnings.filterwarnings("ignore")
#--------------------------------------------#
# 使用自己训练好的模型预测需要修改3个参数
# model_path、backbone和classes_path都需要修改!
# 如果出现shape不匹配
# 一定要注意训练时的config里面的num_classes、
# model_path和classes_path参数的修改
#--------------------------------------------#
class SSD(object):
_defaults = {
#--------------------------------------------------------------------------#
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
#
# 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
# 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
#--------------------------------------------------------------------------#
"model_path" : 'logs/ep200-loss1.828-val_loss1.712.pth',
"classes_path" : 'dataset11.0/classes.txt',
#---------------------------------------------------------------------#
# 用于预测的图像大小,和train时使用同一个即可
#---------------------------------------------------------------------#
"input_shape" : [640, 640],
#-------------------------------#
# 主干网络的选择
# vgg或者mobilenetv2或者resnet50
#-------------------------------#
"backbone" : "vgg",
#---------------------------------------------------------------------#
# 只有得分大于置信度的预测框会被保留下来
#---------------------------------------------------------------------#
"confidence" : 0.5,
#---------------------------------------------------------------------#
# 非极大抑制所用到的nms_iou大小
#---------------------------------------------------------------------#
"nms_iou" : 0.3,
#---------------------------------------------------------------------#
# 用于指定先验框的大小 [21, 45, 99, 153, 207, 261, 315] [30, 60, 111, 162, 213, 264, 315]
#---------------------------------------------------------------------#
'anchors_size' : [21, 45, 99, 153, 207, 261, 315],
#---------------------------------------------------------------------#
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
#---------------------------------------------------------------------#
"letterbox_image" : False,
#-------------------------------#
# 是否使用Cuda
# 没有GPU可以设置成False
#-------------------------------#
"cuda" : True,
}
@classmethod
def get_defaults(cls, n):
if n in cls._defaults:
return cls._defaults[n]
else:
return "Unrecognized attribute name '" + n + "'"
#---------------------------------------------------#
# 初始化ssd
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
for name, value in kwargs.items():
setattr(self, name, value)
#---------------------------------------------------#
# 计算总的类的数量
#---------------------------------------------------#
self.class_names, self.num_classes = get_classes(self.classes_path)
self.anchors = torch.from_numpy(get_anchors(self.input_shape, self.anchors_size, self.backbone)).type(torch.FloatTensor)
if self.cuda:
self.anchors = self.anchors.cuda()
self.num_classes = self.num_classes + 1
#---------------------------------------------------#
# 画框设置不同的颜色
#---------------------------------------------------#
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
self.bbox_util = BBoxUtility(self.num_classes)
self.generate()
show_config(**self._defaults)
#---------------------------------------------------#
# 载入模型
#---------------------------------------------------#
def generate(self, onnx=False):
#-------------------------------#
# 载入模型与权值
#-------------------------------#
self.net = SSD300(self.num_classes, self.backbone)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
self.net = self.net.eval()
print('{} model, anchors, and classes loaded.'.format(self.model_path))
if not onnx:
if self.cuda:
self.net = torch.nn.DataParallel(self.net)
self.net = self.net.cuda()
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image, crop = False, count = False):
#---------------------------------------------------#
# 计算输入图片的高和宽
#---------------------------------------------------#
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度,图片预处理,归一化。
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
with torch.no_grad():
#---------------------------------------------------#
# 转化成torch的形式
#---------------------------------------------------#
images = torch.from_numpy(image_data).type(torch.FloatTensor)
if self.cuda:
images = images.cuda()
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
#-----------------------------------------------------------#
# 将预测结果进行解码
#-----------------------------------------------------------#
results = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image,
nms_iou = self.nms_iou, confidence = self.confidence)
#--------------------------------------#
# 如果没有检测到物体,则返回原图
#--------------------------------------#
if len(results[0]) <= 0:
return image
top_label = np.array(results[0][:, 4], dtype = 'int32')
top_conf = results[0][:, 5]
top_boxes = results[0][:, :4]
#---------------------------------------------------------#
# 设置字体与边框厚度
#---------------------------------------------------------#
font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.input_shape[0], 1)
#---------------------------------------------------------#
# 计数
#---------------------------------------------------------#
if count:
print("top_label:", top_label)
classes_nums = np.zeros([self.num_classes])
for i in range(self.num_classes):
num = np.sum(top_label == i)
if num > 0:
print(self.class_names[i], " : ", num)
classes_nums[i] = num
print("classes_nums:", classes_nums)
#---------------------------------------------------------#
# 是否进行目标的裁剪
#---------------------------------------------------------#
if crop:
for i, c in list(enumerate(top_boxes)):
top, left, bottom, right = top_boxes[i]
top = max(0, np.floor(top).astype('int32'))
left = max(0, np.floor(left).astype('int32'))
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
right = min(image.size[0], np.floor(right).astype('int32'))
dir_save_path = "img_crop"
if not os.path.exists(dir_save_path):
os.makedirs(dir_save_path)
crop_image = image.crop([left, top, right, bottom])
crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
print("save crop_" + str(i) + ".png to " + dir_save_path)
#---------------------------------------------------------#
# 图像绘制
#---------------------------------------------------------#
for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)]
box = top_boxes[i]
score = top_conf[i]
top, left, bottom, right = box
top = max(0, np.floor(top).astype('int32'))
left = max(0, np.floor(left).astype('int32'))
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
right = min(image.size[0], np.floor(right).astype('int32'))
label = '{} {:.2f}'.format(predicted_class, score)
draw = ImageDraw.Draw(image)
label_size = draw.textsize(label, font)
label = label.encode('utf-8')
print(label, top, left, bottom, right)
if top - label_size[1] >= 0:
text_origin = np.array([left, top - label_size[1]])
else:
text_origin = np.array([left, top + 1])
for i in range(thickness):
draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
del draw
return image
def get_FPS(self, image, test_interval):
#---------------------------------------------------#
# 计算输入图片的高和宽
#---------------------------------------------------#
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度,图片预处理,归一化。
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
with torch.no_grad():
#---------------------------------------------------#
# 转化成torch的形式
#---------------------------------------------------#
images = torch.from_numpy(image_data).type(torch.FloatTensor)
if self.cuda:
images = images.cuda()
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
#-----------------------------------------------------------#
# 将预测结果进行解码
#-----------------------------------------------------------#
results = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image,
nms_iou = self.nms_iou, confidence = self.confidence)
t1 = time.time()
for _ in range(test_interval):
with torch.no_grad():
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
#-----------------------------------------------------------#
# 将预测结果进行解码
#-----------------------------------------------------------#
results = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image,
nms_iou = self.nms_iou, confidence = self.confidence)
t2 = time.time()
tact_time = (t2 - t1) / test_interval
return tact_time
def convert_to_onnx(self, simplify, model_path):
import onnx
self.generate(onnx=True)
im = torch.zeros(1, 3, *self.input_shape).to('cpu') # image size(1, 3, 512, 512) BCHW
input_layer_names = ["images"]
output_layer_names = ["output"]
# Export the model
print(f'Starting export with onnx {onnx.__version__}.')
torch.onnx.export(self.net,
im,
f = model_path,
verbose = False,
opset_version = 12,
training = torch.onnx.TrainingMode.EVAL,
do_constant_folding = True,
input_names = input_layer_names,
output_names = output_layer_names,
dynamic_axes = None)
# Checks
model_onnx = onnx.load(model_path) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
# Simplify onnx
if simplify:
import onnxsim
print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
model_onnx, check = onnxsim.simplify(
model_onnx,
dynamic_input_shape=False,
input_shapes=None)
assert check, 'assert check failed'
onnx.save(model_onnx, model_path)
print('Onnx model save as {}'.format(model_path))
def get_map_txt(self, image_id, image, class_names, map_out_path):
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
#---------------------------------------------------#
# 计算输入图片的高和宽
#---------------------------------------------------#
image_shape = np.array(np.shape(image)[0:2])
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
#---------------------------------------------------------#
image = cvtColor(image)
#---------------------------------------------------------#
# 给图像增加灰条,实现不失真的resize
# 也可以直接resize进行识别
#---------------------------------------------------------#
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
#---------------------------------------------------------#
# 添加上batch_size维度,图片预处理,归一化。
#---------------------------------------------------------#
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
with torch.no_grad():
#---------------------------------------------------#
# 转化成torch的形式
#---------------------------------------------------#
images = torch.from_numpy(image_data).type(torch.FloatTensor)
if self.cuda:
images = images.cuda()
#---------------------------------------------------------#
# 将图像输入网络当中进行预测!
#---------------------------------------------------------#
outputs = self.net(images)
#-----------------------------------------------------------#
# 将预测结果进行解码
#-----------------------------------------------------------#
results = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image,
nms_iou = self.nms_iou, confidence = self.confidence)
#--------------------------------------#
# 如果没有检测到物体,则返回原图
#--------------------------------------#
if len(results[0]) <= 0:
return
top_label = np.array(results[0][:, 4], dtype = 'int32')
top_conf = results[0][:, 5]
top_boxes = results[0][:, :4]
for i, c in list(enumerate(top_label)):
predicted_class = self.class_names[int(c)]
box = top_boxes[i]
score = str(top_conf[i])
top, left, bottom, right = box
if predicted_class not in class_names:
continue
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
f.close()
return 这是我的get_map.py代码。根据我提供的这3段代码,生成一段可以直接使用的代码
最新发布