PyTorch使用(5)-张量索引操作

文章目录

  • 1. 简单行、列索引
    • 1.1. 基础用法
    • 1.2. 工程实践要点
  • 2. 列表索引
    • 2.1. 基础用法
    • 2.2. 高级用法
    • 2.3. 性能考虑
  • 3. 范围索引
    • 3.1. 基础用法
    • 3.2. 高级技巧
    • 3.3. 内存特性
  • 4. 布尔索引
    • 4.1. 基础用法
    • 4.2. 高级用法
    • 4.3. 性能注意事项
  • 5. 多维索引
    • 5.1. 基础用法
    • 5.2. 高级模式
    • 5.3. 工程实践
  • 6. 综合性能比较
  • 7. 最佳实践建议

1. 简单行、列索引

简单的行、列索引是最基本的索引操作,通过整数来访问张量中的元素。可以使用类似数组索引的方式来操作。

1.1. 基础用法

import torch

# 创建一个3x4的矩阵
x = torch.tensor([[1, 2, 3, 4],
                 [5, 6, 7, 8],
                 [9, 10, 11, 12]])

# 获取第2行(索引从0开始)
row = x[1]  # tensor([5, 6, 7, 8])

# 获取第3列
col = x[:, 2]  # tensor([3, 7, 11])

1.2. 工程实践要点

内存视图:简单索引返回的是原张量的视图,不复制数据

性能:O(1)时间复杂度,是最快的索引方式

GPU兼容:在CUDA张量上同样高效

# 获取连续多行/多列
rows = x[1:3]  # 第2-3行
cols = x[:, 1:3]  # 第2-3列

2. 列表索引

列表索引是通过一个列表或数组来选择张量中的多个元素。这种索引方式可以选择多个位置的元素,并返回一个新的张量。

2.1. 基础用法

# 使用列表选择特定行
selected_rows = x[[0, 2]]  # 第1和第3行

# 使用列表选择特定列
selected_cols = x[:, [1, 3]]  # 第2和第4列

# 选择特定元素
elements = x[[0, 1, 2], [1, 2, 3]]  # (0,1), (1,2), (2,3)位置的元素

2.2. 高级用法

# 创建索引张量(比Python列表更高效)
indices = torch.tensor([0, 2], device=x.device)
selected = x[indices]  # 第1和第3行

# 组合行列索引
x[[[0], [2]], [1, 3]]  # 第1/3行的第2/4列 → 2x2矩阵

2.3. 性能考虑

内存开销:列表索引会创建新张量,复制数据

替代方案:对于连续索引,优先使用切片

GPU优化:将索引张量放在与数据相同的设备上

3. 范围索引

范围索引允许你选择张量的一个切片,类似于 Python 列表的切片操作。通过起始索引和结束索引来选择一段连续的元素

3.1. 基础用法

# 基本范围切片
sub_matrix = x[0:2, 1:3]  # 第1-2行,第2-3列

# 带步长的范围切片
every_other = x[::2, ::3]  # 每隔一行/三列选取

# 反向索引
reversed_rows = x[::-1]  # 行顺序反转

3.2. 高级技巧

# 创建范围索引张量
range_idx = torch.arange(1, 3)  # tensor([1, 2])
selected = x[range_idx]  # 第2-3行

# 结合步长和偏移
strided = x[1::2, ::2]  # 从第2行开始每隔一行,所有列每隔一个

3.3. 内存特性

连续范围:返回视图,不复制数据

非连续范围:可能触发拷贝操作

最佳实践:尽量使用基础切片而非arange创建的范围

4. 布尔索引

布尔索引是根据条件来选择张量中的元素。它使用一个布尔数组或条件表达式来判断哪些元素符合条件,从而选择它们。

4.1. 基础用法

# 创建布尔掩码
mask = x > 5
# tensor([[False, False, False, False],
#        [False,  True,  True,  True],
#        [ True,  True,  True,  True]])

# 应用布尔索引
selected = x[mask]  # tensor([6, 7, 8, 9, 10, 11, 12])

# 条件赋值
x[x % 2 == 0] = 0  # 将所有偶数置0

4.2. 高级用法

# 多条件组合
mask = (x > 3) & (x < 9)
selected = x[mask]

# 按行/列条件索引
row_mask = torch.any(x > 10, dim=1)  # 选择包含大于10的元素的整行
selected_rows = x[row_mask]

4.3. 性能注意事项

掩码创建:布尔操作会创建临时张量

内存占用:大张量的布尔掩码会消耗大量内存

GPU优势:布尔索引在CUDA上并行化效果极佳

5. 多维索引

多维索引可以是混合多种索引方式,包括整数索引、切片索引、布尔索引等。它让你能够根据复杂的条件或结构对张量进行切片和访问

5.1. 基础用法

# 创建3D张量
y = torch.randn(2, 3, 4)  # batch=2, seq_len=3, features=4

# 各维度单独索引
elem = y[1, 2, 3]  # 第2个batch,第3个序列,第4个特征

# 混合索引方式
sub_tensor = y[1, :, [0, 2]]  # 第2个batch,所有序列,第1和3个特征

5.2. 高级模式

# 使用Ellipsis(...)简化索引
first_batch_all_features = y[0, ...]  # 等价于 y[0, :, :]

# 使用None增加维度
expanded = y[:, None, :, :]  # 在第二维增加一个维度

# 跨维度索引
diag = y.diagonal(dim1=1, dim2=2)  # 获取每个batch的特征对角线

5.3. 工程实践

维度顺序:注意PyTorch的通道优先约定(N, C, H, W)

广播机制:了解索引操作中的广播规则

视图与拷贝:复杂索引可能触发意外拷贝

6. 综合性能比较

操作类型返回视图内存效率GPU加速比适用场景
简单索引10-100x常规子矩阵提取
列表索引5-20x非连续元素选择
范围索引通常10-50x连续区块操作
布尔索引20-100x条件筛选
多维索引有时不定10-50x高维数据操作

使用总结

  • 简单行列索引:基础的整数索引,用来访问单个元素。
  • 列表索引:通过提供一个索引列表或数组来选择多个元素。
  • 范围索引:通过切片来选择张量的一个区间。
  • 布尔索引:通过布尔条件来选择符合条件的元素。
  • 多维索引:通过混合使用不同的索引方式,进行复杂的索引操作。

7. 最佳实践建议

优先使用简单索引:性能最佳,内存最友好

避免频繁的小规模索引:合并多个操作为一个

注意设备一致性:索引张量应与数据在同一设备

利用原地操作:对于大张量修改,使用_后缀方法

预分配内存:对于已知大小的结果,先创建目标张量

# 高效索引操作示例
def efficient_indexing(x, row_indices, col_indices):
    # 预分配结果张量
    result = torch.empty(len(row_indices), 
                        len(col_indices),
                        device=x.device)
    
    # 批量索引操作
    torch.index_select(x, 0, row_indices, out=result)
    torch.index_select(result, 1, col_indices, out=result)
    
    return result
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值