Transfer Learning 迁移学习

本文深入探讨了迁移学习在计算机视觉领域的应用,通过两种主要方法——固定特征抽取器和微调卷积网络,介绍如何利用预训练模型提升新任务的表现。文章详细讲解了数据预处理、模型训练、损失函数及优化器的设置,并提供了实用的代码示例。

Transfer Learning

Note: 本笔记是根据 Pytorch tutorialTransfer Learning for Computer Vision Tutorial 代码学习而来。 此处用于图像上的 transfer。
代码可见: Transfer Learning。 更多详细信息, 请点击Transfer Learning-MachineLearning’s Next Frontier
Transfer learning, 通常加载与训练模型后, 我们冻结model的部分参数, 一般只训练model的最后基层, 这样可以保留整个模型前面对物体特征提取的能力。pre-trained模型一定要与新的dataset有共同点, 比如都是图像分类问题, 这样才能有效地吧pre-trained model里的ability to extract feature 迁移到新的model中。

Two main ways that transfer learning is use:

  1. ConvNet as a fixed feature extractor: Here, you ‘freeze’ the weights of all the parameters in the network except that of the final several layers (aka “the head”, usually fully connected layers). These last layers are replaced with new ones initialized with random weights and only these layers are trained.

  2. Finetuning the ConvNet:Instead of random initializaion, the model is initialized using a pretrained network, after which the training proceeds as usual but with a different dataset. Usually the head (or part of it) is also replaced in the network in case there is a different number of outputs. It is common in this method to set the learning rate to a smaller number. This is done because the network is already trained, and only minor changes are required to “finetune” it to a new dataset.

Sometimes we could combine the above two methods: First we can freeze the feature extractor, and train the head. After that, we could unfreeze the feature extractor (or part of it), set the learning rate to something smaller, and continue training.
@ Transfer Learning

1. Data Preprocess

由于特定的网络对与图片要求不一, 我们应该查看图片的数量分布, 长宽高比, 即aspect ration。对于某些类别数量不够, 可以使用data augmentation。

  1. torchvision.transforms: 实现data augmentation (在training dataset中使用), 并且对数据集进行一定的预处理, 包括改变尺寸大小, 转化为tensor, 以及是否需要继续使用Normalization
from torchvision import transforms

transforms_mean = [0.485, 0.456, 0.406] 
transforms_std = [0.229, 0.224, 0.225]

image_transforms = {
	'train': transforms.Compose([
		transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
		transforms.RandomRotation(degrees=15), 
		transforms.ColorJitter(), 
		transforms.RandomHorizontalFlip(), 
		transforms.CenterCrop(size=224), 
		transforms.ToTensor(), 
		transforms.Normalize(mean=transforms_mean, std=transforms_std),
	]), 
	'val': transforms.Compose([
		transforms.Resize(size=256), 
		transforms.CenterCrop(size=224), 
		transforms.ToTensor(), 
		transforms.Normalize(mean=transforms_mean, std=transforms_std),
	]),
}

  1. 创建data set
  • torchvision.datasets.ImageFolder: 一般性的数据加载器。其中root下的文件目录为

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

	也可以采用文件转换的方式读取文件。`filter`函数可过滤掉`Mac os`系统自带的根目录文件。
		import os 
		base_dir = os.path.dirname(os.path.abspath(__file__))  # get the basic absolute path of this py file
		file_names = os.path.join(base_dir, img_path)  
		image_names = list(filter(lambda x: x.endswith('png') file_names))
torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=<function default_loader>, is_valid_file=None)
  • torch.utils.data.Dataset: pytorch 官方文档关于create iterable dataset . This is one of map-style dataset, we use to overwrite the original class to implement our own __getitem__(), and __len__() function. For example, such a dataset, when accessed dataset[idx], could read the idx-thimage and its corresponding label from a folder on the disk. we can write our own class via parenting Dataset.
  1. Create data loader
    torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True, num_workers=4): Here, we are only discussing on these four parameters.
  • dataset: dataset from which to load the data
  • batch_size: how many samples per batch to load
  • shuffle: sometimes, you need to set it to be True in training data set, but there is no need to use it in validation dataset.
  • num_workers: how many subprocesses to use for data loading.

2. Training The Model

  1. 采取加载原有模型参数
    model_ft = copy.deepcopy(pre_model.state_dict()): 采用state_dict()加载 parameters of pre-trained model。
from torchvision import models

model = models.vgg16(pretrained=True)
for param in model.parameters():
	param.requires_grad = False  # freeze all parameters we do not use
# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # here 2 is the 2 classes
  1. 加载部分与训练模型
from torchvision import models
# load resnet50, Net() is own class model
resnet = models.resnet50(pretrained=True)
model = Net(...)

# load parameters
pretrained_dict = resnet.state_dict() 
model_dict = model.state_dict()

# remove keys in pretrainde_dict, which are different from model_dict
pretrained_dict = {k: v for k, v in pretraied_dict.item() if k in model_dict}

# update model_dict
model_dict.update(pretrained_dict)

# load needed state_dict()
model.load_state_dict(model_dict) 

3. Loss Function & Optimizer

微调时, 参数列表应该很长并包含所有参数模型。但是, 当进行特征提取时, 此列表应该很短并且仅包括重塑层的weights 和 bias

# 将模型发送到GPU
model_ft = model_ft.to(device)
# 在此运行中收集要优化/更新的参数。
# 如果我们正在进行微调,我们将更新所有参数。
# 但如果我们正在进行特征提取方法,我们只会更新刚刚初始化的参数,即`requires_grad`的参数为 True。
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract:
params_to_update = []
for name,param in model_ft.named_parameters():
if param.requires_grad == True: params_to_update.append(param) print("\t",name)
else:
for name,param in model_ft.named_parameters():
if param.requires_grad == True: print("\t",name)
# 观察所有参数都在优化
optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

4. Details of Transfer Learning

一般来说, transfer learning 分为一下步骤:

  1. 初始化预训练model
  2. 重组最后一层, 使其具有与新数据集类别数相同的输出数
  3. 为优化算法定义我们想要在训练期间跟新的参数
  4. 运行训练步骤
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值