PyTorch中的索引与切片

部署运行你感兴趣的模型镜像

在 PyTorch 中,张量(torch.Tensor)的索引与切片操作和 Python 的列表、NumPy 数组类似,但针对高维张量做了扩展,支持灵活的维度访问和修改。

1. 基本索引(单元素访问)

通过下标[]访问张量中的单个元素,索引从 0 开始,支持负数索引(表示从末尾开始)。
举个栗子:

  • 1维张量
t1 = torch.tensor([1, 2, 3, 4, 5])
print(t1[0])   
# 输出: tensor(1)
print(t1[-1])  
# 输出: tensor(5)
  • 2维张量(行×列)

t2 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(t2[1, 2]) 
 # 访问第2行第3列(索引1,2),输出: tensor(6)
print(t2[0][1])
 # 等价于t2[0,1],输出: tensor(2)

2. 切片(范围访问)

使用 start:end:step 语法获取张量的子区域,规则与 Python 列表一致:

  • start:起始索引(默认 0);
  • end:结束索引(不包含,默认维度长度);
  • step:步长(默认 1,负数表示反向)。

举个栗子:

  • 1维张量切片
t1 = torch.tensor([1, 2, 3, 4, 5])
print(t1[1:4])    
# 从索引1到3(不包含4),输出: tensor([2, 3, 4])
print(t1[:3])     
# 从开头到2,输出: tensor([1, 2, 3])
print(t1[::2])    
# 步长2,输出: tensor([1, 3, 5])
print(t1[::-1])  
 # 反向,输出: tensor([5, 4, 3, 2, 1])
  • 2维张量切片(行优先)
t2 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(t2[0:2, 1:3])  
# 行0-1,列1-2 → [[2,3],[5,6]]
# 输出:
# tensor([[2, 3],
#         [5, 6]])

print(t2[:, 0])  
# 所有行,第0列 → [1,4,7]
# 输出: tensor([1, 4, 7])

3. 高维张量索引

对于 3 维及以上张量(如 (深度, 行, 列)(批次, 通道, 高, 宽)),通过逗号分隔各维度的索引 / 切片。
举个栗子::

  • 3 维张量
t3 = torch.tensor([
    [[1, 2], [3, 4]],    # 第0个深度
    [[5, 6], [7, 8]],    # 第1个深度
    [[9, 10], [11, 12]]  # 第2个深度
])

# 访问第1个深度的第0行第1列
print(t3[1, 0, 1])  
# 输出: tensor(6)

# 切片:所有深度,第0行,所有列
print(t3[:, 0, :])
# 输出:
# tensor([[ 1,  2],
#         [ 5,  6],
#         [ 9, 10]])

4. 高级索引

除了基本切片,PyTorch 还支持 整数数组索引布尔索引,用于非连续元素的访问。

(1)整数数组索引(按索引列表取元素)

通过列表或张量指定要访问的索引,返回与索引列表形状一致的张量。

t2 = torch.tensor([
    [1, 2, 3],   # 行索引 0
    [4, 5, 6],   # 行索引 1
    [7, 8, 9]    # 行索引 2
])
indices = torch.tensor([[0, 1], [2, 0]])  # 每行是一组坐标 (行,列)
#输出:tensor([[0, 1],
#             [2, 0]])
print(t2[indices[:, 0], indices[:, 1]])  
# 输出: tensor([2, 7])

这里用了两个索引数组:

  • indices[:, 0]:取 indices 所有行的第 0 列 → 结果是 [0, 2](行索引列表
  • indices[:, 1]:取 indices 所有行的第 1 列 → 结果是 [1, 0](列索引列表

PyTorch 会将两个索引数组按相同位置的元素配对,即:

  • 第 1 对:行索引 = 0,列索引 = 1 → 对应 t2[0, 1] → 值为 2
  • 第 2 对:行索引 = 2,列索引 = 0 → 对应 t2[2, 0] → 值为 7
简化理解

indices[:, 0]indices[:, 1] 看作两个 “并行列表”,分别存储行索引和列索引,然后按顺序一一对应取值:

行索引列表:[0, 2]
列索引列表:[1, 0]
→ 取 t2[0,1] 和 t2[2,0] → [2, 7]

适合从张量中提取 非连续的、任意位置的多个元素,只需按顺序列出它们的坐标即可。

t2 = torch.tensor([
    [1, 2, 3],   # 行索引 0
    [4, 5, 6],   # 行索引 1
    [7, 8, 9]    # 行索引 2
])
# 取第0行的第1和2列
print(t2[0, [1, 2]])  
# 输出: tensor([2, 3])
t2[0, [1, 2]]
  • 第一个索引 0:指定要访问的 行 → 第 0 行(即 [1, 2, 3])。
  • 第二个索引 [1, 2]:指定要访问的 列索引列表 → 第 1 列和第 2 列。
    #####取值过程
    第 0 行的第 1 列:t2[0, 1] → 值为 2
    第 0 行的第 2 列:t2[0, 2] → 值为 3

2)布尔索引(按条件筛选元素)

通过与原张量形状相同的布尔张量,筛选出 True 位置的元素。

t2 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 筛选所有大于5的元素
mask = t2 > 5
print(t2[mask])  # 输出: tensor([6, 7, 8, 9])

# 结合切片使用(第0行中大于1的元素)
print(t2[0, t2[0] > 1])  # 输出: tensor([2, 3])

注意事项

视图(View)与副本(Copy):

  • 基本切片(如 t[:, :2])返回的是原张量的视图(view),修改切片会影响原张量;
  • 高级索引(如整数数组、布尔索引)返回的是副本(copy),修改不会影响原张量。
t = torch.tensor([1, 2, 3])
slice_view = t[:2]
slice_view[0] = 0
print(t)  # 输出: tensor([0, 2, 3])(原张量被修改)

mask = t > 1
slice_copy = t[mask]
slice_copy[0] = 100
print(t)  # 输出: tensor([0, 2, 3])(原张量未修改)
  • PyTorch 中,对张量进行基本切片(如 t[:2]t[1:4] 等连续范围的切片)时,不会复制原张量的数据,而是创建一个 “视图”—— 视图仅记录对原张量的索引范围(如 “从 0 到 2 的元素”),实际数据仍存储在原张量的内存中。
  • 因此,slice_view 和 t 指向同一块内存的不同区域,修改 slice_view 相当于修改原内存中的数据,原张量 t 自然会受到影响。
  • 只有当使用 高级索引(如整数数组索引、布尔索引)时,PyTorch 才会创建原张量的副本(复制数据到新内存),此时修改副本不会影响原张量。

维度保留

若需保留切片后的维度(如从 2 维张量中取一行仍保持 2 维),可使用 None 或 np.newaxis 增加维度:

t2 = torch.tensor([[1, 2, 3], [4, 5, 6]])
row = t2[0, :]          # 形状 (3,)(1维)
row_keep_dim = t2[0:1, :]  # 形状 (1, 3)(2维,保留行维度)
第一种情况:row = t2[0, :] → 形状 (3,)(1 维)
  • 原本 t2 有 “行、列” 两个维度,当你明确指定 “第 0 行” 后,“行” 这个维度就消失了,只剩下 “列” 维度对应的 3 个元素,所以最终形状是 (3,)(1 维张量,只有一个维度)。
  • 可以理解为:把 “第 0 行” 从 2 维表格里 “抽出来”,变成了一个 1 维的 “列表” [1, 2, 3]。
二种情况:row_keep_dim = t2[0:1, :] → 形状 (1, 3)(2 维)

用范围切片(如 0:1)访问某一维度时,会 “保留” 这个维度。

  • 虽然 0:1 最终只取了 1 行,但它明确表示 “这是一个包含 1 行的行维度”,所以 t2 原本的 “行、列” 两个维度都被保留了,只是行维度的长度从 2 变成了 1,最终形状是 (1, 3)(2 维张量,仍有 “行、列” 结构)。
  • 可以理解为:把 “第 0 行” 从 2 维表格里 “切出来一小块”,这块仍然是一个 2 维的 “小表格”,只是只有 1 行 3 列:[[1, 2, 3]]。

负数步长的边界

反向切片(如 t[::-1])中,startend 需注意边界,例如 t[3:0:-1] 表示从索引 3 到 1(不包含 0)。

先明确正向切片的基础规则

对于正向切片 t[start:end:step=1]

  • start 开始(包含 start);
  • end 结束(不包含 end);
  • 步长为 1,依次递增取元素。
    例如 t = [0,1,2,3,4]t[1:4] 表示从索引 1 到 3(不包含 4),结果是 [1,2,3]
反向切片的特殊逻辑(step=-1)

当步长为负数时,切片方向变为 “从后往前”,此时:

  • start 必须 大于 end(否则切片为空);
    仍然遵循 “包含 start,不包含 end” 的规则,但方向相反。

举个栗子:
假设 t = [0, 1, 2, 3, 4](索引 0 到 4),t[3:0:-1]:

  • start=3:从索引 3 开始(包含 3,对应元素 3);
  • end=0:到索引 0 结束(不包含 0,即停止在索引 1);
  • step=-1:每次向前移动 1 位(从高索引到低索引)。

最终结果:[3, 2, 1](从索引 3 到 1,不包含 0)。

对比 t[::-1](完整反向),startend 都省略了,默认:

  • start 为最后一个索引(4);
  • end 为第一个索引的前一位(-1,等效于不包含 0 之前的位置);
  • 步长 - 1,从后往前取所有元素。

结果:[4,3,2,1,0](完整反转)。

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值