Tensorflow模型
简单点说,一个tensorflow模型包含了神经网络的结构(graph)和通过训练得到的一系列神经网络的参数。
神经网络的结构(graph)即神经网络的节点(nodes)及其流图(flow),节点是一系列张量,每个张量中运行一个op,流图是这些节点之间的运算关系,神经网络的参数包括训练得到的神经网络各层的权重(weights)和偏置(biases)以及程序中用到的其他变量。
要了解如何保存(持久化,frezee)一个tensorflow模型,首先要了解一下两个概念:
Protocol buffers 协议缓冲区
说到Tensorflow中model的保存,就不得不提Protocol buffers ,Protocol buffers 是tensorflow中保存数据的协议,所有TensorFlow的文件格式都是基于Protocol Buffers的,概括说来,Protocol buffers 是一种语言中立的,平台中立的,可扩展的串行化结构化数据的方式,Protocol buffers实现了这样一种协议,你可以在文本文件中定义数据结构,使用Protocol buffers编译器编译之后,可以把文本文件中定义的数据结构生成C,Python,Java和其他语言的类。Protocol buffers文本文件结构跟xml,json等文件结构类似,其优点是文件更小,序列化读取,保存速度更快,操作更简单,官方文档中说,实现同等功能Protocol buffers文件比xml文件小3~10倍,序列化读取,保存速度快20~100倍。
通过在.proto文件中Protocol buffers message类型来指定希望将序列化信息结构化的方式。每个Protocol buffers message是包含一系列名称 - 值对的信息的小逻辑记录。下面是一个.proto文件的一个非常基本的例子,它定义了一个包含有关人员信息的message:
syntax = "proto2";
package tutorial;
message Person {
required string name = 1;
required int32 id = 2;
optional string email = 3;
enum PhoneType {
MOBILE = 0;
HOME = 1;
WORK = 2;
}
message PhoneNumber {
required string number = 1;
optional PhoneType type = 2 [default = HOME];
}
repeated PhoneNumber phones = 4;
}
message AddressBook {
repeated Person people = 1;
}
syntax = "proto2";
package tutorial;
message Person {
required string name = 1;
required int32 id = 2;
optional string email = 3;
enum PhoneType {
MOBILE = 0;
HOME = 1;
WORK = 2;
}
message PhoneNumber {
required string number = 1;
optional PhoneType type = 2 [default = HOME];
}
repeated PhoneNumber phones = 4;
}
message AddressBook {
repeated Person people = 1;
}
通过Protocol buffers编译器编译成C++,会生成类似下面的代码:
class Person
{
// name
inline bool has_name() const;
inline void clear_name();
inline const ::std::string& name() const;
inline void set_name(const ::std::string& value);
inline void set_name(const char* value);
inline ::std::string* mutable_name();
// id
inline bool has_id() const;
inline void clear_id();
inline int32_t id() const;
inline void set_id(int32_t value);
// email
inline bool has_email() const;
inline void clear_email();
inline const ::std::string& email() const;
inline void set_email(const ::std::string& value);
inline void set_email(const char* value);
inline ::std::string* mutable_email();
// phones
inline int phones_size() const;
inline void clear_phones();
inline const ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >& phones() const;
inline ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >* mutable_phones();
inline const ::tutorial::Person_PhoneNumber& phones(int index) const;
inline ::tutorial::Person_PhoneNumber* mutable_phones(int index);
inline ::tutorial::Person_PhoneNumber* add_phones();
}
Protocol buffers 文件有txt和二进制两种保存格式,文件后缀分别为.pbtxt和.pb,以txt格式保存的文件是可读的。
MetaGraph(元图)
MetaGraph是一个Protocol buffers,tensorflow通过MetaGraph来记录计算图中的节点信息以及运行计算图中节点所需要的元数据,通俗点说就是,MetaGraph包含了这个神经网络的结构和设计这个神经网络所用到的所有变量。
MetaGraph包含一个GraphDef和所有与graph中计算相关的元数据,用于图形的长期存储。 MetaGraph包含继续训练,执行评估或在以前训练过的graph上运行推断所需的信息。
MetaGraph中包含的信息用MetaGraphDef协议缓冲区来表示。简而言之,一个MetaGraph包含了神经网络的结构--graph,也包含了运行这个graph相关的参数(权重,偏置和其他变量),它包含以下字段:
--MetaInfoDef 用于元信息,如版本和其他用户信息。
--GraphDef 用于描述graph。
--SaverDef Saver定义相关信息。
--CollectionDef 进一步描述了模型其他组件的map,如Variables,tf.train.QueueRunner等。为了使Python对象能够从MetaGraphDef序列化,Python类必须实现to_proto()和from_proto()方法 ,并使用register_proto_function将其注册到系统。
因此,要保存一个tensorflow模型,需要保存一下两方面的信息:
①神经网络的结构---MetaGraph
MetaGraph以.meta格式的文件保存,它保存了神经网络的结构(节点及其流图)和神经网络中用到的所有参数,我们称之为神经网络变量(variables),注意:这里所说的变量是变量名,而没有变量值
②神经网络参数值---checkpoint文件
如上所述,这些参数包括训练得到的神经网络各层的权重(weights)和偏置(biases)以及程序中用到的其他变量的值。
tensorflow用checkpoint文件来保存神经网络的所有变量值,文件后缀名.cpkt。0.11之前的版本,通常只有一个.cpkt文件,新版tensorflow使用了两个文件:
model.cpkt.data
model.cpkt.index
其中model.cpkt.data文件中保存了神经
class Person
{
// name
inline bool has_name() const;
inline void clear_name();
inline const ::std::string& name() const;
inline void set_name(const ::std::string& value);
inline void set_name(const char* value);
inline ::std::string* mutable_name();
// id
inline bool has_id() const;
inline void clear_id();
inline int32_t id() const;
inline void set_id(int32_t value);
// email
inline bool has_email() const;
inline void clear_email();
inline const ::std::string& email() const;
inline void set_email(const ::std::string& value);
inline void set_email(const char* value);
inline ::std::string* mutable_email();
// phones
inline int phones_size() const;
inline void clear_phones();
inline const ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >& phones() const;
inline ::google::protobuf::RepeatedPtrField< ::tutorial::Person_PhoneNumber >* mutable_phones();
inline const ::tutorial::Person_PhoneNumber& phones(int index) const;
inline ::tutorial::Person_PhoneNumber* mutable_phones(int index);
inline ::tutorial::Person_PhoneNumber* add_phones();
}