pytorch的vgg19的预训练模型提取图片特征

概要

推荐中item侧的特征,item的头图、封面等是一个比较好的特征。但当物料比较少时不可能自己搭建模型训练提取图像特征。比较好的处理方法就是运用torchvision.models模块中的预训练好的模型进行特征提取。
最近,在相关的内容,并且在提取特征时遇到一个比较隐形的坑。记录一下,希望可以帮到遇到同样问题的小伙伴。

内容

torchvision.models集成了很多模型,大家选择适合自己业务场景的模型。本文选用的是vgg19。代码如下:

import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.autograd import Variable
import numpy as np
from PIL import Image
import pandas as pd

# pretrained参数表示是否运用预训练模型
vgg_model = models.vgg19(pretrained=True)
# 这里重新定义最后一层,默认层输出的维度是1000,我这里重新定义可以输出自己想要的维度。
new_classifier = torch.nn.Sequential(*list(vgg_model.children())[-1][:6])
# (6): Linear(in_features=4096, out_features=1000, bias=True)
new_classifier.add_module("Linear output", torch.nn.Linear(4096, 256))
vgg_model.classifier = new_classifier
trans = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
# 这里convert成三通道,如果本来就是三通道的图片,则可以省略。
im = Image.open(image_dir).convert('RGB')
im = trans(im)
im.unsqueeze_(dim=0)
# Set model to eval mode 这一步很重要
vgg_model = vgg_model.eval()
y = vgg_model(im).data.numpy().tolist()

以上代码就是利用vgg19提取图像特征所有代码。
其中比较要注意的地方:
1、图像转换成RGB三通道,有些图片因为格式等其他问题导致不符合输入要求,容易导致报错。
2、vgg_model = vgg_model.eval()这个代码很重要,因为没有这行代码也能运行成功,但是输出的结果不稳定,同一张图像两次执行的结果不一致。这是因为一些模型使用的模块有不用的训练和评估行为,比如批量正则。需要正确的使用model.train()或model.eval()。
在这里插入图片描述
参考:
1、 Using the pre-trained models
2、PyTorch中文文档

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值