xlnet预测mask位置单词

import torch
from transformers import XLNetTokenizer, XLNetLMHeadModel

tokenizer = XLNetTokenizer.from_pretrained('/data_local/plm_models/xlnet_cased_L-12_H-768_A-12/')
model = XLNetLMHeadModel.from_pretrained('/data_local/plm_models/xlnet_cased_L-12_H-768_A-12/')
# xlnet-base-cased

input_ids = torch.tensor(tokenizer.encode("I love <mask> .", add_special_tokens=False)).unsqueeze(0)
# 输入"I love <mask> ." 与 "I love you ." 最后的预测的结果一致,因为perm_mask 指定you或者<mask>不可见
print(tokenizer.tokenize("I love <mask> ."))
print(tokenizer.tokenize("I love you ."))
print(input_ids)

perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[:, :, -3] = 1.0
# 不能使用单词you

target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)
# 预测you位置的单词
target_mapping[0, 0, -3] = 1.0

outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
next_token_logits = outputs[0]
print(outputs[0])

# predicted_index = torch.argmax(next_token_logits[0][0]).item()
predicted_li = torch.topk(next_token_logits[0][0], 5)
for predicted_index in predicted_li[1]:
    predicted_index = predicted_index.item()
    predicted_token = tokenizer.convert_ids_to_tokens(predicted_index)
    print(predicted_token)
# 结果怎么感觉不对呢?

运行结果:

['▁I', '▁love', '<mask>', '▁', '.']
['▁I', '▁love', '▁you', '▁', '.']
tensor([[ 35, 564,   6,  17,   9]])
tensor([[[-27.0873, -43.3048, -43.4165,  ..., -35.7004, -31.0599, -33.4355]]],
       grad_fn=<AddBackward0>)
▁
▁love
▁Love
▁all
▁loving

感觉这预测的不太对啊?但是参考nlp - How to get words from output of XLNet using Transformers library - Stack Overflow

确实是这么做的。

还有如果要同时预测n个位置的词,要分别制定好perm_mask 和  target_mapping 进行n次预测。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

旺旺棒棒冰

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值