接上述问题,以下是我的guided diffusion的扩散模型和分类器模型的微调训练代码,请检查下有没有不合理的地方,数据格式是否满足:import os
import json
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import cv2
from tqdm import tqdm
from guided_diffusion import dist_util, logger
from guided_diffusion.train_script_util_image_mask import (
model_and_diffusion_defaults,
classifier_defaults,
create_model_and_diffusion,
create_classifier_and_diffusion,
classifier_and_diffusion_defaults,
add_dict_to_argparser,
args_to_dict
)
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='pydevd_plugins'
)
# 设置随机种子确保可复现性
torch.manual_seed(42)
np.random.seed(42)
def _get_defect_names(json_dir):
"""遍历标签文件获取所有缺陷名称"""
defect_names = []
json_file_list = os.listdir(json_dir)
assert len(json_file_list) > 0, "标签文件为空..."
for json_file in json_file_list:
json_file = os.path.join(json_dir, json_file)
with open(json_file, 'r') as fp:
print("json_name=", json_file)
data = json.load(fp)
for shapes in data['shapes']:
label = shapes['label']
if label not in defect_names:
defect_names.append(label)
with open(".\\categoryName.txt", 'w') as file:
for name in defect_names:
file.write(name + '\n')
return defect_names
"""扩展标注区域-包含背景区域更能学习到全局特征分布--待添加到训练流程"""
def _expand_roi_with_mask(image, mask, rect, expand_ratio=0.2):
# 1. 计算原始最小外接矩形
x, y, w, h = rect
# 2. 按比例扩展矩形区域
new_w = int(w * (1 + expand_ratio))
new_h = int(h * (1 + expand_ratio))
new_x = max(0, x - int((new_w - w) / 2))
new_y = max(0, y - int((new_h - h) / 2))
# 3. 提取扩展区域
expanded_roi = image[new_y:new_y + new_h, new_x:new_x + new_w]
# 4. 调整掩码位置
image_height, image_width = expanded_roi.shape[:2]
new_mask = np.zeros((image_height, image_width), dtype=np.uint8)
orig_in_new_x = x - new_x
orig_in_new_y = y - new_y
new_mask[orig_in_new_y:orig_in_new_y + h, orig_in_new_x:orig_in_new_x + w] = mask[y:y + h, x:x + w]
return expanded_roi, new_mask
class DefectDataset(Dataset):
"""
缺陷数据集加载器
处理原始图像和对应的标注信息
"""
def __init__(self, image_dir, json_dir, defect_names, transform=None, image_size=64):
self.image_dir = image_dir
self.json_dir = json_dir
self.defect_names = defect_names
self.transform = transform
self.image_size = image_size
self.samples = self._prepare_samples()
def _prepare_samples(self):
"""预处理所有样本,提取缺陷区域和标签"""
samples = []
for img_name in os.listdir(self.image_dir):
if not img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
continue
base_name = os.path.splitext(img_name)[0]
json_path = os.path.join(self.json_dir, f"{base_name}.json")
if not os.path.exists(json_path):
continue
with open(json_path, 'r') as f:
data = json.load(f)
shapes = data.get("shapes", [])
image_width = data.get("imageWidth")
image_height = data.get("imageHeight")
# 创建全图掩码
full_mask = np.zeros((image_height, image_width), dtype=np.uint8)
for shape in shapes:
label = shape.get('label')
if label not in self.defect_names:
continue
points = np.array(shape.get('points'), dtype=np.int32)
# 在掩码上绘制缺陷区域
cv2.fillPoly(full_mask, [points], 255)
# 计算缺陷区域的边界框
x, y, w, h = cv2.boundingRect(points)
# 添加到样本列表
samples.append({
'image_path': os.path.join(self.image_dir, img_name),
'mask': full_mask.copy(),
'label': self.defect_names.index(label),
'bbox': (x, y, w, h)
})
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
image = cv2.imread(sample['image_path'])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)
mask = sample['mask'].astype(np.float32)
label = sample['label']
rect = sample['bbox']
# 扩展缺陷区域
defect_region, defect_mask = _expand_roi_with_mask(image, mask, rect, expand_ratio=0.2)
defect_mask = defect_mask.astype(np.float32)
# 调整大小
defect_region = cv2.resize(defect_region, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
defect_mask = cv2.resize(defect_mask, (self.image_size, self.image_size),interpolation=cv2.INTER_NEAREST)# 掩码使用最近邻插值
# 图像归一化到[-1,1]范围(扩散模型标准)
defect_region = (defect_region / 127.5) - 1.0
# 掩码归一化到[0, 1]范围
defect_mask = defect_mask / 255.0
defect_mask = np.clip(defect_mask, 0, 1) # 确保值在[0,1]
# 转换为张量
if self.transform:
defect_region = self.transform(defect_region)
defect_mask = self.transform(defect_mask)
else:
to_tensor = transforms.ToTensor()
defect_region = to_tensor(defect_region).to(torch.float32)
defect_mask = to_tensor(defect_mask).to(torch.float32)
return {
'image': defect_region, # 范围[-1, 1], float32
'mask': defect_mask, # 范围[0, 1], float32
'label': torch.tensor(label, dtype=torch.long)
}
def train_diffusion_model(args):
"""训练扩散模型"""
logger.configure()
logger.log("preparing dataset...")
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor()
])
# 加载数据集
train_dataset = DefectDataset(
args.train_image_dir,
args.train_json_dir,
args.defect_names,
transform=transform,
image_size=args.image_size
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
logger.log("creating model and diffusion...")
# 创建模型和扩散过程
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model = dist_util.load_state_dict(model, args.diffusion_pre_model_path)
model.to(dist_util.dev())
# 使用混合精度训练
if args.use_fp16:
model.convert_to_fp16()
# 优化器
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5, verbose=True
)
logger.log("training...")
# TensorBoard 记录器
writer = SummaryWriter(log_dir=os.path.join(args.log_dir, "diffusion"))
for epoch in range(args.epochs):
model.train()
total_loss = 0
with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}", unit="batch") as t:
for batch in t:
# 准备数据
images = batch['image'].to(dist_util.dev())
masks = batch['mask'].to(dist_util.dev())
labels = batch['label'].to(dist_util.dev())
# 扩散过程前向加噪
t_batch = torch.randint(0, diffusion.num_timesteps, (images.shape[0],), device=dist_util.dev())
noise = torch.randn_like(images)
x_t = diffusion.q_sample(images, t_batch, noise=noise)
# 模型预测
model_output = model(x_t, t_batch, masks, y=labels)
# 计算损失
loss = F.mse_loss(model_output[:, :3, :, :], noise, reduction='none')
loss = (loss * masks).sum() / masks.sum() # 仅计算掩码区域的损失
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
t.set_postfix(loss=f"{loss.item():.4f}")
# 更新学习率
avg_loss = total_loss / len(train_loader)
scheduler.step(avg_loss)
# 记录日志
writer.add_scalar('Loss/train', avg_loss, epoch)
logger.log(f"Epoch {epoch + 1} | Loss: {avg_loss:.4f}")
# 保存检查点
if (epoch + 1) % args.save_interval == 0:
checkpoint_path = os.path.join(args.checkpoint_dir, f"diffusion_model_epoch_{epoch + 1}.pt")
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss,
}, checkpoint_path)
logger.log(f"Saved checkpoint to {checkpoint_path}")
# 保存最终模型
final_model_path = os.path.join(args.model_dir, "diffusion_final_model2.pt")
torch.save(model.state_dict(), final_model_path)
logger.log(f"Diffusion model training complete. Model saved to {final_model_path}")
writer.close()
return diffusion
def train_classifier_model(args):
"""训练与扩散模型对齐的分类器模型"""
logger.configure()
logger.log("preparing dataset for classifier...")
# 数据预处理(与扩散模型相同)
transform = transforms.Compose([
# transforms.Resize((args.image_size, args.image_size)),
transforms.ToTensor()
])
# 加载数据集(与扩散模型相同)
train_dataset = DefectDataset(
args.train_image_dir,
args.train_json_dir,
args.defect_names,
transform=transform,
image_size=args.image_size
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
logger.log("creating classifier model with aligned diffusion...")
# 创建分类器
classifier, diffusion = create_classifier_and_diffusion(
**args_to_dict(args, classifier_and_diffusion_defaults().keys())
)
classifier = dist_util.load_state_dict(classifier, args.classify_pre_model_path)
classifier.to(dist_util.dev())
# 验证扩散模型和分类器的时间步对齐
if diffusion.num_timesteps != args.diffusion_steps:
logger.log(f"警告: 扩散模型时间步数不匹配 ({diffusion.num_timesteps} vs {args.diffusion_steps})")
logger.log("建议使用相同的扩散步数以确保对齐")
# 使用混合精度训练(与扩散模型相同)
if args.classifier_use_fp16:
classifier.convert_to_fp16()
# 优化器(与扩散模型相同)
optimizer = optim.AdamW(classifier.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# 损失函数 - 带权重的交叉熵
class_weights = torch.ones(len(args.defect_names)).to(dist_util.dev())
if args.class_weights:
class_weights = torch.tensor(args.class_weights, dtype=torch.float).to(dist_util.dev())
criterion = nn.CrossEntropyLoss(weight=class_weights)
# 学习率调度器(与扩散模型相同)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5, verbose=True
)
logger.log("training classifier with aligned noise levels...")
# TensorBoard 记录器
writer = SummaryWriter(log_dir=os.path.join(args.log_dir, "classifier"))
for epoch in range(args.epochs):
classifier.train()
total_loss = 0
correct = 0
total = 0
with tqdm(train_loader, desc=f"Epoch {epoch + 1}/{args.epochs}", unit="batch") as t:
for batch in t:
# 准备数据(与扩散模型相同)
images = batch['image'].to(dist_util.dev())
masks = batch['mask'].to(dist_util.dev())
labels = batch['label'].to(dist_util.dev())
# 关键对齐点1:使用与扩散模型相同的离散时间步
t_batch = torch.randint(
0, diffusion.num_timesteps,
(images.shape[0],),
device=dist_util.dev()
)
# 关键对齐点2:使用相同的噪声添加方法
noise = torch.randn_like(images)
noisy_images = diffusion.q_sample(images, t_batch, noise=noise)
# 模型预测 - 使用整数时间步输入
outputs = classifier(noisy_images, t_batch, masks)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算准确率
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
total_loss += loss.item()
accuracy = 100 * correct / total
t.set_postfix(loss=loss.item(), acc=accuracy)
# 更新学习率
avg_loss = total_loss / len(train_loader)
avg_accuracy = 100 * correct / total
scheduler.step(avg_loss)
# 记录日志
writer.add_scalar('Loss/train', avg_loss, epoch)
writer.add_scalar('Accuracy/train', avg_accuracy, epoch)
logger.log(f"Epoch {epoch + 1} | Loss: {avg_loss:.4f} | Acc: {avg_accuracy:.2f}%")
# 保存检查点
if (epoch + 1) % args.save_interval == 0:
checkpoint_path = os.path.join(args.checkpoint_dir, f"classifier_epoch_{epoch + 1}.pt")
torch.save({
'epoch': epoch,
'model_state_dict': classifier.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': avg_loss,
'accuracy': avg_accuracy,
}, checkpoint_path)
logger.log(f"Saved classifier checkpoint to {checkpoint_path}")
# 保存最终模型
final_model_path = os.path.join(args.model_dir, "classifier_final_model.pt")
torch.save(classifier.state_dict(), final_model_path)
logger.log(f"Classifier training complete. Model saved to {final_model_path}")
writer.close()
def main():
# 创建参数解析器
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, model_and_diffusion_defaults())
add_dict_to_argparser(parser, classifier_defaults())
# 添加训练特定参数
parser.add_argument("--train_image_dir", type=str, required=False, default="train_dataset/tuku2/images", help="训练图像目录")
parser.add_argument("--train_json_dir", type=str, required=False, default="train_dataset/tuku2/labels", help="训练JSON标注目录")
parser.add_argument("--defect_names", type=list, required=False, default=[], help="训练JSON标注目录")
parser.add_argument("--diffusion_pre_model_path", type=str, default="result/models/128x128_diffusion.pt", help="扩散预训练权重")
parser.add_argument("--classify_pre_model_path", type=str, default="result/models/128x128_classifier.pt", help="分类预训练权重")
parser.add_argument("--class_weights", type=float, nargs='+', help="类别权重(可选)")
parser.add_argument("--model_dir", type=str, default="result/models", help="模型保存目录")
parser.add_argument("--log_dir", type=str, default="result/logs", help="日志目录")
parser.add_argument("--checkpoint_dir", type=str, default="result/checkpoints", help="检查点目录")
parser.add_argument("--batch_size", type=int, default=8, help="批大小")
parser.add_argument("--epochs", type=int, default=10, help="训练轮数")
parser.add_argument("--lr", type=float, default=1e-4, help="学习率")
parser.add_argument("--weight_decay", type=float, default=1e-6, help="权重衰减")
parser.add_argument("--save_interval", type=int, default=10, help="保存间隔(epoch)")
parser.add_argument("--train_mode", choices=["diffusion", "classifier", "both"], default="diffusion",
help="训练模式: diffusion, classifier or both")
args = parser.parse_args()
args.defect_names = _get_defect_names(args.train_json_dir)
# 创建目录
os.makedirs(args.model_dir, exist_ok=True)
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.checkpoint_dir, exist_ok=True)
# 根据模式执行训练
if args.train_mode in ["diffusion", "both"]:
logger.log("Starting diffusion model training...")
train_diffusion_model(args)
if args.train_mode in ["classifier", "both"]:
logger.log("Starting classifier model training...")
train_classifier_model(args)
if __name__ == "__main__":
main()
print('train finish...')
请详述
最新发布