想要兼容代码,比较可行的是参考如下博主指出的方法
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
比如如下的一长条代码中
class CNNBaseModel(object):
"""
Base model for other specific cnn ctpn_models
"""
def __init__(self):
pass
@staticmethod
def conv2d(inputdata, out_channel, kernel_size, padding='SAME',
stride=1, w_init=None, b_init=None,
split=1, use_bias=True, data_format='NHWC', name=None):
"""
Packing the tensorflow conv2d function.
:param name: op name
:param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other
unknown dimensions.
:param out_channel: number of output channel.
:param kernel_size: int so only support square kernel convolution
:param padding: 'VALID' or 'SAME'
:param stride: int so only support square stride
:param w_init: initializer for convolution weights
:param b_init: initializer for bias
:param split: split channels as used in Alexnet mainly group for GPU memory save.
:param use_bias: whether to use bias.
:param data_format: default set to NHWC according tensorflow
:return: tf.Tensor named ``output``
"""
with tf.variable_scope(name):
in_shape = inputdata.get_shape().as_list()
channel_axis = 3 if data_format == 'NHWC' else 1
in_channel = in_shape[channel_axis]
assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
assert in_channel % split == 0
assert out_channel % split == 0
padding = padding.upper()
# if isinstance(kernel_size, list):
# filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel / split, out_channel]
# else:
# filter_shape = [kernel_size, kernel_size] + [in_channel / split, out_channel]
if isinstance(kernel_size, list):
filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel // split, out_channel]
else:
filter_shape = [kernel_size, kernel_size] + [in_c