先膜拜一波大佬,太牛了
对照Pytorch代码实现BiLSTM+CRF与大佬的解析相信看到这篇文章的你也能解惑。
下面是华丽的分割线
大佬对BiLSTM+CRF的解析非常透彻,解决了我很多迷惑的地方。。。详细的步骤我就不介绍了请参阅参考文献。
pytorch 实例
# -*- coding: utf-8 -*-
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
torch.manual_seed(1)
def argmax(vec):
# return the argmax as a python int
_, idx = torch.max(vec, 1)
# 返回两个列表,一个是最大值列表,另一个是dim = 1,
# 目的是取出当前的最大值在行向量中位置。
return idx.item()
def prepare_sequence(seq, to_ix):
idxs = [to_ix[w] for w in seq]
return torch.tensor(idxs, dtype=torch.long)
# Compute log sum exp in a numerically stable way for the forward algorithm
# 以前向算法的数值稳定方式计算log-sum-exp.
def log_sum_exp(vec):
max_score = vec[0, argmax(vec)]
max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
return max_score + \
torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
############################################################################## #
class BiLSTM_CRF(nn.Module):
def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
super(BiLSTM_CRF, self).__init__()
self.embedding_dim = embedding_dim # 5
self.hidden_dim = hidden_dim# 3
self.vocab_size = vocab_size
self.tag_to_ix = tag_to_ix
self.tagset_size = len(tag_to_ix) # 5
self.word_embeds = nn.Embedding(vocab_size, embedding_dim)#(25,5)
self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
num_layers=1, bidirectional=True)#(5, 1, num_layers =1, Bi=True)
# Maps the output of the LSTM into tag space.将LSTM的输出映射到标记空间。
self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)#(3,5)
# Matrix of transition parameters. Entry i,j is the score of
# transitioning *to* i *from* j.
# 转移参数矩阵. 条目i,j是*从* j转换*到* i 的分数。
# 转移矩阵是训练的得来的,它是随机初始化的。
self.transitions = nn.Parameter(
torch.randn(self.tagset_size, self.tagset_size))
# These two statements enforce the constraint that we never transfer
# to the start tag and we never transfer from the stop tag
# 这两个语句强制执行约束: 我们从不转移到start-tag的,我们永远也不会从stop-tag转移.
# 强制设置START和STOP的值为-10000
self.transitions.data[tag_to_ix[START_TAG], :] = -10000 #也就是第4行的位置全部设置为-10000
self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000 #也就是第5列的位置