中文文本纠错_论文Spelling Error Correction with Soft-Masked BERT(ACL_2020)学习笔记与模型复现
最近在ACL 2020上看到一篇论文《Spelling Error Correction with Soft-Masked BERT》,论文的主题为中文文本纠错中的**Chinese spelling error correction (CSC)**任务,论文作者为来自字节跳动AI Lab与复旦大学的研究人员。
《Spelling Error Correction with Soft-Masked BERT》一文中主要提出了一种新的模型框架名为Soft-Masked BERT。Soft-Masked BERT模型框架中主要含有两部分模型, 一部分称之为错误探查网络Detection Network, 另一部分称之为纠错网络Correction Network。
错误探查网络Detection Network由一个双向的GRU模型组成,而纠错网络Correction Network则基于预训练的Bert模型构建。两种网络则通过一种名为Soft masking的方式连接, 即错误探查网络Detection Network的输出经过Soft masking之后再输入进纠错网络Correction Network进行计算。
接下来, 将对论文内容中从数据集构建、模型构建、训练过程、主要实验结果、讨论这五部分进行详细地阐述。
一、数据集构建
数据集的构建在整个《Spelling Error Correction with Soft-Masked BERT》论文框架中起着重要的作用。论文中一共构建了三种数据集,分别为SIGHAN数据集、News Title数据集与5 million news titles数据集。
(1) SIGHAN数据集
SIGHAN数据集为Chinese Spelling Check Task领域的一个benchmark数据集,数据集的链接如下:SIGHAN 2013 Bake-off: Chinese Spelling Check Task。
SIGHAN数据集中包含1100条文本,共有461种错误(spelling errors),这些文本都是从中文文章中收集,相对来说数据集的主题范围较窄。
SIGHAN数据集被分为了三部分,分别为:训练集(training set)、开发集(development set)、测试集(test set)。在这里训练集(training set)用于Soft-Masked BERT模型的fine-tuning,测试集(test set)用于检测模型的性能,而开发集(development set)则被用来对超参数进行调整(hyper-parameter tuning)。在这里,一种能够提升模型性能的方法是将SIGHAN数据集的训练集(training set)中可能存在的一些不包含错误的文本从中剔除(unchanged texts),这样模型进行fine-tuning的SIGHAN训练集(training set)中所有文本都是包含错误(spelling errors)的文本。
(2) News Title数据集
相比于SIGHAN数据集,News Title数据集是一个更大的数据集。News Title数据集中的文本都来自于今日头条app中的文章的标题部分,这些文本的内容涉及政治、娱乐、体育、教育等许多方面。作者为了确保News Title数据集中包含足够多的错误文本,特意从低质量的文本中抽样了15730条样本,所有样本中一共有5,423个样本的文本包含了拼写错误(spelling errors),错误的类型一共有3441种。
值得注意的是,这里News Title数据集仅被对半分成了两部分,分别为:开发集(development set)与测试集(test set)。News Title数据集的开发集也是用来对超参数进行调整(hyper-parameter tuning),News Title数据集的测试集也用于检测模型的性能。
(3) 5 million news titles数据集
5 million news titles数据集中的文本都是从一些中文新闻app中爬取下来的。
同时,作者在这里创建了同音字混淆表(confusion table)。在5 million news titles数据集的文本中,对15%的字符进行随机替换,这15%被随机替换的字符中,有80%的字符使用同音字混淆表confusion table中此字符的同音字符进行替换;而剩下20%的字符使用随机字符进行替换。以这种方式构建数据集再用于训练模型,能够让训练出的Soft-Masked BERT模型获得较强的同音字混淆错误的纠正能力。
需要特别注意的是,5 million news titles数据集只会被用来对模型进行fine-tuning。
例如实验中,在利用SIGHAN数据集的测试集(test set)来检测模型性能之前,模型会先在5 million news titles数据集上做一次fine-tuning,再在SIGHAN数据集的训练集(training set)上再做一次fine-tuning,最后才会使用SIGHAN数据集的测试集(test set)来检测模型性能。
再如,实验中在利用News Title数据集的测试集(test set)来检测模型性能之前,也会先在5 million news titles数据集上fine-tuning一次,才会在News Title数据集的测试集(test set)上检测模型性能。
因此可以看出,在5 million news titles数据集上进行fine-tuning在整个模型训练过程以及之后的性能检测中起到至关重要的作用。
二、模型构建
Soft-Masked BERT模型的机构如下图所示:
Soft-Masked BERT模型框架细分则可以被分为三部分:错误探查网络Detection Network、Soft Masking Connection、纠错网络Correction Network。
(1) 模型输入
整个模型框架的输入input embeddings是由文本句子中每一个字符的word embedding、position embedding、segment embedding三部分嵌入的加和embedding构成的。因此,可以看出Soft-Masked BERT模型框架的输入实际和Bert模型的一般输入形式相同。
在上式中, x i x_{i} xi表示一个文本序列中的第 i i i个字符, e i e_{i} ei表示第 i i i个字符经过三部分嵌入后的加和embedding表示(input embedding)。
(2) 错误探查网络Detection Network
Soft-Masked BERT模型框架中的错误探查网络Detection Network实质上为一个双向GRU模型(Bi-GRU)。Bi-GRU模型对每个文本序列进行正向与反向编码,再将最后一层隐藏层中文本序列的正向编码的隐藏状态与反向编码的隐藏状态横向合并,Bi-GRU模型的计算过程如下公式所示:
h i d h_{i}^{d} hid是文本序列中字符 i i i的嵌入 e i e_{i} ei在经过Bi-GRU模型计算后最后一层隐藏层中双向编码的隐藏状态。论文中Bi-GRU模型的隐藏层维度数设置为256,双向编码后的隐藏层输出维度数为512。
之后,Bi-GRU模型计算得出的 h i d h_{i}^{d} hid会被输入进两个全连接层中分别计算。
-
Detection Network二分类输出计算
Detection Network中Bi-GRU模型的输出 h i d h_{i}^{d} hid会被输入进一个全连接层中进行二分类学习。在计算整个Soft-Masked BERT模型的损失函数时,Detection Network与Correction Network各自的交叉熵损失值的带权加和,构成了Soft-Masked BERT模型损失函数的表示。
上式中, b b b为此全连接层的偏置项; W W W为此全连接层的权重矩阵, W W W将 h i d h_{i}^{d} hid映射到维度为2的空间中,再经过一层 s o f t m a x softmax softmax层之后,即可计算Detection Network的二分类输出的损失值。此处 P d ( y i = k │ X ) P_d (y_i=k│X) Pd(yi=k│X)表示错误探查网络Detection Network分类文本序列中每一个字符 x i x_{i} xi是否为拼写错误字符的二分类条件概率。 -
Soft Masking Connection系数计算
另一个全连接层用来计算Soft-Masked BERT模型中Soft Masking Connection处的系数 p i p_{i} pi,其计算过程如下公式所示:
其中, b d b_{d} bd表示此全连接层中的偏置项; W d W_{d} Wd表示此全连接层中的权重矩阵,其会将 h i d h_{i}^{d} hid映射到维度为1的空间中。此全连接层的输出会再被输入进Sigmoid层中,将值映射到(0,1)之间,这样经过Sigmoid层后输出的值 p i p_{i} pi就为Soft Masking Connection处的系数。
Soft Masking Connection为此篇论文的核心idea之一,其作用是利用计算得到的Soft Masking Connection的系数 p i p_{i} pi来对整个模型框架的输入input embeddings( e i e_{i} ei)与"mask特殊符"的嵌入mask embeddings( e m a s k e_{mask} emas