Bert模型冻结指定参数进行训练

本文详细介绍了BERT模型的参数结构,并提供了如何在微调过程中冻结特定层参数的实战代码,以提高训练效率。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

      由于bert模型具有12层,参数量达一亿,bert模型做微调有的时候就需要只训练部分参数,那么就需要把其他的参数冻结掉,固定住,又能微调bert模型,还能提高模型训练的效率。这个就需要用到parameter的requires_grad的属性,来冻结和放开参数。

首先我们看看bert模型的具体参数有那些:

bert.embeddings.word_embeddings.weight torch.Size([21128, 768])
bert.embeddings.position_embeddings.weight torch.Size([512, 768])
bert.embeddings.token_type_embeddings.weight torch.Size([2, 768])
bert.embeddings.LayerNorm.weight torch.Size([768])
bert.embeddings.LayerNorm.bias torch.Size([768])

bert.encoder.layer.0.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.query.bias torch.Size([768])
bert.encoder.layer.0.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.key.bias torch.Size([768])
bert.encoder.layer.0.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.self.value.bias torch.Size([768])
bert.encoder.layer.0.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.0.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.0.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.0.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.0.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.0.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.0.output.dense.bias torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.0.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.1.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.query.bias torch.Size([768])
bert.encoder.layer.1.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.key.bias torch.Size([768])
bert.encoder.layer.1.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.self.value.bias torch.Size([768])
bert.encoder.layer.1.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.1.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.1.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.1.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.1.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.1.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.1.output.dense.bias torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.1.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.2.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.self.query.bias torch.Size([768])
bert.encoder.layer.2.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.self.key.bias torch.Size([768])
bert.encoder.layer.2.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.self.value.bias torch.Size([768])
bert.encoder.layer.2.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.2.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.2.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.2.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.2.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.2.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.2.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.2.output.dense.bias torch.Size([768])
bert.encoder.layer.2.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.2.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.3.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.self.query.bias torch.Size([768])
bert.encoder.layer.3.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.self.key.bias torch.Size([768])
bert.encoder.layer.3.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.self.value.bias torch.Size([768])
bert.encoder.layer.3.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.3.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.3.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.3.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.3.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.3.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.3.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.3.output.dense.bias torch.Size([768])
bert.encoder.layer.3.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.3.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.4.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.self.query.bias torch.Size([768])
bert.encoder.layer.4.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.self.key.bias torch.Size([768])
bert.encoder.layer.4.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.self.value.bias torch.Size([768])
bert.encoder.layer.4.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.4.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.4.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.4.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.4.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.4.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.4.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.4.output.dense.bias torch.Size([768])
bert.encoder.layer.4.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.4.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.5.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.self.query.bias torch.Size([768])
bert.encoder.layer.5.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.self.key.bias torch.Size([768])
bert.encoder.layer.5.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.self.value.bias torch.Size([768])
bert.encoder.layer.5.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.5.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.5.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.5.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.5.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.5.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.5.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.5.output.dense.bias torch.Size([768])
bert.encoder.layer.5.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.5.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.6.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.self.query.bias torch.Size([768])
bert.encoder.layer.6.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.self.key.bias torch.Size([768])
bert.encoder.layer.6.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.self.value.bias torch.Size([768])
bert.encoder.layer.6.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.6.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.6.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.6.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.6.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.6.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.6.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.6.output.dense.bias torch.Size([768])
bert.encoder.layer.6.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.6.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.7.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.self.query.bias torch.Size([768])
bert.encoder.layer.7.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.self.key.bias torch.Size([768])
bert.encoder.layer.7.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.self.value.bias torch.Size([768])
bert.encoder.layer.7.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.7.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.7.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.7.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.7.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.7.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.7.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.7.output.dense.bias torch.Size([768])
bert.encoder.layer.7.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.7.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.8.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.self.query.bias torch.Size([768])
bert.encoder.layer.8.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.self.key.bias torch.Size([768])
bert.encoder.layer.8.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.self.value.bias torch.Size([768])
bert.encoder.layer.8.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.8.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.8.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.8.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.8.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.8.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.8.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.8.output.dense.bias torch.Size([768])
bert.encoder.layer.8.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.8.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.9.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.self.query.bias torch.Size([768])
bert.encoder.layer.9.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.self.key.bias torch.Size([768])
bert.encoder.layer.9.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.self.value.bias torch.Size([768])
bert.encoder.layer.9.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.9.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.9.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.9.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.9.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.9.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.9.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.9.output.dense.bias torch.Size([768])
bert.encoder.layer.9.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.9.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.10.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.self.query.bias torch.Size([768])
bert.encoder.layer.10.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.self.key.bias torch.Size([768])
bert.encoder.layer.10.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.self.value.bias torch.Size([768])
bert.encoder.layer.10.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.10.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.10.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.10.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.10.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.10.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.10.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.10.output.dense.bias torch.Size([768])
bert.encoder.layer.10.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.10.output.LayerNorm.bias torch.Size([768])

bert.encoder.layer.11.attention.self.query.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.self.query.bias torch.Size([768])
bert.encoder.layer.11.attention.self.key.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.self.key.bias torch.Size([768])
bert.encoder.layer.11.attention.self.value.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.self.value.bias torch.Size([768])
bert.encoder.layer.11.attention.output.dense.weight torch.Size([768, 768])
bert.encoder.layer.11.attention.output.dense.bias torch.Size([768])
bert.encoder.layer.11.attention.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.11.attention.output.LayerNorm.bias torch.Size([768])
bert.encoder.layer.11.intermediate.dense.weight torch.Size([3072, 768])
bert.encoder.layer.11.intermediate.dense.bias torch.Size([3072])
bert.encoder.layer.11.output.dense.weight torch.Size([768, 3072])
bert.encoder.layer.11.output.dense.bias torch.Size([768])
bert.encoder.layer.11.output.LayerNorm.weight torch.Size([768])
bert.encoder.layer.11.output.LayerNorm.bias torch.Size([768])

bert.pooler.dense.weight torch.Size([768, 768])
bert.pooler.dense.bias torch.Size([768])
out.weight torch.Size([2, 768])
out.bias torch.Size([2])

比如说现在我们要放开第11和12层以及bert.pooler和out层参数,冻结其他的参数,怎么实现呢?

pytorch中有 model.named_parameters() 和 requires_grad,直接写一个遍历然后设置就好。具体实现代码:

import torch.nn as nn
from transformers import BertModel
import torch

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('pretrained_models/Chinese-BERT-wwm')
        self.out = nn.Linear(768,2)

    def forward(self):
        out = self.bert()
        return out

if __name__ == '__main__':
    model = Model()
    unfreeze_layers = ['layer.10','layer.11','bert.pooler','out.']

    for name, param in model.named_parameters():
        print(name,param.size())

    print("*"*30)
    print('\n')

    for name ,param in model.named_parameters():
        param.requires_grad = False
        for ele in unfreeze_layers:
            if ele in name:
                param.requires_grad = True
                break
    #验证一下
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name,param.size())

    #过滤掉requires_grad = False的参数
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.00001)

最后在做训练的时候,优化器中一定要添加过滤器filter把requires_grad = False的参数过滤掉,在训练的时候,不会更新这些参数。

 

参考文章:

Pytorch 如何精确的冻结我想冻结的预训练模型的某一层,有什么命令吗?

<think>嗯,用户想了解如何微调BERT训练模型以适应特定任务,还要求示例代码和教程。首先,我需要回忆一下BERT微调的基本步骤。通常包括环境配置、数据预处理、模型加载、训练和评估这几个部分。用户提供的引用里提到了Hugging Face的transformers库,所以应该用这个库来实现。 首先,环境配置需要安装transformers、datasets和torch这些库。用户给的示例代码中有安装transormers的拼写错误,我得注意别写错了。然后,数据预处理部分,用户提到了IMDb数据集,可能需要加载该数据集,并进行分词处理。BERT的分词器是BertTokenizer,需要用它来处理文本,转换成input_ids、attention_mask这些张量。 接下来是模型加载。对于文本分类任务,应该使用BertForSequenceClassification类。需要指定num_labels参数,根据任务类型设置,比如二分类就是2。然后定义训练参数,比如学习率、批次大小、训练轮数等。这里可能需要用到TrainingArguments和Trainer这两个类来简化训练过程。 在训练过程中,可能需要自定义评估指标,比如准确率,这可以通过compute_metrics函数实现。然后将训练集和验证集传给Trainer,开始训练。最后,保存微调后的模型,并加载进行推理。 用户可能需要示例代码,所以得写出一个完整的流程,包括安装库、加载数据、分词、模型初始化、训练和评估。同时,需要注意代码中的关键点,比如数据集的划分、分词器的应用、模型配置等。另外,常见问题如学习率设置、处理小数据集的方法也需要提及,比如数据增强、分层抽样等。 还要考虑用户可能遇到的问题,比如显存不足,这时候可以建议减小批次大小或者使用梯度累积。另外,数据集标签不平衡的话,可能需要调整类别权重或者使用不同的损失函数。这部分需要在注意事项里提到。 最后,根据用户提供的引用,引用1和2都需要在适当的地方标注,比如环境配置部分和数据预处理部分。确保引用标识正确,比如[^1]和[^2]。 总结一下,回答的结构应该是步骤分明的,每个步骤详细说明,并给出代码示例。同时,相关问题部分需要生成与BERT微调相关的问题,比如处理小数据集、多分类任务、学习率设置、过拟合应对等。这样用户能进一步探索相关内容。</think>### 如何微调 BERT训练模型 #### 1. **环境配置** 安装必要的库: ```bash pip install transformers datasets torch ``` #### 2. **数据预处理** 以 IMDb 电影评论分类任务为例: ```python from datasets import load_dataset from transformers import BertTokenizer # 加载数据集 dataset = load_dataset("imdb") tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") # 定义分词函数 def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512) # 应用分词器 tokenized_datasets = dataset.map(tokenize_function, batched=True) ``` #### 3. **加载预训练模型** ```python from transformers import BertForSequenceClassification model = BertForSequenceClassification.from_pretrained( "bert-base-uncased", num_labels=2 # 二分类任务 ) ``` #### 4. **训练配置** ```python from transformers import TrainingArguments, Trainer import numpy as np from sklearn.metrics import accuracy_score # 定义评估指标 def compute_metrics(eval_pred): logits, labels = eval_pred predictions = np.argmax(logits, axis=-1) return {"accuracy": accuracy_score(labels, predictions)} # 设置训练参数 training_args = TrainingArguments( output_dir="./results", evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=3, weight_decay=0.01, ) ``` #### 5. **训练与评估** ```python # 创建 Trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"], compute_metrics=compute_metrics, ) # 开始训练 trainer.train() # 保存模型 model.save_pretrained("./fine-tuned-bert-imdb") ``` #### 6. **推理示例** ```python text = "This movie is fantastic!" inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True) outputs = model(**inputs) prediction = np.argmax(outputs.logits.detach().numpy()) print("Positive" if prediction == 1 else "Negative") ``` ### 关键点说明 1. **学习率选择**:BERT 微调通常使用较小的学习率(如 $2 \times 10^{-5}$),避免破坏预训练权重[^1]。 2. **批次大小**:根据显存调整 `per_device_train_batch_size`。 3. **评估策略**:`evaluation_strategy="epoch"` 表示每个 epoch 结束后验证性能。 ### 常见问题处理 - **显存不足**:减少批次大小或使用梯度累积(`gradient_accumulation_steps`) - **小数据集优化**:冻结 BERT 底层参数,仅微调顶层 - **类别不平衡**:在 `Trainer` 中添加 `class_weights` 参数 ### 注意事项 - 中文任务需改用 `bert-base-chinese` 模型 - 多分类任务需调整 `num_labels` 参数 - 长文本需调整 `max_length`(不超过 512) ---
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值