AF3 FeaturePipeline类解读

AlphaFold3 feature_pipeline 模块 FeaturePipeline 类是一个封装类,通过调用函数np_example_to_features 实现整个数据处理流程

源代码:

def np_to_tensor_dict(
    np_example: Mapping[str, np.ndarray],
    features: Sequence[str],
) -> TensorDict:
    """Creates dict of tensors from a dict of NumPy arrays.

    Args:
        np_example: A dict of NumPy feature arrays.
        features: A list of strings of feature names to be returned in the dataset.

    Returns:
        A dictionary of features mapping feature names to features. Only the given
        features are returned, all other ones are filtered out.
    """
    # torch generates warnings if feature is already a torch Tensor
    to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
    tensor_dict = {
        k: to_tensor(v) for k, v in np_example.items() if k in features
    }

    return tensor_dict


def make_data_config(
    config: ml_collections.ConfigDict,
    mode: str,
    num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
    cfg = copy.deepcopy(config)
    mode_cfg = cfg[mode]
    # with cfg.unlocked():
    if mode_cfg.crop_size is None:
        mode_cfg.crop_size = num_res

    feature_names = cfg.common.unsupervised_features

    # Add seqemb related features if using seqemb mode.
    if cfg.seqemb_mode.enabled:
        feature_names += cfg.common.seqemb_features

    if cfg.common.use_templates:
        feature_names += cfg.common.template_features

    if cfg[mode].supervised:
        feature_names += cfg.supervised.supervised_features

    return cfg, feature_names


def np_example_to_features(
    np_example: FeatureDict,
    config: ml_collections.ConfigDict,
    mode: str,
    is_multimer: bool = False
):
    np_example = dict(np_example)

    seq_length = np_example["seq_length"]
    num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
    cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
 
    if "deletion_matrix_int" in np_example:
        np_example["deletion_matrix"] = np_example.pop(
            "deletion_matrix_int"
        ).astype(np.float32)

    tensor_dict = np_to_tensor_dict(
        np_example=np_example, features=feature_names
    )

    with torch.no_grad():
        if is_multimer:
            features = input_pipeline_multimer.process_tensors_from_config(
                tensor_dict,
                cfg.common,
                cfg[mode],
            )
        else:
            features = input_pipeline.process_tensors_from_config(
                tensor_dict,
                cfg.common,
                cfg[mode],
            )

    if mode == "train":
        p = torch.rand(1).item()
        use_clamped_fape_value 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值