请给出具体且详细的步骤,以下是我的代码:# train.py (完整最终版 - 完全切换到在线数据增强)
import os
import torch
import logging
import json
import copy
# --- Detectron2 Imports ---
import detectron2.utils.comm as comm
from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog, DatasetCatalog, build_detection_train_loader
from detectron2.data.datasets import register_coco_instances
from detectron2.evaluation import COCOEvaluator
from detectron2.projects import point_rend
# --- [新增] 导入数据增强和工具模块 ---
from detectron2.data import detection_utils as utils
import detectron2.data.transforms as T
# --- 用于格式化输出的辅助函数 (无变化) ---
def print_section_header(title):
"""打印分节标题"""
print("\n" + "="*10 + f" {title} " + "="*10)
def print_script_header(title):
"""打印脚本顶部标题"""
print("="*50)
print(f"{title:^50}")
print("="*50 + "\n")
def print_script_footer(title):
"""打印脚本底部脚注"""
print("\n" + "="*50)
print(f"{title:^50}")
print("="*50)
# --- [新增] 自定义数据加载器 (在线数据增强核心) ---
def custom_mapper(dataset_dict):
"""
自定义数据加载和增强函数 (Mapper)。
在每次读取图像时,实时、随机地应用一系列数据增强,以提高模型的鲁棒性。
"""
# 必须深拷贝,否则会修改原始的全局数据集字典
dataset_dict = copy.deepcopy(dataset_dict)
# 1. 从文件读取原始图像
image = utils.read_image(dataset_dict["file_name"], format="BGR")
# 2. 定义一系列强大的在线数据增强操作
# 这些操作会以随机参数实时应用到每张图像上
augs = T.AugmentationList([
# --- 颜色和光照增强 (模拟不同天气和光照条件) ---
T.RandomBrightness(0.8, 1.2),
T.RandomContrast(0.8, 1.2),
T.RandomSaturation(0.8, 1.2),
# --- 几何变换 ---
T.RandomFlip(prob=0.5, horizontal=True, vertical=False), # 水平翻转
# --- 模拟图像质量下降和干扰 ---
# 随机应用高斯模糊 (模拟失焦或运动模糊)
T.RandomApply(T.GaussianBlur(sigma=(0.2, 1.5)), prob=0.4),
# 随机应用随机擦除 (模拟部分遮挡)
T.RandomApply(T.RandomErasing(scale=(0.02, 0.1), ratio=(0.3, 3.3)), prob=0.5),
# --- 尺度变换 (多尺度训练) ---
T.ResizeShortestEdge(
short_edge_length=cfg.INPUT.MIN_SIZE_TRAIN, # 从配置中获取尺度范围
max_size=cfg.INPUT.MAX_SIZE_TRAIN,
sample_style=cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING # 'choice' 或 'range'
)
])
# 3. 应用增强
aug_input = T.AugInput(image)
transforms = augs(aug_input)
image_transformed = aug_input.image
# 4. 转换图像格式以适应模型输入
dataset_dict["image"] = torch.as_tensor(image_transformed.transpose(2, 0, 1).astype("float32"))
# 5. 对标注 (边界框、掩码) 应用相同的几何变换
annos = [
utils.transform_instance_annotations(obj, transforms, image_transformed.shape[:2])
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(annos, image_transformed.shape[:2], mask_format="bitmask")
# 6. 过滤掉增强后可能变为空的实例
dataset_dict["instances"] = utils.filter_empty_instances(instances)
return dataset_dict
# --- 1. 数据集注册 (无变化) ---
def setup_datasets(dataset_root="dataset"):
# ... (这部分代码与你之前提供的版本完全相同,此处省略) ...
"""
注册所有数据集,并返回训练集的类别名称列表。
"""
print_section_header("数据集注册与类别名称提取")
class_names = None
if not os.path.isdir(dataset_root):
print(f"[错误] 基础数据集目录 '{dataset_root}' 未找到。无法继续进行数据集注册。")
return None
for subset in ["train", "valid", "test"]:
print(f"[信息] 尝试注册数据集: 'cable_{subset}'")
json_file = os.path.join(dataset_root, subset, "_annotations.coco.json")
image_root = os.path.join(dataset_root, subset)
if not os.path.exists(json_file):
print(f"[警告] 未找到 'cable_{subset}' 的标注文件: {json_file},跳过注册。")
if subset == "train":
print("[错误] 训练集标注文件缺失,无法提取类别名称。")
return None
continue
if not os.path.isdir(image_root):
print(f"[警告] 未找到 'cable_{subset}' 的图像目录: {image_root},跳过注册。")
continue
try:
register_coco_instances(
name=f"cable_{subset}",
metadata={},
json_file=json_file,
image_root=image_root
)
print(f"[成功] 已注册 'cable_{subset}'")
except Exception as e:
print(f"[错误] 注册 'cable_{subset}' 失败: {e}")
if subset == "train": return None
continue
print("[信息] 尝试从 'cable_train' 提取类别名称...")
train_json_path = os.path.join(dataset_root, "train", "_annotations.coco.json")
if not os.path.exists(train_json_path):
print(f"[错误] 训练集标注文件 '{train_json_path}' 未找到。")
return None
try:
with open(train_json_path, 'r', encoding='utf-8') as f:
coco_json_data = json.load(f)
if 'categories' in coco_json_data and coco_json_data['categories']:
class_names = [cat['name'] for cat in coco_json_data['categories']]
MetadataCatalog.get("cable_train").thing_classes = class_names
# 手动为 valid 和 test 也设置元数据,确保评估时一致
if "cable_valid" in DatasetCatalog.list():
MetadataCatalog.get("cable_valid").thing_classes = class_names
if "cable_test" in DatasetCatalog.list():
MetadataCatalog.get("cable_test").thing_classes = class_names
print(f"[成功] 直接从JSON提取的类别名称: {class_names}")
else:
print("[错误] 训练JSON中未找到或为空的 'categories' 字段。")
return None
except Exception as e:
print(f"[错误] 类别名称提取过程中发生错误: {e}")
return None
if class_names:
print(f"[信息] 最终派生的类别名称: {class_names}。类别数量: {len(class_names)}")
return class_names
# --- 2. 自定义评估器Trainer (核心修改) ---
class CocoTrainer(DefaultTrainer):
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
print(f"[信息] 为数据集 '{dataset_name}' 构建 COCOEvaluator, 输出到 '{output_folder}'")
return COCOEvaluator(dataset_name, cfg, True, output_folder)
@classmethod
def build_train_loader(cls, cfg):
"""
[核心修改]
重写此方法以使用我们自定义的 `custom_mapper`。
这样,训练数据加载器就会在运行时进行在线数据增强。
"""
print("[信息] 使用自定义的 `custom_mapper` 进行在线数据增强。")
return build_detection_train_loader(cfg, mapper=custom_mapper)
# --- 3. 主训练函数 ---
def main(args):
print_script_header("Detectron2 PointRend 训练流程开始 (在线增强模式)")
class_names = setup_datasets(dataset_root="dataset")
if not class_names:
print("[严重] 未能获取到有效的类别列表。终止训练。")
return
num_classes = len(class_names)
print(f"[成功] 数据集注册完成。共找到 {num_classes} 个类别: {class_names}")
print_section_header("模型配置")
global cfg # [新增] 将 cfg 设为全局变量,以便 custom_mapper 可以访问
cfg = get_cfg()
config_file_local_path = os.path.join("detectron2-main", "projects", "PointRend", "configs", "InstanceSegmentation", "implicit_pointrend_R_50_FPN_3x_coco.yaml")
print(f"[配置] 从本地文件加载基础配置: '{config_file_local_path}'")
if not os.path.exists(config_file_local_path):
print(f"[错误] 配置文件未找到: '{config_file_local_path}'")
return
cfg.set_new_allowed(True)
cfg.merge_from_file(config_file_local_path)
# --- [新增] 配置在线数据增强相关的参数 ---
print("[配置] 设置多尺度训练参数...")
cfg.INPUT.MIN_SIZE_TRAIN = (640, 672, 704, 736, 768, 800)
cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
print(f"[配置] 设置训练数据集为: ('cable_train',)")
cfg.DATASETS.TRAIN = ("cable_train",)
print(f"[配置] 设置测试/验证数据集为: ('cable_valid',)")
cfg.DATASETS.TEST = ("cable_valid",)
cfg.OUTPUT_DIR = "model_output_pointrend_augmented" # 建议换个名字以区分
print(f"[配置] 设置输出目录为: '{cfg.OUTPUT_DIR}'")
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
print("[配置] 正在从本地加载预训练权重...")
local_weight_path = os.path.join("pretrained_models", "model_final_f17282.pkl")
if os.path.exists(local_weight_path):
cfg.MODEL.WEIGHTS = local_weight_path
print(f"[成功] 预训练权重路径已设置为: '{local_weight_path}'")
else:
print(f"[警告] 本地预训练权重文件未找到: '{local_weight_path}'。模型将从头开始训练。")
cfg.MODEL.WEIGHTS = ""
print(f"[配置] 为 {num_classes} 个类别调整模型。")
cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes
if hasattr(cfg.MODEL, "POINT_HEAD"):
cfg.MODEL.POINT_HEAD.NUM_CLASSES = num_classes
# --- 修复 PointRend 缺失配置的健壮写法 ---
if "PointRend" in cfg.MODEL.META_ARCHITECTURE:
cfg.MODEL.ROI_MASK_HEAD.NUM_CLASSES = num_classes
# PointRend的配置文件可能不包含这些,手动添加以确保兼容性
if not hasattr(cfg.MODEL.ROI_MASK_HEAD, "IN_FEATURES") or not cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES:
cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = cfg.MODEL.ROI_HEADS.IN_FEATURES
if not hasattr(cfg.MODEL.ROI_MASK_HEAD, "OUTPUT_SIDE_RESOLUTION"):
cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION = 7
print_section_header("超参数设置")
cfg.DATALOADER.NUM_WORKERS = 2
print(f"[配置] DATALOADER.NUM_WORKERS (数据加载器工作进程数): {cfg.DATALOADER.NUM_WORKERS}")
cfg.SOLVER.IMS_PER_BATCH = 3
print(f"[配置] SOLVER.IMS_PER_BATCH (每批图像数): {cfg.SOLVER.IMS_PER_BATCH}")
cfg.SOLVER.BASE_LR = 0.00025
print(f"[配置] SOLVER.BASE_LR (基础学习率): {cfg.SOLVER.BASE_LR}")
# 使用余弦退火学习率调度器,比阶梯下降更平滑
cfg.SOLVER.LR_SCHEDULER_NAME = "WarmupCosineLR"
print(f"[配置] SOLVER.LR_SCHEDULER_NAME (学习率调度器): {cfg.SOLVER.LR_SCHEDULER_NAME}")
cfg.SOLVER.MAX_ITER = 10000
print(f"[配置] SOLVER.MAX_ITER (最大迭代次数): {cfg.SOLVER.MAX_ITER}")
cfg.SOLVER.STEPS = [] # 在使用 WarmupCosineLR 时,此项应为空
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
print(f"[配置] MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE (每图RoI批大小): {cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE}")
cfg.TEST.EVAL_PERIOD = 500
print(f"[配置] TEST.EVAL_PERIOD (评估周期): {cfg.TEST.EVAL_PERIOD} 次迭代")
print("[信息] 执行默认设置 (日志、环境检查等)...")
default_setup(cfg, args)
print_section_header("训练器初始化与模型结构保存")
print("[信息] 正在初始化 CocoTrainer...")
trainer = CocoTrainer(cfg)
if comm.is_main_process():
# ... (模型结构保存代码无变化,此处省略) ...
model_arch_file_path = os.path.join(cfg.OUTPUT_DIR, "model_architecture.txt")
try:
with open(model_arch_file_path, "w", encoding="utf-8") as f:
f.write(f"Model Configuration Path: {config_file_local_path}\n")
f.write(f"Number of Classes: {num_classes}\n\n")
f.write("Model Architecture:\n")
f.write("="*80 + "\n")
f.write(str(trainer.model))
print(f"[成功] 模型结构已保存到: {model_arch_file_path}")
except Exception as e:
print(f"[错误] 保存模型结构到文件 '{model_arch_file_path}' 失败: {e}")
resume_training = args.resume
print(f"[信息] 训练器将尝试加载权重。是否从检查点恢复: {resume_training}")
trainer.resume_or_load(resume=resume_training)
print_section_header("训练开始")
print(f"[信息] 开始使用 PointRend R-CNN 进行训练,共 {cfg.SOLVER.MAX_ITER} 次迭代...")
try:
trainer.train()
print("\n[成功] 训练成功完成!")
except Exception as e:
print(f"\n[错误] 训练过程中发生错误: {e}")
finally:
print_script_footer("Detectron2 PointRend 训练流程结束")
if __name__ == "__main__":
parser = default_argument_parser()
args = parser.parse_args()
print_script_header("命令行参数")
print("生效的命令行参数:")
for arg, value in sorted(vars(args).items()):
print(f" {arg}: {value}")
print("\n[信息] 正在启动训练过程...")
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
最新发布