简介
1,什么是BitFit
BitFit 是一种在小规模数据集上进行微调(fine-tuning)的技术,主要应用于大规模预训练模型,尤其是在深度学习和自然语言处理(NLP)领域中。BitFit 通过仅微调模型中的偏置项(bias term),而不更新模型的其他参数(如权重),来实现高效且节省资源的微调方式。相比传统的微调方法,BitFit 大大减少了计算和存储开销,同时还能在一定程度上保留模型的性能。
2,BitFit 微调的核心思想
传统的微调方法通常更新模型中的所有参数,包括权重和偏置项。BitFit 的关键创新在于,它只更新 偏置项(bias),而不更新模型中的其他参数(如权重)。在深度神经网络中,偏置项负责对模型的输出进行平移,使得网络能更灵活地调整预测结果。
为什么只更新偏置项?
- 偏置项对学习的影响:尽管偏置项在网络中占据较小的比例,但它们在训练过程中起到了重要的作用,尤其是在激活函数的作用下,偏置项可以显著影响模型的输出。通过仅调整偏置项,模型可以在特定任务中适应新的数据分布,而不需要对整个网络进行重新训练。
- 资源效率:更新整个模型的所有参数需要大量的计算和存储资源,尤其是对于大型模型(如BERT、GPT-3等)。而只更新偏置项,意味着可以显著减少计算量和存储需求,降低微调的成本。
3. BitFit 微调的步骤
BitFit 微调的过程非常简洁,通常包括以下几个步骤:
1. 选择预训练模型
首先,选择一个在大规模数据集上预训练过的模型,通常是一个基于 Transformer 架构的模型(如BERT、GPT等)。这个模型已经学到了通用的语言特征,但在特定任务上可能并不完全优化。
2. 冻结所有层的权重
在微调过程中,冻结模型中的所有权重参数。这意味着,所有的模型层(如注意力层、前馈层等)在训练过程中保持不变,模型的权重不进行任何更新。
3. 只更新偏置项
在冻结了模型的大部分参数后,BitFit 只对模型中的偏置项进行微调。通过标准的梯度下降算法来优化这些偏置项,以适应特定任务的数据分布。
4. 训练和优化
使用特定任务的数据(如分类、回归等任务的数据集)进行训练。由于只微调了偏置项,训练过程通常非常快速,且计算资源需求远低于标准微调方法。
5. 评估和应用
完成微调后,使用验证集对模型进行评估。如果微调后的模型在目标任务上表现良好,就可以直接应用于实际应用中。
4. BitFit 微调的优点
BitFit 微调方法有几个显著的优点,尤其是在大规模模型和小规模数据集的场景下:
1. 计算和存储效率
- 减少计算开销:传统的微调方法需要更新大量的参数,尤其是当模型规模很大时,计算资源的消耗非常高。BitFit 只更新偏置项,大大减少了计算量。
- 节省存储空间:因为更新的参数较少,存储需求也大大减少。对于硬件资源有限的场景,BitFit 提供了一个有效的解决方案。
2. 快速训练
由于只需要更新少量的参数,BitFit 的训练时间比标准微调方法要短得多。即使是在计算资源有限的环境中,开发者也可以快速完成模型微调。
3. 保持原有模型的表现
BitFit 只更新偏置项,避免了对模型架构的重大修改,因此能够在许多任务中保持与传统微调方法相似的性能。尤其是在迁移学习的场景下,BitFit 能够有效地适应新的任务数据,同时避免了过拟合。
4. 适应小规模数据集
在数据较少的情况下,标准微调可能会导致过拟合,因为训练数据不足以支持对整个模型的更新。而 BitFit 只更新偏置项,减少了过拟合的风险,尤其适用于小规模数据集的微调。
5. BitFit 微调的局限性
尽管 BitFit 具有许多优势,但它也存在一些局限性:
1. 仅适用于特定任务
由于 BitFit 只更新偏置项,它的效果可能在某些任务上不如标准微调方法,尤其是在需要大量参数调整以适应新任务的复杂任务上。例如,在某些复杂的NLP任务中,可能需要对模型的更多层进行更新,以达到最优的表现。
2. 偏置项更新的效果有限
对于某些模型,偏置项可能对模型的性能影响较小,因此只更新偏置项可能无法有效提高模型的准确性或表现。
3. 依赖预训练模型的质量
BitFit 的效果很大程度上依赖于所选用的预训练模型的质量。如果预训练模型在原始任务上表现不佳,那么仅微调偏置项可能无法大幅提升模型的性能。
适合使用 BitFit 微调的模型类型
上面讲解 BitFit 这多,其实已经告知那些模型适合使用BitFit微调。那些下面就更加具体的针对模型类型来讲解是否适合使用BitFit微调。
1 Transformer 架构的预训练模型
BitFit 最常应用于基于 Transformer 的大规模预训练模型。这些模型通常已经在海量数据上进行了预训练,学习了通用的语言特征。常见的模型包括:
BERT(Bidirectional Encoder Representations from Transformers)
- 适合文本分类、问答系统、文本相似度计算等任务。
- BERT 的层中包含大量偏置项,且偏置项在模型输出中有重要贡献。
GPT 系列(Generative Pretrained Transformer)
- 适合生成任务(如文本生成)和下游语言理解任务。
- 在语言生成任务中,偏置项微调可以调整生成文本的风格或倾向。
RoBERTa
- 类似于 BERT,但使用了更多的数据和更优的训练策略。
- 在需要更细粒度特征的任务中,BitFit 微调同样高效。
T5(Text-to-Text Transfer Transformer)
- 适用于多任务学习,例如文本生成、翻译、问答等。
- T5 的偏置项调整可以快速适配任务。
2 其他 NLP 模型
除了经典的 Transformer 架构,以下类型的模型也可以采用 BitFit 微调:
基于 RNN 的语言模型
- 虽然不如 Transformer 常见,但某些情况下仍被使用。RNN 中的偏置项也可以用来调整输出分布。
Seq2Seq 模型
- 用于翻译、摘要生成等任务。只调整偏置项可以让模型更快速适应新任务的目标分布。
3 预训练的多模态模型
CLIP(Contrastive Language-Image Pretraining)
- 结合文本和图像的模型,广泛用于图像与文本的匹配、图像生成等任务。
- 在 CLIP 中,偏置项微调可以快速适配新任务的特定语义需求。
Vision-Language Models
- 如 BLIP、ALIGN 等多模态模型,也能通过 BitFit 微调实现高效适配新任务。
4 视觉模型(部分情况适用)
ViT(Vision Transformer)
- 在图像分类、目标检测等任务中,可以采用 BitFit 微调。
- 偏置项在 ViT 中控制各层的特征分布,微调可以调整模型对不同特征的敏感性。
ConvNeXt、ResNet 等 CNN 模型
- 传统的卷积神经网络中,偏置项的作用较小,BitFit 对其的效果有限,通常不如 Transformer 模型显著。
代码
代码部分其实很简单,对模型named_parameters添加一个判断,是bias就进行微调,不是就保持不变。
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq
from transformers import pipeline, TrainingArguments, Trainer
# 分词器
tokenizer = AutoTokenizer.from_pretrained("langboat_bloom-1b4-zh")
# 函数内将instruction和response拆开分词的原因是:
# 为了便于mask掉不需要计算损失的labels, 即代码labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
def process_func(example):
MAX_LENGTH = 256
input_ids, attention_mask, labels = [], [], []
instruction = tokenizer(
"\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
response = tokenizer(example["output"] + tokenizer.eos_token)
input_ids = instruction["input_ids"] + response["input_ids"]
attention_mask = instruction["attention_mask"] + response["attention_mask"]
labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
if len(input_ids) > MAX_LENGTH:
input_ids = input_ids[:MAX_LENGTH]
attention_mask = attention_mask[:MAX_LENGTH]
labels = labels[:MAX_LENGTH]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
if __name__ == "__main__":
# 加载数据集
dataset = Dataset.load_from_disk("./alpaca_data_zh/")
# 处理数据
tokenized_ds = dataset.map(process_func, remove_columns=dataset.column_names)
# print(tokenizer.decode(tokenized_ds[1]["input_ids"]))
# print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_ds[1]["labels"]))))
# 创建模型
model = AutoModelForCausalLM.from_pretrained("langboat_bloom-1b4-zh", low_cpu_mem_usage=True)
# 基于bitfit只训练带有bias的参数
for name, param in model.named_parameters():
if "bias" not in name:
param.requires_grad = False
# 训练参数
args = TrainingArguments(
output_dir="./chatbot",
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
logging_steps=10,
num_train_epochs=1
)
# trainer
trainer = Trainer(
model=model,
args=args,
train_dataset=tokenized_ds,
data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
)
# 训练模型
trainer.train()
# 模型推理
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
ipt = "Human: {}\n{}".format("考试有哪些技巧?", "").strip() + "\n\nAssistant: "
output = pipe(ipt, max_length=256, do_sample=True)
print(output)
数据集和模型都已下载到本地,下载发生进入地址
数据集:shibing624/alpaca-zh
模型:Langboat/bloom-1b4-zh