深さ学習のGRUアルゴリズムの例


まずコードをダウンロードします.https://github.com/whk6688/rnn
例1:予測以下
private void train(CharText ctext, double lr) {
        Map indexChar = ctext.getIndexChar();
        Map charVector = ctext.getCharVector();
        List sequence = ctext.getSequence();
        for (int i = 0; i < 100; i++) {
            double error = 0;
            double num = 0;
            double start = System.currentTimeMillis();
            for (int s = 0; s < sequence.size(); s++) {
                String seq = sequence.get(s);
                if (seq.length() < 3) {
                    continue;
                }

                Map acts = new HashMap<>();
                // forward pass
                System.out.print(String.valueOf(seq.charAt(0)+"->"));
                for (int t = 0; t < seq.length() - 1; t++) {
                    DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t)));
                    acts.put("x" + t, xt);

                    gru.active(t, acts);

                    DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
                    acts.put("py" + t, predcitYt);
                    DoubleMatrix trueYt = charVector.get(String.valueOf(seq.charAt(t + 1)));
                    acts.put("y" + t, trueYt);

                    System.out.print(indexChar.get(predcitYt.argmax()));
                    //error += LossFunction.getMeanCategoricalCrossEntropy(predcitYt, trueYt);

                }

                System.out.println();

                // bptt
                gru.bptt(acts, seq.length() - 2, lr);

                num +=  seq.length();
            }
            System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
        }
    }

    private void test(CharText ctext) {

        Map indexChar = ctext.getIndexChar();
        Map charVector = ctext.getCharVector();
        Map acts = new HashMap<>();     
        String seq=" ";     
        int t=0;
        DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t)));
        acts.put("x" + t, xt);
        gru.active(t, acts);         
        DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
        acts.put("py" + t, predcitYt); 
        System.out.print(indexChar.get(predcitYt.argmax()));
    }

トレーニングのテキストは次のとおりです.
    
    
    
    
    
    
  ×××
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    
×××
    
    
    
    
    
    
    
    

「いいえ」と入力すると、次の単語は「花」として提示されます.このアルゴリズムには時間の概念があるので、2つの単語の先頭にない成語を加えると、結果が異なることがわかります.
例2:予測結果
public static void main(String[] args) {  
        loadData();  
        int hiddenSize = 4;//       
        double lr = 0.1;  
        gru = new GRU(4, hiddenSize, new MatIniter(MatIniter.Type.Uniform, 0.1, 0, 0),3);//4    ,3      
        for (int i = 0; i < 2000; i++) {//  2000   
            double error = 0;  
            double num = 0;  
            double start = System.currentTimeMillis();  
            Map acts = new HashMap<>();  
            for (int s = 0; s < train_x.length; s++) {  
                double newx[][] = new double[1][4];  
                newx[0] = train_x[s];  
                DoubleMatrix xt = new DoubleMatrix(newx);//        
                //System.out.println(xt.getColumns()+" "+xt.getRows());  
                acts.put("x" + s, xt);  
                gru.active(s, acts);  
                DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));  
                acts.put("py" + s, predcitYt);  

                double newy[][] = new double[1][3];  
                newy[0] = train_y[s];  
                DoubleMatrix trueYt = new DoubleMatrix(newy);  
                acts.put("y" + s, trueYt);  

                //System.out.println(predcitYt.argmax()+"-->"+trueYt.argmax());

                if(predcitYt.argmax()!=trueYt.argmax())  
                    error++;  

                // bptt  
                num ++;  
            }  
            gru.bptt(acts, train_x.length-1, lr);  
            System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");  
        }//      

        //      
        int num = 0,error = 0;  
        Map acts = new HashMap<>();  
        for(int s = 0; s

この例では花の種類を予測し,もちろん決定木を用いて実現することもできる.方法を変えてもいい感じだ
参照:https://blog.csdn.net/czs1130/article/details/70717348