[size=x-large][color=blue]1、切换backend[/color][/size]
修改~/.keras/keras.json 文件中的 theano 字段为tensorflow即可
官方文档:[url]https://keras.io/backend/[/url]
[size=x-large][color=blue]2、theano和tensorflow卷积核互相转换[/color][/size]
切换backend后,模型运算会出错,原因在于tensorflow中的卷积实际上时相关,二theano中的卷积是真正的卷积!!!!所以,切换backend时,需要对卷积核进行翻转
参见:[url]https://github.com/fchollet/keras/wiki/Converting-convolution-kernels-from-Theano-to-TensorFlow-and-vice-versa[/url]
[b]通用转换代码[/b]如下(theano和tensorflow互转):
[b]tensoflow专用转换代码[/b]如下:
修改~/.keras/keras.json 文件中的 theano 字段为tensorflow即可
官方文档:[url]https://keras.io/backend/[/url]
[size=x-large][color=blue]2、theano和tensorflow卷积核互相转换[/color][/size]
切换backend后,模型运算会出错,原因在于tensorflow中的卷积实际上时相关,二theano中的卷积是真正的卷积!!!!所以,切换backend时,需要对卷积核进行翻转
参见:[url]https://github.com/fchollet/keras/wiki/Converting-convolution-kernels-from-Theano-to-TensorFlow-and-vice-versa[/url]
[b]通用转换代码[/b]如下(theano和tensorflow互转):
from keras import backend as K
from keras.utils.np_utils import convert_kernel
model = model_from_json(open(os.path.join('.', 'model.json')).read())
model.load_weights(os.path.join('.', 'model_weights.h5'))
for layer in model.layers:
if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D','Convolution3D', 'AtrousConvolution2D']:
original_w = K.get_value(layer.W)
converted_w = convert_kernel(original_w)
K.set_value(layer.W, converted_w)
print('running')
K.get_session().run(ops)
print('saving')
model.save_weights('model_weights_anotherBackend.h5')
[b]tensoflow专用转换代码[/b]如下:
from keras import backend as K
from keras.utils.np_utils import convert_kernel
import tensorflow as tf
model = model_from_json(open(os.path.join('.', 'model.json')).read())
model.load_weights(os.path.join('.', 'model_weights.h5'))
ops = []
for layer in model.layers:
if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D', 'Convolution3D', 'AtrousConvolution2D']:
original_w = K.get_value(layer.W)
print(layer.W.name)
print('\t',end='')
print(layer.W.get_shape().to_list())
converted_w = convert_kernel(original_w)
ops.append(tf.assign(layer.W, converted_w).op)
print('running')
K.get_session().run(ops)
print('saving')
model.save_weights('model_weights_tensorflow.h5')