DiffusionAD方法的复现记录以及注意点


前言

本文主要记录复现方法DiffusionAD的步骤以及一些个注意点。

一、数据以及代码准备

1.数据准备

数据准备的主要目录如下:
说明:VisA文件夹放置于dataset文件夹下,其中DISthresh文件夹存放对象轮廓掩码,ground_truth文件夹存放异常区域掩码,--因为该文重建部分主要采用扩散模型,所以训练集只有正常样本--,若出现无ground_truth对应数据问题,下文有对应说明。

该文作者也提供了两个数据集(MvTec-AD和VisA)的图像以及对象轮廓掩码,并做好了数据划分,链接如下:谷歌网盘链接

VisA
|-- candle
|-----|----- DISthresh
|-----|----- ground_truth
|-----|----- test
|-----|--------|------ good
|-----|--------|------ bad
|-----|----- train
|-----|--------|------ good
|-- capsules
|-- ...

二、环境安装

建议:不要直接下载提供的requirements.txt,除非cuda版本以及其他环境大致相同,可以将对应的安装包修改符合本地cuda版本后再进行安装。或者按部就班,先把torch一套安装后,其他挨个安装,避免版本错误或者cuda版本对应错误。

安装命令如下:

pip install -r requirements.txt

注意:下述部分环境配置需要按照本地设备环境进行对应修改。
这里提供几个关于torch版本对应链接,自行参考:
pytorch官方链接
【torch、torchvision、torchaudio】版本对应关系
【环境搭建】Python、PyTorch与cuda的版本对应表

若需对应的cuda版本安装命令可参考:

pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 torchaudio==0.12.0 -f https://download.pytorch.org/whl/torch_stable.html

requirements.txt内容如下:

graphviz==0.8.4
helpers==0.2.0
huggingface-hub==0.17.2
idna==3.4
imageio==2.26.0
imgaug==0.4.0
numpy==1.23.1
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
opencv-python==4.7.0.72
pandas==2.0.1
Pillow==9.4.0
scikit-image==0.20.0
scikit-learn==1.2.1
scipy==1.9.1
shapely==2.0.1
threadpoolctl==3.1.0
tokenizers==0.13.3
torch==1.13.1
torchaudio==0.13.1
torchvision==0.14.1
tqdm==4.65.0
transformers==4.33.2

三、训练及评估

1.参数文件调整以及数据集构建说明

下述参数根据个人设备情况自行调整,若算力有限,个人建议优先调整Batch_Size。

{
  "img_size": [256,256],				# 图像大小
  "Batch_Size": 16,						# 每次取出图像张数
  "EPOCHS": 3000,						# 训练轮数
  "T": 1000,							# 时间步
  "base_channels": 128,
  "beta_schedule": "linear",
  "loss_type": "l2",
  "diffusion_lr": 1e-4,					# 扩散模型学习率
  "seg_lr": 1e-5,						# 分割网络学习率
  "random_slice": true,
  "weight_decay": 0.0,
  "save_imgs":true,
  "save_vids":false, 
  "dropout":0,
  "attention_resolutions":"32,16,8",
  "num_heads":4,
  "num_head_channels":-1,
  "noise_fn":"gauss",
  "channels":3,
  "mvtec_root_path":"datasets/mvtec",	#数据集路径,此处可自行设置数据集个数
  "visa_root_path":"datasets/VisA_1class/1cls",
  "dagm_root_path":"datasets/dagm",
  "mpdd_root_path":"datasets/mpdd",
  "anomaly_source_path":"datasets/DTD",	#异常源数据集
  "noisier_t_range":600,
  "less_t_range":300,
  "condition_w":1,
  "eval_normal_t":200,
  "eval_noisier_t":400,
  "output_path":"outputs"				#输出路径
}

数据集构建代码部分,以VisA数据集为例:
说明:对应上文提到的ground_truth文件数据缺失问题,此处可更换自有数据集,即有异常区域掩码数据,并放到对应路径上;若没有对应异常区域掩码且对复现结果不做要求,可取消这部分输入。

# 该处代码,
if base_dir == 'good':
    image, mask = self.transform_image(img_path, None)
    has_anomaly = np.array([0], dtype=np.float32)
else:
    mask_path = os.path.join(dir_path, '../../ground_truth/')
    mask_path = os.path.join(mask_path, base_dir)
    mask_file_name = file_name.split(".")[0]+".png"
    mask_path = os.path.join(mask_path, mask_file_name)
    image, mask = self.transform_image(img_path, mask_path)
    has_anomaly = np.array([1], dtype=np.float32)

2.训练

调整对应位置代码后,运行根目录下的train.py文件即可。

# 该处代码,根据个人的数据集数量进行调整,对应类别也需调整
mvtec_classes = ['carpet', 'grid', 'leather', 'tile', 'wood', 'bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut', 'pill', 'screw',
            'toothbrush', 'transistor', 'zipper']
    
visa_classes = ['candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2', 'pcb1', 'pcb2',
         'pcb3', 'pcb4', 'pipe_fryum']

mpdd_classes = ['bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate', 'tubes'] 
dagm_class = ['Class1', 'Class2', 'Class3', 'Class4', 'Class5','Class6', 'Class7', 'Class8', 'Class9', 'Class10']
for sub_class in current_classes:    
    print("class",sub_class)
    if sub_class in visa_classes:
        subclass_path = os.path.join(args["visa_root_path"],sub_class)
        print(subclass_path)
        # 数据集准备
        # 训练数据集准备
        training_dataset = VisATrainDataset(
            subclass_path,sub_class,img_size=args["img_size"],args=args
            )
        # 测试数据集准备
        testing_dataset = VisATestDataset(
            subclass_path,sub_class,img_size=args["img_size"],
            )
        class_type='VisA'
    elif sub_class in mpdd_classes:
        subclass_path = os.path.join(args["mpdd_root_path"],sub_class)

        training_dataset = MPDDTrainDataset(
            subclass_path,sub_class,img_size=args["img_size"],args=args
            )
        testing_dataset = MPDDTestDataset(
            subclass_path,sub_class,img_size=args["img_size"],
            )
        class_type='MPDD'
    elif sub_class in mvtec_classes:
        subclass_path = os.path.join(args["mvtec_root_path"],sub_class)
        training_dataset = MVTecTrainDataset(
            subclass_path,sub_class,img_size=args["img_size"],args=args
            )
        testing_dataset = MVTecTestDataset(
            subclass_path,sub_class,img_size=args["img_size"],
            )
        class_type='MVTec'
    elif sub_class in dagm_class:
        subclass_path = os.path.join(args["dagm_root_path"],sub_class)
        training_dataset = DAGMTrainDataset(
            subclass_path,sub_class,img_size=args["img_size"],args=args
            )
        testing_dataset =DAGMTestDataset(
            subclass_path,sub_class,img_size=args["img_size"],
            )
        class_type='DAGM'

3.评估

同样的,调整数据输入对应位置代码后,运行根目录下的eval.py文件。
总之,确保输入的数据与对应的类别对应即可。

# 该处代码,根据个人的数据集数量进行调整,对应类别也需调整
if sub_class in visa_classes:
    subclass_path = os.path.join(args['visa_root_path'],sub_class)
    print("visa_eval_root_path",args["visa_root_path"])
    print("subclass_path",subclass_path)
    testing_dataset = VisATestDataset(
        subclass_path,sub_class,img_size=args["img_size"],
        )
    class_type='VisA'
elif sub_class in mpdd_classes:
    subclass_path = os.path.join(args["mpdd_root_path"],sub_class)
    testing_dataset = MVTecTestDataset(
        subclass_path,sub_class,img_size=args["img_size"],
        )
    class_type='MPDD'
elif sub_class in mvtec_classes:
    subclass_path = os.path.join(args["mvtec_root_path"],sub_class)
    testing_dataset = MVTecTestDataset(
        subclass_path,sub_class,img_size=args["img_size"],
        )
    class_type='MVTec'

elif sub_class in dagm_class:
    subclass_path = os.path.join(args["dagm_root_path"],sub_class)
    testing_dataset =DAGMTestDataset(
        subclass_path,sub_class,img_size=args["img_size"],
        )
    class_type='DAGM'

先写到这里,后续。。。。

### 关于DiffusionAD复现 DiffusionAD是一种基于扩散模型的异常检测方法,旨在利用扩散过程中的特征变化来识别数据集中的异常样本。由于传统DDPM扩散模型存在多类别异常检测上的局限性[^1],因此DiffusionAD通过改进模型结构和训练策略以适应更复杂的场景。 #### 复现DiffusionAD实验所需准备 为了成功复现DiffusionAD模型,在开始之前需准备好如下环境: - Python版本建议3.8以上 - 安装PyTorch框架及其依赖库 - 下载官方开源代码仓库或获取最新版源码文件 #### 主要实现步骤概述 以下是简化后的Python代码片段用于展示核心逻辑: ```python import torch from torchvision import datasets, transforms from diffusion_ad.model import DiffusionModel # 假设这是自定义模块路径 transform = transforms.Compose([ transforms.ToTensor(), ]) dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) model = DiffusionModel() def train(model, dataset): optimizer = torch.optim.Adam(model.parameters(), lr=0.001) for epoch in range(num_epochs): for batch_idx, (data, target) in enumerate(dataset): data.requires_grad_() noise_level = ... # 计算当前噪声水平 noisy_data = add_noise(data, noise_level) # 添加高斯白噪 loss = model(noisy_data).mean() optimizer.zero_grad() loss.backward() optimizer.step() ``` 此段伪代码仅作为概念验证用途,并未完全体现实际项目细节。具体参数设置、损失函数设计以及评估指标的选择应参照原始论文和技术文档说明。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值