LSTM网络之时间序列处理总结(二)

LSTM网络之时间序列处理总结(二)

最近在研究LSTM网络,因为是初步学习,主要使用MATLAB进行编程,实践了4种不同的网络结构,用于不同的时间序列任务。本期是第二期,主要更新双层LSTM用于时间序列分类的部分。

1 前置数据处理部分

1.1 数据读取与分割

clear; close all; clc 
load WaveformData.mat

numChannels = size(data{1},2);

num2visualize = 4;
idx = randi(length(data), 1, num2visualize);

1.2 进行数据的初步可视化

figure
c_num = 2;
if mod(num2visualize, c_num) == 0
    r_num = floor(num2visualize/c_num);
else
    r_num = floor(num2visualize/c_num) + 1;
end
tiledlayout(r_num,c_num) 
for i = 1:num2visualize
    nexttile
    stackedplot(data{idx(i)},...
        DisplayLabels="特征通道 "+string(1:numChannels))
    
    xlabel("时间步索引")
    title("当前序列属于类别 " + string(labels(idx(i))))
end

可视化结果为

project cover
数据初步可视化结果

1.3 训练数据的构成

numObservations = numel(data);
[idxTrain,idxValid,idxTest] = trainingPartitions(numObservations,[0.8  0.1 0.1]);

XTrain = data(idxTrain);
TTrain = labels(idxTrain);

XValid = data(idxValid);
TValid = labels(idxValid);

XTest = data(idxTest);
TTest = labels(idxTest);

1.4 训练数据分布情况的可视化

numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,1);
end

% 将序列长度进行排序
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
TTrain = TTrain(idx);

figure
bar(sequenceLengths)
xlabel("时间序列排序后索引")
ylabel("时间序列长度")
title("按序列长度升序排列的序列数据")

可视化结果为

project cover
数据分布情况可视化结果

2 LSTM网络进行时间序列预测

2.1 网络训练过程

首先是最普通的LSTM网络结构,只有单个双层LSTM

numHiddenUnits = 128;
numClasses = 4;

layers = [
    sequenceInputLayer(numChannels)
    bilstmLayer(numHiddenUnits,OutputMode="last")
    fullyConnectedLayer(numClasses)
    softmaxLayer];

接下来需要指定训练选项,并进行训练。这次训练次数为200次,增加了RMSE指标,并采用交叉熵进行损失评估

%% 指定训练选项
options = trainingOptions("adam", ...
     ValidationData={XValid,TValid}, ...
    ValidationFrequency=5, ...
    ValidationPatience=100, ...
    MaxEpochs=200, ...
    SequencePaddingDirection="left", ...
    Shuffle="every-epoch", ...
    OutputNetwork="best-validation", ...
    Plots="training-progress", ...
    Metrics = ["rmse"], ...
    L2Regularization = 0.0001, ...
    InitialLearnRate = 0.001, ...
    GradientThreshold = 1, ...
    Verbose=true);
    
net = trainnet(XTrain,TTrain,layers,"crossentropy",options);

训练过程可视化结果为

project cover
训练过程可视化结果

训练过程命令行输出结果为

project cover
训练过程命令行输出可视化结果

最终进行训练结果的存储

counter = 1;
file_dir = 'networks/'; % 当目录不存在,就再创建一个
if ~exist(file_dir, 'dir')
    mkdir(file_dir)
end

file_head = [file_dir, 'trained_net'];
name_temp = [file_head, num2str(counter), '.mat'];
while exist(name_temp,"file")
    counter = counter + 1;
    name_temp = [file_head, num2str(counter), '.mat'];
end
save(name_temp, 'net')

3.2 网络测试部分

读取保存的网络文件并处理好测试集,进行检测后进行结果可视化

%numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,1);
end

[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
TTest = TTest(idx);

scores = minibatchpredict(net,XTest);
YTest = scores2label(scores,classNames);

fprintf(['分类准确度为: ', num2str(mean(YTest == TTest), "%.4f"),'\n']); 

figure
confusionchart(TTest,YTest)

由于这里是对分类任务进行测试,故主要输出其混淆矩阵的结果

project cover
测试结果可视化混淆矩阵
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

空 白II

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值