用huggingface.transformers.AutoModelForSeq2SeqLM实现文本生成任务

本文详细介绍了如何使用BART模型和transformers库的Seq2SeqTrainer进行文本生成,包括设置参数、数据预处理、训练过程以及生成和推理方法的演示。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

诸神缄默不语-个人优快云博文目录

本文介绍的是用BART(别的s2s模型也差不多)+ transformers官方训练框架(Seq2SeqTrainer)运行文本生成代码,先介绍前置条件,然后介绍训练代码,最后会介绍使用generate()函数或者使用pipeline来进行推理的方法。

1. 导入包

import json,random,copy,os,re

from transformers import AutoTokenizer,AutoModelForSeq2SeqLM,DataCollatorForSeq2Seq,Seq2SeqTrainingArguments,\
                        Seq2SeqTrainer,EarlyStoppingCallback

from datasets import Dataset,DatasetDict

2. 设置参数

model_path="bart-base-chinese"
output_path="data/my_checkpoints/bart_202403071330"
source_max_length=1024
target_max_length=128  #这2个都是模型限长

3. 导入分词器和模型

tokenizer=AutoTokenizer.from_pretrained(model_path)
model=AutoModelForSeq2SeqLM.from_pretrained(model_path)

4. 导入数据集

具体部分略,总之最后要是datasets.Dataset的格式,输入键是text,输出键是summary
然后进行如下数据预处理工作:

def preprocess_function(examples):
    model_inputs = tokenizer(examples['text'],max_length=source_max_length,truncation=True)

    labels = tokenizer(text_target=examples["summary"],max_length=target_max_length,truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = datasets.map(preprocess_function, batched=True)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

datasets包的使用文档可参考huggingface.datasets使用说明

5. 训练

training_args = Seq2SeqTrainingArguments(
    output_dir=output_path,
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=32,
    weight_decay=0.01,
    save_strategy="epoch",
    save_steps=1,
    save_total_limit=3,
    num_train_epochs=100,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
    metric_for_best_model="eval_loss",
    load_best_model_at_end=True
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["valid"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(10)],
)

trainer.train()

6. 测试

在output_path中会保留save_total_limit个checkpoint,一般来说建议用checkpoint数最大的那个文件夹。这个文件夹称为final_checkpoint_folder,从这个文件夹里面调用模型权重:model=AutoModelForSeq2SeqLM.from_pretrained(final_checkpoint_folder)
以下代码中的model指的是这个model

6.1 使用generate()函数

函数文档:https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate

outputs=model.generate(tokenizer.encode(input_text,return_tensors="pt",max_length=source_max_length),max_new_tokens=target_max_length,do_sample=False)
output_text=tokenizer.decode(outputs[0],skip_special_tokens=True).replace(" ","")

(这个会在CPU上跑,如果把模型和数据都放到GPU上就能在GPU上跑)

之所以要删空格是因为BART输出会在token之间自动添加空格,就跟英文一样。

generate()函数的使用可以参考这2篇博文:基于 transformers 的 generate() 方法实现多样化文本生成:参数含义和算法原理解读_length_penalty-优快云博客LLM(大语言模型)解码时是怎么生成文本的? - 知乎

6.2 使用pipeline

函数源码:https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/text2text_generation.py

函数文档:https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.Text2TextGenerationPipeline

需要导入pipeline函数:

from transformers import pipeline

generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
outputs = generator(input_text, max_new_tokens=128, do_sample=False)
output_text = outputs[0]["generated_text"].replace(" ", "")

pipeline()device入参可以设置GPU

使用pipeline的问题在于我没找到限制输入长度的位置,所以如果输入文本太长的话就会报错,不像用encodegenerate时可以显式限制输入长度

在本文撰写过程中参考到的网络资料

  1. https://huggingface.co/docs/transformers/tasks/summarization
  2. https://blog.youkuaiyun.com/daotianweng/article/details/121036353
  3. 完整项目文件可参考:https://github.com/PolarisRisingWar/llm-throught-ages/blob/master/models/BART/bart_generation1.py
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

诸神缄默不语

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值