import random def read_data(path): with open(path, encoding = "utf-8" ) as f: all_data = f.read().split("\n") return all_data #MyDataLoader类功能: __init__每一次迭代开始时初始化变量,__next__在每一次迭代中取出批次大小个classType和content class MyDataLoader(): def __init__(self,dataset,batch_size,shuffle = 0): self.dataset = dataset self.cursor = 0 self.batch_size = batch_size self.shuffle = shuffle def __next__(self): if self.cursor >= len(self.dataset): raise StopIteration # 停止迭代 batch_size1 = [] batch_size2 = [] # 获取当前批次的索引范围 start_index = self.cursor end_index = min(self.cursor + self.batch_size, len(self.dataset)) # 获取当前批次的数据 batch_indices = list(range(start_index, end_index)) #batch_index = range(self.cursor-self.batch_size,self.cursor-self.batch_size+cur_batch) #我范的错误❌TypeError: 'range' object does not support item assignment # 打乱当前批次的索引顺序 #random.shuffle(batch_indices) 的作用是就地打乱列表顺序,不返回值,它直接修改传入的列表对象,使其顺序被打乱 if self.shuffle : print("进行shuffle操作") random.shuffle(batch_indices) print(f"当前批次的索引:",batch_indices) for i in batch_indices: class_type, content = dataset[i] batch_size1.append(class_type) batch_size2.append(content) self.cursor += self.batch_size #每次迭代应该对cursor值进行+batch_size操作 return batch_size1, batch_size2 #MyDataset类的功能是, getitem取下标时自动提取类别和内容数据, iter在调用可迭代对象时调用MyDataLoader类 class MyDataset(): def __init__(self,all_data, batch_size,shuffle): self.all_data = all_data self.batch_size = batch_size self.shuffle = shuffle def __len__(self): return len(self.all_data) #对dataset进行,数据清洗,在通过下标调用dataset是只保留自己需要的数据即(类别和内容数据) def __getitem__(self, item): #数据清洗部分 if item >= len(self.all_data): raise IndexError("list index out of range") data = self.all_data[item].split("_!_") if len(data) == 5: _, class_type, _, content, _ = data else: print(f"跳过格式不正确的数据项: {self.all_data[item]}") class_type, content = None, None # 或者设置为默认值 return class_type, content def __iter__(self): return MyDataLoader(self,self.batch_size,self.shuffle)#这里很有趣,self表示的是MyDataset实例化对象即dataset所对应的数据集 if __name__ == "__main__": all_data = read_data("data.txt") #参数区 batch_size = 2 epoch = 10 dataset = MyDataset(all_data,batch_size,1) d = dataset[0] print(d) for e in range(epoch): print(f"-----------第{e+1}批次-----------") for batch_data1,batch_data2 in dataset: print(batch_data1, batch_data2)
手敲大模型-基础篇 实现Dataset DataLoader方法
于 2025-03-10 14:43:39 首次发布