matlab LSTM序列分类的官方示例

matlab版本是2018b及其以上。

%%
%加载序列数据
%数据描述:总共270组训练样本共分为9类,每组训练样本的训练样个数不等,每个训练训练样本由12个特征向量组成,
[XTrain,YTrain] = japaneseVowelsTrainData;
%数据可视化
figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
legend("Feature " + string(1:12),'Location','northeastoutside')
%%
%LSTM可以将分组后等量的训练样本进行训练,从而提高训练效率
%如果每组的样本数量不同,进行小批量拆分,则需要尽量保证分块的训练样本数相同
%首先找到每组样本数和总的组数
numObservations = numel(XTrain);
for i=1:numObservations
    sequence = XTrain{i};
    sequenceLengths(i) = size(sequence,2);
end
%绘图前后排序的各组数据个数
figure
subplot(1,2,1)
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")
%按序列长度对测试数据进行排序
[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);
subplot(1,2,2)
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

%%
%设置LSTM训练数据的小批量分组个数
miniBatchSize = 27;

%%
%定义LSTM网络架构:
%将输入大小指定为序列大小 12(输入数据的维度)
%指定具有 100 个隐含单元的双向 LSTM 层,并输出序列的最后一个元素。
%指定九个类,包含大小为 9 的全连接层,后跟 softmax 层和分类层。
inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;

layers = [ ...
    sequenceInputLayer(inputSize)
    bilstmLayer(numHiddenUnits,'OutputMode','last')
    fullyConnectedLayer(numClasses)
    softmaxLayer
    classificationLayer]

%%
%指定训练选项:
%求解器为 'adam'
%梯度阈值为 1,最大轮数为 100。
% 27 作为小批量数。
%填充数据以使长度与最长序列相同,序列长度指定为 'longest'。
%数据保持按序列长度排序的状态,不打乱数据。
% 'ExecutionEnvironment' 指定为 'cpu',设定为'auto'表明使用GPU。

maxEpochs = 100;
miniBatchSize = 27;

options = trainingOptions('adam', ...
    'ExecutionEnvironment','cpu', ...
    'GradientThreshold',1, ...
    'MaxEpochs',maxEpochs, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest', ...
    'Shuffle','never', ...
    'Verbose',0, ...
    'Plots','training-progress');

%%
%训练LSTM网络
net = trainNetwork(XTrain,YTrain,layers,options);

%%
%测试LSTM网络
%加载测试集
[XTest,YTest] = japaneseVowelsTestData;

%由于LSTM已经按照相似长度的小批量分组27,测试需要按照相同方式对数据进行排序处理。
numObservationsTest = numel(XTest);
for i=1:numObservationsTest
    sequence = XTest{i};
    sequenceLengthsTest(i) = size(sequence,2);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

%使用classify进行分类,指定小批量大小27,指定组内数据按照最长的数据填充
miniBatchSize = 27;
YPred = classify(net,XTest, ...
    'MiniBatchSize',miniBatchSize, ...
    'SequenceLength','longest');
%计算分类准确度
acc = sum(YPred == YTest)./numel(YTest)

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值