作者:chen_h
微信号 & QQ:862251340
微信公众号:coderpai
在 TensorFlow 中定义你的模型,可能会导致一个巨大的代码量。那么,如何去组织代码,使得它是一个高可读性和高可重用的呢?如果你刚刚开始学习代码架构,那么这里有一个例子,不妨学习一下。
定义计算图
当你设计一个模型的时候,从类出发是一个非常好的开始。那么如何来设计一个类的接口呢?通常,我们会为模型设计一些输入接口和输出目标值,并且提供训练接口,验证接口,测试接口等等。
class Model:
def __init__(self, data, target):
data_size = int(data.get_shape()[1])
target_size = int(target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(data, weight) + bias
self._prediction = tf.nn.softmax(incoming)
cross_entropy = -tf.reduce_sum(target, tf.log(self._prediction))
self._optimize = tf.train.RMSPropOptimizer(0.03).minimize(cross_entropy)
mistakes = tf.not_equal(
tf.argmax(target, 1), tf.argmax(self._prediction, 1))
self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
@property
def prediction(self):
return self._prediction
@property
def optimize(self):
return self._optimize
@property
def error(self):
return self._error
基本上,我们都会使用 TensorFlow 提供的代码块来构建我们的模型。但是,它也存在一些问题。最值得注意的一点是,整个图都是在一个函数中定义和构造的,那么这即不可读也不可重复使用。
使用 @property 装饰器
如果你不了解装饰器,那么可以先学习这篇文章。
如果我们只是把代码分割成函数,这肯定是行不通的,因为每次调用函数时,图都会被新代码进行扩展。因此,我们要确保只有当函数第一次被调用时,这个操作才会被添加到图中。这个方式就是懒加载(lazy-loading,使用时才创建)。
class Model:
def __init__(self, data, target):
self.data = data
self.target = target
self._prediction = None
self._optimize = None
self._error = None
@property
def prediction(self):
if not self._prediction:
data_size = int(self.data.get_shape()[1])
target_size = int(self.target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(self.data, weight) + bias
self._prediction = tf.nn.softmax(incoming)
return self._prediction
@property
def optimize(self):
if not self._optimize:
cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
optimizer = tf.train.RMSPropOptimizer(0.03)
self._optimize = optimizer.minimize(cross_entropy)
return self._optimize
@property
def error(self):
if not self._error:
mistakes = tf.not_equal(
tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
self._error = tf.reduce_mean(tf.cast(mistakes, tf.float32))
return self._error
这个代码组织已经比第一个代码好很多了。你的代码现在被组织成一个单独的功能。但是,由于懒加载的逻辑,这个代码看起来还是有点臃肿。让我们来看看如何可以改进这个代码。
Python 是一种相当灵活的语言。所以,让我告诉你如何去掉刚刚例子中的冗余代码。我们将一个像 @property 一样的装饰器,但是只评估一次函数。它将结果存储在一个以装饰函数命名的成员中,并且在随后的调用中返回该值。如果你不是很了解这个装饰器是什么东西,你可以看看这个学习指南。
import functools
def lazy_property(function):
attribute = '_cache_' + function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
使用这个装饰器,我们可以将上面的代码简化如下:
class Model:
def __init__(self, data, target):
self.data = data
self.target = target
self.prediction
self.optimize
self.error
@lazy_property
def prediction(self):
data_size = int(self.data.get_shape()[1])
target_size = int(self.target.get_shape()[1])
weight = tf.Variable(tf.truncated_normal([data_size, target_size]))
bias = tf.Variable(tf.constant(0.1, shape=[target_size]))
incoming = tf.matmul(self.data, weight) + bias
return tf.nn.softmax(incoming)
@lazy_property
def optimize(self):
cross_entropy = -tf.reduce_sum(self.target, tf.log(self.prediction))
optimizer = tf.train.RMSPropOptimizer(0.03)
return optimizer.minimize(cross_entropy)
@lazy_property
def error(self):
mistakes = tf.not_equal(
tf.argmax(self.target, 1), tf.argmax(self.prediction, 1))
return tf.reduce_mean(tf.cast(mistakes, tf.float32))
请注意,我们在装饰器中只是提到了这些属性。完整的图还是要我们运行 tf.initialize_variables()
来定义的。
使用 scopes 来组织图
我们现在有一个简单的方法可以来定义我们的模型,但是由此产生的计算图仍然是非常拥挤的。如果你想要将图进行可视化,那么它将会包含很多互相链接的小节点。解决方案是我们对每一个函数(每一个节点)都起一个名字,利用 tf.name_scope(‘name’) 或者 tf.variable_scope(‘name’) 就可以实现。然后节点会在图中自由组合在一起,我们只需要调整我们的装饰器就可以了。
import functools
def define_scope(function):
attribute = '_cache_' + function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
with tf.variable_scope(function.__name):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
我给了装饰器一个新的名字,因为除了我们加的懒惰缓存,它还具有特定的 TensorFlow 功能。除此之外,模型看起来和原来的是一样的。
完整的代码,你可以查看Github。
来源:danijar