简介
query_pos初始化为shape(n,256)
改进点1:在cross_att时候,内容信息和位置信息解偶
具体方法:
(
q
c
o
n
t
e
n
t
+
q
p
o
s
)
∗
(
k
c
o
n
t
e
n
t
+
k
p
o
s
)
T
(q_{content}+q_{pos}) * (k_{content}+k_{pos})^T
(qcontent+qpos)∗(kcontent+kpos)T
改成
c
a
t
(
q
c
,
q
p
)
∗
c
a
t
(
k
c
,
k
p
)
T
cat(q_c,q_p) * cat(k_c,k_p)^T
cat(qc,qp)∗cat(kc,kp)T
解释
首先,把特征侧和目标侧分为内容信息(content)和位置(content)
把query分为q_content和q_pos,更好理解query.对应于我之前说的query和query_pos.
第一式子:
q
c
∗
k
c
T
+
q
p
∗
k
p
T
+
q
c
∗
k
p
T
+
q
c
T
∗
k
p
q_c*k_c^T+q_p*k_p^T+q_c * k_p^T+q_c^T * k_p
qc∗kcT+qp∗kpT+qc∗kpT+qcT∗kp
第二式子:
q
c
∗
k
c
T
+
q
p
∗
k
p
T
q_c*k_c^T+q_p*k_p^T
qc∗kcT+qp∗kpT
实现各计算各的相似度再合并
改进点2:在cross_att时候,位置信息包含参考点位置编码和基于内容信息预测的参考点缩放量
具体方法:
参考点经过正余弦位置编码生成位置信息的base量shape(n,256)
内容信息经过线性映射生成位置信息的缩放量shape(n,256)
效果
1.加速收敛,训练的epoch明显减少
2.内容和位置注意力解偶