AlphaFold3 data_modules 模块的 OpenFoldMultimerDataset
类 looped_samples 方法是一个生成器,它会无限生成某个子数据集(指定 dataset_idx
)中、经过过滤后的有效数据样本下标,用于训练过程中的采样。
源代码:
def looped_samples(self, dataset_idx):
max_cache_len = int(self.epoch_len * self.probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
is_distillation = dataset.treat_pdb_as_distillation
idx_iter = self.looped_shuffled_dataset_idx(len(dataset))
mmcif_data_cache = dataset.mmcif_data_cache
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
mmcif_id = dataset.idx_to_mmcif_id(candidate_idx)
mmcif_data_cache_entry = mmcif_data_cache[mmcif_id]
if not self.deterministic_train_filter(cache_entry=mmcif_data_cache_entry,
is_distillation=is_distillation):
continue
chain_probs = self.get_stochastic_train_filter_prob(
mmcif_data_cache_entry,
)
weights.extend([[1. - p, p] for p in chain_probs])
idx.extend([candidate_idx] * len(chain_probs))
samples = torch.multinomial(
torch.tensor(weights),
num_samples=1,