# architectures/base_model.py
import torch
import inspect
import torch.nn as nn
from abc import ABC, abstractmethod
from components import *
import os
from datetime import datetime
class BaseMultiArchModel(ABC, nn.Module):
"""多架构模型基类,定义通用组件和接口"""
# 组件白名单(核心定型后仅创建者可修改)
# 注意:必须注册具体实现类,而非抽象基类!
ALLOWED_COMPONENTS = {
"embedding_Standard": StandardEmbedding, # 具体实现类
"embedding_Position": PositionalEmbedding, # 具体实现类
"attention_Self": SelfAttentionModule, # 具体实现类
"attention_Dynamic": DynamicAttention, # 具体实现类
"norm_Layer": LayerNormModule, # 具体实现类
"norm_Dynamic": DynamicNormModule, # 具体实现类
"cnn_Conv1d": Conv1dModule, # 具体实现类
"cnn_MultiK": MultiKernelCNN, # 具体实现类
"rnn_GRU": GRUModule, # 具体实现类
"rnn_LSTM": LSTMModule, # 具体实现类
"transformer_Transformer": TransformerModule, # 具体实现类
"transformer_Lightweight": LightweightTransformer, # 具体实现类
"fusion_Conca": ConcatFusion, # 具体实现类
"fusion_Attention": AttentionFusion # 具体实现类
}
# 基因快照配置
GENE_SNAPSHOT_DIR = "gene_snapshots/"
def __init__(self, model_config):
super().__init__()
self.model_config = model_config
self._init_components()
self._ensure_snapshot_dir()
def _init_components(self):
"""初始化通用组件(增加白名单校验)"""
# 创建嵌入层(任务特定,需子类实现)
self.embedding = self._create_embedding()
# 创建通用组件(使用白名单校验)
for comp_type in [
"embedding_Standard",
"embedding_Position",
"attention_Self",
"attention_Dynamic",
"norm_Layer",
"norm_Dynamic",
"cnn_Conv1d",
"cnn_MultiK",
"rnn_GRU",
"rnn_LSTM",
"transformer_Transformer",
"transformer_Lightweight",
"fusion_Conca",
"fusion_Attention"
]:
config = self.model_config["model_components"].get(comp_type, {})
comp_class = self.ALLOWED_COMPONENTS[comp_type]
# 校验组件是否被篡改
assert issubclass(comp_class, BaseComponent), f"非法组件:{comp_type}"
# 确保是具体实现类
assert not inspect.isabstract(comp_class), f"非法组件:{comp_type} 是抽象类"
# 创建组件实例
setattr(self, comp_type, comp_class(**config))
# 【修改】根据配置设置默认组件的快捷引用
# 不再通过命名规则自动推断,而是通过配置文件显式指定
default_components = self.model_config.get("default_components", {})
# 默认组件映射关系
default_mapping = {
"attention": "attention_Self",
"norm": "norm_Layer",
"cnn": "cnn_Conv1d",
"rnn": "rnn_GRU",
"transformer": "transformer_Transformer",
"fusion": "fusion_Conca"
}
# 应用默认组件配置
for attr_name, default_key in default_mapping.items():
# 如果配置中指定了其他组件,则使用配置中的组件
config_key = default_components.get(attr_name, default_key)
# 检查组件是否存在
if not hasattr(self, config_key):
raise KeyError(f"默认组件 {config_key} 未在组件初始化中创建,请检查配置!")
# 设置快捷引用,例如 self.attention = self.attention_Self
setattr(self, attr_name, getattr(self, config_key))
def _ensure_snapshot_dir(self):
"""确保基因快照目录存在"""
if not os.path.exists(self.GENE_SNAPSHOT_DIR):
os.makedirs(self.GENE_SNAPSHOT_DIR)
@abstractmethod
def _create_embedding(self):
"""创建任务特定的嵌入层"""
pass
@abstractmethod
def _create_fusion(self):
"""创建特征融合层"""
pass
# 通用组件创建方法(保留,但修改为使用配置中指定的组件)
def _create_attention(self):
# 从配置中获取使用哪个注意力组件
attention_type = self.model_config["model_components"].get("attention_type", "attention_Self")
config = self.model_config["model_components"].get(attention_type, {})
return self.ALLOWED_COMPONENTS[attention_type](
embed_dim=self.model_config["embedding_dim"],
**config
)
def _create_normalization(self):
# 从配置中获取使用哪个归一化组件
norm_type = self.model_config["model_components"].get("norm_type", "norm_Layer")
config = self.model_config["model_components"].get(norm_type, {})
return self.ALLOWED_COMPONENTS[norm_type](
num_features=self.model_config["embedding_dim"],
**config
)
def _create_cnn(self):
# 从配置中获取使用哪个CNN组件
cnn_type = self.model_config["model_components"].get("cnn_type", "cnn_Conv1d")
config = self.model_config["model_components"].get(cnn_type, {})
return self.ALLOWED_COMPONENTS[cnn_type](
in_channels=self.model_config["embedding_dim"],
out_channels=self.model_config["hidden_dim"],
**config
)
def _create_rnn(self):
# 从配置中获取使用哪个RNN组件
rnn_type = self.model_config["model_components"].get("rnn_type", "rnn_GRU")
config = self.model_config["model_components"].get(rnn_type, {})
return self.ALLOWED_COMPONENTS[rnn_type](
input_size=self.model_config["embedding_dim"],
hidden_size=self.model_config["hidden_dim"],
**config
)
def _create_transformer(self):
# 从配置中获取使用哪个Transformer组件
transformer_type = self.model_config["model_components"].get("transformer_type", "transformer_Transformer")
config = self.model_config["model_components"].get(transformer_type, {})
return self.ALLOWED_COMPONENTS[transformer_type](
d_model=self.model_config["embedding_dim"],
**config
)
def forward_features(self, x):
"""通用特征提取流程"""
# 嵌入层处理
x_embed = self.embedding(x)
# 使用默认的归一化和注意力组件
x_embed = self.norm(x_embed) if hasattr(self, 'norm') else x_embed
x_embed = self.attention(x_embed) if hasattr(self, 'attention') else x_embed
# 多架构特征提取
cnn_out = self.cnn(x_embed.transpose(1, 2)).transpose(1, 2) if hasattr(self, 'cnn') else None
rnn_out = self.rnn(x_embed) if hasattr(self, 'rnn') else None
trans_out = self.transformer(x_embed) if hasattr(self, 'transformer') else None
# 特征融合
fusion_inputs = {}
if cnn_out is not None:
fusion_inputs["cnn"] = cnn_out
if rnn_out is not None:
fusion_inputs["rnn"] = rnn_out
if trans_out is not None:
fusion_inputs["transformer"] = trans_out
fused_features = self.fusion(fusion_inputs)
return fused_features
@abstractmethod
def forward(self, x):
"""主前向传播方法"""
pass
def save_gene_snapshot(self, snapshot_name=None, evolution_note=""):
"""保存核心组件权重(基因快照)"""
snapshot_name = snapshot_name or datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_path = os.path.join(self.GENE_SNAPSHOT_DIR, f"{snapshot_name}.pth")
# 保存所有组件的权重
state_dict = {
"model_config": self.model_config
}
# 保存主要组件的权重
for attr in ["embedding", "attention", "norm", "cnn", "rnn", "transformer", "fusion"]:
if hasattr(self, attr):
state_dict[attr] = getattr(self, attr).state_dict()
# 保存所有注册的组件权重
for comp_type in self.ALLOWED_COMPONENTS.keys():
if hasattr(self, comp_type):
state_dict[comp_type] = getattr(self, comp_type).state_dict()
torch.save(state_dict, snapshot_path)
# 写入进化日志
with open(os.path.join(self.GENE_SNAPSHOT_DIR, "evolution_log.md"), "a") as f:
f.write(f"## {snapshot_name}\n")
f.write(f"变更原因:{evolution_note}\n")
f.write(f"时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
def load_gene_snapshot(self, snapshot_path):
"""加载基因快照(复活/回退用)"""
state_dict = torch.load(snapshot_path)
# 更新配置
self.model_config = state_dict["model_config"]
# 加载主要组件权重
for attr in ["embedding", "attention", "norm", "cnn", "rnn", "transformer", "fusion"]:
if attr in state_dict and hasattr(self, attr):
getattr(self, attr).load_state_dict(state_dict[attr])
# 加载所有注册的组件权重
for comp_type in self.ALLOWED_COMPONENTS.keys():
if comp_type in state_dict and hasattr(self, comp_type):
getattr(self, comp_type).load_state_dict(state_dict[comp_type])
print(f"[萧默芯] 成功加载基因快照: {snapshot_path}") 帮我分析一下这个代码的可靠性和能力!