torch.argmax(input, dim, keepdim=False) → LongTensor
功能:返回最大值所在的索引。根据dim的不同有不同的处理方式
1.dim=None时,则返回扁平化输入的 argmax,就是将input展成一维,然后从中找出最大值所在索引,此时返回的Tensor只有一个值
>>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a) tensor(0)
2.dim有值时,返回该维最大值的缩印组成的tensor
>>> a = torch.randn(4, 4) >>> a tensor([[ 1.3398, 0.2663, -0.2686, 0.2450], [-0.7401, -0.8805, -0.3402, -1.1936], [ 0.4907, -1.3948, -1.0691, -0.3132], [-1.6092, 0.5419, -0.2993, 0.3195]]) >>> torch.argmax(a, dim=1) tensor([ 0, 2, 0, 1])