LSTM网络之时间序列处理总结(二)
最近在研究LSTM网络,因为是初步学习,主要使用MATLAB进行编程,实践了4种不同的网络结构,用于不同的时间序列任务。本期是第二期,主要更新双层LSTM用于时间序列分类的部分。
1 前置数据处理部分
1.1 数据读取与分割
clear; close all; clc
load WaveformData.mat
numChannels = size(data{1},2);
num2visualize = 4;
idx = randi(length(data), 1, num2visualize);
1.2 进行数据的初步可视化
figure
c_num = 2;
if mod(num2visualize, c_num) == 0
r_num = floor(num2visualize/c_num);
else
r_num = floor(num2visualize/c_num) + 1;
end
tiledlayout(r_num,c_num)
for i = 1:num2visualize
nexttile
stackedplot(data{idx(i)},...
DisplayLabels="特征通道 "+string(1:numChannels))
xlabel("时间步索引")
title("当前序列属于类别 " + string(labels(idx(i))))
end
可视化结果为

数据初步可视化结果
1.3 训练数据的构成
numObservations = numel(data);
[idxTrain,idxValid,idxTest] = trainingPartitions(numObservations,[0.8 0.1 0.1]);
XTrain = data(idxTrain);
TTrain = labels(idxTrain);
XValid = data(idxValid);
TValid = labels(idxValid);
XTest = data(idxTest);
TTest = labels(idxTest);
1.4 训练数据分布情况的可视化
numObservations = numel(XTrain);
for i=1:numObservations
sequence = XTrain{i};
sequenceLengths(i) = size(sequence,1);
end
% 将序列长度进行排序
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);
figure
bar(sequenceLengths)
xlabel("时间序列排序后索引")
ylabel("时间序列长度")
title("按序列长度升序排列的序列数据")
可视化结果为

数据分布情况可视化结果
2 LSTM网络进行时间序列预测
2.1 网络训练过程
首先是最普通的LSTM网络结构,只有单个双层LSTM
numHiddenUnits = 128;
numClasses = 4;
layers = [
sequenceInputLayer(numChannels)
bilstmLayer(numHiddenUnits,OutputMode="last")
fullyConnectedLayer(numClasses)
softmaxLayer];
接下来需要指定训练选项,并进行训练。这次训练次数为200次,增加了RMSE指标,并采用交叉熵进行损失评估
%% 指定训练选项
options = trainingOptions("adam", ...
ValidationData={XValid,TValid}, ...
ValidationFrequency=5, ...
ValidationPatience=100, ...
MaxEpochs=200, ...
SequencePaddingDirection="left", ...
Shuffle="every-epoch", ...
OutputNetwork="best-validation", ...
Plots="training-progress", ...
Metrics = ["rmse"], ...
L2Regularization = 0.0001, ...
InitialLearnRate = 0.001, ...
GradientThreshold = 1, ...
Verbose=true);
net = trainnet(XTrain,TTrain,layers,"crossentropy",options);
训练过程可视化结果为

训练过程可视化结果
训练过程命令行输出结果为

训练过程命令行输出可视化结果
最终进行训练结果的存储
counter = 1;
file_dir = 'networks/'; % 当目录不存在,就再创建一个
if ~exist(file_dir, 'dir')
mkdir(file_dir)
end
file_head = [file_dir, 'trained_net'];
name_temp = [file_head, num2str(counter), '.mat'];
while exist(name_temp,"file")
counter = counter + 1;
name_temp = [file_head, num2str(counter), '.mat'];
end
save(name_temp, 'net')
3.2 网络测试部分
读取保存的网络文件并处理好测试集,进行检测后进行结果可视化
%numObservationsTest = numel(XTest);
for i=1:numObservationsTest
sequence = XTest{i};
sequenceLengthsTest(i) = size(sequence,1);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);
scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);
fprintf(['分类准确度为: ', num2str(mean(YTest == TTest), "%.4f"),'\n']);
figure
confusionchart(TTest,YTest)
由于这里是对分类任务进行测试,故主要输出其混淆矩阵的结果

测试结果可视化混淆矩阵