为了进一步了解上一篇中的class,搜了github如下示例:
import tensorflow as tf
class MyLayer(tf.keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(int(input_shape[1]), self.output_dim),
initializer='uniform',