在keras读取某卷积层的数据
获取图片卷积后的值,图片示例:
#-*-coding:utf-8-*-
import tensorflow as tf
import cv2
keras =tf.keras
K=tf.keras.backend
#定义模型这里是227x227x3的输入->卷积->最大池化->卷积,模型没有训练,用的是初始化参数
model=keras.Sequential()
model.add(keras.layers.Conv2D(32,kernel_size=(3,3),strides=(2,2),padding='valid',kernel_initializer='uniform',input_shape=(227,227,3),activation=keras.activations.relu))
model.add(keras.layers.MaxPool2D(strides=[2,2]))
model.add(keras.layers.Conv2D(64,kernel_size=(3,3),strides=(2,2),padding='valid',activation=keras.activations.relu))
img=cv2.imread('d:/cm.jpg') #保存的图片在d盘
#按输入格式改变图数组
img2=cv2.resize(img,(227,227))
img3=img2.reshape([-1,227,227,3]) #输入层函数必须是4维
layer1=K.function([model.layers[0].input],[model.layers[2].output]) #layers[0]表示开始层,layers[2]表示输出层,K.function取局部模型
f1=layer1([img3])[0][0] #第一个[0],表示([],dtype='float')中的[],第二个[0]表示第一个图的三维数组,[img3]表示输入,返回f1表示输出
import matplotlib.pyplot as plt
for i in range(64): #64是第二个卷积的输出特征数
ag=f1[:,:,i] #最后一维表示第i个特征的图
plt.subplot(8,8,i+1)
plt.imshow(ag)
plt.axis('off')
plt.show()
结果:
在tensorflow中读取某卷积层的数据
与上面是同样的模型,获取卷积中的特征图
#-*-coding:utf-8-*-
import tensorflow as tf
import cv2
img=cv2.imread('d:/cm.jpg')
img2=cv2.resize(img,(227,227))
img3=img2.reshape([-1,227,227,3])
#定义模型这里是227x227x3的输入->卷积->最大池化->卷积,模型没有训练,用的是初始化参数
X=tf.placeholder('float',shape=[None,227,227,3],name='X')
w1=tf.get_variable('w1',shape=(3,3,3,32))
b1=tf.get_variable('b1',shape=[32])
conv=tf.nn.conv2d(X,w1,strides=[1,2,2,1],padding='VALID')
conv=tf.nn.bias_add(conv,b1)
conv=tf.nn.relu(conv)
tf.nn.max_pool(conv,ksize=[1,2,2,1],strides=[1,2,2,1],padding='VALID')
w2=tf.get_variable('w2',shape=[3,3,32,64])
b2=tf.get_variable('b2',shape=[64])
conv=tf.nn.conv2d(conv,w2,strides=[1,2,2,1],padding='VALID')
conv=tf.nn.bias_add(conv ,b2)
conv=tf.nn.relu(conv)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
out=sess.run(conv,feed_dict={X:img3})[0]
import matplotlib.pyplot as plt
for i in range(64): #64是第二个卷积的输出特征数
ag=out[:,:,i] #最后一维表示第i个特征的图
plt.subplot(8,8,i+1)
plt.imshow(ag)
plt.axis('off')
plt.show()
结果: