这是一段使用百度ernie-1.0做特征提取的Bi-Lstm+crf的代码:
class ERNIE_LSTM_CRF(nn.Module):
"""
ernie_lstm_crf model
"""
def __init__(self, ernie_config, tagset_size, embedding_dim, hidden_dim, rnn_layers, dropout_ratio, dropout1, use_cuda=False):
super(ERNIE_LSTM_CRF, self).__init__()
self.embedding_dim = embedding_dim
self.hidden_dim = hidden_dim
##加载ERNIE
self.word_embeds = AutoModel.from_pretrained(ernie_config)
# self.word_embeds = ErnieModel.from_pretrained(ernie_config, from_hf_hub=False)
self.lstm = nn.LSTM(embedding_dim, hidden_dim,
num_layers=rnn_layers, bidirectional=True,
dropout=dropout_ratio, batch_first=True)
self.rnn_layers = rnn_layers
self.dropout1 = nn.Dropout(p=dropout1)
self.crf = CRF(num_tags=tagset_size, batch_first&#