网络模型及测试图片放在一个文件夹下面
#coding=utf-8
import tensorflow as tf
import scipy.io
import scipy.misc
import os
import numpy as np
import matplotlib.pyplot as plt
def nets(data_path,input_img):
layers = (
'conv1_1','relu1_1','conv1_2','relu1_2','pool1',
'conv2_1','relu2_1','conv2_2','relu2_2','pool2',
'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','conv3_4','relu3_4','pool3',
'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','conv4_4','relu4_4','pool4',
'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','conv5_4','relu5_4'
)
data = scipy.io.loadmat(data_path)
mean = data['normalization'][0][0][0]
mean_pixel = np.mean(mean, (0,1))
weigths = data['layers'][0]
_nets = {}
#记录以此前向传播的结果
current = input_img
for i,name in enumerate(layers):
kind = name[:4]
if kind == 'conv':
kernels,biase = weigths[i][0][0][0][0]
kernels = np.transpose(kernels,(1,0,2,3))
biase = np.reshape(biase,(-1))
#kernels因为已经固定好了
current = tf.nn.conv2d(current, tf.constant(kernels), [1,1,1,1], padding = 'SAME')
current = tf.nn.bias_add(current, biase)
if kind == 'relu':
current = tf.nn.relu(current)
if kind == 'pool':
current = tf.nn.max_pool(current, [1,2,2,1],[1,2,2,1], padding = 'SAME')
_nets[name] = current
assert len(_nets) == len(layers)
return _nets,layers,mean_pixel
cwd = os.getcwd()
data_path = cwd + '/data/imagenet-vgg-verydeep-19.mat'
img_path = cwd + '/data/horse.jpg'
input_img = scipy.misc.imread(img_path).astype(np.float32)
#batch_size,h,w,chanel
shape = (1,input_img.shape[0],input_img.shape[1],input_img.shape[2])
with tf.Session() as sess:
image = tf.placeholder(tf.float32, shape = shape)
net, layers, mean_pixel = nets(data_path,image)
img_prepocess = np.array([input_img - mean_pixel])
ax = [ _ for _ in range(len(layers))]
figure = plt.figure(figsize=(24,12))
for i,layer in enumerate(layers):
print('[%d/%d] %s' % (i+1, len(layers), layer))
features = net[layer].eval(feed_dict = {image:img_prepocess})
print('type of feature:{},shape is {}'.format(type(features),features.shape))
ax[i] = figure.add_subplot(4,9,i+1)
plt.imshow(features[0,:, :, 0],cmap = plt.cm.gray)
plt.title('' + layer)
#这个是单步显示,太麻烦,合成一张了
# if True:
# plt.figure(i+1,figsize = (8,6))
# plt.matshow(features[0,:, :, 0],cmap = plt.cm.gray, fignum = i+1)
# plt.title('' + layer)
# plt.colorbar()
# plt.show()
#
#要保存图片需要在show之前使用
plt.savefig('2.png')
plt.show()
print('Done')
以下输入的图片,及最后的测试结果,可以查看这个网络学习到的特征
此处生成图片时,格式一定要注意,png格式,别的好像存不下来