import os
#import sys
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.layers import Conv2D, Conv1D, ZeroPadding2D
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda, multiply
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
#sys.path.append('../data')
#sys.path.append('../models')
from data.data_generator import DIV2KDatasetMultiple as Database
import h5py
nodes = {'in': ['rec_luma', 'rec_boundary'], 'out': ['org_chroma']}
# Scheme 1 architecture
class get_att(tf.keras.layers.Layer):
def __init__(self):
super(get_att, self).__init__()
def call(self,inputs):
f1, f2 = inputs # att_x [bs, N, N, h], att_b [bs, b, h]
f1 = K.reshape(f1, shape=[K.shape(f1)[0], K.shape(f1)[1] * K.shape(f1)[2], K.shape(f1)[-1]])
y = tf.matmul(f1, f2, transpose_b=True)
return K.softmax(y / 0.5, axis=-1)
class apply_att(tf.keras.layers.Layer):
def __init__(self):
super(apply_att, self).__init__()
def call(self, inputs):
f1, f2, f3 = inputs # att [bs, NxN, b], b [bs, b, D], x_out [bs, N, N, D]
y = K.batch_dot(f1, f2) # [bs, NxN, D]
return K.reshape(y, shape=K.shape(f3))
class CrossIntraModel:
def __init__(self, cf):
self._cf = cf
self.name = "multi"
self.model = self.get_model()
def attentive_join(self, x, b):
att_b = Conv1D(self._cf.att_h, kernel_size=1, strides=1, padding='same', activation='relu', name='att_b')(b)
att_x = Conv2D(self._cf.att_h, kernel_size=1, strides=1, padding='same', activation='relu', name='att_x')(x)
x_out = Conv2D(b.shape[-1], kernel_size=1, strides=1, padding='same', activation='relu', name='att_x1')(x)
att = get_att()([att_x, att_b])
b_out =apply_att()([att, b, x_out])
return multiply([x_out, b_out])
def get_model(self):
l_input = Input((None, None, 1), name='l_input')
b_input = Input((None, 1), name='b_input')
# boundary branch
b = Conv1D(self._cf.bb1, kernel_size=1, strides=1, padding='same', activation='relu', name='b1')(b_input)
b = Conv1D(self._cf.bb2, kernel_size=1, strides=1, padding='same', activation='relu', name='b2')(b)
# luma branch
x = ZeroPadding2D((2, 2))(l_input)
x = Conv2D(self._cf.lb1, kernel_size=3, strides=1, padding='valid', activation=None, name='x1')(x)
x = Conv2D(self._cf.lb2, kernel_size=3, strides=1, padding='valid', activation='relu', name='x2')(x)
# trunk branch
t = self.attentive_join(x, b)
t = Conv2D(self._cf.tb, kernel_size=3, strides=1, padding='same', activation=None, name='t2')(t)
output = Conv2D(2, kernel_size=1, strides=1, padding='same', activation='linear', name='out')(t)
return Model([l_input, b_input], output)
@staticmethod
def norm_mse(y_true, y_pred):
return MeanSquaredError(y_true * 255, y_pred * 255)
def create_dataset(self, features, labels):
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
dataset = dataset.shuffle(buffer_size=len(features)).batch(self._cf.batch_size).prefetch(tf.data.AUTOTUNE)
return dataset
def train(self):
print("Training model: %s" % self.name)
output_path = os.path.join(self._cf.output_path, self._cf.model)
if not os.path.exists(self._cf.output_path): os.mkdir(self._cf.output_path)
if not os.path.exists(output_path): os.mkdir(output_path)
train_data_path = os.path.join(self._cf.data_path, 'train', "4x4.h5")
train_data_in = h5py.File(train_data_path)
train_data_out = h5py.File(train_data_path)
train_data_in = [train_data_in[k] for k in nodes['in']]
train_data_out = [train_data_out[k] for k in nodes['out']]
train_ds = self.create_dataset(np.array(train_data_in), np.array(train_data_out))
checkpoint = ModelCheckpoint(output_path + "/.weights.h5",
monitor='val_loss', verbose=0, mode='min',
save_best_only=True, save_weights_only=True)
early_stop = EarlyStopping(monitor='val_loss', mode="min", patience=self._cf.es_patience)
tensorboard = TensorBoard(log_dir=output_path)
callbacks_list = [checkpoint, early_stop, tensorboard]
optimizer = Adam(self._cf.lr, self._cf.beta)
nb_block_shapes = len(self._cf.block_shape)
self.model.compile(optimizer=optimizer, loss=self.norm_mse, metrics=['accuracy'])
self.model.summary()
self.model.fit(train_ds,
epochs=self._cf.epochs
)
最新发布