运行LaneNet,基于Tensorflow2.0兼容Tensorflow1.0运行报错。TypeError: Dimension value must be integer or None , go

本文讲述了如何在TensorFlow2环境中运行TensorFlow1.x的代码,包括导入兼容模块、使用`tf.compat.v1`和`disable_v2_behavior()`函数,以及在遇到`tf.contrib.layers`等API差异时的报错处理策略。重点介绍了`conv2d`函数的迁移和解决初始化错误的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

想要兼容代码,比较可行的是参考如下博主指出的方法

用tensorflow2运行tensorflow1.x代码的方法_tf2调用tf1代码-优快云博客icon-default.png?t=N7T8https://blog.youkuaiyun.com/qq_28941587/article/details/128466526?spm=1001.2101.3001.6650.2&utm_medium=distribute.pc_relevant.none-task-blog-2~default~CTRLIST~Rate-2-128466526-blog-114885736.235%5Ev43%5Epc_blog_bottom_relevance_base5&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2~default~CTRLIST~Rate-2-128466526-blog-114885736.235%5Ev43%5Epc_blog_bottom_relevance_base5&utm_relevant_index=5

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

比如如下的一长条代码中

class CNNBaseModel(object):
    """
    Base model for other specific cnn ctpn_models
    """

    def __init__(self):
        pass

    @staticmethod
    def conv2d(inputdata, out_channel, kernel_size, padding='SAME',
               stride=1, w_init=None, b_init=None,
               split=1, use_bias=True, data_format='NHWC', name=None):
        """
        Packing the tensorflow conv2d function.
        :param name: op name
        :param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other
        unknown dimensions.
        :param out_channel: number of output channel.
        :param kernel_size: int so only support square kernel convolution
        :param padding: 'VALID' or 'SAME'
        :param stride: int so only support square stride
        :param w_init: initializer for convolution weights
        :param b_init: initializer for bias
        :param split: split channels as used in Alexnet mainly group for GPU memory save.
        :param use_bias:  whether to use bias.
        :param data_format: default set to NHWC according tensorflow
        :return: tf.Tensor named ``output``
        """
        with tf.variable_scope(name):
            in_shape = inputdata.get_shape().as_list()
            channel_axis = 3 if data_format == 'NHWC' else 1
            in_channel = in_shape[channel_axis]
            assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
            assert in_channel % split == 0
            assert out_channel % split == 0

            padding = padding.upper()

            # if isinstance(kernel_size, list):
            #     filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel / split, out_channel]
            # else:
            #     filter_shape = [kernel_size, kernel_size] + [in_channel / split, out_channel]
            if isinstance(kernel_size, list):
                filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel // split, out_channel]
            else:
                filter_shape = [kernel_size, kernel_size] + [in_c
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

光芒再现dev

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值