前言
本文主要记录复现方法DiffusionAD的步骤以及一些个注意点。
一、数据以及代码准备
1.数据准备
数据准备的主要目录如下:
说明:VisA文件夹放置于dataset文件夹下,其中DISthresh文件夹存放对象轮廓掩码,ground_truth文件夹存放异常区域掩码,--因为该文重建部分主要采用扩散模型,所以训练集只有正常样本--,若出现无ground_truth对应数据问题,下文有对应说明。
该文作者也提供了两个数据集(MvTec-AD和VisA)的图像以及对象轮廓掩码,并做好了数据划分,链接如下:谷歌网盘链接
VisA
|-- candle
|-----|----- DISthresh
|-----|----- ground_truth
|-----|----- test
|-----|--------|------ good
|-----|--------|------ bad
|-----|----- train
|-----|--------|------ good
|-- capsules
|-- ...
二、环境安装
建议:不要直接下载提供的requirements.txt,除非cuda版本以及其他环境大致相同,可以将对应的安装包修改符合本地cuda版本后再进行安装。或者按部就班,先把torch一套安装后,其他挨个安装,避免版本错误或者cuda版本对应错误。
安装命令如下:
pip install -r requirements.txt
注意:下述部分环境配置需要按照本地设备环境进行对应修改。
这里提供几个关于torch版本对应链接,自行参考:
pytorch官方链接
【torch、torchvision、torchaudio】版本对应关系
【环境搭建】Python、PyTorch与cuda的版本对应表
若需对应的cuda版本安装命令可参考:
pip install torch==1.12.0+cu116 torchvision==0.13.0+cu116 torchaudio==0.12.0 -f https://download.pytorch.org/whl/torch_stable.html
requirements.txt内容如下:
graphviz==0.8.4
helpers==0.2.0
huggingface-hub==0.17.2
idna==3.4
imageio==2.26.0
imgaug==0.4.0
numpy==1.23.1
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
opencv-python==4.7.0.72
pandas==2.0.1
Pillow==9.4.0
scikit-image==0.20.0
scikit-learn==1.2.1
scipy==1.9.1
shapely==2.0.1
threadpoolctl==3.1.0
tokenizers==0.13.3
torch==1.13.1
torchaudio==0.13.1
torchvision==0.14.1
tqdm==4.65.0
transformers==4.33.2
三、训练及评估
1.参数文件调整以及数据集构建说明
下述参数根据个人设备情况自行调整,若算力有限,个人建议优先调整Batch_Size。
{
"img_size": [256,256], # 图像大小
"Batch_Size": 16, # 每次取出图像张数
"EPOCHS": 3000, # 训练轮数
"T": 1000, # 时间步
"base_channels": 128,
"beta_schedule": "linear",
"loss_type": "l2",
"diffusion_lr": 1e-4, # 扩散模型学习率
"seg_lr": 1e-5, # 分割网络学习率
"random_slice": true,
"weight_decay": 0.0,
"save_imgs":true,
"save_vids":false,
"dropout":0,
"attention_resolutions":"32,16,8",
"num_heads":4,
"num_head_channels":-1,
"noise_fn":"gauss",
"channels":3,
"mvtec_root_path":"datasets/mvtec", #数据集路径,此处可自行设置数据集个数
"visa_root_path":"datasets/VisA_1class/1cls",
"dagm_root_path":"datasets/dagm",
"mpdd_root_path":"datasets/mpdd",
"anomaly_source_path":"datasets/DTD", #异常源数据集
"noisier_t_range":600,
"less_t_range":300,
"condition_w":1,
"eval_normal_t":200,
"eval_noisier_t":400,
"output_path":"outputs" #输出路径
}
数据集构建代码部分,以VisA数据集为例:
说明:对应上文提到的ground_truth文件数据缺失问题,此处可更换自有数据集,即有异常区域掩码数据,并放到对应路径上;若没有对应异常区域掩码且对复现结果不做要求,可取消这部分输入。
# 该处代码,
if base_dir == 'good':
image, mask = self.transform_image(img_path, None)
has_anomaly = np.array([0], dtype=np.float32)
else:
mask_path = os.path.join(dir_path, '../../ground_truth/')
mask_path = os.path.join(mask_path, base_dir)
mask_file_name = file_name.split(".")[0]+".png"
mask_path = os.path.join(mask_path, mask_file_name)
image, mask = self.transform_image(img_path, mask_path)
has_anomaly = np.array([1], dtype=np.float32)
2.训练
调整对应位置代码后,运行根目录下的train.py文件即可。
# 该处代码,根据个人的数据集数量进行调整,对应类别也需调整
mvtec_classes = ['carpet', 'grid', 'leather', 'tile', 'wood', 'bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut', 'pill', 'screw',
'toothbrush', 'transistor', 'zipper']
visa_classes = ['candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2', 'pcb1', 'pcb2',
'pcb3', 'pcb4', 'pipe_fryum']
mpdd_classes = ['bracket_black', 'bracket_brown', 'bracket_white', 'connector', 'metal_plate', 'tubes']
dagm_class = ['Class1', 'Class2', 'Class3', 'Class4', 'Class5','Class6', 'Class7', 'Class8', 'Class9', 'Class10']
for sub_class in current_classes:
print("class",sub_class)
if sub_class in visa_classes:
subclass_path = os.path.join(args["visa_root_path"],sub_class)
print(subclass_path)
# 数据集准备
# 训练数据集准备
training_dataset = VisATrainDataset(
subclass_path,sub_class,img_size=args["img_size"],args=args
)
# 测试数据集准备
testing_dataset = VisATestDataset(
subclass_path,sub_class,img_size=args["img_size"],
)
class_type='VisA'
elif sub_class in mpdd_classes:
subclass_path = os.path.join(args["mpdd_root_path"],sub_class)
training_dataset = MPDDTrainDataset(
subclass_path,sub_class,img_size=args["img_size"],args=args
)
testing_dataset = MPDDTestDataset(
subclass_path,sub_class,img_size=args["img_size"],
)
class_type='MPDD'
elif sub_class in mvtec_classes:
subclass_path = os.path.join(args["mvtec_root_path"],sub_class)
training_dataset = MVTecTrainDataset(
subclass_path,sub_class,img_size=args["img_size"],args=args
)
testing_dataset = MVTecTestDataset(
subclass_path,sub_class,img_size=args["img_size"],
)
class_type='MVTec'
elif sub_class in dagm_class:
subclass_path = os.path.join(args["dagm_root_path"],sub_class)
training_dataset = DAGMTrainDataset(
subclass_path,sub_class,img_size=args["img_size"],args=args
)
testing_dataset =DAGMTestDataset(
subclass_path,sub_class,img_size=args["img_size"],
)
class_type='DAGM'
3.评估
同样的,调整数据输入对应位置代码后,运行根目录下的eval.py文件。
总之,确保输入的数据与对应的类别对应即可。
# 该处代码,根据个人的数据集数量进行调整,对应类别也需调整
if sub_class in visa_classes:
subclass_path = os.path.join(args['visa_root_path'],sub_class)
print("visa_eval_root_path",args["visa_root_path"])
print("subclass_path",subclass_path)
testing_dataset = VisATestDataset(
subclass_path,sub_class,img_size=args["img_size"],
)
class_type='VisA'
elif sub_class in mpdd_classes:
subclass_path = os.path.join(args["mpdd_root_path"],sub_class)
testing_dataset = MVTecTestDataset(
subclass_path,sub_class,img_size=args["img_size"],
)
class_type='MPDD'
elif sub_class in mvtec_classes:
subclass_path = os.path.join(args["mvtec_root_path"],sub_class)
testing_dataset = MVTecTestDataset(
subclass_path,sub_class,img_size=args["img_size"],
)
class_type='MVTec'
elif sub_class in dagm_class:
subclass_path = os.path.join(args["dagm_root_path"],sub_class)
testing_dataset =DAGMTestDataset(
subclass_path,sub_class,img_size=args["img_size"],
)
class_type='DAGM'
先写到这里,后续。。。。
1万+

被折叠的 条评论
为什么被折叠?



