文章目录
Tensorflow有着专门的数据读取模块tfrecord,可以直接从硬盘读取数据,而不受内存限制,但Pytorch一直没有一个专有的、高效的数据读取方法。如何基于PyTorch,高效读取大规模训练数据集,并充分利用GPU性能,成为一个关键问题。
在笔者之前的文章(tensorflow2.x(一)显存不够或内存不够要怎么办?)中,介绍过 tensorflow 针对大数据集的训练有 tfrecords的数据存储格式,可以将超大的数据样本先存在硬盘,每次只从存储硬盘中读取一个batch的数据入内存,而不是将整个训练样本一次性全部读入,因此可以大大减小内存的限制。
1.官方给出的方案
我们先介绍一些官方给出的解决方法及其局限:
在Pytorch中,官方给了可以自定义的数据集基类 Dataset:
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
只要继承,并实现这两个抽象方法即可:
- getitem() 方法通过索引返回数据集中选定的样本。
- len() 方法返回数据集的总大小。例如,如果您的数据集包含 1,00,000 个样本,则该len方法应返回 1,00,000
举个栗子:
#导入相关模块
from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import os
import torch
from torchvision import transforms
import numpy as np
class AnimalData(Dataset): #继承Dataset
def __init__(self, root_dir):
self.root_dir = root_dir #文件目录
self.images = os.listdir(self.root_dir) #目录里的所有文件
def __len__(self): #返回整个数据集的大小
return len(self.images)
def __getitem__(self,index) :#根据索引index返回dataset[index]
image_index = self.images[index]
img_path = os.path.join(self.root_dir, image_index)
img = io.imread(img_path) # 读取该图片
label = img_path.split('\\')[-1].split('.')[0]
sample = {'image':img,'label':label}
return sample # 返回该样本及其标签
可以看到,网络上的大部分例子都是类似的,都是针对图片数据,每张图片即是一个样本,训练时,只需要预先读取样本(图片)的文件地址,再通过 DataLoader 依次加载即可。
但是!对于非图片数据的样本,假如我们有数百万上千万个样本,但是每个样本又很小,如果按照这种方法使每个样本独立存储,就会有数百万个小文件,这就相当不优雅,且速度较慢。
2. 其他可行方案
PyTorch下训练数据小文件转大文件读写(附有各种存储格式对比),这篇文章中给出了几个可行的解决方案,诸如:
- 在pytorch下,使用 TFRecord 方法
- 直接将现有数据集按照二进制读取,存入一个bins的大文件中
- 采用如sqlite的数据库存储数据并读取
比较推荐的一个方法是 sqlite,转换前后,数据存储大小不变,也可以正常多进程读取。给出代码如下:
写入到sqlite数据库中
import sqlite3
from pathlib import Path
from tqdm import tqdm
def read_txt(txt_path):
with open(txt_path, 'r', encoding='utf-8-sig') as f:
data = list(map(lambda x: x.rstrip('\n'), f))
return data
def img_to_bytes(img_path):
with open(img_path, 'rb') as f:
img_bytes = f.read()
return img_bytes
class SQLiteWriter(object):
def __init__(self, db_path):
self.conn = sqlite3.connect(db_path)
self.cursor = self.conn.cursor()
def execute(self, sql, value=None):
if value:
self.cursor.execute(sql, value)
else:
self.cursor.execute(sql)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.cursor.close()
self.conn.commit()
self.conn.close()
if __name__ == '__main__':
dataset_dir = Path('datasets/minist')
save_db_dir = dataset_dir / 'sqlite'
save_db_path = str(save_db_dir / 'val.db')
# val.txt中 每行为:图像路径\t对应文本值 e.g. xxxx.jpg\txxxxxx
img_paths = read_txt(str(dataset_dir / 'val.txt'))
with SQLiteWriter(save_db_path) as db_writer:
# 创建表
table_name = 'minist'
# 注意这里的表中字段,要根据自己数据集来定义
# 具体数据库类型,可参考:https://docs.python.org/zh-cn/3/library/sqlite3.html#sqlite-and-python-types
# demo中示例所涉及到的数据集为文本识别数据集,样本为图像,标签为对应文本,
# 下面示例字段的数据类型为python下的数据类型,只需转为以下对应数据类型即可写入数据库的表中
# e.g. img_path: str(xxxx.jpg), img_data: bytes格式的图像数据, img_label: str(xxxxx)
create_table_sql = f'create table {table_name} (img_path TEXT primary key, img_data BLOB, img_label TEXT)'
db_writer.execute(create_table_sql)
# 向表中插入数据,value部分采用占位符
insert_sql = f'insert into {table_name} (img_path, img_data, img_label) values(?, ?, ?)'
for img_info in tqdm(img_paths):
img_path, label = img_info.split('\t')
img_full_path = str(dataset_dir / 'images' / img_path)
img_data = img_to_bytes(img_full_path)
db_writer.execute(insert_sql, (img_path, img_data, label))
读取数据库
class SimpleDataset(Dataset):
def __init__(self, db_path, transform=None) -> None:
self.db_path = db_path
self.conn = None
self.establish_conn()
# 数据库中表名
self.table_name = 'Synthetic_chinese_dataset'
self.cursor.execute(f'select max(rowid) from {self.table_name}')
self.nums = self.cursor.fetchall()[0][0]
self.transform = transform
def __getitem__(self, index: int):
self.establish_conn()
# 查询
search_sql = f'select * from {self.table_name} where rowid=?'
self.cursor.execute(search_sql, (index+1, ))
img_path, img_bytes, label = self.cursor.fetchone()
# 还原图像和标签
img = Image.open(BytesIO(img_bytes))
img = img.convert('RGB')
img = scale_resize_pillow(img, (320, 32))
if self.transform:
img = self.transform(img)
return img, label
def __len__(self) -> int:
return self.nums
def establish_conn(self):
if self.conn is None:
self.conn = sqlite3.connect(self.db_path,
check_same_thread=False,
cached_statements=1024)
self.cursor = self.conn.cursor()
return self
def close_conn(self):
if self.conn is not None:
self.cursor.close()
self.conn.close()
del self.conn
self.conn = None
return self
# --------------------------------------------------
train_dataset = SimpleDataset(train_db_path, train_transforms)
# ✧✧使用部分,需要手动关闭数据库连接
train_dataset.close_conn()
train_dataloader = DataLoader(train_dataset,
batch_size=batch_size,
num_workers=n_worker,
pin_memory=True,
sampler=train_sampler)
3. 幻方量化开源方案—FFRecord
好的,接下来才是本章的重点——FFRecord数据存储格式。
FFRecord格式是幻方AI开发的一种适用于3FS性能的二进制序列存储格式,适配了 PyTorch 的 Dataset 和 Dataloader 接口,可以非常方便的加载发起训练。具体项目地址: https://github.com/HFAiLab/ffrecord
项目目前没有上到pypi,所以需要手动下载文件包并通过python setup.py install
来安装。
注意项目的依赖环境必须满足:
- OS: Linux
- Python >= 3.6
- Pytorch >= 1.6
- NumPy
- tqdm
- zlib:
sudo apt-get install zliblg-dev
- cmake:
pip install cmake pybind11 >= 2.8
3.1 FFRecord数据的读写
FFRecord数据的写入
import pickle
from ffrecord import FileWriter
def serialize(sample):
""" Serialize a sample to bytes or bytearray
You could use anything you like to serialize the sample.
Here we simply use pickle.dumps().
"""
return pickle.dumps(sample)
samples = [i for i in range(100)] # anything you would like to store
fname = 'test.ffr'
n = len(samples) # number of samples to be written
writer = FileWriter(fname, n)
for i in range(n):
data = serialize(samples[i]) # data should be bytes or bytearray
writer.write_one(data)
writer.close()
以上代码中,通过FileWriter(fname, n)
创建对象,其中n是ffr文件内的样本数据,再通过writer.write_one(data)
逐个写入样本即可。注意,写入样本的样本个数不能超过预先设定的n。
FFRecord数据的读取
from ffrecord import FileReader
def deserialize(data):
""" deserialize bytes data
The deserialize method should be paired with the serialize method above.
"""
return pickle.loads(data)
fname = 'test.ffr'
reader = FileReader(fname, check_data=True)
print(f'Number of samples: {reader.n}')
indices = [3, 6, 0, 10] # indices of each sample
data = reader.read(indices) # return a list of bytes-like data
for i in data:
sample = deserialize(i)
print(sample)
reader.close()
输出为:
Number of samples: 100
3
6
0
10
若是数据集太大,无法一次性预先知道具有训练集内具体有多少个样本数n,也可以分批次存储为多个ffr文件,再合并读取:
import pickle
from ffrecord import FileReader
def deserialize(data):
""" deserialize bytes data
The deserialize method should be paired with the serialize method above.
"""
return pickle.loads(data)
fname_list = ['test1.ffr', 'test2.ffr']
reader = FileReader(fname_list , check_data=True)
print(f'Number of samples: {reader.n}')
indices = [0, 5, 100, 105] # indices of each sample
data = reader.read(indices) # return a list of bytes-like data
for i in data:
sample = deserialize(i)
print(sample)
输出为:
Number of samples: 200
0
5
100
105
以上代码中,test1.ffr
存储了100个值为0-99的的样本,test2.ffr
存储了100个值为100-199的的样本,FileReader
通过读取文件地址的list,将两个文件合并。
3.1 ffrecord.torch.Dataset 和 ffrecord.torch.DataLoader
项目内,提供了 ffrecord.torch.Dataset
和ffrecord.torch.DataLoader
专门用于PyTorch 的训练数据准备。
数据集的自定义与 PyTorch 的 Dataset 基本相同:
class CustomDataset(ffrecord.torch.Dataset):
def __init__(self, fname, check_data=True, transform=None):
self.reader = FileReader(fname, check_data)
self.transform = transform
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
# we read a batch of samples at once
assert isintance(indices, list)
data = self.reader.read(indices)
# deserialize data
samples = [pickle.loads(b) for b in data]
# transform data
if self.transform:
samples = [self.transform(s) for s in samples]
return samples
dataset = CustomDataset('train.ffr')
indices = [3, 4, 1, 0]
samples = dataset[indices]
与 PyTorch 的Dataset 稍有不同的是, ffrecord.torch.Dataset
可以一次性输入一组数据索引的列表,并直接返回这一组样本值。
ffrecord.torch.DataLoader 与 PyTorch 的 DataLoader 基本相同:
dataset = CustomDataset('train.ffr')
loader = ffrecord.torch.DataLoader(dataset,
batch_size=16,
shuffle=True,
num_workers=8)
start_epoch = 5
start_step = 100 # resume from epoch 5, step 100
loader.set_step(start_step)
for epoch in range(start_epoch, epochs):
for i, batch in enumerate(loader):
# training model
loader.set_step(0) # remember to reset before the next epoch
这里的 loader 支持通过set_step()
方法在训练过程中跳过数据 step,以便在训练中断时,在断点处继续训练。
基于 ffrecord 项目的数据读取方法极大降低数据加载的开销,充分利用GPU的计算时间,提高GPU利用率。
参考资料