torch中关于torch.max()和torch.min()函数的理解

Tensor中max和min函数的用法详解
  • 简介

在tensor类型的数据中,max和min函数常用来比较两个tensor数据的大小,或者取出tensor数据中的最大值。关于max函数和min函数的用法有以下几种场景:

对于tensorA和tensorB:

  1. torch.max(tensorA):返回tensor中的最大值。
  2. torch.mac(tensorA,dim):dim表示指定的维度,返回指定维度的最大数和对应下标
  3. torch.max(tensorA,tensorB):比较tensorA和tensorB相对较大的元素。
输入:
x = th.arange(0,16,1).view(4,4)
print('x:\n',x)
print('t.max(x):\n',t.max(x))
print('t.max(x,1):\n',t.max(x,1))
print('t.max(x,0):\n',t.max(x,0))
print('t.max(x,1)[0]:\n',t.max(x,1)[0])
print('t.max(x,1)[1]:\n',t.max(x,1)[1])
print('t.max(x,1)[1].data:\n',t.max(x,1)[1].data)
print('t.max(x,1)[1].data.numpy():\n',t.max(x,1)[1].data.numpy())
print('t.max(x,1)[1].data.numpy().squeeze():\n',t.max(x,1)[1].data.numpy().squeeze())
print('t.max(x,1)[0].data:\n',t.max(x,1)[0].data)
print('t.max(x,1)[0].data.numpy():\n',t.max(x,1)[0].data.numpy())
print('t.max(x,1)[0].data.numpy().squeeze():\n',t.max(x,1)[0].data.numpy().squeeze())
输出:
x:
 tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
t.max(x):
 tensor(15)
t.max(x,1):
 torch.return_types.max(
values=tensor([ 3,  7, 11, 15]),
indices=tensor([3, 3, 3, 3]))
t.max(x,0):
 torch.return_types.max(
values=tensor([12, 13, 14, 15]),
indices=tensor([3, 3, 3, 3]))
t.max(x,1)[0]:
 tensor([ 3,  7, 11, 15])
t.max(x,1)[1]:
 tensor([3, 3, 3, 3])
t.max(x,1)[1].data:
 tensor([3, 3, 3, 3])
t.max(x,1)[1].data.numpy():
 [3 3 3 3]
t.max(x,1)[1].data.numpy().squeeze():
 [3 3 3 3]
t.max(x,1)[0].data:
 tensor([ 3,  7, 11, 15])
t.max(x,1)[0].data.numpy():
 [ 3  7 11 15]
t.max(x,1)[0].data.numpy().squeeze():
 [ 3  7 11 15]

上面代码中x是一个4*4的二阶张量:

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

t.max(x)表示从x中取出最大的那个值:15

t.max(x,1)表示每行中取出最大的元素,并返回其下标,第一行为3,第二行为7,第三行为11,第四行为15.注意:t.max(x,1)中的1表示行,0表示列,因为此处的例子是二阶的张量,只有行和列。

所以t.max(x,1)表示每列中取出最大的元素,即[3,7,11,15]。

.data 只返回variable中的数据部分(去掉Variable containing:)。

.data.numpy() 把数据转化成numpy ndarry。

].data.numpy().squeeze() 把数据条目中维度为1 的删除掉。

torch.max(tensorA,tensorB) element-wise 比较tensorA 和tensorB 中的元素,返回较大的那个值。

那如果是3阶的张两呢?

下面我们将x的张量变成2*2*4,即2维的两行四列。

输入:
x = th.arange(0,16,1).view(2,2,4)
print('x:\n',x)
print('t.max(x):\n',t.max(x))
print('t.max(x,1):\n',t.max(x,1))
print('t.max(x,0):\n',t.max(x,0))
print('t.max(x,1)[0]:\n',t.max(x,1)[0])
print('t.max(x,1)[1]:\n',t.max(x,1)[1])
print('t.max(x,1)[1].data:\n',t.max(x,1)[1].data)
print('t.max(x,1)[1].data.numpy():\n',t.max(x,1)[1].data.numpy())
print('t.max(x,1)[1].data.numpy().squeeze():\n',t.max(x,1)[1].data.numpy().squeeze())
print('t.max(x,1)[0].data:\n',t.max(x,1)[0].data)
print('t.max(x,1)[0].data.numpy():\n',t.max(x,1)[0].data.numpy())
print('t.max(x,1)[0].data.numpy().squeeze():\n',t.max(x,1)[0].data.numpy().squeeze())


输出:

x:
 tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
t.max(x):
 tensor(15)
t.max(x,1):
 torch.return_types.max(
values=tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]]),
indices=tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]]))
t.max(x,0):
 torch.return_types.max(
values=tensor([[ 8,  9, 10, 11],
        [12, 13, 14, 15]]),
indices=tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]]))
t.max(x,1)[0]:
 tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]])
t.max(x,1)[1]:
 tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]])
t.max(x,1)[1].data:
 tensor([[1, 1, 1, 1],
        [1, 1, 1, 1]])
t.max(x,1)[1].data.numpy():
 [[1 1 1 1]
 [1 1 1 1]]
t.max(x,1)[1].data.numpy().squeeze():
 [[1 1 1 1]
 [1 1 1 1]]
t.max(x,1)[0].data:
 tensor([[ 4,  5,  6,  7],
        [12, 13, 14, 15]])
t.max(x,1)[0].data.numpy():
 [[ 4  5  6  7]
 [12 13 14 15]]
t.max(x,1)[0].data.numpy().squeeze():
 [[ 4  5  6  7]
 [12 13 14 15]]

从输出结果看,

  1. t.max(x):依然是从x中取出最大的一个值。
  2. t.max(x,1):输出
    [[ 4,  5,  6,  7],
    [12, 13, 14, 15]],因为此时有两个维度,每个维度上有两行四列的数据,所以1表示行,即取出每个维度上每个行的最大值。[ 4, 5, 6, 7]是第一个维度中的最大值所在的行,[12, 13, 14, 15]是第二个维度中最大值所在的行。
  3. t.max(x,0):输出:
    torch.return_types.max(
    values=tensor([[ 8,  9, 10, 11],
            [12, 13, 14, 15]]),
    indices=tensor([[1, 1, 1, 1],
            [1, 1, 1, 1]]))此时,0表示是维度,x有两个维度,表示从这两个维度中取出最大值所在的维度。即[[ 8, 9, 10, 11], [12, 13, 14, 15]]。
    

假如我们想要取出每个维度中每一列中最大的元素该怎么办呢?我们只需要指定t.max(x,dim)中的dim=2就可以取出列向量了。

输入:
x = th.arange(0,16,1).view(2,2,4)
print('x:\n',x)
print('t.max(x):\n',t.max(x))
print('t.max(x,2):\n',t.max(x,2))
print('t.max(x,2)[0]:\n',t.max(x,2))
print('t.max(x,1)[0][0]:\n',t.max(x,2)[0][0])


输出:

x:
 tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]]])
t.max(x):
 tensor(15)
t.max(x,2):
 torch.return_types.max(
values=tensor([[ 3,  7],
        [11, 15]]),
indices=tensor([[3, 3],
        [3, 3]]))
t.max(x,2)[0]:
 torch.return_types.max(
values=tensor([[ 3,  7],
        [11, 15]]),
indices=tensor([[3, 3],
        [3, 3]]))
t.max(x,1)[0][0]:
 tensor([3, 7])

 

<think>我们被要求解释torch.max函数的作用,并且根据提供的引用材料进行回答。引用材料中详细描述了torch.max的用法,包括其调用方式、参数、返回值以及示例。 根据引用[1]引用[4]: - torch.max有两种主要调用方式: 1. 不指定维度:返回整个输入张量中的最大值。 2. 指定维度(dim):返回该维度上每一切片的最大值,同时可以选择返回最大值的位置索引。 引用[2]引用[3]提供了关于dim参数使用的示例,说明当指定维度时,函数会沿着该维度进行最大值计算,并返回两个张量:最大值构成的张量对应的索引张量。 引用[5]补充了argmax的用法,但这里我们关注的是max。 因此,我们可以将torch.max的作用总结为: 1. 求全局最大值(不指定dim) 2. 沿着指定维度求最大值,并可以选择返回最大值的位置(指定dim,并设置keepdim参数) 另外,引用[4]提到参数keepdim(默认为False)的作用:是否保持输出张量的维度。如果设置为True,则输出张量在指定的维度上长度为1,否则该维度会被压缩(去掉)。 接下来,我们按照要求组织回答: - 使用中文回答 - 行内公式用$...$,独立公式用$$...$$ - 在适当位置添加引用标识 - 最后列出相关问题 由于问题只要求解释作用,我们不需要写太多代码,但可以给出简单的示例说明。 注意:用户的问题中引用了多个材料,我们需要综合这些材料。 </think>### `torch.max`函数在PyTorch中的作用详解 `torch.max`是PyTorch中的核心函数,主要用于**张量的极值计算****位置索引获取**。其功能根据参数不同分为两种模式: #### 1. 全局最大值计算(不指定维度) - **功能**:返回整个张量中的最大值 - **返回值**:单个标量值 - **数学表达**: $$ \text{max\_val} = \max_{i,j,\cdots} x_{ij\cdots} $$ ```python import torch x = torch.tensor([[5, 2, 2], [7, 2, 0], [8, 0, 6]]) max_val = torch.max(x) # 返回 tensor(8) ``` #### 2. 沿维度最大值计算(指定`dim`参数) - **功能**:沿指定维度计算最大值及其索引 - **返回值**:命名元组`(values, indices)`,包含: - `values`:各切片的最大值 - `indices`:最大值在原始维度的位置索引 - **数学表达**(以`dim=0`为例): $$ \text{values}_j = \max_i x_{ij}, \quad \text{indices}_j = \arg\max_i x_{ij} $$ ##### 维度计算示例: ```python # 沿列计算(dim=0) values, indices = torch.max(x, dim=0) """ values = tensor([8, 2, 6]) # 每列最大值 indices = tensor([2, 0, 2]) # 最大值所在行号 """ # 沿行计算(dim=1) values, indices = torch.max(x, dim=1) """ values = tensor([5, 7, 8]) # 每行最大值 indices = tensor([0, 0, 0]) # 最大值所在列号 """ ``` #### 3. 关键参数说明 | 参数 | 类型 | 作用 | 示例 | |----------|---------|-------------------------------|---------------------------| | `dim` | int | 指定计算维度 | `dim=0`(列方向)[^2][^4] | | `keepdim`| bool | 是否保留原维度 | `False`(默认压缩维度) | | `out` | tuple | 指定输出存储位置 | 可选 | #### 4. 特殊行为图示 ```python t = torch.randn(2, 2, 2) # 3维张量 """ t = tensor([[[-0.1, 1.2], [ 0.5, -0.3]], [[ 2.1, -0.7], [-1.0, 0.9]]]) """ # 沿第0维计算 torch.max(t, dim=0) """ values: tensor([[2.1, 1.2], # 两个2x2张量的逐元素最大值 [0.5, 0.9]]) indices: tensor([[1, 0], [0, 1]]) """ ``` #### 5. 与相关函数的对比 | 函数 | 返回值 | 主要用途 | |------------------|---------------------|------------------------------| | `torch.max()` | (值, 索引) 或 标量 | 获取最大值及其位置[^1][^3] | | `torch.argmax()` | 索引张量 | 仅获取最大值位置[^5] | | `torch.min()` | (值, 索引) 或 标量 | 最小值计算(与max对称) | > **典型应用场景**: > 1. 分类任务中获取模型预测类别:`pred = torch.max(logits, dim=1)[1]` > 2. 特征图中提取空间最大值(MaxPooling基础) > 3. 数据归一化处理时计算数值范围 --- ### 相关问题 1. 当`keepdim=True`时,`torch.max`的输出维度会发生什么变化?[^4] 2. 如何理解高维张量(如4D图像张量)中`dim`参数的意义? 3. `torch.max`与NumPy的`np.max`在广播机制上有哪些差异? 4. 在梯度计算中,使用`torch.max`获取的索引是否参与反向传播?为什么? 5. 如何利用`torch.max`实现Top-k选择功能?
评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值