AF3 Rigid类unsqueeze方法解读

AlphaFold3 rigid_utils 模块 Rigid 类的 unsqueeze 方法是 “维度扩充” 工具,类似于 torch.unsqueeze(),它可以 在指定维度上增加一个大小为 1 的新维度

源代码:

 def unsqueeze(self, 
        dim: int,
    ) -> Rigid:
        """
            Analogous to torch.unsqueeze. The dimension is relative to the
            shared dimensions of the rotation/translation.
            
            Args:
                dim: A positive or negative dimension index.
            Returns:
                The unsqueezed transformation.
        """
        if dim >= len(self.shape):
            raise ValueError("Invalid dimension")
        rots = self._rots.unsqueeze(dim)
        trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)

        return Rigid(rots, trans)

代码解读:

1️⃣ 输入参数
dim: int

dim 是要插入的新维度的位置,可以是正数(从头数起)也可以是负数(从末尾数起)。

2️⃣ 检查合法性
if dim >= len(self.shape):
    raise ValueError("Invalid dimension")

这里确保 dim 不能超过当前 Rigid 对象的总维度数,否则就抛出 ValueError

self.shape 代表的是批次维度的形状。

@property
    def shape(self) -> torch.Size:
        """
            Returns the virtual shape of the rotation object. This shape is
            defined as the batch dimensions of the underlying rotation matrix
            or quaternion. If the Rotation was initialized with a [10, 3, 3]
            rotation matrix tensor, for example, the resulting shape would be
            [10].
        
            Returns:
                The virtual shape of the rotation object
        """
        s = None
        if(self._quats is not None):
            s = self._quats.shape[:-1]
        else:
            s = self._rot_mats.shape[:-2]

        return s

3️⃣ 扩充旋转矩阵的维度

rots = self._rots.unsqueeze(dim)
  • self._rots 是 Rotation 对象,里面存的是旋转矩阵。

  • unsqueeze(dim) 会在指定位置新增维度,比如原来是 [batch_size, 3, 3],在 dim=1 插入维度后变成 [batch_size, 1, 3, 3]

4️⃣ 扩充平移向量的维度
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
  • self._trans 是平移向量 Tensor,在对应位置插入新维度。

  • 这里有个细节:如果 dim 是负数,就需要 减1,确保和正数维度的行为一致。

  • 注: 

  • self._rots 存的是旋转矩阵,形状类似 [*batch, 3, 3]

  • self._trans 存的是平移向量,形状类似 [*batch, 3]

  • 旋转矩阵占了二维,而平移只占一维,所以当dim为负数时,添加维度的位置为dim-1。

使用场景:

1️⃣ 增加批次维度 (Batch Dimension)

在训练蛋白质模型时,需要把单个结构扩充成批次数据输入网络:

rigid_obj = rigid_obj.unsqueeze(0)  # 插入batch维度
2️⃣ 多模型联合推理

比如将多条蛋白质链组合成复合体,每条链有自己的 Rigid 变换,可以扩充维度用来管理多链数据:

multichain_rigid = rigid_obj.unsqueeze(1)  # 扩充链维度
3️⃣ 扩充时间维度 (Trajectory/MD)
  • 如果想在时间序列里追踪蛋白质结构变化,比如分子动力学 (MD) 模拟,可以扩充时间维度:

trajectory_rigid = rigid_obj.unsqueeze(0)  # 插入时间维度

总结

✅ unsqueeze() 方法扩充维度,核心功能类似于 torch.unsqueeze()
✅ 支持正负维度索引,灵活插入
✅ 扩充后的新对象仍然是 Rigid 实例,保持原有方法的兼容性
✅ 适合批次扩充、多链建模、动态追踪等场景

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值