# 该代码主要是为了讲解介绍Dataloader的工作机制
from torch.utils import data
# 任何Dataset数据类的子类,并重写相关的函数
class NerDataset(data.Dataset):
# 将需要的参数进行初始化
def __init__(self, examples, tokenizer, label_map, max_seq_length):
self.examples=examples
self.tokenizer=tokenizer
self.label_map=label_map
self.max_seq_length=max_seq_length
# 该函数是为了计算样本的数量,其意义在于保证__getitem__()的参数inx的取值范围在(0, len)之间
def __len__(self):
return len(self.examples)
# 迭代传入__getitem__()索引,并返回相应的内容
def __getitem__(self, idx):
feat=example2feature(self.examples[idx], self.tokenizer, self.label_map, max_seq_length)
return feat.input_ids, feat.input_mask, feat.segment_ids, feat.predict_mask, feat.label_ids
# classmethod修饰符的函数不需要实例化,不需要 self 参数。第一个参数cls表示自身类,用于调用类的属性、类的方法、实例化对象等(如cls.tokenizer)。
# 此函数的第二个参数batch,即是__getitem__()迭代batch_size次的返回结果来作为参数。
@classmethod
def pad(cls, batch):
seqlen_list = [len(sample[
Pytorch中的DataLoader处理机制
最新推荐文章于 2025-12-01 09:46:19 发布
博客详细解释了data.Dataloader()的工作流程,包括构建Dataset子类、以特定方式迭代获取batch_data、对batch_data进行collate_fn()操作,最终得到供torch使用的数据对象,还提到需将变量类型处理成torch.tensor,最后给出代码示例解读的参考文献。

最低0.47元/天 解锁文章
2117

被折叠的 条评论
为什么被折叠?



