torch.max()
是一个 PyTorch 函数,可让您计算 PyTorch 张量指定维度上的最大值及其相应索引。以下是 的基本用法torch.max()
import torch
# 创建一个样本张量
x = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# 计算最大值及其沿特定维度的索引
max_values, max_indices = torch.max(x, dim=0)
print("Maximum values:", max_values)
print("Indices of maximum values:", max_indices)
输出的结果:
Maximum values: tensor([4, 5, 6])
Indices of maximum values: tensor([1, 1, 1])
在此示例中,torch.max(x, dim=0)
计算输入张量 沿维度 0(沿列)的最大值及其相应索引x
。
torch.max(x, dim=1)
计算输入张量 沿维度 1(沿行)的最大值及其相应索引x
。
结果将是两个张量:
max_values
将包含沿指定维度的最大值。max_indices
将包含沿指定维度的最大值的索引。