文章目录
- 概述
- 1. 模型模块整体架构
- 1.1 目录结构设计
- 1.2 标准化模型架构
- 1.3 设计原则
- 1.3.1 统一抽象原则
- 1.3.2 配置驱动原则
- 1.3.3 任务扩展原则
- 2. 核心基类深度分析
- 2.1 PreTrainedModel基类架构
- 2.1.1 权重管理系统
- 2.1.2 序列化和反序列化
- 2.2 配置系统深度分析
- 2.2.1 配置基类架构
- 2.2.2 配置版本管理
- 3. 具体模型实现分析
- 3.1 BERT模型深度分析
- 3.1.1 BERT架构设计
- 3.1.2 BERT注意力机制
- 3.1.3 BERT中间层和输出层
- 3.2 任务特定模型扩展
- 3.2.1 序列分类模型
- 3.2.2 Token分类模型
- 4. 模型自动加载系统
- 4.1 Auto类系统架构
- 4.2 模型注册机制
- 5. 高级特性和优化
- 5.1 模型量化和压缩
- 5.2 内存优化技术
- 5.3 分布式模型支持
- 6. 模型生成系统
- 6.1 生成配置和接口
- 6.2 高级生成策略
- 7. 模型模块总结与展望
- 7.1 架构优势总结
- 7.2 技术创新点
- 7.3 未来发展方向
- 7.4 最佳实践建议
团队博客: 汽车电子社区
概述
Transformers库的模型模块是其最核心的组成部分,包含100+个预训练模型的完整实现,从经典的BERT到最新的LLaMA,涵盖了自然语言处理、计算机视觉、语音处理等多个领域。该模块通过统一的设计模式和高度标准化的架构,实现了不同模型间的代码复用和快速集成。模型模块位于src/transformers/models/目录下,每个模型都有独立的子目录,包含模型架构、配置、分词器等完整实现。本文档将从软件架构、设计模式、核心算法、实现细节等多个维度对模型模块进行全面深度剖析。
1. 模型模块整体架构
1.1 目录结构设计
模型模块采用高度规范化的目录结构,确保每个模型实现的一致性:
models/
├── __init__.py # 模型模块导出
├── auto/ # 自动模型加载系统
│ ├── __init__.py # Auto系列API
│ ├── modeling_auto.py # AutoModel实现
│ ├── configuration_auto.py # AutoConfig实现
│ └── tokenization_auto.py # AutoTokenizer实现
├── bert/ # BERT模型实现
│ ├── __init__.py # BERT模块导出
│ ├── configuration_bert.py # BERT配置类 (50+行)
│ ├── modeling_bert.py # BERT模型实现 (3791行)
│ ├── tokenization_bert.py # BERT分词器
│ └── tokenization_bert_fast.py # BERT快速分词器
├── gpt2/ # GPT-2模型实现
│ ├── __init__.py
│ ├── configuration_gpt2.py
│ ├── modeling_gpt2.py
│ └── ...
├── t5/ # T5模型实现
│ ├── __init__.py
│ ├── configuration_t5.py
│ ├── modeling_t5.py
│ └── ...
├── llama/ # LLaMA模型实现
│ ├── __init__.py
│ ├── configuration_llama.py
│ ├── modeling_llama.py
│ └── ...
├── vision_transformer/ # ViT模型实现
│ ├── __init__.py
│ ├── configuration_vit.py
│ ├── modeling_vit.py
│ └── ...
├── wav2vec2/ # Wav2Vec2语音模型
│ ├── __init__.py
│ ├── configuration_wav2vec2.py
│ ├── modeling_wav2vec2.py
│ └── ...
└── ... # 其他模型实现
1.2 标准化模型架构
每个模型都遵循统一的架构模式,确保一致性和可维护性:
# 标准模型架构模式
class StandardModelArchitecture:
"""标准化模型架构模式"""
class Components:
# 必需组件
ConfigClass: # 配置类 (继承PreTrainedConfig)
ModelClass: # 主模型类 (继承PreTrainedModel)
TokenizerClass: # 分词器类 (继承PreTrainedTokenizer)
# 可选组件
FastTokenizerClass: # 快速分词器 (继承PreTrainedTokenizerFast)
FeatureExtractorClass: # 特征提取器 (继承PreTrainedFeatureExtractor)
ProcessorClass: # 多模态处理器
# 任务特定模型
ForSequenceClassification: # 序列分类模型
ForTokenClassification: # Token分类模型
ForQuestionAnswering: # 问答模型
ForCausalLM: # 因果语言模型
ForMaskedLM: # 掩码语言模型
ForMultipleChoice: # 多选题模型
1.3 设计原则
1.3.1 统一抽象原则
所有模型都继承自统一的基类,确保接口一致性:
# 统一的抽象层次
PreTrainedModel (基类)
├── 编码器模型 (Encoder-Only): BERT, RoBERTa, ALBERT
├── 解码器模型 (Decoder-Only): GPT-2, LLaMA, OPT
├── 编解码器模型 (Encoder-Decoder): T5, BART, Pegasus
└── 多模态模型 (Multi-Modal): CLIP, BLIP, ViLT
1.3.2 配置驱动原则
通过配置文件控制模型的所有超参数和行为:
# 配置驱动的模型构建
class ModelFromConfig:
def __init__(self, config):
self.config = config
self._build_model_from_config()
def _build_model_from_config(self):
# 根据配置动态构建模型
self.embeddings = self._build_embeddings()
self.encoder = self._build_encoder()
self.pooler = self._build_pooler() if config.add_pooling_layer else None
1.3.3 任务扩展原则
每个基础模型都可以扩展为不同的下游任务:
# 任务扩展示例
class BaseModel(PreTrainedModel):
"""基础模型类"""
class BaseModelForSequenceClassification(BaseModel):
"""序列分类扩展"""
class BaseModelForTokenClassification(BaseModel):
"""Token分类扩展"""
class BaseModelForQuestionAnswering(BaseModel):
"""问答任务扩展"""
2. 核心基类深度分析
2.1 PreTrainedModel基类架构
modeling_utils.py中的PreTrainedModel是所有模型的基础抽象类,包含4697行代码,提供了模型的完整基础设施:
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
"""所有预训练模型的基础抽象类"""
# 类属性 - 子类必须定义
config_class = None # 对应的配置类
base_model_prefix = "" # 模型前缀
main_input_name = "input_ids" # 主要输入名称
supports_gradient_checkpointing = False # 是否支持梯度检查点
def __init__(self, config):
super().__init__()
self.config = config
# 模型初始化后处理
self.post_init()
def post_init(self):
"""模型初始化后的处理"""
# 权重初始化
self.init_weights()
# 设置设备
self.to(self.device)
@property
def device(self):
"""获取模型设备"""
return next(self.parameters()).device
@property
def dtype(self):
"""获取模型数据类型"""
return next(self.parameters()).dtype
2.1.1 权重管理系统
class WeightManagementMixin:
"""权重管理混入类"""
def init_weights(self):
"""初始化模型权重"""
# 1. 应用初始化配置
if hasattr(self.config, 'init_method'):
init_method = self.config.init_method
else:
init_method = self._default_init_method
# 2. 递归初始化所有模块
for module in self.modules():
if hasattr(module, 'weight') and module.weight is not None:
if isinstance(module, nn.Linear):
# 线性层初始化
self._init_linear_weights(module, init_method)
elif isinstance(module, nn.Embedding):
# 嵌入层初始化
self._init_embedding_weights(module, init_method)
elif isinstance(module, nn.LayerNorm):
# 层归一化初始化
self._init_layernorm_weights(module)
def _default_init_method(self, tensor):
"""默认权重初始化方法"""
# 根据配置选择初始化策略
if self.config.weight_init_std is not None:
# 标准正态分布初始化
nn.init.normal_(tensor, mean=0.0, std=self.config.weight_init_std)
elif hasattr(self.config, 'initializer_range'):
# 根据配置的初始化范围
nn.init.normal_(tensor, mean=0.0, std=self.config.initializer_range)
else:
# 默认Xavier初始化
nn.init.xavier_uniform_(tensor)
def _init_linear_weights(self, module, init_method):
"""初始化线性层权重"""
# 输入权重初始化
init_method(module.weight.data)
# 偏置初始化
if module.bias is not None:
if self.config.use_bias:
nn.init.zeros_(module.bias.data)
else:
module.bias.data.zero_()
def _init_embedding_weights(self, module, init_method):
"""初始化嵌入层权重"""
init_method(module.weight.data)
# 特殊token处理
if hasattr(self.config, 'pad_token_id') and self.config.pad_token_id is not None:
module.weight.data[self.config.pad_token_id].zero_()
2.1.2 序列化和反序列化
class SerializationMixin:
"""序列化混入类"""
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""保存预训练模型"""
# 1. 创建保存目录
os.makedirs(save_directory, exist_ok=True)
# 2. 保存模型权重
weights_file = os.path.join(save_directory, WEIGHTS_NAME)
if self.config.save_format == "safetensors":
# SafeTensors格式保存
self._save_safetensors(weights_file)
else:
# PyTorch格式保存
self._save_pytorch_weights(weights_file)
# 3. 保存配置文件
config_file = os.path.join(save_directory, CONFIG_NAME)
self.config.save_pretrained(save_directory)
# 4. 保存模型状态信息
state_dict_file = os.path.join(save_directory, "state.json")
state = {
"model_type": self.config.model_type,
"framework": "pytorch",
"transformers_version": __version__,
}
with open(state_dict_file, "w") as f:
json.dump(state, f, indent=2)
def _save_safetensors(self, weights_file: str):
"""SafeTensors格式保存"""
from safetensors.torch import save_file
# 提取模型权重
state_dict = self.state_dict()
# 分片保存(大模型优化)
if self.config.use_sharded_weights:
self._save_sharded_safetensors(state_dict, weights_file)
else:
save_file(state_dict, weights_file)
def _save_pytorch_weights(self, weights_file: str):
"""PyTorch格式保存"""
state_dict = self.state_dict()
# 分片保存(大模型优化)
if self.config.use_sharded_weights:
self._save_sharded_pytorch(state_dict, weights_file)
else:
torch.save(state_dict, weights_file)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
*model_args,
**kwargs
):
"""从预训练模型加载"""
# 1. 加载配置
config = kwargs.pop("config", None)
if config is None:
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
# 2. 创建模型实例
model = cls(config, *model_args, **kwargs)
# 3. 加载权重
state_dict = cls._load_state_dict(pretrained_model_name_or_path, **kwargs)
# 4. 权重转换和加载
model.load_state_dict(state_dict, strict=kwargs.get("strict", True))
return model
@classmethod
def _load_state_dict(cls, pretrained_model_name_or_path: str, **kwargs):
"""加载状态字典"""
# 1. 确定权重文件路径
if os.path.isdir(pretrained_model_name_or_path):
# 本地目录
weights_files = cls._get_weight_files(pretrained_model_name_or_path)
else:
# Hub仓库
weights_files = cls._download_weights(pretrained_model_name_or_path, **kwargs)
# 2. 加载状态字典
if len(weights_files) == 1:
# 单文件权重
if weights_files[0].endswith(".safetensors"):
from safetensors.torch import load_file
state_dict = load_file(weights_files[0])
else:
state_dict = torch.load(weights_files[0], map_location="cpu")
else:
# 分片权重
state_dict = cls._load_sharded_weights(weights_files)
return state_dict
2.2 配置系统深度分析
2.2.1 配置基类架构
class PreTrainedConfig(PushToHubMixin, RotaryEmbeddingConfigMixin):
"""配置基类 - 所有模型配置的基础"""
# 类属性定义
model_type: str = "" # 模型类型标识符
is_composition: bool = False # 是否为复合配置
attribute_map: dict = {} # 属性映射表
keys_to_ignore_at_inference: list = [] # 推理时忽略的键
def __init__(self, **kwargs):
# 标准配置参数
self.vocab_size = kwargs.pop("vocab_size", None)
self.hidden_size = kwargs.pop("hidden_size", None)
self.num_hidden_layers = kwargs.pop("num_hidden_layers", None)
self.num_attention_heads = kwargs.pop("num_attention_heads", None)
self.intermediate_size = kwargs.pop("intermediate_size", None)
# 激活函数
self.hidden_act = kwargs.pop("hidden_act", "gelu")
# Dropout和正则化
self.hidden_dropout_prob = kwargs.pop("hidden_dropout_prob", 0.1)
self.attention_probs_dropout_prob = kwargs.pop("attention_probs_dropout_prob", 0.1)
# LayerNorm参数
self.layer_norm_eps = kwargs.pop("layer_norm_eps", 1e-12)
# 初始化参数
self.initializer_range = kwargs.pop("initializer_range", 0.02)
self.weight_decay = kwargs.pop("weight_decay", 0.0)
# 序列长度参数
self.max_position_embeddings = kwargs.pop("max_position_embeddings", 512)
# 特殊token
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.unk_token_id = kwargs.pop("unk_token_id", None)
# 任务特定参数
self.problem_type = kwargs.pop("problem_type", None)
self.num_labels = kwargs.pop("num_labels", None)
# 存储未使用的kwargs
self.init_kwargs = kwargs
# 应用属性映射
self._apply_attribute_map()
def _apply_attribute_map(self):
"""应用属性映射"""
for old_name, new_name in self.attribute_map.items():
if hasattr(self, old_name):
setattr(self, new_name, getattr(self, old_name))
delattr(self, old_name)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
output = {}
for key, value in self.__dict__.items():
if not key.startswith("_") and not callable(value):
output[key] = value
return output
def to_json_string(self) -> str:
"""转换为JSON字符串"""
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
"""从字典创建配置"""
# 创建配置实例
config = cls(**config_dict)
# 应用额外的kwargs
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
return config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs):
"""从预训练模型加载配置"""
# 1. 确定配置文件路径
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# 2. 从字典创建配置
return cls.from_dict(config_dict, **kwargs)
@classmethod
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
"""获取配置字典"""
# 1. 从Hub或本地加载
config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
_raise_exceptions_for_missing_entries=False,
**kwargs
)
# 2. 读取配置文件
if config_file is None:
raise ValueError(f"Config file not found in {pretrained_model_name_or_path}")
with open(config_file, "r", encoding="utf-8") as reader:
config_dict = json.load(reader)
return config_dict, kwargs
2.2.2 配置版本管理
class ConfigVersionManager:
"""配置版本管理器"""
@staticmethod
def upgrade_config(config_dict: Dict[str, Any], target_version: str) -> Dict[str, Any]:
"""升级配置到目标版本"""
current_version = config_dict.get("transformers_version", "0.0.0")
# 版本升级逻辑
if version.parse(current_version) < version.parse("4.0.0"):
config_dict = ConfigVersionManager._upgrade_to_v4_0(config_dict)
if version.parse(current_version) < version.parse("4.20.0"):
config_dict = ConfigVersionManager._upgrade_to_v4_20(config_dict)
# ... 更多版本升级
return config_dict
@staticmethod
def _upgrade_to_v4_0(config_dict: Dict[str, Any]) -> Dict[str, Any]:
"""升级到v4.0"""
# 移除废弃参数
deprecated_params = ["use_cache", "output_attentions"]
for param in deprecated_params:
config_dict.pop(param, None)
# 添加新参数的默认值
if "use_return_dict" not in config_dict:
config_dict["use_return_dict"] = True
return config_dict
3. 具体模型实现分析
3.1 BERT模型深度分析
3.1.1 BERT架构设计
class BertModel(BertPreTrainedModel):
"""BERT基础模型实现"""
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
# 1. 嵌入层
self.embeddings = BertEmbeddings(config)
# 2. 编码器层
self.encoder = BertEncoder(config)
# 3. 池化层
self.pooler = BertPooler(config) if add_pooling_layer else None
# 初始化权重
self.post_init()
def get_input_embeddings(self):
"""获取输入嵌入层"""
return self.embeddings.word_embeddings
def set_input_embeddings(self, value):
"""设置输入嵌入层"""
self.embeddings.word_embeddings = value
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
"""BERT前向传播"""
# 1. 配置默认参数
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 2. 处理输入
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
# 3. 计算设备信息
device = input_ids.device if input_ids is not None else inputs_embeds.device
# 4. 创建注意力掩码
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
# 5. 扩展attention_mask用于后续使用
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
# 6. 准备encoder_attention_mask
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# 7. 准备head_mask
if head_mask is not None:
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
# 8. 嵌入层处理
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length
)
# 9. 编码器处理
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 10. 池化层处理
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
# 11. 返回结果
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
class BertEmbeddings(nn.Module):
"""BERT嵌入层实现"""
def __init__(self, config):
super().__init__()
self.config = config
# 词嵌入
self.word_embeddings = nn.Embedding(
config.vocab_size,
config.hidden_size,
padding_idx=config.pad_token_id
)
# 位置嵌入
self.position_embeddings = nn.Embedding(
config.max_position_embeddings,
config.hidden_size
)
# token类型嵌入
self.token_type_embeddings = nn.Embedding(
config.type_vocab_size,
config.hidden_size
)
# 层归一化和dropout
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 位置ID缓存
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False
)
# Token type ID缓存
self.register_buffer(
"token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
"""嵌入层前向传播"""
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
# 处理位置ID
if position_ids is None:
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# 处理token类型ID
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
# 获取嵌入
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
# 组合嵌入
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
# 层归一化和dropout
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
3.1.2 BERT注意力机制
class BertSelfAttention(nn.Module):
"""BERT自注意力机制实现"""
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.config = config
self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
# 线性变换层
self.query = nn.Linear(config.hidden_size, config.hidden_size)
self.key = nn.Linear(config.hidden_size, config.hidden_size)
self.value = nn.Linear(config.hidden_size, config.hidden_size)
# Dropout层
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
# 头数和维度
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
# 相对位置嵌入(如果使用)
if self.position_embedding_type in ["relative_key", "relative_key_query"]:
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
def transpose_for_scores(self, x):
"""转置以便多头注意力计算"""
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
"""自注意力前向传播"""
mixed_query_layer = self.query(hidden_states)
# 如果有过去的键值对,则拼接
if past_key_value is not None:
past_key, past_value = past_key_value
key_layer = torch.cat([past_key, self.key(hidden_states)], dim=2)
value_layer = torch.cat([past_value, self.value(hidden_states)], dim=2)
else:
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
# 处理编码器状态(交叉注意力)
if encoder_hidden_states is not None:
key_layer = self.key(encoder_hidden_states)
value_layer = self.value(encoder_hidden_states)
attention_mask = encoder_attention_mask
# 多头变换
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(key_layer)
value_layer = self.transpose_for_scores(value_layer)
# 计算注意力分数
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
if self.position_embedding_type in ["relative_key", "relative_key_query"]:
# 相对位置注意力
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
if self.position_embedding_type == "relative_key":
relative_position_scores = self._compute_relative_key_scores(query_layer, key_layer)
else:
relative_position_scores = self._compute_relative_key_query_scores(query_layer, key_layer)
attention_scores = attention_scores + relative_position_scores
# 缩放
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# 应用注意力掩码
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
# 计算注意力权重
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# 应用dropout
attention_probs = self.dropout(attention_probs)
# 应用head_mask
if head_mask is not None:
attention_probs = attention_probs * head_mask
# 计算上下文向量
context_layer = torch.matmul(attention_probs, value_layer)
# 重新排列维度
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs
class BertSelfOutput(nn.Module):
"""BERT自注意力输出层"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
"""输出层前向传播"""
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertAttention(nn.Module):
"""BERT注意力模块(包含自注意力和输出)"""
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.self = BertSelfAttention(config, position_embedding_type)
self.output = BertSelfOutput(config)
self.pruned_heads = set()
def prune_heads(self, heads):
"""剪枝注意力头"""
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
)
# 剪枝线性层
self.self.query = prune_linear_layer(self.self.query, index)
self.self.key = prune_linear_layer(self.self.key, index)
self.self.value = prune_linear_layer(self.self.value, index)
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
# 更新头数和已剪枝头集合
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
"""注意力模块前向传播"""
self_outputs = self.self(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
attention_output = self.output(self_outputs[0], hidden_states)
# 如果输出注意力权重,则包含在输出中
if output_attentions:
outputs = (attention_output,) + self_outputs[1:]
else:
outputs = (attention_output,)
return outputs
3.1.3 BERT中间层和输出层
class BertIntermediate(nn.Module):
"""BERT中间层(前馈网络)"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
# 激活函数
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""中间层前向传播"""
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module):
"""BERT输出层"""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
"""输出层前向传播"""
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class BertLayer(nn.Module):
"""BERT编码器层"""
def __init__(self, config, position_embedding_type=None):
super().__init__()
# 注意力模块
self.attention = BertAttention(config, position_embedding_type=position_embedding_type)
# 是否为解码器
self.is_decoder = config.is_decoder
# 交叉注意力(解码器使用)
if self.is_decoder:
self.crossattention = BertAttention(config, position_embedding_type=position_embedding_type)
# 中间层和输出层
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
"""编码器层前向传播"""
# 自注意力
self_attention_outputs = self.attention(
hidden_states,
attention_mask,
head_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
)
attention_output = self_attention_outputs[0]
# 如果是解码器且有编码器状态,则应用交叉注意力
if self.is_decoder and encoder_hidden_states is not None:
cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
output_attentions=output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = self_attention_outputs[1:] + cross_attention_outputs[1:]
else:
outputs = self_attention_outputs[1:]
# 前馈网络
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + outputs
return outputs
class BertEncoder(nn.Module):
"""BERT编码器"""
def __init__(self, config, position_embedding_type=None):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config, position_embedding_type) for _ in range(config.num_hidden_layers)])
# 梯度检查点
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
"""编码器前向传播"""
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.is_decoder else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# 梯度检查点
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpoint(
layer_module,
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
past_key_values,
use_cache,
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
past_key_values,
use_cache,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-2],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
3.2 任务特定模型扩展
3.2.1 序列分类模型
class BertForSequenceClassification(BertPreTrainedModel):
"""BERT序列分类模型"""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# 基础BERT模型
self.bert = BertModel(config)
# 分类头
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
"""序列分类前向传播"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# BERT前向传播
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 使用[CLS] token的表示进行分类
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
# 计算损失
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
3.2.2 Token分类模型
class BertForTokenClassification(BertPreTrainedModel):
"""BERT Token分类模型(如NER)"""
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
# 基础BERT模型
self.bert = BertModel(config, add_pooling_layer=False)
# 分类头
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# 初始化权重
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
"""Token分类前向传播"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# BERT前向传播(不返回pooler_output)
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# 使用每个token的表示进行分类
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
# 计算损失
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
4. 模型自动加载系统
4.1 Auto类系统架构
class AutoModel:
"""自动模型加载基类"""
# 模型映射字典
_model_mapping = MODEL_MAPPING_NAMES
@classmethod
def from_config(cls, config, **kwargs):
"""从配置创建模型"""
# 1. 获取配置类
config_class = cls._model_mapping[type(config)]
# 2. 创建模型
return config_class.from_config(config, **kwargs)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""从预训练模型加载"""
# 1. 加载配置
config = kwargs.pop("config", None)
if config is None:
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
# 2. 获取模型类
model_class = cls._model_mapping[type(config)]
# 3. 加载模型
return model_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
**kwargs
)
# 扩展的Auto类
class AutoModelForSequenceClassification(AutoModel):
"""自动序列分类模型"""
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
class AutoModelForTokenClassification(AutoModel):
"""自动Token分类模型"""
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
class AutoModelForQuestionAnswering(AutoModel):
"""自动问答模型"""
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
4.2 模型注册机制
class ModelRegistry:
"""模型注册系统"""
_models = {}
_configs = {}
@classmethod
def register(cls, name: str, model_class: Type[PreTrainedModel], config_class: Type[PreTrainedConfig]):
"""注册新模型"""
cls._models[name] = model_class
cls._configs[name] = config_class
# 更新Auto类映射
if hasattr(model_class, 'config_class'):
MODEL_MAPPING_NAMES[config_class] = model_class
@classmethod
def get_model_class(cls, name: str) -> Type[PreTrainedModel]:
"""获取模型类"""
return cls._models.get(name)
@classmethod
def get_config_class(cls, name: str) -> Type[PreTrainedConfig]:
"""获取配置类"""
return cls._configs.get(name)
@classmethod
def list_models(cls) -> List[str]:
"""列出所有注册的模型"""
return list(cls._models.keys())
# 装饰器注册器
def register_model(name: str):
"""模型注册装饰器"""
def decorator(model_class):
config_class = model_class.config_class
ModelRegistry.register(name, model_class, config_class)
return model_class
return decorator
# 使用示例
@register_model("my_custom_model")
class MyCustomModel(PreTrainedModel):
config_class = MyCustomConfig
def __init__(self, config):
super().__init__(config)
# 模型实现
5. 高级特性和优化
5.1 模型量化和压缩
class QuantizationMixin:
"""模型量化混入类"""
def quantize(self, quantization_config: QuantizationConfig):
"""量化模型"""
if quantization_config.quant_method == "static":
return self._static_quantization(quantization_config)
elif quantization_config.quant_method == "dynamic":
return self._dynamic_quantization(quantization_config)
elif quantization_config.quant_method == "qat":
return self._quantization_aware_training(quantization_config)
else:
raise ValueError(f"Unsupported quantization method: {quantization_config.quant_method}")
def _static_quantization(self, config: QuantizationConfig):
"""静态量化"""
# 1. 准备校准数据
calibration_data = self._prepare_calibration_data(config.calibration_dataset)
# 2. 插入观察者
self._insert_observers(config)
# 3. 校准过程
self._calibrate(calibration_data)
# 4. 转换为量化模型
quantized_model = torch.quantization.convert(self.eval(), inplace=config.inplace)
return quantized_model
def _dynamic_quantization(self, config: QuantizationConfig):
"""动态量化"""
# 指定量化配置
quantized_model = torch.quantization.quantize_dynamic(
self,
{nn.Linear, nn.LSTM, nn.GRU},
dtype=torch.qint8,
inplace=config.inplace
)
return quantized_model
def _quantization_aware_training(self, config: QuantizationConfig):
"""量化感知训练"""
# 1. 准备模型
self.train()
# 2. 插入伪量化节点
self._insert_fake_quant(config)
# 3. 转换为QAT模型
qat_model = torch.quantization.prepare_qat(self, inplace=config.inplace)
return qat_model
class PruningMixin:
"""模型剪枝混入类"""
def prune(self, pruning_config: PruningConfig):
"""剪枝模型"""
if pruning_config.pruning_method == "magnitude":
return self._magnitude_pruning(pruning_config)
elif pruning_config.pruning_method == "structured":
return self._structured_pruning(pruning_config)
elif pruning_config.pruning_method == "gradual":
return self._gradual_pruning(pruning_config)
else:
raise ValueError(f"Unsupported pruning method: {pruning_config.pruning_method}")
def _magnitude_pruning(self, config: PruningConfig):
"""基于权重大小的剪枝"""
# 1. 计算权重重要性分数
importance_scores = {}
for name, param in self.named_parameters():
if "weight" in name and param.dim() > 1:
importance_scores[name] = torch.abs(param)
# 2. 确定剪枝阈值
all_scores = torch.cat([score.flatten() for score in importance_scores.values()])
threshold = torch.kthvalue(all_scores, int(len(all_scores) * config.sparsity)).values
# 3. 应用剪枝
mask_dict = {}
for name, scores in importance_scores.items():
mask = scores > threshold
mask_dict[name] = mask.to(torch.float32)
# 4. 应用掩码
self._apply_pruning_masks(mask_dict, inplace=config.inplace)
return self
5.2 内存优化技术
class MemoryOptimizationMixin:
"""内存优化混入类"""
def enable_gradient_checkpointing(self):
"""启用梯度检查点"""
if not self.supports_gradient_checkpointing:
raise ValueError("Model does not support gradient checkpointing")
# 递归启用梯度检查点
def enable_checkpointing_recursive(module):
if hasattr(module, "gradient_checkpointing_enable"):
module.gradient_checkpointing_enable()
for child in module.children():
enable_checkpointing_recursive(child)
enable_checkpointing_recursive(self)
def optimize_for_inference(self):
"""推理优化"""
# 1. 评估模式
self.eval()
# 2. 融合操作
self._fuse_modules()
# 3. 优化内存布局
self._optimize_memory_layout()
# 4. JIT编译(如果支持)
if hasattr(torch.jit, "optimize_for_inference"):
return torch.jit.optimize_for_inference(torch.jit.script(self))
return self
def _fuse_modules(self):
"""融合模块以减少内存占用"""
# 定义可融合的模块模式
fusion_patterns = [
["Linear", "BatchNorm1d"],
["Conv2d", "BatchNorm2d", "ReLU"],
["Linear", "ReLU"],
["Conv2d", "ReLU"],
]
# 应用融合
for pattern in fusion_patterns:
self._apply_fusion_pattern(pattern)
def _optimize_memory_layout(self):
"""优化内存布局"""
for module in self.modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
# 转换为channels_last格式(对于CNN)
if hasattr(module.weight, "is_contiguous_memory_format"):
module.weight.data = module.weight.data.to(memory_format=torch.channels_last)
# 确保权重是连续的
if not module.weight.is_contiguous():
module.weight.data = module.weight.data.contiguous()
5.3 分布式模型支持
class DistributedModelMixin:
"""分布式模型混入类"""
def prepare_for_distributed(self, strategy: str = "ddp"):
"""为分布式训练准备模型"""
if strategy == "ddp":
return self._prepare_for_ddp()
elif strategy == "deepspeed":
return self._prepare_for_deepspeed()
elif strategy == "fsdp":
return self._prepare_for_fsdp()
else:
raise ValueError(f"Unsupported distributed strategy: {strategy}")
def _prepare_for_ddp(self):
"""为DDP准备模型"""
import torch.distributed as dist
if not dist.is_initialized():
raise RuntimeError("Distributed training not initialized")
# 包装模型
self = torch.nn.parallel.DistributedDataParallel(
self,
device_ids=[dist.get_rank()],
output_device=dist.get_rank(),
find_unused_parameters=getattr(self.config, "ddp_find_unused_parameters", False)
)
return self
def _prepare_for_deepspeed(self):
"""为DeepSpeed准备模型"""
# 这个方法通常在Trainer中调用
# 这里只是占位符,实际的DeepSpeed初始化在Trainer中进行
pass
def _prepare_for_fsdp(self):
"""为FSDP准备模型"""
# FSDP的完整集成需要Accelerate
# 这里提供基础接口
pass
def shard_model(self, shard_strategy: str = "zero2"):
"""模型分片"""
if shard_strategy == "zero1":
return self._zero1_sharding()
elif shard_strategy == "zero2":
return self._zero2_sharding()
elif shard_strategy == "zero3":
return self._zero3_sharding()
else:
raise ValueError(f"Unsupported shard strategy: {shard_strategy}")
def _zero2_sharding(self):
"""ZeRO-2分片策略"""
# 将梯度状态分片
for param in self.parameters():
if param.requires_grad:
param.data = param.data.detach().contiguous()
if hasattr(param, "grad") and param.grad is not None:
param.grad = param.grad.detach().contiguous()
return self
6. 模型生成系统
6.1 生成配置和接口
class GenerationMixin:
"""生成混入类 - 为所有模型提供生成能力"""
@staticmethod
def _get_generation_mode(
assistant_model: Optional["PreTrainedModel"] = None,
input_ids: Optional[torch.LongTensor] = None,
**kwargs
) -> GenerationMode:
"""确定生成模式"""
if "num_beams" in kwargs and kwargs["num_beams"] > 1:
if "do_sample" in kwargs and kwargs["do_sample"]:
return GenerationMode.BEAM_SAMPLE
else:
return GenerationMode.BEAM_SEARCH
elif "do_sample" in kwargs and kwargs["do_sample"]:
if "temperature" in kwargs and kwargs["temperature"] > 0:
return GenerationMode.SAMPLE
else:
return GenerationMode.GREEDY_SEARCH
elif assistant_model is not None:
return GenerationMode.ASSISTED_GENERATION
else:
return GenerationMode.GREEDY_SEARCH
def generate(
self,
input_ids: torch.LongTensor,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
"""生成主方法"""
# 1. 准备生成配置
generation_config = self._prepare_generation_config(generation_config, **kwargs)
# 2. 确定生成模式
generation_mode = self._get_generation_mode(
assistant_model=assistant_model,
input_ids=input_ids,
**generation_config.to_dict()
)
# 3. 准备输入
model_inputs = self._prepare_model_inputs(input_ids, generation_config.bos_token_id)
# 4. 准备处理器和停止条件
logits_processor = self._get_logits_processor(
generation_config,
input_ids_length=model_inputs["input_ids"].shape[-1],
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)
stopping_criteria = self._get_stopping_criteria(
generation_config, stopping_criteria
)
# 5. 执行生成
if generation_mode == GenerationMode.GREEDY_SEARCH:
return self._greedy_search(
input_ids,
logits_processor,
stopping_criteria,
generation_config,
**model_inputs,
)
elif generation_mode == GenerationMode.SAMPLE:
return self._sample(
input_ids,
logits_processor,
stopping_criteria,
generation_config,
**model_inputs,
)
elif generation_mode == GenerationMode.BEAM_SEARCH:
return self._beam_search(
input_ids,
logits_processor,
stopping_criteria,
generation_config,
**model_inputs,
)
elif generation_mode == GenerationMode.ASSISTED_GENERATION:
return self._assisted_generation(
input_ids,
assistant_model,
logits_processor,
stopping_criteria,
generation_config,
**model_inputs,
)
6.2 高级生成策略
class AdvancedGenerationStrategies:
"""高级生成策略"""
@staticmethod
def nucleus_sampling(logits: torch.Tensor, p: float) -> torch.Tensor:
"""核采样(Top-p)"""
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# 移除累积概率超过p的token
sorted_indices_to_remove = cumulative_probs > p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float("-inf")
return logits
@staticmethod
def top_k_sampling(logits: torch.Tensor, k: int) -> torch.Tensor:
"""Top-k采样"""
if k == 0:
return logits
top_k_logits, top_k_indices = torch.topk(logits, k)
# 创建掩码,只保留top-k
indices_to_remove = logits < top_k_logits[:, -1:]
logits[indices_to_remove] = float("-inf")
return logits
@staticmethod
def temperature_scaling(logits: torch.Tensor, temperature: float) -> torch.Tensor:
"""温度缩放"""
return logits / temperature
@staticmethod
def repetition_penalty(
logits: torch.Tensor,
input_ids: torch.Tensor,
penalty: float
) -> torch.Tensor:
"""重复惩罚"""
score = torch.gather(logits, 1, input_ids)
# 对已生成的token应用惩罚
if penalty != 1.0:
score = torch.where(
score < 0,
score * penalty,
score / penalty
)
logits.scatter_(1, input_ids, score)
return logits
7. 模型模块总结与展望
7.1 架构优势总结
Transformers模型模块通过其卓越的设计体现了深度学习模型实现的最佳实践:
1. 高度标准化: 统一的接口、配置和实现模式确保了一致性
2. 模块化设计: 清晰的组件分离,便于维护和扩展
3. 配置驱动: 通过配置文件完全控制模型行为和超参数
4. 任务扩展: 基础模型可以轻松扩展为各种下游任务
5. 自动加载: Auto类系统提供了无缝的模型加载体验
6. 性能优化: 内置量化、剪枝、内存优化等技术
7. 分布式支持: 原生支持多种分布式训练策略
7.2 技术创新点
1. 统一抽象: 通过PreTrainedModel基类统一了所有模型的接口
2. 动态配置: 配置系统支持动态参数和版本管理
3. 自动发现: Auto类系统实现了模型的自动识别和加载
4. 梯度检查点: 内置的梯度检查点支持减少内存占用
5. 生成系统: 统一的生成接口支持多种生成策略
6. 多模态支持: 通过统一接口支持文本、图像、语音等多模态模型
7.3 未来发展方向
1. 更大模型支持: 更好地支持万亿参数级别的超大模型
2. 更多模态: 视频、3D、图形等新兴模态的支持
3. 边缘优化: 针对移动端和边缘设备的特殊优化
4. 自动化模型: 自动模型架构搜索和优化
5. 绿色AI: 更高效的能耗和资源利用
7.4 最佳实践建议
1. 遵循标准化: 实现新模型时严格遵循标准化模式
2. 配置驱动: 通过配置文件控制所有模型行为
3. 测试覆盖: 确保充分的单元测试和集成测试
4. 文档完善: 提供详细的API文档和使用示例
5. 性能优化: 充分利用内置的优化技术
6. 版本管理: 注意模型权重的版本兼容性
Transformers模型模块通过其卓越的架构设计和丰富的功能特性,为深度学习模型实现提供了强大而灵活的基础框架,极大地简化了新模型的开发和集成,对推动AI技术的快速发展和普及具有重要意义。

963

被折叠的 条评论
为什么被折叠?



