文章目录
我的建议是直接从相关的简单项目入手,这里以pytorch版本的U-Net分割代码为例(https://github.com/milesial/Pytorch-UNet)
我会先对项目的代码做一个全面的解析,再结合自身的经验给出一些关于网络模块自定义和调参的经验。
项目整体流程
这里先用伪代码的形式介绍下这个项目的大致流程
# 定义网络
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
# 加载数据集
dataset = BasicDataset(dir_img, dir_mask)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
# 优化参数设置
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
# 训练
for epoch in range(1, epochs+1):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {
epoch}/{
epochs}', unit='img') as pbar:
for batch in train_loader:
optimizer.zero_grad()
images = batch['image']
true_masks = batch['mask']
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
masks_pred = net(images)
loss = criterion(masks_pred, true_masks) \
+ dice_loss(F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True)
optimizer.step()
loss.backward()
# 验证
val_score = evaluate(net, val_loader, device)
定义网络
这里先回顾一下U-Net网络的基本结构,如下图所示
本质上就是先进行多次下采样,再进行多次上采样,途中使用双卷积提取特征,上采样时结合下采样时同尺寸的特征。
本项目中网络代码在/unet目录下,/unet/unet_model.py给出了模型的定义
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
可以看到,作者通过定义了一个U-Net的类来实现网络的定义,class中一共有两个函数:__init__
与forward
。
__init__
函数用于对网络需要使用的参数或模块进行定义与声明,当通过下行代码构建网络时
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
网络会先唤起__init__
函数将传入的n_channel等参数在网络中进行定义,令其成为class中的全局变量,这样才能让class中的其他函数对其访问。
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
forward函数则定义了网络前向传播的过程,
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5