在机器学习之中,使用预训练的模型是很常见的事情。但是如果直接使用预训练的模型进行使用的结果会很差,需要进行Fine-tuning,比如ResNet,ResNet3D等模型。那么如何才能在Pytorch之中进行FineTuning呢?下面以ResNet3D为例,说明如何有选择的进行FineTuning。
1. 模型的下载与加载
首先我们需要定义一个ResNet3D-101模型(PytorchVision之中并没有内置),并且下载好预训练的模型(直接在Github搜索ResNet3D Pretrained即可找到)。因为ResNet3D代码过长,因此这里只展示调用,具体代码在末尾给出。
model = ResNet(block=Bottleneck,
layers=[3, 4, 23, 3],
shortcut_type='B',
num_classes=n_classes, # 自己决定
sample_duration=sample_duration, # 序列长度
sample_size=sample_size) # 图片大小
n_finetune_classes = 700 # 预训练的时候是700分类
pretrain = torch.load(pretrained_resnet101_path)
model.fc = nn.Linear(model.fc.in_features, n_finetune_classes)
model.load_state_dict(pretrain['state_dict'])
model.fc = nn.Linear(model.fc.in_features, n_classes)
上面代码是先定义了模型之后,再将模型的输出类别调整为和预训练的分类一致,再在载入之后替换回自己的分类。
2. 选择我们要Fine-Tuning的参数
假设我们只需要Fine-Tuning最后一层FC以及最后的一个卷积层。我们需要通过如下代码获取到我们需要的参数:
ft_begin_index = 4 # 开始Fine-Tuning的层数
ft_module_names = []
for i in range(ft_begin_index, 5):
ft_module_names.append('layer{}'.format(i))
ft_module_names.append('fc')
parameters = []
for k, v in model.named_parameters():
for ft_module in ft_module_names: # fc
if ft_module in k:
parameters.append({'params': v})
break
else:
parameters.append({'params': v, 'lr': 0.0})
在上半段,我们先获取到了我们所有要微调的参数的名称(根据具体的模型案例来决定,但是ResNet几乎都完全一致);在下半段,我们对于微调的参数设置lr为0,将其他的参数的lr不设置。
3. 使用SGD Optimizer验证
我们以及获取到了要微调的参数,我们可以使用SGD来验证一下是否只训练这些参数:
optimizer = torch.optim.SGD(parameters, lr=0.1)
for param_group in optimizer.param_groups:
print(param_group['lr'])
然后我们从输出结果可以看到,只有我们需要微调的几个参数的学习率为0.1,其余全部为0.0 。
0.0
0.0
....
0.1
0.1
由此,我们便可以进行下一步的训练了。
Last. ResNet3D-101 模型
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import math
def conv3x3x3(in_planes, out_planes, stride=1):
# 3x3x3 convolution with padding
return nn.Conv3d(in_channels=in_planes,
out_channels=out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.stride = stride
self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes)
self.conv2 = nn.Conv3d(planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm3d(planes)
self.conv3 = nn.Conv3d(planes,
planes * self.expansion,
kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm3d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,
block,
layers,
sample_size,
sample_duration,
shortcut_type='B',
num_classes=8):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv3d(3,
64,
kernel_size=7,
stride=(1, 2, 2),
padding=(3, 3, 3),
bias=False)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
self.layer2 = self._make_layer(block,
128,
layers[1],
shortcut_type,
stride=2)
self.layer3 = self._make_layer(block,
256,
layers[2],
shortcut_type,
stride=2)
self.layer4 = self._make_layer(block,
512,
layers[3],
shortcut_type,
stride=2)
last_duration = int(math.ceil(sample_duration / 16))
last_size = int(math.ceil(sample_size / 32))
self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size),
stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv3d):
# m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
nn.init.kaiming_normal_(m.weight, mode='fan_out')
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
downsample = None
# downsample case
if stride != 1 or self.in_planes != planes * block.expansion:
if shortcut_type == 'A':
assert True, 'Not implemented!'
else:
downsample = nn.Sequential(
nn.Conv3d(self.in_planes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm3d(planes * block.expansion),
)
layers = []
layers.append(block(self.in_planes, planes, stride, downsample))
self.in_planes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.in_planes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x