数据增强,模型微调《动手学深度学习pytorch》

本文探讨了图像增广技术,包括翻转、裁剪和颜色变化,以扩大训练数据集并提升模型泛化能力。同时介绍了模型微调方法,通过在源数据集上预训练模型,再在目标数据集上进行调整,以应对小数据集的挑战。
部署运行你感兴趣的模型镜像

图像增广

通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力。

我们可以对图像进行不同方式的裁剪,使感兴趣的物体出现在不同位置,从而减轻模型对物体出现位置的依赖性。我们也可以调整亮度、色彩等因素来降低模型对色彩的敏感度。

1翻转和裁剪,2变化颜色(亮度、对比度、饱和度和色调)3将多个图像增广方法叠加使用

多样本数据增强:SMOTE,SamplePairing,mixup,无监督的数据增强:GAN

shape_aug = torchvision.transforms.RandomResizedCrop(200, scale=(0.1, 1), ratio=(0.5, 2)) #宽和高之比随机取自 0.5∼2 ,为原面积 10%∼100% 的区域,然后再将该区域的宽和高分别缩放到200像素
apply(img, shape_aug)
apply(img, torchvision.transforms.RandomVerticalFlip()) #垂直(上下)翻转
apply(img, torchvision.transforms.RandomHorizontalFlip()) #实现一半概率的图像水平(左右)翻转。
apply(img, torchvision.transforms.ColorJitter(brightness=0.5, contrast=0, saturation=0, hue=0))#亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue
#叠加方法
augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)

2应用pytorch数据集

flip_aug = torchvision.transforms.Compose([
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ToTensor()])

no_aug = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor()])

train_iter = torchvision.datasets.CIFAR10(root=root, train=is_train, transform=flip_aug, download=False)
test_iter = torchvision.datasets.CIFAR10(root=root, train=not_train, transform=no_aug, download=False)

模型微调

数据集太少不能用用复杂模型。解决:迁移学习

将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

我们介绍迁移学习中的一种常用技术:微调(fine tuning)。由以下四步构成

  1. 源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数
  4. 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision import models
#使用在ImageNet数据集上预训练的ResNet-18作为源模型。
pretrained_net = models.resnet18(pretrained=False) #pretrained=True来自动下载并加载预训练的模型参数
pretrained_net.load_state_dict(torch.load('/home/kesci/input/resnet185352/resnet18-5c106cde.pth')) 

#原来Linear(in_features=512, out_features=1000, bias=True) 改成我们需要的
pretrained_net.fc = nn.Linear(512, 2) #Linear(in_features=512, out_features=2, bias=True)
'''此时,pretrained_net的fc层就被随机初始化了,但是其他层依然保存着预训练得到的参数,参数已经足够好,因此一般只需使用较小的学习率来微调这些参数,而fc中的随机初始化参数一般需要更大的学习率从头训练。
PyTorch可以方便的对模型的不同部分设置不同的学习参数,我们在下面代码中将fc的学习率设为已经预训练过的部分的10倍。'''
output_params = list(map(id, pretrained_net.fc.parameters()))
feature_params = filter(lambda p: id(p) not in output_params, pretrained_net.parameters())

lr = 0.01
optimizer = optim.SGD([{'params': feature_params},
                       {'params': pretrained_net.fc.parameters(), 'lr': lr * 10}],
                       lr=lr, weight_decay=0.001)

def train_fine_tuning(pretrained_net, optimizer, batch_size=128, num_epochs=5):
    train_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/train'), transform=train_augs),
                            batch_size, shuffle=True)
    test_iter = DataLoader(ImageFolder(os.path.join(data_dir, 'hotdog/test'), transform=test_augs),
                           batch_size)
    loss = torch.nn.CrossEntropyLoss()
train_fine_tuning(pretrained_net, optimizer)


#对比,所有都从头训练 
scratch_net = models.resnet18(pretrained=False, num_classes=2) 
lr = 0.1
optimizer = optim.SGD(scratch_net.parameters(), lr=lr, weight_decay=0.001)
train_fine_tuning(scratch_net, optimizer)

 

您可能感兴趣的与本文相关的镜像

Llama Factory

Llama Factory

模型微调
LLama-Factory

LLaMA Factory 是一个简单易用且高效的大型语言模型(Large Language Model)训练与微调平台。通过 LLaMA Factory,可以在无需编写任何代码的前提下,在本地完成上百种预训练模型的微调

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值