对pytorch中 torch.argmax(dim=)、x.argmin(dim=)的容易理解
import torchx=torch.randn(3,3,4)print(x)print(x.argmax(dim=0))输出:tensor([[[ 0.5128, 0.3717, 0.3606, -0.0286], [ 0.0933, -1.4781, -0.3561, -0.2652], [-0.8861, 0.6988, 1.1243, -1.1301]], [[-0.0246, 0.0917, -0.0623, -.
原创
2021-11-02 18:50:23 ·
2738 阅读 ·
4 评论