一、DNN反向传播算法算法介绍

 

    在了解DNN的反向传播算法前,我们先要知道DNN反向传播算法要解决的问题,也就是说,什么时候我们需要这个反向传播算法? 

    回到我们监督学习的一般问题,假设我们有m个训练样本:{(x1,y1),(x2,y2),...,(xm,ym)}{(x1,y1),(x2,y2),...,(xm,ym)},其中xx为输入向量,特征维度为n_inn_in,而yy为输出向量,特征维度为n_outn_out。我们需要利用这m个样本训练出一个模型,当有一个新的测试样本(xtest,?)(xtest,?)来到时, 我们可以预测ytestytest向量的输出。 

    如果我们采用DNN的模型,即我们使输入层有n_inn_in个神经元,而输出层有n_outn_out个神经元。再加上一些含有若干神经元的隐藏层。此时我们需要找到合适的所有隐藏层和输出层对应的线性系数矩阵WW,偏倚向量bb,让所有的训练样本输入计算出的输出尽可能的等于或很接近样本输出。怎么找到合适的参数呢?

    如果大家对传统的机器学习的算法优化过程熟悉的话,这里就很容易联想到我们可以用一个合适的损失函数来度量训练样本的输出损失,接着对这个损失函数进行优化求最小化的极值,对应的一系列线性系数矩阵WW,偏倚向量bb即为我们的最终结果。在DNN中,损失函数优化极值求解的过程最常见的一般是通过梯度下降法来一步步迭代完成的,当然也可以是其他的迭代方法比如牛顿法与拟牛顿法。如果大家对梯度下降法不熟悉,建议先阅读我之前写的梯度下降(Gradient Descent)小结。

    对DNN的损失函数用梯度下降法进行迭代优化求极小值的过程即为我们的反向传播算法。

【预测模型】基于WMMSE的DNN算法实现数据预测_预测模型

二、代码

%  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%  MATLAB code to reproduce our work on DNN research for ICASSP 2017.
%  Simply run "main.m", you will get the result for Gaussian IC case in section 4.3.
%  To get results for other sections, slightly modification may apply.
%  We also provide some pre-trained functions to show our results in Table. 1 & Table 2.
%  To run our code, Neuron Network Toolbox and Deep Learning Toolbox need to be installed first.
%  Code has been tested successfully on MATLAB 2016b prerelease platform.
%
%  References:
%  [1] Haoran Sun, Xiangyi Chen, Qingjiang Shi, Mingyi Hong and Xiao Fu.
%  "LEARNING TO OPTIMIZE: TRAINING DEEP NEURAL NETWORKS FOR WIRELESS RESOURCE MANAGEMENT."

%  %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


clc
clear
clear all
K = 10;
num_H = 50000;
disp('####### Generate Training Data #######');
generate(K,num_H);

disp('####### Train Deep Neural Network #######');
trainDNN(K,num_H);

disp('####### Evaluate Training Performance #######');
trainperformance(K,num_H);

disp('####### Evaluate Testing Performance #######');
testperformance(K,num_H)
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.

三、演示结果

【预测模型】基于WMMSE的DNN算法实现数据预测_预测模型_02