FlashAttention
if is_flash_attn_available(): # 检查flashattention的可用性
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
FlashAttention是Tranformer模型中用于改进注意力机制的技术,主要目的是减少计算复杂度和内存占用。
- flash_attn_func用于标准的flashattention计算。
- flash_attn_varlen_func用于处理变长序列(长度未能确定)的flashattention计算。
- index_first_axis用于处理第一个索引轴。
- pad_input将数据进行填充处理,从而确定长度。
- unpad_input将填充后的输入还原为原始形态。
Logging模块
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"
创建了名为logger的日志记录器对象,__name__用于保存模块的名称,确保每个模块都有自己的日志记录器。
_CONFIG_FOR_DOC前面带有下划线,因此可以看出其代表一个模块的内部变量。
get_unpad_data模块
def _get_unpad_data(padding_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32