深さ学習のGRUアルゴリズムの例
まずコードをダウンロードします.https://github.com/whk6688/rnn
例1:予測以下
トレーニングのテキストは次のとおりです.
「いいえ」と入力すると、次の単語は「花」として提示されます.このアルゴリズムには時間の概念があるので、2つの単語の先頭にない成語を加えると、結果が異なることがわかります.
例2:予測結果
この例では花の種類を予測し,もちろん決定木を用いて実現することもできる.方法を変えてもいい感じだ
参照:https://blog.csdn.net/czs1130/article/details/70717348
例1:予測以下
private void train(CharText ctext, double lr) {
Map indexChar = ctext.getIndexChar();
Map charVector = ctext.getCharVector();
List sequence = ctext.getSequence();
for (int i = 0; i < 100; i++) {
double error = 0;
double num = 0;
double start = System.currentTimeMillis();
for (int s = 0; s < sequence.size(); s++) {
String seq = sequence.get(s);
if (seq.length() < 3) {
continue;
}
Map acts = new HashMap<>();
// forward pass
System.out.print(String.valueOf(seq.charAt(0)+"->"));
for (int t = 0; t < seq.length() - 1; t++) {
DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t)));
acts.put("x" + t, xt);
gru.active(t, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
acts.put("py" + t, predcitYt);
DoubleMatrix trueYt = charVector.get(String.valueOf(seq.charAt(t + 1)));
acts.put("y" + t, trueYt);
System.out.print(indexChar.get(predcitYt.argmax()));
//error += LossFunction.getMeanCategoricalCrossEntropy(predcitYt, trueYt);
}
System.out.println();
// bptt
gru.bptt(acts, seq.length() - 2, lr);
num += seq.length();
}
System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
}
}
private void test(CharText ctext) {
Map indexChar = ctext.getIndexChar();
Map charVector = ctext.getCharVector();
Map acts = new HashMap<>();
String seq=" ";
int t=0;
DoubleMatrix xt = charVector.get(String.valueOf(seq.charAt(t)));
acts.put("x" + t, xt);
gru.active(t, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + t));
acts.put("py" + t, predcitYt);
System.out.print(indexChar.get(predcitYt.argmax()));
}
トレーニングのテキストは次のとおりです.
×××
×××
「いいえ」と入力すると、次の単語は「花」として提示されます.このアルゴリズムには時間の概念があるので、2つの単語の先頭にない成語を加えると、結果が異なることがわかります.
例2:予測結果
public static void main(String[] args) {
loadData();
int hiddenSize = 4;//
double lr = 0.1;
gru = new GRU(4, hiddenSize, new MatIniter(MatIniter.Type.Uniform, 0.1, 0, 0),3);//4 ,3
for (int i = 0; i < 2000; i++) {// 2000
double error = 0;
double num = 0;
double start = System.currentTimeMillis();
Map acts = new HashMap<>();
for (int s = 0; s < train_x.length; s++) {
double newx[][] = new double[1][4];
newx[0] = train_x[s];
DoubleMatrix xt = new DoubleMatrix(newx);//
//System.out.println(xt.getColumns()+" "+xt.getRows());
acts.put("x" + s, xt);
gru.active(s, acts);
DoubleMatrix predcitYt = gru.decode(acts.get("h" + s));
acts.put("py" + s, predcitYt);
double newy[][] = new double[1][3];
newy[0] = train_y[s];
DoubleMatrix trueYt = new DoubleMatrix(newy);
acts.put("y" + s, trueYt);
//System.out.println(predcitYt.argmax()+"-->"+trueYt.argmax());
if(predcitYt.argmax()!=trueYt.argmax())
error++;
// bptt
num ++;
}
gru.bptt(acts, train_x.length-1, lr);
System.out.println("Iter = " + i + ", error = " + error / num + ", time = " + (System.currentTimeMillis() - start) / 1000 + "s");
}//
//
int num = 0,error = 0;
Map acts = new HashMap<>();
for(int s = 0; s
この例では花の種類を予測し,もちろん決定木を用いて実現することもできる.方法を変えてもいい感じだ
参照:https://blog.csdn.net/czs1130/article/details/70717348