# -------------------------- 新增:禁用TensorFlow oneDNN(避免数值干扰) --------------------------
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # 可选:便于调试CUDA错误
# -------------------------------------------------------------------------------------------------
import argparse
import json
import random
from pathlib import Path
import numpy as np
import pandas as pd
import time
import pickle
## pytorch
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.backends import cudnn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
# -------------------------- 修改:适配torchmetrics 1.2.0版本 --------------------------
from torchmetrics.classification import BinaryAUROC as AUC # 使用二分类AUROC替代AUC
from torchmetrics.classification import BinaryJaccardIndex as JaccardIndex # 使用二分类JaccardIndex
# -------------------------------------------------------------------------------------------------
## model:导入ResNet-based的UNet系列模型
from models import UNet, UNet11, UNet16, UNet16BN, LinkNet34
from loss import LossBinary
from dataset import make_loader
from utils import save_weights, write_event, write_tensorboard, print_model_summay, set_freeze_layers, set_train_layers, get_freeze_layer_names
from validation import validation_binary
from transforms import DualCompose, ImageOnly, Normalize, HorizontalFlip, VerticalFlip
# -------------------------- 修改:替换AllInOneMeter(基于torchmetrics) --------------------------
class AllInOneMeter:
def __init__(self, num_classes=5):
self.num_classes = num_classes
# 使用BinaryAUROC(torchmetrics 1.2.0兼容方式)
self.auc_meters = [AUC() for _ in range(num_classes)]
# 使用BinaryJaccardIndex(torchmetrics 1.2.0兼容方式)
self.jaccard_meters = [JaccardIndex() for _ in range(num_classes)]
self.loss1_list = []
self.loss2_list = []
self.loss3_list = []
self.total_loss_list = []
def add(self, pred_mask, gt_mask, pred_ind1, pred_ind2, gt_ind, loss1, loss2, loss3, total_loss):
# 处理掩码格式:(B, C, H, W) → 展平为(B*H*W, C)
pred_mask_flat = pred_mask.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
gt_mask_flat = gt_mask.permute(0, 2, 3, 1).reshape(-1, self.num_classes)
# 逐类别更新AUC和Jaccard指标
for c in range(self.num_classes):
self.auc_meters[c].update(pred_mask_flat[:, c], gt_mask_flat[:, c].int())
self.jaccard_meters[c].update(pred_mask_flat[:, c], gt_mask_flat[:, c].int())
# 记录损失
self.loss1_list.append(loss1)
self.loss2_list.append(loss2)
self.loss3_list.append(loss3)
self.total_loss_list.append(total_loss)
def value(self):
# 计算平均损失
avg_loss1 = np.mean(self.loss1_list)
avg_loss2 = np.mean(self.loss2_list)
avg_loss3 = np.mean(self.loss3_list)
avg_total_loss = np.mean(self.total_loss_list)
# 计算每个类别的AUC和Jaccard
aucs = [self.auc_meters[c].compute().item() for c in range(self.num_classes)]
jaccards = [self.jaccard_meters[c].compute().item() for c in range(self.num_classes)]
# 重置指标(为下一个epoch准备)
self.reset()
return {
"loss1": avg_loss1, "loss2": avg_loss2, "loss3": avg_loss3, "total_loss": avg_total_loss,
"auc": aucs, "jaccard": jaccards, "mean_auc": np.mean(aucs), "mean_jaccard": np.mean(jaccards)
}
def reset(self):
# 重置所有指标
for c in range(self.num_classes):
self.auc_meters[c].reset()
self.jaccard_meters[c].reset()
self.loss1_list = []
self.loss2_list = []
self.loss3_list = []
self.total_loss_list = []
# -------------------------------------------------------------------------------------------------
def get_split(train_test_split_file='./data/train_test_id.pickle'):
with open(train_test_split_file,'rb') as f:
train_test_id = pickle.load(f)
train_test_id['total'] = train_test_id[['pigment_network',
'negative_network',
'streaks',
'milia_like_cyst',
'globules']].sum(axis=1)
valid = train_test_id[train_test_id.Split != 'train'].copy()
valid['Split'] = 'train'
train_test_id = pd.concat([train_test_id, valid], axis=0)
return train_test_id
def main():
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('--jaccard-weight', type=float, default=1)
arg('--checkpoint', type=str, default='checkpoint/1_multi_task_unet', help='checkpoint path')
arg('--train-test-split-file', type=str, default='./data/train_test_id.pickle', help='train test split file path')
arg('--image-path', type=str, default='data/task2_h5/', help='image path')
arg('--batch-size', type=int, default=8)
arg('--n-epochs', type=int, default=100)
arg('--optimizer', type=str, default='Adam', help='Adam or SGD')
arg('--lr', type=float, default=0.001)
arg('--workers', type=int, default=4)
arg('--model', type=str, default='UNet16', choices=['UNet', 'UNet11', 'UNet16', 'UNet16BN', 'LinkNet34'])
arg('--model-weight', type=str, default=None)
arg('--resume-path', type=str, default=None)
arg('--attribute', type=str, default='all', choices=['pigment_network', 'negative_network',
'streaks', 'milia_like_cyst',
'globules', 'all'])
args = parser.parse_args()
## folder for checkpoint
checkpoint = Path(args.checkpoint)
checkpoint.mkdir(exist_ok=True, parents=True)
image_path = args.image_path
if args.attribute == 'all':
num_classes = 5
else:
num_classes = 1
args.num_classes = num_classes
### save initial parameters
print('--' * 10)
print(args)
print('--' * 10)
checkpoint.joinpath('params.json').write_text(
json.dumps(vars(args), indent=True, sort_keys=True))
## load pretrained model:修改为ResNet编码器适配(关键变更点)
if args.model == 'UNet':
model = UNet(num_classes=num_classes)
elif args.model == 'UNet11':
# 原pretrained='vgg'改为pretrained=True(适配models.py中ResNet18编码器)
model = UNet11(num_classes=num_classes, pretrained=True)
elif args.model == 'UNet16':
# 原pretrained='vgg'改为pretrained=True(适配models.py中ResNet18编码器)
model = UNet16(num_classes=num_classes, pretrained=True)
elif args.model == 'UNet16BN':
# 原pretrained='vgg'改为pretrained=True(适配models.py中ResNet18编码器)
model = UNet16BN(num_classes=num_classes, pretrained=True)
elif args.model == 'LinkNet34':
# LinkNet34本身基于ResNet34,保持pretrained=True
model = LinkNet34(num_classes=num_classes, pretrained=True)
else:
model = UNet(num_classes=num_classes, input_channels=3)
## multiple GPUs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = nn.DataParallel(model)
model.to(device)
## load pretrained model weight(若指定预训练权重文件)
if args.model_weight is not None:
state = torch.load(args.model_weight)
model.load_state_dict(state['model'])
print('--' * 10)
print('Load pretrained model weight from:', args.model_weight)
print('--' * 10)
## model summary(仅打印一次)
print_model_summay(model)
## define loss function
loss_fn = LossBinary(jaccard_weight=args.jaccard_weight)
## 启用cudnn benchmark模式(输入尺寸固定时加速训练)
cudnn.benchmark = True
## get train-test split data
train_test_id = get_split(args.train_test_split_file)
## print train/val data count
print('--' * 10)
print('num train = {}, num val = {}'.format(
(train_test_id['Split'] == 'train').sum(),
(train_test_id['Split'] != 'train').sum()
))
print('--' * 10)
## 定义训练/验证数据增强(保持原逻辑)
train_transform = DualCompose([
HorizontalFlip(),
VerticalFlip(),
ImageOnly(Normalize())
])
val_transform = DualCompose([
ImageOnly(Normalize())
])
# 备用归一化(若数据集需ImageNet均值方差,可在DataLoader中启用)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
## define data loader
train_loader = make_loader(
train_test_id, image_path, args, train=True, shuffle=True,
train_test_split_file=args.train_test_split_file,
transform=train_transform
)
valid_loader = make_loader(
train_test_id, image_path, args, train=False, shuffle=False, # 验证集无需shuffle,修改为False
train_test_split_file=args.train_test_split_file,
transform=val_transform
)
## 数据格式检查(调试用)
if True:
print('--'*10)
print('Check data format and range:')
train_image, train_mask, train_mask_ind = next(iter(train_loader))
print('train_image.shape:', train_image.shape) # 期望:(B, H, W, C)
print('train_mask.shape:', train_mask.shape) # 期望:(B, H, W, C)
print('train_mask_ind.shape:', train_mask_ind.shape) # 期望:(B, 1)或(B, C)
print('train_image.min:', train_image.min().item())
print('train_image.max:', train_image.max().item())
print('train_mask.min:', train_mask.min().item())
print('train_mask.max:', train_mask.max().item())
print('train_mask_ind.min:', train_mask_ind.min().item())
print('train_mask_ind.max:', train_mask_ind.max().item())
print('--' * 10)
## 验证函数(保持原逻辑)
valid_fn = validation_binary
###########
## 优化器配置
if args.optimizer == 'Adam':
optimizer = Adam(model.parameters(), lr=args.lr)
elif args.optimizer == 'SGD':
optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4) # 新增权重衰减防过拟合
else:
raise ValueError(f"Unsupported optimizer: {args.optimizer}, choose 'Adam' or 'SGD'")
## 损失函数(复用定义)
criterion = loss_fn
## 学习率调度器(修复verbose=True警告,保持原逻辑)
scheduler = ReduceLROnPlateau(optimizer, 'min', factor=0.8, patience=5, verbose=False)
##########
## 恢复训练(若指定resume路径)
previous_valid_loss = 10.0
model_path = checkpoint / 'model.pt'
if args.resume_path is not None and model_path.exists():
try:
state = torch.load(str(model_path), map_location=device)
epoch = state['epoch']
step = state['step']
model.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer']) # 恢复优化器状态(关键)
scheduler.load_state_dict(state['scheduler']) # 恢复调度器状态(关键)
previous_valid_loss = state.get('valid_loss', 10.0)
previous_valid_jaccard = state.get('valid_jaccard', 0.0)
print('--' * 10)
print(f"Restored training from epoch {epoch}, step {step:,}")
print(f"Previous valid loss: {previous_valid_loss:.4f}, previous valid jaccard: {previous_valid_jaccard:.4f}")
print('--' * 10)
except Exception as e:
print(f"Warning: Failed to resume training - {str(e)}")
print("Start training from scratch")
epoch = 1
step = 0
previous_valid_jaccard = 0.0
else:
epoch = 1
step = 0
previous_valid_jaccard = 0.0
#########
## 开始训练
log = checkpoint.joinpath('train.log').open('at', encoding='utf8')
writer = SummaryWriter(log_dir=checkpoint)
meter = AllInOneMeter(num_classes=num_classes) # 初始化多分类指标计算器
print('Start training with ResNet-based model:', args.model)
# 损失权重(可根据任务调整)
w1 = 1.0 # 主掩码损失权重
w2 = 0.5 # 辅助任务1损失权重(center分支)
w3 = 0.5 # 辅助任务2损失权重(mask分支)
for epoch in range(epoch, args.n_epochs + 1):
model.train()
random.seed(epoch) # 固定epoch种子,确保数据增强一致性
start_time = time.time()
meter.reset() # 重置指标计算器
try:
for i, (train_image, train_mask, train_mask_ind) in enumerate(train_loader):
# 数据格式转换:(B, H, W, C) → (B, C, H, W)(适配PyTorch模型输入)
train_image = train_image.permute(0, 3, 1, 2).float()
train_mask = train_mask.permute(0, 3, 1, 2).float()
train_mask_ind = train_mask_ind.unsqueeze(1).float() # 确保维度:(B, 1) → 适配模型输出
# 转移数据到设备(GPU/CPU)
train_image = train_image.to(device, non_blocking=True)
train_mask = train_mask.to(device, non_blocking=True)
train_mask_ind = train_mask_ind.to(device, non_blocking=True)
# 模型前向传播(多输出:主掩码+2个辅助任务)
outputs, outputs_mask_ind1, outputs_mask_ind2 = model(train_image)
# 计算损失(主损失+辅助损失)
# 主掩码损失(使用自定义LossBinary,含Jaccard权重)
loss1 = criterion(outputs, train_mask)
# 辅助任务损失(二分类交叉熵,适配ind标签)
loss2 = F.binary_cross_entropy_with_logits(outputs_mask_ind1, train_mask_ind)
loss3 = F.binary_cross_entropy_with_logits(outputs_mask_ind2, train_mask_ind)
# 总损失(加权求和)
total_loss = loss1 * w1 + loss2 * w2 + loss3 * w3
# 打印批次损失(每10个批次打印一次,避免日志冗余)
if (i + 1) % 10 == 0 or i == 0:
print(f"Epoch {epoch:3d} | Iter {i:3d}/{len(train_loader)-1:3d} | "
f"Loss1: {loss1.item():.4f} | Loss2: {loss2.item():.4f} | "
f"Loss3: {loss3.item():.4f} | Total Loss: {total_loss.item():.4f}")
# 反向传播
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
step += 1
# 更新指标计算器
with torch.no_grad():
train_prob = torch.sigmoid(outputs)
train_mask_ind_prob1 = torch.sigmoid(outputs_mask_ind1)
train_mask_ind_prob2 = torch.sigmoid(outputs_mask_ind2)
meter.add(train_prob, train_mask, train_mask_ind_prob1, train_mask_ind_prob2,
train_mask_ind, loss1.item(), loss2.item(), loss3.item(), total_loss.item())
#########################
## 每个epoch结束后计算指标
epoch_time = time.time() - start_time
train_metrics = meter.value()
train_metrics['epoch_time'] = epoch_time
train_metrics['image'] = train_image.data
train_metrics['mask'] = train_mask.data
train_metrics['prob'] = train_prob.data
# 验证集评估
valid_metrics = valid_fn(model, criterion, valid_loader, device, num_classes)
print(f"Validation - Loss: {valid_metrics['loss1']:.4f}, Mean Jaccard: {valid_metrics['mean_jaccard']:.4f}")
##############
## 写入日志
write_event(log, step, epoch=epoch, train_metrics=train_metrics, valid_metrics=valid_metrics)
#########################
## 写入TensorBoard
write_tensorboard(writer, model, epoch, train_metrics=train_metrics, valid_metrics=valid_metrics)
#########################
## 保存最佳模型
valid_loss = valid_metrics['loss1']
valid_jaccard = valid_metrics['mean_jaccard'] # 使用平均Jaccard作为指标
if valid_loss < previous_valid_loss:
save_weights(model, model_path, epoch + 1, step, train_metrics, valid_metrics,
optimizer, scheduler, valid_loss, valid_jaccard)
previous_valid_loss = valid_loss
print('Save best model by loss')
if valid_jaccard > previous_valid_jaccard:
save_weights(model, model_path, epoch + 1, step, train_metrics, valid_metrics,
optimizer, scheduler, valid_loss, valid_jaccard)
previous_valid_jaccard = valid_jaccard
print('Save best model by jaccard')
#########################
## 调整学习率并打印当前学习率
scheduler.step(valid_metrics['loss1'])
current_lr = optimizer.param_groups[0]['lr'] # 获取当前学习率
print(f"Epoch {epoch:3d} | Current Learning Rate: {current_lr:.6f}")
print(f"Epoch {epoch:3d} completed in {epoch_time:.2f} seconds")
except KeyboardInterrupt:
print("Training interrupted by user")
writer.close()
log.close()
return
except Exception as e:
print(f"Error occurred during training: {str(e)}")
writer.close()
log.close()
raise
writer.close()
log.close()
print("Training completed successfully")
if __name__ == '__main__':
main()
这是完整的train.py代码
from torch import nn
import torch
from torchvision import models
import torchvision
from torch.nn import functional as F
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding=1)
class ConvRelu(nn.Module):
def __init__(self, in_: int, out: int):
super().__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlock(nn.Module):
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlock, self).__init__()
self.in_channels = in_channels
if is_deconv:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class DecoderBlockBN(nn.Module):
"""
Paramaters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockBN, self).__init__()
self.in_channels = in_channels
if is_deconv:
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
ConvRelu(in_channels, middle_channels),
ConvRelu(middle_channels, out_channels),
)
def forward(self, x):
return self.block(x)
class ResNet18Encoder(nn.Module):
def __init__(self, pretrained=True):
super(ResNet18Encoder, self).__init__()
# 修复torchvision 0.13+版本pretrained参数警告(使用weights替代)
if pretrained:
resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
else:
resnet = models.resnet18(weights=None)
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1 # 输出通道数:64
self.layer2 = resnet.layer2 # 输出通道数:128
self.layer3 = resnet.layer3 # 输出通道数:256
self.layer4 = resnet.layer4 # 输出通道数:512
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x1 = self.layer1(x) # 64通道
x2 = self.layer2(x1) # 128通道
x3 = self.layer3(x2) # 256通道
x4 = self.layer4(x3) # 512通道
return x1, x2, x3, x4
# 基础UNet实现
class UNet(nn.Module):
def __init__(self, num_classes=1, input_channels=3, num_filters=64):
super(UNet, self).__init__()
self.num_classes = num_classes
# 编码器部分
self.encoder1 = nn.Sequential(
ConvRelu(input_channels, num_filters),
ConvRelu(num_filters, num_filters)
)
self.pool1 = nn.MaxPool2d(2, 2)
self.encoder2 = nn.Sequential(
ConvRelu(num_filters, num_filters * 2),
ConvRelu(num_filters * 2, num_filters * 2)
)
self.pool2 = nn.MaxPool2d(2, 2)
self.encoder3 = nn.Sequential(
ConvRelu(num_filters * 2, num_filters * 4),
ConvRelu(num_filters * 4, num_filters * 4)
)
self.pool3 = nn.MaxPool2d(2, 2)
self.encoder4 = nn.Sequential(
ConvRelu(num_filters * 4, num_filters * 8),
ConvRelu(num_filters * 8, num_filters * 8)
)
self.pool4 = nn.MaxPool2d(2, 2)
# 中间部分
self.middle = nn.Sequential(
ConvRelu(num_filters * 8, num_filters * 16),
ConvRelu(num_filters * 16, num_filters * 16)
)
# 解码器部分
self.upconv4 = nn.ConvTranspose2d(num_filters * 16, num_filters * 8, kernel_size=2, stride=2)
self.decoder4 = nn.Sequential(
ConvRelu(num_filters * 16, num_filters * 8),
ConvRelu(num_filters * 8, num_filters * 8)
)
self.upconv3 = nn.ConvTranspose2d(num_filters * 8, num_filters * 4, kernel_size=2, stride=2)
self.decoder3 = nn.Sequential(
ConvRelu(num_filters * 8, num_filters * 4),
ConvRelu(num_filters * 4, num_filters * 4)
)
self.upconv2 = nn.ConvTranspose2d(num_filters * 4, num_filters * 2, kernel_size=2, stride=2)
self.decoder2 = nn.Sequential(
ConvRelu(num_filters * 4, num_filters * 2),
ConvRelu(num_filters * 2, num_filters * 2)
)
self.upconv1 = nn.ConvTranspose2d(num_filters * 2, num_filters, kernel_size=2, stride=2)
self.decoder1 = nn.Sequential(
ConvRelu(num_filters * 2, num_filters),
ConvRelu(num_filters, num_filters)
)
# 输出层
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
# 辅助输出
self.aux1 = nn.Conv2d(num_filters * 8, num_classes, kernel_size=1)
self.aux2 = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
# 编码器
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
# 中间
middle = self.middle(self.pool4(enc4))
# 解码器
dec4 = self.upconv4(middle)
dec4 = torch.cat([dec4, enc4], dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat([dec3, enc3], dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat([dec2, enc2], dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat([dec1, enc1], dim=1)
dec1 = self.decoder1(dec1)
# 主输出
x_out_mask = self.final(dec1)
# 辅助输出
aux1_out = self.aux1(dec4)
x_out_empty_ind1 = F.adaptive_avg_pool2d(aux1_out, (1, 1)).squeeze()
aux2_out = self.aux2(dec1)
x_out_empty_ind2 = F.adaptive_max_pool2d(aux2_out, (1, 1)).squeeze()
return x_out_mask, x_out_empty_ind1, x_out_empty_ind2
class UNet11(nn.Module):
def __init__(self, num_classes=1, num_filters=32, pretrained=False):
super().__init__()
self.pool = nn.MaxPool2d(2, 2)
self.num_classes = num_classes
self.encoder = ResNet18Encoder(pretrained=pretrained)
# 解码器通道数修正
self.center = DecoderBlock(
in_channels=512,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 8,
is_deconv=True
)
self.dec5 = DecoderBlock(
in_channels=512 + num_filters * 8,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 8,
is_deconv=True
)
self.dec4 = DecoderBlock(
in_channels=256 + num_filters * 8,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 4,
is_deconv=True
)
self.dec3 = DecoderBlock(
in_channels=128 + num_filters * 4,
middle_channels=num_filters * 4 * 2,
out_channels=num_filters * 2,
is_deconv=True
)
self.dec2 = DecoderBlock(
in_channels=64 + num_filters * 2,
middle_channels=num_filters * 2 * 2,
out_channels=num_filters,
is_deconv=True
)
self.dec1 = ConvRelu(
in_=num_filters,
out=num_filters
)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
self.aux1 = nn.Conv2d(num_filters * 8, num_classes, kernel_size=1)
self.aux2 = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
x1, x2, x3, x4 = self.encoder(x) # x1=64, x2=128, x3=256, x4=512
conv5 = x4
center = self.center(self.pool(conv5))
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, x3], 1))
dec3 = self.dec3(torch.cat([dec4, x2], 1))
dec2 = self.dec2(torch.cat([dec3, x1], 1))
dec1 = self.dec1(dec2)
x_out_mask = self.final(dec1)
# 辅助输出
aux1_out = self.aux1(dec5)
x_out_empty_ind1 = F.adaptive_avg_pool2d(aux1_out, (1, 1)).squeeze()
aux2_out = self.aux2(dec1)
x_out_empty_ind2 = F.adaptive_max_pool2d(aux2_out, (1, 1)).squeeze()
return x_out_mask, x_out_empty_ind1, x_out_empty_ind2
class UNet16(nn.Module):
def __init__(self, num_classes=1, num_filters=32, pretrained=False):
super().__init__()
self.num_classes = num_classes
self.num_filters = num_filters
self.pool = nn.MaxPool2d(2, 2)
self.encoder = ResNet18Encoder(pretrained=pretrained)
# 解码器通道数修正 - 重新设计通道数使其匹配
self.center = DecoderBlock(
in_channels=512,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 8
)
self.center_Conv2d = nn.Conv2d(num_filters * 8, num_classes, kernel_size=1)
self.dec5 = DecoderBlock(
in_channels=512 + num_filters * 8,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 8
)
self.dec4 = DecoderBlock(
in_channels=256 + num_filters * 8,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 8
)
# 修复:确保dec3输出通道与x2的空间尺寸兼容
self.dec3 = DecoderBlock(
in_channels=128 + num_filters * 8,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 4
)
# 修复:确保dec2输入通道与dec3输出和x1匹配
self.dec2 = DecoderBlock(
in_channels=64 + num_filters * 4,
middle_channels=num_filters * 4 * 2,
out_channels=num_filters * 2
)
# 修复:确保dec1输入通道与dec2输出和x1匹配,并处理空间尺寸
self.dec1 = DecoderBlock(
in_channels=num_filters * 2 + 64,
middle_channels=num_filters * 2 * 2,
out_channels=num_filters
)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
x1, x2, x3, x4 = self.encoder(x) # x1=64, x2=128, x3=256, x4=512
conv5 = x4
center = self.center(self.pool(conv5))
center_conv = self.center_Conv2d(center)
x_out_empty_ind1 = F.adaptive_avg_pool2d(center_conv, (1, 1)).squeeze()
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, x3], 1))
dec3 = self.dec3(torch.cat([dec4, x2], 1))
dec2 = self.dec2(torch.cat([dec3, x1], 1))
# 处理空间尺寸不匹配:上采样x1以匹配dec2的尺寸
if x1.size()[2:] != dec2.size()[2:]:
x1_resized = F.interpolate(x1, size=dec2.size()[2:], mode='bilinear', align_corners=False)
else:
x1_resized = x1
dec1 = self.dec1(torch.cat([dec2, x1_resized], 1))
x_out_mask = self.final(dec1)
x_out_empty_ind2 = F.adaptive_max_pool2d(x_out_mask, (1, 1)).squeeze()
return x_out_mask, x_out_empty_ind1, x_out_empty_ind2
class UNet16BN(nn.Module):
def __init__(self, num_classes=1, num_filters=32, pretrained=False):
super().__init__()
self.num_classes = num_classes
self.num_filters = num_filters
self.pool = nn.MaxPool2d(2, 2)
self.encoder = ResNet18Encoder(pretrained=pretrained)
# 解码器通道数修正 - 重新设计通道数使其匹配
self.center = DecoderBlockBN(
in_channels=512,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 8
)
self.center_Conv2d = nn.Conv2d(num_filters * 8, num_classes, kernel_size=1)
self.dec5 = DecoderBlockBN(
in_channels=512 + num_filters * 8,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 8
)
self.dec4 = DecoderBlockBN(
in_channels=256 + num_filters * 8,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 8
)
# 修复:确保dec3输出通道与x2的空间尺寸兼容
self.dec3 = DecoderBlockBN(
in_channels=128 + num_filters * 8,
middle_channels=num_filters * 8 * 2,
out_channels=num_filters * 4
)
# 修复:确保dec2输入通道与dec3输出和x1匹配
self.dec2 = DecoderBlockBN(
in_channels=64 + num_filters * 4,
middle_channels=num_filters * 4 * 2,
out_channels=num_filters * 2
)
# 修复:确保dec1输入通道与dec2输出和x1匹配,并处理空间尺寸
self.dec1 = DecoderBlockBN(
in_channels=num_filters * 2 + 64,
middle_channels=num_filters * 2 * 2,
out_channels=num_filters
)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
x1, x2, x3, x4 = self.encoder(x) # x1=64, x2=128, x3=256, x4=512
conv5 = x4
center = self.center(self.pool(conv5))
center_conv = self.center_Conv2d(center)
x_out_empty_ind1 = F.adaptive_avg_pool2d(center_conv, (1, 1)).squeeze()
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, x3], 1))
dec3 = self.dec3(torch.cat([dec4, x2], 1))
dec2 = self.dec2(torch.cat([dec3, x1], 1))
# 处理空间尺寸不匹配:上采样x1以匹配dec2的尺寸
if x1.size()[2:] != dec2.size()[2:]:
x1_resized = F.interpolate(x1, size=dec2.size()[2:], mode='bilinear', align_corners=False)
else:
x1_resized = x1
dec1 = self.dec1(torch.cat([dec2, x1_resized], 1))
x_out_mask = self.final(dec1)
x_out_empty_ind2 = F.adaptive_max_pool2d(x_out_mask, (1, 1)).squeeze()
return x_out_mask, x_out_empty_ind1, x_out_empty_ind2
class DecoderBlockLinkNet(nn.Module):
def __init__(self, in_channels, n_filters):
super().__init__()
self.relu = nn.ReLU(inplace=True)
# B, C, H, W -> B, C/4, H, W
self.conv1 = nn.Conv2d(in_channels, in_channels // 4, 1)
self.norm1 = nn.BatchNorm2d(in_channels // 4)
# B, C/4, H, W -> B, C/4, 2 * H, 2 * W
self.deconv2 = nn.ConvTranspose2d(in_channels // 4, in_channels // 4, kernel_size=4,
stride=2, padding=1, output_padding=0)
self.norm2 = nn.BatchNorm2d(in_channels // 4)
# B, C/4, H, W -> B, C, H, W
self.conv3 = nn.Conv2d(in_channels // 4, n_filters, 1)
self.norm3 = nn.BatchNorm2d(n_filters)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.deconv2(x)
x = self.norm2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.relu(x)
return x
class LinkNet34(nn.Module):
def __init__(self, num_classes=1, num_channels=3, pretrained=True):
super().__init__()
assert num_channels == 3
self.num_classes = num_classes
filters = [64, 128, 256, 512]
if pretrained:
resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
else:
resnet = models.resnet34(weights=None)
self.firstconv = resnet.conv1
self.firstbn = resnet.bn1
self.firstrelu = resnet.relu
self.firstmaxpool = resnet.maxpool
self.encoder1 = resnet.layer1
self.encoder2 = resnet.layer2
self.encoder3 = resnet.layer3
self.encoder4 = resnet.layer4
# Decoder
self.decoder4 = DecoderBlockLinkNet(filters[3], filters[2])
self.decoder3 = DecoderBlockLinkNet(filters[2], filters[1])
self.decoder2 = DecoderBlockLinkNet(filters[1], filters[0])
self.decoder1 = DecoderBlockLinkNet(filters[0], filters[0])
# Final Classifier
self.finaldeconv1 = nn.ConvTranspose2d(filters[0], 32, 3, stride=2)
self.finalrelu1 = nn.ReLU(inplace=True)
self.finalconv2 = nn.Conv2d(32, 32, 3)
self.finalrelu2 = nn.ReLU(inplace=True)
self.finalconv3 = nn.Conv2d(32, num_classes, 2, padding=1)
# 辅助输出
self.aux1 = nn.Conv2d(filters[2], num_classes, kernel_size=1)
self.aux2 = nn.Conv2d(32, num_classes, kernel_size=1)
def forward(self, x):
# Encoder
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstrelu(x)
x = self.firstmaxpool(x)
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
# Decoder with Skip Connections
d4 = self.decoder4(e4) + e3
d3 = self.decoder3(d4) + e2
d2 = self.decoder2(d3) + e1
d1 = self.decoder1(d2)
# Final Classification
f1 = self.finaldeconv1(d1)
f2 = self.finalrelu1(f1)
f3 = self.finalconv2(f2)
f4 = self.finalrelu2(f3)
f5 = self.finalconv3(f4)
x_out_mask = f5
# 辅助输出
aux1_out = self.aux1(d4)
x_out_empty_ind1 = F.adaptive_avg_pool2d(aux1_out, (1, 1)).squeeze()
aux2_out = self.aux2(f4)
x_out_empty_ind2 = F.adaptive_max_pool2d(aux2_out, (1, 1)).squeeze()
return x_out_mask, x_out_empty_ind1, x_out_empty_ind2
class Conv3BN(nn.Module):
def __init__(self, in_: int, out: int, bn=False):
super().__init__()
self.conv = conv3x3(in_, out)
self.bn = nn.BatchNorm2d(out) if bn else None
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
x = self.activation(x)
return x
这是完整的models的代码,接上面的问题进行修改代码,并将代码完整输出