AlphaFold3 data_modules 模块的 OpenFoldDataset
类是一个自定义的数据集类,继承自 torch.utils.data.Dataset
。它的目的是在训练时实现 随机过滤器(stochastic filters),用于从多个不同的数据集(OpenFoldSingleDataset
或 OpenFoldSingleMultimerDataset
)中进行样本选择和过滤。该类具有一些 数据增强 和 过滤 的机制,以提高模型的训练质量。
源代码:
class OpenFoldDataset(torch.utils.data.Dataset):
"""
Implements the stochastic filters applied during AlphaFold's training.
Because samples are selected from constituent datasets randomly, the
length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
and filtered once at initialization.
"""
def __init__(self,
datasets: Union[Sequence[OpenFoldSingleDataset], Sequence[OpenFoldSingleMultimerDataset]],
probabilities: Sequence[float],
epoch_len: int,
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
self.datasets = datasets
self.probabilities = probabilities
self.epoch_len = epoch_len
self.generator = generator
self.datapoints = None
self._samples = [self.looped_samples(i) for i in range(len(self.datasets))]
if _roll_at_init:
self.reroll()
@staticmethod
def deterministic_train_filter(
cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
*args, **kwargs
) -> bool:
# Hard filters
resolution = cache_entry.get("resolution", None)
seqs = [cache_entry["seq"]]
return all([resolution_filter(resolution=resolution,
max_resolution=max_resolution),
aa_count_filter(seqs=seqs,
max_single_aa_prop=max_single_aa_prop)])
@staticmethod
def get_stochastic_train_filter_prob(
cache_entry: Any,
*args, **kwargs
) -> float:
# Stochastic filters
probabilities = []
cluster_size = cache_entry.get("cluster_size", None)
if cluster_size is not None and cluster_size > 0:
probabilities.append(1 / cluster_size)
chain_length = len(cache_entry["seq"])