DBNコード解析
20751 ワード
回転:http://blog.csdn.net/Rainbow0210/article/details/53010694?locationNum=1&fps=1
DBNの実現(DeepLeran ToolBox):
ここではDBNを無監督学習の枠組みとして使用し、ANNに「学習成果」を付与して分類を完成させる.
トレーニングセットは6000枚で、28*28の手書きデジタル画像です.テストセットは10000枚で、28*28の手書きデジタル画像です.対応するシングル画像の特徴次元は28*28=784です. 1 , 4 5 , , 8 9, 10 11 15 , , , , , 23 , , 30, 31 , 34, , , , 38, 40, , , , 1 , 4 5 , , 8 9, 10 11 15 , , , , , 23 , , 30, 31 , 34, , , , 38, 40, , , , 1 , 4 5 , , 8 9, 10 11 15 , , , , , 23 , , 1 , 4 5 , , 8 9, 10 11 15 , , , , , 23 , , 1 , 4 5 , , 8 9, 10 11 15 , , 1 , 4 5 , , 8 9, 10 11 15 , , 1 , 4 5 , 1 , 4 5 ,
DBNの実現(DeepLeran ToolBox):
ここではDBNを無監督学習の枠組みとして使用し、ANNに「学習成果」を付与して分類を完成させる.
トレーニングセットは6000枚で、28*28の手書きデジタル画像です.テストセットは10000枚で、28*28の手書きデジタル画像です.対応するシングル画像の特徴次元は28*28=784です.
% function test_example_DBN
load mnist_uint8;
train_x = double(train_x) / 255;
test_x = double(test_x) / 255;
train_y = double(train_y);
test_y = double(test_y);
%% ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NN
rand('state',0)
%train dbn
% DBN
% , 100 ,
dbn.sizes = [100 100];
%
opts.numepochs = 2;
%
opts.batchsize = 100;
% ,
opts.momentum = 0;
%
opts.alpha = 1;
% DBN
dbn = dbnsetup(dbn, train_x, opts);
% DBN
dbn = dbntrain(dbn, train_x, opts);
% , DBN
%unfold dbn to nn
% DBN NN
nn = dbnunfoldtonn(dbn, 10);
% NN Sigmoid
nn.activation_function = 'sigm';
%train nn
% NN
opts.numepochs = 3;
opts.batchsize = 100;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);
assert(er < 0.10, 'Too big error');
function dbn = dbnsetup(dbn, x, opts)
%n ,784
n = size(x, 2);
%dbn.sizes rbm ,[784 100 100]
dbn.sizes = [n, dbn.sizes];
%numel(dbn.sizes) dbn.sizes , [784 100 100], 3
% rbm
for u = 1 : numel(dbn.sizes) - 1
% rbm
dbn.rbm{u}.alpha = opts.alpha;
%
dbn.rbm{u}.momentum = opts.momentum;
% rbm 784-100, rbm 100-100
% , 0
dbn.rbm{u}.W = zeros(dbn.sizes(u + 1), dbn.sizes(u));
% , ,
dbn.rbm{u}.vW = zeros(dbn.sizes(u + 1), dbn.sizes(u));
% rbm 784, rbm 100
% , 0
dbn.rbm{u}.b = zeros(dbn.sizes(u), 1);
dbn.rbm{u}.vb = zeros(dbn.sizes(u), 1);
% rbm 100, rbm 100
% , 0
dbn.rbm{u}.c = zeros(dbn.sizes(u + 1), 1);
dbn.rbm{u}.vc = zeros(dbn.sizes(u + 1), 1);
end
end
function dbn = dbntrain(dbn, x, opts)
% n = 1;
% x = train_x,60000 , 784, 60000*784
%n dbn rbm, n=2
n = numel(dbn.rbm);
% rbm
dbn.rbm{1} = rbmtrain(dbn.rbm{1}, x, opts);
% rbm, rbm
for i = 2 : n
% rbm
x = rbmup(dbn.rbm{i - 1}, x);
% rbm
dbn.rbm{i} = rbmtrain(dbn.rbm{i}, x, opts);
end
end
function x = rbmup(rbm, x)
%sigm sigmoid
%
x = sigm(repmat(rbm.c', size(x, 1), 1) + x * rbm.W');
end
function rbm = rbmtrain(rbm, x, opts)
% x , [0,1]
assert(isfloat(x), 'x must be a float');
assert(all(x(:)>=0) && all(x(:)<=1), 'all data in x must be in [0:1]');
%m , m = 60000
m = size(x, 1);
% , opts.batchsize , opts.batchsize m
numbatches = m / opts.batchsize;
%opts.batchsize m
assert(rem(numbatches, 1) == 0, 'numbatches not integer');
%opts.numepochs,
for i = 1 : opts.numepochs
% 1-m , 1-m ,kk 1-m
kk = randperm(m);
% eer
err = 0;
%
for l = 1 : numbatches
% opts.batchsize
% ,
batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);
% v1
% v1 100*784
v1 = batch;
% v1 h1 ,
h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W');
% h1 v1 ,
v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W);
% v2 h2 ,
h2 = sigm(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W');
% ,h1,v1,h2,v2 , ,
%
c1 = h1' * v1;
c2 = h2' * v2;
rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2) / opts.batchsize;
rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize;
rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize;
%
rbm.W = rbm.W + rbm.vW;
rbm.b = rbm.b + rbm.vb;
rbm.c = rbm.c + rbm.vc;
% err
err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;
end
%
disp(['epoch ' num2str(i) '/' num2str(opts.numepochs) '. Average reconstruction error is: ' num2str(err / numbatches)]);
end
end