当cnn_mnist.m运行完成后,我们再打开data文件夹里的mnist-baseline-simplenn文件夹,就会发现里面多了一个pdf文件和20个net-epoch-(1~20).mat,这20个net-epoch-(1~20).mat,就是经过每一轮训练后,获得的训练好的模型。
如果在训练的时候选择了opts.batchNormalization为true的话,即进行批量归一化,那么生成的文件夹便是mnist-baseline-simplenn-bnorm,文件夹下也会有20个模型。在测试的时候,如果使用此模型,并且对图像仅仅是进行了归一化和减去均值操作,那么测试便得不到想要的结果。
在此按照ImageNet测试的demo写了一个mnist测试的代码,有关注意事项在代码中说明
run ../matlab/vl_setupnn
load('../data\mnist-baseline-simplenn/net-epoch-20.mat');%此模型包含三个部分,其中一部分为net
load('../data\mnist-baseline-simplenn-bnorm/imdb.mat');%images结构体在此读取
net = vl_simplenn_tidy(net);
net.layers{1,end}.type = 'softmax';%训练时为softmaxloss,测试时为softmax
test_index = find(images.set==3);%1对应训练集,3对应测试集,1有(1——60000)3有(60001——70000)
% 挑选出测试集以及真实类别
test_data = images.data(:,:,:,test_index);
test_label = images.labels(test_index);
im_ = test_data(:,:,:,536);%随意选取一张图像
% im=imread('5.jpg');
% im_=single(im);
im_=imresize(im_,net.meta.inputSize(1:2));%此处和ImageNet网络名称不同
im_ = im_ - images.data_mean;去均值
% im_=im_-net.meta.normalization.averageImage;
res=vl_simplenn(net,im_);
y=res(end).x;
x=gather(res(end).x);
scores=squeeze(gather(res(end).x));
[bestScore,best]=max(scores);
figure(1);
clf;
imshow(im_);
title(sprintf('%s %d,%.3f',...
net.meta.classes.name{best-1},best-1,bestScore));
另外还有一个对序列号为60000-70000图像进行整体精度预测的代码,大致思路与上面相同
run ../matlab/vl_setupnn
load('../data\mnist-baseline-simplenn/net-epoch-11.mat');%此处换成自己下载模型存储的位置
load('../data\mnist-baseline-simplenn-bnorm/imdb.mat');
net = vl_simplenn_tidy(net);
net.layers{1,end}.type = 'softmax';%训练时为softmaxloss,测试时为softmax
% 挑选出测试样本在全体数据中对应的编号60001-70000
test_index = find(images.set==3);%1对应训练集,3对应测试集,1有(1——60000)3有(60001——70000)
% 挑选出测试集以及真实类别
test_data = images.data(:,:,:,test_index);
test_label = images.labels(test_index);
% 将最后一层改为 softmax (原始为softmaxloss,这是训练用)
net.layers{1, end}.type = 'softmax';
% 对每张测试图片进行分类
for i = 1:length(test_label)
i
im_ = test_data(:,:,:,i);
im_ = im_ - images.data_mean;
res = vl_simplenn(net, im_) ;
scores = squeeze(gather(res(end).x)) ;
[bestScore, best] = max(scores) ;
pre(i) = best;
end
% 计算准确率
accurcy = length(find(pre==test_label))/length(test_label);
disp(['accurcy = ',num2str(accurcy*100),'%']);