Tensorflow模型持久化与恢复

 Tensorflow模型
简单点说,一个tensorflow模型包含了神经网络的结构(graph)和通过训练得到的一系列神经网络的参数。
神经网络的结构(graph)即神经网络的节点(nodes)及其流图(flow),节点是一系列张量,每个张量中运行一个op,流图是这些节点之间的运算关系,神经网络的参数包括训练得到的神经网络各层的权重(weights)和偏置(biases)以及程序中用到的其他变量。
要了解如何保存(持久化,frezee)一个tensorflow模型,首先要了解一下两个概念:

Protocol buffers 协议缓冲区

说到Tensorflow中model的保存,就不得不提Protocol buffers ,Protocol buffers 是tensorflow中保存数据的协议,所有TensorFlow的文件格式都是基于Protocol Buffers的,概括说来,Protocol buffers 是一种语言中立的,平台中立的,可扩展的串行化结构化数据的方式,Protocol buffers实现了这样一种协议,你可以在文本文件中定义数据结构,使用Protocol buffers编译器编译之后,可以把文本文件中定义的数据结构生成C,Python,Java和其他语言的类。Protocol buffers文本文件结构跟xml,json等文件结构类似,其优点是文件更小,序列化读取,保存速度更快,操作更简单,官方文档中说,实现同等功能Protocol buffers文件比xml文件小3~10倍,序列化读取,保存速度快20~100倍。
通过在.proto文件中Protocol buffers message类型来指定希望将序列化信息结构化的方式。每个Protocol buffers message是包含一系列名称 - 值对的信息的小逻辑记录。下面是一个.proto文件的一个非常基本的例子,它定义了一个包含有关人员信息的message:

syntax = "proto2";  
  
package tutorial;  
  
message Person {  
  required string name = 1;  
  required int32 id = 2;  
  optional string email = 3;  

  enum PhoneType {  
    MOBILE = 0;  
    HOME = 1;  
    WORK = 2;  
  }  
  
  message PhoneNumber {  
    required string number = 1;  
    optional PhoneType type = 2 [default = HOME];  
  }  
  
  repeated PhoneNumber phones = 4;  
}  
  
message AddressBook {  
  repeated Person people = 1;  
}

通过Protocol buffers编译器编译成C++,会生成类似下面的代码:

class Person  
{  
// name  
  inline bool has_name() const;  
  inline void clear_name();  
  inline const ::std::string& name() const;  
  inline void set_name(const ::std::string& value);  
  inline void set_name(const char* value);  
  inline ::std::string* mutable_name();  
  
  // id  
  inline bool has_id() const;  
  inline void clear_id();  
  inline int32_t id() const;  
  inline void set_id(int32_t value);  
  
  // email  
  inline bool has_email() const;  
  inline void clear_email();  
  inline const ::std::string& email() const;  
  inline void set_email(const ::std::string& value);  
  inline void set_email(const char* value);  
  inline ::std::string* mutable_email();  
  
  // phones  
  inline int phones_size() const;  
  inline void clear_phones();  
  inline const ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >& phones() const;  
  inline ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >* mutable_phones();  
  inline const ::tutorial::Person_PhoneNumber& phones(int index) const;  
  inline ::tutorial::Person_PhoneNumber* mutable_phones(int index);  
  inline ::tutorial::Person_PhoneNumber* add_phones();  
}

Protocol buffers 文件有txt和二进制两种保存格式,文件后缀分别为.pbtxt和.pb,以txt格式保存的文件是可读的。

MetaGraph(元图)

MetaGraph是一个Protocol buffers,tensorflow通过MetaGraph来记录计算图中的节点信息以及运行计算图中节点所需要的元数据,通俗点说就是,MetaGraph包含了这个神经网络的结构和设计这个神经网络所用到的所有变量。
MetaGraph包含一个GraphDef和所有与graph中计算相关的元数据,用于图形的长期存储。 MetaGraph包含继续训练,执行评估或在以前训练过的graph上运行推断所需的信息。
MetaGraph中包含的信息用MetaGraphDef协议缓冲区来表示。简而言之,一个MetaGraph包含了神经网络的结构--graph,也包含了运行这个graph相关的参数(权重,偏置和其他变量),它包含以下字段:
--MetaInfoDef 用于元信息,如版本和其他用户信息。
--GraphDef 用于描述graph。
--SaverDef Saver定义相关信息。
--CollectionDef 进一步描述了模型其他组件的map,如Variables,tf.train.QueueRunner等。为了使Python对象能够从MetaGraphDef序列化,Python类必须实现to_proto()和from_proto()方法 ,并使用register_proto_function将其注册到系统。


因此,要保存一个tensorflow模型,需要保存一下两方面的信息:

①神经网络的结构---MetaGraph
MetaGraph以.meta格式的文件保存,它保存了神经网络的结构(节点及其流图)和神经网络中用到的所有参数,我们称之为神经网络变量(variables),注意:这里所说的变量是变量名,而没有变量值
②神经网络参数值---checkpoint文件
如上所述,这些参数包括训练得到的神经网络各层的权重(weights)和偏置(biases)以及程序中用到的其他变量的值。
tensorflow用checkpoint文件来保存神经网络的所有变量值,文件后缀名.cpkt。0.11之前的版本,通常只有一个.cpkt文件,新版tensorflow使用了两个文件:
model.cpkt.data
model.cpkt.index
其中model.cpkt.data文件中保存了神经

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值