torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
用一个具体的实例来帮助你理解 (batch, seq, hidden, n//2, 2)
的操作。假设我们有以下数据和操作:
1. 数据准备
假设 pos_x
的形状是 (1, 1, 2, 4)
,内容如下:
pos_x = [
[
[
[1, 2, 3, 4], # hidden[0]
[5, 6, 7, 8] # hidden[1]
]
]
]
2. 奇偶切片
pos_x[:, :, :, 0::2]
(偶数索引):
取最后一维的偶数索引[0, 2]
:
pos_x[:, :, :, 0::2] = [
[
[
[1, 3], # hidden[0]
[5, 7] # hidden[1]
]
]
]
pos_x[:, :, :, 1::2]
(奇数索引):
取最后一维的奇数索引[1, 3]
:
pos_x[:, :, :, 1::2] = [
[
[
[2, 4], # hidden[0]
[6, 8] # hidden[1]
]
]
]
3. torch.stack
操作
使用 torch.stack
将奇数索引和偶数索引的结果沿新维度 dim=4
拼接:
stacked = torch.stack((pos_x[:, :, :, 0::2], pos_x[:, :, :, 1::2]), dim=4)
结果形状:(1, 1, 2, 2, 2)
内容如下:
stacked = [
[
[
[
[1, 2], # 偶数索引的第一个值 + 奇数索引的第一个值
[3, 4] # 偶数索引的第二个值 + 奇数索引的第二个值
],
[
[5, 6], # hidden[1] 的偶数和奇数
[7, 8]
]
]
]
]
4. flatten
操作
flatten(3)
展平第 3 维和第 4 维((2, 2)
合并成 4
):
flattened = stacked.flatten(3)
结果形状:(1, 1, 2, 4)
内容如下:
flattened = [
[
[
[1, 2, 3, 4], # hidden[0]
[5, 6, 7, 8] # hidden[1]
]
]
]
总结
stack
将奇偶分离的数据重新组合,形成一个新维度区分奇偶。flatten
按存储顺序展平新维度,将奇偶数据合并回原来的顺序。最终数据顺序为:[1, 2, 3, 4, 5, 6, 7, 8]
与原始 pos_x
相同。