模块维度转换2

模块维度转换2

#把一个四维的转为三维的
import torch
x = torch.randn(10,3,32,32)
#BCHW 转BNC
x_out = x.permute(0,2,3,1)
out = x_out.flatten(start_dim=1,end_dim=2)
print(out.shape)
输入张量形状

原始输入 x 的形状为 (10, 3, 32, 32),对应 BCHW 格式:

  • B=10:Batch Size(批量大小)
  • C=3:Channels(通道数,如RGB图像)
  • H=32, W=32:Height(高度)和 Width(宽度)
步骤1:维度重排 permute(0,2,3,1)
  • 操作目标:将 BCHW 转换为 BHWC 格式。
  • 解释
    permute(0,2,3,1) 调整维度顺序为 (B, H, W, C),即:
    • 第0维保持 B=10(批量维度)
    • 第2维(H=32)和第3维(W=32)交换到中间
    • 第1维(C=3)移动到最后
  • 输出形状
    x_out.shape = (10, 32, 32, 3)  # BHWC
    
步骤2:展平空间维度 flatten(start_dim=1, end_dim=2)
  • 操作目标:将 HW 合并为一个维度 N(空间位置总数),得到 BNC 格式。
  • 解释
    • start_dim=1:从第1维(H=32)开始
    • end_dim=2:到第2维(W=32)结束
    • 合并 HW 后的长度为 32*32=1024
  • 输出形状
    out.shape = (10, 1024, 3)  # BNC
    
最终输出
print(out.shape)  # 输出: torch.Size([10, 1024, 3])

代码逻辑总结

  1. 目的:将图像张量从 BCHW 转换为 BNC 格式,适用于需要空间位置编码的任务(如Transformer模型)。
  2. 关键操作
    • permute:调整维度顺序为 BHWC
    • flatten:合并 HW 为单一维度 N
  3. 输出意义
    • B=10:保持批量不变。
    • N=1024:每个图像的空间位置总数(32x32像素)。
    • C=3:原始通道数(如RGB)。

潜在问题与改进

1. 内存连续性
  • 问题permute 操作可能导致张量内存不连续,后续操作(如矩阵乘法)可能报错。
  • 解决:在 flatten 前调用 contiguous()
    x_out = x.permute(0,2,3,1).contiguous()
    out = x_out.flatten(start_dim=1, end_dim=2)
    
2. 更简洁的实现

直接使用 viewreshape(需确保维度顺序正确):

out = x.permute(0,2,3,1).reshape(10, -1, 3)  # 效果相同
3. 应用场景
  • Transformer 输入:将图像转换为序列格式 (B, N, C),用于Vision Transformer(ViT)。
  • 特征嵌入:将空间位置与通道分离,便于后续编码(如位置编码 + 通道投影)。

可视化维度变换流程

原始输入: (B, C, H, W) = (10, 3, 32, 32)
    |
    | permute(0,2,3,1)
    ↓
中间结果: (B, H, W, C) = (10, 32, 32, 3)
    |
    | flatten(1, 2)
    ↓
最终输出: (B, N, C) = (10, 1024, 3)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值