12-7 tfrecoderSRESPCN
通过使用ESPCN网络,在flower数据集上将低分辨率图片复原成高分辨率图片并与其他复原函数生成的结果进行比较。
程序:
#1 引入头文件,创建样本数据源
import tensorflow as tf
from slim.datasets import flowers
import numpy as np
import matplotlib.pyplot as plt
slim = tf.contrib.slim
'''----------------------------------------------'''
#6 构建图片质量评估函数
def batch_mse_psnr(dbatch):
im1, im2 = np.split(dbatch, 2)
mse = ((im1 - im2) ** 2).mean(axis=(1, 2))
psnr = np.mean(20 * np.log10(255.0 / np.sqrt(mse)))
return np.mean(mse), psnr
def batch_y_psnr(dbatch):
r, g, b = np.split(dbatch, 3, axis=3)
y = np.squeeze(0.3 * r + 0.59 * g + 0.11 * b)
im1, im2 = np.split(y, 2)
mse = ((im1 - im2) ** 2).mean(axis=(1, 2))
psnr = np.mean(20 * np.log10(255.0 / np.sqrt(mse)))
return psnr
def batch_ssim(dbatch):
im1, im2 = np.split(dbatch, 2)
imgsize = im1.shape[1] * im1.shape[2]
avg1 = im1.mean((1, 2), keepdims=1)
avg2 = im2.mean((1, 2), keepdims=1)
std1 = im1.std((1, 2), ddof=1)
std2 = im2.std((1, 2), ddof=1)
cov = ((im1 - avg1) * (im2 - avg2)).mean((1, 2)) * imgsize / (imgsize - 1)
avg1 = np.squeeze(avg1)
avg2 = np.squeeze(avg2)
k1 = 0.01
k2 = 0.03
c1 = (k1 * 255) ** 2
c2 = (k2 * 255) ** 2
c3 = c2 / 2
return np.mean(
(2 * avg1 * avg2 + c1) * 2 * (cov + c3) / (avg1 ** 2 + avg2 ** 2 + c1) / (std1 ** 2 + std2 ** 2 + c2))
'''----------------------------------------------'''
def showresult(subplot, title, orgimg, thisimg, dopsnr=True):
p = plt.subplot(subplot)
p.axis('off')
p.imshow(np.asarray(thisimg[0], dtype='uint8'))
if dopsnr:
conimg = np.concatenate((orgimg, thisimg))
mse, psnr = batch_mse_psnr(conimg)
ypsnr = batch_y_psnr(conimg)
ssim = batch_ssim(conimg)
p.set_title(title + str(int(psnr)) + " y:" + str(int(ypsnr)) + " s:" + str(ssim))
else:
p.set_title(title)
height = width = 200
batch_size = 4
DATA_DIR = "D:/tmp/data/flowers"
# 选择数据集validation
dataset = flowers.get_split('validation', DATA_DIR)
# 创建一个provider
provider = slim.dataset_data_provider.DatasetDataProvider(dataset, num_readers=2)
# 通过provider的get拿到内容
[image, label] = provider.get(['image', 'label'])
print(image.shape)
'''----------------------------------------------'''
#2 获取批次样本并通过TensorFlow函数实现超分辨率
# 剪辑图片为统一大小
distorted_image = tf.image.resize_image_with_crop_or_pad(image, height, width) # 剪辑尺寸,不够填充
################################################
images, labels = tf.train.batch([distorted_image, label], batch_size=batch_size)
print(images.shape)
x_smalls = tf.image.resize_images(images, (np.int32(height / 2), np.int32(width / 2))) # 缩小2*2倍
x_smalls2 = x_smalls / 255.0
# 还原
x_nearests = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.NEAREST_NEIGHBOR)
x_bilins = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.BILINEAR)
x_bicubics = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.BICUBIC)
'''----------------------------------------------'''
#3 建立ESPCN网络结构
net = slim.conv2d(x_smalls2, 64, 5, activation_fn=tf.nn.tanh)
net = slim.conv2d(net, 32, 3, activation_fn=tf.nn.tanh)
net = slim.conv2d(net, 12, 3, activation_fn=None) # 2*2*3
y_predt = tf.depth_to_space(net, 2)
y_pred = y_predt * 255.0
y_pred = tf.maximum(y_pred, 0)
y_pred = tf.minimum(y_pred, 255)
dbatch = tf.concat([tf.cast(images, tf.float32), y_pred], 0)
'''----------------------------------------------'''
#4 构建loss及优化器
cost = tf.reduce_mean(tf.pow(tf.cast(images, tf.float32) / 255.0 - y_predt, 2))
optimizer = tf.train.AdamOptimizer(0.000001).minimize(cost)
'''----------------------------------------------'''
#5 建立session,运行
training_epochs = 150000
display_step = 200
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# 启动队列
tf.train.start_queue_runners(sess=sess)
# 启动循环开始训练
for epoch in range(training_epochs):
_, c = sess.run([optimizer, cost])
# 显示训练中的详细信息
if epoch % display_step == 0:
d_batch = dbatch.eval()
mse, psnr = batch_mse_psnr(d_batch)
ypsnr = batch_y_psnr(d_batch)
ssim = batch_ssim(d_batch)
print("Epoch:", '%04d' % (epoch + 1),
"cost=", "{:.9f}".format(c), "psnr", psnr, "ypsnr", ypsnr, "ssim", ssim)
print("完成!")
'''----------------------------------------------'''
#7 图示结果
imagesv, label_batch, x_smallv, x_nearestv, x_bilinv, x_bicubicv, y_predv = sess.run(
[images, labels, x_smalls, x_nearests, x_bilins, x_bicubics, y_pred])
print("原", np.shape(imagesv), "缩放后的", np.shape(x_smallv), label_batch)
###显示
plt.figure(figsize=(20, 10))
showresult(161, "org", imagesv, imagesv, False)
showresult(162, "small/4", imagesv, x_smallv, False)
showresult(163, "near", imagesv, x_nearestv)
showresult(164, "biline", imagesv, x_bilinv)
showresult(165, "bicubicv", imagesv, x_bicubicv)
showresult(166, "pred", imagesv, y_predv)
plt.show()
#
## 可视化结果
# plt.figure(figsize=(20,10))
# p1 = plt.subplot(161)
# p2 = plt.subplot(162)
# p3 = plt.subplot(163)
# p4 = plt.subplot(164)
# p5 = plt.subplot(165)
# p6 = plt.subplot(166)
# p1.axis('off')
# p2.axis('off')
# p3.axis('off')
# p4.axis('off')
# p5.axis('off')
# p6.axis('off')
#
#
# p1.imshow(imagesv[0])# 显示图片
# p2.imshow(np.asarray(x_smallv[0], dtype='uint8') )# 显示图片,必须转成uint8才能打印出来
# p3.imshow(np.asarray(x_nearestv[0], dtype='uint8') )# 显示图片
# p4.imshow(np.asarray(x_bilinv[0], dtype='uint8') )# 显示图片
# p5.imshow(np.asarray(x_bicubicv[0], dtype='uint8') )# 显示图片
# p6.imshow(np.asarray(y_predv[0], dtype='uint8') )# 显示图片
#
# p1.set_title("org")
# p2.set_title("small/4")
# p3.set_title("near")
# p4.set_title("biline")
# p5.set_title("bicubicv")
# p6.set_title("pred")
# plt.show()
结果:
(?, ?, 3)
(4, 200, 200, 3)
Epoch: 0001 cost= 0.360416085 psnr 6.669006 ypsnr 6.8186326 ssim 0.021616353
Epoch: 0201 cost= 0.331164747 psnr 8.1468115 ypsnr 7.0099936 ssim 0.030560756
Epoch: 0401 cost= 0.298657089 psnr 6.903029 ypsnr 6.9319057 ssim 0.03709429
Epoch: 0601 cost= 0.278883010 psnr 8.724986 ypsnr 8.624394 ssim 0.0352783
Epoch: 0801 cost= 0.266729325 psnr 6.616158 ypsnr 6.7454033 ssim 0.052763406
Epoch: 1001 cost= 0.221033752 psnr 8.113207 ypsnr 8.400061 ssim 0.08303006
Epoch: 1201 cost= 0.248143092 psnr 8.751854 ypsnr 8.84062 ssim 0.13017127
Epoch: 1401 cost= 0.243928418 psnr 9.626335 ypsnr 10.030566 ssim 0.15815271
Epoch: 1601 cost= 0.190259323 psnr 8.577847 ypsnr 9.032168 ssim 0.14702274
Epoch: 1801 cost= 0.157699853 psnr 11.350587 ypsnr 11.985262 ssim 0.2490645
Epoch: 2001 cost= 0.123353019 psnr 9.742246 ypsnr 10.396275 ssim 0.20674324
Epoch: 2201 cost= 0.115382612 psnr 8.905223 ypsnr 9.6788645 ssim 0.24067946
......
Epoch: 148001 cost= 0.002764479 psnr 26.83477 ypsnr 26.944088 ssim 0.9765523
Epoch: 148201 cost= 0.005125942 psnr 24.819565 ypsnr 25.221066 ssim 0.92076725
Epoch: 148401 cost= 0.002659404 psnr 28.065817 ypsnr 28.351946 ssim 0.9765806
Epoch: 148601 cost= 0.002774108 psnr 24.86807 ypsnr 26.140589 ssim 0.9535808
Epoch: 148801 cost= 0.002453615 psnr 27.56114 ypsnr 27.631792 ssim 0.97236115
Epoch: 149001 cost= 0.004432773 psnr 25.963507 ypsnr 26.101135 ssim 0.9683759
Epoch: 149201 cost= 0.002581447 psnr 27.878935 ypsnr 28.195324 ssim 0.9630526
Epoch: 149401 cost= 0.004874272 psnr 26.993256 ypsnr 27.916882 ssim 0.97514504
Epoch: 149601 cost= 0.002618119 psnr 26.826574 ypsnr 26.896557 ssim 0.91461986
Epoch: 149801 cost= 0.003381842 psnr 23.671938 ypsnr 24.175095 ssim 0.90275687
完成!
原 (4, 200, 200, 3) 缩放后的 (4, 100, 100, 3) [4 2 3 3]
ESPCN实例flowers结果:
12-8 resESPCN
将flower数据集中的图片转换成低分辨率,再通过使用带残差网络的ESPCN网络复原成高分辨率图片,并与其他复原函数的生成结果进行比较。
程序:
import tensorflow as tf
from slim.datasets import flowers
import numpy as np
import matplotlib.pyplot as plt
import os
slim = tf.contrib.slim
tf.reset_default_graph()
def batch_mse_psnr(dbatch):
im1, im2 = np.split(dbatch, 2)
mse = ((im1 - im2) ** 2).mean(axis=(1, 2))
psnr = np.mean(20 * np.log10(255.0 / np.sqrt(mse)))
return np.mean(mse), psnr
def batch_y_psnr(dbatch):
r, g, b = np.split(dbatch, 3, axis=3)
y = np.squeeze(0.3 * r + 0.59 * g + 0.11 * b)
im1, im2 = np.split(y, 2)
mse = ((im1 - im2) ** 2).mean(axis=(1, 2))
psnr = np.mean(20 * np.log10(255.0 / np.sqrt(mse)))
return psnr
def batch_ssim(dbatch):
im1, im2 = np.split(dbatch, 2)
imgsize = im1.shape[1] * im1.shape[2]
avg1 = im1.mean((1, 2), keepdims=1)
avg2 = im2.mean((1, 2), keepdims=1)
std1 = im1.std((1, 2), ddof=1)
std2 = im2.std((1, 2), ddof=1)
cov = ((im1 - avg1) * (im2 - avg2)).mean((1, 2)) * imgsize / (imgsize - 1)
avg1 = np.squeeze(avg1)
avg2 = np.squeeze(avg2)
k1 = 0.01
k2 = 0.03
c1 = (k1 * 255) ** 2
c2 = (k2 * 255) ** 2
c3 = c2 / 2
return np.mean(
(2 * avg1 * avg2 + c1) * 2 * (cov + c3) / (avg1 ** 2 + avg2 ** 2 + c1) / (std1 ** 2 + std2 ** 2 + c2))
def showresult(subplot, title, orgimg, thisimg, dopsnr=True):
p = plt.subplot(subplot)
p.axis('off')
p.imshow(np.asarray(thisimg[0], dtype='uint8'))
if dopsnr:
conimg = np.concatenate((orgimg, thisimg))
mse, psnr = batch_mse_psnr(conimg)
ypsnr = batch_y_psnr(conimg)
ssim = batch_ssim(conimg)
p.set_title(title + str(int(psnr)) + " y:" + str(int(ypsnr)) + " s:" + " s:%.4f" % ssim)
else:
p.set_title(title)
height = width = 256
batch_size = 16
DATA_DIR = "D:/tmp/data/flowers"
# 选择数据集validation
dataset = flowers.get_split('validation', DATA_DIR)
# 创建一个provider
provider = slim.dataset_data_provider.DatasetDataProvider(dataset, num_readers=2)
# 通过provider的get拿到内容
[image, label] = provider.get(['image', 'label'])
print(image.shape)
# 剪辑图片为统一大小
distorted_image = tf.image.resize_image_with_crop_or_pad(image, height, width) # 剪辑尺寸,不够填充
################################################
'''----------------------------------------------'''
#1 修改输入图片分辨率
images, labels = tf.train.batch([distorted_image, label], batch_size=batch_size)
print(images.shape)
x_smalls = tf.image.resize_images(images, (np.int32(height / 4), np.int32(width / 4))) # 缩小4*4倍
x_smalls2 = x_smalls / 255.0
'''----------------------------------------------'''
# 还原
x_nearests = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.NEAREST_NEIGHBOR)
x_bilins = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.BILINEAR)
x_bicubics = tf.image.resize_images(x_smalls, (height, width), tf.image.ResizeMethod.BICUBIC)
####################################
# net = slim.conv2d(x_smalls2, 64, 5,activation_fn = tf.nn.tanh)
# net =slim.conv2d(net, 256, 3,activation_fn = tf.nn.tanh)
# net = tf.depth_to_space(net,2) #64
# net =slim.conv2d(net, 64, 3,activation_fn = tf.nn.tanh)
# net = tf.depth_to_space(net,2) #16
# y_predt = slim.conv2d(net, 3, 3,activation_fn = None)#2*2*3
######################################
'''----------------------------------------------'''
#2 添加残差网络
def leaky_relu(x, alpha=0.1, name='lrelu'):
with tf.name_scope(name):
x = tf.maximum(x, alpha * x)
return x
def residual_block(nn, i, name='resblock'):
with tf.variable_scope(name + str(i)):
conv1 = slim.conv2d(nn, 64, 3, activation_fn=leaky_relu, normalizer_fn=slim.batch_norm)
conv2 = slim.conv2d(conv1, 64, 3, activation_fn=leaky_relu, normalizer_fn=slim.batch_norm)
return tf.add(nn, conv2)
net = slim.conv2d(x_smalls2, 64, 5, activation_fn=leaky_relu)
block = []
for i in range(16):
block.append(residual_block(block[-1] if i else net, i))
conv2 = slim.conv2d(block[-1], 64, 3, activation_fn=leaky_relu, normalizer_fn=slim.batch_norm)
sum1 = tf.add(conv2, net)
conv3 = slim.conv2d(sum1, 256, 3, activation_fn=None)
ps1 = tf.depth_to_space(conv3, 2)
relu2 = leaky_relu(ps1)
conv4 = slim.conv2d(relu2, 256, 3, activation_fn=None)
ps2 = tf.depth_to_space(conv4, 2) # 再放大两倍 64
relu3 = leaky_relu(ps2)
y_predt = slim.conv2d(relu3, 3, 3, activation_fn=None) # 输出
'''----------------------------------------------'''
y_pred = y_predt * 255.0
y_pred = tf.maximum(y_pred, 0)
y_pred = tf.minimum(y_pred, 255)
dbatch = tf.concat([tf.cast(images, tf.float32), y_pred], 0)
'''----------------------------------------------'''
#3 修改学习率,进行网络训练
learn_rate = 0.001
cost = tf.reduce_mean(tf.pow(tf.cast(images, tf.float32) / 255.0 - y_predt, 2))
optimizer = tf.train.AdamOptimizer(learn_rate).minimize(cost)
# training_epochs =100000
# display_step =5000
training_epochs = 10000
'''----------------------------------------------'''
display_step = 400
'''----------------------------------------------'''
#4 添加检查点
flags = 'b' + str(batch_size) + '_h' + str(height / 4) + '_r' + str(
learn_rate) + '_res' # set for practicers to try different setups
# flags='b'+str(batch_size)+'_r'+str(height/4)+'_depth_conv2d'#set for practicers to try different setups
if not os.path.exists('save'):
os.mkdir('save')
save_path = 'save/tf_' + flags
if not os.path.exists(save_path):
os.mkdir(save_path)
saver = tf.train.Saver(max_to_keep=1) # 生成saver
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
kpt = tf.train.latest_checkpoint(save_path)
print(kpt)
startepo = 0
if kpt != None:
saver.restore(sess, kpt)
ind = kpt.find("-")
startepo = int(kpt[ind + 1:])
print("startepo=", startepo)
# 启动队列
tf.train.start_queue_runners(sess=sess)
# 启动循环开始训练
for epoch in range(startepo, training_epochs):
_, c = sess.run([optimizer, cost])
# 显示训练中的详细信息
if epoch % display_step == 0:
d_batch = dbatch.eval()
mse, psnr = batch_mse_psnr(d_batch)
ypsnr = batch_y_psnr(d_batch)
ssim = batch_ssim(d_batch)
print("Epoch:", '%04d' % (epoch + 1),
"cost=", "{:.9f}".format(c), "psnr", psnr, "ypsnr", ypsnr, "ssim", ssim)
saver.save(sess, save_path + "/tfrecord.cpkt", global_step=epoch)
print("完成!")
saver.save(sess, save_path + "/tfrecord.cpkt", global_step=epoch)
'''----------------------------------------------'''
imagesv, label_batch, x_smallv, x_nearestv, x_bilinv, x_bicubicv, y_predv = sess.run(
[images, labels, x_smalls, x_nearests, x_bilins, x_bicubics, y_pred])
print("原", np.shape(imagesv), "缩放后的", np.shape(x_smallv), label_batch)
# print(np.max(imagesv[0]),np.max(x_bilinv[0]),np.max(x_bicubicv[0]),np.max(y_predv[0]))
# print(np.min(imagesv[0]),np.min(x_bilinv[0]),np.min(x_bicubicv[0]),np.min(y_predv[0]))
###显示
plt.figure(figsize=(20, 10))
showresult(161, "org", imagesv, imagesv, False)
showresult(162, "small/4", imagesv, x_smallv, False)
showresult(163, "near", imagesv, x_nearestv)
showresult(164, "biline", imagesv, x_bilinv)
showresult(165, "bicubicv", imagesv, x_bicubicv)
showresult(166, "pred", imagesv, y_predv)
plt.show()
结果:
(?, ?, 3)
(16, 256, 256, 3)
None
Epoch: 0001 cost= 0.211108640 psnr 11.407981 ypsnr 13.488435 ssim 0.3307791485562411
Epoch: 0401 cost= 0.011509083 psnr 21.501999 ypsnr 21.893381 ssim 0.9146646896505927
Epoch: 0801 cost= 0.006142357 psnr 22.968485 ypsnr 23.852066 ssim 0.9279918317009753
Epoch: 1201 cost= 0.005762197 psnr 23.517471 ypsnr 23.978153 ssim 0.9582196990382146
Epoch: 1601 cost= 0.005582960 psnr 21.631796 ypsnr 21.95909 ssim 0.9040559790367432
Epoch: 2001 cost= 0.006873480 psnr 22.270742 ypsnr 22.955856 ssim 0.9182121559303617
Epoch: 2401 cost= 0.005612638 psnr 22.760986 ypsnr 23.089485 ssim 0.9262228042835311
Epoch: 2801 cost= 0.005443098 psnr 23.420565 ypsnr 23.99352 ssim 0.9442156526190871
Epoch: 3201 cost= 0.005806287 psnr 23.749552 ypsnr 24.477228 ssim 0.9450137807004954
Epoch: 3601 cost= 0.005395472 psnr 25.262682 ypsnr 25.838692 ssim 0.9542044947249698
Epoch: 4001 cost= 0.006084155 psnr 22.919348 ypsnr 23.530197 ssim 0.9359309824799281
Epoch: 4401 cost= 0.005560590 psnr 25.203331 ypsnr 25.738913 ssim 0.9530244535767078
Epoch: 4801 cost= 0.004648690 psnr 25.333511 ypsnr 25.992489 ssim 0.9402358180894312
Epoch: 5201 cost= 0.003865756 psnr 25.506927 ypsnr 25.824717 ssim 0.9539453031428664
Epoch: 5601 cost= 0.003216719 psnr 24.912142 ypsnr 25.47681 ssim 0.9475247548633271
Epoch: 6001 cost= 0.004854371 psnr 23.78054 ypsnr 24.404945 ssim 0.9592952459358722
Epoch: 6401 cost= 0.003208177 psnr 26.982233 ypsnr 27.602211 ssim 0.974860297423353
Epoch: 6801 cost= 0.004153049 psnr 24.687653 ypsnr 25.168243 ssim 0.94472342376237
Epoch: 7201 cost= 0.004355598 psnr 27.507273 ypsnr 28.3989 ssim 0.9794754731337089
Epoch: 7601 cost= 0.006875461 psnr 24.117432 ypsnr 24.569798 ssim 0.9477540112962018
Epoch: 8001 cost= 0.006706636 psnr 22.219862 ypsnr 22.79896 ssim 0.9211242007323356
Epoch: 8401 cost= 0.004559200 psnr 24.15137 ypsnr 24.684011 ssim 0.9578859948346384
Epoch: 8801 cost= 0.005599694 psnr 24.537134 ypsnr 25.12828 ssim 0.9583166929362248
Epoch: 9201 cost= 0.005546704 psnr 23.239655 ypsnr 23.372845 ssim 0.9045935260983738
Epoch: 9601 cost= 0.006495696 psnr 24.761713 ypsnr 25.204636 ssim 0.953236243712651
完成!
原 (16, 256, 256, 3) 缩放后的 (16, 64, 64, 3) [3 3 0 4 4 2 3 0 3 4 3 3 0 2 3 2]
12-9 rsgan
将flower数据集中的图片转换成低分辨率,再通过使用SRGAN网络将其还原成高分辨率,并与其他复原函数的生成结果进行比较。
程序:
结果: