Pytorch 数据集太大,内存不够怎么办?——幻方量化优雅的解决方法

Tensorflow有着专门的数据读取模块tfrecord,可以直接从硬盘读取数据,而不受内存限制,但Pytorch一直没有一个专有的、高效的数据读取方法。如何基于PyTorch,高效读取大规模训练数据集,并充分利用GPU性能,成为一个关键问题。

在笔者之前的文章(tensorflow2.x(一)显存不够或内存不够要怎么办?)中,介绍过 tensorflow 针对大数据集的训练有 tfrecords的数据存储格式,可以将超大的数据样本先存在硬盘,每次只从存储硬盘中读取一个batch的数据入内存,而不是将整个训练样本一次性全部读入,因此可以大大减小内存的限制。

1.官方给出的方案

我们先介绍一些官方给出的解决方法及其局限:
在Pytorch中,官方给了可以自定义的数据集基类 Dataset:

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

只要继承,并实现这两个抽象方法即可:

  • getitem() 方法通过索引返回数据集中选定的样本。
  • len() 方法返回数据集的总大小。例如,如果您的数据集包含 1,00,000 个样本,则该len方法应返回 1,00,000

举个栗子:

#导入相关模块
from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import os
import torch
from torchvision import transforms
import numpy as np

class AnimalData(Dataset): #继承Dataset
    def __init__(self, root_dir):
        self.root_dir = root_dir   #文件目录
        self.images = os.listdir(self.root_dir) #目录里的所有文件
    
    def __len__(self): #返回整个数据集的大小
        return len(self.images)
    
    def __getitem__(self,index) :#根据索引index返回dataset[index]
        image_index = self.images[index]
        img_path = os.path.join(self.root_dir, image_index)
        img = io.imread(img_path) # 读取该图片
        label = img_path.split('\\')[-1].split('.')[0]
        sample = {'image':img,'label':label}
        
        return sample # 返回该样本及其标签

可以看到,网络上的大部分例子都是类似的,都是针对图片数据,每张图片即是一个样本,训练时,只需要预先读取样本(图片)的文件地址,再通过 DataLoader 依次加载即可。

但是!对于非图片数据的样本,假如我们有数百万上千万个样本,但是每个样本又很小,如果按照这种方法使每个样本独立存储,就会有数百万个小文件,这就相当不优雅,且速度较慢。

2. 其他可行方案

PyTorch下训练数据小文件转大文件读写(附有各种存储格式对比),这篇文章中给出了几个可行的解决方案,诸如:

  • 在pytorch下,使用 TFRecord 方法
  • 直接将现有数据集按照二进制读取,存入一个bins的大文件中
  • 采用如sqlite的数据库存储数据并读取

比较推荐的一个方法是 sqlite,转换前后,数据存储大小不变,也可以正常多进程读取。给出代码如下:
写入到sqlite数据库中

import sqlite3
from pathlib import Path
 
from tqdm import tqdm
 
 
def read_txt(txt_path):
    with open(txt_path, 'r', encoding='utf-8-sig') as f:
        data = list(map(lambda x: x.rstrip('\n'), f))
    return data
 
 
def img_to_bytes(img_path):
    with open(img_path, 'rb') as f:
        img_bytes = f.read()
        return img_bytes
 
 
class SQLiteWriter(object):
    def __init__(self, db_path):
        self.conn = sqlite3.connect(db_path)
        self.cursor = self.conn.cursor()
 
    def execute(self, sql, value=None):
        if value:
            self.cursor.execute(sql, value)
        else:
            self.cursor.execute(sql)
 
    def __enter__(self):
        return self
 
    def __exit__(self, exc_type, exc_val, exc_tb):
        self.cursor.close()
        self.conn.commit()
        self.conn.close()
 
 
if __name__ == '__main__':
    dataset_dir = Path('datasets/minist')
 
    save_db_dir = dataset_dir / 'sqlite'
    save_db_path = str(save_db_dir / 'val.db')
 
    # val.txt中 每行为:图像路径\t对应文本值 e.g. xxxx.jpg\txxxxxx
    img_paths = read_txt(str(dataset_dir / 'val.txt'))
 
    with SQLiteWriter(save_db_path) as db_writer:
        # 创建表
        table_name = 'minist'
  
        # 注意这里的表中字段,要根据自己数据集来定义
        # 具体数据库类型,可参考:https://docs.python.org/zh-cn/3/library/sqlite3.html#sqlite-and-python-types
        # demo中示例所涉及到的数据集为文本识别数据集,样本为图像,标签为对应文本,
        # 下面示例字段的数据类型为python下的数据类型,只需转为以下对应数据类型即可写入数据库的表中
        # e.g. img_path: str(xxxx.jpg), img_data: bytes格式的图像数据, img_label: str(xxxxx)
        create_table_sql = f'create table {table_name} (img_path TEXT primary key, img_data BLOB, img_label TEXT)'
        db_writer.execute(create_table_sql)
 
        # 向表中插入数据,value部分采用占位符
        insert_sql = f'insert into {table_name} (img_path, img_data, img_label) values(?, ?, ?)'
        for img_info in tqdm(img_paths):
            img_path, label = img_info.split('\t')
 
            img_full_path = str(dataset_dir / 'images' / img_path)
            img_data = img_to_bytes(img_full_path)
 
            db_writer.execute(insert_sql, (img_path, img_data, label))

读取数据库

class SimpleDataset(Dataset):
    def __init__(self, db_path, transform=None) -> None:
        self.db_path = db_path
        self.conn = None
        self.establish_conn()
 
        # 数据库中表名
        self.table_name = 'Synthetic_chinese_dataset'
 
        self.cursor.execute(f'select max(rowid) from {self.table_name}')
        self.nums = self.cursor.fetchall()[0][0]
        self.transform = transform
 
    def __getitem__(self, index: int):
        self.establish_conn()
 
        # 查询
        search_sql = f'select * from {self.table_name} where rowid=?'
        self.cursor.execute(search_sql, (index+1, ))
        img_path, img_bytes, label = self.cursor.fetchone()
 
        # 还原图像和标签
        img = Image.open(BytesIO(img_bytes))
        img = img.convert('RGB')
        img = scale_resize_pillow(img, (320, 32))
 
        if self.transform:
            img = self.transform(img)
        return img, label
 
    def __len__(self) -> int:
        return self.nums
 
    def establish_conn(self):
        if self.conn is None:
            self.conn = sqlite3.connect(self.db_path,
                                        check_same_thread=False,
                                        cached_statements=1024)
            self.cursor = self.conn.cursor()
        return self
 
    def close_conn(self):
        if self.conn is not None:
            self.cursor.close()
            self.conn.close()
 
            del self.conn
            self.conn = None
        return self  
 
# --------------------------------------------------
train_dataset = SimpleDataset(train_db_path, train_transforms)
# ✧✧使用部分,需要手动关闭数据库连接
train_dataset.close_conn()
train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=n_worker,
                              pin_memory=True,
                              sampler=train_sampler)

3. 幻方量化开源方案—FFRecord

好的,接下来才是本章的重点——FFRecord数据存储格式

FFRecord格式是幻方AI开发的一种适用于3FS性能的二进制序列存储格式,适配了 PyTorch 的 Dataset 和 Dataloader 接口,可以非常方便的加载发起训练。具体项目地址: https://github.com/HFAiLab/ffrecord

项目目前没有上到pypi,所以需要手动下载文件包并通过python setup.py install来安装。
注意项目的依赖环境必须满足:

  • OS: Linux
  • Python >= 3.6
  • Pytorch >= 1.6
  • NumPy
  • tqdm
  • zlib: sudo apt-get install zliblg-dev
  • cmake: pip install cmake pybind11 >= 2.8

3.1 FFRecord数据的读写

FFRecord数据的写入

import pickle
from ffrecord import FileWriter

def serialize(sample):
    """ Serialize a sample to bytes or bytearray

    You could use anything you like to serialize the sample.
    Here we simply use pickle.dumps().
    """
    return pickle.dumps(sample)


samples = [i for i in range(100)]  # anything you would like to store
fname = 'test.ffr'
n = len(samples)  # number of samples to be written
writer = FileWriter(fname, n)

for i in range(n):
    data = serialize(samples[i])  # data should be bytes or bytearray
    writer.write_one(data)

writer.close()

以上代码中,通过FileWriter(fname, n)创建对象,其中n是ffr文件内的样本数据,再通过writer.write_one(data)逐个写入样本即可。注意,写入样本的样本个数不能超过预先设定的n。

FFRecord数据的读取

from ffrecord import FileReader


def deserialize(data):
    """ deserialize bytes data

    The deserialize method should be paired with the serialize method above.
    """
    return pickle.loads(data)


fname = 'test.ffr'
reader = FileReader(fname, check_data=True)
print(f'Number of samples: {reader.n}')

indices = [3, 6, 0, 10]      # indices of each sample
data = reader.read(indices)  # return a list of bytes-like data

for i in data:
    sample = deserialize(i)
    print(sample)

reader.close()

输出为:

Number of samples: 100
3
6
0
10

若是数据集太大,无法一次性预先知道具有训练集内具体有多少个样本数n,也可以分批次存储为多个ffr文件,再合并读取:

import pickle
from ffrecord import FileReader


def deserialize(data):
    """ deserialize bytes data

    The deserialize method should be paired with the serialize method above.
    """
    return pickle.loads(data)


fname_list = ['test1.ffr', 'test2.ffr']
reader = FileReader(fname_list , check_data=True)
print(f'Number of samples: {reader.n}')

indices = [0, 5, 100, 105]      # indices of each sample
data = reader.read(indices)  # return a list of bytes-like data

for i in data:
    sample = deserialize(i)
    print(sample)

输出为:

Number of samples: 200
0
5
100
105

以上代码中,test1.ffr存储了100个值为0-99的的样本,test2.ffr存储了100个值为100-199的的样本,FileReader 通过读取文件地址的list,将两个文件合并。

3.1 ffrecord.torch.Dataset 和 ffrecord.torch.DataLoader

项目内,提供了 ffrecord.torch.Datasetffrecord.torch.DataLoader 专门用于PyTorch 的训练数据准备。

数据集的自定义与 PyTorch 的 Dataset 基本相同:

class CustomDataset(ffrecord.torch.Dataset):

    def __init__(self, fname, check_data=True, transform=None):
        self.reader = FileReader(fname, check_data)
        self.transform = transform

    def __len__(self):
        return self.reader.n

    def __getitem__(self, indices):
        # we read a batch of samples at once
        assert isintance(indices, list)
        data = self.reader.read(indices)

        # deserialize data
        samples = [pickle.loads(b) for b in data]

        # transform data
        if self.transform:
            samples = [self.transform(s) for s in samples]
        return samples

dataset = CustomDataset('train.ffr')
indices = [3, 4, 1, 0]
samples = dataset[indices]

与 PyTorch 的Dataset 稍有不同的是, ffrecord.torch.Dataset 可以一次性输入一组数据索引的列表,并直接返回这一组样本值。

ffrecord.torch.DataLoader 与 PyTorch 的 DataLoader 基本相同:

dataset = CustomDataset('train.ffr')
loader = ffrecord.torch.DataLoader(dataset,
                                   batch_size=16,
                                   shuffle=True,
                                   num_workers=8)

start_epoch = 5
start_step = 100  # resume from epoch 5, step 100
loader.set_step(start_step)

for epoch in range(start_epoch, epochs):
    for i, batch in enumerate(loader):
        # training model

    loader.set_step(0)  # remember to reset before the next epoch

这里的 loader 支持通过set_step()方法在训练过程中跳过数据 step,以便在训练中断时,在断点处继续训练。

基于 ffrecord 项目的数据读取方法极大降低数据加载的开销,充分利用GPU的计算时间,提高GPU利用率。

参考资料

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

RicardoOzZ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值