DBNコード解析


回転:http://blog.csdn.net/Rainbow0210/article/details/53010694?locationNum=1&fps=1
DBNの実現(DeepLeran ToolBox): 
ここではDBNを無監督学習の枠組みとして使用し、ANNに「学習成果」を付与して分類を完成させる.
トレーニングセットは6000枚で、28*28の手書きデジタル画像です.テストセットは10000枚で、28*28の手書きデジタル画像です.対応するシングル画像の特徴次元は28*28=784です.
% function test_example_DBN
load mnist_uint8;

train_x = double(train_x) / 255;
test_x  = double(test_x)  / 255;
train_y = double(train_y);
test_y  = double(test_y);

%%  ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NN
rand('state',0)
%train dbn
% DBN    
% 100    ,           
dbn.sizes = [100 100];
%    
opts.numepochs =   2;
%         
opts.batchsize = 100;
%    ,         
opts.momentum  =   0;
%    
opts.alpha     =   1;
%  DBN
dbn = dbnsetup(dbn, train_x, opts);
%  DBN
dbn = dbntrain(dbn, train_x, opts);
%  ,    DBN   

%unfold dbn to nn
% DBN          NN   
nn = dbnunfoldtonn(dbn, 10);

%  NN      Sigmoid  
nn.activation_function = 'sigm';

%train nn
%  NN
opts.numepochs =  3;
opts.batchsize = 100;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);

assert(er < 0.10, 'Too big error');
  • 1
  • ,
  • 4
  • 5
  • ,
  • ,
  • 8
  • 9,
  • 10
  • 11
  • 15
  • ,
  • ,
  • ,
  • ,
  • ,
  • 23
  • ,
  • ,
  • 30,
  • 31
  • ,
  • 34,
  • ,
  • ,
  • ,
  • 38,
  • 40,
  • ,
  • ,
  • ,
  • 1
  • ,
  • 4
  • 5
  • ,
  • ,
  • 8
  • 9,
  • 10
  • 11
  • 15
  • ,
  • ,
  • ,
  • ,
  • ,
  • 23
  • ,
  • ,
  • 30,
  • 31
  • ,
  • 34,
  • ,
  • ,
  • ,
  • 38,
  • 40,
  • ,
  • ,
  • ,
  • function dbn = dbnsetup(dbn, x, opts)
        %n          ,784
        n = size(x, 2);
        %dbn.sizes rbm   ,[784 100 100]
        dbn.sizes = [n, dbn.sizes];
    
        %numel(dbn.sizes)  dbn.sizes      ,  [784 100 100],  3
        %     rbm
        for u = 1 : numel(dbn.sizes) - 1
            %   rbm     
            dbn.rbm{u}.alpha    = opts.alpha;
            %    
            dbn.rbm{u}.momentum = opts.momentum;
            %   rbm 784-100,    rbm 100-100
            %       ,     0
            dbn.rbm{u}.W  = zeros(dbn.sizes(u + 1), dbn.sizes(u));
            %       ,  ,    
            dbn.rbm{u}.vW = zeros(dbn.sizes(u + 1), dbn.sizes(u));
            %   rbm 784,   rbm 100
            %      ,     0
            dbn.rbm{u}.b  = zeros(dbn.sizes(u), 1);
            dbn.rbm{u}.vb = zeros(dbn.sizes(u), 1);
            %   rbm 100,   rbm 100
            %      ,     0
            dbn.rbm{u}.c  = zeros(dbn.sizes(u + 1), 1);
            dbn.rbm{u}.vc = zeros(dbn.sizes(u + 1), 1);
        end
    end
    
  • 1
  • ,
  • 4
  • 5
  • ,
  • ,
  • 8
  • 9,
  • 10
  • 11
  • 15
  • ,
  • ,
  • ,
  • ,
  • ,
  • 23
  • ,
  • ,
  • 1
  • ,
  • 4
  • 5
  • ,
  • ,
  • 8
  • 9,
  • 10
  • 11
  • 15
  • ,
  • ,
  • ,
  • ,
  • ,
  • 23
  • ,
  • ,
  • function dbn = dbntrain(dbn, x, opts)
        % n = 1;
        % x = train_x,60000   ,     784, 60000*784
        %n dbn    rbm,  n=2
        n = numel(dbn.rbm);
        %       rbm
        dbn.rbm{1} = rbmtrain(dbn.rbm{1}, x, opts);
        %     rbm,       rbm
        for i = 2 : n
            %  rbm
            x = rbmup(dbn.rbm{i - 1}, x);
            %  rbm
            dbn.rbm{i} = rbmtrain(dbn.rbm{i}, x, opts);
        end
    
    end
    
  • 1
  • ,
  • 4
  • 5
  • ,
  • ,
  • 8
  • 9,
  • 10
  • 11
  • 15
  • ,
  • ,
  • 1
  • ,
  • 4
  • 5
  • ,
  • ,
  • 8
  • 9,
  • 10
  • 11
  • 15
  • ,
  • ,
  • function x = rbmup(rbm, x)
        %sigm sigmoid  
        %         
        x = sigm(repmat(rbm.c', size(x, 1), 1) + x * rbm.W');
    end
    
  • 1
  • ,
  • 4
  • 5
  • ,
  • 1
  • ,
  • 4
  • 5
  • ,
  • function rbm = rbmtrain(rbm, x, opts)
        %  x          ,    [0,1]
        assert(isfloat(x), 'x must be a float');
        assert(all(x(:)>=0) && all(x(:)<=1), 'all data in x must be in [0:1]');
    
        %m     ,  m = 60000
        m = size(x, 1);
        %    ,    opts.batchsize   ,    opts.batchsize    m
        numbatches = m / opts.batchsize;
    
        %opts.batchsize     m
        assert(rem(numbatches, 1) == 0, 'numbatches not integer');
    
        %opts.numepochs,    
        for i = 1 : opts.numepochs
    
            %    1-m  ,   1-m    ,kk 1-m      
            kk = randperm(m);
    
            %     eer
            err = 0;
    
            %          
            for l = 1 : numbatches
                %  opts.batchsize       
                %               ,       
                batch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);
    
                %   v1
                %  v1 100*784   
                v1 = batch;
                %  v1  h1   ,     
                h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W');
                %  h1  v1   ,     
                v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W);
                %  v2  h2   ,     
                h2 = sigm(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W');
    
                %  ,h1,v1,h2,v2      ,             ,          
    
                %         
                c1 = h1' * v1;
                c2 = h2' * v2;
    
                rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2)     / opts.batchsize;
                rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize;
                rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize;
    
                %    
                rbm.W = rbm.W + rbm.vW;
                rbm.b = rbm.b + rbm.vb;
                rbm.c = rbm.c + rbm.vc;
    
                %  err
                err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;
            end
            %    
            disp(['epoch ' num2str(i) '/' num2str(opts.numepochs)  '. Average reconstruction error is: ' num2str(err / numbatches)]);
    
        end
    end