EMアルゴリズム及びOpenCVソース分析

6271 ワード

EMの原理については、以下を参照してください.
http://blog.csdn.net/app_12062011/article/details/50350428
EMアルゴリズムの起動と終了
アルゴリズム実行の開始ステップには、3つの指定方法があります.CvEM::START_を使用している場合AUTO_STEPでは、k-meansアルゴリズムを呼び出して最初のパラメータを推定し、K-meansはクラスの中心をランダムに初期化し、KMEANS_PP_CENTERSは、EMアルゴリズムによって異なる結果が得られ、データ量が大きいほどこの差異は小さくなる.CvEM::START_を指定した場合E_STEPまたはCvEM::START_M_STEPパラメータでは,同じ入力データが発生せず,異なる結果が得られる.CvEM::START_を指定した場合M_STEPパラメータは、Mステップで開始し、Mステップ固定这里写图片描述这里写图片描述を最適化し、確率这里写图片描述を与えなければならないCvEM::START_E_STEPは、Eステップで開始し、CvEMParams::meansは与えなければならず、CvEMParams::weightsおよびCvEMParams::covsパラメータは与えられなくてもよく、weightsは初期の各成分を表す確率を与える.
アルゴリズム実行の終了条件.EMアルゴリズムは反復アルゴリズムであり,自然終了条件は反復回数が達成されたか,あるいは2回の反復間の差がepsilonより小さいと終了することができる.パラメータの解析については、機械学習中国語参考マニュアルを参照してください.
CvEM::train関数で次の手順を実行します.
  • init_params.
  • emObj=EM//EMオブジェクトを作成します.
  • 準拠_params.start_Step値は、異なるプロセス、train、trainE、trainMを実行します.この3つのtrainプロセスはlogLikelihoods(Mat構造)を返します.Labels,probs(所与のサンプルxが各カテゴリに属する後験確率).トレーニングの前にtrain関数でsetTrainDataを呼び出してトレーニングデータを準備し、do_を呼び出します.train正式訓練.setTrainDataは、START_AUTO_STEPではK-meansが開き、データをCV_に変換します32FC1.訓練データはクラスメンバーtrainSamplesに保存されます.
  • 実行do_trainプロセス1.clusterTrainSamplesでkmeansメソッドクラスタリングトレーニングを呼び出しnclustersクラスを統合し,各サンプルのカテゴリを得た.Kmeansが実行する前にデータがCVであることを保証する32 FC 1タイプ、実行後にCV_に変換64 FC 1タイプ.2.labelsに従ってすべてのデータをnclusters個のマトリクスに配置し、各マトリクスの共分散マトリクスと各カテゴリの重み値(サンプル数をサンプル総数で割ったもの)をそれぞれ計算する.共分散行列ごとに特異値分解(SVD)を行い,最大特異値の逆数を得た.3.次の条件が満たされるまで、Eステップ、Mステップを繰り返し実行する:
  • trainLogLikelihoodDeltaは,2回の隣接反復過程における対数尤度の増分を表す.
    E-stepソース:





    上の式に注意して、logに変換した後、乗算除算は加算減算になります
    Vec2d EM::computeProbabilities(const Mat& sample, Mat* probs) const
    {
        // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
        // q = arg(max_k(L_ik))
        // probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
        // see Alex Smola's blog http://blog.smola.org/page/2 for
        // details on the log-sum-exp trick
    
        CV_Assert(!means.empty());
        CV_Assert(sample.type() == CV_64FC1);
        CV_Assert(sample.rows == 1);
        CV_Assert(sample.cols == means.cols);
    
        int dim = sample.cols;
    
        Mat L(1, nclusters, CV_64FC1);	//L  1*nclusters
        int label = 0;
        for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
        {
            const Mat centeredSample = sample - means.row(clusterIndex); //    
    
            Mat rotatedCenteredSample = covMatType != EM::COV_MAT_GENERIC ?
                    centeredSample : centeredSample * covsRotateMats[clusterIndex];
    
            double Lval = 0;
            for(int di = 0; di < dim; di++)
            {
                double w = invCovsEigenValues[clusterIndex].at(covMatType != EM::COV_MAT_SPHERICAL ? di : 0); 
    			//              
                double val = rotatedCenteredSample.at(di);
                Lval += w * val * val;//                      
            }
            CV_DbgAssert(!logWeightDivDet.empty());
            L.at(clusterIndex) = logWeightDivDet.at(clusterIndex) - 0.5 * Lval;
    		  // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
    		  // note: L.at(clusterIndex) =  log(weight_k) - 0.5 * log(|det(cov_k)|)-0.5 * Lval
    		  
            if(L.at(clusterIndex) > L.at(label))   
                label = clusterIndex;		//      label 
        }
    
        double maxLVal = L.at(label);  //
        Mat expL_Lmax = L; // exp(L_ij - L_iq)	  //L  1*nclusters
        for(int i = 0; i < L.cols; i++)
            expL_Lmax.at(i) = std::exp(L.at(i) - maxLVal);
        double expDiffSum = sum(expL_Lmax)[0]; // sum_j(exp(L_ij - L_iq))
    
        if(probs) //probs
        {
            probs->create(1, nclusters, CV_64FC1);
            double factor = 1./expDiffSum;
            expL_Lmax *= factor;
            expL_Lmax.copyTo(*probs);
        }
    
        Vec2d res;
        res[0] = std::log(expDiffSum)  + maxLVal - 0.5 * dim * CV_LOG2PI; //dim      CV_LOG2PI (1.8378770664093454835606594728112)
       	 //
    	res[1] = label;
    
        return res;
    }

    M-step:
    EM算法及OpenCV源码分析_第1张图片