在使用 PyTorch 进行深度学习时,张量形状与广播机制常常是让初学者感到困惑的地方。我们需要时常面对多维张量,并在批量、通道、空间位置等多个维度之间做运算。
如果能熟练掌握各种维度变换操作——包括 unsqueeze
、expand
、view
/reshape
、transpose
/permute
等,可以帮助我们灵活地操纵张量,写出高效而简洁的矩阵化(vectorized)代码。本文将重点聚焦于 expand
以及与之密切相关的“维度扩展”技巧和其底层原理。
一、为什么需要 expand
?
许多同学初次接触 PyTorch 的广播(broadcasting)时,往往只知道 +
、*
等运算符会自动广播。确实,在能自动广播的场合,可能不一定要手动去调用 expand
。但实际开发中,有些场景我们想精确控制哪一个维度被复制、如何被复制,就会手动用到 expand
:
- 构造大维度掩码(mask)。例如,需要对
(B, K, H*W)
的张量,在“anchor 样本 b”与“被遍历样本 b2”两大维度上进行运算,那么就会创建(B, B, K, H*W)
的大张量。如果仅靠自动广播&#