基于PyTorch的图像分类全流程指南
1. 数据加载器的构建
在将数据导入到PyTorch中时,我们可以使用以下几行Python代码来构建数据加载器:
batch_size=64
train_data_loader = data.DataLoader(train_data, batch_size=batch_size)
val_data_loader = data.DataLoader(val_data, batch_size=batch_size)
test_data_loader = data.DataLoader(test_data, batch_size=batch_size)
这里的 batch_size 是一个关键参数,它决定了在训练和更新网络之前,有多少张图像会通过网络。理论上,我们可以将 batch_size 设置为测试集和训练集中图像的数量,这样网络在更新之前会看到每一张图像。但在实践中,我们通常不会这样做,因为较小的批次(文献中更常见的称为小批量)比存储数据集中每一张图像的所有信息所需的内存更少,而且较小的批次大小会使训练更快,因为我们可以更频繁地更新网络。
默认情况下,PyTorch的数据加载器的 batch_size 设置为1,你几乎肯定需要更改这个值。虽然这里选择了64,但你可以进行实验,看看在不耗尽GPU内存的情况下,能使用多大的小批量。此外,你还可以尝试一些其他的参数,例如指定数据集的采样方式、每次运行时是否打乱整个数据集,以及使用多少个工作进程从数据集中提取数据,
超级会员免费看
订阅专栏 解锁全文
82

被折叠的 条评论
为什么被折叠?



