『跟着雨哥学AI』系列之三:详解飞桨框架模型训练

本课程详细介绍了百度飞桨框架中的模型训练流程,包括模型训练的目的、配置、训练方法及评估预测等内容。通过实例展示了不同训练接口的应用。

 

课程简介:

“跟着雨哥学AI”是百度飞桨开源框架近期针对高层API推出的系列课。本课程由多位资深飞桨工程师精心打造,不仅提供了从数据处理、到模型组网、模型训练、模型评估和推理部署全流程讲解;还提供了丰富的趣味案例,旨在帮助开发者更全面清晰地掌握百度飞桨框架的用法,并能够举一反三、灵活使用飞桨框架进行深度学习实践。

图片

1. 什么是模型训练?

在深度学习领域,我们经常听到「模型训练」这一关键词,上节课中我们详细的解释了什么是模型以及模型是如何组建的,这节课我们需要考虑「什么是模型训练?」「为什么要进行模型训练?」「我们如何进行模型的训练?」, 甚至我们还会好奇「模型训练的结果是什么?」 。

 

以识别任务为例,如下图所示,面对大量的数据和素材,我们的目的就是使用我们设计并组建的模型(算法)能够实现对目标进行准确的识别。那么这一目标也就是我们进行模型训练的原动力,为了达到这一目标,我们需要有一个好的算法,而算法对应的就是我们上节课讲述的模型,这套算法里面包含若干的关键权重信息,用于指导每个模型(算法)节点如何对输入数据做特征提取,在刚建设好模型的时候,这些权重信息会随机设置,效果很差,无法直接用于我们的任务使用,我们就需要为算法找到一组最合适的权重参数,这组权重参数就是我们模型训练后得到的结果。

 

总的来说,模型训练其实就是我们使用大量的数据「调教」模型(算法)找出最优权重参数的过程。

图片

那么如何才能进行模型训练呢?

 

2. 模型训练详解

 

以往我们实现模型训练时常常需要面对非常繁杂的代码,要写好多步骤,才能正确的使程序运行起来。这些代码里面包含比较多的概念和接口使用,刚刚上手的同学们一般需要花比较多的时间和精力来弄明白相关的知识和使用方法,使得许多开发者望而却步。

为了解决这种问题,同时满足新手开发者和资深开发者,既能够减少入门的难度,提升开发的效率,又能拥有较好的定制化能力。飞桨框架提供了两种模型训练的方法:一种是基于基础API的常规训练方式;另一种是用paddle.Model对组建好的模型进行封装,通过高层API完成模型的训练与预测,可以在3-5行内,完成模型的训练。前者适合框架经验比较多的资深开发者,而后者极大的简化了学习和开发的代码量,对初学者用户非常友好。

 

接下来我们就进行到详细的讲解环节吧。

 

Note: 高层API实现的模型训练与预测API都可以通过基础API实现,本文着重介绍高层API的训练方式,然后会将高层API拆解为基础API,方便同学们对比学习。

2.1 模型训练配置

 

什么是模型训练配置呢?

 

这里是我们做的一个概念抽象,在模型训练的时候我们需要选用和指定我们要使用的梯度优化器、损失函数计算方法和模型评估指标计算方法,那么我们可以在正式启动训练之前对这些所需的必备内容做一个统一配置。

 

那么如何进行模型训练配置呢?

 

第一步就是我们需要使用paddle.Model接口完成对模型的封装,将网络结构组合成一个可快速使用高层API进行训练和预测的类。完成模型的封装以后,我们便可以使用model.prepare接口实现模型的配置。

 

为了完整的实现这个过程,我们就一起回忆一下前几节课的内容,进行数据的处理和加载、模型组建、模型的封装以及模型配置吧。

 

数据的处理与加载

In [1]
import paddle
import paddle.nn as nn

paddle.__version__
'2.0.0-rc1'
In [2]
import paddle.vision.transforms as T
from paddle.vision.datasets import MNIST


# 数据预处理,这里用到了随机调整亮度、对比度和饱和度
transform = T.Normalize(mean=[127.5], std=[127.5])

# 数据加载,在训练集上应用数据预处理的操作
train_dataset = MNIST(mode='train', transform=transform)
test_dataset = MNIST(mode='test', transform=transform)
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz
Begin to download

Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz
Begin to download
........
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz
Begin to download

Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz
Begin to download
..
Download finished  
模型的组建
In [3]
# 模型的组建
mnist = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 10))
模型的封装
In [5]
# 将网络结构用 Model类封装成为模型
model = paddle.Model(mnist)
模型训练配置
In [6]
# 为模型训练做准备,参数optimizer设置优化器,参数loss损失函数,参数metrics设置精度计算方式
model.prepare(optimizer=paddle.optimizer.Adam(parameters=model.parameters()),
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())

2.2 模型训练

完成模型的配置工作以后,我们可以正式进入模型训练环节。那么如何进行模型训练呢?

飞桨框架进行模型训练有3种方式来完成:

  1. 全流程的训练启动,包含轮次迭代,数据集迭代,模型评估等,我们可以使用高层APImodel.fit接口来完成;

  2. 对于一些单轮训练内部定制化训练计算过程可以自己手写轮次和数据集迭代,通过高层APImodel.train_batch来实现单个批次数据的训练操作,比如训练GAN类型的网络,需要同时训练生成器和辨别器两个网络;

  3. 如果每个计算细节都想自定义来完成,我们也可以直接使用基础API来实现整个训练过程。

接下来我们就给大家展示下3种方式的使用示例代码来具体了解一下接口的使用方式。

 

2.2.1 全流程模型训练model.fit接口

In [7]
# 启动模型训练,指定训练数据集,设置训练轮次,设置每次数据集计算的批次大小,设置日志格式
model.fit(train_dataset, epochs=5, batch_size=64, verbose=1)
The loss value printed in the log is the current step, and the metric is the average value of previous step.
Epoch 1/5
step  10/938 [..............................] - loss: 0.9135 - acc: 0.3906 - ETA: 26s - 29ms/step
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or&n
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值