KMアルゴリズム

6709 ワード

KMアルゴリズムを紹介する2つの良い文章:
1: http://www.cppblog.com/MatoNo1/archive/2011/07/23/151724.aspx
2:
  KMアルゴリズムは,頂点ごとに1つの符号(トップスケールと呼ばれる)を与えることによって,最大重みマッチングを求める問題を完全マッチングを求める問題に変換する.頂点XiのトップマークをA[i]、頂点YiのトップマークをB[i]、頂点XiとYjの間のエッジ権をw[i,j]とする.アルゴリズム実行中のいずれかの時点で、いずれかの辺(i,j)に対して、A[i]+B[j]>=w[i,j]は常に成立する.KMアルゴリズムの正しさは,A[i]+B[j]=w[i,j]を満たすすべての辺(i,j)からなるサブマップ(等しいサブマップと呼ぶ)が完全に一致すると,この完全マッチングが二分マップの最大重みマッチングとなるという定理に基づいている.この定理は明らかだ.二分図のいずれかのマッチングについて、等しいサブ図に含まれる場合、そのエッジ重みは、すべての頂点のトップスケール和に等しい.あるエッジが等しいサブマップに含まれていない場合、そのエッジの重みとすべての頂点のトップスケールの和より小さい.したがって、等しいサブマップの完全なマッチングは、必ず二分図の最大重みマッチングである.初期時にA[i]+B[j]>=w[i,j]を一定にするために、A[i]を頂点Xiに関連付けられたすべてのエッジの最大重みとし、B[j]=0とする.現在の等しいサブマップが完全に一致していない場合は、次の方法でトップスケールを変更して、等しいサブマップが完全に一致するまで拡大します.現在の等しいサブマップの完全なマッチングが失敗したのは,あるX頂点に対して,それから出発する交差路が見つからないためである.この時、私たちは交錯木を手に入れました.その葉の結点はすべてX頂点です.ここで、インタリーブツリーのX頂点のトップスケールをすべて値dを小さくし、Y頂点のトップスケールをすべて同じ値dを増やすと、次のことがわかります.
  • 両端はいずれもインターリーブツリーのエッジ(i,j),A[i]+B[j]の値は変化しなかった.すなわち,元は等しいサブマップに属していたが,現在も等しいサブマップに属している.
  • 両端は交錯木の中の辺(i,j),A[i]およびB[j]に変化しなかった.すなわち、元は等しいサブマップに属していた(または属していない)が、現在も等しいサブマップに属している(または属していない).
  • X端はインタリーブツリーになく、Y端はインタリーブツリーのエッジ(i,j)にあり、そのA[i]+B[j]の値は大きくなる.元は等しいサブマップに属していなかったが、現在も等しいサブマップに属していない.
  • X端はインタリーブツリーにあり、Y端はインタリーブツリーのエッジ(i,j)になく、そのA[i]+B[j]の値は減少する.すなわち,元は等子図に属していなかったが,現在は等子図に入る可能性があり,等子図を拡大した.

  • 今の問題はd値を求めることです.A[i]+B[j]>=w[i,j]が常に成立し、少なくとも1つのエッジが等しいサブマップに入るように、dはmin{A[i]+B[j]-w[i,j]|Xiがインタリーブツリーにあり、Yiがインタリーブツリーにないことに等しいはずである.
    以上がKMアルゴリズムの基本構想である.しかし、素朴な実現方法では、時間的複雑度はO(n 4)--O(n)次拡張路を探す必要があり、拡張するたびに最大O(n)次トップスケールを修正する必要があり、トップスケールを修正するたびにエッジを列挙してd値を求めるため、複雑度はO(n 2)である.実際,KMアルゴリズムの複雑さはO(n 3)まで可能である.各Y頂点に「緩和量」関数slackを与え,拡張路を探し始めるたびに無限大に初期化した.拡張路を探している間に、エッジ(i,j)をチェックするときに、等しいサブマップにない場合は、slack[j]をA[i]+B[j]-w[i,j]の元の値の小さい値にします.これにより、トップラベルを修正する際に、インタリーブツリーにないY頂点のslack値のうちの最小値をd値として取得すればよい.しかし、トップマークを修正した後、すべてのslack値をdから減算することに注意してください.
    テンプレート対応http://acm.hdu.edu.cn/showproblem.php?pid=2255
    O(N^4)テンプレート:
    #include <cstdio>
    
    #include <cstring>
    
    #include <iostream>
    
    #define maxn 301
    
    using namespace std;
    
    
    
    const int inf = 99999999;
    
    int w[maxn][maxn],link[maxn];
    
    int lx[maxn],ly[maxn];//    
    
    bool vtx[maxn],vty[maxn];//  X,Y        
    
    int nx,ny;
    
    
    
    bool dfs(int i)
    
    {
    
        int j;
    
        vtx[i] = true;//     
    
        for (j = 0; j < ny; ++j)
    
        {
    
            //         
    
            if (!vty[j] && lx[i] + ly[j] == w[i][j])
    
            {
    
                vty[j] = true;
    
                if (link[j] == -1 || dfs(link[j]))
    
                {
    
                    link[j] = i;
    
                    return true;
    
                }
    
            }
    
        }
    
        return false;
    
    }
    
    int KM()
    
    {
    
        int i,j,k;
    
        //   X     
    
        for (i = 0; i < nx; ++i)
    
        {
    
            for (j = 0,lx[i] = -inf; j < ny; ++j)
    
            {
    
                lx[i] = max(lx[i],w[i][j]);
    
            }
    
        }
    
        //   
    
        for (i = 0; i < maxn; ++i)
    
        {
    
            link[i] = -1; ly[i] = 0;
    
        }
    
        for (i = 0; i < nx; ++i)
    
        {
    
            while (1)
    
            {
    
                for (j = 0; j < maxn; ++j) vtx[j] = vty[j] = false;
    
                if (dfs(i)) break;//         
    
                int d = inf;
    
                //  d
    
                for (j = 0; j < nx; ++j)
    
                {
    
                    if (vtx[j])
    
                    {
    
                        for (k = 0; k < ny; ++k)
    
                        {
    
                            if (!vty[k]) d = min(d,lx[j] + ly[k] - w[j][k]);
    
                        }
    
                    }
    
                }
    
                if (d == inf) return -1;//       
           // dfs d, lx[i] + ly[j] == w[i][j] 。 for (j = 0; j < nx; ++j) if (vtx[j]) lx[j] -= d; for (j = 0; j < ny; ++j) if (vty[j]) ly[j] += d; } } // int sum = 0; for (i = 0; i < ny; ++i) if (link[i] > -1) sum += w[link[i]][i]; return sum; } int main() { int i,j; while (~scanf("%d",&nx)) { ny = nx; memset(w,0,sizeof(w)); for (i = 0; i < nx; ++i) { for (j = 0; j < ny; ++j) { scanf("%d",&w[i][j]); } } printf("%d
    ",KM()); } return 0; }

    O(n^3)テンプレート:
    #include <cstdio>
    
    #include <cstring>
    
    #include <iostream>
    
    #define maxn 301
    
    using namespace std;
    
    
    
    const int inf = 99999999;
    
    int w[maxn][maxn],link[maxn];
    
    int lx[maxn],ly[maxn];//    
    
    int slack[maxn];
    
    bool vtx[maxn],vty[maxn];//  X,Y        
    
    int nx,ny;
    
    
    
    bool dfs(int i)
    
    {
    
       int j;
    
       vtx[i] = true;
    
       for (j = 0; j < ny; ++j)
    
       {
    
           if (vty[j]) continue;
    
           int tmp = lx[i] + ly[j] - w[i][j];
    
           if (tmp == 0)
    
           {
    
               vty[j] = true;
    
               if (link[j] == -1 || dfs(link[j]))
    
               {
    
                   link[j] = i;
    
                   return true;
    
               }
    
           }
    
           else
    
           slack[j] = min(slack[j],tmp);//     
    
       }
    
       return false;
    
    }
    
    
    
    int KM()
    
    {
    
        int i,j;
    
        //   lx
    
        for (i = 0; i < nx; ++i)
    
        {
    
            for (j = 0,lx[i] = -inf; j < ny; ++j)
    
            {
    
                lx[i] = min(lx[i],w[i][j]);
    
            }
    
        }
    
        for (i = 0; i < maxn; ++i)
    
        {
    
            link[i] = -1; ly[i] = 0;
    
        }
    
        for (i = 0; i < nx; ++i)
    
        {
    
            for (j = 0; j < ny; ++j) slack[j] = inf;//      
    
            while (1)
    
            {
    
                for (j = 0; j < maxn; ++j) vtx[j] = vty[j] = false;
    
                if (dfs(i)) break;
    
                int d = inf;
    
                for (j = 0; j < ny; ++j)
    
                {
    
                    if (!vty[j] && d > slack[j]) d = slack[j];
    
                }
    
                if (d == inf) return -1;
    
                for (j = 0; j < nx; ++j)
    
                if (vtx[j]) lx[j] -= d;
    
                for (j = 0; j < ny; ++j)
    
                if (vty[j]) ly[j] += d;
    
                else    slack[j] -= d;//        
    
            }
    
        }
    
        int sum = 0;
    
        for (i = 0; i < ny; ++i)
    
        if (link[i] > -1) sum += w[link[i]][i];
    
        return sum;
    
    }
    
    int main()
    
    {
    
        int i,j;
    
        while (~scanf("%d",&nx))
    
        {
    
            ny = nx;
    
            memset(w,0,sizeof(w));
    
            for (i = 0; i < nx; ++i)
    
            {
    
                for (j = 0; j < ny; ++j)
    
                {
    
                    scanf("%d",&w[i][j]);
    
                }
    
            }
    
            printf("%d
    ",KM()); } return 0; }