手敲大模型-基础篇 实现Dataset DataLoader方法

import random

def read_data(path):
    with open(path, encoding = "utf-8" ) as f:
        all_data = f.read().split("\n")
    return all_data

#MyDataLoader类功能: __init__每一次迭代开始时初始化变量,__next__在每一次迭代中取出批次大小个classType和content
class MyDataLoader():
    def __init__(self,dataset,batch_size,shuffle = 0):
        self.dataset = dataset
        self.cursor = 0
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __next__(self):
        if self.cursor >= len(self.dataset):
            raise StopIteration  # 停止迭代

        batch_size1 = []
        batch_size2 = []

        # 获取当前批次的索引范围
        start_index = self.cursor
        end_index = min(self.cursor + self.batch_size, len(self.dataset))
        # 获取当前批次的数据
        batch_indices = list(range(start_index, end_index))
        #batch_index = range(self.cursor-self.batch_size,self.cursor-self.batch_size+cur_batch)
        #我范的错误❌TypeError: 'range' object does not support item assignment


        # 打乱当前批次的索引顺序
        #random.shuffle(batch_indices) 的作用是就地打乱列表顺序,不返回值,它直接修改传入的列表对象,使其顺序被打乱
        if self.shuffle :
            print("进行shuffle操作")
            random.shuffle(batch_indices)
        print(f"当前批次的索引:",batch_indices)
        for i in batch_indices:
            class_type, content = dataset[i]
            batch_size1.append(class_type)
            batch_size2.append(content)
        self.cursor += self.batch_size #每次迭代应该对cursor值进行+batch_size操作
        return batch_size1, batch_size2



#MyDataset类的功能是, getitem取下标时自动提取类别和内容数据, iter在调用可迭代对象时调用MyDataLoader类
class MyDataset():
    def __init__(self,all_data, batch_size,shuffle):
        self.all_data = all_data
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __len__(self):
        return len(self.all_data)

    #对dataset进行,数据清洗,在通过下标调用dataset是只保留自己需要的数据即(类别和内容数据)
    def __getitem__(self, item):

        #数据清洗部分
        if item >= len(self.all_data):
            raise IndexError("list index out of range")
        data = self.all_data[item].split("_!_")
        if len(data) == 5:
            _, class_type, _, content, _ = data
        else:
            print(f"跳过格式不正确的数据项: {self.all_data[item]}")
            class_type, content = None, None  # 或者设置为默认值

        return class_type, content

    def __iter__(self):
        return MyDataLoader(self,self.batch_size,self.shuffle)#这里很有趣,self表示的是MyDataset实例化对象即dataset所对应的数据集

if __name__ == "__main__":
    all_data = read_data("data.txt")

    #参数区
    batch_size = 2
    epoch = 10


    dataset = MyDataset(all_data,batch_size,1)

    d = dataset[0]
    print(d)

    for e in range(epoch):
        print(f"-----------第{e+1}批次-----------")
        for batch_data1,batch_data2 in dataset:
            print(batch_data1, batch_data2)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值