基于Matlab的BiLSTM实现

Matlab实现BiLSTM深度学习分类
本文探讨了在Python深度学习环境配置困难的情况下,使用Matlab搭建BiLSTM模型进行时序遥感数据分类的优势。通过实验,模型达到88.9%的分类精度,展示Matlab在深度学习领域的应用潜力。
部署运行你感兴趣的模型镜像

问题背景

目前深度学习多使用python实现。不过想要配置好一个python的深度学习环境有时却并不轻松,常常因为各个第三方库版本兼容性问题而失败。相比之下,matlab仅需一次安装简化了不少工作。这几年matlab的深度学习工具箱也是发展迅速。但我发现matlab的相关资料却比较少。因此,我探索了下如何用matlab搭建一个BiLSTM用于时序遥感数据的分类。

clc,clear;
%% Load the training data
% trainX: an array with shape of (n, c, t). n represents the number
%            of training samples, c is the number of features, t is the
%             length of time sequence.
% trainY: an array with shape of (n,). n represents the number of training
%            samples. 
rootDir = 'root_dir';
trainingData = importdata(fullfile(rootDir,'train.mat'));
trainX = trainingData.trainx;
trainY = trainingData.trainy+1;
Xtrain = cell({});
for i = 1:size(trainX,1)
    Xtrain{i,1} = squeeze(trainX(i,:,:));
end
Ytrain = categorical(trainY');

%% Load the validation data
valData = importdata(fullfile(rootDir,'test.mat'));
valX = valData.testx;
valY = valData.testy+1;
Yval = categorical(valY');
Xval = cell({});
for i = 1:size(valX,1)
    Xval{i,1} = squeeze(valX(i,:,:));
end
valDataSet = cell({Xval,Yval});

%% Create bilstm model
% numFeatures: The number of expected features in input data
% numHiddens: The number of features in the hidden state
% numClasses: The number of classess
numFeatures = 8;
numHiddens = 256;
numClasses = 5;

netLayers = [
    sequenceInputLayer(numFeatures,"Name","input")
    bilstmLayer(numHiddens,"Name","bilstm_1",'OutputMode','last')
    bilstmLayer(numHiddens,"Name","bilstm_2",'OutputMode','last')
    dropoutLayer(0.5,"Name","dropout")
    flattenLayer("Name","flatten")
    fullyConnectedLayer(numClasses,"Name","fc")
    softmaxLayer("Name","softmax")
    classificationLayer("Name","classification")];

%% Set the hyper parameters for unet training 
options = trainingOptions('adam', ...                          
                                        'InitialLearnRate',1e-4, ...    
                                        'Plots','training-progress',...
                                        'MaxEpochs',60, ... 
                                        'MiniBatchSize',128,...
                                        'VerboseFrequency',1,...
                                        'ExecutionEnvironment', 'auto',...
                                        'Shuffle','every-epoch',...
                                         'ValidationData',valDataSet,...
                                         'ValidationFrequency',1,...
                                         'WorkerLoad',4,...
                                         'CheckPointPath',rootDir);
% start training
net = trainNetwork(Xtrain,Ytrain, netLayers, options);

%% Save and load model
save('bilstm.mat','net');
bilstm = importdata('bilstm.mat');

%% Accuracy assessment
pred = classify(bilstm, Xval);
[confusionMatrix,order] = confusionmat(categorical(valY),pred);
cm = confusionchart(confusionMatrix);

% caculate user accuracy and mapping accuracy
confusionMatrix = [confusionMatrix, zeros(size(confusionMatrix,1),1)];
confusionMatrix = [confusionMatrix; zeros(1,size(confusionMatrix,2))];
confusionMatrix(1:end-1,end) = confusionMatrix(sub2ind(size(confusionMatrix),1:numClasses,1:numClasses))...
                                                ./sum(confusionMatrix(1:end-1,1:end-1),2)';
confusionMatrix(end,1:end-1) = confusionMatrix(sub2ind(size(confusionMatrix),1:numClasses,1:numClasses))...
                                                ./sum(confusionMatrix(1:end-1,1:end-1),1);    
confusionMatrix(end,end) = sum(confusionMatrix(sub2ind(size(confusionMatrix),1:numClasses,1:numClasses)))...
                                                ./sum(sum(confusionMatrix(1:end-1,1:end-1)));
                                            
mappingAccuracy = confusionMatrix(end,1:end-1);
userAccuracy = confusionMatrix(1:end-1,end);
totalAccuracy = confusionMatrix(end,end);

分类精度还不错,达到了88.9%

 

您可能感兴趣的与本文相关的镜像

Python3.9

Python3.9

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值