关于break MISSING_BLOCK_LABEL成对出现

源码:

ThreadGate gate;
		gate = getThreadGate();
		if (!gate.isActiveThread()) {
			gate.enter(-1L);
			evict(((Map) (getFrontMap())));
			evict(getBackMap());
			processDeferredEvents(true);
			gate.exit();
		}

反编译后:

ThreadGate gate;
        gate = getThreadGate();
        if(gate.isActiveThread())
            break MISSING_BLOCK_LABEL_53;
        gate.enter(-1L);
        evict(((Map) (getFrontMap())));
        evict(getBackMap());
        processDeferredEvents(true);
        gate.exit();
        break MISSING_BLOCK_LABEL_53;
        Exception exception;
        exception;
        gate.exit();
        throw exception;


import os import shutil import cv2 import numpy as np from scipy.spatial import distance_matrix from glob import glob from tqdm import tqdm from ultralytics import YOLO class AdvancedLabelValidator: def __init__(self, template_img_path, template_label_path, pixel_tolerance=0.02, min_matches=10): """ :param template_img_path: 模板图片路径 :param template_label_path: 模板标签路径(YOLO格式) :param pixel_tolerance: 精调阶段距离容差(图像对角线比例) :param min_matches: 特征点匹配的最低有效点数 """ # 加载模板数据 self.template_img = cv2.imread(template_img_path) if self.template_img is None: raise ValueError(f"无法加载模板图片: {template_img_path}") self.template_bboxes = self._load_yolo_bboxes(template_label_path, self.template_img.shape) self.pixel_tolerance = pixel_tolerance # 初始化SIFT特征检测器 self.sift = cv2.SIFT_create() self.flann = cv2.FlannBasedMatcher( dict(algorithm=1, trees=5), dict(checks=50) ) self.min_matches = min_matches # 预计算模板特征点 self.template_kp, self.template_des = self.sift.detectAndCompute( cv2.cvtColor(self.template_img, cv2.COLOR_BGR2GRAY), None ) def _load_yolo_bboxes(self, label_path, img_shape): """加载YOLO格式标签并转换为像素坐标""" img_h, img_w = img_shape[:2] bboxes = [] if not os.path.exists(label_path): raise FileNotFoundError(f"标签文件不存在: {label_path}") with open(label_path) as f: for line in f.readlines(): parts = line.strip().split() if len(parts) != 5: continue _, x_center, y_center, width, height = map(float, parts) x1 = int((x_center - width / 2) * img_w) y1 = int((y_center - height / 2) * img_h) x2 = int((x_center + width / 2) * img_w) y2 = int((y_center + height / 2) * img_h) bboxes.append([x1, y1, x2, y2]) return bboxes def _align_with_sift(self, test_img): """使用SIFT特征点进行粗对齐""" gray_test = cv2.cvtColor(test_img, cv2.COLOR_BGR2GRAY) kp_test, des_test = self.sift.detectAndCompute(gray_test, None) # 特征匹配 matches = self.flann.knnMatch(self.template_des, des_test, k=2) # Lowe's ratio test筛选优质匹配 good_matches = [] for m, n in matches: if m.distance < 0.7 * n.distance: good_matches.append(m) if len(good_matches) < self.min_matches: print(f"警告:仅找到{len(good_matches)}个匹配点,低于阈值{self.min_matches}") return None # 计算单应性矩阵 src_pts = np.float32([self.template_kp[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2) dst_pts = np.float32([kp_test[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2) H, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) return H def _calculate_max_distance(self, img_size): """计算自适应距离阈值""" return self.pixel_tolerance * np.sqrt(img_size[0] ** 2 + img_size[1] ** 2) def validate_labels(self, test_img_path, test_label_path): """ 改进版验证流程: 1. SIFT特征点粗对齐 2. 距离精调匹配 """ test_img = cv2.imread(test_img_path) if test_img is None: raise ValueError(f"无法加载测试图片: {test_img_path}") test_bboxes = self._load_yolo_bboxes(test_label_path, test_img.shape) # 阶段1:SIFT粗对齐 H = self._align_with_sift(test_img) if H is None: print("特征点匹配失败,退回基础距离匹配") return self._basic_distance_match(test_img, test_bboxes) # 变换模板bbox到测试图像空间 aligned_bboxes = [] for bbox in self.template_bboxes: corners = np.array([ [[bbox[0], bbox[1]]], [[bbox[2], bbox[3]]] ], dtype=np.float32) transformed = cv2.perspectiveTransform(corners, H) x1, y1 = transformed[0][0] x2, y2 = transformed[1][0] aligned_bboxes.append([x1, y1, x2, y2]) # 阶段2:距离精调匹配 return self._refined_distance_match(aligned_bboxes, test_bboxes, test_img.shape) def _basic_distance_match(self, test_img, test_bboxes): """基础距离匹配(无对齐)""" template_centers = np.array([[(x1 + x2) / 2, (y1 + y2) / 2] for x1, y1, x2, y2 in self.template_bboxes]) test_centers = np.array([[(x1 + x2) / 2, (y1 + y2) / 2] for x1, y1, x2, y2 in test_bboxes]) return self._match_centers(template_centers, test_centers, test_img.shape) def _refined_distance_match(self, aligned_bboxes, test_bboxes, img_shape): """对齐后的精调距离匹配""" aligned_centers = np.array([[(x1 + x2) / 2, (y1 + y2) / 2] for x1, y1, x2, y2 in aligned_bboxes]) test_centers = np.array([[(x1 + x2) / 2, (y1 + y2) / 2] for x1, y1, x2, y2 in test_bboxes]) return self._match_centers(aligned_centers, test_centers, img_shape) def _match_centers(self, ref_centers, test_centers, img_size): """核心匹配逻辑""" dist_mat = distance_matrix(ref_centers, test_centers) max_dist = self._calculate_max_distance(img_size[:2][::-1]) matched_pairs = [] missing_indices = list(range(len(ref_centers))) extra_indices = list(range(len(test_centers))) # 双向匹配 for ref_idx in range(len(ref_centers)): if len(test_centers) == 0: break min_dist_idx = np.argmin(dist_mat[ref_idx]) min_dist = dist_mat[ref_idx, min_dist_idx] if min_dist <= max_dist: matched_pairs.append((ref_idx, min_dist_idx, min_dist)) if ref_idx in missing_indices: missing_indices.remove(ref_idx) if min_dist_idx in extra_indices: extra_indices.remove(min_dist_idx) return { 'matched': matched_pairs, 'missing': missing_indices, 'extra': extra_indices, 'max_distance': max_dist } def visualize_validation(self, test_img_path, test_label_path, output_path=None): """可视化验证结果(增强版)""" test_img = cv2.imread(test_img_path) test_bboxes = self._load_yolo_bboxes(test_label_path, test_img.shape) result = self.validate_labels(test_img_path, test_label_path) # 绘制对齐后的模板bbox(半透明蓝色) H = self._align_with_sift(test_img) if H is not None: for i, bbox in enumerate(self.template_bboxes): corners = np.array([[[bbox[0], bbox[1]]], [[bbox[2], bbox[3]]]], dtype=np.float32) transformed = cv2.perspectiveTransform(corners, H) x1, y1 = transformed[0][0] x2, y2 = transformed[1][0] overlay = test_img.copy() cv2.rectangle(overlay, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), -1) cv2.addWeighted(overlay, 0.2, test_img, 0.8, 0, test_img) cv2.putText(test_img, f"Aligned T{i}", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2) # 绘制匹配结果 for ref_idx, test_idx, dist in result['matched']: bbox = test_bboxes[test_idx] cv2.rectangle(test_img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 0), 2) cv2.putText(test_img, f"M(D={dist:.1f})", (bbox[0], bbox[3] + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) # 绘制漏检 for i in result['missing']: if H is not None: try: # 构造角点数组(左上和右下角) corners = np.array([[ [self.template_bboxes[i][0], self.template_bboxes[i][1]], [self.template_bboxes[i][2], self.template_bboxes[i][3]] ]], dtype=np.float32) # 执行透视变换 transformed = cv2.perspectiveTransform(corners, H) # 提取变换后坐标 (x1, y1), (x2, y2) = transformed[0] # 绘制红色虚线框 cv2.rectangle(test_img, (int(x1), int(y1)), (int(x2), int(y2)), (0, 0, 255), 2, cv2.LINE_AA) cv2.putText(test_img, "Missing", (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2) except Exception as e: print(f"绘制漏检区域时出错: {str(e)}") continue # 绘制错检 for i in result['extra']: bbox = test_bboxes[i] cv2.rectangle(test_img, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 255, 255), 2) cv2.putText(test_img, "Extra", (bbox[0], bbox[3] + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2) if output_path: os.makedirs(os.path.dirname(output_path), exist_ok=True) cv2.imwrite(output_path, test_img) return test_img def validate_dataset(template_img_path, template_label_path, test_img_dir, test_label_dir, validate_result): """ 完整验证流程: 1. 加载模板 2. 遍历测试集验证 3. 生成报告 """ validator = AdvancedLabelValidator(template_img_path, template_label_path) test_images = sorted(glob(os.path.join(test_img_dir, "*.*"))) results = [] for img_path in tqdm(test_images, desc="验证中"): base_name = os.path.splitext(os.path.basename(img_path))[0] label_path = os.path.join(test_label_dir, f"{base_name}.txt") if not os.path.exists(label_path): print(f"警告:缺失标签文件 {label_path}") continue try: # 验证并保存结果 os.makedirs(validate_result, exist_ok=True) output_path = os.path.join(validate_result, f"validated_{base_name}.jpg") # 执行验证 validator.visualize_validation(img_path, label_path, output_path) result = validator.validate_labels(img_path, label_path) results.append({ 'image': base_name, 'matched': len(result['matched']), 'missing': len(result['missing']), 'extra': len(result['extra']), 'result_image': output_path }) except Exception as e: print(f"处理 {img_path} 时出错: {str(e)}") generate_report(results) return results def generate_report(results): """生成交互式HTML报告""" html = f"""<!DOCTYPE html> <html> <head> <title>焊点检测报告</title> <style> body {{ font-family: Arial; margin: 20px; }} .summary {{ background: #f0f0f0; padding: 15px; border-radius: 5px; }} table {{ width: 100%; border-collapse: collapse; margin-top: 20px; }} th, td {{ padding: 10px; border: 1px solid #ddd; text-align: center; }} th {{ background: #4CAF50; color: white; }} tr:nth-child(even) {{ background: #f2f2f2; }} img {{ max-width: 300px; cursor: pointer; transition: transform 0.3s; }} img:hover {{ transform: scale(1.5); }} .modal {{ display: none; position: fixed; z-index: 1; padding-top: 100px; left: 0; top: 0; width: 100%; height: 100%; overflow: auto; background-color: rgba(0,0,0,0.9); }} .modal-content {{ margin: auto; display: block; max-width: 80%; }} .close {{ position: absolute; top: 15px; right: 35px; color: #f1f1f1; font-size: 40px; font-weight: bold; cursor: pointer; }} </style> </head> <body> <h1>焊点检测质量报告</h1> <div class="summary"> <h2>汇总统计</h2> <p>总检测图片: {len(results)}</p> <p>总匹配焊点: {sum([r['matched'] for r in results])}</p> <p>平均漏检: {np.mean([r['missing'] for r in results]):.1f}</p> <p>平均错检: {np.mean([r['extra'] for r in results]):.1f}</p> </div> <table> <tr> <th>图片名称</th> <th>匹配数</th> <th>漏检数</th> <th>错检数</th> <th>结果预览</th> </tr> """ for r in results: html += f""" <tr> <td>{r['image']}</td> <td>{r['matched']}</td> <td>{r['missing']}</td> <td>{r['extra']}</td> <td><img src="{r['result_image']}" onclick="openModal('{r['result_image']}')"></td> </tr> """ html += """ </table> <div id="imageModal" class="modal"> <span class="close" onclick="closeModal()">×</span> <img class="modal-content" id="modalImage"> </div> <script> function openModal(src) { document.getElementById('modalImage').src = src; document.getElementById('imageModal').style.display = "block"; } function closeModal() { document.getElementById('imageModal').style.display = "none"; } </script> </body> </html> """ with open("validation_report.html", "w") as f: f.write(html) print("报告已生成: validation_report.html") def detect_and_save_labels(input_dir, output_label_dir, model_path): """检测并保存标签到指定目录""" model = YOLO(model_path) # 创建临时目录(YOLO默认会生成多层目录) temp_dir = os.path.join("yolo_output", "output") # 执行检测(保存到临时目录) results = model.predict( source=input_dir, project="yolo_output", name='output', save_txt=True, save=True, conf=0.336, iou=0.2 ) # 复制标签文件到指定目录 os.makedirs(output_label_dir, exist_ok=True) for txt_file in glob(os.path.join(temp_dir, "labels", "*.txt")): shutil.copy(txt_file, output_label_dir) # 清理临时目录 shutil.rmtree("yolo_output") print(f"标签已保存到: {output_label_dir}") if __name__ == "__main__": # 配置路径 template_img = "1/moban/IMG_20250617_094520.jpg" template_label = "1/moban/IMG_20250617_094520.txt" test_img_dir = "1/images" test_label_dir = "1/labels" validate_result = "1/result" model_path = 'runs/detect/train28/weights/last.pt' # 1. 使用YOLO检测图片并生成标签 detect_and_save_labels(test_img_dir, test_label_dir, model_path) # 2. 验证 validate_dataset(template_img, template_label, test_img_dir, test_label_dir, validate_result)优化一下对齐策略,希望对齐时先用sift特征进行粗对齐,将模板图片进行旋转和移动使其和待测图片对齐,之后利用旋转移动后的模板标签和待测图标签进行对齐,此次对齐的主要依据是每个目标框的中心点,根据这个思路优化改进对齐思路,生成一版单独的对齐所用的代码
最新发布
09-04
import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import monai from monai.transforms import ( Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, RandCropByPosNegLabeld, RandFlipd, RandRotate90d, EnsureTyped, Activations, AsDiscrete, Resized, RandZoomd, RandGaussianNoised, CenterSpatialCropd ) from monai.data import list_data_collate, decollate_batch from monai.networks.nets import SwinUNETR from monai.losses import DiceCELoss from monai.metrics import DiceMetric from glob import glob from sklearn.model_selection import train_test_split from monai.data import PersistentDataset from torch.optim.lr_scheduler import LambdaLR from tqdm import tqdm from torch.cuda.amp import GradScaler, autocast import matplotlib.pyplot as plt # ======= 配置参数 ======= root_dir = "datasets/LiTS/processed" images_dir = os.path.join(root_dir, "images") labels_dir = os.path.join(root_dir, "labels") max_epochs = 200 batch_size = 1 learning_rate = 1e-4 num_classes = 3 # 背景、肝脏、肿瘤 warmup_epochs = 10 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") use_amp = False # 自动混合精度训练 def get_valid_size(size): return tuple([max(32, (s // 32) * 32) for s in size]) base_size = (128, 128, 64) resized_size = get_valid_size(base_size) crop_size = get_valid_size((64, 64, 32)) print(f"使用尺寸: Resized={resized_size}, Crop={crop_size}") # ======= 跨模态注意力融合模块 ======= class LightCrossAttentionFusion(nn.Module): def __init__(self, img_feat_dim=192, text_feat_dim=512, num_heads=2): super().__init__() self.num_heads = num_heads self.img_feat_dim = img_feat_dim self.text_feat_dim = text_feat_dim self.query_proj = nn.Linear(img_feat_dim, img_feat_dim) self.key_proj = nn.Linear(text_feat_dim, img_feat_dim) self.value_proj = nn.Linear(text_feat_dim, img_feat_dim) self.out_proj = nn.Linear(img_feat_dim, img_feat_dim) def forward(self, img_feat, text_feat): B, C, D, H, W = img_feat.shape N = D * H * W img_feat_flat = img_feat.view(B, C, N).permute(0, 2, 1) # (B, N, C) Q = self.query_proj(img_feat_flat) # (B, N, C) K = self.key_proj(text_feat).unsqueeze(1) # (B, 1, C) V = self.value_proj(text_feat).unsqueeze(1) # (B, 1, C) head_dim = C // self.num_heads Q = Q.view(B, N, self.num_heads, head_dim).permute(0, 2, 1, 3) K = K.view(B, 1, self.num_heads, head_dim).permute(0, 2, 1, 3) V = V.view(B, 1, self.num_heads, head_dim).permute(0, 2, 1, 3) scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5) attn = torch.softmax(scores, dim=-2) context = torch.matmul(attn, V) context = context.permute(0, 2, 1, 3).contiguous().view(B, N, C) out = self.out_proj(context) out = out.permute(0, 2, 1).view(B, C, D, H, W) fused = img_feat + out return fused # ======= 继承SwinUNETR加入融合 ======= class SwinUNETRWithCLIPFusion(SwinUNETR): def __init__(self, img_size, in_channels, out_channels, feature_size=12, use_checkpoint=False, text_feat_dim=512): super().__init__(img_size=img_size, in_channels=in_channels, out_channels=out_channels, feature_size=feature_size, use_checkpoint=use_checkpoint) fusion_img_feat_dim = feature_size * 16 # 原始是192 self.fusion = LightCrossAttentionFusion(img_feat_dim=fusion_img_feat_dim, text_feat_dim=text_feat_dim) self.fusion_reduce = nn.Conv3d(fusion_img_feat_dim, feature_size, kernel_size=1) # 跳跃连接通道也统一降维到 feature_size self.skip4_reduce = nn.Conv3d(96, feature_size, kernel_size=1) self.skip3_reduce = nn.Conv3d(48, feature_size, kernel_size=1) self.skip2_reduce = nn.Conv3d(24, feature_size, kernel_size=1) self.skip1_reduce = nn.Conv3d(12, feature_size, kernel_size=1) # 自定义解码器,每个 block 都是 in=feature_size, skip=feature_size from monai.networks.blocks import UnetUpBlock self.decoder1 = UnetUpBlock( spatial_dims=3, in_channels=feature_size, out_channels=feature_size, kernel_size=3, # 卷积核大小 stride=1, # 主路径不下采样 upsample_kernel_size=2, # 上采样倍数 norm_name="instance", ) self.decoder2 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance") self.decoder3 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance") self.decoder4 = UnetUpBlock(3, feature_size, feature_size, kernel_size=3, stride=1, upsample_kernel_size=2, norm_name="instance") self.decoder5 = nn.ConvTranspose3d(feature_size, feature_size, kernel_size=2, stride=2) def forward(self, x, text_feat=None): enc_out_list = self.swinViT(x) if not hasattr(self, "printed_shapes"): print("enc_out_list 通道数:", [feat.shape[1] for feat in enc_out_list]) print("enc_out_list 各层特征图尺寸:", [feat.shape for feat in enc_out_list]) self.printed_shapes = True enc_out = enc_out_list[-1] # [B, 192, 2, 2, 1] if text_feat is not None: enc_out = self.fusion(enc_out, text_feat) enc_out = self.fusion_reduce(enc_out) # [B, 12, ...] # 降维跳跃连接 skip4 = self.skip4_reduce(enc_out_list[-2]) # 96 → 12 skip3 = self.skip3_reduce(enc_out_list[-3]) # 48 → 12 skip2 = self.skip2_reduce(enc_out_list[-4]) # 24 → 12 skip1 = self.skip1_reduce(enc_out_list[-5]) # 12 → 12 # 解码路径(全部通道为 feature_size) d1 = self.decoder1(enc_out, skip4) d2 = self.decoder2(d1, skip3) d3 = self.decoder3(d2, skip2) d4 = self.decoder4(d3, skip1) d5 = self.decoder5(d4) out = self.out(d5) return out # ======= 数据预处理 ======= train_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=crop_size, pos=1.0, neg=1.0, num_samples=1, image_key="image", image_threshold=0, ), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1), RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2), RandRotate90d(keys=["image", "label"], prob=0.5, max_k=3), RandZoomd(keys=["image", "label"], prob=0.5, min_zoom=0.8, max_zoom=1.2, mode=("trilinear", "nearest")), RandGaussianNoised(keys=["image"], prob=0.2, mean=0.0, std=0.1), EnsureTyped(keys=["image", "label"]), ]) val_transforms = Compose([ LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Orientationd(keys=["image", "label"], axcodes="RAS"), Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True), Resized(keys=["image", "label"], spatial_size=resized_size, mode=("trilinear", "nearest")), CenterSpatialCropd(keys=["image", "label"], roi_size=crop_size), EnsureTyped(keys=["image", "label"]), ]) images = sorted(glob(os.path.join(images_dir, "*.nii.gz"))) labels = sorted(glob(os.path.join(labels_dir, "*.nii.gz"))) data = [{"image": img, "label": lbl} for img, lbl in zip(images, labels)] train_files, val_files = train_test_split(data, test_size=0.2, random_state=42) train_ds = PersistentDataset(data=train_files, transform=train_transforms, cache_dir="./cache/train") val_ds = PersistentDataset(data=val_files, transform=val_transforms, cache_dir="./cache/val") train_loader = DataLoader( train_ds, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=list_data_collate, pin_memory=True ) val_loader = DataLoader( val_ds, batch_size=1, shuffle=False, num_workers=2, collate_fn=list_data_collate, pin_memory=True ) # ======= 加载预提取文本特征 ======= clip_text_features = np.load("./clip_text_features.npy") # shape (num_prompts, 512) clip_text_features = torch.from_numpy(clip_text_features).float() def get_text_features_for_batch(batch_size, clip_text_features): if clip_text_features.shape[0] >= batch_size: return clip_text_features[:batch_size].to(device) else: return clip_text_features.repeat(batch_size, 1).to(device) # ======= 模型、损失、优化器、调度器 ======= model = SwinUNETRWithCLIPFusion( img_size=crop_size, in_channels=1, out_channels=num_classes, feature_size=12, use_checkpoint=True, text_feat_dim=512, ).to(device) class_weights = torch.tensor([0.2, 0.3, 0.5]).to(device) loss_function = DiceCELoss( to_onehot_y=True, softmax=True, include_background=True, ce_weight=class_weights, lambda_dice=0.5, lambda_ce=0.5 ) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5) def lr_lambda(epoch): if epoch < warmup_epochs: return (epoch + 1) / warmup_epochs progress = (epoch - warmup_epochs) / (max_epochs - warmup_epochs) return 0.5 * (1 + np.cos(np.pi * progress)) scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) post_pred = Compose([Activations(softmax=True), AsDiscrete(argmax=True, to_onehot=num_classes)]) post_label = Compose([AsDiscrete(to_onehot=num_classes)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False, num_classes=num_classes) scaler = torch.cuda.amp.GradScaler(enabled=use_amp) best_metric = -1 best_metric_epoch = -1 train_loss_history = [] val_dice_history = [] os.makedirs("fusion_checkpoints", exist_ok=True) os.makedirs("logs", exist_ok=True) for epoch in range(max_epochs): print(f"\nEpoch {epoch+1}/{max_epochs}") model.train() epoch_loss = 0 step = 0 with tqdm(total=len(train_loader), desc=f"训练 Epoch {epoch+1}") as pbar: for batch_data in train_loader: step += 1 inputs = batch_data["image"].to(device) labels = batch_data["label"].to(device) batch_size_now = inputs.shape[0] text_feat = get_text_features_for_batch(batch_size_now, clip_text_features) optimizer.zero_grad() with autocast(enabled=use_amp): outputs = model(inputs, text_feat) loss = loss_function(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() epoch_loss += loss.item() pbar.update(1) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) epoch_loss /= step train_loss_history.append(epoch_loss) print(f"训练平均损失: {epoch_loss:.4f}") scheduler.step() current_lr = optimizer.param_groups[0]['lr'] print(f"当前学习率: {current_lr:.7f}") model.eval() dice_values = [] with torch.no_grad(), tqdm(total=len(val_loader), desc=f"验证 Epoch {epoch+1}") as pbar: for val_data in val_loader: val_images = val_data["image"].to(device) val_labels = val_data["label"].to(device) batch_size_val = val_images.shape[0] text_feat_val = get_text_features_for_batch(batch_size_val, clip_text_features) with autocast(enabled=use_amp): val_outputs = model(val_images, text_feat_val) val_outputs_list = decollate_batch(val_outputs.detach().cpu()) val_labels_list = decollate_batch(val_labels.detach().cpu()) val_output_convert = [post_pred(x) for x in val_outputs_list] val_label_convert = [post_label(x) for x in val_labels_list] dice_metric(y_pred=val_output_convert, y=val_label_convert) metric = dice_metric.aggregate().item() dice_values.append(metric) dice_metric.reset() pbar.update(1) pbar.set_postfix({"dice": f"{metric:.4f}"}) avg_metric = np.mean(dice_values) val_dice_history.append(avg_metric) print(f"验证平均Dice: {avg_metric:.4f}") if avg_metric > best_metric: best_metric = avg_metric best_metric_epoch = epoch + 1 torch.save(model.state_dict(), f"fusion_checkpoints/best_model_epoch{best_metric_epoch}_dice{best_metric:.4f}.pth") print(f"保存新的最佳模型! Epoch: {best_metric_epoch}, Dice: {best_metric:.4f}") if (epoch + 1) % 10 == 0: torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': epoch_loss, 'dice': avg_metric }, f"fusion_checkpoints/checkpoint_epoch_{epoch+1}.pth") plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.plot(train_loss_history, label='训练损失') plt.title('训练损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.subplot(1, 2, 2) plt.plot(val_dice_history, label='验证Dice', color='orange') plt.title('验证Dice系数') plt.xlabel('Epoch') plt.ylabel('Dice') plt.legend() plt.tight_layout() plt.savefig("logs/fusion_training_metrics.png") plt.close() print(f"\n训练完成! 最佳Dice: {best_metric:.4f} at epoch {best_metric_epoch}") 这份代码输出有问题呀Epoch 1/200 训练 Epoch 1: 0%| | 0/104 [00:00<?, ?it/s]enc_out_list 通道数: [12, 24, 48, 96, 192] enc_out_list 各层特征图尺寸: [torch.Size([1, 12, 32, 32, 16]), torch.Size([1, 24, 16, 16, 8]), torch.Size([1, 48, 8, 8, 4]), torch.Size([1, 96, 4, 4, 2]), torch.Size([1, 192, 2, 2, 1])] 训练 Epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:14<00:00, 7.14it/s, loss=0.7403] 训练平均损失: 0.7277 当前学习率: 0.0000200 验证 Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.48it/s, dice=0.0361] 验证平均Dice: 0.0469 保存新的最佳模型! Epoch: 1, Dice: 0.0469 /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:350: UserWarning: Glyph 25439 (\N{CJK UNIFIED IDEOGRAPH-635F}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:350: UserWarning: Glyph 22833 (\N{CJK UNIFIED IDEOGRAPH-5931}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:350: UserWarning: Glyph 35757 (\N{CJK UNIFIED IDEOGRAPH-8BAD}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:350: UserWarning: Glyph 32451 (\N{CJK UNIFIED IDEOGRAPH-7EC3}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:350: UserWarning: Glyph 39564 (\N{CJK UNIFIED IDEOGRAPH-9A8C}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:350: UserWarning: Glyph 35777 (\N{CJK UNIFIED IDEOGRAPH-8BC1}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:350: UserWarning: Glyph 31995 (\N{CJK UNIFIED IDEOGRAPH-7CFB}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:350: UserWarning: Glyph 25968 (\N{CJK UNIFIED IDEOGRAPH-6570}) missing from current font. plt.tight_layout() /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:351: UserWarning: Glyph 25439 (\N{CJK UNIFIED IDEOGRAPH-635F}) missing from current font. plt.savefig("logs/fusion_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:351: UserWarning: Glyph 22833 (\N{CJK UNIFIED IDEOGRAPH-5931}) missing from current font. plt.savefig("logs/fusion_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:351: UserWarning: Glyph 35757 (\N{CJK UNIFIED IDEOGRAPH-8BAD}) missing from current font. plt.savefig("logs/fusion_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:351: UserWarning: Glyph 32451 (\N{CJK UNIFIED IDEOGRAPH-7EC3}) missing from current font. plt.savefig("logs/fusion_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:351: UserWarning: Glyph 39564 (\N{CJK UNIFIED IDEOGRAPH-9A8C}) missing from current font. plt.savefig("logs/fusion_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:351: UserWarning: Glyph 35777 (\N{CJK UNIFIED IDEOGRAPH-8BC1}) missing from current font. plt.savefig("logs/fusion_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:351: UserWarning: Glyph 31995 (\N{CJK UNIFIED IDEOGRAPH-7CFB}) missing from current font. plt.savefig("logs/fusion_training_metrics.png") /home/liulicheng/MultiModal_MedSeg_2025/train/train_swinunetr_clipfusion.py:351: UserWarning: Glyph 25968 (\N{CJK UNIFIED IDEOGRAPH-6570}) missing from current font. plt.savefig("logs/fusion_training_metrics.png") Epoch 2/200 训练 Epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.71it/s, loss=0.7325] 训练平均损失: 0.7216 当前学习率: 0.0000300 验证 Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.44it/s, dice=0.0592] 验证平均Dice: 0.0790 保存新的最佳模型! Epoch: 2, Dice: 0.0790 Epoch 3/200 训练 Epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.72it/s, loss=0.7141] 训练平均损失: 0.7082 当前学习率: 0.0000400 验证 Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.48it/s, dice=0.1195] 验证平均Dice: 0.1622 保存新的最佳模型! Epoch: 3, Dice: 0.1622 Epoch 4/200 训练 Epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.67it/s, loss=0.6731] 训练平均损失: 0.6850 当前学习率: 0.0000500 验证 Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.50it/s, dice=0.2126] 验证平均Dice: 0.2260 保存新的最佳模型! Epoch: 4, Dice: 0.2260 Epoch 5/200 训练 Epoch 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.81it/s, loss=0.6340] 训练平均损失: 0.6579 当前学习率: 0.0000600 验证 Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.38it/s, dice=0.2866] 验证平均Dice: 0.3043 保存新的最佳模型! Epoch: 5, Dice: 0.3043 Epoch 6/200 训练 Epoch 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.64it/s, loss=0.5915] 训练平均损失: 0.6187 当前学习率: 0.0000700 验证 Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.56it/s, dice=0.3499] 验证平均Dice: 0.3778 保存新的最佳模型! Epoch: 6, Dice: 0.3778 Epoch 7/200 训练 Epoch 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.58it/s, loss=0.5379] 训练平均损失: 0.5679 当前学习率: 0.0000800 验证 Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.50it/s, dice=0.3741] 验证平均Dice: 0.4120 保存新的最佳模型! Epoch: 7, Dice: 0.4120 Epoch 8/200 训练 Epoch 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.53it/s, loss=0.4831] 训练平均损失: 0.5142 当前学习率: 0.0000900 验证 Epoch 8: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.53it/s, dice=0.4267] 验证平均Dice: 0.4663 保存新的最佳模型! Epoch: 8, Dice: 0.4663 Epoch 9/200 训练 Epoch 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.58it/s, loss=0.4233] 训练平均损失: 0.4585 当前学习率: 0.0001000 验证 Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.64it/s, dice=0.3302] 验证平均Dice: 0.3476 Epoch 10/200 训练 Epoch 10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.71it/s, loss=0.3787] 训练平均损失: 0.4152 当前学习率: 0.0001000 验证 Epoch 10: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.65it/s, dice=0.3302] 验证平均Dice: 0.3476 Epoch 11/200 训练 Epoch 11: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.68it/s, loss=0.3197] 训练平均损失: 0.3688 当前学习率: 0.0001000 验证 Epoch 11: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.59it/s, dice=0.3302] 验证平均Dice: 0.3476 Epoch 12/200 训练 Epoch 12: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 104/104 [00:13<00:00, 7.69it/s, loss=0.4502] 训练平均损失: 0.3400 当前学习率: 0.0001000 验证 Epoch 12: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:05<00:00, 4.45it/s, dice=0.3302] 验证平均Dice: 0.3476
07-01
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值