基于增强型U-Net的印刷缺陷检测技术

在工业生产中,印刷质量直接关系到产品的外观和用户体验。印刷缺陷检测作为质量控制的关键环节,传统方法依赖人工目检或简单图像处理,效率低、精度有限。随着深度学习的发展,基于卷积神经网络(CNN)的方法逐渐成为主流。本文将结合代码,详细介绍一种基于增强型U-Net的印刷缺陷检测技术,该技术结合了注意力机制和混合损失函数,实现高精度的缺陷分割。

一、模型架构

(一)增强型U-Net

U-Net因其在医学图像分割中的成功而广为人知。本文提出的增强型U-Net在此基础上进行了改进。编码器部分通过多次卷积和池化操作提取图像的多尺度特征,而解码器部分则通过转置卷积逐步恢复图像的空间分辨率。在每个解码器块中,加入了跳跃连接,将编码器部分的特征图与解码器部分的特征图进行融合,有效解决了梯度消失问题,同时保留了图像的细节信息。

class EnhancedUNet(nn.Module):
    def __init__(self):
        super().__init__()
        features = Config.init_features

        self._initialize = nn.init.kaiming_normal_

        # 编码器
        self.encoder1 = self._build_block(Config.in_channels, features)
        self.pool1 = nn.MaxPool2d(2)
        self.encoder2 = self._build_block(features, features * 2)
        self.pool2 = nn.MaxPool2d(2)
        self.encoder3 = self._build_block(features * 2, features * 4)
        self.pool3 = nn.MaxPool2d(2)
        self.encoder4 = self._build_block(features * 4, features * 8)
        self.pool4 = nn.MaxPool2d(2)

        # 瓶颈层
        self.bottleneck = self._build_block(features * 8, features * 16)

        # 解码器
        self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, 2, 2)
        self.decoder4 = self._build_block(features * 16, features * 8)
        self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, 2, 2)
        self.decoder3 = self._build_block(features * 8, features * 4)
        self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, 2, 2)
        self.decoder2 = self._build_block(features * 4, features * 2)
        self.upconv1 = nn.ConvTranspose2d(features * 2, features, 2, 2)
        self.decoder1 = self._build_block(features * 2, features)

        # 注意力机制
        self.attention4 = AttentionBlock(features * 8, features * 8)
        self.attention3 = AttentionBlock(features * 4, features * 4)
        self.attention2 = AttentionBlock(features * 2, features * 2)
        self.attention1 = AttentionBlock(features * 1, features * 1)

        self.final_conv = nn.Conv2d(features, Config.out_channels, 1)
        self._initialize_weights()

    def _build_block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

 

(二)注意力机制

为了进一步提升模型对缺陷区域的关注度,引入了注意力机制。在解码器的每个块中,使用注意力模块对特征图进行加权。注意力模块通过卷积和激活函数学习特征图中每个位置的重要性权重,然后将权重与特征图相乘,使模型更加关注与缺陷相关的特征。这种注意力机制不仅增强了模型对缺陷的敏感度,还提高了分割结果的准确性。

 

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.sigmoid(x)
        return x

二、数据处理与增强

(一)数据集

实验使用了COCO格式的数据集,包含印刷品的图像及其对应的注释信息。数据集分为训练集和验证集,分别用于模型训练和性能评估。

class PrintDefectDataset(Dataset):
    def __init__(self, image_dir, annotation_path, transform=None, mode='train'):
        self.image_dir = image_dir
        self.transform = transform
        self.mode = mode
        self.target_size = (Config.img_size, Config.img_size)

        if not os.path.isdir(image_dir):
            raise ValueError(f"Image directory not found: {image_dir}")
        if not os.path.isfile(annotation_path):
            raise ValueError(f"Annotation file not found: {annotation_path}")

        with open(annotation_path) as f:
            data = json.load(f)
            self._validate_data(data)

        self.images = {img['id']: img for img in data['images']}
        self.annotations = defaultdict(list)
        for ann in data['annotations']:
            img_id = ann['image_id']
            self.annotations[img_id].append(ann)
        self.ids = [img_id for img_id in self.images.keys() if img_id in self.annotations]
        self._analyze_dataset(data)

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img_info = self.images[img_id]

        img_path = os.path.join(self.image_dir, img_info['file_name'])
        image = self._load_image(img_path)

        mask = self._create_mask(img_id, img_info)

        if self.transform and self.mode == 'train':
            transformed = self.transform(image=image, mask=mask)
            image, mask = transformed['image'], transformed['mask']

        self._validate_sample(image, mask)

        return (
            torch.from_numpy(image).unsqueeze(0).float(),
            torch.from_numpy(mask).unsqueeze(0).float()
        )

 

(二)数据增强

为了增加数据的多样性,提高模型的泛化能力,采用了多种数据增强方法。包括弹性变换、网格畸变、平移缩放旋转、随机亮度对比度调整、高斯模糊、水平翻转、垂直翻转、旋转、粗Dropout和高斯噪声等。这些增强操作在训练过程中随机应用,使模型能够学习到不同条件下的图像特征,增强其鲁棒性。

train_transform = A.Compose([
    A.Resize(img_size, img_size, interpolation=cv2.INTER_LANCZOS4),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, interpolation=cv2.INTER_NEAREST, p=0.5),
        A.GridDistortion(num_steps=5, distort_limit=0.3, interpolation=cv2.INTER_NEAREST, p=0.5)
    ], p=0.7),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, interpolation=cv2.INTER_NEAREST, border_mode=cv2.BORDER_CONSTANT, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.GaussianBlur(blur_limit=3, p=0.2),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=30, interpolation=cv2.INTER_NEAREST, p=0.5),
    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
    A.GaussNoise(var_limit=(0.001, 0.01), mean=0, per_channel=True, p=0.2),
], additional_targets={'mask': 'mask'})

 

三、训练策略

(一)损失函数

提出了混合损失函数,结合了Focal Loss和Dice Loss。Focal Loss用于解决类别不平衡问题,使模型更加关注难分类的样本。Dice Loss则直接优化分割的Dice系数,提高分割结果的连通性和准确性。混合损失函数的使用,使模型在训练过程中能够更好地平衡不同类别和区域的损失,提升整体性能。

class HybridLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, pred, target):
        # Focal Loss
        bce_loss = self.bce(pred, target)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        focal_loss = focal_loss.mean()

        # Dice Loss
        pred_sigmoid = torch.sigmoid(pred)
        smooth = 1e-7
        intersection = 2.0 * (pred_sigmoid * target).sum(dim=(1, 2, 3))
        union = pred_sigmoid.sum(dim=(1, 2, 3)) + target.sum(dim=(1, 2, 3))
        dice_loss = 1 - (intersection + smooth) / (union + smooth)
        dice_loss = dice_loss.mean()

        return focal_loss + dice_loss

 

(二)优化器与学习率调度

选择AdamW优化器,其在处理权重衰减方面表现良好,有助于防止过拟合。同时,采用余弦退火学习率调度策略,在训练过程中逐步调整学习率,使模型能够更快地收敛到最优解。

self.optimizer = optim.AdamW(self.model.parameters(), lr=config.lr, weight_decay=1e-5)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=config.epochs, eta_min=1e-7)

(三)早停机制

为了防止过拟合,设置了早停机制。在验证集上,如果连续一定数量的轮次(如15轮)模型性能没有提升,则提前停止训练,保存当前最佳模型。

if self.early_stop_counter >= self.config.early_stop:
    print("Early stopping triggered!")
    break

 

四、实验结果与分析

(一)训练过程

在训练过程中,观察到训练损失和验证损失逐渐下降,验证F1分数和IoU(交并比)逐渐上升。这表明模型在训练过程中逐渐学习到了图像的特征和缺陷的模式,性能不断提升。

def train(self, train_loader, val_loader):
    print("Start training...")
    metrics = {
        'train_loss': [],
        'val_loss': [],
        'val_f1': [],
        'val_iou': []
    }

    for epoch in range(1, self.config.epochs + 1):
        epoch_start = time.time()

        train_loss = self.train_epoch(train_loader)
        val_loss, val_f1, val_iou = self.evaluate(val_loader)

        epoch_time = time.time() - epoch_start

        print(f"\nEpoch {epoch:03d}/{self.config.epochs} | Time: {epoch_time:.1f}s")
        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val F1: {val_f1:.4f} | Val IoU: {val_iou:.4f}")

        metrics['train_loss'].append(train_loss)
        metrics['val_loss'].append(val_loss)
        metrics['val_f1'].append(val_f1)
        metrics['val_iou'].append(val_iou)

        if val_f1 > self.best_f1 or (val_f1 == self.best_f1 and val_iou > self.best_iou):
            self.best_f1 = val_f1
            self.best_iou = val_iou
            self.early_stop_counter = 0
            self.save_model()
            self.visualize_predictions(val_loader)
            visual.save_metrics(metrics, self.config.output_dir)
        else:
            self.early_stop_counter += 1
            print(f"No improvement for {self.early_stop_counter} epochs")

        if self.early_stop_counter >= self.config.early_stop:
            print("Early stopping triggered!")
            break

        self.scheduler.step()

    visual.plot_metrics(metrics, self.config.output_dir)

 

(二)可视化结果

通过对验证集的预测结果进行可视化,可以看到模型能够准确地分割出印刷缺陷的区域。预测结果与真实标签(Ground Truth)的对比显示,模型的分割结果与真实情况高度吻合,说明模型具有良好的分割能力。

def visualize_predictions(self, loader, num_samples=3):
    self.model.eval()
    with torch.no_grad():
        for idx, (images, masks) in enumerate(loader):
            if idx >= num_samples:
                break
            images = images.to(self.config.device)
            outputs = self.model(images)
            preds = torch.sigmoid(outputs).cpu().numpy()

            plt.figure(figsize=(18, 6))

            plt.subplot(1, 3, 1)
            plt.imshow(images[0].cpu().squeeze().numpy(), cmap='gray')
            plt.title('Original Image')

            plt.subplot(1, 3, 2)
            plt.imshow(masks[0].cpu().squeeze().numpy(), cmap='jet')
            plt.title('Ground Truth')

            plt.subplot(1, 3, 3)
            plt.imshow(preds[0].squeeze(), cmap='jet', vmin=0, vmax=1)
            plt.title(f'Prediction (F1: {self.calculate_f1(preds[0][0], masks[0].cpu().numpy().squeeze()):.3f}, IoU: {self.calculate_iou(preds[0][0], masks[0].cpu().numpy().squeeze()):.3f})')
            plt.show()

(三)性能指标

在验证集上,模型取得了较高的F1分数和IoU值。F1分数衡量了模型的精确率和召回率的平衡,而IoU则反映了分割结果与真实标签的重叠程度。这些指标的高值表明模型在缺陷检测任务中具有较高的准确性和鲁棒性。

def evaluate(self, loader):
    self.model.eval()
    total_loss = 0
    all_f1 = []
    all_iou = []
    progress = tqdm(loader, desc="Evaluating", ncols=100)

    with torch.no_grad():
        for images, masks in progress:
            images = images.to(self.config.device)
            masks = masks.to(self.config.device)

            outputs = self.model(images)
            loss = self.criterion(outputs, masks)
            total_loss += loss.item()

            preds = torch.sigmoid(outputs).cpu().numpy()
            targets = masks.cpu().numpy().astype(np.uint8)

            batch_f1 = [self.calculate_f1(p[0], t[0]) for p, t in zip(preds, targets)]
            batch_iou = [self.calculate_iou(p[0], t[0]) for p, t in zip(preds, targets)]
            all_f1.extend(batch_f1)
            all_iou.extend(batch_iou)

            progress.set_postfix(f1=f"{np.mean(batch_f1):.4f}", iou=f"{np.mean(batch_iou):.4f}")

    return total_loss / len(loader), np.mean(all_f1), np.mean(all_iou)

 下篇博客继续更新优化后的代码。。。。。

五、总结与展望

本文介绍了一种基于增强型U-Net的印刷缺陷检测技术。通过引入注意力机制和混合损失函数,模型能够有效地分割印刷缺陷,提高检测的准确性和鲁棒性。实验结果表明,该方法在印刷缺陷检测任务中具有良好的性能。

未来的研究可以进一步优化模型架构,探索更有效的注意力机制和损失函数。同时,可以将该技术应用于其他类型的工业缺陷检测,如电子元件检测、纺织品缺陷检测等,为工业质量控制提供更有力的支持。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值