从这里找到的,https:// github.co/ rasmusbergpalm/ DeepLearnToolbox
当然,搜索深度学习工具箱也行。
function test_example_DBN
load mnist_uint8;
train_x = double(train_x) / 255;
test_x = double(test_x) / 255;
train_y = double(train_y);
test_y = double(test_y);
%% ex1 train a 100 hidden unit RBM and visualize its weights
rand('state',0)
dbn.sizes = [100];
opts.numepochs = 1;
opts.batchsize = 100;
opts.momentum = 0;
opts.alpha = 1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);
figure; visualize(dbn.rbm{1}.W'); % Visualize the RBM weights
%% ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NN
rand('state',0)
%train dbn
dbn.sizes = [100 100];
opts.numepochs = 1;
opts.batchsize = 100;
opts.momentum = 0;
opts.alpha = 1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);
%unfold dbn to nn
nn = dbnunfoldtonn(dbn, 10);
nn.activation_function = 'sigm';
%train nn
opts.numepochs = 1;
opts.batchsize = 100;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);
assert(er < 0.10, 'Too big error');
报错是:
epoch 1/1. Average reconstruction error is: 66.2661epoch 1/1. Average reconstruction error is: 66.2661epoch 1/1. Average reconstruction error is: 10.286Attempted to access lmisys(5); index out of bounds because numel(lmisys)=4.
Error in lmiunpck (line 23) rs=lmisys(4); rv=lmisys(5); % row sizes of LMISET,LMIVAR
Error in nnsetup (line 26)[LMI_set,LMI_var,LMI_term,data]=lmiunpck(lmisys);
Error in dbnunfoldtonn (line 6) nn = nnsetup([dbn.sizes outputsize]);
Error in test_DBN (line 32)nn = dbnunfoldtonn(dbn, 10);