上一节地址:解读pytorch中的biLSTM_CRF源码(一)
本节目录
预定义的helper函数
class biLSTM_CRF中的部分成员方法__init__,init_hidden,_get_lstm_features
预定义的helper函数
1.to_scalar
def to_scalar(var): # 将变量转化为标量
# returns a python float
return var.view(-1).data.tolist()[0]
变量将张量对象、梯度及创建张量对象的函数引用封装起来;
data:获取变量相关的张量
scalar:包含一个元素的张量称为标量
2.argmax
def argmax(vec):
# return the argmax as a python int
_, idx = torch.max(vec, 1)
return to_scalar(idx)
3.prepare_sequence
def prepare_sequence(seq, to_ix):
idxs = [to_ix[w] for w in seq]