一、引言
在深度学习领域,Tensor(张量)是数据存储与运算的核心载体。掌握 Tensor 的进阶操作,能更高效地处理数据。在学习这些进阶操作前,理解 dim
(维度)的概念至关重要,它贯穿于众多 Tensor 操作中。本篇围绕 Tensor 的索引与数据筛选、裁剪、组合拼接、切片、变形、填充、广播机制等展开,结合原理分析、应用场景说明及代码示例,帮助新手深入理解并灵活运用这些技巧,为后续模型开发夯实基础。
二、dim
(维度)
开始之前让我们先回顾一下,Tensor维度的概念。
2.1 维度的直观理解
我们可以把 Tensor 想象成不同形状的容器,维度就是描述这个容器形状的参数。
零维张量(标量)
零维张量就像是一个单独的数字,它没有方向,也没有形状,只有一个值。例如,torch.tensor(5)
就是一个零维张量,它只有一个元素,不存在维度的概念,就像一个孤立的点。
一维张量(向量)
一维张量可以看作是一排数字,就像我们日常排队一样。每个数字都有一个对应的位置,这个位置就是它的索引。例如,torch.tensor([1, 2, 3, 4])
就是一个一维张量,它有 4 个元素,就像有 4 个人在排队。这里的维度就是 1,它表示这是一个线性的排列。
二维张量(矩阵)
二维张量类似于一个表格,有行和列。我们可以把它想象成一个教室的座位表,每一行代表一排座位,每一列代表一列座位。例如,torch.tensor([[1, 2, 3], [4, 5, 6]])
是一个二维张量,它有 2 行 3 列。在这个例子中,我们可以说有两个维度,第一个维度表示行,第二个维度表示列。
高维张量
当维度超过 2 时,理解起来可能会有点困难,但我们可以通过类比来想象。比如三维张量可以看作是一叠表格,就像一叠扑克牌,每一张扑克牌就是一个二维的表格,而这一叠的顺序就是第三个维度。更高维度的张量可以继续这样层层嵌套地想象。
2.2 dim
在操作中的作用
在 PyTorch 的很多函数中,dim
参数用于指定操作在哪个维度上进行。可以把它理解成我们对 Tensor 这个容器进行操作时的一个“方向”。
例如,在进行求和操作时,如果指定 dim = 0
,就相当于沿着第一个维度(行)的方向进行求和。对于一个二维张量,就是把每一列的元素相加;如果指定 dim = 1
,则是沿着第二个维度(列)的方向进行求和,也就是把每一行的元素相加。
import torch
# 创建一个二维张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 沿着行的方向(dim = 0)求和
sum_row = torch.sum(tensor, dim=0)
print("沿着行的方向求和:", sum_row) # 输出: tensor([5, 7, 9])
# 沿着列的方向(dim = 1)求和
sum_col = torch.sum(tensor, dim=1)
print("沿着列的方向求和:", sum_col) # 输出: tensor([ 6, 15])
通过这个例子,我们可以看到 dim
参数决定了操作的方向,帮助我们在不同的维度上对 Tensor 进行处理。理解了 dim
的概念后,我们就可以更轻松地学习后续的 Tensor 进阶操作了。
三、Tensor 的索引与数据筛选
3.1 索引与数据筛选的重要性
在实际数据处理中,常需从 Tensor 提取特定元素或满足条件的数据。例如,图像分类中筛选特定区域像素值,自然语言处理中提取符合条件的词向量。PyTorch 提供的函数可精准实现索引与数据筛选。
3.2 常用索引与筛选函数详解
torch.where(condition, x, y)
- 原理:根据
condition
布尔值,从x
和y
选取元素。True
取x
元素,False
取y
元素。 - 应用场景:数据清洗、按条件替换元素。
- 示例:
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
condition = torch.tensor([True, False, True])
where_result = torch.where(condition, x, y)
print("torch.where 结果:", where_result)
# 输出:torch.where 结果: tensor([1, 5, 3])
torch.gather(input, dim, index, out=None)
- 原理:在
input
的指定维度dim
,按index
索引选取元素。 - 应用场景:推荐系统提取推荐结果,强化学习提取奖励值。
- 示例:
input_tensor = torch.tensor([[1, 2], [3, 4]])
dim = 1
index = torch.tensor([[0], [1]])
gather_result = torch.gather(input_tensor, dim, index)
print("torch.gather 结果:\n", gather_result)
# 输出:tensor([[1], [4]])
torch.index_select(input, dim, index, out=None)
- 原理:在指定维度
dim
,按index
索引选取input
元素。 - 应用场景:从大规模数据集选取特定样本或特征。
- 示例:
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
dim = 1
index = torch.tensor([0, 2])
index_select_result = torch.index_select(input_tensor, dim, index)
print("torch.index_select 结果:\n", index_select_result)
# 输出:tensor([[1, 3], [4, 6]])
torch.masked_select(input, mask, out=None)
- 原理:通过
mask
布尔值筛选input
元素,输出一维向量。 - 应用场景:提取符合条件的元素,如筛选高于阈值的像素值。
- 示例:
input_tensor = torch.tensor([1, 2, 3, 4])
mask = torch.tensor([False, True, False, True])
masked_select_result = torch.masked_select(input_tensor, mask)
print("torch.masked_select 结果:", masked_select_result)
# 输出:tensor([2, 4])
torch.take(input, indices)
- 原理:将
input
视为一维 Tensor,按indices
索引选取元素。 - 应用场景:对多维 Tensor 扁平化后的索引操作。
- 示例:
input_tensor = torch.tensor([[1, 2], [3, 4]])
indices = torch.tensor([0, 2, 3])
take_result = torch.take(input_tensor, indices)
print("torch.take 结果:", take_result)
# 输出:tensor([1, 3, 4])
torch.nonzero(input, out=None)
- 原理:返回
input
非零元素坐标,以二维 Tensor 呈现。 - 应用场景:分析非零元素分布,如稀疏矩阵处理。
- 示例:
input_tensor = torch.tensor([[0, 1], [2, 0]])
nonzero_result = torch.nonzero(input_tensor)
print("torch.nonzero 结果:\n", nonzero_result)
# 输出:tensor([[0, 1], [1, 0]])
四、Tensor 的裁剪运算
4.1 裁剪运算的核心作用
深度学习训练中,梯度离散或爆炸是常见问题。裁剪运算(如 clamp
函数)可过滤元素范围,限制梯度,确保训练稳定,也用于数据预处理处理异常值。
4.2 clamp
函数深度解析
- 原理:
clamp(min, max)
将小于min
的元素设为min
,大于max
的设为max
,中间值保留。 - 示例:
tensor = torch.tensor([1, 4, 6, 8])
clamped_tensor = tensor.clamp(2, 7)
print("裁剪结果:", clamped_tensor)
# 输出:tensor([2, 4, 6, 7])
五、Tensor 的组合/拼接
5.1 torch.cat
:沿已有维度拼接
- 原理:
torch.cat(seq, dim=0)
沿指定维度dim
拼接 Tensor,要求其他维度形状一致。 - 应用场景:合并训练数据批次,组合不同特征 Tensor。
- 示例:
t1 = torch.tensor([[1, 2], [3, 4]])
t2 = torch.tensor([[5, 6]])
cat_result = torch.cat((t1, t2), dim=0)
print("torch.cat 结果:\n", cat_result)
# 输出:tensor([[1, 2], [3, 4], [5, 6]])
5.2 torch.stack
:按新维度拼接
- 原理:
torch.stack(seq, dim=0)
创建新维度dim
堆叠 Tensor,要求所有 Tensor 形状一致。 - 应用场景:将多个样本组合为批次,如图像分类输入。
- 示例:
t1 = torch.tensor([1, 2])
t2 = torch.tensor([3, 4])
stack_result = torch.stack((t1, t2), dim=1)
print("torch.stack 结果:\n", stack_result)
# 输出:tensor([[1, 3], [2, 4]])
六、Tensor 的切片
6.1 torch.chunk
:按维度平均分块
- 原理:
torch.chunk(tensor, chunks, dim=0)
沿维度dim
平均分chunks
块,最后一块可能较小。 - 应用场景:大规模数据分块处理,如分布式训练分配数据。
- 示例:
tensor = torch.arange(8)
chunks = torch.chunk(tensor, 3, dim=0)
for i, chunk in enumerate(chunks):
print(f"第 {i + 1} 块: {chunk}")
# 输出:第 1 块: tensor([0, 1, 2]);第 2 块: tensor([3, 4, 5]);第 3 块: tensor([6, 7])
6.2 torch.split
:按指定大小分割
- 原理:
torch.split(tensor, split_size_or_sections, dim=0)
沿维度dim
,按指定大小或列表分割。 - 应用场景:灵活划分数据,如分训练集、验证集、测试集。
- 示例:
tensor = torch.arange(9)
splits = torch.split(tensor, [3, 4, 2], dim=0)
for i, split in enumerate(splits):
print(f"第 {i + 1} 分割: {split}")
# 输出:第 1 分割: tensor([0, 1, 2]);第 2 分割: tensor([3, 4, 5, 6]);第 3 分割: tensor([7, 8])
七、Tensor 的变形操作
7.1 变形操作的多样性与实用性
深度学习中,常需根据层的输入输出要求改变 Tensor 形状,PyTorch 提供丰富变形函数满足需求。
7.2 常用变形函数详解
torch.reshape(input, shape)
- 原理:不改变数据存储顺序,重排
input
为指定shape
。 - 示例:
tensor = torch.arange(6)
reshaped = torch.reshape(tensor, (2, 3))
print("torch.reshape 结果:\n", reshaped)
# 输出:tensor([[0, 1, 2], [3, 4, 5]])
torch.t(input)
- 原理:专门用于 2D Tensor 转置,交换行和列。
- 示例:
tensor_2d = torch.tensor([[1, 2], [3, 4]])
transposed = torch.t(tensor_2d)
print("torch.t 结果:\n", transposed)
# 输出:tensor([[1, 3], [2, 4]])
torch.transpose(input, dim0, dim1)
- 原理:交换
input
中指定的两个维度dim0
和dim1
,适用于高维 Tensor。 - 示例:
tensor_3d = torch.rand(2, 3, 4)
transposed_3d = torch.transpose(tensor_3d, 1, 2)
print("torch.transpose 结果形状:", transposed_3d.shape)
# 输出:torch.Size([2, 4, 3])
torch.squeeze(input, dim=None, out=None)
- 原理:不指定
dim
时,去除所有维度为 1 的维度;指定dim
时,仅去除该维度(需维度为 1)。 - 示例:
tensor = torch.tensor([[1]])
squeezed = torch.squeeze(tensor)
print("torch.squeeze 结果:", squeezed)
# 输出:tensor(1)
torch.unsqueeze(input, dim, out=None)
- 原理:在指定位置
dim
添加维度为 1 的维度。 - 示例:
tensor = torch.tensor([1, 2, 3])
unsqueezed = torch.unsqueeze(tensor, 0)
print("torch.unsqueeze 结果形状:", unsqueezed.shape)
# 输出:torch.Size([1, 3])
torch.flip(input, dims)
- 原理:沿
dims
中指定维度翻转 Tensor。 - 示例:
tensor = torch.tensor([[1, 2], [3, 4]])
flipped = torch.flip(tensor, dims=[0])
print("torch.flip 结果:\n", flipped)
# 输出:tensor([[3, 4], [1, 2]])
torch.rot90(input, k, dims)
- 原理:沿指定维度
dims
,将 Tensor 旋转k
个 90 度。 - 示例:
tensor = torch.tensor([[1, 2], [3, 4]])
rotated = torch.rot90(tensor, k=1, dims=[0, 1])
print("torch.rot90 结果:\n", rotated)
# 输出:tensor([[2, 4], [1, 3]])
八、Tensor 的填充操作
8.1 torch.full
:快速创建填充 Tensor
- 原理:
torch.full(size, fill_value)
创建指定形状size
的 Tensor,用fill_value
填充。 - 应用场景:初始化权重矩阵,如神经网络全连接层权重。
- 示例:
filled_tensor = torch.full((2, 3), 3.14)
print("填充结果:\n", filled_tensor)
# 输出:tensor([[3.1400, 3.1400, 3.1400], [3.1400, 3.1400, 3.1400]])
九、PyTorch 中的广播机制
9.1 广播机制的原理与规则
广播机制允许不同形状 Tensor 运算,通过自动扩展维度匹配形状,需满足:
- 每个张量至少有一个维度。
- 从最后一维向前匹配,维度大小相等或其中一个为 1。
9.2 广播机制示例
t1 = torch.rand(2, 1, 1)
t2 = torch.rand(3)
result = t1 + t2
print("广播后结果形状:", result.shape)
# 输出:torch.Size([2, 3, 3])
广播机制的规则
从右向左比较维度:从张量的最后一个维度开始,逐一比较两个张量的维度大小。
- 维度相等或其一为 1:
若两个张量对应维度的大小相等,或者其中一个维度大小为 1,那么这两个维度是兼容的。 - 扩展维度为 1 的张量:
当某个张量的某个维度大小为 1 时,会将该维度扩展为与另一个张量对应维度相同的大小。 - 维度缺失则插入 1:
若一个张量的维度少于另一个,会在其左侧插入大小为 1 的维度,直至维度数量相同。
t1 和 t2 相加的详细过程
已知 t1 的形状是 (2, 1, 1),t2 的形状是 (3)。
- 使维度数量相同
t2 只有一个维度,而 t1 有三个维度。按照规则,在 t2 的左侧插入大小为 1 的维度,使它的维度数量和 t1 一样。这样,t2 的形状就变成了 (1, 1, 3)。 - 从右向左比较维度
- 最后一个维度:
t1 的最后一个维度大小是 1,t2 的最后一个维度大小是 3。依据规则,会把 t1 的最后一个维度扩展为 3,也就是把 t1 最后一个维度的值复制 3 次。 - 倒数第二个维度:t1 的倒数第二个维度大小是 1,t2 的倒数第二个维度大小也是 1,这两个维度兼容,无需扩展。
- 第一个维度:t1 的第一个维度大小是 2,t2 的第一个维度大小是 1。按照规则,会把 t2 的第一个维度扩展为 2,也就是把 t2 第一个维度的值复制 2 次。
- 最后一个维度:
- 扩展后的形状
扩展之后,t1 的形状变为 (2, 1, 3),t2 的形状也变为 (2, 1, 3),此时两个张量形状相同,就能够进行逐元素相加了。
9.3 广播机制的应用场景
广播机制广泛应用于批量运算,如对批次图像标准化,通过广播将均值、方差应用到每个图像,提升计算效率。
十、总结
本篇深入解析了 Tensor 的进阶操作,包括索引筛选、裁剪、组合拼接、切片、变形、填充、广播机制等,结合原理、场景与代码。通过学习,新手可理解操作逻辑,在实际项目中灵活运用,为深度学习模型构建提供有力支持。后续在模型训练优化中,这些技巧将助力实现更好的深度学习目标。