AlphaFold3 rigid_utils 模块 Rigid 类的 unsqueeze
方法是 “维度扩充” 工具,类似于 torch.unsqueeze()
,它可以 在指定维度上增加一个大小为 1 的新维度。
源代码:
def unsqueeze(self,
dim: int,
) -> Rigid:
"""
Analogous to torch.unsqueeze. The dimension is relative to the
shared dimensions of the rotation/translation.
Args:
dim: A positive or negative dimension index.
Returns:
The unsqueezed transformation.
"""
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
rots = self._rots.unsqueeze(dim)
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
return Rigid(rots, trans)
代码解读:
1️⃣ 输入参数
dim: int
dim
是要插入的新维度的位置,可以是正数(从头数起)也可以是负数(从末尾数起)。
2️⃣ 检查合法性
if dim >= len(self.shape):
raise ValueError("Invalid dimension")
这里确保 dim
不能超过当前 Rigid
对象的总维度数,否则就抛出 ValueError。
self.shape
代表的是批次维度的形状。
@property
def shape(self) -> torch.Size:
"""
Returns the virtual shape of the rotation object. This shape is
defined as the batch dimensions of the underlying rotation matrix
or quaternion. If the Rotation was initialized with a [10, 3, 3]
rotation matrix tensor, for example, the resulting shape would be
[10].
Returns:
The virtual shape of the rotation object
"""
s = None
if(self._quats is not None):
s = self._quats.shape[:-1]
else:
s = self._rot_mats.shape[:-2]
return s
3️⃣ 扩充旋转矩阵的维度
rots = self._rots.unsqueeze(dim)
-
self._rots
是Rotation
对象,里面存的是旋转矩阵。 -
unsqueeze(dim)
会在指定位置新增维度,比如原来是[batch_size, 3, 3]
,在dim=1
插入维度后变成[batch_size, 1, 3, 3]
。
4️⃣ 扩充平移向量的维度
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
-
self._trans
是平移向量Tensor
,在对应位置插入新维度。 -
这里有个细节:如果
dim
是负数,就需要 减1,确保和正数维度的行为一致。 -
注:
-
self._rots
存的是旋转矩阵,形状类似[*batch, 3, 3]
-
self._trans
存的是平移向量,形状类似[*batch, 3]
-
旋转矩阵占了二维,而平移只占一维,所以当dim为负数时,添加维度的位置为dim-1。
使用场景:
1️⃣ 增加批次维度 (Batch Dimension)
在训练蛋白质模型时,需要把单个结构扩充成批次数据输入网络:
rigid_obj = rigid_obj.unsqueeze(0) # 插入batch维度
2️⃣ 多模型联合推理
比如将多条蛋白质链组合成复合体,每条链有自己的 Rigid
变换,可以扩充维度用来管理多链数据:
multichain_rigid = rigid_obj.unsqueeze(1) # 扩充链维度
3️⃣ 扩充时间维度 (Trajectory/MD)
-
如果想在时间序列里追踪蛋白质结构变化,比如分子动力学 (MD) 模拟,可以扩充时间维度:
trajectory_rigid = rigid_obj.unsqueeze(0) # 插入时间维度
总结
✅ unsqueeze()
方法扩充维度,核心功能类似于 torch.unsqueeze()
✅ 支持正负维度索引,灵活插入
✅ 扩充后的新对象仍然是 Rigid 实例,保持原有方法的兼容性
✅ 适合批次扩充、多链建模、动态追踪等场景