tensorflow 2.0 Layer定义的源码分析

一直不太懂tensorflow 2.0的层的操作,所以跑去看了下源码,其实也不难,如果对python比较了解的话,自己去看下源码也很快就能理解了。

tensorflow 2.0中的api都是使用了keras那一套,这篇文章中主要是介绍keras Layer层的实现,从自定义层下手读懂实现的方法。

 

自定义层的定义如下,super(MyDenseLayer, self).__init__()方法是执行父类的init()的方法,就是子类的实例可以调用父类中的__init__定义的属性,这样子类的self得到了父类的一切就可以为所欲为了。

class MyDenseLayer(tf.keras.layers.Layer):
  def __init__(self, num_outputs):
    super(MyDenseLayer, self).__init__()
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_variable("kernel",
                                    shape=[int(input_shape[-1]),
                                           self.num_outputs])

  def call(self, input):
    return tf.matmul(input, self.kernel)

layer = MyDenseLayer(10)

_ = layer(tf.zeros([10, 5])) # Calling the layer `.builds` it.


看到这个层里有三个函数,

__init__:这个是初始化函数,实例化一个MyDenseLayer对象的时候就会执行这个函数,

build函数:一般是在call执行之前执行,一般用来构造网络参数的shape,当然也可以在init里定义,在build中定义的好处是,可以在实例化MyDenseLayer对象后,再输入输入的shape给Layer。

call函数:定义layer层里需要的操作

你可能会疑惑,他的对象实例好像没有调用call函数,为什么会会执行这个函数呢。这主要是通过__call__函数实现的,

通过在父类中定义__call__函数,而执行 layer(tf.zeros([10, 5])) 函数的时候会调用父类的__call__函数,贴上父类  tf.keras.layers.Layer的__call__函数 的代码,(注意传入的self是子类的实例

 def __call__(self, inputs, *args, **kwargs):
   
    input_list = nest.flatten(inputs)

    if context.executing_eagerly():
      # Accept NumPy inputs by converting to Tensors when executing eagerly.
      if all(isinstance(x, (np.ndarray, float, int)) for x in input_list):
        inputs = nest.map_structure(ops.convert_to_tensor, inputs)
        input_list = nest.flatten(inputs)

    # We will attempt to build a TF graph if & only if all inputs are symbolic.
    # This is always the case in graph mode. It can also be the case in eager
    # mode when all inputs can be traced back to `keras.Input()` (when building
    # models using the functional API).
    build_graph = tf_utils.are_all_symbolic_tensors(input_list)
    executing_eagerly = context.executing_eagerly()

    # Handle Keras mask propagation from previous layer to current layer.
    previous_mask = None
    if build_graph and (not hasattr(self, '_compute_previous_mask') or
                        self._compute_previous_mask):
      previous_mask = base_layer_utils.collect_previous_mask(inputs)
      if not hasattr(self, '_call_fn_args'):
        self._call_fn_args = function_utils.fn_args(self.call)
      if ('mask' in self._call_fn_args and 'mask' not in kwargs and
          not generic_utils.is_all_none(previous_mask)):
        # The previous layer generated a mask, and mask was not explicitly pass
        # to __call__, hence we set previous_mask as the default value.
        kwargs['mask'] = previous_mask

    input_shapes = None

    with ops.name_scope(self._name_scope()):
      if not self.built:
        # Build layer if applicable (if the `build` method has been overridden).
        self._maybe_build(inputs)
        # We must set self.built since user defined build functions are not
        # constrained to set self.built.
        self.built = True

      # Check input assumptions set after layer building, e.g. input shape.
      if build_graph:
        # Symbolic execution on symbolic tensors. We will attempt to build
        # the corresponding TF subgraph inside `backend.get_graph()`
        input_spec.assert_input_compatibility(
            self.input_spec, inputs, self.name)
        graph = backend.get_graph()
        with graph.as_default():
          if not executing_eagerly:
            # In graph mode, failure to build the layer's graph
            # implies a user-side bug. We don't catch exceptions.
            outputs = self.call(inputs, *args, **kwargs)
          else:
            try:
              outputs = self.call(inputs, *args, **kwargs)
            except Exception:  # pylint: disable=broad-except
              # Any issue during graph-building means we will later run the
              # model in eager mode, whether the issue was related to
              # graph mode or not. This provides a nice debugging experience.
              self._call_is_graph_friendly = False
              # We will use static shape inference to return symbolic tensors
              # matching the specifications of the layer outputs.
              # Since we have set `self._call_is_graph_friendly = False`,
              # we will never attempt to run the underlying TF graph (which is
              # disconnected).
              # TODO(fchollet): consider py_func as an alternative, which
              # would enable us to run the underlying graph if needed.
              input_shapes = nest.map_structure(lambda x: x.shape, inputs)
              output_shapes = self.compute_output_shape(input_shapes)
              outputs = nest.map_structure(
                  lambda shape: backend.placeholder(shape, dtype=self.dtype),
                  output_shapes)

          if outputs is None:
            raise ValueError('A layer\'s `call` method should return a '
                             'Tensor or a list of Tensors, not None '
                             '(layer: ' + self.name + ').')
          self._handle_activity_regularization(inputs, outputs)
          self._set_mask_metadata(inputs, outputs, previous_mask)
          if base_layer_utils.have_all_keras_metadata(inputs):
            inputs, outputs = self._set_connectivity_metadata_(
                inputs, outputs, args, kwargs)
          if hasattr(self, '_set_inputs') and not self.inputs:
            # Subclassed network: explicitly set metadata normally set by
            # a call to self._set_inputs().
            # This is not relevant in eager execution.
            self._set_inputs(inputs, outputs)
      else:
        # Eager execution on data tensors.
        outputs = self.call(inputs, *args, **kwargs)
        self._handle_activity_regularization(inputs, outputs)
        return outputs

    if not context.executing_eagerly():
      # Optionally load weight values specified at layer instantiation.
      # TODO(fchollet): consider enabling this with eager execution too.
      if (hasattr(self, '_initial_weights') and
          self._initial_weights is not None):
        self.set_weights(self._initial_weights)
        del self._initial_weights
    return outputs

这个函数里先把输入的类型转换一下,然后处理上一层传到下一层的mask啥的,这mask是个啥我也不太懂,不过不关键,主要接着看下面这段代码

 with ops.name_scope(self._name_scope()):
      if not self.built:
        # Build layer if applicable (if the `build` method has been overridden).
        self._maybe_build(inputs)
        # We must set self.built since user defined build functions are not
        # constrained to set self.built.
        self.built = True

      # Check input assumptions set after layer building, e.g. input shape.

这段代码就是看是否build过网络,没有build过的化就执行 self._maybe_build(inputs) 函数,然后把self.build=True,这个的目的是当这层需要反复使用的时候,这个时候就不需要再build操作了 ,再贴出_maybe_build这个函数

  def _maybe_build(self, inputs):
    # Check input assumptions set before layer building, e.g. input rank.
    input_spec.assert_input_compatibility(
        self.input_spec, inputs, self.name)
    input_list = nest.flatten(inputs)
    if input_list and self._dtype is None:
      try:
        self._dtype = input_list[0].dtype.base_dtype.name
      except AttributeError:
        pass
    input_shapes = None
    if all(hasattr(x, 'shape') for x in input_list):
      input_shapes = nest.map_structure(lambda x: x.shape, inputs)
    # Only call `build` if the user has manually overridden the build method.
    if not hasattr(self.build, '_is_default'):
      self.build(input_shapes)

可以看到最后一句就是调用子类的build方法,这里为什么调用的是子类的build呢,是因为实例化的是子类,调用父类的__call__(self, inputs, *args, **kwargs)方法是传入的self是子类的实例,所以调用的当然是子类的方法,然后就接着向下看

      if build_graph:
        # Symbolic execution on symbolic tensors. We will attempt to build
        # the corresponding TF subgraph inside `backend.get_graph()`
        input_spec.assert_input_compatibility(
            self.input_spec, inputs, self.name)
        graph = backend.get_graph()
        with graph.as_default():
          if not executing_eagerly:
            # In graph mode, failure to build the layer's graph
            # implies a user-side bug. We don't catch exceptions.
            outputs = self.call(inputs, *args, **kwargs)
          else:
            try:
              outputs = self.call(inputs, *args, **kwargs)
            except Exception:  # pylint: disable=broad-except
              # Any issue during graph-building means we will later run the
              # model in eager mode, whether the issue was related to
              # graph mode or not. This provides a nice debugging experience.
              self._call_is_graph_friendly = False
              # We will use static shape inference to return symbolic tensors
              # matching the specifications of the layer outputs.
              # Since we have set `self._call_is_graph_friendly = False`,
              # we will never attempt to run the underlying TF graph (which is
              # disconnected).
              # TODO(fchollet): consider py_func as an alternative, which
              # would enable us to run the underlying graph if needed.
              input_shapes = nest.map_structure(lambda x: x.shape, inputs)
              output_shapes = self.compute_output_shape(input_shapes)
              outputs = nest.map_structure(
                  lambda shape: backend.placeholder(shape, dtype=self.dtype),
                  output_shapes)

          if outputs is None:
            raise ValueError('A layer\'s `call` method should return a '
                             'Tensor or a list of Tensors, not None '
                             '(layer: ' + self.name + ').')
          self._handle_activity_regularization(inputs, outputs)
          self._set_mask_metadata(inputs, outputs, previous_mask)
          if base_layer_utils.have_all_keras_metadata(inputs):
            inputs, outputs = self._set_connectivity_metadata_(
                inputs, outputs, args, kwargs)
          if hasattr(self, '_set_inputs') and not self.inputs:
            # Subclassed network: explicitly set metadata normally set by
            # a call to self._set_inputs().
            # This is not relevant in eager execution.
            self._set_inputs(inputs, outputs)
      else:
        # Eager execution on data tensors.
        outputs = self.call(inputs, *args, **kwargs)
        self._handle_activity_regularization(inputs, outputs)
        return outputs

这段代码主要是执行了构建了个计算图,然后主要是执行了

outputs = self.call(inputs, *args, **kwargs)

这一段,因为传入的是子类实例,所以同样调用的是子类的call方法,然后得到输出,至于 executing_eagerly 是让网络选择动态eager的方式还是静态图的方式执行,详细可以自行了解哦

 

总结来说呢,其实tensorflow 2.0中可以说keras中,所有Layer的执行都是把Layer传到Layer的父类的__call__(self)方法中去,然后在__call__(self) 方法中执行self.call() 函数,调用Layer中的call方法执行。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值