(6)使用预训练模型冻结部分层

本文介绍了如何在PyTorch中冻结神经网络的部分层,通过设置参数`requires_grad=False`来避免训练这些层。实验结果显示,对于猫狗分类任务,在10个epoch内,冻结部分网络层的模型表现优于未冻结的情况。
PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

1、可以先看看各层的信息

"""
    函数说明: 打印模型里面各层的信息,陈20221012
    :param    model: 加载的模型
    :return:  
"""
def look_layer(model):
    cnt = 0
    print("看看模型的param.requires_grad和个数1:")
    for name, param in model.named_parameters():
        cnt += 1
        print("now cnt ",cnt,name,param.requires_grad,param.shape)
    print(cnt)   

 2、冻结神经网络参数不训练的参数

param.requires_grad = False   #设置为False后就不会训练了

3、冻结参数的例子

下面的例子是1--45层不训练,从46层开始训练

    target_frozennum = 46  #1--45层不训练
    cnt1 = 0
    for name, param in model.named_parameters():
        cnt1 += 1
        # print("now cnt ",cnt1,name,param.requires_grad,param.shape)
        if cnt1 == target_frozennum:
            break
        param.requires_grad = False
    # print(cnt1)

完整的代码如下

import  torch
from    torch import optim, nn
# import  visdom
# from tensorboardX import SummaryWriter  #(1)引入tensorboardX
from torch.utils.tensorboard import SummaryWriter
 
# import  torchvision
from    torch.utils.data import DataLoader
from    torchvision import transforms, datasets, models
# from    utils import Flatten
 
# from    pokemon import Pokemon
# from    resnet import ResNet18
from    torchvision.models import resnet18
from    PIL import Image
from    tqdm import tqdm
from    torchinfo import summary
import  os
 
 
# batchsz = 32
batch_size = 256
lr = 1e-3
epochs = 10
img_resize = 224
 
# print("cuda:") 
# print(torch.cuda.is_available())
# print(torch.cuda.device_count())
# print(torch.cuda.current_device()) 
# print(torch.cuda.get_device_name(0)) 
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
torch.manual_seed(1234)
 
# tf = transforms.Compose([
#                      transforms.Resize((224,224)),
#                      transforms.ToTensor(),
#      ])
#输入应该是PIL.Image类型
tf = transforms.Compose([
    #匿名函数
    # lambda x:Image.open(x).convert('RGB'), # string path= > image data
    transforms.Resize((int(img_resize*1.25), int(img_resize*1.25))),
    transforms.RandomRotation(15),
    transforms.CenterCrop(img_resize),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
])
# db = torchvision.datasets.ImageFolder(root='pokemon', transform=tf)
# train_db = datasets.ImageFolder(root='D:/pytorch_learning2022/data/cat_dog/train', transform=tf)
# print(train_db.class_to_idx)
# print("个数",len(train_db))
# val_db = datasets.ImageFolder(root='D:/pytorch_learning2022/data/cat_dog/val', transform=tf)
# test_db = datasets.ImageFolder(root='D:/pytorch_learning2022/data/cat_dog/test', transform=tf)
 
train_db = datasets.ImageFolder(root='/home/chen/chen_deep/data/cat_dog/train', transform=tf)
print(train_db.class_to_idx)
print("个数",len(train_db))
val_db = datasets.ImageFolder(root='/home/chen/chen_deep/data/cat_dog/val', transform=tf)
test_db = datasets.ImageFolder(root='/home/chen/chen_deep/data/cat_dog/test', transform=tf)
 
 
# train_db = Pokemon('pokemon', 224, mode='train')
# val_db = Pokemon('pokemon', 224, mode='val')
# test_db = Pokemon('pokemon', 224, mode='test')
# train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
#                           num_workers=4)
# val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
# test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
train_loader = DataLoader(train_db, batch_size = batch_size, shuffle=True)
val_loader = DataLoader(val_db, batch_size = batch_size)
test_loader = DataLoader(test_db, batch_size = batch_size)
 
 
# viz = visdom.Visdom()
#(2)初始化,注意可以给定路径
writer = SummaryWriter('runs/chen_cat_dog_test1')
# pip install tensorboard
# tensorboard --logdir=runs
# http://localhost:6006
 
def evalute(model, loader):
    model.eval()
    
    correct = 0
    total = len(loader.dataset)
 
    for x,y in loader:
        x,y = x.to(device), y.to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
 
    return correct / total
 
"""
    函数说明: 打印模型里面各层的信息,陈20221012
    :param    model: 加载的模型
    :return:  
"""
def look_layer(model):
    cnt = 0
    print("看看模型的param.requires_grad和个数1:")
    for name, param in model.named_parameters():
        cnt += 1
        print("now cnt ",cnt,name,param.requires_grad,param.shape)
    print(cnt)    
 
def main():
    #(1)如用anaconda激活你自己的环境
    # conda env list
    # conda activate chentorch_cp310
    #(2)安装
    # pip install visdom
    #(3)使用
    # python -m visdom.server
    # http://localhost:8097/
 
    # model = ResNet18(5).to(device)
 
    # model = ResNet18(5).to(device)
    #The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`.
    #  You can also use `weights=ResNet18_Weights.DEFAULT`
    # trained_model = resnet18(pretrained=True)
    trained_model = resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model = nn.Sequential(*list(trained_model.children())[:-1], #[b, 512, 1, 1]
                        #   Flatten(), # [b, 512, 1, 1] => [b, 512]
                          nn.Flatten(),
                        #   nn.Linear(512, 5)
                          nn.Linear(512, 2)
                          ).to(device)
    look_layer(model)
    target_frozennum = 46  #1--45层不训练
    cnt1 = 0
    for name, param in model.named_parameters():
        cnt1 += 1
        # print("now cnt ",cnt1,name,param.requires_grad,param.shape)
        if cnt1 == target_frozennum:
            break
        param.requires_grad = False
 
    # x = torch.randn(2, 3, 224, 224)
    # print(model(x).shape)
 
    #一机多卡设置
    # # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'#设置所有可以使用的显卡,共计四块
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'#设置所有可以使用的显卡,共计四块
    # device_ids = [0,1]#选中其中两块
    # model = nn.DataParallel(model, device_ids=device_ids)#并行使用两块
    # #model = torch.nn.Dataparallel(model)  # 默认使用所有的device_ids
 
 
    # batch_size = 256
    summary(model, input_size=(batch_size, 3, 224, 224))
    # print(model)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
 
 
    best_acc, best_epoch = 0, 0
    global_step = 0
    # viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    # viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in tqdm(range(epochs)):
 
        for step, (x,y) in tqdm(enumerate(train_loader)):
 
            # x: [b, 3, 224, 224], y: [b]
            x, y = x.to(device), y.to(device)
            
            model.train()
            logits = model(x)
            loss = criteon(logits, y)
 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
            # viz.line([loss.item()], [global_step], win='loss', update='append')
            #(3)将batch中TrainLoss添加到tensorboardX中
            writer.add_scalar('TrainLoss', loss.item(), global_step=global_step)
            global_step += 1
 
        if epoch % 1 == 0:
 
            val_acc = evalute(model, val_loader)
            if val_acc> best_acc:
                best_epoch = epoch
                best_acc = val_acc
 
                torch.save(model.state_dict(), 'best_cat_dong_transfer_frozensome.mdl')
 
                # viz.line([val_acc], [global_step], win='val_acc', update='append')
                #(4)将epoch中TestAcc添加到tensorboardX中
                writer.add_scalar('TestAcc', val_acc, global_step=epoch) 
 
 
    print('best acc:', best_acc, 'best epoch:', best_epoch)
 
    model.load_state_dict(torch.load('best_cat_dong_transfer_frozensome.mdl'))
    print('loaded from ckpt!')
 
    test_acc = evalute(model, test_loader)
    print('test acc:', test_acc)
 
    #(5)关闭writer
    writer.close()  
    
 
if __name__ == '__main__':
    main()

4、测试结果显示效果感觉比不冻结好

测试结果显示冻结部分网络层,对cat和dog这个简单的分类,10epoch的情况下效果感觉要好

单个图的测试代码如下

 

import  torch
from    torch import optim, nn
 
from    torch.utils.data import DataLoader
from    torchvision import transforms,datasets ,models
 
# from    resnet import ResNet18
from    torchvision.models import resnet18
from    PIL import Image
# from    utils import Flatten
import  os

img_resize = 224
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
 
# device = torch.device('cuda')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
 
 
tf = transforms.Compose([
    #匿名函数
    lambda x:Image.open(x).convert('RGB'), # string path= > image data
    transforms.Resize((img_resize, img_resize)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
])
 
"""
    函数说明: 根据训练结果对模型进行测试
    :param    img_test: 待测试的图片
    :return:  y: 测试结果,分类序号
 """
def model_test_img(model,img_test):
    model.eval()
    # img = Image.open(img_test).convert('RGB') 
    # resize = transforms.Resize([224,224])
    # x = transforms.Resize([img_resize,img_resize])(img)
    # x = transforms.ToTensor()(x)
    x =tf(img_test)
    x = x.to(device)
    x = x.unsqueeze(0)
    x = transforms.Normalize(mean,std)(x)
    # print(x.shape)
    with torch.no_grad():
        logits = model(x)
        pred = logits.argmax(dim=1)
    return pred
 
 
def main():
 
    #(1)如用anaconda激活你自己的环境
    # conda env list
    # conda activate chentorch_cp310
 
    #分类名称
    class_name = ['cat', 'dog']
    # image_file = "D:/pytorch_learning2022/data/pokeman/train/bulbasaur/00000002.jpg"
    # image_file = "/home/chen/chen_deep/data/cat_dog/test/cat/cat.20.jpg"  #出错
    image_file = "/home/chen/chen_deep/data/cat_dog/test/cat/cat.21.jpg"  #出错    
    # image_file = "/home/chen/chen_deep/data/catdogpuretestpic/test/test1/82.jpg"  #出错
    # image_file = "/home/chen/chen_deep/data/catdogpuretestpic/test/test1/83.jpg"     
    # image_file = "/home/chen/chen_deep/data/catdogpuretestpic/test/test1/302.jpg"  
    # image_file = "/home/chen/chen_deep/data/cat_dog/test/cat/cat.10.jpg"    #出错
    # image_file = "/home/chen/chen_deep/data/cat_dog/test/cat/cat.285.jpg"
    # image_file = "/home/chen/chen_deep/data/cat_dog/test/cat/cat.576.jpg"


    # trained_model = resnet18(pretrained=True)
    trained_model = resnet18(weights = models.ResNet18_Weights.DEFAULT)

    model = nn.Sequential(*list(trained_model.children())[:-1], #[b, 512, 1, 1]
                          nn.Flatten(),
                        #   Flatten(), # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 2)
                          ).to(device)
    #发现使用state_dict保存时如果是多GPU
    #测试的时候也要多GPU
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'#设置所有可以使用的显卡,共计四块
    # device_ids = [0,1]#选中其中两块
    # model = nn.DataParallel(model, device_ids=device_ids)#并行使用两块

    model.load_state_dict(torch.load('best_cat_dog_transfer_frozensome.mdl'))
    y = model_test_img(model,image_file)
    print(y)
 
    print("detect result is: ",class_name[y])
 
    img_show = Image.open(image_file)
    img_show.show()
 
if __name__ == '__main__':
    main()

您可能感兴趣的与本文相关的镜像

PyTorch 2.6

PyTorch 2.6

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值