PyTorch使用(4)-张量拼接操作

文章目录

  • 张量拼接操作
  • 1. torch.cat 函数的使用
      • 1.1. torch.cat 定义
      • 1.2. 语法
      • 1.3. 关键规则
    • 1.4. 示例代码
      • 1.4.1. 沿行拼接(dim=0)
      • 1.4.2. 沿列拼接(dim=1)
      • 1.4.3. 高维拼接(dim=2)
    • 1.5. 错误场景分析
      • 1.5.1. 维度数不一致
      • 1.5.2. 非拼接维度大小不匹配
      • 1.5.3. 设备或数据类型不一致
      • 1.6. 与 torch.stack 的区别
    • 1.7. 高级用法
      • 1.7.1. 批量拼接(Batch-wise Concatenation)
      • 1.7.2. 自动广播支持
    • 1.8. 总结
  • 2. torch.stack 函数的使用
    • 2.1. 函数定义
    • 2.2. 核心规则
    • 2.3. 使用示例
    • 2.4. 与 torch.cat 的对比
    • 2.4. 常见错误与调试
    • 2.5. 工程实践技巧
    • 2.7. 性能优化建议
    • 2.8. 总结

张量拼接操作

1. torch.cat 函数的使用

在 PyTorch 中,torch.cat 是用于沿指定维度拼接多个张量的核心函数

1.1. torch.cat 定义

功能: 将多个张量沿指定维度(dim)拼接,生成新张量。

输入要求:

所有输入张量的 维度数必须相同。

非拼接维度的大小必须一致。

张量必须位于 同一设备 且 数据类型相同。

1.2. 语法

torch.cat(tensors, dim=0, *, out=None) → Tensor

参数:

tensors (sequence of Tensors):需拼接的张量序列(列表或元组)。

dim (int, optional):拼接的维度索引,默认为 0。

out (Tensor, optional):可选输出张量。

1.3. 关键规则

规则示例
输入张量维度数必须相同不允许将 2D 张量与 3D 张量拼接
非拼接维度大小必须一致若 dim=1,所有张量的 dim=0、dim=2 等大小必须相同
拼接维度大小可以不同沿 dim=0 拼接形状为 (2, 3) 和 (3, 3) 的张量,结果为 (5, 3)
输出维度数与输入相同输入均为 3D 张量,输出仍为 3D 张量

1.4. 示例代码

1.4.1. 沿行拼接(dim=0)

import torch

A = torch.tensor([[1, 2], [3, 4]])    # shape: (2, 2)
B = torch.tensor([[5, 6], [7, 8]])    # shape: (2, 2)
C = torch.cat([A, B], dim=0)          # shape: (4, 2)
print(C)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6],
#         [7, 8]])

1.4.2. 沿列拼接(dim=1)

D = torch.tensor([[9], [10]])          # shape: (2, 1)
E = torch.cat([A, D], dim=1)          # shape: (2, 3)
print(E)
# 输出:
# tensor([[ 1,  2,  9],
#         [ 3,  4, 10]])

1.4.3. 高维拼接(dim=2)

F = torch.randn(2, 3, 4)              # shape: (2, 3, 4)
G = torch.randn(2, 3, 5)              # shape: (2, 3, 5)
H = torch.cat([F, G], dim=2)          # shape: (2, 3, 9)

1.5. 错误场景分析

1.5.1. 维度数不一致

A_2D = torch.randn(2, 3)
B_3D = torch.randn(2, 3, 4)
try:
    torch.cat([A_2D, B_3D], dim=0)  # 报错:维度数不同
except RuntimeError as e:
    print("错误:", e)

1.5.2. 非拼接维度大小不匹配

A = torch.randn(2, 3)
B = torch.randn(3, 3)              # dim=0 大小不同
try:
    torch.cat([A, B], dim=1)       # 报错:非拼接维度大小不一致
except RuntimeError as e:
    print("错误:", e)

1.5.3. 设备或数据类型不一致

if torch.cuda.is_available():
    A_cpu = torch.randn(2, 3)
    B_gpu = torch.randn(2, 3).cuda()
    try:
        torch.cat([A_cpu, B_gpu], dim=0)  # 报错:设备不一致
    except RuntimeError as e:
        print("错误:", e)

1.6. 与 torch.stack 的区别

函数输入维度输出维度核心用途
torch.cat所有张量维度相同维度数与输入相同沿现有维度扩展张量
torch.stack所有张量形状严格相同新增一个维度创建新维度合并张量

示例对比:

A = torch.tensor([1, 2])          # shape: (2)
B = torch.tensor([3, 4])          # shape: (2)

# cat 沿 dim=0
C_cat = torch.cat([A, B])         # shape: (4)

# stack 沿 dim=0
C_stack = torch.stack([A, B])     # shape: (2, 2)

1.7. 高级用法

1.7.1. 批量拼接(Batch-wise Concatenation)

# 批量数据拼接(batch_size=2)
batch_A = torch.randn(2, 3, 4)    # shape: (2, 3, 4)
batch_B = torch.randn(2, 3, 5)    # shape: (2, 3, 5)
batch_C = torch.cat([batch_A, batch_B], dim=2)  # shape: (2, 3, 9)

1.7.2. 自动广播支持

torch.cat 不支持广播,必须显式匹配形状:

A = torch.randn(3, 1)            # shape: (3, 1)
B = torch.randn(1, 3)            # shape: (1, 3)
try:
    torch.cat([A, B], dim=1)     # 报错:非拼接维度大小不一致
except RuntimeError as e:
    print("错误:", e)

1.8. 总结

适用场景:合并同维度的特征、批量数据拼接等。

核心规则

1、输入张量维度数相同。

2、非拼接维度大小严格一致。

3、设备与数据类型一致。

优先使用 torch.cat:当需要在现有维度扩展时;需新增维度时选择 torch.stack。

2. torch.stack 函数的使用

2.1. 函数定义

torch.stack(tensors, dim=0, *, out=None) → Tensor

功能:将多个张量沿新维度堆叠(非拼接),要求所有输入张量形状严格相同。

  • 输入:
    • tensors (sequence of Tensors):形状相同的张量序列(列表/元组)。
    • dim (int):新维度的插入位置(支持负数索引)。
  • 输出:
    • 比输入张量多一维的新张量。

2.2. 核心规则

规则示例
输入张量形状必须完全相同(3, 4) 只能与 (3, 4) 堆叠,不能与 (3, 5) 堆叠
输出维度 = 输入维度 + 1输入(3, 4) → 输出 (n, 3, 4)(n为堆叠数量)
新维度大小 = 张量数量堆叠3个张量 → 新维度大小为3
设备/数据类型必须一致所有张量需在同一设备(CPU/GPU)且 dtype 相同

2.3. 使用示例

(1) 基础用法

import torch
# 定义两个相同形状的张量
A = torch.tensor([1, 2, 3])      # shape: (3,)
B = torch.tensor([4, 5, 6])      # shape: (3,)

# 沿新维度0堆叠
C = torch.stack([A, B])          # shape: (2, 3)
print(C)
# tensor([[1, 2, 3],
#         [4, 5, 6]])

# 沿新维度1堆叠
D = torch.stack([A, B], dim=1)   # shape: (3, 2)
print(D)
# tensor([[1, 4],
#         [2, 5],
#         [3, 6]])

(2) 高维张量堆叠

# 形状为 (2, 3) 的张量
X = torch.randn(2, 3)
Y = torch.randn(2, 3)

# 沿dim=0堆叠(新增最外层维度)
Z0 = torch.stack([X, Y])         # shape: (2, 2, 3)

# 沿dim=1堆叠(插入到第二维)
Z1 = torch.stack([X, Y], dim=1)  # shape: (2, 2, 3)

# 沿dim=-1堆叠(插入到最后一维)
Z2 = torch.stack([X, Y], dim=-1) # shape: (2, 3, 2)

(3) 批量数据构建

# 模拟批量图像数据(单张图像shape: (3, 32, 32))
image1 = torch.randn(3, 32, 32)
image2 = torch.randn(3, 32, 32)
image3 = torch.randn(3, 32, 32)

# 构建batch维度(batch_size=3)
batch = torch.stack([image1, image2, image3])  # shape: (3, 3, 32, 32)

2.4. 与 torch.cat 的对比

特性 torch.stack torch.cat
输入要求 所有张量形状严格相同 仅需非拼接维度相同
输出维度 比输入多1维 与输入维度相同
内存开销 更高(新增维度) 更低(复用现有维度)
典型场景 构建batch、新增序列维度 合并特征、扩展现有维度
示例对比:

A = torch.tensor([1, 2])
B = torch.tensor([3, 4])

# stack -> 新增维度
stacked = torch.stack([A, B])    # shape: (2, 2)

# cat -> 沿现有维度扩展
concatenated = torch.cat([A, B]) # shape: (4)

2.4. 常见错误与调试

(1) 形状不匹配

A = torch.randn(2, 3)
B = torch.randn(2, 4)  # 第二维不同
try:
    torch.stack([A, B])
except RuntimeError as e:
    print("Error:", e)  # Sizes of tensors must match

(2) 设备不一致

A_cpu = torch.randn(3, 4)
B_gpu = torch.randn(3, 4).cuda()
try:
    torch.stack([A_cpu, B_gpu])
except RuntimeError as e:
    print("Error:", e)  # Expected all tensors to be on the same device

(3) 空张量处理

empty_tensors = [torch.tensor([]) for _ in range(3)]
try:
    torch.stack(empty_tensors)  # 可能引发未定义行为
except RuntimeError as e:
    print("Error:", e)

2.5. 工程实践技巧

(1) 批量数据预处理

# 从数据加载器中逐批读取数据并堆叠
batch_images = []
for image in dataloader:
    batch_images.append(image)
    if len(batch_images) == batch_size:
        batch = torch.stack(batch_images)  # shape: (batch_size, C, H, W)
        process_batch(batch)
        batch_images = []

(2) 序列建模中的时间步堆叠

# RNN输入序列构建(T个时间步,每个步长特征dim=D)
time_steps = [torch.randn(1, D) for _ in range(T)]
input_seq = torch.stack(time_steps, dim=1)  # shape: (1, T, D)

(3) 多任务输出合并

# 多任务学习中的输出堆叠
task1_out = torch.randn(batch_size, 10)
task2_out = torch.randn(batch_size, 5)
multi_out = torch.stack([task1_out, task2_out], dim=1)  # shape: (batch_size, 2, ...)

2.7. 性能优化建议

避免循环中频繁堆叠:优先在内存中收集所有张量后一次性堆叠。

# 低效做法
result = None
for x in data_stream:
    if result is None:
        result = x.unsqueeze(0)
    else:
        result = torch.stack([result, x.unsqueeze(0)])

# 高效做法
tensor_list = [x for x in data_stream]
result = torch.stack(tensor_list)

显存不足时考虑分块处理:

chunk_size = 1000
for i in range(0, len(big_list), chunk_size):
    chunk = torch.stack(big_list[i:i+chunk_size])
    process(chunk)

2.8. 总结

核心用途:构建batch、新增维度、多任务输出整合。

关键检查点:

  • 输入张量形状完全一致。
  • 设备与数据类型统一。
  • 合理选择 dim 参数控制维度扩展位置。

优先选择场景:当需要显式创建新维度时使用;若仅需扩展现有维度,用 torch.cat 更高效。

<think>嗯,用户想知道在PyTorch中如何进行张量拼接,特别是torch.cat和torch.stack的使用方法和示例。首先,我需要回忆这两个函数的具体用法和区别。 根据引用中的资料,torch.cat用于在已有的维度上拼接张量。比如,如果两个张量的形状都是(2,3),用dim=0的话,拼接后的形状应该是(4,3),而dim=1的话则是(2,6)。这需要确保除了拼接维度外,其他维度的大小相同。用户可能需要看到具体的例子,所以应该准备一些代码示例,比如创建两个张量然后进行cat操作。 然后是torch.stack,它会在新的维度上拼接张量。比如,两个(2,3)张量用stack之后,如果dim=0的话,形状会是(2,2,3),也就是新增了一个维度。这个函数要求所有张量的形状完全相同。这时候可能需要对比cat和stack的不同,比如在同样的输入下,cat和stack的输出形状差异。引用中提到stack可以看作是增加维度,而cat是在现有维度上扩展,这点需要明确说明。 接下来,用户可能需要知道两者的区别,比如是否需要相同维度,以及结果维度的变化。例如,使用cat时,拼接后的维度是原维度之和,而stack会新增一个维度。这时候需要举例子说明,比如引用中的例子,三个(7,4)张量用stack后变成(3,7,4),而用cat的话如果在dim=0上则是(21,4),不过这可能要看具体的dim参数选择。 还要注意错误情况的处理,比如如果张量形状不匹配,使用这两个函数会报错。这时候需要提醒用户注意张量的形状是否一致,特别是在使用stack的时候必须完全相同,而cat则只需要在非拼接维度上一致。 另外,用户的问题中提到需要示例代码,所以应该准备好使用简单张量进行演示,比如创建x和y,然后展示cat和stack的不同结果。同时,要解释dim参数的影响,比如dim=0和dim=1时的不同拼接方式。 可能还需要提到其他相关的函数,比如torch.chunk和torch.split,但用户主要问的是cat和stack,所以可以简要带过,如引用中提到的cat是split和chunk的逆向操作,但重点还是放在用户的问题上。 最后,确保回答结构清晰,分点说明两个函数的使用方法、示例代码以及区别,并生成相关问题。注意引用标注的位置,比如在提到cat和stack的不同时引用对应的文献。</think>在PyTorch中,张量拼接主要通过`torch.cat`和`torch.stack`实现。以下是具体方法、示例及区别说明: --- ### 1. `torch.cat()`:沿现有维度拼接张量 **核心规则**:所有张量在非拼接维度上的形状必须一致[^1]。 - **语法**:`torch.cat(tensors, dim=0, out=None)` - **示例**: ```python import torch x = torch.tensor([[1, 2], [3, 4]]) # shape (2,2) y = torch.tensor([[5, 6], [7, 8]]) # shape (2,2) # 沿行(dim=0)拼接,结果形状 (4,2) result_cat0 = torch.cat([x, y], dim=0) print(result_cat0) # [[1,2], [3,4], [5,6], [7,8]] # 沿列(dim=1)拼接,结果形状 (2,4) result_cat1 = torch.cat([x, y], dim=1) print(result_cat1) # [[1,2,5,6], [3,4,7,8]] ``` --- ### 2. `torch.stack()`:沿新维度拼接张量 **核心规则**:所有张量形状必须完全相同[^3]。 - **语法**:`torch.stack(tensors, dim=0, out=None)` - **示例**: ```python x = torch.tensor([1, 2]) # shape (2) y = torch.tensor([3, 4]) # shape (2) # 新增第0维度,结果形状 (2,2) result_stack0 = torch.stack([x, y], dim=0) print(result_stack0) # [[1,2], [3,4]] # 新增第1维度,结果形状 (2,2) result_stack1 = torch.stack([x, y], dim=1) print(result_stack1) # [[1,3], [2,4]] ``` --- ### 3. `torch.cat`与`torch.stack`的区别 | 特性 | `torch.cat` | `torch.stack` | |---------------------|---------------------------------------|----------------------------------------| | **输入要求** | 仅在拼接维度外形状一致[^2] | 所有张量形状完全相同[^4] | | **输出维度变化** | 沿现有维度扩展(如从(2,2)(4,2)) | 新增维度(如从(2,2)(1,2,2)) | | **典型应用场景** | 合并同类数据(如多个特征矩阵拼接) | 构建批次数据或通道叠加(如RGB图像) | --- ### 示例对比 ```python a = torch.randn(2, 3) b = torch.randn(2, 3) cat_result = torch.cat([a, b], dim=0) # 形状 (4,3) stack_result = torch.stack([a, b], dim=0) # 形状 (2,2,3) ``` ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值