PyTorch学习之路(level2)——自定义数据读取

在上一篇博客PyTorch学习之路(level1)——训练一个图像分类模型中介绍了如何用PyTorch训练一个图像分类模型,建议先看懂那篇博客后再看这篇博客。在那份代码中,采用torchvision.datasets.ImageFolder这个接口来读取图像数据,该接口默认你的训练数据是按照一个类别存放在一个文件夹下。但是有些情况下你的图像数据不是这样维护的,比如一个文件夹下面各个类别的图像数据都有,同时用一个对应的标签文件,比如txt文件来维护图像和标签的对应关系,在这种情况下就不能用torchvision.datasets.ImageFolder来读取数据了,需要自定义一个数据读取接口。另外这篇博客最后还顺带介绍如何保存模型和多GPU训练。

怎么做呢?

先来看看torchvision.datasets.ImageFolder这个类是怎么写的,主要代码如下,想详细了解的可以看:官方github代码

看起来很复杂,其实非常简单。继承的类是torch.utils.data.Dataset,主要包含三个方法:初始化__init__,获取图像__getitem__,数据集数量 __len____init__方法中先通过find_classes函数得到分类的类别名(classes)和类别名与数字类别的映射关系字典(class_to_idx)。然后通过make_dataset函数得到imags,这个imags是一个列表,其中每个值是一个tuple,每个tuple包含两个元素:图像路径和标签。剩下的就是一些赋值操作了。在__getitem__方法中最重要的就是 img = self.loader(path)这行,表示数据读取,可以从__init__方法中看出self.loader采用的是default_loader,这个default_loader的核心就是用python的PIL库的Image模块来读取图像数据。

class ImageFolder(data.Dataset):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes
### 如何使用自定义数据集训练Transformer模型 #### 准备环境和安装依赖库 为了能够顺利地使用自定义数据集来训练Transformer模型,需要先准备好相应的开发环境以及必要的Python包。通常情况下会采用Anaconda作为管理工具,并通过pip或conda命令安装Hugging Face的`transformers`库和其他辅助库,比如PyTorch用于构建神经网络框架。 ```bash conda create -n transformer_env python=3.9 conda activate transformer_env pip install transformers datasets torch scikit-learn ``` #### 创建自定义数据处理类 当准备好了所需的软件环境之后,下一步就是针对特定的任务场景设计适合的数据读取与预处理逻辑。这里可以继承来自`datasets.Dataset`的基础类,重写其中的关键方法以便适应不同的输入格式[^4]。 对于图像分类任务而言,可能涉及到图片路径列表、标签映射字典等结构的设计;而对于文本序列到序列的学习,则更关注于源语言句子同目标语言翻译之间的配对关系建立及其编码方式的选择(如子词分割技术的应用)。下面给出一段简化版的例子说明如何实现这一点: ```python from datasets import Dataset as HFDataset import pandas as pd class CustomTextDataset(HFDataset): def _load_data(self, file_path): df = pd.read_csv(file_path) texts = list(df['text']) labels = list(df['label']) self.data_dict = {'texts': texts, 'labels': labels} def __init__(self, csv_file): super().__init__() self._load_data(csv_file) def __getitem__(self, idx): text = self.data_dict['texts'][idx] label = self.data_dict['labels'][idx] item = {"input_ids": tokenizer(text)["input_ids"], "labels": int(label)} return item def __len__(self): return len(self.data_dict['texts']) ``` 在这个例子中,假设有一个CSV文件包含了两列——分别是待处理的文字内容(`'text'`)和对应的类别标记(`'label'`)。上述代码片段展示了怎样基于这些信息创建一个新的数据集实例,并实现了获取单条记录的方法`__getitem__()` 和返回整个集合大小的功能 `__len__()`. 同时也引入了一个名为`tokenizer`的对象来进行文本向量化操作,在实际应用前还需要对其进行初始化配置。 #### 配置Tokenizer并加载预训练模型 接着要做的工作是对所选中的基础架构执行微调之前完成tokenization过程设置。这一步骤可以通过指定想要使用的分词器类型(例如BERT风格的WordPiece或是GPT系列里的Byte-Level BPE),再结合具体应用场景调整一些超参数选项来达成目的。一旦确定好合适的方案就可以直接从官方仓库下载对应版本的权重矩阵了[^1]。 ```python from transformers import AutoTokenizer, AutoModelForSequenceClassification model_name_or_path = "bert-base-cased" num_labels = 2 # binary classification task tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model = AutoModelForSequenceClassification.from_pretrained( model_name_or_path, num_labels=num_labels ) ``` 这段脚本里指定了一个二元情感分析案例下的初始状态迁移起点为“bert-base-cased”,并且告知系统该问题属于两类判别范畴之内。随后分别建立了关联性的Tokenzier对象和带有线性变换层输出单元数量等于预期分类数目的一套完整的Seq2Classifcation体系结构。 #### 微调模型及评估性能 最后便是正式开启fine-tuning流程的部分啦!此时应该已经拥有了经过良好组织整理过的样本资料集以及精心挑选出来的骨干组件。现在只需要按照标准做法设定优化算法参数、损失函数形式还有学习率调度策略等方面的内容即可启动迭代更新机制直至收敛为止。期间还可以定期保存checkpoint便于后续恢复继续训练或者部署上线服务端口供外部访问请求调用。 ```python from transformers import Trainer, TrainingArguments training_args = TrainingArguments( output_dir='./results', evaluation_strategy="epoch", learning_rate=2e-5, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=3, weight_decay=0.01, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=lambda p: {"accuracy": (p.predictions.argmax(axis=-1) == p.label_ids).mean()}, ) trainer.train() ``` 以上代码段落描述的是利用Trainer API快速搭建起一套简易但功能完备的小规模实验平台的过程。其中包括但不限于定义日志存储位置、验证频率间隔、批量尺寸规格等一系列重要的运行期属性控制项;同时也明确了具体的评价指标计算规则以方便实时监控进度变化趋势图谱。
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值