AlphaFold3 data_modules 模块的 OpenFoldMultimerDataset
类deterministic_train_filter
方法是该类中最关键的一个静态方法之一。这个方法用于 训练时对样本进行确定性过滤(hard filters),只有通过这些过滤条件的样本才可能参与后续训练。
源代码:
@staticmethod
def deterministic_train_filter(
cache_entry: Any,
is_distillation: bool,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
minimum_number_of_residues: int = 200,
*args, **kwargs
) -> bool:
"""
Implement multimer training filtering criteria described in
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf Supplementary section 7.1
"""
resolution = cache_entry.get("resolution", None)
seqs = cache_entry["seqs"]
return all([resolution_filter(resolution=resolution,
max_resolution=max_resolution),
aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop),
(not is_distillation or all_seq_len_filter(seqs=seqs,
minimum_number_of_residues=minimum_number_of_residues))])
代码解读:
参数解释:
-
cache_entry
: 一个 dict,包含蛋白的缓存信息(如序列、分辨率等)。 -
is_distillation
: 表示这个数据是否是从结构数据库中**蒸馏(distilled)**出来的。如果是,则更严格地限制序列长度。 -
max_resolution
: 最大允许的分辨率(默认为 9Å)。 -
max_single_aa_prop
: 允许出现最多的氨基酸在序列中所占的最大比例(默认 0.8)。 -
minimum_number_of_residues
: 蒸馏样本中每条链至少要有的残基数(默认 200 个)。
过滤逻辑:
该方法使用了三个过滤器函数,要求都必须通过(all([...])
):
条件 | 使用的函数 | 含义 |
---|---|---|
分辨率过滤 | resolution_filter | 分辨率必须小于等于 9Å(或指定值) |
氨基酸单一比例过滤 | aa_count_filter | 不能有某种氨基酸比例 > 0.8 |
序列长度过滤(可选) | all_seq_len_filter | 如果是蒸馏数据,则所有链总长度 ≥ 200 |