代码
import tensorflow as tf
import matplotlib.pyplot as plt
#读取jpg文件
original_data = tf.read_file("test.jpg")
#解析数据
img_data = tf.image.decode_jpeg(original_data)
img_data = tf.cast(img_data, tf.float64)
#卷积核, 卷积高度 * 宽度 * 通道数 *卷积核个数
filter = tf.Variable(tf.random_normal(shape=[5,5,3,3], dtype=tf.float64))
#卷积操作
img_filter_data = tf.nn.conv2d([img_data], filter, [1, 1, 1, 1], padding='SAME')
#值类型转换
img_filter_data_u64 = tf.cast(img_filter_data, tf.uint64)
with tf.Session() as sess:
tf.global_variables_initializer().run()
img_filter_data_u64 = sess.run(img_filter_data_u64)
b,h,w,c = (img_filter_data_u64.shape)
#如果是灰度图片,需转换一下,才能在plt上显示
if c == 1:
img_filter_data_u64 = img_filter_data_u64.reshape(b,h,w)
plt.imshow(img_filter_data_u64[0])
plt.show()
效果
