%% LeNet-5 for MNIST Classification - Complete Implementation
clear; close all; clc;
%% 1. Load and Preprocess MNIST Dataset
[XTrain, YTrain] = digitTrain4DArrayData;
[XTest, YTest] = digitTest4DArrayData;
% Display dataset information
fprintf('MNIST Dataset:\n');
fprintf('Training Set: %d images (28x28x1)\n', size(XTrain, 4));
fprintf('Testing Set: %d images (28x28x1)\n', size(XTest, 4));
% Normalize pixel values to [0, 1]
XTrain = rescale(XTrain);
XTest = rescale(XTest);
% Visualize sample training images
figure('Name', 'Sample MNIST Training Images', 'Position', [100, 100, 900, 300]);
randIdx = randperm(size(XTrain, 4), 20);
for i = 1:20
subplot(4, 5, i);
imshow(XTrain(:, :, :, randIdx(i)));
title(sprintf('Label: %d', YTrain(randIdx(i))));
end
%% 2. Define LeNet-5 Network Architecture
layers = [
% Input Layer
imageInputLayer([28 28 1], 'Name', 'input', 'Normalization', 'none')
% Convolutional Block 1
convolution2dLayer(5, 6, 'Padding', 'same', 'Name', 'conv1')
batchNormalizationLayer('Name', 'bn1')
reluLayer('Name', 'relu1')
averagePooling2dLayer(2, 'Stride', 2, 'Name', 'pool1')
% Convolutional Block 2
convolution2dLayer(5, 16, 'Padding', 'same', 'Name', 'conv2')
batchNormalizationLayer('Name', 'bn2')
reluLayer('Name', 'relu2')
averagePooling2dLayer(2, 'Stride', 2, 'Name', 'pool2')
% Fully Connected Layers
fullyConnectedLayer(120, 'Name', 'fc1')
reluLayer('Name', 'relu3')
fullyConnectedLayer(84, 'Name', 'fc2')
reluLayer('Name', 'relu4')
% Output Layer
fullyConnectedLayer(10, 'Name', 'fc_out')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'output')
];
% Display network architecture
fprintf('\nLeNet-5 Network Architecture:\n');
disp(layers);
%% 3. Configure Training Options
options = trainingOptions('adam', ...
'InitialLearnRate', 0.001, ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropFactor', 0.5, ...
'LearnRateDropPeriod', 10, ...
'MaxEpochs', 15, ...
'MiniBatchSize', 128, ...
'Shuffle', 'every-epoch', ...
'ValidationData', {XTest, YTest}, ...
'ValidationFrequency', 100, ...
'Plots', 'training-progress', ...
'ExecutionEnvironment', 'auto', ... % Uses GPU if available
'L2Regularization', 0.0001, ...
'Verbose', true, ...
'VerboseFrequency', 100);
%% 4. Train the Network
fprintf('\nTraining LeNet-5 on MNIST dataset...\n');
tStart = tic;
[net, trainingInfo] = trainNetwork(XTrain, YTrain, layers, options);
trainingTime = toc(tStart);
fprintf('Training completed in %.2f seconds (%.2f minutes)\n', ...
trainingTime, trainingTime/60);
%% 5. Evaluate Network Performance
% Test set predictions
fprintf('\nEvaluating on test set...\n');
YPred = classify(net, XTest);
% Calculate accuracy
accuracy = sum(YPred == YTest) / numel(YTest);
fprintf('Test Accuracy: %.2f%%\n', accuracy * 100);
% Confusion matrix
figure('Name', 'Confusion Matrix', 'Position', [100, 100, 700, 600]);
matlabshared.mlearnlib.confusionchart(YTest, YPred, ...
'ColumnSummary', 'column-normalized', ...
'RowSummary', 'row-normalized');
'Title', sprintf('LeNet-5 Performance (Accuracy: %.2f%%)', accuracy*100));
% Display misclassified examples
misclassified = find(YPred ~= YTest);
if ~isempty(misclassified)
fprintf('Number of misclassified images: %d\n', numel(misclassified));
figure('Name', 'Misclassified Examples', 'Position', [100, 100, 900, 600]);
randMis = misclassified(randperm(min(20, numel(misclassified))));
for i = 1:min(20, numel(randMis))
subplot(4, 5, i);
imshow(XTest(:, :, :, randMis(i)));
title(sprintf('True: %d\nPred: %d', YTest(randMis(i)), YPred(randMis(i))));
end
end
%% 6. Visualize Network Components
% First convolutional layer filters
conv1Weights = net.Layers(2).Weights;
figure('Name', 'First Convolutional Layer Filters', 'Position', [100, 100, 800, 300]);
montage(rescale(conv1Weights), 'Size', [2 3]);
title('Layer 1: 6 Filters (5x5)');
colorbar;
% Feature maps visualization
sampleImg = XTest(:, :, :, 1); % Use first test image
featureMaps = activations(net, sampleImg, {'relu1', 'relu2'});
% Layer 1 feature maps
figure('Name', 'Feature Maps - Layer 1 (relu1)', 'Position', [100, 100, 900, 300]);
montage(featureMaps(:, :, :, 1), 'Size', [2 3]);
title('Feature Maps after First Convolutional Block');
% Layer 2 feature maps
figure('Name', 'Feature Maps - Layer 2 (relu2)', 'Position', [100, 100, 900, 300]);
montage(featureMaps(:, :, :, 2), 'Size', [4 4]);
title('Feature Maps after Second Convolutional Block');
%% 7. Training History Visualization
figure('Position', [100, 100, 1000, 800]);
% Loss plot
subplot(2, 2, 1);
plot(trainingInfo.TrainingLoss, 'LineWidth', 1.5);
hold on;
plot(trainingInfo.ValidationLoss, 'LineWidth', 1.5);
title('Training and Validation Loss');
xlabel('Iteration');
ylabel('Loss');
legend('Training', 'Validation');
grid on;
% Accuracy plot
subplot(2, 2, 2);
plot(trainingInfo.TrainingAccuracy, 'LineWidth', 1.5);
hold on;
plot(trainingInfo.ValidationAccuracy, 'LineWidth', 1.5);
title('Training and Validation Accuracy');
xlabel('Iteration');
ylabel('Accuracy (%)');
ylim([80 100]);
legend('Training', 'Validation', 'Location', 'southeast');
grid on;
% Learning rate schedule
subplot(2, 2, 3);
plot(trainingInfo.LearnRate, 'LineWidth', 1.5);
title('Learning Rate Schedule');
xlabel('Iteration');
ylabel('Learning Rate');
grid on;
%% 8. Data Augmentation Improvement (Optional)
fprintf('\nTraining with data augmentation for improved performance...\n');
% Create augmented datastore
augmenter = imageDataAugmenter(...
'RandRotation', [-15 15], ...
'RandXTranslation', [-2 2], ...
'RandYTranslation', [-2 2], ...
'RandXScale', [0.9 1.1], ...
'RandYScale', [0.9 1.1]);
augimdsTrain = augmentedImageDatastore([28 28], XTrain, YTrain, ...
'DataAugmentation', augmenter);
% Train with augmented data
options.InitialLearnRate = 0.001;
options.MaxEpochs = 20;
options.ValidationFrequency = 200;
tStartAug = tic;
[netAug, infoAug] = trainNetwork(augimdsTrain, layers, options);
trainingTimeAug = toc(tStartAug);
% Evaluate augmented model
YPredAug = classify(netAug, XTest);
accuracyAug = sum(YPredAug == YTest) / numel(YTest);
fprintf('Augmented Model Training Time: %.2f seconds\n', trainingTimeAug);
fprintf('Augmented Test Accuracy: %.2f%%\n', accuracyAug * 100);
fprintf('Accuracy Improvement: +%.2f%%\n', (accuracyAug - accuracy)*100);
% Compare performance
figure('Position', [100, 100, 800, 400]);
subplot(1, 2, 1);
bar([accuracy*100, accuracyAug*100]);
set(gca, 'XTickLabel', {'Original', 'Augmented'});
ylabel('Accuracy (%)');
title('Model Comparison');
ylim([95 100]);
grid on;
subplot(1, 2, 2);
bar([trainingTime, trainingTimeAug]);
set(gca, 'XTickLabel', {'Original', 'Augmented'});
ylabel('Training Time (s)');
title('Training Time Comparison');
grid on;
%% 9. Save Models
save('lenet5_mnist.mat', 'net');
save('lenet5_mnist_augmented.mat', 'netAug');
fprintf('\nModels saved as "lenet5_mnist.mat" and "lenet5_mnist_augmented.mat"\n');
以上代码文件: 有错误untitled3.m 行: 98 列: 77
无效表达式。调用函数或对变量进行索引时,请使用圆括号。否则,请检查不匹配的分隔符。改为正确的完整可运行代码