参考资料:
整个源码主要围绕着Transformers库中的model_t5展开,也默认大家都已经熟悉了transformer的基本原理,知道了像GPT和Bert这些为代表的decoder和encoder架构,这里我们就介绍另一个流派的代表:基于encoder-decoder架构的T5模型
1.代码文件框架:
这里主要围绕modeling_t5展开(基本原理强力推荐上面的博客,下面主要是对官方代码的解读):
class T5LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
层归一化:hidden_states(B,L,C),return (B,L,C)
class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_states = self.wi(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(hidden_states)
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
主要作为FFN中的一部分,输入和输出维度没有变化,用于深层次加工和信息提炼
class T5DenseGatedActDense(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
# To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
# See https://github.com/huggingface/transformers/issues/20287
# we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
if (
isinstance(self.wo.weight, torch.Tensor)
and hidden_states.dtype != self.wo.weight.dtype
and self.wo.weight.dtype != torch.int8
):
hidden_states = hidden_states.to(self.wo.weight.dtype)
hidden_states = self.wo(hidden_states)
return hidden_states
这个和前面一个多了门控机制,也就是wi_0的输出当作模型重要性权重来控制wi_1,输入输出维度依旧不变
class T5LayerFF(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
if config.is_gated_act:
self.DenseReluDense = T5DenseGatedActDense(config)
else:
self.DenseReluDense = T5DenseActDense(config)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, hidden_states):
forwarded_states = self.layer_norm(hidden_states)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
FFN,输入输出维度不变,用于特征融合。
下面介绍最重要的注意力机制:
1.初始化:
def __init__(
self,
config: T5Config,
has_relative_attention_bias=False,
layer_idx: Optional[int] = None,
):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.relative_attention_max_distance = config.relative_attention_max_distance
self.d_model = config.d_model
self.key_value_proj_dim = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim
self.layer_idx = layer_idx
if layer_idx is None and self.is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = False
下面是各个参数的解释:
2.relative_position_bucket
函数
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = max_exact + (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
)
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
目标就是将任意两个 token 之间的相对距离(如 -500, -10, 0, 5, 100)映射到一个固定数量的、有限的“桶”(buckets) 中,进而加到注意力分数上,具体的数学公式和代码细节参见上面博客。
3.compute_bias函数:
def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias"""
if device is None:
device = self.relative_attention_bias.weight.device
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(
relative_position, # shape (query_length, key_length)
bidirectional=(not self.is_decoder),
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
4.forward函数
def forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
batch_size, seq_length = hidden_states.shape[:2]
# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None
query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache):
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
curr_past_key_value = past_key_value.cross_attention_cache
else:
curr_past_key_value = past_key_value.self_attention_cache
else:
curr_past_key_value = past_key_value
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul(query_states, key_states.transpose(3, 2))
if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
)
position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None:
causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias
scores += position_bias_masked
# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.o(attn_output)
outputs = (attn_output, position_bias)
if output_attentions:
outputs = outputs + (attn_weights,)
return outputs
注意:
Mask 是 (batch_size, 1, 1, key_length) (编码器非因果) 或 (batch_size, 1, seq_length, key_length) (解码器因果)
核心代码解释:
# 输入是 (batch_size, seq_length, dim)
# Mask 是 (batch_size, 1, 1, key_length) (编码器非因果) 或 (batch_size, 1, seq_length, key_length) (解码器因果)
batch_size, seq_length = hidden_states.shape[:2]
# 如果提供了 key_value_states,说明这一层是解码器中的交叉注意力层
is_cross_attention = key_value_states is not None
query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
得到q值
# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache):
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
curr_past_key_value = past_key_value.cross_attention_cache
else:
curr_past_key_value = past_key_value.self_attention_cache
else:
curr_past_key_value = past_key_value
current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
得到K,V值,这是在介绍K,V是怎么来的
将Mask信息和相对位置编码一起打包成偏置项
if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
)
position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None:
causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
解释:
这些就是T5的难点和重点了,下面无非就是对这些模块的组合拼装
class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__()
self.SelfAttention = T5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
cache_position=None,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
normed_hidden_states,
mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs
T5的子注意力模块
class T5LayerCrossAttention(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self,
hidden_states,
key_value_states,
attention_mask=None,
position_bias=None,
layer_head_mask=None,
past_key_value=None,
use_cache=False,
query_length=None,
output_attentions=False,
cache_position=None,
):
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
normed_hidden_states,
mask=attention_mask,
key_value_states=key_value_states,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
cache_position=cache_position,
)
layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
return outputs
T5的交叉注意力模块
class T5Block(GradientCheckpointingLayer):
def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__()
self.is_decoder = config.is_decoder
self.layer = nn.ModuleList()
self.layer.append(
T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
)
if self.is_decoder:
self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx))
self.layer.append(T5LayerFF(config))
def forward(
self,
hidden_states,
attention_mask=None,
position_bias=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
layer_head_mask=None,
cross_attn_layer_head_mask=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
return_dict=True,
cache_position=None,
):
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
layer_head_mask=layer_head_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
)
hidden_states = self_attention_outputs[0]
attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
cross_attention_outputs = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
query_length=cache_position[-1] + 1,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = cross_attention_outputs[0]
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[1:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states)
# clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
outputs = (hidden_states,)
return (
outputs + attention_outputs
) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
T5的基本模块(encoder:(self-attention+FFN)decoder:(self-attention,cross-attention,FFN))
class T5ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config: T5Config):
super().__init__()
self.dense = nn.Linear(config.d_model, config.d_model)
self.dropout = nn.Dropout(p=config.classifier_dropout)
self.out_proj = nn.Linear(config.d_model, config.num_labels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states)
hidden_states = self.dense(hidden_states)
hidden_states = torch.tanh(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.out_proj(hidden_states)
return hidden_states
分类头,用于下流任务
T5stack:分别用来组合前面的block成为encoder/decoder:
初始化:
class T5Stack(T5PreTrainedModel):
def __init__(self, config, embed_tokens=None):
super().__init__(config)
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.block = nn.ModuleList(
[T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
)
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
# Initialize weights and apply final processing
self.post_init()
# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
forward函数:
这里主要呈现核心代码(很多其实是工程上的需要)
if inputs_embeds is None:
if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings")
inputs_embeds = self.embed_tokens(input_ids)
将id序列的每个id变成embedding向量
if self.is_decoder:
if use_cache and past_key_values is None:
if self.config.is_encoder_decoder:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
else:
past_key_values = DynamicCache()
for i, layer_module in enumerate(self.block):
layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i]
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if causal_mask is not None:
causal_mask = causal_mask.to(hidden_states.device)
if position_bias is not None:
position_bias = position_bias.to(hidden_states.device)
if encoder_hidden_states is not None:
encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
if encoder_extended_attention_mask is not None:
encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
if encoder_decoder_position_bias is not None:
encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
if layer_head_mask is not None:
layer_head_mask = layer_head_mask.to(hidden_states.device)
if cross_attn_layer_head_mask is not None:
cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
hidden_states,
causal_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
循环遍历组装的模块
T5model:
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
将T5stack的encoder部分和decoder部分组装,主要是输出最后一层的隐状态