Text-to-SQL模型----IRNET
笔者最近在做Text-to-SQL任务,看了这篇论文记录一下理解过程,如有理解错误,欢迎大家指正。
IRNET模型是微软2019年发表于ACL的论文,论文题目为Towards Complex Text-to-SQL in Cross-Domain Database with Intermediate Representation,论文下载地址https://arxiv.org/pdf/1905.08205.pdf
代码地址为https://github.com/microsoft/IRNet
这篇论文的创新点在于提出了SemQL,是自然语言和SQL问题的中间表示,从自然语言得到SemQL,SemQL解码得到SQL。在实际数据中,SQL的group by,having等中出现的column 并为出现在自然语言中(这里是指question),SQL中group by做聚合函数方便使用,但是很少有很细节的考虑,使得端到端的模型学习是一个挑战。这些问题统称为mismatch problem。
IRNET模型
IRNET模型主要包含SemQL,schema linking,NL encoder, schema encoder, decoder。其中SemQL是将SQL表征为SemQL树,schema linking是NL,columns 确定类型的过程,NL encoder 是对输入的NL进行编码,schema encoder 是对数据库中的column和table 进行encoder,decoder是得到sql语句的过程,模型结构图如下:
SemQL
为了解决NL和SQL之间的不匹配问题,提出了SemQL,是NL和SQL的中间表示。建立SemQL,需要提前定义规则,将SQL表征为一个树形结构。在SemQL中不会出现SQL语句中group by,having,嵌套的子句等,并且where,having中中的条件统一用filter表示,每一个节点中column和table,指定table方便表征重复的列。SemQL规则如下图:
其中(1)Z表示 sql语句中是否包含intersect,union,except;
(2)R是sql中的select 查询是否有where和orderby,
(3)Order,Suerlative是对应着orderby中内容,如果orderby 中有limit,对应Suerlative,反之对应order。
(4)Filter表示不同的计算符号,如果filter中的节点A的聚合节点为None,则表示having,否则表示where,如果R在filter下的节点,表示嵌套查询。
下面给出一个SQL的SemQL的结果:
NL: Show the names of students who have a grade higher than 5 and have at least 2 friends.
SQL: SELECT T1.name FROM friend AS T1 JOIN highschooler AS T2
ON T1.student_id = T2.id WHERE T2.grade > 5 GROUP BY T1.student_id HAVING count(*) >= 2
Schema linking
schema linking 是自然语言(其实是数据中的question)与数据库的schema中的column 和table建立关系的过程,即识别question中提及到的数据库中的column 和table并给定类型的过程。text-to-SQL的schema linking可以理解为实体链接的过程,实体是数据库schema中的column,table和value等。
在IRNet中schema linking识别question中的column 和table的过程是采用的n-gram方式,识别column和table的方式相类似,大致过程如下:
(1)NL与column按照n-gram进行完全匹配和部分匹配,NL中相匹配的span给定type为column
(2)NL与table按照n-gram进行完全匹配和部分匹配,NL中相匹配的span给定type为table
(3)NL与数据库中的value进行匹配,
question中的字符进行value匹配,这个用到了Concept Net,只考虑Concept-Net中的两个类型,分为is a type of 和related terms,这里只考虑columns的值为这两个类型的columns,根据value的match情况,将columns的类型定义为为VALUE EXACT MATCH 或者VALUE PARTIAL MATCH。
注意: (1)这里的n-gram匹配是不重合匹配
(2) 如果n-gram匹配同时匹配到column和table,优先column匹配
(3) column的类型有4个类型分别为exact match,partial match, value exact match 和value partial match
模型
IRNET的模型结构图如下:
模型分为NL Encoder,Schema Encoder和Decoder三个部分。
(1)NL Encoder
对输入的question和schema linking后的type进行encoder,这里采用的glove 进行embedding或者bert ,LSTM进行encoder。NL进行encoder前,将label type为column,table, value这三个类型拼接到对应的span前进行embedding
(2)Schema Encoder
schema encoder分别是对 schema linking后的column和table进行encoder的过程,这个schema encoder的结果用于decoder中,在解码中的pointer network以及selection column和 selection table 中使用。对column的encoder和table的encoder相似,schema 的表示为 s=(c, t) ,其中
c
=
{
(
c
1
,
ϕ
1
)
,
(
c
2
,
ϕ
2
)
,
.
.
.
,
c
n
,
ϕ
n
}
c=\{(c_{1},\phi_{1}), (c_{2},\phi_{2}), ..., c_{n},\phi_{n}\}
c={(c1,ϕ1),(c2,ϕ2),...,cn,ϕn}表示column,
t
=
{
t
1
,
t
2
,
.
.
.
,
t
m
}
t = \{t_{1}, t_{2}, ..., t_{m}\}
t={t1,t2,...,tm}表示table ,
对column的encoder 的过程如下:
i. 对column中的中每一个字符
c
i
c_{i}
ci进行embedding,然后对embedding 进行取均值得到vector
e
^
c
i
\hat{e}_{c}^{i}
e^ci
ii. 对column
c
i
c_{i}
ci中的类型
ϕ
i
\phi_{i}
ϕi进行embedding得到特征矩阵
φ
i
\varphi_{i}
φi
iii. 对column中的特征向量进行余弦相似度加权计算得到context vector
c
c
i
c_{c}^{i}
cci
iv. 对上述三个步骤的特征矩阵求和得到最后column的embedding
对table的encoder 过程如下:
(3)Decoder
decoder的目的是生成SemQL,这里使用了applyrule, selection columns 和selection table,applyrule。基于给定的SemQL 的树结构,我们使用基于利用LSTM的语法的解码器通过动态 过程生成 SemQL。在解码的时候使用了coarse-to-fine框架,解码的时候先得到SemQL,然后采用细粒度的解码器补充解码器选择column和table丢失的细节。Decoder公式如下:
其中
a
i
a_{i}
ai是当前时刻i的action taken,
a
<
i
a_{<i}
a<i是时刻i以前的序列 action,
T
T
T是总共的时间步
T
T
T。
这里的applyrule 是一个感知机层
i. selection column 和selection table
在selection column 和selection table中使用了pointer network,pointer network可以理解为一个记忆网络,记忆选过的column,decoder中有一个门控单元决定是否从memory中选择column还是数据库schema选择,如果从schema选择column,则该column加入pointer network中的memory。selection column和selection table的数学表达式如下:
ii Coarse-to-fine
在解码的时候采用有粗到细的过程。模型图如下:
具体公式表达如下:
实验结果
在spider 数据集上,IRNet + glove 在验证集和测试集上的结果分别为53.2%,46.7%,IRNet + bert 在在验证集和测试集上的结果分别为61.9%,54.7%。在Dev Set上IRNet共有483个错误预测。主要分为三类:Column Prediction(占比32.3%),Nested Query(23.9%),Operator(12.4%)。具体如下:
i. Column Prediction错误:基于字段值没能正确找到对应的column
ii. Nested Query错误:大部分是由于Extra Hard level的复杂嵌套查询没能准确构造
iii. Operator错误:部分operators的选择要求机器有一定常识,如“from old to young”需要把年龄降序排列
上述内容为IRNET模型的内容,如有错误,欢迎指正。