在 PyTorch 中,张量(torch.Tensor)的索引与切片操作和 Python 的列表、NumPy 数组类似,但针对高维张量做了扩展,支持灵活的维度访问和修改。
1. 基本索引(单元素访问)
通过下标[]访问张量中的单个元素,索引从 0 开始,支持负数索引(表示从末尾开始)。
举个栗子:
- 1维张量
t1 = torch.tensor([1, 2, 3, 4, 5])
print(t1[0])
# 输出: tensor(1)
print(t1[-1])
# 输出: tensor(5)
- 2维张量(行×列)
t2 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(t2[1, 2])
# 访问第2行第3列(索引1,2),输出: tensor(6)
print(t2[0][1])
# 等价于t2[0,1],输出: tensor(2)
2. 切片(范围访问)
使用 start:end:step 语法获取张量的子区域,规则与 Python 列表一致:
start:起始索引(默认 0);end:结束索引(不包含,默认维度长度);step:步长(默认 1,负数表示反向)。
举个栗子:
- 1维张量切片
t1 = torch.tensor([1, 2, 3, 4, 5])
print(t1[1:4])
# 从索引1到3(不包含4),输出: tensor([2, 3, 4])
print(t1[:3])
# 从开头到2,输出: tensor([1, 2, 3])
print(t1[::2])
# 步长2,输出: tensor([1, 3, 5])
print(t1[::-1])
# 反向,输出: tensor([5, 4, 3, 2, 1])
- 2维张量切片(行优先)
t2 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(t2[0:2, 1:3])
# 行0-1,列1-2 → [[2,3],[5,6]]
# 输出:
# tensor([[2, 3],
# [5, 6]])
print(t2[:, 0])
# 所有行,第0列 → [1,4,7]
# 输出: tensor([1, 4, 7])
3. 高维张量索引
对于 3 维及以上张量(如 (深度, 行, 列) 或 (批次, 通道, 高, 宽)),通过逗号分隔各维度的索引 / 切片。
举个栗子::
- 3 维张量
t3 = torch.tensor([
[[1, 2], [3, 4]], # 第0个深度
[[5, 6], [7, 8]], # 第1个深度
[[9, 10], [11, 12]] # 第2个深度
])
# 访问第1个深度的第0行第1列
print(t3[1, 0, 1])
# 输出: tensor(6)
# 切片:所有深度,第0行,所有列
print(t3[:, 0, :])
# 输出:
# tensor([[ 1, 2],
# [ 5, 6],
# [ 9, 10]])
4. 高级索引
除了基本切片,PyTorch 还支持 整数数组索引 和 布尔索引,用于非连续元素的访问。
(1)整数数组索引(按索引列表取元素)
通过列表或张量指定要访问的索引,返回与索引列表形状一致的张量。
t2 = torch.tensor([
[1, 2, 3], # 行索引 0
[4, 5, 6], # 行索引 1
[7, 8, 9] # 行索引 2
])
indices = torch.tensor([[0, 1], [2, 0]]) # 每行是一组坐标 (行,列)
#输出:tensor([[0, 1],
# [2, 0]])
print(t2[indices[:, 0], indices[:, 1]])
# 输出: tensor([2, 7])
这里用了两个索引数组:
- indices[:, 0]:取 indices 所有行的第 0 列 → 结果是 [0, 2](行索引列表)
- indices[:, 1]:取 indices 所有行的第 1 列 → 结果是 [1, 0](列索引列表)
PyTorch 会将两个索引数组按相同位置的元素配对,即:
- 第 1 对:行索引 = 0,列索引 = 1 → 对应 t2[0, 1] → 值为 2
- 第 2 对:行索引 = 2,列索引 = 0 → 对应 t2[2, 0] → 值为 7
简化理解
把 indices[:, 0] 和 indices[:, 1] 看作两个 “并行列表”,分别存储行索引和列索引,然后按顺序一一对应取值:
行索引列表:[0, 2]
列索引列表:[1, 0]
→ 取 t2[0,1] 和 t2[2,0] → [2, 7]
适合从张量中提取 非连续的、任意位置的多个元素,只需按顺序列出它们的坐标即可。
t2 = torch.tensor([
[1, 2, 3], # 行索引 0
[4, 5, 6], # 行索引 1
[7, 8, 9] # 行索引 2
])
# 取第0行的第1和2列
print(t2[0, [1, 2]])
# 输出: tensor([2, 3])
t2[0, [1, 2]]
- 第一个索引
0:指定要访问的 行 → 第 0 行(即[1, 2, 3])。 - 第二个索引
[1, 2]:指定要访问的 列索引列表 → 第 1 列和第 2 列。
#####取值过程
第 0 行的第 1 列:t2[0, 1]→ 值为2;
第 0 行的第 2 列:t2[0, 2]→ 值为3。
2)布尔索引(按条件筛选元素)
通过与原张量形状相同的布尔张量,筛选出 True 位置的元素。
t2 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
# 筛选所有大于5的元素
mask = t2 > 5
print(t2[mask]) # 输出: tensor([6, 7, 8, 9])
# 结合切片使用(第0行中大于1的元素)
print(t2[0, t2[0] > 1]) # 输出: tensor([2, 3])
注意事项
视图(View)与副本(Copy):
- 基本切片(如
t[:, :2])返回的是原张量的视图(view),修改切片会影响原张量; - 高级索引(如整数数组、布尔索引)返回的是副本(copy),修改不会影响原张量。
t = torch.tensor([1, 2, 3])
slice_view = t[:2]
slice_view[0] = 0
print(t) # 输出: tensor([0, 2, 3])(原张量被修改)
mask = t > 1
slice_copy = t[mask]
slice_copy[0] = 100
print(t) # 输出: tensor([0, 2, 3])(原张量未修改)
- PyTorch 中,对张量进行基本切片(如
t[:2]、t[1:4]等连续范围的切片)时,不会复制原张量的数据,而是创建一个 “视图”—— 视图仅记录对原张量的索引范围(如 “从 0 到 2 的元素”),实际数据仍存储在原张量的内存中。 - 因此,slice_view 和 t 指向同一块内存的不同区域,修改 slice_view 相当于修改原内存中的数据,原张量 t 自然会受到影响。
- 只有当使用 高级索引(如整数数组索引、布尔索引)时,PyTorch 才会创建原张量的副本(复制数据到新内存),此时修改副本不会影响原张量。
维度保留
若需保留切片后的维度(如从 2 维张量中取一行仍保持 2 维),可使用 None 或 np.newaxis 增加维度:
t2 = torch.tensor([[1, 2, 3], [4, 5, 6]])
row = t2[0, :] # 形状 (3,)(1维)
row_keep_dim = t2[0:1, :] # 形状 (1, 3)(2维,保留行维度)
第一种情况:row = t2[0, :] → 形状 (3,)(1 维)
- 原本
t2有 “行、列” 两个维度,当你明确指定 “第 0 行” 后,“行” 这个维度就消失了,只剩下 “列” 维度对应的 3 个元素,所以最终形状是(3,)(1 维张量,只有一个维度)。 - 可以理解为:把 “第 0 行” 从 2 维表格里 “抽出来”,变成了一个 1 维的 “列表” [1, 2, 3]。
二种情况:row_keep_dim = t2[0:1, :] → 形状 (1, 3)(2 维)
用范围切片(如 0:1)访问某一维度时,会 “保留” 这个维度。
- 虽然 0:1 最终只取了 1 行,但它明确表示 “这是一个包含 1 行的行维度”,所以 t2 原本的 “行、列” 两个维度都被保留了,只是行维度的长度从 2 变成了 1,最终形状是 (1, 3)(2 维张量,仍有 “行、列” 结构)。
- 可以理解为:把 “第 0 行” 从 2 维表格里 “切出来一小块”,这块仍然是一个 2 维的 “小表格”,只是只有 1 行 3 列:[[1, 2, 3]]。
负数步长的边界
反向切片(如 t[::-1])中,start 和 end 需注意边界,例如 t[3:0:-1] 表示从索引 3 到 1(不包含 0)。
先明确正向切片的基础规则
对于正向切片 t[start:end:step=1]:
- 从
start开始(包含start); - 到
end结束(不包含end); - 步长为 1,依次递增取元素。
例如t = [0,1,2,3,4],t[1:4]表示从索引 1 到 3(不包含 4),结果是[1,2,3]。
反向切片的特殊逻辑(step=-1)
当步长为负数时,切片方向变为 “从后往前”,此时:
start必须 大于end(否则切片为空);
仍然遵循 “包含start,不包含end” 的规则,但方向相反。
举个栗子:
假设 t = [0, 1, 2, 3, 4](索引 0 到 4),t[3:0:-1]:
start=3:从索引 3 开始(包含 3,对应元素 3);end=0:到索引 0 结束(不包含 0,即停止在索引 1);step=-1:每次向前移动 1 位(从高索引到低索引)。
最终结果:[3, 2, 1](从索引 3 到 1,不包含 0)。
对比 t[::-1](完整反向),start 和 end 都省略了,默认:
start为最后一个索引(4);end为第一个索引的前一位(-1,等效于不包含 0 之前的位置);- 步长 - 1,从后往前取所有元素。
结果:[4,3,2,1,0](完整反转)。
8万+

被折叠的 条评论
为什么被折叠?



