centernet损失函数修改记录

想将centernet中w,h的loss修正关联起来,采用iouloss进行优化,于是需修改loss函数,注意事项如下:

(1)loss函数本身到网络的输出层存在一个函数,这个函数求导,再逐层链式向前求导,完成训练。所以,自己写出的损失函数到pred一定要明确可导,不能乱写。

(2)写损失函数尽量使用torch自带的函数和+,-,*,/号进行操作,避免新建变量和for循环等操作,容易造成梯度无法反向传播。常用torch函数:

torch.clamp(),截断操作,限定最大最小值

torch.min, max

新建tensor需注意用原有输入tensor赋值方式:a=pred.new_tensor([0])

(3)正常操作流程如下:loss函数继承nn.module模块,编写forward过程。

(4)即便所有流程都编写正确,由于loss中存在的各种运算,还会出现计算出的梯度和weight不在一个量级,导致weight无法更新问题。比如,weight在1e-4量级,计算出的梯度在1e-10量级,这样会导致你训练几千次,weight也不会更新,进而loss没有任何改变。

(5)针对(4)中的问题,mmdet中的loss函数模板会有两个weight,一个weight为标量,可线性扩大缩小loss的数值;第二weight为tensor,与输出loss的维度一致,可精确控制每一个tensor元素对应loss的权重。

(6)即使通过(5)中操作,将loss放大较大范围,是的weight对应的grad在相同量级,可进行正常更新,还会存在问题:即损失函数设计不合理,weight较小时导致grad很小,梯度不更新,weight较大时,导致grad更新又跳变,训练发散。  总结就是:loss函数本身设计有问题,是无法通过简单的weight权重调整解决的。得尝试更换或重新设计loss函数。

我现在需要做焊接区域缺陷目标检测任务,现有的数据如下: 现有的有待测焊接接头图片(path=D:\seam\SegNet\PH2Dataset\origin image),焊接区域缺陷数据集(img path=D:\seam\SegNet\PH2Dataset\welding image\img,label path=D:\seam\SegNet\PH2Dataset\welding image\label),焊接区域匹配模板(path=D:\seam\SegNet\Multiple templates\Templates)模板是从待测焊接图像中切割出来用于匹配焊接区域的,数据集是切割出的焊接区域,按照yolo数据集格式标注的,原图和标签。 现在需要形成两个代码:这两个代码的要求是相互作用的 训练代码:需要完成下列要求 1、根据我的数据集,学习标定框选焊接缺陷 2、要求用到centerNET网络模型,必须修改网络结构,有一定的改动 3、训练的最佳记录需要保存下来,path=D:/seam/SegNet/PH2Dataset/weights/best_model_improved.pth,设定在训练轮数中,损失最小的即为最佳记录 4、训练的结果要形成一个train loss、test loss与训练次数的曲线图 5、显示五张训练结果图 6、代码需要进行数据增强,在训练时动态增强扩充数据集5倍,不要采用颜色变化的方式,其他的都可用, 7、我的数据集没有划分训练集、测试集,将扩充的图像数据按照8:2的比例划分训练集和测试集 8、直接在代码中定义centerNET模型 检测代码:需要完成下列要求 1、不用再次训练,直接使用训练代码的最佳记录,进行目标检测 2、待测图片是整个的焊接接头,必须用模板匹配,匹配出焊接区域后,将这部分的图像切割出来,再送进目标检测模型 3、检测结果需要显示模板匹配结果、切割出来的图像以及目标检测的结果 4、目标检测的结果中,需要有检测框,希望检测代码能够像YOLO算法的检测结果那样,用不同的小框,框选出检测目标 5、目标检测3个分类,0:crack,1:aluminum,2:Nocrack,这三个分类中,0,1和2是对立的,不同时出现,这是多分类问题,当检测出0,1类结果时就框选出0,1类结果,两个分别用两个识别框框选,当检测到2类结果时,就用一个框框选出来,0,1和2类结果不同时出现,只要有0,1那么就不可能有2,生成的识别框要求精准,数量强制要求,0和1类的结果中,0类结果优先显示 6、必须要有注意力特征融合模块,多特征融合模块,多模板匹配模块,多尺度检测,将模板匹配的位置切割出来在送进预测模型,这个过程也必须需要(我的数据集就是训练的切割后的区域),完成精确目标检测任务,这个框一定要像YOLO结果那样,带有置信度,大小位置合理 7、现有的路径都是存在的,不用修改 请结合我的要求,生成这两个代码,这两个代码必须协调,特别注意训练代码保存的模型参数与检测代码的使用时的参数能否兼容,这两个代码不能相互使用时出现问题 8、要引入orb算法和flann优化器 9、检测代码加入一些检测的指标,如正确率、回归率、召回率等数据值 两个代码相互协调,给我完整的两个代码,不用在意篇幅限制,两个代码必须完整 10、仔细生成代码,注意代码是否能够运行 我的标签按照这个: 0,crack,(255,0,0) 1,aluminum,(0,255,0) 2,nocrack,(0,0,255) 我的数据集中的图片和标签数量是一一对应的,代码动态扩充数据集时,注意两者的对应关系
最新发布
03-08
<think>好的,我现在需要帮用户生成训练和检测两个代码,满足他们的焊接缺陷检测需求。首先,我需要仔细阅读用户的需求,确保每个要求都被满足。 用户的数据集路径已经给出,训练代码需要基于CenterNet,并进行结构修改。数据增强部分需要动态扩充5倍,但不能用颜色变化,所以可以考虑旋转、翻转、缩放等方法。同时,数据集没有预先划分,需要代码中按8:2划分,扩充后的数据也要保持图片和标签的对应。 训练代码需要保存最佳模型,损失最小的保存为指定路径。还要绘制损失曲线,显示五张训练结果图。这部分可能需要使用matplotlib来可视化,并且在每个epoch结束后记录损失值,最后生成图表。 检测代码要使用ORB和FLANN进行模板匹配,切割出焊接区域,再送入模型检测。检测结果需要显示匹配过程、切割后的图像和检测结果,并计算正确率、召回率等指标。分类0和1优先显示,且与2互斥,这需要在后处理时处理类别逻辑。 另外,模型结构需要加入注意力特征融合和多特征融合模块,这可能需要在CenterNet的基础上添加SE模块或CBAM。多尺度检测可能需要结合不同特征层的输出。 需要注意训练和检测代码的兼容性,确保模型保存和加载时的参数一致。数据增强部分要确保变换后的图像和标签同步处理,比如坐标变换后的边界框要正确。 检测代码中的模板匹配可能需要使用OpenCV的模板匹配函数,或者ORB特征匹配,用户特别提到了ORB和FLANN优化器,所以这部分要正确实现。切割后的区域要调整大小符合模型输入尺寸。 最后,代码的路径要正确,特别是用户指定的模型保存路径和数据集路径,要检查是否存在,必要时创建目录。还要测试代码是否可运行,避免语法错误或库的缺失。 现在,我需要分步骤构建两个代码,确保每个要求都被覆盖,并处理好可能的错误点,比如数据增强时的标签同步,模型结构的修改,以及检测时的类别处理逻辑。</think>根据您的需求,我将为您提供完整且可协调运行的训练和检测代码。以下是分步实现的方案: ### 训练代码(train.py) ```python import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader, random_split from torchvision import transforms import cv2 import numpy as np import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split # 自定义CenterNet改进模型 class ImprovedCenterNet(nn.Module): def __init__(self, num_classes=3): super().__init__() # 骨干网络 self.backbone = nn.Sequential( nn.Conv2d(3, 64, 7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(3, stride=2, padding=1) ) # 注意力特征融合模块 self.seam = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(64, 4, 1), nn.ReLU(), nn.Conv2d(4, 64, 1), nn.Sigmoid() ) # 多尺度检测头 self.head = nn.Sequential( nn.Conv2d(64, 256, 3, padding=1), nn.ReLU(), nn.Conv2d(256, num_classes+4, 1) # 输出分类+偏移量 ) def forward(self, x): x = self.backbone(x) att = self.seam(x) x = x * att # 注意力融合 return self.head(x) # 自定义数据集类 class WeldingDataset(Dataset): def __init__(self, img_dir, label_dir, augment=True): self.img_dir = img_dir self.label_dir = label_dir self.augment = augment self.files = [f for f in os.listdir(img_dir) if f.endswith('.jpg')] # 数据增强方法(排除颜色变换) self.transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(10), transforms.RandomAffine(0, translate=(0.1, 0.1)), transforms.ColorJitter(brightness=0, contrast=0, saturation=0, hue=0) ]) def __len__(self): return len(self.files) * 5 # 动态扩充5倍 def __getitem__(self, idx): orig_idx = idx % len(self.files) img_path = os.path.join(self.img_dir, self.files[orig_idx]) label_path = os.path.join(self.label_dir, self.files[orig_idx].replace('.jpg', '.txt')) img = cv2.imread(img_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 动态数据增强 if self.augment and (idx >= len(self.files)): img = self.transform(torch.from_numpy(img).permute(2,0,1).float()) img = img.permute(1,2,0).numpy().astype(np.uint8) # 处理标签 with open(label_path, 'r') as f: lines = f.readlines() label = [list(map(float, line.strip().split()))[0] for line in lines] # 转换为模型需要的格式 img_tensor = transforms.ToTensor()(img) label_tensor = torch.tensor(label, dtype=torch.long) return img_tensor, label_tensor # 训练参数设置 BATCH_SIZE = 8 EPOCHS = 50 LR = 0.001 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 初始化数据集 dataset = WeldingDataset( img_dir='D:/seam/SegNet/PH2Dataset/welding image/img', label_dir='D:/seam/SegNet/PH2Dataset/welding image/label' ) # 划分训练测试集 train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) # 创建数据加载器 train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE) # 初始化模型和优化器 model = ImprovedCenterNet().to(DEVICE) optimizer =
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值