dlibライブラリsvm_c_ex.cppの詳細な注記

15061 ワード

dlibは現代のC++技術で作成されたプラットフォーム間汎用アルゴリズムライブラリであり、dlibアルゴリズムライブラリを適用すると、ある態様またはある研究分野に迅速にアルゴリズムを適用することができる.dilbライブラリの各クラスと各関数には詳細なドキュメントがあり、多くのサンプルコードが提供されていますが、dlibライブラリが膨大すぎて、dlibライブラリに関するネット上の紹介が比較的少ないため、初心者にはいくつかの困難があります.本稿では、dlibライブラリにおけるC-SVMの実装アルゴリズムsvm_について説明する.c_ex.cppで詳しく紹介します.不足点は読者の皆様のご指摘をお願いいたします.
/* 
      Dlib  SVM   ,  SVM   c      。

           ,            (cross validation) SVM                 。
                       。

        2  ,             10          +1,          -1.     
*/


#include 
#include 

using namespace std;
using namespace dlib;//      

int main()
{
    /*SVM        ,        。            typedef*/
    /*typedef int size,typedef          。    size  int      */
    /*             2    。           ,   2    ;
                 ,     2  0,   matrix.set_size()      */
    typedef matrix<double, 2, 1> sample_type;

    /*typedef          kernel  。
            (radial basis kernel)       2D  。
                with tools  */
    typedef radial_basis_kernel kernel_type;

    /*      samples labels,        */
    std::vector samples;
    std::vector<double> labels;

    /* samples labels  。                             。*/
    for (int r = -20; r <= 20; ++r)
    {
        for (int c = -20; c <= 20; ++c)
        {
            sample_type samp;//  sample_type    2 1    samp
            samp(0) = r;
            samp(1) = c;// samp    
            samples.push_back(samp);// samp       samples

            // if this point is less than 10 from the origin
            if (sqrt((double)r*r + c*c) <= 10)//sqrt((double)r*r + c*c)  2D     (0,0)   
                labels.push_back(+1);// +1
            else
                labels.push_back(-1);// -1. labels    ,push_back

        }
    }
    /*
        -20 +20,    -20 +20
       -1 +1
    (-1,1)   (0,1)   (1,1)

    (-1,0)   (0,0)   (1,0)

    (-1,-1)  (0,-1)  (1,-1)
            ,       <10  +1,   -1
    */

    /*                              。     ,  it often heads off numerical stability problems and
    also prevents one large feature from smothering others。          ,      :*/
    vector_normalizer normalizer;//        normalizer

    /*                      */
    normalizer.train(samples);

    // now normalize each sample        : samples.size()     
    for (unsigned long i = 0; i < samples.size(); ++i)
        samples[i] = normalizer(samples[i]);

    /*          。  ,           -c gamma  。                。
          cross_validate_trainer()          n               。
      ,             ,            。
      ,                    。
               ,                     :*/

    /*  gamma      sigma*/
    randomize_samples(samples, labels);//           

    /*          kernel_type  svm_c_trainer   trainer*/
    svm_c_trainer trainer;

    /*     C gamma      ,    C gamma   。
           C gamma          。
          model_selection_ex.cpp,                   */
    cout << "doing cross validation" << endl;
    //    
    for (double gamma = 0.00001; gamma <= 1; gamma *= 5)
    {
        for (double C = 1; C < 100000; C *= 5)//gamma=0.00001:1,   5;C=1:100000,   5
        {
            trainer.set_kernel(kernel_type(gamma));
            trainer.set_c(C);//     ,  C gamma       

            cout << "gamma: " << gamma << "    C: " << C;

            /*    C gamma  3     。
            cross_validate_trainer()        。            +1        ,         -1 。*/
            cout << "     cross validation accuracy: " 
                 << cross_validate_trainer(trainer, samples, labels, 3);//3     
        }
    }

    /*          ,    C gamma     5 0.15625。*/

    /*                  。
             +1     >=0  ,    -1     <0  。*/
    trainer.set_kernel(kernel_type(0.15625));
    trainer.set_c(5);//     c gamma  5 0.15625

    typedef decision_function dec_funct_type;
    typedef normalized_function funct_type;//funct_type        

    /*                。
               ,                        。*/
    funct_type learned_function;
    learned_function.normalizer = normalizer;  // save normalization information       
    learned_function.function = trainer.train(samples, labels); // perform the actual SVM training and save the results

    /*               (basis_vectors)*/
    cout << "
number of support vectors in our learned_function is "
<< learned_function.function.basis_vectors.size() << endl; // , (learned_function) sample_type sample;//sample 2 1 sample(0) = 3.123; sample(1) = 2;// (2D )。 +1, >=0. -1, <0。 // <10 +1, -1 cout << "This is a +1 class example, the classifier output is " << learned_function(sample) << endl; sample(0) = 3.123; sample(1) = 9.3545; cout << "This is a +1 class example, the classifier output is " << learned_function(sample) << endl; sample(0) = 13.123; sample(1) = 9.3545; cout << "This is a -1 class example, the classifier output is " << learned_function(sample) << endl; sample(0) = 13.123; sample(1) = 0; cout << "This is a -1 class example, the classifier output is " << learned_function(sample) << endl; /* , , 0 +1 -1 */ typedef probabilistic_decision_function probabilistic_funct_type; typedef normalized_function pfunct_type; pfunct_type learned_pfunct;// learned_pfunct learned_pfunct.normalizer = normalizer; learned_pfunct.function = train_probabilistic_decision_function(trainer, samples, labels, 3); // learned_pfunct +1 -1 。 //+1 ,-1 // , learned_function cout << "
number of support vectors in our learned_pfunct is "
<< learned_pfunct.function.decision_funct.basis_vectors.size() << endl; sample(0) = 3.123; sample(1) = 2; cout << "This +1 class example should have high probability. Its probability is: " << learned_pfunct(sample) << endl; sample(0) = 3.123; sample(1) = 9.3545; cout << "This +1 class example should have high probability. Its probability is: " << learned_pfunct(sample) << endl; sample(0) = 13.123; sample(1) = 9.3545; cout << "This -1 class example should have low probability. Its probability is: " << learned_pfunct(sample) << endl; sample(0) = 13.123; sample(1) = 0; cout << "This -1 class example should have low probability. Its probability is: " << learned_pfunct(sample) << endl; /* , dlib ,everything 。 , learned_pfunct disk , disk 。*/ serialize("saved_function.dat") << learned_pfunct;//serialize saved_function.dat,< // learned_pfunct /* , */ deserialize("saved_function.dat") >> learned_pfunct;//deserialize, saved_function.dat,>>learned_pfunct /* dlib ,file_to_code_ex.cpp。 : C++ , C++ 。 std::istringstream */ /* 200 。SVM 。 , 。 dlib 。*/ /* 10 , 。 reduced2() 。 reduced2() 。 。 。*/ cout << "
cross validation accuracy with only 10 support vectors: "
<< cross_validate_trainer(reduced2(trainer,10), samples, labels, 3); // cout << "cross validation accuracy with all the original support vectors: " << cross_validate_trainer(trainer, samples, labels, 3); /* , <10, */ // : learned_function.function = reduced2(trainer,10).train(samples, labels); // learned_pfunct.function = train_probabilistic_decision_function(reduced2(trainer,10), samples, labels, 3); }