最初代码:
action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax
修改后代码:
action = torch.max(actions_value, 1)[1].data.numpy()[0] # return the argmax
参考:
https://morvanzhou.github.io/tutorials/machine-learning/torch/4-05-DQN/链接下面的评论