import argparse
import os
from glob import glob
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.optim import lr_scheduler
import yaml
from albumentations import RandomRotate90, Flip, HueSaturationValue, RandomBrightness, RandomContrast, Resize, Normalize
from albumentations.core.composition import Compose, OneOf
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import wandb
import archs
import losses
from dataset import Dataset
from metrics import dice_score, iou_score, recall_score, f1_score, pixel_accuracy, precision_score
from utils import str2bool
# ------------------------------
# 多指标工具
# ------------------------------
class AverageMeterDict:
def __init__(self, keys):
self.keys = keys
self.reset()
def reset(self):
self.sums = {k:0 for k in self.keys}
self.counts = {k:0 for k in self.keys}
def update(self, metrics, n=1):
for k in self.keys:
self.sums[k] += metrics[k] * n
self.counts[k] += n
@property
def avg(self):
return {k: self.sums[k]/max(self.counts[k],1) for k in self.keys}
def compute_metrics(pred, target):
return {
'dice': dice_score(pred, target).item(),
'iou': iou_score(pred, target).item(),
'recall': recall_score(pred, target).item(),
'precision': precision_score(pred, target).item(),
'f1': f1_score(pred, target).item(),
'pa': pixel_accuracy(pred, target).item()
}
# ------------------------------
# 参数解析
# ------------------------------
import argparse
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', '1'):
return True
elif v.lower() in ('no', 'false', 'f', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default='experiment_1', help='Name of the experiment')
parser.add_argument('--epochs', default=200, type=int, help='Number of training epochs')
parser.add_argument('-b', '--batch_size', default=4, type=int, help='Batch size for training')
parser.add_argument('--lr', default=1e-5, type=float, help='Learning rate')
parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'], help='Optimizer to use')
parser.add_argument('--scheduler', default='CosineAnnealingLR',
choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'],
help='Learning rate scheduler')
parser.add_argument('--early_stopping', default=4, type=int, help='Early stopping patience')
parser.add_argument('--num_workers', default=4, type=int, help='Number of data loading workers')
parser.add_argument('--deep_supervision', default=False, type=str2bool, help='Whether to use deep supervision')
parser.add_argument('--loss', default='BCEDiceLoss', help='Loss function to use')
parser.add_argument('--arch', default='NestedUNet', help='Model architecture')
parser.add_argument('--dataset', default='dsb2018_96', help='Dataset to use')
parser.add_argument('--input_channels', default=3, type=int, help='Number of input channels')
parser.add_argument('--num_classes', default=1, type=int, help='Number of output classes')
parser.add_argument('--input_w', default=256, type=int, help='Width of the input image')
parser.add_argument('--input_h', default=256, type=int, help='Height of the input image')
parser.add_argument('--img_ext', default='.jpg', help='Image file extension')
parser.add_argument('--mask_ext', default='.png', help='Mask file extension')
return parser.parse_args()
# ------------------------------
# 训练与验证函数
# ------------------------------
def train_one_epoch(config, train_loader, model, criterion, optimizer):
metric_keys = ['loss','dice','iou','recall','f1','precision','pa']
avg_meter = AverageMeterDict(metric_keys)
model.train()
pbar = tqdm(total=len(train_loader))
for input,target,_ in train_loader:
input = input.cuda()
target = target.cuda()
output = model(input)[-1] if config['deep_supervision'] else model(input)
loss = criterion(output,target)
metrics = compute_metrics(output,target)
metrics['loss'] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_meter.update(metrics,input.size(0))
pbar.set_postfix({k:f"{v:.4f}" for k,v in avg_meter.avg.items()})
pbar.update(1)
pbar.close()
return avg_meter.avg
def validate_one_epoch(config, val_loader, model, criterion):
metric_keys = ['loss','dice','iou','recall','precision','f1','pa']
avg_meter = AverageMeterDict(metric_keys)
model.eval()
with torch.no_grad():
for input,target,_ in tqdm(val_loader,total=len(val_loader)):
input = input.cuda()
target = target.cuda()
output = model(input)[-1] if config['deep_supervision'] else model(input)
loss = criterion(output,target)
metrics = compute_metrics(output,target)
metrics['loss'] = loss.item()
avg_meter.update(metrics,input.size(0))
return avg_meter.avg
# ------------------------------
# 主训练函数
# ------------------------------
def main():
args = parse_args()
config = vars(args)
if config['name'] is None:
config['name'] = f"{config['dataset']}_{config['arch']}_DS" if config['deep_supervision'] else f"{config['dataset']}_{config['arch']}_woDS"
os.makedirs(f"models/{config['name']}",exist_ok=True)
with open(f"models/{config['name']}/config.yml",'w') as f:
yaml.dump(config,f)
wandb.init(project="UNet++",name=config['name'],config=config,mode='offline')
# 损失函数
if config['loss']=='BCEWithLogitsLoss':
criterion = nn.BCEWithLogitsLoss().cuda()
else:
criterion = losses.__dict__[config['loss']]().cuda()
# 模型
print(f"=> creating model {config['arch']}")
model = archs.__dict__[config['arch']](config['num_classes'],config['input_channels'],config['deep_supervision']).cuda()
# 优化器
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(params,lr=config['lr']) if config['optimizer']=='Adam' else optim.SGD(params,lr=config['lr'],momentum=0.9,nesterov=True,weight_decay=1e-4)
# 学习率调度器
if config['scheduler']=='CosineAnnealingLR':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=config['epochs'],eta_min=1e-5)
elif config['scheduler']=='ReduceLROnPlateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.1,patience=2,verbose=1,min_lr=1e-5)
else:
scheduler = None
cudnn.benchmark=True
# 数据加载
# 数据加载
train_img_dir = 'D:\\Pytorch-UNet-master new\\Pytorch-UNet-master\\data\\imgs_resized\\' # 新路径
train_mask_dir = 'D:\\Pytorch-UNet-master new\\Pytorch-UNet-master\\data\\masks_resized\\' # 新路径
val_img_dir = 'D:\\Pytorch-UNet-master new\\Pytorch-UNet-master\\data\\imgs_resized\\' # 新路径
val_mask_dir = 'D:\\Pytorch-UNet-master new\\Pytorch-UNet-master\\data\\masks_resized\\' # 新路径
# 获取训练和验证数据集的图片ID
train_img_ids = glob(os.path.join(train_img_dir, '*' + config['img_ext'])) # 读取训练集图像路径
train_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in train_img_ids]
val_img_ids = glob(os.path.join(val_img_dir, '*' + config['img_ext'])) # 读取验证集图像路径
val_img_ids = [os.path.splitext(os.path.basename(p))[0] for p in val_img_ids]
# 拆分训练集与验证集,80%用于训练,20%用于验证
train_ids, val_ids = train_test_split(train_img_ids, test_size=0.2, random_state=41)
# 数据增强:训练集和验证集的变换
train_transform = Compose([
RandomRotate90(),
Flip(),
OneOf([
HueSaturationValue(),
RandomBrightness(),
RandomContrast()
], p=1),
Resize(config['input_h'], config['input_w']),
Normalize()
])
val_transform = Compose([
Resize(config['input_h'], config['input_w']),
Normalize()
])
# 训练集数据加载
train_dataset = Dataset(
train_ids,
train_img_dir, # 新路径
train_mask_dir, # 新路径
config['img_ext'], # 图像扩展名(例如:.jpg)
config['mask_ext'], # 掩膜扩展名(例如:.jpg)
config['num_classes'],
train_transform # 训练集图像增强
)
# 验证集数据加载
val_dataset = Dataset(
val_ids,
val_img_dir, # 新路径
val_mask_dir, # 新路径
config['img_ext'], # 图像扩展名(例如:.jpg)
config['mask_ext'], # 掩膜扩展名(例如:.jpg)
config['num_classes'],
val_transform # 验证集图像增强
)
# 数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True,
num_workers=config['num_workers'], drop_last=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False,
num_workers=config['num_workers'], drop_last=False)
# 训练循环
best_iou=0
trigger=0
for epoch in range(config['epochs']):
print(f"\nEpoch [{epoch+1}/{config['epochs']}]")
train_metrics = train_one_epoch(config,train_loader,model,criterion,optimizer)
val_metrics = validate_one_epoch(config,val_loader,model,criterion)
# 学习率调度
if config['scheduler']=='CosineAnnealingLR':
scheduler.step()
elif config['scheduler']=='ReduceLROnPlateau':
scheduler.step(val_metrics['loss'])
# wandb记录
wandb.log({f"train_{k}":v for k,v in train_metrics.items()})
wandb.log({f"val_{k}":v for k,v in val_metrics.items()})
wandb.log({'epoch':epoch})
print(f"Train Metrics: {train_metrics}")
print(f"Val Metrics: {val_metrics}")
if val_metrics['iou']>best_iou:
torch.save(model.state_dict(),f"models/{config['name']}/model.pth")
best_iou = val_metrics['iou']
print("=> saved best model")
trigger=0
else:
trigger+=1
if config['early_stopping']>0 and trigger>=config['early_stopping']:
print("=> early stopping")
break
torch.cuda.empty_cache()
if __name__=='__main__':
main()
训练代码是否有问题