数据
- 数据收集: 收集原始样本和标签,如Img和Label。
- 数据划分: 划分成训练集train,用来训练模型;验证集valid,验证模型是否过拟合,挑选还没有过拟合的时候的模型;测试集test,测试挑选出来的模型的性能。
- 数据读取: PyTorch中数据读取的核心是Dataloader。Dataloader分为Sampler和DataSet两个子模块。Sampler的功能是生成索引,即样本序号;DataSet的功能是根据索引读取样本和标签。
- 数据预处理: 数据的中心和,标准化,旋转,翻转等,在PyTorch中是通过transforms实现的。
选择模型
根据任务的难易程度选择简单的线性模型或者复杂的神经网络模型。
损失函数
根据不同的任务选择不同的损失函数,比如线性回归中选择均方差损失函数,分类选择交叉熵。
优化器
有了loss就可以求取梯度,得到梯度,用优化器更新权值。
迭代训练
反复训练的过程