模块维度转换2
#把一个四维的转为三维的
import torch
x = torch.randn(10,3,32,32)
#BCHW 转BNC
x_out = x.permute(0,2,3,1)
out = x_out.flatten(start_dim=1,end_dim=2)
print(out.shape)
输入张量形状
原始输入 x
的形状为 (10, 3, 32, 32)
,对应 BCHW 格式:
B=10
:Batch Size(批量大小)C=3
:Channels(通道数,如RGB图像)H=32
,W=32
:Height(高度)和 Width(宽度)
步骤1:维度重排 permute(0,2,3,1)
- 操作目标:将
BCHW
转换为BHWC
格式。 - 解释:
permute(0,2,3,1)
调整维度顺序为(B, H, W, C)
,即:- 第0维保持
B=10
(批量维度) - 第2维(
H=32
)和第3维(W=32
)交换到中间 - 第1维(
C=3
)移动到最后
- 第0维保持
- 输出形状:
x_out.shape = (10, 32, 32, 3) # BHWC
步骤2:展平空间维度 flatten(start_dim=1, end_dim=2)
- 操作目标:将
H
和W
合并为一个维度N
(空间位置总数),得到BNC
格式。 - 解释:
start_dim=1
:从第1维(H=32
)开始end_dim=2
:到第2维(W=32
)结束- 合并
H
和W
后的长度为32*32=1024
- 输出形状:
out.shape = (10, 1024, 3) # BNC
最终输出
print(out.shape) # 输出: torch.Size([10, 1024, 3])
代码逻辑总结
- 目的:将图像张量从 BCHW 转换为 BNC 格式,适用于需要空间位置编码的任务(如Transformer模型)。
- 关键操作:
permute
:调整维度顺序为BHWC
。flatten
:合并H
和W
为单一维度N
。
- 输出意义:
B=10
:保持批量不变。N=1024
:每个图像的空间位置总数(32x32像素)。C=3
:原始通道数(如RGB)。
潜在问题与改进
1. 内存连续性
- 问题:
permute
操作可能导致张量内存不连续,后续操作(如矩阵乘法)可能报错。 - 解决:在
flatten
前调用contiguous()
:x_out = x.permute(0,2,3,1).contiguous() out = x_out.flatten(start_dim=1, end_dim=2)
2. 更简洁的实现
直接使用 view
或 reshape
(需确保维度顺序正确):
out = x.permute(0,2,3,1).reshape(10, -1, 3) # 效果相同
3. 应用场景
- Transformer 输入:将图像转换为序列格式
(B, N, C)
,用于Vision Transformer(ViT)。 - 特征嵌入:将空间位置与通道分离,便于后续编码(如位置编码 + 通道投影)。
可视化维度变换流程
原始输入: (B, C, H, W) = (10, 3, 32, 32)
|
| permute(0,2,3,1)
↓
中间结果: (B, H, W, C) = (10, 32, 32, 3)
|
| flatten(1, 2)
↓
最终输出: (B, N, C) = (10, 1024, 3)