torch.multinomial()理解

本文详细解析了PyTorch中的multinomial函数,介绍了如何使用此函数进行有放回或无放回的采样,并通过实例展示了不同参数设置下的采样结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch.multinomial(input, num_samples,replacement=False, out=None) → LongTensor

作用是对input的每一行做n_samples次取值,输出的张量是每一次取值时input张量对应行的下标。

输入是一个input张量,一个取样数量,和一个布尔值replacement。

input张量可以看成一个权重张量,每一个元素代表其在该行中的权重。如果有元素为0,那么在其他不为0的元素

被取干净之前,这个元素是不会被取到的。

n_samples是每一行的取值次数,该值不能大于每一样的元素数,否则会报错。

replacement指的是取样时是否是有放回的取样,True是有放回,False无放回。

看官方给的例子:
>>> weights = torch.Tensor([0, 10, 3, 0]) # create a Tensor of weights
>>> torch.multinomial(weights, 4)

 1
 2
 0
 0
[torch.LongTensor of size 4]

>>> torch.multinomial(weights, 4, replacement=True)

 1
 2
 1
 2
[torch.LongTensor of size 4]

输入是[0,10,3,0],也就是说第0个元素和第3个元素权重都是0,在其他元素被取完之前是不会被取到的。

所以第一个multinomial取4次,可以试试重复运行这条命令,发现只会有2种结果:[1 2 0 0]以及[2 1 0 0],以[1 2 0 0]这种情况居多。这其实很好理解,第1个元素权重比第2个元素权重要大,所以先取第1个元素的概率就会大。在第1和2个元素取完之后,剩下了2个没有权重的元素,它们才会被取到。但实际上权重为0的元素被取到时也不会显示正确的下标,关于0的下标问题我还没有想到很合理的解释,先行略过。

而第二个multinomial取4次,发现就只会出现1和2这两个元素了。这是因为replacement为真,所以有放回,就永远也不会取到权重为0的元素了。

再试试输入二维张量,则返回的也会成为一个二维张量,行数为输入的行数,列数为n_samples,即每一行都取了n_samples次,取法和一维张量相同。

def generate_text(model, start_string, max_len=1000, temperature=1.0, stream=True): input_eval = torch.Tensor([char2idx[char] for char in start_string]).to(dtype=torch.int64, device=device).reshape(1, -1) #bacth_size=1, seq_len长度是多少都可以 (1,5) hidden = None text_generated = [] #用来保存生成的文本 model.eval() pbar = tqdm(range(max_len)) # 进度条 print(start_string, end="") # no_grad是一个上下文管理器,用于指定在其中的代码块中不需要计算梯度。在这个区域内,不会记录梯度信息,用于在生成文本时不影响模型权重。 with torch.no_grad(): for i in pbar:#控制进度条 logits, hidden = model(input_eval, hidden=hidden) # 温度采样,较高的温度会增加预测结果的多样性,较低的温度则更加保守。 #取-1的目的是只要最后,拼到原有的输入上 logits = logits[0, -1, :] / temperature #logits变为1维的 # using multinomial to sampling probs = F.softmax(logits, dim=-1) #算为概率分布 idx = torch.multinomial(probs, 1).item() #从概率分布中抽取一个样本,取概率较大的那些 input_eval = torch.Tensor([idx]).to(dtype=torch.int64, device=device).reshape(1, -1) #把idx转为tensor text_generated.append(idx) if stream: print(idx2char[idx], end="", flush=True) return "".join([idx2char[i] for i in text_generated]) # load checkpoints model.load_state_dict(torch.load("checkpoints/text_generation/best.ckpt", weights_only=True,map_location="cpu")) start_string = "All: " #这里就是开头,什么都可以 res = generate_text(model, start_string, max_len=1000, temperature=0.5, stream=True)这段代码有什么用
03-11
``` import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import numpy as np class ActorCritic(nn.Module): def __init__(self, state_dim, action_dim): super(ActorCritic, self).__init__() self.fc1 = nn.Linear(state_dim, 128) self.fc2 = nn.Linear(128, 128) self.actor = nn.Linear(128, action_dim) self.critic = nn.Linear(128, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) action_probs = F.softmax(self.actor(x), dim=-1) state_value = self.critic(x) return action_probs, state_value class A2CScheduler: def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99): self.model = ActorCritic(state_dim, action_dim) self.optimizer = optim.Adam(self.model.parameters(), lr=lr) self.gamma = gamma def select_action(self, state): state = torch.FloatTensor(state).unsqueeze(0) action_probs, _ = self.model(state) action = torch.multinomial(action_probs, 1).item() return action, action_probs[:, action] def update(self, trajectory): rewards, log_probs, state_values = [], [], [] for (state, action, reward, log_prob, state_value) in trajectory: rewards.append(reward) log_probs.append(log_prob) state_values.append(state_value) returns = [] R = 0 for r in reversed(rewards): R = r + self.gamma * R returns.insert(0, R) returns = torch.tensor(returns) log_probs = torch.stack(log_probs) state_values = torch.stack(state_values).squeeze() advantage = returns - state_values actor_loss = -log_probs * advantage.detach() critic_loss = F.mse_loss(state_values, returns) loss = actor_loss.mean() + critic_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step() # 结合 `mp-quic-go` 使用 # 1. 获取状态信息 (如带宽、RTT、丢包等) # 2. 选择路径 (基于 `select_action` 方法) # 3. 收集数据并训练模型 (基于 `update` 方法)```请详细解释每一行代码的含义和意义
04-02
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值