[Pytorch]如何在Pytorch之中选择性Fine-Tuning(微调)模型的参数

在机器学习之中,使用预训练的模型是很常见的事情。但是如果直接使用预训练的模型进行使用的结果会很差,需要进行Fine-tuning,比如ResNet,ResNet3D等模型。那么如何才能在Pytorch之中进行FineTuning呢?下面以ResNet3D为例,说明如何有选择的进行FineTuning。

1. 模型的下载与加载

首先我们需要定义一个ResNet3D-101模型(PytorchVision之中并没有内置),并且下载好预训练的模型(直接在Github搜索ResNet3D Pretrained即可找到)。因为ResNet3D代码过长,因此这里只展示调用,具体代码在末尾给出。

model = ResNet(block=Bottleneck,
              layers=[3, 4, 23, 3],
              shortcut_type='B',
              num_classes=n_classes, # 自己决定
              sample_duration=sample_duration, # 序列长度
              sample_size=sample_size) # 图片大小
              
n_finetune_classes = 700 # 预训练的时候是700分类
pretrain = torch.load(pretrained_resnet101_path)
model.fc = nn.Linear(model.fc.in_features, n_finetune_classes)
model.load_state_dict(pretrain['state_dict'])
model.fc = nn.Linear(model.fc.in_features, n_classes)

上面代码是先定义了模型之后,再将模型的输出类别调整为和预训练的分类一致,再在载入之后替换回自己的分类。

2. 选择我们要Fine-Tuning的参数

假设我们只需要Fine-Tuning最后一层FC以及最后的一个卷积层。我们需要通过如下代码获取到我们需要的参数:

 ft_begin_index = 4 # 开始Fine-Tuning的层数
 ft_module_names = []
 for i in range(ft_begin_index, 5):
     ft_module_names.append('layer{}'.format(i))
 ft_module_names.append('fc')

 parameters = []
 for k, v in model.named_parameters():
     for ft_module in ft_module_names:  # fc
         if ft_module in k:
             parameters.append({'params': v})
             break
     else:
         parameters.append({'params': v, 'lr': 0.0})

在上半段,我们先获取到了我们所有要微调的参数的名称(根据具体的模型案例来决定,但是ResNet几乎都完全一致);在下半段,我们对于微调的参数设置lr为0,将其他的参数的lr不设置。

3. 使用SGD Optimizer验证

我们以及获取到了要微调的参数,我们可以使用SGD来验证一下是否只训练这些参数:

optimizer = torch.optim.SGD(parameters, lr=0.1)
for param_group in optimizer.param_groups:
    print(param_group['lr'])

然后我们从输出结果可以看到,只有我们需要微调的几个参数的学习率为0.1,其余全部为0.0 。

0.0
0.0
....
0.1
0.1

由此,我们便可以进行下一步的训练了。

Last. ResNet3D-101 模型

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
import math


def conv3x3x3(in_planes, out_planes, stride=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(in_channels=in_planes,
                     out_channels=out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=1,
                     bias=False)


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.stride = stride

        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(planes,
                               planes,
                               kernel_size=3,
                               stride=stride,
                               padding=1,
                               bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes,
                               planes * self.expansion,
                               kernel_size=1,
                               bias=False)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)

        self.downsample = downsample

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self,
                 block,
                 layers,
                 sample_size,
                 sample_duration,
                 shortcut_type='B',
                 num_classes=8):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv3d(3,
                               64,
                               kernel_size=7,
                               stride=(1, 2, 2),
                               padding=(3, 3, 3),
                               bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       shortcut_type,
                                       stride=2)
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       shortcut_type,
                                       stride=2)
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       shortcut_type,
                                       stride=2)
        last_duration = int(math.ceil(sample_duration / 16))
        last_size = int(math.ceil(sample_size / 32))
        self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size),
                                    stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                # m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        # downsample case
        if stride != 1 or self.in_planes != planes * block.expansion:
            if shortcut_type == 'A':
                assert True, 'Not implemented!'
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(self.in_planes,
                              planes * block.expansion,
                              kernel_size=1,
                              stride=stride,
                              bias=False),
                    nn.BatchNorm3d(planes * block.expansion),
                )
        layers = []
        layers.append(block(self.in_planes, planes, stride, downsample))
        self.in_planes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.in_planes, planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值