OpenNMT/CTranslate2中的Transformer模型支持指南

OpenNMT/CTranslate2中的Transformer模型支持指南

CTranslate2 Fast inference engine for Transformer models CTranslate2 项目地址: https://gitcode.com/gh_mirrors/ct/CTranslate2

概述

OpenNMT/CTranslate2是一个高效的推理引擎,支持多种Transformer架构的模型。本文将详细介绍CTranslate2对Hugging Face Transformers模型的支持情况,以及如何转换和使用这些模型。

支持的Transformer模型

CTranslate2目前支持以下主流Transformer模型:

  • 文本生成类:BART、BLOOM、CodeGen、Falcon、GPT系列(GPT2、GPT-J、GPT-NeoX)、Llama 2、MPT、OPT
  • 翻译类:M2M100、MarianMT、MBART、NLLB
  • 语音识别:Whisper
  • 编码器类:BERT、DistilBERT、XLM-RoBERTa
  • 其他:Pegasus、T5

模型转换方法

要将Hugging Face的Transformer模型转换为CTranslate2格式,需要使用ct2-transformers-converter工具。基本转换命令如下:

pip install transformers[torch]
ct2-transformers-converter --model 模型名称或路径 --output_dir 输出目录

转换完成后,您可以在输出目录中找到CTranslate2格式的模型文件。

特殊注意事项

特殊标记处理

与其他框架不同,CTranslate2不会自动为Transformers模型添加特殊标记。这是因为Hugging Face的tokenizer已经包含了这些标记。如果您不使用Hugging Face的tokenizer,需要手动添加这些特殊标记。

模型特定说明

不同模型在使用时有各自的注意事项:

  1. BERT/DistilBERT:仅支持编码器部分,任务特定层仍需使用PyTorch运行
  2. Llama 2:需要先申请模型访问权限
  3. MPT:转换时需要添加--trust_remote_code参数
  4. NLLB:需要transformers>=4.21.0版本
  5. OPT:所有输入应以</s>标记开头
  6. Whisper:需要transformers>=4.23.0版本

典型使用示例

1. 文本摘要(BART)

import ctranslate2
import transformers

# 初始化
translator = ctranslate2.Translator("bart-large-cnn")
tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/bart-large-cnn")

# 准备输入
text = "您的输入文本..."
source = tokenizer.convert_ids_to_tokens(tokenizer.encode(text))

# 执行摘要
results = translator.translate_batch([source])
target = results[0].hypotheses[0]

# 解码输出
print(tokenizer.decode(tokenizer.convert_tokens_to_ids(target), skip_special_tokens=True))

2. 文本分类(BERT)

import ctranslate2
import transformers

# 初始化编码器
encoder = ctranslate2.Encoder("bert-base-uncased-yelp-polarity")

# 加载分类头
classifier = transformers.AutoModelForSequenceClassification.from_pretrained(
    "textattack/bert-base-uncased-yelp-polarity").classifier

# 准备输入
inputs = ["正面评价", "负面评价"]
tokens = tokenizer(inputs).input_ids

# 获取编码
output = encoder.forward_batch(tokens)
pooler_output = output.pooler_output

# 分类预测
logits = classifier(pooler_output)
predicted_class_ids = logits.argmax(1)

3. 文本生成(GPT-2)

import ctranslate2
import transformers

# 初始化
generator = ctranslate2.Generator("gpt2_ct2")
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")

# 无条件生成
start_tokens = [tokenizer.bos_token]
results = generator.generate_batch([start_tokens], max_length=30, sampling_topk=10)

# 有条件生成
start_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode("开头文本"))
results = generator.generate_batch([start_tokens], max_length=30, sampling_topk=10)

4. 机器翻译(T5)

import ctranslate2
import transformers

# 初始化
translator = ctranslate2.Translator("t5-small-ct2")
tokenizer = transformers.AutoTokenizer.from_pretrained("t5-small")

# 准备输入
input_text = "translate English to German: The house is wonderful."
input_tokens = tokenizer.convert_ids_to_tokens(tokenizer.encode(input_text))

# 执行翻译
results = translator.translate_batch([input_tokens])

# 解码输出
output_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(results[0].hypotheses[0]))

5. 语音识别(Whisper)

import ctranslate2
import librosa
import transformers

# 加载音频
audio, _ = librosa.load("audio.wav", sr=16000, mono=True)

# 提取特征
processor = transformers.WhisperProcessor.from_pretrained("openai/whisper-tiny")
inputs = processor(audio, return_tensors="np", sampling_rate=16000)
features = ctranslate2.StorageView.from_array(inputs.input_features)

# 初始化模型
model = ctranslate2.models.Whisper("whisper-tiny-ct2")

# 语言检测
results = model.detect_language(features)
language, probability = results[0][0]

# 语音识别
prompt = processor.tokenizer.convert_tokens_to_ids(["<|startoftranscript|>", language, "<|transcribe|>"])
results = model.generate(features, [prompt])
output = processor.decode(results[0].sequences_ids[0])

性能优化建议

  1. 量化:对于大模型,可以使用--quantization参数进行量化,减少内存占用和提高推理速度
  2. 设备选择:明确指定运行设备(CPU/GPU)以获得最佳性能
  3. 批次处理:尽可能使用批次处理提高吞吐量
  4. 参数调优:根据任务调整max_lengthsampling_topk等生成参数

常见问题

  1. 模型转换失败:检查transformers版本是否满足要求,确保模型名称正确
  2. 内存不足:尝试量化模型或使用更小的模型变体
  3. 特殊标记问题:确认是否正确处理了模型所需的特殊标记
  4. 性能问题:检查是否使用了合适的设备,并尝试调整批次大小

通过本文介绍的方法,您可以轻松地将各种Transformer模型转换为CTranslate2格式,并在生产环境中高效地使用它们。

CTranslate2 Fast inference engine for Transformer models CTranslate2 项目地址: https://gitcode.com/gh_mirrors/ct/CTranslate2

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

计姗群

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值