AttentionWrapperState中的函数_compute_attention()理解:
GitHub源码地址
def _compute_attention(attention_mechanism, cell_output, attention_state,
attention_layer):
"""Computes the attention and alignments for a given attention_mechanism."""
alignments, next_attention_state = attention_mechanism(
cell_output, state=attention_state)
# attention_mechanism的输入是,
# def __call__(self, query, state)
#return alignments, next_state
#其中next_state = alignments
#query: 与self.values类型匹配的张量,尺寸为:[batch_size, query_depth]
#state: 与self.values类型匹配的张量,尺寸为:[batch_size, alignments_size](alignments_size 是'max_time')
#返回值:
#alignments:与self.values类型匹配的张量,尺寸为:[batch_size, alignments_size](alignments_size 是'max_time')
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the memory time dimension
# 上下文是沿着记忆时间维度,对对齐(alignments)和值(values)的内积
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, memory_size]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, memory_size].
# we then squeeze out the singleton dim.
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1])
if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
else:
attention = context
return attention, alignments, next_attention_state
应用:
attention, alignments = _compute_attention(
attention_mechanism, rnn_cell_state, previous_alignments[i],
self._attention_layers[i] if self._attention_layers else None)
AttentionWrapperState:
参数:
cell_state:在前一时间步中,RNNCell的状态
attention:在上一时间步中得到的attention
time:int32 含有当前时间步的标量
alignments:一个张量或一组张量,每应用个注意力机制的前一时间步得到的对齐(alignments)
alignment_history:(如果激活)一个TensorArray或一组TensorArray,应用于每个注意力机制的所有时间步的对齐矩阵
attention_state:单个嵌套对象或嵌套对象的元组,输入给每个注意力机制的状态,此对象可能包含张量或张量数组。
TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用。
class AttentionWrapperState(
collections.namedtuple("AttentionWrapperState",
("cell_state", "attention", "time", "alignments",
"alignment_history", "attention_state"))):
"""`namedtuple` storing the state of a `AttentionWrapper`.
Contains:
- `cell_state`: The state of the wrapped `RNNCell` at the previous time
step.
- `attention`: The attention emitted at the previous time step.
- `time`: int32 scalar containing the current time step.
- `alignments`: A single or tuple of `Tensor`(s) containing the alignments
emitted at the previous time step for each attention mechanism.
- `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s)
containing alignment matrices from all time steps for each attention
mechanism. Call `stack()` on each to convert to a `Tensor`.
- `attention_state`: A single or tuple of nested objects
containing attention mechanism state for each attention mechanism.
The objects may contain Tensors or TensorArrays.
"""
def clone(self, **kwargs):
"""Clone this object, overriding components provided by kwargs.
The new state fields' shape must match original state fields' shape. This
will be validated, and original fields' shape will be propagated to new
fields.
Example:
```python
initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...)
initial_state = initial_state.clone(cell_state=encoder_state)
```
Args:
**kwargs: Any properties of the state object to replace in the returned
`AttentionWrapperState`.
Returns:
A new `AttentionWrapperState` whose properties are the same as
this one, except any overridden properties as provided in `kwargs`.
"""
def with_same_shape(old, new):
"""Check and set new tensor's shape."""
if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
return tensor_util.with_same_shape(old, new)
return new
return nest.map_structure(
with_same_shape,
self,
super(AttentionWrapperState, self)._replace(**kwargs))
def hardmax(logits, name=None):
"""Returns batched one-hot vectors.
The depth index containing the `1` is that of the maximum logit value.
Args:
logits: A batch tensor of logit values.
name: Name to use when creating ops.
Returns:
A batched one-hot tensor.
"""
with ops.name_scope(name, "Hardmax", [logits]):
logits = ops.convert_to_tensor(logits, name="logits")
if logits.get_shape()[-1].value is not None:
depth = logits.get_shape()[-1].value
else:
depth = array_ops.shape(logits)[-1]
return array_ops.one_hot(
math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
def _compute_attention(attention_mechanism, cell_output, attention_state,
attention_layer):
"""Computes the attention and alignments for a given attention_mechanism."""
alignments, next_attention_state = attention_mechanism(
cell_output, state=attention_state)
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, memory_size]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, memory_size].
# we then squeeze out the singleton dim.
context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
context = array_ops.squeeze(context, [1])
if attention_layer is not None:
attention = attention_layer(array_ops.concat([cell_output, context], 1))
else:
attention = context
return attention, alignments, next_attention_state