论文来源:COLING 2018 Best Paper
论文链接:SGM: Sequence Generation Model for Multi-label Classification
我记得我开知乎专栏的第一篇文章写的是如何去做关于用户评论的情感分类,这其实也是一个多标签分类 (Multi-label Classification,简称 MLC) 问题。这几天重温了一下目前关于多标签分类的论文,发现了一个很有意思的研究方向:利用 Seq2Seq 的思想去做多标签分类,这么做的动机是因为往往多标签分类时的多个标签彼此之间是相互关联的,而传统的多标签分类方法是把问题看成多个单标签分类来做 (可以参考我之前写过的文章),因而失去了对这些相互关联信息的提取。咋看一下,多么 natural 的 idea,这样直观的 idea 被评为最佳论文,大家丝毫不会觉得奇怪。(为啥之前就没人想到呢。。。)
今天,我们来讲一下这篇 best paper。
传统的多标签问题解决思路
利用深度学习或者传统的机器学习,我们都可以很轻松的处理一个多标签问题。简单来说,像之前那个 AI Challenger 的比赛,我的做法是把 MLC 任务转换成许多个单标签分类的问题。可是,这种方法忽略了标签之间的相互关系;而且,在大数据集上的计算代价非常昂贵,比如比赛的数据集一共有 20 个细粒度,意味着我们要训练 20 个模型才行,那万一有 100 个细粒度呢,岂不是要训练 100 个模型才行吗 ?这样的道路感觉越走越偏了。
目前多标签分类方法存在的问题
这样的道路在标签个数有限的条件下还是可行的,可万一标签很多呢。并且上述方法是没有考虑标签间的相关性,而这种相关性可能能够提高特定问题上模型的效果。
例如,在对知乎文章进行分类的时候,我们经常能够看到标签数据挖掘和机器学习一起出现,而数据挖掘和机器学习一起出现的几率就会降低一些,我们基本可以从中得知,标签数据挖掘的文章具有较大的可能也可以具有标签机器学习,这便是标签间相关性对多标签问题模型的促进效果。
SGM 模型细节和实现
这篇论文提出一个自己的模型,叫 SGM。
论文主要的贡献是:
1. 把多标签分类问题当做序列生成问题,进而考虑标签间相关性
2. 在序列生成模型的 decode 部分进行了改造,不但考虑了标签间相关性,还自动获取了输入文本的关键信息(Attention机制)
3. 本论文提出的方法效果极好,指标比 baseline 提升很多。在关系表示上也具有非常好的效果。
模型如下图所示:
下面,让我来梳理一下论文中提出的模型细节。
Encoder
令 为 m 个单词的序列,
是第
个单词的 one-hot 表示。我们首先通过一个嵌入矩阵 (embedding matrix),把
嵌入成一个稠密的嵌入向量
,
是词汇表的大小,
是嵌入向量的维度。
我们使用一个bidirectional LSTM 从两个方向上来读取文本序列 x,并且计算每个单词的隐藏状态:
我们通过连接两个方向上的隐藏状态来得到第 个单词的最终隐藏状态,
这使得状态具有以第 个单词为中心的序列信息。
class rnn_encoder(nn.Module):
def __init__(self, config, vocab_size, embedding=None):
super(rnn_encoder, self).__init__()
if embedding is not None:
self.embedding = embedding
else:
self.embedding = nn.Embedding(vocab_size, config.emb_size)
self.rnn = nn.LSTM(input_size=config.emb_size, hidden_size=config.encoder_hidden_size,
num_layers=config.num_layers, dropout=config.dropout, bidirectional=config.bidirec)
self.config = config
def forward(self, input, lengths):
embs = pack(self.embedding(input), lengths)
outputs, (h, c) = self.rnn(embs)
outputs = unpack(outputs)[0]
if not self.config.bidirec:
return outputs, (h, c)
else:
batch_size = h.size(1)
h = h.transpose(0, 1).contiguous().view(batch_size, -1, 2 * self.config.encoder_hidden_size)
c = c.transpose(0, 1).contiguous().view(batch_size, -1, 2 * self.config.encoder_hidden_size)
state = (h.transpose(0, 1), c.transpose(0, 1))
return outputs, state
Attention
当模型预测不同的标签的时候,并不是所有的单词贡献相同。注意力机制会通过关注文本序列中的不同部分,产生一个上下文向量 (context vector)。
特别的,本文采用的 Attention 是 global attention,我在以前文章中提到过,这里就不列举公式了,看代码反而更容易理解一些。
class global_attention(nn.Module):
def __init__(self, hidden_size, activation=None):
super(global_attention, self).__init__()
self.linear_in = nn.Linear(hidden_size, hidden_size)
self.linear_out = nn.Linear(2*hidden_size, hidden_size)
self.softmax = nn.Softmax()
self.tanh = nn.Tanh()
self.activation = activation
def forward(self, x, context):
# x: batch * hidden_size
# context: batch * time * hidden_size
# batch * hidden_size * 1
gamma_h = self.linear_in(x).unsqueeze(2)
if self.activation == 'tanh':
gamma_h = self.tanh(gamma_h)
# batch * time * hidden_size batch * hidden_size * 1 => batch * time * 1 => batch * time
weights = torch.bmm(context, gamma_h).squeeze(2)
# batch * time
weights = self.softmax(weights)
# batch * 1 * time batch * time * hidden_size => batch * 1 * hidden_size => batch * hidden_size
c_t = torch.bmm(weights.unsqueeze(1), context).squeeze(1)
# batch * 2 * hidden_size => batch * hidden_size
output = self.tanh(self.linear_out(torch.cat([c_t, x], 1)))
# output: batch * hidden_size
# weights: batch * time
return output, weights
Decoder
Decoder在第 时刻的隐藏状态计算如下:
其中, 的意思是
和
的连接,
是标签的嵌入,这里的标签指的是在
分布下的最高概率对应的标签。
是在
时刻在标签空间
上的概率分布,计算如下:
在训练阶段,损失函数是 cross-entropy loss function。我们利用 beam search 算法在inference 的时候来找 top-ranked 预测。以 eos 结尾的预测路径加入到了候选路径集合。
class StackedLSTM(nn.Module):
def __init__(self, num_layers, input_size, hidden_size, dropout):
super(StackedLSTM, self).__init__()
self.dropout = nn.Dropout(dropout)
self.num_layers = num_layers
self.layers = nn.ModuleList()
for i in range(num_layers):
self.layers.append(nn.LSTMCell(input_size, hidden_size))
input_size = hidden_size
def forward(self, input, hidden):
h_0, c_0 = hidden
h_1, c_1 = [], []
for i, layer in enumerate(self.layers):
h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
input = h_1_i
if i + 1 != self.num_layers:
input = self.dropout(input)
h_1 += [h_1_i]
c_1 += [c_1_i]
h_1 = torch.stack(h_1)
c_1 = torch.stack(c_1)
return input, (h_1, c_1)
class rnn_decoder(nn.Module):
def __init__(self, config, vocab_size, embedding=None, score_fn=None):
super(rnn_decoder, self).__init__()
if embedding is not None:
self.embedding = embedding
else:
self.embedding = nn.Embedding(vocab_size, config.emb_size)
self.rnn = StackedLSTM(input_size=config.emb_size, hidden_size=config.decoder_hidden_size,
num_layers=config.num_layers, dropout=config.dropout)
self.score_fn = score_fn
if self.score_fn.startswith('general'):
self.linear = nn.Linear(config.decoder_hidden_size, config.emb_size)
elif score_fn.startswith('concat'):
self.linear_query = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size)
self.linear_weight = nn.Linear(config.emb_size, config.decoder_hidden_size)
self.linear_v = nn.Linear(config.decoder_hidden_size, 1)
elif not self.score_fn.startswith('dot'):
self.linear = nn.Linear(config.decoder_hidden_size, vocab_size)
if hasattr(config, 'att_act'):
activation = config.att_act
print('use attention activation %s' % activation)
else:
activation = None
self.attention = models.global_attention(config.decoder_hidden_size, activation)
self.hidden_size = config.decoder_hidden_size
self.dropout = nn.Dropout(config.dropout)
self.config = config
if self.config.global_emb:
self.gated1 = nn.Linear(config.emb_size, config.emb_size)
self.gated2 = nn.Linear(config.emb_size, config.emb_size)
def forward(self, inputs, init_state, contexts):
if not self.config.global_emb:
embs = self.embedding(inputs)
outputs, state, attns = [], init_state, []
for emb in embs.split(1):
output, state = self.rnn(emb.squeeze(0), state)
output, attn_weights = self.attention(output, contexts)
output = self.dropout(output)
outputs += [output]
attns += [attn_weights]
outputs = torch.stack(outputs)
return outputs, state
else:
outputs, state, attns = [], init_state, []
embs = self.embedding(inputs).split(1)
max_time_step = len(embs)
emb = embs[0]
output, state = self.rnn(emb.squeeze(0), state)
output, attn_weights = self.attention(output, contexts)
output = self.dropout(output)
soft_score = F.softmax(self.linear(output))
outputs += [output]
attns += [attn_weights]
batch_size = soft_score.size(0)
a, b = self.embedding.weight.size()
for i in range(max_time_step-1):
emb1 = torch.bmm(soft_score.unsqueeze(1), self.embedding.weight.expand((batch_size, a, b)))
emb2 = embs[i+1]
gamma = F.sigmoid(self.gated1(emb1.squeeze())+self.gated2(emb2.squeeze()))
emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze()
output, state = self.rnn(emb, state)
output, attn_weights = self.attention(output, contexts)
output = self.dropout(output)
soft_score = F.softmax(self.linear(output))
outputs += [output]
attns += [attn_weights]
outputs = torch.stack(outputs)
return outputs, state
Global Embedding
是 label 的 embedding,这个 label 是在
分布下的最高概率所对应标签得来的。可是,这个计算只是贪心的利用了
的最大值。在论文提出的 SGM 模型中,基于先前预测的标签来产生下一个标签。因此,如果在第
时刻得到了错误的预测,然后就会在预测下一个标签的时候得到了一个错误的后继标签,这也叫做 exposure bias (错上加错)。这也是为什么采用贪心法在 decoder 部分不合适的地方,这个道理不仅适用于这篇论文的任务,对于机器翻译、自动摘要等任务,仍然适用。
所以,便有了 beam search 算法的产生。beam search 算法从一定程度上缓解了这个问题,有兴趣的可以自行去搜索下 beam search 算法的原理,但是它仍然不能从根本上解决这个问题,因为 exposure bias 可能会出现在所有的路径上。 表示在
时刻的概率分布,很显然
中的所有信息对我们在第
时刻预测标签是有帮助的。通过考虑所有包含在
中的有效信号,exposure bias 问题应该会得到缓解。
基于这一点,论文继而提出了一个新的 decoder 结构,其中在 时刻中的
可以表示第
时刻的整体信息。受 highway network 中 adaptive gate 的想法的启发,这里引入 global embedding,
表示时刻
输出对应的 Embedding 的 label,在不使用 Global Embedding 时,
; 当使用 Global Embedding,让其等于某时刻
的概率向量和各个标签 Embedding 的乘积,降低
对
值的影响。
Global Embedding 的计算方法如下:
其中 是 transform gate,用于控制带权平均嵌入的比例。所有的
为权重矩阵。通过考虑每一个 label 的概率,模型可以减少先前时间步带来的错误预测的损失。这使得模型预测得更加准确。
代码细节如下,
if self.config.global_emb:
self.gated1 = nn.Linear(config.emb_size, config.emb_size)
self.gated2 = nn.Linear(config.emb_size, config.emb_size)
...
a, b = self.embedding.weight.size()
emb1 = torch.bmm(soft_score.unsqueeze(1), self.embedding.weight.expand((batch_size, a, b)))
emb2 = embs[i+1]
gamma = F.sigmoid(self.gated1(emb1.squeeze())+self.gated2(emb2.squeeze()))
emb = gamma * emb1.squeeze() + (1 - gamma) * emb2.squeeze()
回顾
回顾一下我认为本文比较出彩的一些地方。
Seq2Seq 模型的输入和输出均为序列,且能够学习到输入和输出序列的相关性。对于文本的多标签分类问题,怎么看怎么都是一种很 natural 的 idea。
虽然说想法很 natural,但是也不是说直接把 seq2seq 搬过来就能用 (那么简单的话早就有人想到了。。。),在这其中也会遇到一些问题,本文的作者的一个贡献点也是提出了解决这些问题的方法。
- 问题: 多标签分类的输出显然是不能重复的。
- 解决方法: 作者在最终
输出的时候引入了
将已输出的标签剔除。
的表示如下,如果标签已经被输出了,则
为负无穷,
- 问题: Seq2Seq 中某时刻
的输出对时刻
的输出影响很大,也就是说时刻
出错会对时刻
之后的所有输出造成严重影响。
- 解决方法: 在多标签分类问题中,我们显然不想让标签间拥有如此强的关联性,于是作者提出 Global Embedding 来解决这个问题。
此外,作者在数据预处理阶段也采用一些技巧。比如,考虑到出现次数更多的标签在标签相关性训练中具有更强的作用,在训练时把标签按照其出现次数进行从高到低排序作为输出序列。这么做的好处是,出现次数更多的标签可以出现 LSTM 的前面,进而更好地指导整个标签的输出。
参考资料
论文笔记:SGM: Sequence Generation Model for Multi-label Classification