打卡
基于MindSpore的GPT2文本摘要
环境部署
%%capture captured_output
#如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
!pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp==0.3.1
数据集加载与处理
- 数据集加载
- 这里的实验用的是nlpcc2017z摘要数据,内容为新闻正文及其摘要,总计50000个样本。
from mindnlp.utils import http_get
# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')
from mindspore.dataset import TextFileDataset
# load dataset
dataset = TextFileDataset(str(path), shuffle=False)
# 查看其大小
dataset.get_dataset_size() #50000
# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)
-
数据预处理
原始数据格式
article: [CLS] article_context [SEP] summary: [CLS] summary_context [SEP]
预处理后的数据格式
[CLS] article_context [SEP] summary_context [SEP]
# 这是预处理数据的函数
import json
import numpy as np
# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
def read_map(text):
data = json.loads(text.tobytes())
return np.array(data['article']), np.array(data['summarization'])