在Resnet50CNN结构下实现图像的特征提取,这里采用的是CV2的图像读入方式,最后再把得到的图像转换成npy格式进行输出得得,图像对应的特征。
# -*- coding: utf-8 -*-
"""
Function: 图像特征的提取,可以依据需求修改CNN的输出,得到不同层网络的输出图像特征
Writer: Zenght
date:2019.2.16
"""
from __future__ import print_function, division, absolute_import
import torch
import torchvision
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision import datasets, models, transforms
import os
import cv2
import time
import copy
import torch.utils.data as data
from Rsenet50 import Resnet
class Net(nn.Module):
# 此处可以添加自行设定的网络结构
def __init__(self):
super(Net, self).__init__()
def cv2_imageloader(path):
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img = cv2.imread(path)
img = cv2.resize(img, (224, 224))
im_arr = np.flo