关于F.grid_sample中的坐标系理解

本文详细解析了PyTorch中grid_sample函数的工作原理,重点介绍了如何使用该函数进行坐标转换,并通过实例展示了其在3D张量上的应用过程。

torch.nn.functional.grid_sample()函数的参数grid,表示的是范围为[-1, 1]坐标系下的(x, y, z),坐标与数组的对应关系是:

x -> w, y -> h, z -> d,测试代码如下:

import numpy as np
from torch.nn import functional as F
import torch

if __name__ == '__main__':
    d, h, w = 8, 10, 12
    input = torch.zeros((2, 1, 8, 10, 12), dtype=torch.float32)
    input[:, 0, 2, 3, 4] = 1
    grid = torch.zeros((2, 1, 1, 1, 3), dtype=torch.float32)
    x, y, z = 4, 3, 2 # 对应input的w, h, d
    # rescale to [-1, 1]
    x = 2. * x / (w - 1) - 1.
    y = 2. * y / (h - 1) - 1.
    z = 2. * z / (d - 1) - 1.
    grid[0, 0, 0, 0, :] = torch.from_numpy(np.array([x, y, z]).astype(np.float32))
    grid[1, 0, 0, 0, :] = torch.from_numpy(np.array([x, y, z]).astype(np.float32))
    out = F.grid_sample(input, grid, mode='nearest')
    print(out)

可以看到输出为:

tensor([[[[[1.]]]],

        [[[[1.]]]]])

def run_marl(self): self.init_saved_model() run_episode = self.train_config.run_episode_before_train if "ppo" in self.env_config.learn_policy else 1 for epoch in range(self.current_epoch, self.train_config.epochs + 1): # 在正式开始训练之前做一些动作并将信息存进记忆单元中 # grid_wise_control系列算法和常规marl算法不同, 是以格子作为观测空间。 # ppo 属于on policy算法,训练数据要是同策略的 total_reward = 0 if "grid_wise_control" in self.env_config.learn_policy and isinstance(self.batch_episode_memory, GridBatchEpisodeMemory): for i in range(run_episode): self.env.reset() finish_game = False cycle = 0 while not finish_game and cycle < self.env_config.max_cycles: grid_input = self.env.get_grid_input() unit_pos = self.env.get_agents_approximate_pos() actions_with_name, actions, log_probs = self.agents.choose_actions_in_grid(unit_pos=unit_pos, grid_input=grid_input) observations, rewards, finish_game, infos = self.env.step(actions_with_name) grid_input_next = self.env.get_grid_input() self.batch_episode_memory.store_one_episode(grid_input, grid_input_next, unit_pos, actions, rewards, log_probs) total_reward += rewards cycle += 1 self.batch_episode_memory.set_per_episode_len(cycle) elif isinstance(self.batch_episode_memory, CommBatchEpisodeMemory): for i in range(run_episode): obs = self.env.reset()[0] finish_game = False cycle = 0 while not finish_game and cycle < self.env_config.max_cycles: state = self.env.state() actions_with_name, actions, log_probs = self.agents.choose_actions(obs) obs_next, rewards, finish_game, infos = self.env.step(actions_with_name) state_next = self.env.state() if "ppo" in self.env_config.learn_policy: self.batch_episode_memory.store_one_episode(one_obs=obs, one_state=state, action=actions, reward=rewards, log_probs=log_probs) else: self.batch_episode_memory.store_one_episode(one_obs=obs, one_state=state, action=actions, reward=rewards, one_obs_next=obs_next, one_state_next=state_next) total_reward += rewards obs = obs_next cycle += 1 self.batch_episode_memory.set_per_episode_len(cycle) if "ppo" in self.env_config.learn_policy: # 可以用一个policy跑一个batch的数据来收集,由于性能问题假设batch=1,后续来优化 batch_data = self.batch_episode_memory.get_batch_data() self.agents.learn(batch_data) self.batch_episode_memory.clear_memories() else: self.memory.store_episode(self.batch_episode_memory) self.batch_episode_memory.clear_memories() if self.memory.get_memory_real_size() >= 10: for i in range(self.train_config.learn_num): batch = self.memory.sample(self.train_config.memory_batch) self.agents.learn(batch, epoch) # avg_reward = self.evaluate() avg_reward = total_reward / run_episode one_result_buffer = [avg_reward] self.result_buffer.append(one_result_buffer) if epoch % self.train_config.save_epoch == 0 and epoch != 0: self.save_model_and_result(epoch) print("episode_{} over,avg_reward {}".format(epoch, avg_reward))这一段代码什么意思
09-28
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值