dlibライブラリsvm_c_ex.cppの詳細な注記
15061 ワード
dlibは現代のC++技術で作成されたプラットフォーム間汎用アルゴリズムライブラリであり、dlibアルゴリズムライブラリを適用すると、ある態様またはある研究分野に迅速にアルゴリズムを適用することができる.dilbライブラリの各クラスと各関数には詳細なドキュメントがあり、多くのサンプルコードが提供されていますが、dlibライブラリが膨大すぎて、dlibライブラリに関するネット上の紹介が比較的少ないため、初心者にはいくつかの困難があります.本稿では、dlibライブラリにおけるC-SVMの実装アルゴリズムsvm_について説明する.c_ex.cppで詳しく紹介します.不足点は読者の皆様のご指摘をお願いいたします.
/*
Dlib SVM , SVM c 。
, (cross validation) SVM 。
。
2 , 10 +1, -1.
*/
#include
#include
using namespace std;
using namespace dlib;//
int main()
{
/*SVM , 。 typedef*/
/*typedef int size,typedef 。 size int */
/* 2 。 , 2 ;
, 2 0, matrix.set_size() */
typedef matrix<double, 2, 1> sample_type;
/*typedef kernel 。
(radial basis kernel) 2D 。
with tools */
typedef radial_basis_kernel kernel_type;
/* samples labels, */
std::vector samples;
std::vector<double> labels;
/* samples labels 。 。*/
for (int r = -20; r <= 20; ++r)
{
for (int c = -20; c <= 20; ++c)
{
sample_type samp;// sample_type 2 1 samp
samp(0) = r;
samp(1) = c;// samp
samples.push_back(samp);// samp samples
// if this point is less than 10 from the origin
if (sqrt((double)r*r + c*c) <= 10)//sqrt((double)r*r + c*c) 2D (0,0)
labels.push_back(+1);// +1
else
labels.push_back(-1);// -1. labels ,push_back
}
}
/*
-20 +20, -20 +20
-1 +1
(-1,1) (0,1) (1,1)
(-1,0) (0,0) (1,0)
(-1,-1) (0,-1) (1,-1)
, <10 +1, -1
*/
/* 。 , it often heads off numerical stability problems and
also prevents one large feature from smothering others。 , :*/
vector_normalizer normalizer;// normalizer
/* */
normalizer.train(samples);
// now normalize each sample : samples.size()
for (unsigned long i = 0; i < samples.size(); ++i)
samples[i] = normalizer(samples[i]);
/* 。 , -c gamma 。 。
cross_validate_trainer() n 。
, , 。
, 。
, :*/
/* gamma sigma*/
randomize_samples(samples, labels);//
/* kernel_type svm_c_trainer trainer*/
svm_c_trainer trainer;
/* C gamma , C gamma 。
C gamma 。
model_selection_ex.cpp, */
cout << "doing cross validation" << endl;
//
for (double gamma = 0.00001; gamma <= 1; gamma *= 5)
{
for (double C = 1; C < 100000; C *= 5)//gamma=0.00001:1, 5;C=1:100000, 5
{
trainer.set_kernel(kernel_type(gamma));
trainer.set_c(C);// , C gamma
cout << "gamma: " << gamma << " C: " << C;
/* C gamma 3 。
cross_validate_trainer() 。 +1 , -1 。*/
cout << " cross validation accuracy: "
<< cross_validate_trainer(trainer, samples, labels, 3);//3
}
}
/* , C gamma 5 0.15625。*/
/* 。
+1 >=0 , -1 <0 。*/
trainer.set_kernel(kernel_type(0.15625));
trainer.set_c(5);// c gamma 5 0.15625
typedef decision_function dec_funct_type;
typedef normalized_function funct_type;//funct_type
/* 。
, 。*/
funct_type learned_function;
learned_function.normalizer = normalizer; // save normalization information
learned_function.function = trainer.train(samples, labels); // perform the actual SVM training and save the results
/* (basis_vectors)*/
cout << "
number of support vectors in our learned_function is "
<< learned_function.function.basis_vectors.size() << endl;
// , (learned_function)
sample_type sample;//sample 2 1
sample(0) = 3.123;
sample(1) = 2;// (2D )。 +1, >=0. -1, <0。
// <10 +1, -1
cout << "This is a +1 class example, the classifier output is " << learned_function(sample) << endl;
sample(0) = 3.123;
sample(1) = 9.3545;
cout << "This is a +1 class example, the classifier output is " << learned_function(sample) << endl;
sample(0) = 13.123;
sample(1) = 9.3545;
cout << "This is a -1 class example, the classifier output is " << learned_function(sample) << endl;
sample(0) = 13.123;
sample(1) = 0;
cout << "This is a -1 class example, the classifier output is " << learned_function(sample) << endl;
/* , , 0
+1 -1 */
typedef probabilistic_decision_function probabilistic_funct_type;
typedef normalized_function pfunct_type;
pfunct_type learned_pfunct;// learned_pfunct
learned_pfunct.normalizer = normalizer;
learned_pfunct.function = train_probabilistic_decision_function(trainer, samples, labels, 3);
// learned_pfunct +1 -1 。
//+1 ,-1
// , learned_function
cout << "
number of support vectors in our learned_pfunct is "
<< learned_pfunct.function.decision_funct.basis_vectors.size() << endl;
sample(0) = 3.123;
sample(1) = 2;
cout << "This +1 class example should have high probability. Its probability is: "
<< learned_pfunct(sample) << endl;
sample(0) = 3.123;
sample(1) = 9.3545;
cout << "This +1 class example should have high probability. Its probability is: "
<< learned_pfunct(sample) << endl;
sample(0) = 13.123;
sample(1) = 9.3545;
cout << "This -1 class example should have low probability. Its probability is: "
<< learned_pfunct(sample) << endl;
sample(0) = 13.123;
sample(1) = 0;
cout << "This -1 class example should have low probability. Its probability is: "
<< learned_pfunct(sample) << endl;
/* , dlib ,everything 。
, learned_pfunct disk , disk 。*/
serialize("saved_function.dat") << learned_pfunct;//serialize saved_function.dat,<
// learned_pfunct
/* , */
deserialize("saved_function.dat") >> learned_pfunct;//deserialize, saved_function.dat,>>learned_pfunct
/* dlib ,file_to_code_ex.cpp。 : C++ ,
C++ 。 std::istringstream */
/* 200 。SVM 。
, 。
dlib 。*/
/* 10 , 。 reduced2() 。
reduced2() 。 。
。*/
cout << "
cross validation accuracy with only 10 support vectors: "
<< cross_validate_trainer(reduced2(trainer,10), samples, labels, 3);
//
cout << "cross validation accuracy with all the original support vectors: "
<< cross_validate_trainer(trainer, samples, labels, 3);
/* , <10, */
// :
learned_function.function = reduced2(trainer,10).train(samples, labels);
//
learned_pfunct.function = train_probabilistic_decision_function(reduced2(trainer,10), samples, labels, 3);
}