
【关于 Transformer 代码实战(文本摘要任务篇)】 那些你不知道的事
作者:杨夕
项目地址: https:// github.com/km1994/nlp_p aper_study
个人介绍:大佬们好,我叫杨夕,该项目主要是本人在研读顶会论文和复现经典论文过程中,所见、所思、所想、所闻,可能存在一些理解错误,希望大佬们多多指正。
目录
- 【关于 Transformer 代码实战(文本摘要任务篇)】 那些你不知道的事
- 目录
- 引言
- 一、文本摘要数据集介绍
- 二、数据集加载介绍
- 2.1 数据加载
- 2.2 数据字段抽取
- 三、 数据预处理
- 3.1 summary 数据 处理
- 3.2 编码处理
- 3.3 获取 encoder 词典 和 decoder 词典 长度
- 3.4 确定 encoder 和 decoder 的 maxlen
- 3.5 序列 填充/裁剪
- 四、创建数据集 pipeline
- 五、组件构建
- 5.1 位置编码
- 5.1.1 问题
- 5.1.2 目的
- 5.1.3 思路
- 5.1.4 位置向量的作用
- 5.1.5 步骤
- 5.1.6 计算公式
- 5.1.7 代码实现
- 5.2 Masking 操作
- 5.2.1 介绍
- 5.2.3 类别:padding mask and sequence mask
- padding mask
- sequence mask
- 六、模型构建
- 6.1 self-attention
- 6.1.1 动机
- 6.1.2 传统 Attention
- 6.1.3 核心思想
- 6.1.4 目的
- 6.1.5 公式
- 6.1.6 步骤
- 6.1.7 代码实现
- 6.2 Multi-Headed Attention
- 思路
- 步骤
- 代码实现
- 6.3 前馈网络
- 思路
- 目的
- 代码实现
- 6.4 Transformer encoder 单元
- 结构
- 代码实现
- 6.5 Transformer decoder 单元
- 结构
- 代码实现
- 七、Encoder 和 Decoder 模块构建
- 7.1 Encoder 模块构建
- 7.2 Dncoder 模块构建
- 八、Transformer 构建
- 九、模型训练
- 9.1 配置类
- 9.2 优化函数定义
- 9.3 Loss 损失函数 和 评测指标 定义
- 9.3.1 Loss 损失函数 定义
- 9.4 Transformer 实例化
- 9.5 Mask 实现
- 9.6 模型结果保存
- 9.7 Training Steps
- 9.8 训练
引言
之前给 小伙伴们 写过 一篇 【【关于Transformer】 那些的你不知道的事】后,有一些小伙伴联系我,并和我请教了蛮多细节性问题,针对该问题,小菜鸡的我 也 想和小伙伴 一起 学习,所以就 找到了 一篇【Transformer 在文本摘要任务 上的应用】作为本次学习的 Coding!
一、文本摘要数据集介绍
本任务采用的 文本摘要数据集 为 Kaggle 比赛 之 Inshorts Dataset,该数据集 包含以下字段:
序号 | 字段名 | 字段介绍 | 举例 |
1 | Headline | 标题 | 4 ex-bank officials booked for cheating bank of ₹209 crore |
2 | Short | 短文 | The CBI on Saturday booked four former officials of Syndicate Bank and six others for cheating, forgery, criminal conspiracy and causing ₹209 crore loss to the state-run bank. The accused had availed home loans and credit from Syndicate Bank on the basis of forged and fabricated documents. These funds were fraudulently transferred to the companies owned by the accused persons. |
3 | Source | 数据来源 | The New Indian Express |
4 | Time | 发表时间 | 9:25:00 |
5 | Publish Date | 发表日期 | 2017/3/26 |
注:这里我们只 用到 Headline[摘要] 和 Short[长文本] 作为 文本摘要任务 实验数据
二、数据集加载介绍
2.1 数据加载
本文将数据集存储在 Excel 文件中,通过 pandas 的 read_excel() 方法 获取数据集,代码如下:
news = pd.read_excel("data/news.xlsx")
2.2 数据字段抽取
在 一、文本摘要数据集介绍 中,我们说过,我们只用到 Headline[摘要] 和 Short[长文本] 作为 文本摘要任务 实验数据,所以我们需要 清除 其他字段。代码如下:
news.drop(['Source ', 'Time ', 'Publish Date'], axis=1, inplace=True)
可以采用以下命令,查看结果:
news.head()
news.shape # (55104, 2)

方便后期操作,我们这里直接 从 DataFrame 中分别抽取 出 Headline[摘要] 和 Short[长文本] 数据:
document = news['Short']
summary = news['Headline']
document[30], summary[30]
>>>
('According to the Guinness World Records, the most generations alive in a single family have been seven. The difference between the oldest and the youngest person in the family was about 109 years, when Augusta Bunge's great-great-great-great grandson was born on January 21, 1989. The family belonged to the United States of America.',
'The most generations alive in a single family have been 7')
三、 数据预处理
3.1 summary 数据 处理
summary 数据 作为 decoder 序列数据,我们需要做一些小处理【前后分别加一个标识符】,如下所示:
# for decoder sequence
summary = summary.apply(lambda x: '<go> ' + x + ' <stop>')
summary[0]
>>>
'<go> 4 ex-bank officials booked for cheating bank of ₹209 crore <stop>'
3.2 编码处理
在 进行 文本摘要任务 之前,我们需要 将 文本进行编码:
- 变量定义
# since < and > from default tokens cannot be removed
filters = '!"#$%&()*+,-./:;=?@[]^_`{|}~tn' # 文本中特殊符号清洗
oov_token = '<unk>' # 未登录词 表示
- 定义 文本预处理 tf.keras.preprocessing.text.Tokenizer() 编码类【用于后期 文本编码处理】
document_tokenizer = tf.keras.preprocessing.text.Tokenizer(oov_token=oov_token)
summary_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters=filters, oov_token=oov_token)
Tokenizer : 一个将文本向量化,转换成序列的类。用来文本处理的分词、嵌入 。
keras.preprocessing.text.Tokenizer(num_words=None,
filters='!"#$%&()*+,-./:;<=>?@[]^_`{|}~tn',
lower=True,
split=' ',
char_level=False,
oov_token=None,
document_count=0)
- 参数说明:
- num_words: 默认是None处理所有字词,但是如果设置成一个整数&#x