使用load_dataset()
API默认读取到的数据集是MapDataset
对象,MapDataset
是paddle.io.Dataset
的功能增强版本。其内置的map()
方法适合用来进行批量数据集处理。map()
方法传入的是一个用于数据处理的function。 以下是Dureader-Robust中数据转化的用法:
max_seq_length = 512
doc_stride = 128
train_trans_func = partial(prepare_train_features,
max_seq_length=max_seq_length,
doc_stride=doc_stride,
tokenizer=tokenizer)
train_ds.map(train_trans_func, batched=True)
dev_trans_func = partial(prepare_validation_features,
max_seq_length=max_seq_length,
doc_stride<