pytorch加载大数据
本文介绍的数据特点:
- 数据量大,无法一次读取到内存中
- 数据存储在csv或者文本文件中(每一行是一个sample,包括feature和label)
要求:
- 每次读取一小块数据到内存
- 能够batch
- 能够shuffle
自定义MyDataset,继承torch.utils.data.Dataset,重写__init__(),__len__(),__getitem__(),增加initial()
import torch.utils.data as Data
import random
class MyDataset(Data.Dataset):
def __init__(self,file_path,nraws,shuffle=False):
"""
file_path: the path to the dataset file
nraws: each time put nraws sample into memory for shuffle
shuffle: whether the data need to shuffle
"""
file_raws = 0
# get the count of all samples
with open(file_path,'r') as f: