加入优快云的昇思MindSpore社区,学习更多AI知识。
代码路径
resnet.py
和resnet50_distributed_training.py
是训练网络定义脚本,run.sh
是分布式训练执行脚本。
准备数据集
采用CIFAR-10数据集,由10类32*32的彩色图片组成,每类包含6000张图片,其中训练集共50000张图片,测试集共10000张图片。
将数据集下载并解压到本地,解压后文件夹为cifar-10-batches-bin
。
配置分布式环境
CPU上数据并行主要分为单机多节点和多机多节点两种并行方式(一个训练进程可以理解为一个节点)。在运行训练脚本前,需要搭建组网环境,主要是环境变量配置和训练脚本里初始化接口的调用。
环境变量配置如下:
export MS_WORKER_NUM=8 # Worker number export MS_SCHED_HOST=127.0.0.1 # Scheduler IP address export MS_SCHED_PORT=6667 # Scheduler port export MS_ROLE=MS_WORKER # The role of this node: MS_SCHED represents the scheduler, MS_WORKER represents the worker
其中,
-
MS_WORKER_NUM
:表示worker节点数,多机场景下,worker节点数是每机worker节点之和。 -
MS_SCHED_HOST
:表示scheduler节点ip地址。 -
MS_SCHED_PORT
:表示scheduler节点服务端口,用于接收worker节点发送来的ip和服务端口,然后将收集到的所有worker节点ip和端口下发给每个worker。 -
MS_ROLE
:表示节点类型,分为worker(MS_WORKER)和scheduler(MS_SCHED)两种。不管是单机多节点还是多机多节点,都需要配置一个scheduler节点用于组网。
启动训练
在CPU平台上,以单机8节点为例,执行分布式训练。通过以下shell脚本启动训练,指令bash run.sh cifar-10-batches-bin
。
启动后台运行,会有相关日志生成。
打开worker日志,发现出现了异常。是内存不足了。
修改worker数量为2.
worker日志里有训练的loss打印。