本实践使用多层感知器训练(DNN)模型,用于预测手写数字图片。
以下是基于人工神经网络在PaddlePaddle平台上进行MNIST手写字符分类的实验目的和步骤:
实验目的
掌握人工神经网络的代码实现:通过实践熟悉PaddlePaddle框架中人工神经网络的构建方法,理解输入层、隐藏层、输出层等模块的代码实现及参数设置。
学习深度学习模型的参数调优技术:探究人工神经网络中关键超参数(如学习率、隐藏层神经元数量、激活函数等)对模型性能的影响,掌握使用验证集或交叉验证进行参数调优的方法。
可视化分析能力培养:通过绘制训练过程中的损失曲线、准确率曲线,直观理解人工神经网络的训练过程和分类原理。
完整实验流程实践:从数据预处理、模型构建、训练到模型评估,掌握深度学习项目全流程,并撰写规范的实验报告。
实验步骤
一、数据准备与预处理
加载数据集:从指定路径加载MNIST手写字符数据集,数据集包含60000张训练图像和10000张测试图像,每张图像为28×28的灰度图像。
数据探索与可视化:查看数据集的基本信息,如图像数量、类别分布等。随机选取部分图像,使用matplotlib等工具绘制图像,观察不同字符图像的形状和灰度分布。
数据预处理:
归一化处理:将图像的像素值从0-255缩放到0-1范围,加速模型训练。
数据增强:对训练集图像进行随机旋转、平移等操作,增加数据的多样性,防止过拟合。
二、模型训练与调参
基础人工神经网络模型训练:构建一个基础的人工神经网络模型,包含输入层、隐藏层和输出层。设置初始学习率为0.001,使用交叉熵损失函数和Adam优化器进行训练。在训练集和测试集上计算准确率,观察模型的基本性能。
网络结构对比实验:分别构建不同隐藏层神经元数量、不同激活函数(如ReLU、Sigmoid、Tanh等)的人工神经网络模型,比较不同网络结构在MNIST手写字符分类任务上的效果。
超参数调优:使用验证集对学习率、批量大小、正则化参数等超参数进行调整。例如,尝试学习率分别为0.0001、0.001、0.01,批量大小分别为32、64、128,记录不同参数组合下模型在验证集上的准确率,选择最优的超参数组合。
三、模型评估与可视化
性能评估:在测试集上计算模型的混淆矩阵、分类报告(精确率、召回率、F1值等),全面评估模型的分类性能。对比不同模型结构和超参数下的测试集准确率,分析模型的优劣。
可视化分析:
训练过程可视化:绘制训练过程中的损失曲线和准确率曲线,观察模型的收敛情况和泛化能力。
四、实验报告撰写
内容要求:
实验目的、环境:说明实验的目的,列出实验所使用的PaddlePaddle版本、Python版本及硬件环境等。
数据集描述:详细介绍MNIST手写字符数据集的来源、规模、类别分布等情况。
详细步骤与代码片段:按照实验步骤的顺序,详细描述每个环节的操作过程,并附上关键代码片段(如数据预处理、模型构建、训练等代码),对代码进行必要的注释,重点说明调参过程与可视化方法。
结果分析:对实验结果进行深入分析,包括不同模型结构和超参数对分类性能的影响、可视化结果的解读等。
问题与改进:
记录实验中遇到的问题:如数据不平衡、过拟合、训练时间过长等问题,以及相应的解决方法。
提出可能的优化方向:如尝试更先进的网络结构(如残差网络、密集连接网络等)、进一步优化超参数、进行更复杂的数据增强等,以提高模型的性能。
激活函数参照https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Overview_cn.html#activation-functional 或paddle.nn.functional
优化器、学习率可参照https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/Overview_cn.html
模型训练与评估相关API调用举例 https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Model_cn.html
首先导入必要的包
numpy---------->python第三方库,用于进行科学计算
PIL------------> Python Image Library,python第三方图像处理库
matplotlib----->python的绘图库 pyplot:matplotlib的绘图框架
os------------->提供了丰富的方法来处理文件和目录
#导入需要的包
import numpy as np
import paddle as paddle
import paddle.nn as nn
import paddle.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import paddle
from paddle.io import Dataset
import os
print("本教程基于Paddle的版本号为:"+paddle.__version__)
! python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
from collections import Sized
本教程基于Paddle的版本号为:2.2.2
Looking in indexes: https://mirror.baidu.com/pypi/simple, https://mirrors.aliyun.com/pypi/simple/
Requirement already satisfied: visualdl in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.2.0)
Requirement already satisfied: Pillow>=7.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (8.2.0)
Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (0.8.53)
Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (4.0.1)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (2.22.0)
Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (2.2.3)
Requirement already satisfied: six>=1.14.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.16.0)
Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (0.7.1.1)
Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.0.0)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.21.6)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.1.5)
Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.21.0)
Requirement already satisfied: flask>=1.1.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (1.1.1)
Requirement already satisfied: protobuf>=3.11.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl) (3.20.1)
Requirement already satisfied: importlib-metadata<4.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (4.2.0)
Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (0.6.1)
Requirement already satisfied: pycodestyle<2.9.0,>=2.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (2.8.0)
Requirement already satisfied: pyflakes<2.5.0,>=2.4.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl) (2.4.0)
Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (3.0.0)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (1.1.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (0.16.0)
Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.1->visualdl) (7.0)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl) (2.8.0)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl) (2019.3)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl) (0.18.0)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl) (3.9.9)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (0.10.0)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (2.8.2)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->visualdl) (3.0.9)
Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (2.0.1)
Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (0.10.0)
Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (16.7.9)
Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (1.3.4)
Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (1.3.0)
Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (1.4.10)
Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl) (5.1.2)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (2019.9.11)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (1.25.6)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl) (2.8)
Requirement already satisfied: typing-extensions>=3.6.4 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata<4.3->flake8>=3.7.9->visualdl) (4.7.1)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata<4.3->flake8>=3.7.9->visualdl) (3.8.1)
Requirement already satisfied: MarkupSafe>=2.0.0rc2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.1->visualdl) (2.0.1)
Requirement already satisfied: setuptools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from kiwisolver>=1.0.1->matplotlib->visualdl) (41.4.0)
[notice] A new release of pip available: 22.1.2 -> 24.0
[notice] To update, run: pip install --upgrade pip
Step1:准备数据。
(1)数据集介绍
MNIST数据集包含60000个训练集和10000测试数据集。分为图片和标签,图片是28*28的像素矩阵,标签为0~9共10个数字。
(2)transform函数是定义了一个归一化标准化的标准
(3)train_dataset和test_dataset
paddle.vision.datasets.MNIST()中的mode='train'和mode='test'分别用于获取mnist训练集和测试集
transform=transform参数则为归一化标准
#导入数据集Compose的作用是将用于数据集预处理的接口以列表的方式进行组合。
#导入数据集Normalize的作用是图像归一化处理,支持两种方式: 1. 用统一的均值和标准差值对图像的每个通道进行归一化处理; 2. 对每个通道指定不同的均值和标准差值进行归一化处理。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
# 使用transform对数据集做归一化
print('下载并加载训练数据')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
#print(np.array(test_dataset).shape)
print('加载完成')
下载并加载训练数据
item 220/2421 [=>............................] - ETA: 1s - 616us/it
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz
Begin to download
item 464/2421 [====>.........................] - ETA: 1s - 678us/itemitem 509/2421 [=====>........................] - ETA: 1s - 673us/
item 8/8 [============================>.] - ETA: 0s - 1ms/it
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz
Begin to download
Download finished
item 282/403 [===================>..........] - ETA: 0s - 637us/it
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-images-idx3-ubyte.gz
Begin to download
item 2/2 [===========================>..] - ETA: 0s - 1ms/item
Download finished
Cache file /home/aistudio/.cache/paddle/dataset/mnist/t10k-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/t10k-labels-idx1-ubyte.gz
Begin to download
Download finished
加载完成
#让我们一起看看数据集中的图片是什么样子的
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
print(plt.imshow(train_data0, cmap=plt.cm.binary))
print('train_data0 的标签为: ' + str(train_label_0))
AxesImage(25,22;155x154)
train_data0 的标签为: [5]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
a_max = np.asscalar(a_max.astype(scaled_dtype))
#让我们再来看看数据样子是什么样的吧
print(train_data0)
[[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-0.9764706 -0.85882354 -0.85882354 -0.85882354 -0.01176471 0.06666667
0.37254903 -0.79607844 0.3019608 1. 0.9372549 -0.00392157
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -0.7647059 -0.7176471 -0.2627451 0.20784314
0.33333334 0.9843137 0.9843137 0.9843137 0.9843137 0.9843137
0.7647059 0.34901962 0.9843137 0.8980392 0.5294118 -0.49803922
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -0.6156863 0.8666667 0.9843137 0.9843137 0.9843137
0.9843137 0.9843137 0.9843137 0.9843137 0.9843137 0.96862745
-0.27058825 -0.35686275 -0.35686275 -0.56078434 -0.69411767 -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -0.85882354 0.7176471 0.9843137 0.9843137 0.9843137
0.9843137 0.9843137 0.5529412 0.42745098 0.9372549 0.8901961
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -0.37254903 0.22352941 -0.16078432 0.9843137
0.9843137 0.60784316 -0.9137255 -1. -0.6627451 0.20784314
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -0.8901961 -0.99215686 0.20784314
0.9843137 -0.29411766 -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. 0.09019608
0.9843137 0.49019608 -0.9843137 -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -0.9137255
0.49019608 0.9843137 -0.4509804 -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-0.7254902 0.8901961 0.7647059 0.25490198 -0.15294118 -0.99215686
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -0.3647059 0.88235295 0.9843137 0.9843137 -0.06666667
-0.8039216 -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -0.64705884 0.45882353 0.9843137 0.9843137
0.1764706 -0.7882353 -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -0.8745098 -0.27058825 0.9764706
0.9843137 0.46666667 -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. 0.9529412
0.9843137 0.9529412 -0.49803922 -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -0.6392157 0.01960784 0.43529412 0.9843137
0.9843137 0.62352943 -0.9843137 -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-0.69411767 0.16078432 0.79607844 0.9843137 0.9843137 0.9843137
0.9607843 0.42745098 -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -0.8117647 -0.10588235
0.73333335 0.9843137 0.9843137 0.9843137 0.9843137 0.5764706
-0.3882353 -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -0.81960785 -0.48235294 0.67058825 0.9843137
0.9843137 0.9843137 0.9843137 0.5529412 -0.3647059 -0.9843137
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-0.85882354 0.34117648 0.7176471 0.9843137 0.9843137 0.9843137
0.9843137 0.5294118 -0.37254903 -0.92941177 -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -0.5686275 0.34901962
0.77254903 0.9843137 0.9843137 0.9843137 0.9843137 0.9137255
0.04313726 -0.9137255 -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. 0.06666667 0.9843137
0.9843137 0.9843137 0.6627451 0.05882353 0.03529412 -0.8745098
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]
[-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. -1. -1.
-1. -1. -1. -1. ]]
Step2.网络配置
以下的代码判断就是定义一个简单的多层感知器,一共有三层,两个大小为100的隐层和一个大小为10的输出层,因为MNIST数据集是手写0到9的灰度图像,类别有10个,所以最后的输出大小是10。最后输出层的激活函数是Softmax,所以最后的输出层相当于一个分类器。加上一个输入层的话,多层感知器的结构是:输入层-->>隐层-->>隐层-->>输出层。
# 定义多层感知器
#动态图定义多层感知器
class mnist(paddle.nn.Layer):
def __init__(self):
super(mnist,self).__init__()
def forward(self, input_):
x = paddle.reshape(input_, [input_.shape[0], -1])
return y
from paddle.metric import Accuracy
# 用Model封装模型
model = paddle.Model(mnist())
# 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 配置模型
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())
Step3.模型训练及评估
callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
# 训练保存并验证模型
model.fit(train_dataset,test_dataset,epochs=5,batch_size=64,save_dir='multilayer_perceptron',verbose=1)
#模型预测
result = model.predict(test_dataset, batch_size=1)
#请补全模型性能验证代码,可使用model下的evaluate函数或者利用上面的预测出来的结果
model.evaluate(test_dataset,verbose=1)
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/5
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:77: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
return (isinstance(seq, collections.Sequence) and
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
/tmp/ipykernel_94/2115515413.py in <module>
1 callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
2 # 训练保存并验证模型
----> 3 model.fit(train_dataset,test_dataset,epochs=5,batch_size=64,save_dir='multilayer_perceptron',verbose=1)
4
5 #模型预测
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in fit(self, train_data, eval_data, batch_size, epochs, eval_freq, log_freq, save_dir, save_freq, verbose, drop_last, shuffle, num_workers, callbacks, accumulate_grad_batches, num_iters)
1730 for epoch in range(epochs):
1731 cbks.on_epoch_begin(epoch)
-> 1732 logs = self._run_one_epoch(train_loader, cbks, 'train')
1733 cbks.on_epoch_end(epoch, logs)
1734
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in _run_one_epoch(self, data_loader, callbacks, mode, logs)
2060 step + 1 == len(data_loader))
2061
-> 2062 outs = getattr(self, mode + '_batch')(*_inputs)
2063
2064 if self._metrics and self._loss:
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in train_batch(self, inputs, labels, update)
1059 print(loss)
1060 """
-> 1061 loss = self._adapter.train_batch(inputs, labels, update)
1062 if fluid.in_dygraph_mode() and self._input_info is None:
1063 self._update_inputs()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model.py in train_batch(self, inputs, labels, update)
712 else:
713 outputs = self.model.network.forward(
--> 714 *[to_variable(x) for x in inputs])
715
716 losses = self.model._loss(*(to_list(outputs) + labels))
/tmp/ipykernel_94/582360547.py in forward(self, input_)
11 x = paddle.reshape(input_, [input_.shape[0], -1])
12
---> 13 return y
NameError: name 'y' is not defined
Step4.模型预测
#获取测试集的第一个图片
test_data0, test_label_0 = test_dataset[0][0],test_dataset[0][1]
test_data0 = test_data0.reshape([28,28])
plt.figure(figsize=(2,2))
#展示测试集中的第一个图片
print(plt.imshow(test_data0, cmap=plt.cm.binary))
print('test_data0 的标签为: ' + str(test_label_0))
print('test_data0 预测的数值为:%d' % np.argsort(result[0][0])[0][-1])
visualdl --logdir visualdl_log_dir
最新发布