Record修改记录
输入为HxWxD , Stem为 1/2 原尺寸,C=32. 并做残差链接
网络输出先做torch.cat.再做final_conv
import torch.nn as nn
import torch
from torchsummary import summary
BN_MOMENTUM = 0.1
'''
[conv -> bn -> relu -> conv -> bn -> Residual -> relu
'''
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
'''
2x[conv -> bn -> relu] -> Residual -> relu
'''
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm3d(planes * self.expansion,
momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class StageModule(nn.Module):
def __init__(self, input_branches, output_branches, c):
"""
构建对应stage,即用来融合不同尺度的实现
:param input_branches: 输入的分支数,每个分支对应一种尺度
:param output_branches: 输出的分支数
:param c: 输入的第一个分支通道数
"""
super().__init__()
self.input_branches = input_branches
self.output_branches = output_branches
self.branches = nn.ModuleList() # 存储每一个branch上的block
for i in range(self.input_branches): # 每个分支上都先通过不同个BasicBlock
w = c * (2 ** i) # 对应第i个分支的通道数,每一层的通道数要翻倍
branch = nn.Sequential(
BasicBlock(w, w),
BasicBlock(w, w),
BasicBlock(w, w),
BasicBlock(w, w)
)
self.branches.append(branch) # 每一个分支上的Block已构建好
self.fuse_layers = nn.ModuleList() # 用于融合每个分支上的输出
for i in range(self.output_branches):
self.fuse_layers.append(nn.ModuleList())
for j in range(self.input_branches):
if i == j:
# 当输入、输出为同一个分支时不做任何处理
self.fuse_layers[-1].append(nn.Identity())
elif i < j:
# 当输入分支j大于输出分支i时(即输入分支下采样率大于输出分支下采样率),
# 此时需要对输入分支j进行通道调整以及上采样,方便后续相加
self.fuse_layers[-1].append(
nn.Sequential(
nn.Conv3d(c * (2 ** j), c * (2 ** i), kernel_size=1, stride=1, bias=False),
nn.BatchNorm3d(c * (2 ** i), momentum=BN_MOMENTUM),
nn.Upsample(scale_factor=2.0 ** (j - i), mode='trilinear')
)
)
else: # i > j
# 当输入分支j小于输出分支i时(即输入分支下采样率小于输出分支下采样率),
# 此时需要对输入分支j进行通道调整以及下采样,方便后续相加
# 注意,这里每次下采样2x都是通过一个3x3卷积层实现的,4x就是两个,8x就是三个,总共i-j个
ops = []
# 前i-j-1个卷积层不用变通道,只进行下采样
for k in range(i - j - 1):
ops.append(
nn.Sequential(
nn.Conv3d(c * (2 ** j), c * (2 ** j), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(c * (2 ** j), momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
)
# 最后一个卷积层不仅要调整通道,还要进行下采样
ops.append(
nn.Sequential(
nn.Conv3d(c * (2 ** j), c * (2 ** i), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(c * (2 ** i), momentum=BN_MOMENTUM)
)
)
self.fuse_layers[-1].append(nn.Sequential(*ops))
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# 每个分支通过对应的block
x = [branch(xi) for branch, xi in zip(self.branches, x)]
# 接着融合不同尺寸信息
x_fused = []
for i in range(len(self.fuse_layers)):
x_fused.append(
self.relu(
sum([self.fuse_layers[i][j](x[j]) for j in range(len(self.branches))]) # 第j个输出分支对 前面不同分支的输出进行处理,包括不处理(Indenty) 上采样x2 、 x4 ,相加
)
)
return x_fused
class HighResolutionNet(nn.Module):
def __init__(self, base_channel: int = 32, output_channels: int = 6):
super().__init__()
'''
Stem层, 初始图像带步长卷积下采样了两次,变成1/4尺寸的特征图和c=64)
然后进入Layer1. input: 1/4的尺寸+base channel * 4的通道。 只调整channel数
有两个分支,分支1再变为base channel /2 , 分支2变为 1/2尺寸+ base channel
'''
self.conv1 = nn.Conv3d(1, 32, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm3d(32, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv3d(32, 32, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm3d(32, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
# Stage1
downsample = nn.Sequential(
nn.Conv3d(32, 128, kernel_size=1, stride=1, bias=False),
nn.BatchNorm3d(128, momentum=BN_MOMENTUM)
)
'''
Layer1 在不同的stage 一直卷
'''
self.layer1 = nn.Sequential(
Bottleneck(32, 32, downsample=downsample), #ResNet bottleneck 操作,输入为c,输出为4c
Bottleneck(128, 32),
Bottleneck(128, 32),
Bottleneck(128, 32)
)
self.transition1 = nn.ModuleList([ # 两个分支,1/4尺寸和1/8尺寸+
nn.Sequential(
nn.Conv3d(128, base_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm3d(base_channel, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
),
nn.Sequential(
nn.Sequential( # 这里又使用一次Sequential是为了适配原项目中提供的权重
nn.Conv3d(128, base_channel * 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(base_channel * 2, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
)
])
# Stage2
self.stage2 = nn.Sequential(
StageModule(input_branches=2, output_branches=2, c=base_channel)
)
# transition2 ,先对Stage2输出的两个Block不做处理,下采样第二个Block
self.transition2 = nn.ModuleList([
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Sequential(
nn.Sequential(
nn.Conv3d(base_channel * 2, base_channel * 4, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(base_channel * 4, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
)
])
# Stage3
self.stage3 = nn.Sequential(
StageModule(input_branches=3, output_branches=3, c=base_channel),
StageModule(input_branches=3, output_branches=3, c=base_channel),
StageModule(input_branches=3, output_branches=3, c=base_channel),
StageModule(input_branches=3, output_branches=3, c=base_channel)
)
# transition3
self.transition3 = nn.ModuleList([
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Identity(), # None, - Used in place of "None" because it is callable
nn.Sequential(
nn.Sequential(
nn.Conv3d(base_channel * 4, base_channel * 8, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm3d(base_channel * 8, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)
)
)
])
# Stage4
# 注意,最后一个StageModule只输出分辨率最高的特征层
self.stage4 = nn.Sequential(
StageModule(input_branches=4, output_branches=4, c=base_channel),
StageModule(input_branches=4, output_branches=4, c=base_channel),
StageModule(input_branches=4, output_branches=1, c=base_channel)
)
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
# Final layer
self.final_layer = nn.Conv3d(base_channel*2, output_channels, kernel_size=1, stride=1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
residual = x
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x) # stem
x = self.layer1(x)
x = [trans(x) for trans in self.transition1] # x变成了一个列表。每个Stage有好几个输出
x = self.stage2(x) # 把前一层的x输入传入
x = [
self.transition2[0](x[0]),
self.transition2[1](x[1]),
self.transition2[2](x[-1])
] # New branch derives from the "upper" branch only
x = self.stage3(x)
x = [
self.transition3[0](x[0]),
self.transition3[1](x[1]),
self.transition3[2](x[2]),
self.transition3[3](x[-1]),
] # New branch derives from the "upper" branch only
x = self.stage4(x) # 4个输入分支,1个输出分支
# stage4输出为1/2大小,需要上采用和Stem做Concat
x = self.up(x[0])
x = self.final_layer(torch.cat((x, residual),dim=1))
# x = self.up(x)
#print('x shape', x.shape)
return x
if __name__ == '__main__':
torch.cuda.set_device(0)
network = HighResolutionNet()
net = network.cuda().eval()
summary(net,(1,96,96,96))