整体架构图
网络讲解
conv 3x3,ReLu就是卷积层,其中卷积核大小是3x3,然后经过ReLu激活。
copy and crop的意思是复制和裁剪。这块内容我觉得很多人最初和我一样,不明白是什么意思,这里的意思就是对于你输出的尺寸,你需要进行复制并进行中心剪裁。方便和后面上采样生成的尺寸进行拼接。
max pool 2x2,就是最大池化层,卷积核为2x2。
up-conv 2x2:这里对于初学者来说,是最难领悟的地方,因为看不懂这个符号是啥意思。我最初以为是upsample+conv2d,试了一下,好像生成不了符合要求的尺寸,后来想了一下,这个是不是就是反卷积,用来上采样的,然后试了一下,可以实现,并且卷积核也是2x2。本文中使用的就是ConvTranspose2d()函数进行该操作。
conv 1x1 这里就是卷积层,卷积核大小是1x1。
第一块内容
# 由572*572*1变成了570*570*64
self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0)
self.relu1_1 = nn.ReLU(inplace=True)
# 由570*570*64变成了568*568*64
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)
self.relu1_2 = nn.ReLU(inplace=True)
由Unet网络架构图,可以看出输入图像是1x572x572大小,其中的1代表的是通道数(后续可以自己更改成自己想要的,比如3通道),输出通道是64,并且通过conv3x3,得知卷积核为3x3尺寸,并且由图片中的尺寸变成570x570,因此可以得出相关的参数值in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0,整个图片的蓝色箭头的卷积操作都是这样,因此kernel_size=3, stride=1, padding=0可以固定了。只需要更改输入和输出通道数的大小即可。
数据维度变化:1x572x572->64x570x570->64x568x568
最大池化层1
# 采用最大池化进行下采样,图片大小减半,通道数不变,由568*568*64变成284*284*64
self.maxpool_1 = nn.MaxPool2d(kernel_size=2, stride=2)
最大池化的卷积核和步长都设置为2,使得输出尺寸减半,通道数不变。
数据维度变化:64x568x568->64x284x284
第二块内容
self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=0) # 284*284*64->282*282*128
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0) # 282*282*128->280*280*128
self.relu2_2 = nn.ReLU(inplace=True)
数据维度变化:64x284x284->128x282x282->128x280x280
最大池化层2
# 采用最大池化进行下采样 280*280*128->140*140*128
self.maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2)
最大池化的卷积核和步长都设置为2,使得输出尺寸减半,通道数不变。
数据维度变化:128x280x280->128x140x140
Unet左边部分剩下内容,等等等等(有空补)
Unet左边部分汇总
由Unet网络架构图,可以看出,每经过一次卷积+relu操作,图像尺寸-2,可以得出padding=0(VGG16中padding=1,因此使得图像尺寸不变);每经过一次最大池化,图像尺寸减半。
左边部分代码如下:
class Unet(nn.Module):
def __init__(self):
super(Unet, self).__init__()
self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0) # 由572*572*1变成了570*570*64
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0) # 由570*570*64变成了568*568*64
self.relu1_2 = nn.ReLU(inplace=True)
self.maxpool_1 = nn.MaxPool2d(kernel_size=2, stride=2) # 采用最大池化进行下采样,图片大小减半,通道数不变,由568*568*64变成284*284*64
self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=0) # 284*284*64->282*282*128
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0) # 282*282*128->280*280*128
self.relu2_2 = nn.ReLU(inplace=True)
self.maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2) # 采用最大池化进行下采样 280*280*128->140*140*128
self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=0) # 140*140*128->138*138*256
self.relu3_1 = nn.ReLU(inplace=True)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0) # 138*138*256->136*136*256
self.relu3_2 = nn.ReLU(inplace=True)
self.maxpool_3 = nn.MaxPool2d(kernel_size=2, stride=2) # 采用最大池化进行下采样 136*136*256->68*68*256
self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=0) # 68*68*256->66*66*512
self.relu4_1 = nn.ReLU(inplace=True)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0) # 66*66*512->64*64*512
self.relu4_2 = nn.ReLU(inplace=True)
self.maxpool_4 = nn.MaxPool2d(kernel_size=2, stride=2) # 采用最大池化进行下采样 64*64*512->32*32*512
self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=0) # 32*32*512->30*30*1024
self.relu5_1 = nn.ReLU(inplace=True)
self.conv5_2 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0) # 30*30*1024->28*28*1024
self.relu5_2 = nn.ReLU(inplace=True)
前向传播函数中,你不能将所有对象都写成x,因为这个网络涉及到copy and crop,如果你全部都当作x,那么就无法复制和裁剪了,因为每次都是对最终结果进行复制,而不是中间步骤进行复制。
注意一下,下面的写法是错误的。
原因是因为Unet网络需要copy,因此你不能将所有层的输出都定义为x。
**正确做法应该是,在最大池化(下采样)之前,你需要有个新变量保存输出的内容,方便后续进行复制和裁剪。**这里不知道你能否听懂。代码如下:
def forward(self, x):
x1 = self.conv1_1(x)
x1 = self.relu1_1(x1)
x2 = self.conv1_2(x1)
x2 = self.relu1_2(x2) # 这个后续需要使用
down1 = self.maxpool_1(x2)
x3 = self.conv2_1(down1)
x3 = self.relu2_1(x3)
x4 = self.conv2_2(x3)
x4 = self.relu2_2(x4) # 这个后续需要使用
down2 = self.maxpool_2(x4)
x5 = self.conv3_1(down2)
x5 = self.relu3_1(x5)
x6 = self.conv3_2(x5)
x6 = self.relu3_2(x6) # 这个后续需要使用
down3 = self.maxpool_3(x6)
x7 = self.conv4_1(down3)
x7 = self.relu4_1(x7)
x8 = self.conv4_2(x7)
x8 = self.relu4_2(x8) # 这个后续需要使用
down4 = self.maxpool_4(x8)
x9 = self.conv5_1(down4)
x9 = self.relu5_1(x9)
x10 = self.conv5_2(x9)
x10 = self.relu5_2(x10)
右边部分代码讲解
右边部分的架构如下,当然,由于Unet网络的特殊性,不能只看右半边。
右半部分每一层最开始的数据,由两部分组成,一部分由up-conv 2x2的上采样组成,另外一部风是由左边部分进行复制并进行中心裁剪后得到的,然后对这两部分进行拼接。
以最下面的绿色箭头这部分,举个例子。
最下面的是1024x28x28的图像,经过上采样(绿色箭头),得到512x56x56的图像,尺寸扩大一倍,通道数减半。
然后看最下面的灰色的横向箭头。灰色箭头左边的图像是512x64x64,然后对其进行复制并中心裁剪(中心裁剪是看图得出的),最后得到512x56x56,然后和刚刚说的上采样得到的图像进行拼接,最后得出1024x56x56,我这应该讲的很清楚了。我最开始的时候,这地方没有仔细看图,一直在想到底是如何得出的。
有了上面这个例子,大家应该就能理解右半部分了。
接下来就实现这个上面说的。
注意:我在init中只是定义了上采样的函数,没有涉及到copy and crop,这个我放到forward函数中实现。
下面这四个上采样,就是图片中绿色箭头部分,大家可以关注一下数据维度的变化。
上采样部分代码
# 接下来实现上采样中的up-conv2*2
self.up_conv_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0) # 28*28*1024->56*56*512
数据维度变化:1024x28x28->512x56x56
self.up_conv_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0) # 52*52*512->104*104*256
数据维度变化:512x52x52->256x104x104
self.up_conv_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0) # 100*100*256->200*200*128
数据维度变化:256x100x100->128x200x200
self.up_conv_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0) # 196*196*128->392*392*64
数据维度变化:128x196x196->64x392x392
右半部分的卷积
右边部分的卷积层也有四个大层,每个大层经过两个卷积层。
self.conv6_1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=0) # 56*56*1024->54*54*512
self.relu6_1 = nn.ReLU(inplace=True)
self.conv6_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0) # 54*54*512->52*52*512
self.relu6_2 = nn.ReLU(inplace=True)
数据维度变化:1024x56x56->512x54x54->512x52x52
self.conv7_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=0) # 104*104*512->102*102*256
self.relu7_1 = nn.ReLU(inplace=True)
self.conv7_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0) # 102*102*256->100*100*256
self.relu7_2 = nn.ReLU(inplace=True)
数据维度变化:512x104x104->256x102x102->256x100x100
self.conv8_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=0) # 200*200*256->198*198*128
self.relu8_1 = nn.ReLU(inplace=True)
self.conv8_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0) # 198*198*128->196*196*128
self.relu8_2 = nn.ReLU(inplace=True)
数据维度变化:256x200x200->128x198x198->128x196x196
self.conv9_1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=0) # 392*392*128->390*390*64
self.relu9_1 = nn.ReLU(inplace=True)
self.conv9_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0) # 390*390*64->388*388*64
self.relu9_2 = nn.ReLU(inplace=True)
数据维度变化:128x392x392->64x390x390->64x388x388
最后的conv 1x1
# 最后的conv1*1
self.conv_10 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1, stride=1, padding=0)
这个代码就是最后的conv1x1操作,输入通道数为64,输出通道数为2,卷积核大小为1,步长为1,padding=0,使得
数据维度变化:64x388x388->2x388x388
其中的输出通道数可以根据自己的需要进行更改。
copy and crop的实现,并且实现拼接操作
上面的代码是init中定义好的层,copy and crop操作没有能够直接实现的函数,因此我放到forward函数中。
# 中心裁剪,
def crop_tensor(self, tensor, target_tensor):
target_size = target_tensor.size()[2]
tensor_size = tensor.size()[2]
delta = tensor_size - target_size
delta = delta // 2
# 如果原始张量的尺寸为10,而delta为2,那么"delta:tensor_size - delta"将截取从索引2到索引8的部分,长度为6,以使得截取后的张量尺寸变为6。
return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]
首先我实现了一个这样的函数。这个函数可以帮助我将tensor中心裁剪成target_tensor的尺寸,符合Unet网络的需求。
# 第一次上采样,需要"Copy and crop"(复制并裁剪)
up1 = self.up_conv_1(x10) # 得到56*56*512
# 需要对x8进行裁剪,从中心往外裁剪
crop1 = self.crop_tensor(x8, up1)
# 拼接操作
up_1 = torch.cat([crop1, up1], dim=1)
这是第一次实现上采样并且进行拼接。
首先up1 = self.up_conv_1(x10)这段代码实现上采样,得到512x56x56的数据,x8就是经过conv4_2和relu操作后,处在左下角灰色箭头左边的数据,其维度是512x64x64,我们需要将其裁剪成up1的形状,因此可以调用self.crop_tensor函数,得到crop1,其维度和up1一样,都是512x56x56。
然后就可以进行拼接,使用torch.cat()函数对张量列表在指定维度上进行拼接,这里就是将crop1和up1进行在通道数维度上的拼接,最后拼接成1024x56x56大小的数据(由unet架构图中可以看出,crop1在前面,up1在后面)。
然后经过两次卷积后,继续上采样,copy and crop,然后进行拼接。
这是第二次的这个过程:上采样+裁剪+拼接
# 第二次上采样,需要"Copy and crop"(复制并裁剪)
up2 = self.up_conv_2(y2)
# 需要对x6进行裁剪,从中心往外裁剪
crop2 = self.crop_tensor(x6, up2)
# 拼接
up_2 = torch.cat([crop2, up2], dim=1)
同理:经过两次卷积后,继续上采样,copy and crop,然后进行拼接。
这是第三次的这个过程:上采样+裁剪+拼接
# 第三次上采样,需要"Copy and crop"(复制并裁剪)
up3 = self.up_conv_3(y4)
# 需要对x4进行裁剪,从中心往外裁剪
crop3 = self.crop_tensor(x4, up3)
up_3 = torch.cat([crop3, up3], dim=1)
同理:经过两次卷积后,继续上采样,copy and crop,然后进行拼接。
这是第四次的这个过程:上采样+裁剪+拼接
# 第四次上采样,需要"Copy and crop"(复制并裁剪)
up4 = self.up_conv_4(y6)
# 需要对x2进行裁剪,从中心往外裁剪
crop4 = self.crop_tensor(x2, up4)
up_4 = torch.cat([crop4, up4], dim=1)
最终代码展示
这个代码我敢肯定,这是全网最基础、最简单的代码,但是是最适合小白的代码。
并且我配上了本代码的具体名称对应的层,比如x1就是第一个conv 3x3,大家可以自己对应着看。
import torch
import torch.nn as nn
class Unet(nn.Module):
def __init__(self):
super(Unet, self).__init__()
self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0) # 由572*572*1变成了570*570*64
self.relu1_1 = nn.ReLU(inplace=True)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0) # 由570*570*64变成了568*568*64
self.relu1_2 = nn.ReLU(inplace=True)
self.maxpool_1 = nn.MaxPool2d(kernel_size=2, stride=2) # 采用最大池化进行下采样,图片大小减半,通道数不变,由568*568*64变成284*284*64
self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=0) # 284*284*64->282*282*128
self.relu2_1 = nn.ReLU(inplace=True)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0) # 282*282*128->280*280*128
self.relu2_2 = nn.ReLU(inplace=True)
self.maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2) # 采用最大池化进行下采样 280*280*128->140*140*128
self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=0) # 140*140*128->138*138*256
self.relu3_1 = nn.ReLU(inplace=True)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0) # 138*138*256->136*136*256
self.relu3_2 = nn.ReLU(inplace=True)
self.maxpool_3 = nn.MaxPool2d(kernel_size=2, stride=2) # 采用最大池化进行下采样 136*136*256->68*68*256
self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=0) # 68*68*256->66*66*512
self.relu4_1 = nn.ReLU(inplace=True)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0) # 66*66*512->64*64*512
self.relu4_2 = nn.ReLU(inplace=True)
self.maxpool_4 = nn.MaxPool2d(kernel_size=2, stride=2) # 采用最大池化进行下采样 64*64*512->32*32*512
self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=0) # 32*32*512->30*30*1024
self.relu5_1 = nn.ReLU(inplace=True)
self.conv5_2 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0) # 30*30*1024->28*28*1024
self.relu5_2 = nn.ReLU(inplace=True)
# 接下来实现上采样中的up-conv2*2
self.up_conv_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0) # 28*28*1024->56*56*512
self.conv6_1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=0) # 56*56*1024->54*54*512
self.relu6_1 = nn.ReLU(inplace=True)
self.conv6_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0) # 54*54*512->52*52*512
self.relu6_2 = nn.ReLU(inplace=True)
self.up_conv_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0) # 52*52*512->104*104*256
self.conv7_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=0) # 104*104*512->102*102*256
self.relu7_1 = nn.ReLU(inplace=True)
self.conv7_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0) # 102*102*256->100*100*256
self.relu7_2 = nn.ReLU(inplace=True)
self.up_conv_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0) # 100*100*256->200*200*128
self.conv8_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=0) # 200*200*256->198*198*128
self.relu8_1 = nn.ReLU(inplace=True)
self.conv8_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0) # 198*198*128->196*196*128
self.relu8_2 = nn.ReLU(inplace=True)
self.up_conv_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0) # 196*196*128->392*392*64
self.conv9_1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=0) # 392*392*128->390*390*64
self.relu9_1 = nn.ReLU(inplace=True)
self.conv9_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0) # 390*390*64->388*388*64
self.relu9_2 = nn.ReLU(inplace=True)
# 最后的conv1*1
self.conv_10 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1, stride=1, padding=0)
# 中心裁剪,
def crop_tensor(self, tensor, target_tensor):
target_size = target_tensor.size()[2]
tensor_size = tensor.size()[2]
delta = tensor_size - target_size
delta = delta // 2
# 如果原始张量的尺寸为10,而delta为2,那么"delta:tensor_size - delta"将截取从索引2到索引8的部分,长度为6,以使得截取后的张量尺寸变为6。
return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]
def forward(self, x):
x1 = self.conv1_1(x)
x1 = self.relu1_1(x1)
x2 = self.conv1_2(x1)
x2 = self.relu1_2(x2) # 这个后续需要使用
down1 = self.maxpool_1(x2)
x3 = self.conv2_1(down1)
x3 = self.relu2_1(x3)
x4 = self.conv2_2(x3)
x4 = self.relu2_2(x4) # 这个后续需要使用
down2 = self.maxpool_2(x4)
x5 = self.conv3_1(down2)
x5 = self.relu3_1(x5)
x6 = self.conv3_2(x5)
x6 = self.relu3_2(x6) # 这个后续需要使用
down3 = self.maxpool_3(x6)
x7 = self.conv4_1(down3)
x7 = self.relu4_1(x7)
x8 = self.conv4_2(x7)
x8 = self.relu4_2(x8) # 这个后续需要使用
down4 = self.maxpool_4(x8)
x9 = self.conv5_1(down4)
x9 = self.relu5_1(x9)
x10 = self.conv5_2(x9)
x10 = self.relu5_2(x10)
# 第一次上采样,需要"Copy and crop"(复制并裁剪)
up1 = self.up_conv_1(x10) # 得到56*56*512
# 需要对x8进行裁剪,从中心往外裁剪
crop1 = self.crop_tensor(x8, up1)
up_1 = torch.cat([crop1, up1], dim=1)
y1 = self.conv6_1(up_1)
y1 = self.relu6_1(y1)
y2 = self.conv6_2(y1)
y2 = self.relu6_2(y2)
# 第二次上采样,需要"Copy and crop"(复制并裁剪)
up2 = self.up_conv_2(y2)
# 需要对x6进行裁剪,从中心往外裁剪
crop2 = self.crop_tensor(x6, up2)
up_2 = torch.cat([crop2, up2], dim=1)
y3 = self.conv7_1(up_2)
y3 = self.relu7_1(y3)
y4 = self.conv7_2(y3)
y4 = self.relu7_2(y4)
# 第三次上采样,需要"Copy and crop"(复制并裁剪)
up3 = self.up_conv_3(y4)
# 需要对x4进行裁剪,从中心往外裁剪
crop3 = self.crop_tensor(x4, up3)
up_3 = torch.cat([crop3, up3], dim=1)
y5 = self.conv8_1(up_3)
y5 = self.relu8_1(y5)
y6 = self.conv8_2(y5)
y6 = self.relu8_2(y6)
# 第四次上采样,需要"Copy and crop"(复制并裁剪)
up4 = self.up_conv_4(y6)
# 需要对x2进行裁剪,从中心往外裁剪
crop4 = self.crop_tensor(x2, up4)
up_4 = torch.cat([crop4, up4], dim=1)
y7 = self.conv9_1(up_4)
y7 = self.relu9_1(y7)
y8 = self.conv9_2(y7)
y8 = self.relu9_2(y8)
# 最后的conv1*1
out = self.conv_10(y8)
return out
if __name__ == '__main__':
input_data = torch.randn([1, 1, 572, 572])
unet = Unet()
output = unet(input_data)
print(output.shape)
# torch.Size([1, 2, 388, 388])
这段代码包括空行在内,一共写了160行++,因为我是初学者,我懂初学者的痛。
网上的代码是经过封装的,因为Unet网络中涉及到很多重复的操作,大家为了简便代码,都通过定义相同操作的类,通过调用,从而减少代码量,使得代码看起来简短一些。但是这就不利于初学者去学习了,因为一般,大家都不喜欢嵌套,跳来跳去会容易晕(除非你一步一步debug)。
正是由于这一点,我才花了很多时间,来写这个博客,想着从初学者的角度,如何逐行编写网络结构。
后续也会打算从这篇文章开始,把重复操作的步骤,通过定义成类,进行调用,从而使得代码简短一些,也更加符合大佬们的写法。
我相信,经过这两天,VGG16+Unet网络架构从零到一的编写,大家的能力会得到很大的提升。
二、网络结构详解
UNet总体上分为编码器和解码器,其中编码器负责提取特征信息,解码器负责还原特征信息;编码器主要由4个块组成,每个块分别由2个卷积层、1个最大池化层组成。解码器也是由4个块组成,每个块都是由1个上采样层、2个卷积层组成,详细信息请见下图。
三、网络组成部分实现
第1步:导入需要的包
import torch
import torch.nn as nn
import torch.nn.functional as F
第2步:我们需要自定义一个卷积的基础块,该基础块由2个卷积层组成。
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
第3步:我们需要自定义一个编码器的基础块,该块由1个最大池化层和第2步的卷积基础块组成。
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
第4步:我们需要自定义一个解码器的基础块,该基础块由1个上采样层和2个卷积层组成。
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) # 双线性插值
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) # 转置卷积
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
第5步:定义一个最后的输出层
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
四、网络结构实现
- 第1步:我们需要把上述定义的类一股脑的导入到你要定义的网络文件中,因为每个人的文件夹不同,这里就不详细讲述。
- 第2步:初始化你的网络模型参数
- 第3步:编写前向传播方法
class UNet(nn.Module): def __init__(self, args, n_channels, n_classes, bilinear=True): super(UNet, self).__init__() # 简单点讲:就是子类使用父类的初始化方法进行初始化,这会使得代码非常的整洁 self.n_channels = n_channels self.n_classes = n_classes self.bilinear = bilinear """DoubleConv <-> (convolution => [BN] => ReLU) * 2""" self.inc = DoubleConv(n_channels, 64) self.down1 = Down(64, 128) self.down2 = Down(128, 256) self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) self.up1 = Up(1024, 512 // factor, bilinear) self.up2 = Up(512, 256 // factor, bilinear) self.up3 = Up(256, 128 // factor, bilinear) self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) logits = self.outc(x) return logits