在分布式模式下,需要在每个 epoch 开始时调用 set_epoch() 方法,然后再创建 DataLoader 迭代器,以使 shuffle 操作能够在多个 epoch 中正常工作。 否则,dataloader迭代器产生的数据将始终使用相同的顺序。
sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None),
sampler=sampler)
for epoch in range(start_epoch, n_epochs):
if is_distributed:
sampler.set_epoch(epoch)
train(loader)
参考:
本文解释了在分布式学习环境下,如何在PyTorch中正确设置epoch并确保DistributedSampler的shuffle功能有效,避免数据加载顺序固定。务必在每个epoch开始时调用set_epoch(),以实现每轮数据的重新打乱。
https://zhuanlan.zhihu.com/p/97115875
5872

被折叠的 条评论
为什么被折叠?



