import torch
from transformers import XLNetTokenizer, XLNetLMHeadModel
tokenizer = XLNetTokenizer.from_pretrained('/data_local/plm_models/xlnet_cased_L-12_H-768_A-12/')
model = XLNetLMHeadModel.from_pretrained('/data_local/plm_models/xlnet_cased_L-12_H-768_A-12/')
# xlnet-base-cased
input_ids = torch.tensor(tokenizer.encode("I love <mask> .", add_special_tokens=False)).unsqueeze(0)
# 输入"I love <mask> ." 与 "I love you ." 最后的预测的结果一致,因为perm_mask 指定you或者<mask>不可见
print(tokenizer.tokenize("I love <mask> ."))
print(tokenizer.tokenize("I love you ."))
print(input_ids)
perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[:, :, -3] = 1.0
# 不能使用单词you
target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)
# 预测you位置的单词
target_mapping[0, 0, -3] = 1.0
outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
next_token_logits = outputs[0]
print(outputs[0])
# predicted_index = torch.argmax(next_token_logits[0][0]).item()
predicted_li = torch.topk(next_token_logits[0][0], 5)
for predicted_index in predicted_li[1]:
predicted_index = predicted_index.item()
predicted_token = tokenizer.convert_ids_to_tokens(predicted_index)
print(predicted_token)
# 结果怎么感觉不对呢?
运行结果:
['▁I', '▁love', '<mask>', '▁', '.']
['▁I', '▁love', '▁you', '▁', '.']
tensor([[ 35, 564, 6, 17, 9]])
tensor([[[-27.0873, -43.3048, -43.4165, ..., -35.7004, -31.0599, -33.4355]]],
grad_fn=<AddBackward0>)
▁
▁love
▁Love
▁all
▁loving
感觉这预测的不太对啊?但是参考nlp - How to get words from output of XLNet using Transformers library - Stack Overflow
确实是这么做的。
还有如果要同时预测n个位置的词,要分别制定好perm_mask 和 target_mapping 进行n次预测。