seq2seq.AttentionWrapperState中的_compute_attention

AttentionWrapperState中的函数_compute_attention()理解:
GitHub源码地址


def _compute_attention(attention_mechanism, cell_output, attention_state,
                       attention_layer):
  """Computes the attention and alignments for a given attention_mechanism."""
  
  alignments, next_attention_state = attention_mechanism(
      cell_output, state=attention_state)
  # attention_mechanism的输入是,
  # def __call__(self, query, state)
  #return alignments, next_state
  #其中next_state = alignments
  #query: 与self.values类型匹配的张量,尺寸为:[batch_size, query_depth]
  #state: 与self.values类型匹配的张量,尺寸为:[batch_size, alignments_size](alignments_size 是'max_time')
  #返回值:
  #alignments:与self.values类型匹配的张量,尺寸为:[batch_size, alignments_size](alignments_size 是'max_time')
  
  # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
  expanded_alignments = array_ops.expand_dims(alignments, 1)
  # Context is the inner product of alignments and values along the memory time dimension
  # 上下文是沿着记忆时间维度,对对齐(alignments)和值(values)的内积
  # alignments shape is
  #   [batch_size, 1, memory_time]
  # attention_mechanism.values shape is
  #   [batch_size, memory_time, memory_size]
  # the batched matmul is over memory_time, so the output shape is
  #   [batch_size, 1, memory_size].
  # we then squeeze out the singleton dim.
  context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
  context = array_ops.squeeze(context, [1])

  if attention_layer is not None:
    attention = attention_layer(array_ops.concat([cell_output, context], 1))
  else:
    attention = context

return attention, alignments, next_attention_state

应用:

      attention, alignments = _compute_attention(
          attention_mechanism, rnn_cell_state, previous_alignments[i],
          self._attention_layers[i] if self._attention_layers else None)

AttentionWrapperState:
参数:
cell_state:在前一时间步中,RNNCell的状态
attention:在上一时间步中得到的attention
time:int32 含有当前时间步的标量
alignments:一个张量或一组张量,每应用个注意力机制的前一时间步得到的对齐(alignments)
alignment_history:(如果激活)一个TensorArray或一组TensorArray,应用于每个注意力机制的所有时间步的对齐矩阵
attention_state:单个嵌套对象或嵌套对象的元组,输入给每个注意力机制的状态,此对象可能包含张量或张量数组。
TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用。

class AttentionWrapperState(
    collections.namedtuple("AttentionWrapperState",
                           ("cell_state", "attention", "time", "alignments",
                            "alignment_history", "attention_state"))):
  """`namedtuple` storing the state of a `AttentionWrapper`.
  Contains:
    - `cell_state`: The state of the wrapped `RNNCell` at the previous time
      step.
    - `attention`: The attention emitted at the previous time step.
    - `time`: int32 scalar containing the current time step.
    - `alignments`: A single or tuple of `Tensor`(s) containing the alignments
       emitted at the previous time step for each attention mechanism.
    - `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s)
       containing alignment matrices from all time steps for each attention
       mechanism. Call `stack()` on each to convert to a `Tensor`.
    - `attention_state`: A single or tuple of nested objects
       containing attention mechanism state for each attention mechanism.
       The objects may contain Tensors or TensorArrays.
  """

  def clone(self, **kwargs):
    """Clone this object, overriding components provided by kwargs.
    The new state fields' shape must match original state fields' shape. This
    will be validated, and original fields' shape will be propagated to new
    fields.
    Example:
    ```python
    initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...)
    initial_state = initial_state.clone(cell_state=encoder_state)
    ```
    Args:
      **kwargs: Any properties of the state object to replace in the returned
        `AttentionWrapperState`.
    Returns:
      A new `AttentionWrapperState` whose properties are the same as
      this one, except any overridden properties as provided in `kwargs`.
    """
    def with_same_shape(old, new):
      """Check and set new tensor's shape."""
      if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
        return tensor_util.with_same_shape(old, new)
      return new

    return nest.map_structure(
        with_same_shape,
        self,
        super(AttentionWrapperState, self)._replace(**kwargs))


def hardmax(logits, name=None):
  """Returns batched one-hot vectors.
  The depth index containing the `1` is that of the maximum logit value.
  Args:
    logits: A batch tensor of logit values.
    name: Name to use when creating ops.
  Returns:
    A batched one-hot tensor.
  """
  with ops.name_scope(name, "Hardmax", [logits]):
    logits = ops.convert_to_tensor(logits, name="logits")
    if logits.get_shape()[-1].value is not None:
      depth = logits.get_shape()[-1].value
    else:
      depth = array_ops.shape(logits)[-1]
    return array_ops.one_hot(
        math_ops.argmax(logits, -1), depth, dtype=logits.dtype)


def _compute_attention(attention_mechanism, cell_output, attention_state,
                       attention_layer):
  """Computes the attention and alignments for a given attention_mechanism."""
  alignments, next_attention_state = attention_mechanism(
      cell_output, state=attention_state)

  # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
  expanded_alignments = array_ops.expand_dims(alignments, 1)
  # Context is the inner product of alignments and values along the
  # memory time dimension.
  # alignments shape is
  #   [batch_size, 1, memory_time]
  # attention_mechanism.values shape is
  #   [batch_size, memory_time, memory_size]
  # the batched matmul is over memory_time, so the output shape is
  #   [batch_size, 1, memory_size].
  # we then squeeze out the singleton dim.
  context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
  context = array_ops.squeeze(context, [1])

  if attention_layer is not None:
    attention = attention_layer(array_ops.concat([cell_output, context], 1))
  else:
    attention = context

return attention, alignments, next_attention_state
class SelfAttention(nn.Module): def init(self, input_size=1, num_heads=1): super(SelfAttention, self).init() self.num_heads = 1 self.head_size = 1 self.query = nn.Linear(1, 1) self.key = nn.Linear(1, 1) self.value = nn.Linear(1, 1) self.out = nn.Linear(1, 1) def forward(self, inputs): batch_size, seq_len, input_size = inputs.size() # 128 706 1 # Split inputs into num_heads inputs = inputs.view(batch_size, seq_len, self.num_heads, self.head_size) inputs = inputs.permute(0, 2, 1, 3).contiguous() queries = self.query(inputs).view(batch_size, self.num_heads, seq_len, self.head_size) keys = self.key(inputs).view(batch_size, self.num_heads, seq_len, self.head_size) values = self.value(inputs).view(batch_size, self.num_heads, seq_len, self.head_size) # Compute attention scores scores = torch.matmul(queries, keys.permute(0, 1, 3, 2)) scores = scores / (self.head_size ** 0.5) attention = F.softmax(scores, dim=-1) # Apply attention weights to values attention_output = torch.matmul(attention, values) attention_output = attention_output.view(batch_size, seq_len, input_size) # Apply output linear layer output = self.out(attention_output) return output class DenseAttentionLayer(nn.Module): def init(self, input_size, return_alphas=True, name=None, num_heads=1): super(DenseAttentionLayer, self).init() self.return_alphas = return_alphas self.name = name self.num_heads = num_heads # If input comes with a hidden dimension (e.g. 5 features per gene) # print("len(input_size): ",len(input_size)) # 2 if len(input_size) == 3: self.feature_collapse = nn.Linear(input_size[-1], 1) input_size = (input_size[0], input_size[1]) self.attention = SelfAttention(input_size=1, num_heads=1) def forward(self, inputs): print("inputs.shape: ",inputs.shape) # torch.Size([128, 706]) output = self.attention(inputs) if self.return_alphas: alphas = F.softmax(output, dim=1) return torch.mul(inputs, alphas), alphas else: return output 对于上述代码其中numheads=1 headsize=1
05-24
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值