tensorflow自定义网络模型

本文介绍了TensorFlow的Slim模块,包括变量、层的构建,以及arg_scope的使用。此外,还讲述了如何查看ckpt文件中的变量,如使用tf.train.NewCheckpointReader和freeze_graph。同时,讲解了TensorBoard的使用,包括控制依赖、命名空间和标量总结。最后,提到了文本检测模型EAST的搭建过程。

Slim

TF-Slim 模块是 TensorFlow 中最好用的 API 之一。尤其是里面引入的 arg_scope、model_variables、repeat、stack。
TF-Slim 是 TensorFlow 中一个用来构建、训练、评估复杂模型的轻量化库。TF-Slim 模块可以和 TensorFlow 中其它API混合使用。

Slim模块的导入

1
import tensorflow.contrib.slim as slim

Slim 构建模型

可以用 slim、variables、layers 和 scopes 来十分简洁地定义模型。下面对各个部分进行了详细描述:

Slim变量(Variables)
1
2
3
4
5
6
weights = slim.variable('weights',
                        shape=[10, 10, 3 , 3],
                        initializer=tf.truncated_normal_initializer(stddev=0.1),
                        regularizer=slim.l2_regularizer(0.05),
                        device='/CPU:0')
~
Slim 层(Layers)

使用基础(plain)的 TensorFlow 代码:

1
2
3
4
5
6
7
8
9
input = ...
with tf.name_scope('conv1_1') as scope:
  kernel = tf.Variable(tf.truncated_normal([3, 3, 64, 128], dtype=tf.float32,
                                           stddev=1e-1), name='weights')
  conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
  biases = tf.Variable(tf.constant(0.0, shape=[128], dtype=tf.float32),
                       trainable=True, name='biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name=scope)

为了避免代码的重复。Slim 提供了很多方便的神经网络 layers 的高层 op。例如:与上面的代码对应的 Slim 版的代码:

1
2
input = ...
net = slim.conv2d(input, 128, [3, 3], scope='conv1_1')

slim.arg_scope() 函数的使用

这个函数的作用是给list_ops中的内容设置默认值。但是每个list_ops中的每个成员需要用@add_arg_scope修饰才行。所以使用slim.arg_scope()有两个步骤:

  • 使用@slim.add_arg_scope修饰目标函数
  • 用 slim.arg_scope()为目标函数设置默认参数.
    例如如下代码;首先用@slim.add_arg_scope修饰目标函数fun1(),然后利用slim.arg_scope()为它设置默认参数。
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    
    import tensorflow as tf
    slim =tf.contrib.slim
     
    @slim.add_arg_scope
    def fun1(a=0,b=0):
        return (a+b)
     
    with slim.arg_scope([fun1],a=10):
        x=fun1(b=30)
        print(x)
    

运行结果:40
参考链接:
https://blog.youkuaiyun.com/u013921430/article/details/80915696

其他用法见参考链接

https://blog.youkuaiyun.com/wanttifa/article/details/90208398

查看ckpt中变量的几种方法

查看ckpt中变量的方法有三种:

  • 在有model的情况下,使用tf.train.Saver进行restore
  • 使用tf.train.NewCheckpointReader直接读取ckpt文件,这种方法不需要model。
  • 使用tools里的freeze_graph来读取ckpt
    Tips:
  • 如果模型保存为.ckpt的文件,则使用该文件就可以查看.ckpt文件里的变量。ckpt路径为 model.ckpt
  • 如果模型保存为.ckpt-xxx-data (图结构)、.ckpt-xxx.index (参数名)、.ckpt-xxx-meta (参数值)文件,则需要同时拥有这三个文件才行。并且ckpt的路径为 model.ckpt-xxx

    1.基于model来读取ckpt文件里的变量

    1.首先建立起model
    2.从ckpt中恢复变量
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    
    with tf.Graph().as_default() as g: 
      #建立model
      images, labels = cifar10.inputs(eval_data=eval_data) 
      logits = cifar10.inference(images) 
      top_k_op = tf.nn.in_top_k(logits, labels, 1) 
      #从ckpt中恢复变量
      sess = tf.Session()
      saver = tf.train.Saver() #saver = tf.train.Saver(...variables...) # 恢复部分变量时,只需要在Saver里指定要恢复的变量
      save_path = 'ckpt的路径'
      saver.restore(sess, save_path) # 从ckpt中恢复变量
    

注意:基于model来读取ckpt中变量时,model和ckpt必须匹配。

2.使用tf.train.NewCheckpointReader直接读取ckpt文件里的变量,使用tools.inspect_checkpoint里的print_tensors_in_checkpoint_file函数打印ckpt里的东西

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#使用NewCheckpointReader来读取ckpt里的变量
from tensorflow.python import pywrap_tensorflow
checkpoint_path = os.path.join(model_dir, "model.ckpt")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) #tf.train.NewCheckpointReader
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
  print("tensor_name: ", key)
  #print(reader.get_tensor(key))
#使用print_tensors_in_checkpoint_file打印ckpt里的内容
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

print_tensors_in_checkpoint_file(file_name, #ckpt文件名字
                 tensor_name, # 如果为None,则默认为ckpt里的所有变量
                 all_tensors, # bool 是否打印所有的tensor,这里打印出的是tensor的
TensorFlow 是一个强大的机器学习框架,可以用于训练和部署各种类型的神经网络模型。在 TensorFlow 中,用户可以自定义模型,以便适应其特定的需求。下面是一些创建自定义模型的步骤: 1. 确定模型架构:用户需要决定模型的架构,包括层数,每层的神经元数,激活函数等。 2. 定义模型类:用户需要创建一个模型类,并继承 TensorFlow 的 Model 类。在模型类中,用户需要定义模型的结构和前向传递过程。 3. 定义前向传递:在模型类中定义前向传递过程,包括输入和输出张量的形状和类型。用户需要实现 call 方法,该方法将输入张量传递给模型,并返回输出张量。 4. 定义损失函数:用户需要定义一个损失函数,该函数将模型的输出张量和真实标签张量作为输入,并返回一个标量损失值。常用的损失函数包括交叉熵损失函数、均方误差损失函数等。 5. 定义优化器:用户需要定义一个优化器,该优化器将最小化损失函数。常用的优化器包括随机梯度下降优化器、Adam 优化器等。 6. 编译模型:最后,用户需要编译模型,将损失函数和优化器传递给模型。在编译模型时,用户可以指定评估指标,例如准确率或 F1 分数。 7. 训练模型:用户可以使用训练数据训练模型。在每个 epoch 结束时,用户可以使用验证数据评估模型的性能。如果模型表现不佳,则可以调整模型架构或超参数,并重新训练模型。 8. 测试模型:在训练结束后,用户可以使用测试数据测试模型的性能。如果模型表现良好,则可以将模型部署到生产环境中,用于实时预测。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值