keras Layer
Simple Introduction
Keras实现了很多层,包括核心层、卷基层、RNN网络层等诸多常用的网络结构。
Core 核心层
Source
class Layer(object):
'''Abstract base layer class.
All Keras layers accept certain keyword arguments:
trainable: boolean. Set to "False" before model compilation
to freeze layer weights (they won't be updated further
during training).
input_shape: a tuple of integers specifying the expected shape
of the input samples. Does not includes the batch size.
(e.g. `(100,)` for 100-dimensional inputs).
batch_input_shape: a tuple of integers specifying the expected
shape of a batch of input samples. Includes the batch size
(e.g. `(32, 100)` for a batch of 32 100-dimensional inputs).
'''
def __init__(self, **kwargs):
allowed_kwargs = {
'input_shape',
'trainable',
'batch_input_shape',
'cache_enabled'}
for kwarg in kwargs:
assert kwarg in allowed_kwargs, 'Keyword argument not understood: ' + kwarg
if 'input_shape' in kwargs:
self.set_input_shape((None,) + tuple(kwargs['input_shape']))
if 'batch_input_shape' in kwargs:
self.set_input_shape(tuple(kwargs['batch_input_shape']))
if 'trainable' in kwargs:
self._trainable = kwargs['trainable']
if not hasattr(self, 'params'):
self.params = []
self._cache_enabled = True
if 'cache_enabled' in kwargs:
self._cache_enabled = kwargs['cache_enabled']
@property
def cache_enabled(self):
return self._cache_enabled
@cache_enabled.setter
def cache_enabled(self, value):
self._cache_enabled = value
def __call__(self, X, mask=None, train=False):
# set temporary input
tmp_input = self.get_input
tmp_mask = None
if hasattr(self, 'get_input_mask'):
tmp_mask = self.get_input_mask
self.get_input_mask = lambda _: mask
self.get_input = lambda _: X
Y = self.get_output(train=train)
# return input to what it was
if hasattr(self, 'get_input_mask'):
self.get_input_mask = tmp_mask
self.get_input = tmp_input
return Y
def set_previous(self, layer, connection_map={}):
'''Connect a layer to its parent in the computational graph.
'''
assert self.nb_input == layer.nb_output == 1, 'Cannot connect layers: input count and output count should be 1.'
if hasattr(self, 'input_ndim'):
assert self.input_ndim == len(layer.output_shape), ('Incompatible shapes: