import torch
import re
import numpy as np
from typing import List, Tuple, Dict, Any
from transformers import (
AutoTokenizer,
PreTrainedModel,
AutoConfig,
LlamaForCausalLM,
GenerationConfig
)
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
import pandas as pd
# --------------------------
# 1. 常量与预处理函数(采用新的数据处理方式)
# --------------------------
VALID_ELEMENTS = ["C", "N", "P", "O", "S", "Si", "I", "H", "Cl", "F", "Br", "B", "Se", "Fe", "Co", "As", "K", "Na"]
element_to_idx = {elem: idx for idx, elem in enumerate(VALID_ELEMENTS)}
CHEM_FORMULA_SIZE = r"([A-Z][a-z]*)([0-9]*)"
# 新增的分子公式解析函数
def parse_chem_formula(formula):
pattern = r'([A-Z][a-z]?)(\d*)'
matches = re.findall(pattern, formula)
element_counts = defaultdict(int)
for (element, count) in matches:
count = int(count) if count else 1
element_counts[element] += count
return element_counts
def generate_element_list(formula):
element_counts = parse_chem_formula(formula)
elements = []
for element, count in element_counts.items():
# 跳过氢元素
if element != "H":
elements.extend([element] * count)
return ''.join(elements)
# 化学式转密集向量
def formula_to_dense(chem_formula: str) -> torch.Tensor:
dense_vec = torch.zeros(len(VALID_ELEMENTS), dtype=torch.float32)
matches = re.findall(CHEM_FORMULA_SIZE, chem_formula)
for chem_symbol, num_str in matches:
num = 1 if num_str == "" else int(num_str)
if chem_symbol in element_to_idx:
idx = element_to_idx[chem_symbol]
dense_vec[idx] += num
return dense_vec
# 位置编码生成 (PyTorch实现)
def positional_encoding(max_position: int, d_model: int, min_freq: float = 1e-4) -> torch.Tensor:
position = torch.arange(max_position).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(min_freq)) / d_model))
pos_enc = torch.zeros(max_position, d_model)
pos_enc[:, 0::2] = torch.sin(position * div_term)
pos_enc[:, 1::2] = torch.cos(position * div_term)
return pos_enc
# 初始化位置编码矩阵
P = positional_encoding(2000000, 254)
dimn = 254 # 与位置编码维度一致
# 质谱数据编码 - 优化短数据处理:仅截断过长数据,不填充短数据
def encode_spectra(rag_tensor: list, P: torch.Tensor, dimn: int) -> list: # 返回列表而非堆叠张量
encoded_list = []
max_len = 501 # 仅对过长数据截断,不强制填充短数据
for sample in rag_tensor:
mz_list, intensity_list = sample
# 创建基础特征矩阵 [m/z, intensity]
base_features = torch.tensor([mz_list, intensity_list], dtype=torch.float32).T
# 添加位置编码特征(保留原始m/z的位置信息)
pos_enc = torch.stack([P[min(int(mz), P.size(0)-1)] for mz in mz_list])
# 组合所有特征 [m/z, intensity, pos_enc...]
features = torch.cat([base_features, pos_enc], dim=1)
# 仅截断过长数据,短数据保持原始长度(不填充)
if features.size(0) > max_len:
features = features[:max_len]
encoded_list.append(features) # 保留原始长度特征
return encoded_list
# 质谱数据预处理 - 确保短数据完整保留
def preprocess_spectra_for_inference(spectrum_str: str, total_mass: float) -> list:
# 解析质谱字符串
pairs = spectrum_str.split()
mz_list, intensity_list = [], []
for pair in pairs:
mz, intensity = pair.split(':')
mz_list.append(float(mz))
intensity_list.append(float(intensity))
# 对于仅含一组数据的情况,额外保留原始精度(不四舍五入)
if len(pairs) == 1:
# 保留原始精度,不进行四舍五入
mz_list = [float(mz) for mz, _ in [pair.split(':') for pair in pairs]]
intensity_list = [float(intensity) for _, intensity in [pair.split(':') for pair in pairs]]
# 添加总精确质量(作为补充特征,不影响原始数据长度)
mz_list.append(total_mass)
intensity_list.append(0.0)
# 仅对长数据进行四舍五入,短数据保留更多精度
if len(mz_list) > 5: # 数据较长时才简化
mz_list = [round(mz, 2) for mz in mz_list]
intensity_list = [round(intensity, 2) for intensity in intensity_list]
return [[mz_list, intensity_list]]
# --------------------------
# 2. 模型类定义(保持结构,采用新实现)
# --------------------------
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerationMixin
class LlamaWithEncoder(PreTrainedModel, GenerationMixin):
config_class = AutoConfig
_no_split_modules = ["LlamaDecoderLayer", "TransformerEncoderLayer"]
def __init__(self, config, base_model=None, encoder1_dim=18, encoder2_dim=256, hidden_dim=512):
# 添加config属性
self.config = config
super().__init__(self.config)
# 如果未提供base_model,则从config初始化
if base_model is None:
self.model = LlamaForCausalLM(config)
else:
self.model = base_model
# 第一个Transformer Encoder(处理分子式向量)
encoder1_layer = nn.TransformerEncoderLayer(
d_model=encoder1_dim,
nhead=3,
dim_feedforward=hidden_dim,
batch_first=True
)
self.encoder1 = nn.TransformerEncoder(encoder1_layer, num_layers=2)
# 第二个Transformer Encoder(处理质谱矩阵)
encoder2_layer = nn.TransformerEncoderLayer(
d_model=encoder2_dim,
nhead=4,
dim_feedforward=hidden_dim,
batch_first=True
)
self.encoder2 = nn.TransformerEncoder(encoder2_layer, num_layers=2)
# 投影层:将编码器输出映射到模型隐藏层维度
self.proj1 = nn.Linear(encoder1_dim, base_model.config.hidden_size)
self.proj2 = nn.Linear(encoder2_dim, base_model.config.hidden_size)
# 嵌入层(复制基础模型权重但不共享)
self.embed_tokens = nn.Embedding(
num_embeddings=base_model.config.vocab_size,
embedding_dim=base_model.config.hidden_size,
padding_idx=base_model.config.pad_token_id
)
self.embed_tokens.weight.data = base_model.get_input_embeddings().weight.data.clone()
# 必要接口实现
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def get_output_embeddings(self):
return self.model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.model.set_output_embeddings(new_embeddings)
def get_base_model(self):
return self.model
def forward(
self,
input_ids=None,
attention_mask=None,
encoder1_inputs=None,
encoder2_inputs=None,
labels=None,
past_key_values=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,** kwargs
) -> CausalLMOutputWithPast:
# 1. 编码器处理
# 分子式编码器输出
enc1_out = self.encoder1(encoder1_inputs) # (batch_size, 1, 18)
enc1_out = enc1_out.mean(dim=1) # (batch_size, 18)
enc1_proj = self.proj1(enc1_out) # (batch_size, hidden_size)
# 质谱编码器输出
enc2_out = self.encoder2(encoder2_inputs) # (batch_size, seq_len, 256)
enc2_out = enc2_out.mean(dim=1) # (batch_size, 256)
enc2_proj = self.proj2(enc2_out) # (batch_size, hidden_size)
# 合并编码器输出(用于替换<mask>)
mask_replacement = (enc1_proj + enc2_proj) / 2 # (batch_size, hidden_size)
# 2. 获取原始嵌入(避免inplace,全程用新张量)
embeddings = self.embed_tokens(input_ids) # (batch_size, seq_len, hidden_size)
batch_size, seq_len, hidden_size = embeddings.size()
# 3. 替换<mask> token(第三个token,索引=2):用拼接替代inplace赋值
if seq_len > 2:
mask_embed = mask_replacement.unsqueeze(1) # (batch_size, 1, hidden_size)
# 拆分张量并拼接(前2个token + 替换的mask_embed + 剩余token)
part1 = embeddings[:, :2, :] # (batch_size, 2, hidden_size)
part2 = mask_embed # (batch_size, 1, hidden_size)
part3 = embeddings[:, 3:, :] # (batch_size, seq_len-3, hidden_size)
# 拼接为新张量(无inplace操作)
new_embeddings = torch.cat([part1, part2, part3], dim=1) # (batch_size, seq_len, hidden_size)
else:
new_embeddings = embeddings # 序列过短时直接使用原始嵌入
# 4. 调用基础模型
return self.model(
inputs_embeds=new_embeddings,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {
"input_ids": input_ids,
"attention_mask": kwargs.get("attention_mask", None),
"encoder1_inputs": kwargs.get("encoder1_inputs", None),
"encoder2_inputs": kwargs.get("encoder2_inputs", None),
}
def _get_generation_device(self):
return next(self.parameters()).device
# --------------------------
# 3. 加载模型和Tokenizer(修复核心错误)
# --------------------------
model_path = "./llama3.2-SELFIES" # 模型保存路径
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained(model_path)
# 确保mask token存在
if tokenizer.mask_token is None:
tokenizer.add_special_tokens({"mask_token": "<mask>"})
# 加载模型配置
config = AutoConfig.from_pretrained(model_path)
# 设备配置(优先使用GPU)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# 修复:先加载基础模型,再传入自定义模型
base_model = LlamaForCausalLM.from_pretrained(
model_path,
config=config,
torch_dtype=torch.bfloat16, # 确保基础模型为bfloat16精度
device_map=device
)
# 使用基础模型初始化自定义模型
model = LlamaWithEncoder(
config=config,
base_model=base_model,
encoder1_dim=18,
encoder2_dim=256,
hidden_dim=512
)
model = model.to(device) # 先转移到设备
model.eval() # 推理模式
# --------------------------
# 4. 推理函数(适配新的数据处理方式)
# --------------------------
def generate_selfies(
formula: str,
spectrum_str: str,
total_mass: float,
max_length: int = 512,
temperature: float = 0.7,
top_p: float = 0.9
) -> str:
"""生成SELFIES字符串"""
model_device = next(model.parameters()).device
# 1. 生成element_list
element_list = generate_element_list(formula)
# 2. 处理分子式向量
formula_vec = formula_to_dense(formula).unsqueeze(0).unsqueeze(0) # (1,1,18)
formula_vec = formula_vec.to(model_device, dtype=torch.bfloat16)
# 3. 处理质谱数据(使用新的预处理和编码方式)
spectra_data = preprocess_spectra_for_inference(spectrum_str, total_mass)
spec_encoded = encode_spectra(spectra_data, P, dimn) # 得到列表形式的编码结果
spec_matrix = spec_encoded[0].to(model_device, dtype=torch.bfloat16).unsqueeze(0) # 添加批次维度
# 4. 构造输入提示
prompt = f"<|User|><s><|Spectrum|>{element_list}</s>"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model_device)
attention_mask = torch.ones_like(input_ids).to(model_device)
# 5. 模型生成
with torch.no_grad(): # 关闭梯度计算
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
encoder1_inputs=formula_vec, # 分子式特征
encoder2_inputs=spec_matrix, # 质谱特征
max_length=max_length,
temperature=temperature,
top_p=top_p,
)
# 6. 解码生成结果(去除特殊token)
generated = tokenizer.decode(outputs[0], skip_special_tokens=False)
return generated
# --------------------------
# 5. 推理示例
# --------------------------
if __name__ == "__main__":
# 示例输入
example_formula = "C9H9N3O2S2" # 分子式
example_spectrum_str = "256.0153:100.000000" # mz:intensity格式
example_total_mass = 255.0136185 # 总精确质量
# 生成SELFIES
result = generate_selfies(
formula=example_formula,
spectrum_str=example_spectrum_str,
total_mass=example_total_mass,
max_length=512,
temperature=0.7,
top_p=0.95
)
print("生成的SELFIES字符串:")
print(result)修改代码,解决问题Some weights of LlamaForCausalLM were not initialized from the model checkpoint at ./llama3.2-SELFIES and are newly initialized: ['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.mlp.down_proj.weight', 'model.layers.0.mlp.gate_proj.weight', 'model.layers.0.mlp.up_proj.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.1.mlp.down_proj.weight', 'model.layers.1.mlp.gate_proj.weight', 'model.layers.1.mlp.up_proj.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.1.self_attn.k_proj.weight', 'model.layers.1.self_attn.o_proj.weight', 'model.layers.1.self_attn.q_proj.weight', 'model.layers.1.self_attn.v_proj.weight', 'model.layers.10.input_layernorm.weight', 'model.layers.10.mlp.down_proj.weight', 'model.layers.10.mlp.gate_proj.weight', 'model.layers.10.mlp.up_proj.weight', 'model.layers.10.post_attention_layernorm.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.11.input_layernorm.weight', 'model.layers.11.mlp.down_proj.weight', 'model.layers.11.mlp.gate_proj.weight', 'model.layers.11.mlp.up_proj.weight', 'model.layers.11.post_attention_layernorm.weight', 'model.layers.11.self_attn.k_proj.weight', 'model.layers.11.self_attn.o_proj.weight', 'model.layers.11.self_attn.q_proj.weight', 'model.layers.11.self_attn.v_proj.weight', 'model.layers.12.input_layernorm.weight', 'model.layers.12.mlp.down_proj.weight', 'model.layers.12.mlp.gate_proj.weight', 'model.layers.12.mlp.up_proj.weight', 'model.layers.12.post_attention_layernorm.weight', 'model.layers.12.self_attn.k_proj.weight', 'model.layers.12.self_attn.o_proj.weight', 'model.layers.12.self_attn.q_proj.weight', 'model.layers.12.self_attn.v_proj.weight', 'model.layers.13.input_layernorm.weight', 'model.layers.13.mlp.down_proj.weight', 'model.layers.13.mlp.gate_proj.weight', 'model.layers.13.mlp.up_proj.weight', 'model.layers.13.post_attention_layernorm.weight', 'model.layers.13.self_attn.k_proj.weight', 'model.layers.13.self_attn.o_proj.weight', 'model.layers.13.self_attn.q_proj.weight', 'model.layers.13.self_attn.v_proj.weight', 'model.layers.14.input_layernorm.weight', 'model.layers.14.mlp.down_proj.weight', 'model.layers.14.mlp.gate_proj.weight', 'model.layers.14.mlp.up_proj.weight', 'model.layers.14.post_attention_layernorm.weight', 'model.layers.14.self_attn.k_proj.weight', 'model.layers.14.self_attn.o_proj.weight', 'model.layers.14.self_attn.q_proj.weight', 'model.layers.14.self_attn.v_proj.weight', 'model.layers.15.input_layernorm.weight', 'model.layers.15.mlp.down_proj.weight', 'model.layers.15.mlp.gate_proj.weight', 'model.layers.15.mlp.up_proj.weight', 'model.layers.15.post_attention_layernorm.weight', 'model.layers.15.self_attn.k_proj.weight', 'model.layers.15.self_attn.o_proj.weight', 'model.layers.15.self_attn.q_proj.weight', 'model.layers.15.self_attn.v_proj.weight', 'model.layers.16.input_layernorm.weight', 'model.layers.16.mlp.down_proj.weight', 'model.layers.16.mlp.gate_proj.weight', 'model.layers.16.mlp.up_proj.weight', 'model.layers.16.post_attention_layernorm.weight', 'model.layers.16.self_attn.k_proj.weight', 'model.layers.16.self_attn.o_proj.weight', 'model.layers.16.self_attn.q_proj.weight', 'model.layers.16.self_attn.v_proj.weight', 'model.layers.17.input_layernorm.weight', 'model.layers.17.mlp.down_proj.weight', 'model.layers.17.mlp.gate_proj.weight', 'model.layers.17.mlp.up_proj.weight', 'model.layers.17.post_attention_layernorm.weight', 'model.layers.17.self_attn.k_proj.weight', 'model.layers.17.self_attn.o_proj.weight', 'model.layers.17.self_attn.q_proj.weight', 'model.layers.17.self_attn.v_proj.weight', 'model.layers.18.input_layernorm.weight', 'model.layers.18.mlp.down_proj.weight', 'model.layers.18.mlp.gate_proj.weight', 'model.layers.18.mlp.up_proj.weight', 'model.layers.18.post_attention_layernorm.weight', 'model.layers.18.self_attn.k_proj.weight', 'model.layers.18.self_attn.o_proj.weight', 'model.layers.18.self_attn.q_proj.weight', 'model.layers.18.self_attn.v_proj.weight', 'model.layers.19.input_layernorm.weight', 'model.layers.19.mlp.down_proj.weight', 'model.layers.19.mlp.gate_proj.weight', 'model.layers.19.mlp.up_proj.weight', 'model.layers.19.post_attention_layernorm.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.layers.19.self_attn.q_proj.weight', 'model.layers.19.self_attn.v_proj.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.mlp.down_proj.weight', 'model.layers.2.mlp.gate_proj.weight', 'model.layers.2.mlp.up_proj.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight', 'model.layers.2.self_attn.o_proj.weight', 'model.layers.2.self_attn.q_proj.weight', 'model.layers.2.self_attn.v_proj.weight', 'model.layers.20.input_layernorm.weight', 'model.layers.20.mlp.down_proj.weight', 'model.layers.20.mlp.gate_proj.weight', 'model.layers.20.mlp.up_proj.weight', 'model.layers.20.post_attention_layernorm.weight', 'model.layers.20.self_attn.k_proj.weight', 'model.layers.20.self_attn.o_proj.weight', 'model.layers.20.self_attn.q_proj.weight', 'model.layers.20.self_attn.v_proj.weight', 'model.layers.21.input_layernorm.weight', 'model.layers.21.mlp.down_proj.weight', 'model.layers.21.mlp.gate_proj.weight', 'model.layers.21.mlp.up_proj.weight', 'model.layers.21.post_attention_layernorm.weight', 'model.layers.21.self_attn.k_proj.weight', 'model.layers.21.self_attn.o_proj.weight', 'model.layers.21.self_attn.q_proj.weight', 'model.layers.21.self_attn.v_proj.weight', 'model.layers.22.input_layernorm.weight', 'model.layers.22.mlp.down_proj.weight', 'model.layers.22.mlp.gate_proj.weight', 'model.layers.22.mlp.up_proj.weight', 'model.layers.22.post_attention_layernorm.weight', 'model.layers.22.self_attn.k_proj.weight', 'model.layers.22.self_attn.o_proj.weight', 'model.layers.22.self_attn.q_proj.weight', 'model.layers.22.self_attn.v_proj.weight', 'model.layers.23.input_layernorm.weight', 'model.layers.23.mlp.down_proj.weight', 'model.layers.23.mlp.gate_proj.weight', 'model.layers.23.mlp.up_proj.weight', 'model.layers.23.post_attention_layernorm.weight', 'model.layers.23.self_attn.k_proj.weight', 'model.layers.23.self_attn.o_proj.weight', 'model.layers.23.self_attn.q_proj.weight', 'model.layers.23.self_attn.v_proj.weight', 'model.layers.24.input_layernorm.weight', 'model.layers.24.mlp.down_proj.weight', 'model.layers.24.mlp.gate_proj.weight', 'model.layers.24.mlp.up_proj.weight', 'model.layers.24.post_attention_layernorm.weight', 'model.layers.24.self_attn.k_proj.weight', 'model.layers.24.self_attn.o_proj.weight', 'model.layers.24.self_attn.q_proj.weight', 'model.layers.24.self_attn.v_proj.weight', 'model.layers.25.input_layernorm.weight', 'model.layers.25.mlp.down_proj.weight', 'model.layers.25.mlp.gate_proj.weight', 'model.layers.25.mlp.up_proj.weight', 'model.layers.25.post_attention_layernorm.weight', 'model.layers.25.self_attn.k_proj.weight', 'model.layers.25.self_attn.o_proj.weight', 'model.layers.25.self_attn.q_proj.weight', 'model.layers.25.self_attn.v_proj.weight', 'model.layers.26.input_layernorm.weight', 'model.layers.26.mlp.down_proj.weight', 'model.layers.26.mlp.gate_proj.weight', 'model.layers.26.mlp.up_proj.weight', 'model.layers.26.post_attention_layernorm.weight', 'model.layers.26.self_attn.k_proj.weight', 'model.layers.26.self_attn.o_proj.weight', 'model.layers.26.self_attn.q_proj.weight', 'model.layers.26.self_attn.v_proj.weight', 'model.layers.27.input_layernorm.weight', 'model.layers.27.mlp.down_proj.weight', 'model.layers.27.mlp.gate_proj.weight', 'model.layers.27.mlp.up_proj.weight', 'model.layers.27.post_attention_layernorm.weight', 'model.layers.27.self_attn.k_proj.weight', 'model.layers.27.self_attn.o_proj.weight', 'model.layers.27.self_attn.q_proj.weight', 'model.layers.27.self_attn.v_proj.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.3.mlp.down_proj.weight', 'model.layers.3.mlp.gate_proj.weight', 'model.layers.3.mlp.up_proj.weight', 'model.layers.3.post_attention_layernorm.weight', 'model.layers.3.self_attn.k_proj.weight', 'model.layers.3.self_attn.o_proj.weight', 'model.layers.3.self_attn.q_proj.weight', 'model.layers.3.self_attn.v_proj.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.4.mlp.gate_proj.weight', 'model.layers.4.mlp.up_proj.weight', 'model.layers.4.post_attention_layernorm.weight', 'model.layers.4.self_attn.k_proj.weight', 'model.layers.4.self_attn.o_proj.weight', 'model.layers.4.self_attn.q_proj.weight', 'model.layers.4.self_attn.v_proj.weight', 'model.layers.5.input_layernorm.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.5.mlp.gate_proj.weight', 'model.layers.5.mlp.up_proj.weight', 'model.layers.5.post_attention_layernorm.weight', 'model.layers.5.self_attn.k_proj.weight', 'model.layers.5.self_attn.o_proj.weight', 'model.layers.5.self_attn.q_proj.weight', 'model.layers.5.self_attn.v_proj.weight', 'model.layers.6.input_layernorm.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.6.mlp.gate_proj.weight', 'model.layers.6.mlp.up_proj.weight', 'model.layers.6.post_attention_layernorm.weight', 'model.layers.6.self_attn.k_proj.weight', 'model.layers.6.self_attn.o_proj.weight', 'model.layers.6.self_attn.q_proj.weight', 'model.layers.6.self_attn.v_proj.weight', 'model.layers.7.input_layernorm.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.7.mlp.gate_proj.weight', 'model.layers.7.mlp.up_proj.weight', 'model.layers.7.post_attention_layernorm.weight', 'model.layers.7.self_attn.k_proj.weight', 'model.layers.7.self_attn.o_proj.weight', 'model.layers.7.self_attn.q_proj.weight', 'model.layers.7.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.8.mlp.gate_proj.weight', 'model.layers.8.mlp.up_proj.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.8.self_attn.k_proj.weight', 'model.layers.8.self_attn.o_proj.weight', 'model.layers.8.self_attn.q_proj.weight', 'model.layers.8.self_attn.v_proj.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.9.mlp.gate_proj.weight', 'model.layers.9.mlp.up_proj.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
最新发布