【AI | pytorch】torch.stack操作

torch.stack(
    (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)

用一个具体的实例来帮助你理解 (batch, seq, hidden, n//2, 2) 的操作。假设我们有以下数据和操作:


1. 数据准备

假设 pos_x 的形状是 (1, 1, 2, 4),内容如下:

pos_x = [
    [
        [
            [1, 2, 3, 4],  # hidden[0]
            [5, 6, 7, 8]   # hidden[1]
        ]
    ]
]

2. 奇偶切片

  • pos_x[:, :, :, 0::2] (偶数索引)
    取最后一维的偶数索引 [0, 2]
pos_x[:, :, :, 0::2] = [
    [
        [
            [1, 3],  # hidden[0]
            [5, 7]   # hidden[1]
        ]
    ]
]
  • pos_x[:, :, :, 1::2] (奇数索引)
    取最后一维的奇数索引 [1, 3]
pos_x[:, :, :, 1::2] = [
    [
        [
            [2, 4],  # hidden[0]
            [6, 8]   # hidden[1]
        ]
    ]
]

3. torch.stack 操作

使用 torch.stack 将奇数索引和偶数索引的结果沿新维度 dim=4 拼接:

stacked = torch.stack((pos_x[:, :, :, 0::2], pos_x[:, :, :, 1::2]), dim=4)

结果形状:(1, 1, 2, 2, 2)
内容如下:

stacked = [
    [
        [
            [
                [1, 2],  # 偶数索引的第一个值 + 奇数索引的第一个值
                [3, 4]   # 偶数索引的第二个值 + 奇数索引的第二个值
            ],
            [
                [5, 6],  # hidden[1] 的偶数和奇数
                [7, 8]
            ]
        ]
    ]
]

4. flatten 操作

flatten(3) 展平第 3 维和第 4 维((2, 2) 合并成 4):

flattened = stacked.flatten(3)

结果形状:(1, 1, 2, 4)
内容如下:

flattened = [
    [
        [
            [1, 2, 3, 4],  # hidden[0]
            [5, 6, 7, 8]   # hidden[1]
        ]
    ]
]

总结

  1. stack 将奇偶分离的数据重新组合,形成一个新维度区分奇偶。
  2. flatten 按存储顺序展平新维度,将奇偶数据合并回原来的顺序。最终数据顺序为:
    [1, 2, 3, 4, 5, 6, 7, 8]
    

与原始 pos_x 相同。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值