Tensorflow小技巧整理:tf.multinomial()采样

在生成任务中,除了tf.argmax(),tf.multinomial()提供了一种概率分布采样的方法。它接受二维logits张量,并根据概率分布进行采样,返回词汇id,为生成增加了多样性。例如,在处理[batch_size, vocab_size]的logits时,它能按概率生成样本。" 123300482,12640680,PyCharm中配置Anaconda环境加载CUDA及PyTorch,"['pytorch', 'pycharm', 'python', 'conda', 'cuda']

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

tf.multinomial()

做生成任务时,得到 decoder 最终的输出之后,就需要决策选如何利用得到的输出张量进行生成。tf.argmax()是最简单最粗暴的一种方法,直接选取概率最大的词汇作为输出。beam search 等算法的出现,使得生成的结果有了更多的可能性。最近看到一段代码,使用的是 tf.multinomial() 进行采样,也尝试用了一下。

tf.multinomial(logits, num_samples, seed=None, name=None)

logits是一个二维张量,num_samples指的是采样的个数。其实很好理解,我们生成每个时刻的 logits 时,输出维度应该是 [ batch_size, vocab_size ] 形式的,代表着该时刻,每一个batch对应的词典中各词汇生成的概率。tf.multinomial() 将按照该概率分布进行采样,返回的值是 logits 第二维上的 id,也就是我们需要的字典的 id。
举个例子:

比如每次将从5个候选词汇中采样,概率分布如图所示,采样个数为100,统计一下结果如下:
可以看到,第一个词和最后一个词的采样次数会高很多,而概率为 0.05 的第二个词和第三个词则很少被采样到。如果5个词概率相同:
则我们的采样结果为:
C:\Users\14473> & C:/Users/14473/AppData/Local/Microsoft/WindowsApps/python3.11.exe "g:/import os.py" Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered. Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 8.94it/s] 测试输入: 打印机显示缺纸怎么办 Traceback (most recent call last): File "g:\import os.py", line 63, in <module> output = test_model(test_input) ^^^^^^^^^^^^^^^^^^^^^^ File "g:\import os.py", line 48, in test_model outputs = model.generate( ^^^^^^^^^^^^^^^ File "C:\Users\14473\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\14473\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\transformers\generation\utils.py", line 2223, in generate result = self._sample( ^^^^^^^^^^^^^ File "C:\Users\14473\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\LocalCache\local-packages\Python311\site-packages\transformers\generation\utils.py", line 3257, in _sample next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
03-08
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值