自学SRCNN阶段性总结。找了代码里面几个关键步骤记录一下下。
非常感谢身边有个大牛同学。Orz
1.PSNR值计算函数
def PSNR(y_true, y_pred):
max_pixel = 1.0
return 10.0 * tf_log10((max_pixel ** 2) / (tf.keras.backend.mean(tf.keras.backend.square(y_pred - y_true))))
其中tf_log10为自定义求log函数
2.build model函数
第一步,读取label值。
data_dir = os.path.join(os.getcwd(), "checkpoint/train.h5")#读取trian的数据
train_data,train_label = read_data(data_dir)#读取train的label
data_dir = os.path.join(os.getcwd(), "checkpoint/test.h5")#读取test的数据
test_data,test_label = read_data(data_dir)#读取train的label
第二步,添加三个卷积层,分别为64,32,和1,输入模型的图像大小为33*33
model = keras.Sequential()#调用Sequential模型
model.add(keras.layers.Conv2D(64, (9, 9),padding='VALID',activation='relu',input_shape=(33,33,1),kernel_initializer='he_normal'))
model.add(keras.layers.Conv2D(32, (1, 1),padding='VALID',activation='relu',kernel_initializer='he_normal'))
model.add(keras.layers.Conv2D(1, (5, 5),padding='VALID',activation='relu',kernel_initializer='he_normal'))
model.summary()
第三步,保存权重值
checkpoint_path = "/your_path/cp-{epoch:04d}.ckpt"
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1,period=5)#保存权重值
第四步,设置优化器并加上PSNR值为判断标准
model.compile(optimizer=tf.optimizers.Adam(config.learning_rate),loss=tf.keras.losses.MSE,metrics=[PSNR,'accuracy'])
第五步,训练模型,评估模型
model.fit(train_dataset,epochs=1,callbacks=[cp_callback])#训练模型
score = model.evaluate(test_dataset)#对于模型进行评估
第六步:保存结果
result = model.predict(test_dataset)
result = merge(result, [nx, ny])
result = result.squeeze()
image_path = os.path.join(os.getcwd(), config.sample_dir)
image_path = os.path.join(image_path, "test_image.png")
imsave(result, image_path)
3.进行切分与合并
def merge(images, size):
print(images.shape)
h, w = images.shape[1], images.shape[2]
img = np.zeros([h*size[0], w*size[1], 1])
print(img.shape)
j = 0
k = 0
for i in range(13):
while(j<24):
img[(i*21):((i + 1) * 21), (j*21):((j + 1) * 21), :] = images[k, 0:21, 0:21, :]
j += 1
k += 1
if(j == 24):
j = 0
print(k)
return img
思路总结:将一个图片以33*33为规格切成n个小块,将这n个小块单独塞进模型进行训练,最后将输出结果合并起来。
其中要通过三个卷积层,梯度下降等步骤。
训练出的模型为ckpt格式,如若想应用到安卓就需要转为pb格式进行模型迁移。