使用Ray Train微调Hugging Face文本分类模型

使用Ray Train微调Hugging Face文本分类模型

ray ray-project/ray: 是一个分布式计算框架,它没有使用数据库。适合用于大规模数据处理和机器学习任务的开发和实现,特别是对于需要使用分布式计算框架的场景。特点是分布式计算框架、无数据库。 ray 项目地址: https://gitcode.com/gh_mirrors/ra/ray

概述

本文将介绍如何使用Ray Train框架来微调Hugging Face的Transformers模型进行文本分类任务。Ray Train是Ray项目中的一个分布式训练框架,它可以帮助我们轻松地将单机训练扩展到多机多GPU环境。

准备工作

环境设置

首先需要初始化Ray集群。Ray支持本地单机集群和分布式集群两种模式:

import ray
ray.init()

可以通过ray.cluster_resources()查看集群可用资源,包括CPU、GPU数量等。

任务配置

我们将使用GLUE基准测试中的文本分类任务。GLUE包含多个子任务,每个任务对应不同的文本分类场景:

GLUE_TASKS = [
    "cola", "mnli", "mnli-mm", "mrpc", 
    "qnli", "qqp", "rte", "sst2", 
    "stsb", "wnli"
]

选择具体任务和模型检查点:

task = "cola"  # 选择任务
model_checkpoint = "distilbert-base-uncased"  # 模型检查点
batch_size = 16  # 批大小

数据处理

加载数据集

使用Hugging Face的datasets库加载GLUE数据集:

from datasets import load_dataset
actual_task = "mnli" if task == "mnli-mm" else task
datasets = load_dataset("glue", actual_task)

转换为Ray Dataset

将Hugging Face数据集转换为Ray Dataset以便分布式处理:

import ray.data

ray_datasets = {
    "train": ray.data.from_huggingface(datasets["train"]),
    "validation": ray.data.from_huggingface(datasets["validation"]),
    "test": ray.data.from_huggingface(datasets["test"]),
}

数据预处理

定义tokenizer和预处理函数:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

def collate_fn(examples: Dict[str, np.array]):
    # 根据任务获取文本列名
    sentence1_key, sentence2_key = task_to_keys[task]
    
    # 单句或双句处理
    if sentence2_key is None:
        outputs = tokenizer(
            list(examples[sentence1_key]),
            truncation=True,
            padding="longest",
            return_tensors="pt",
        )
    else:
        outputs = tokenizer(
            list(examples[sentence1_key]),
            list(examples[sentence2_key]),
            truncation=True,
            padding="longest",
            return_tensors="pt",
        )
    
    outputs["labels"] = torch.LongTensor(examples["label"])
    return outputs

模型训练

训练函数

定义Ray Train的训练函数:

import torch
from transformers import AutoModelForSequenceClassification, Trainer
from ray.train.huggingface.transformers import prepare_trainer

def train_func(config):
    # 初始化模型
    model = AutoModelForSequenceClassification.from_pretrained(
        model_checkpoint, num_labels=num_labels
    )
    
    # 训练参数
    training_args = TrainingArguments(
        output_dir="./results",
        evaluation_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=3,
        weight_decay=0.01,
    )
    
    # 创建Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=ray_datasets["train"],
        eval_dataset=ray_datasets["validation"],
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )
    
    # 准备Ray Train适配
    trainer = prepare_trainer(trainer)
    trainer.train()

启动训练

使用TorchTrainer启动分布式训练:

from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(
        num_workers=num_workers,
        use_gpu=use_gpu,
    ),
)
result = trainer.fit()

关键点解析

  1. Ray Train优势

    • 轻松实现单机多GPU或多机分布式训练
    • 自动处理数据分片和模型并行
    • 与Hugging Face生态无缝集成
  2. 数据处理

    • Ray Dataset提供了高效的数据加载和预处理
    • 支持大规模数据集的内存高效处理
  3. 模型训练

    • 使用标准的Hugging Face Trainer接口
    • 通过prepare_trainer适配Ray Train环境
    • 自动处理分布式训练中的梯度同步

总结

本文展示了如何使用Ray Train框架来分布式微调Hugging Face文本分类模型。Ray Train提供了简单易用的API,使得分布式训练变得非常简单。通过这种方式,我们可以充分利用集群资源,加速模型训练过程。

对于想要进一步扩展训练规模或处理更大数据集的用户,Ray Train是一个值得考虑的选择。它不仅支持Hugging Face Transformers,还可以与其他PyTorch或TensorFlow模型一起使用。

ray ray-project/ray: 是一个分布式计算框架,它没有使用数据库。适合用于大规模数据处理和机器学习任务的开发和实现,特别是对于需要使用分布式计算框架的场景。特点是分布式计算框架、无数据库。 ray 项目地址: https://gitcode.com/gh_mirrors/ra/ray

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

宣连璐Maura

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值