Strassen長方形アルゴリズム


行列乗算は線形代数で最も一般的な演算の一つであり,数値計算に広く応用されている.AとBが2 nnの行列であれば、それらの積C=ABは同じnnの行列である.AとBの積行列Cの要素C[i,j]は、この定義に従ってAとBの積行列Cを計算すると、Cを計算する要素C[i,j]ごとにn個の乗算とn−1回の加算が必要となると定義される.したがって,行列Cのn∧2要素を求めるのに要する計算時間は0(n∧3)である.1960年代末、Strassenは大きな整数乗算で用いられたような分治技術を採用し、2つのn次行列積を計算するのに必要な計算時間をO(n∧log 7)=O(n∧2.18)に改善し、その基本思想はやはり分治法を用いた.まず,nが2のべき乗であると仮定する必要がある.行列A,B,Cの各行列を4つの大きさの等しいサブ行列に分割し,各サブ行列は(n/2)*(n/2)の方程式である.これにより、方程式C=ABを次のように書き換えることができる.
       C∨11    C∨12        A∨11     A∨12       B∨11   B∨12
       C∨21    C∨22   =   A∨21     A∨22  =   B∨21   B∨22
これにより、C 44 B∨22(5)n=2の場合,2つの2次方程式の積は,(2)−(3)式で直接計算でき,合計8回の乗算と4回の加算が必要である.サブマトリクスの次数が2より大きい場合、2つのサブマトリクスの積を求めるために、サブマトリクスの次数が2に下がるまで、サブマトリクスをブロック化し続けることができる.これにより,分治降格の再帰アルゴリズムが生成される.このアルゴリズムによれば、2つのn次方程式の積を計算することは、8つのn/2次方程式の積と4つのn/2次方程式の加算を計算することに変換される.2つのn/2 n/2行列の加算は明らかにc*n 2/4時間以内に完了することができ,ここでcは定数である.従って,上記の分割法の計算時間はT(n)を費やし,この再帰方程式の解はT(n)=O(n 3)であることを満たすべきである.したがって,この方法は元の定義で直接計算するよりも有効ではない.その理由は,式(2)−(5)が行列の乗算回数を減少させなかったためである.マトリクス乗算にかかる時間はマトリクス加減法にかかる時間よりずっと多い.マトリクス乗算の計算時間の複雑さを改善するには,サブマトリクス乗算の回数を減らす必要がある.上記の分治法の考え方から,乗算回数を減らすには,2つの2次行列の積を計算する際に8回未満の乗算ができるかどうかが鍵となる.Strassenは2つの2次方程式の積を計算するための新しいアルゴリズムを提案した.彼のアルゴリズムは7回の乗算しか使わなかったが、加算・減算の演算回数を増やした.この7回の乗算は、M A∨22)(B∨11+B∨22)M∨6=(A∨12-A∨22)(B∨21+B∨22)M∨7=(A∨11-A∨21)(B∨11+B∨12)この7回の乗算をした後、何度か加算、減算すると、C∨11=M∨5+M∨4-M∨2+M∨6 C∨12=M∨1+M∨2 C∨21=M∨3+M∨22=M∨5+M∨1-M∨3-M∨7以上の計算の正確性が検証されやすい.例えば、C∨22=M∨12)   =A∨11B∨11+A∨11B∨22+A∨22B∨11+A∨22B∨22+A∨11B∨12   -A∨11B∨22-A∨21B∨11-A∨22B∨11-A∨11B∨11-A∨11B∨12+A∨21B∨11+A∨21B∨12   =A∨21B∨12+A∨22B∨22    
Strassen行列乗積分治アルゴリズムでは,n/2次行列積に対する7回の再帰呼び出しと18回のn/2次行列の加減演算を用いた.このことから,このアルゴリズムに必要な計算時間T(n)は,T(n)=O(nlog 7)≒O(n 2.81)となる解再帰方程式の定式化法に従って満たされることが分かった.このことから,Strassen行列乗算の計算時間複雑性は通常の行列乗算よりも段階的に改善されることが分かった.2つの2次行列乗算を計算する36の異なる方法を列挙したことがある.しかし、すべての方法は7回の乗算をしなければならない.乗算の計算回数が7回未満になるように2次方程式積を計算するアルゴリズムが見つからない限り、マトリクス積の計算時間の上界をさらに改善することは、上記の考え方によって可能である.しかしHopcroftとKerr(197 l)は,2つの22行列の積を計算し,7回の乗算が必要であることを実証した.したがって,マトリクス乗算の時間的複雑さをさらに改善するには,22マトリクスの乗算回数の減少を計算することに期待できない.33または55行列のより良いアルゴリズムを研究すべきかもしれない.Strassenの後に行列乗算の計算時間の複雑さを改善する多くのアルゴリズムがある.現在の最良の計算時間の境界はO(n 2.367)である.現在知られている行列乗算の最良の下限は依然としてその平凡な下限Ω(n 2)である.従って,マトリクス乗算の時間的複雑さはこれまで正確には知られていなかった.この研究課題についてはまだ多くの仕事がある.
 
Demo:
 
package Recursive;
/**
 * Strassen     
 */
import java.io.*;
import java.util.*;
class matrix
{
    public int[][] m = new int[32][32];
}
public class Strassen
{
    public int judgment(int n)
    {
        int flag = 0,temp=n;
        while(temp%2==0)
        {
            if(temp%2==0) 
                temp/=2;
            else flag=1;
        }
        if(temp==1) 
            flag=0;
        return flag;
    }
    public void Divide(matrix d,matrix d11,matrix d12,matrix d21,matrix d22,int n)/**//*      */
    {
        int i,j;
        for(i=1;i<=n;i++)
        for(j=1;j<=n;j++)
        {
            d11.m[i][j]=d.m[i][j];
            d12.m[i][j]=d.m[i][j+n];
            d21.m[i][j]=d.m[i+n][j];
            d22.m[i][j]=d.m[i+n][j+n];
        }
    }
    public matrix Merge(matrix a11,matrix a12,matrix a21,matrix a22,int n)/**//*      */
    {
        int i,j;
        matrix a = new matrix();
        for(i=1;i<=n;i++)
        for(j=1;j<=n;j++)
        {
            a.m[i][j]=a11.m[i][j];
            a.m[i][j+n]=a12.m[i][j];
            a.m[i+n][j]=a21.m[i][j];
            a.m[i+n][j+n]=a22.m[i][j];
        }
        return a;
    }
    public matrix AdhocMatrixMultiply(matrix x,matrix y) /**//*   2       */
    {
        int m1,m2,m3,m4,m5,m6,m7;
        matrix z = new matrix();

        m1=(y.m[1][2] - y.m[2][2]) * x.m[1][1];
        m2=y.m[2][2] * (x.m[1][1] + x.m[1][2]);
        m3=(x.m[2][1] + x.m[2][2]) * y.m[1][1];
        m4=x.m[2][2] * (y.m[2][1] - y.m[1][1]);
        m5=(x.m[1][1] + x.m[2][2]) * (y.m[1][1]+y.m[2][2]);
        m6=(x.m[1][2] - x.m[2][2]) * (y.m[2][1]+y.m[2][2]);
        m7=(x.m[1][1] - x.m[2][1]) * (y.m[1][1]+y.m[1][2]);
        z.m[1][1] = m5 + m4 - m2 + m6;
        z.m[1][2] = m1 + m2;
        z.m[2][1] = m3 + m4;
        z.m[2][2] = m5 + m1 - m3 - m7;
        return z;
    }
    public matrix MatrixPlus(matrix f,matrix g,int n) /**//*      */
    {
        int i,j;
        matrix h = new matrix();
        for(i=1;i<=n;i++)
        for(j=1;j<=n;j++)
        h.m[i][j]=f.m[i][j]+g.m[i][j];
        return h;
    }
    public matrix MatrixMinus(matrix f,matrix g,int n) /**//*      */
    {
        int i,j;
        matrix h = new matrix();
        for(i=1;i<=n;i++)
        for(j=1;j<=n;j++)
        h.m[i][j]=f.m[i][j]-g.m[i][j];
        return h;
    }

    public matrix MatrixMultiply(matrix a,matrix b,int n) /**//*      */
    {
        int k;
        matrix a11,a12,a21,a22;
        a11 = new matrix();
        a12 = new matrix();
        a21 = new matrix();
        a22 = new matrix();
        matrix b11,b12,b21,b22;
        b11 = new matrix();
        b12 = new matrix();
        b21 = new matrix();
        b22 = new matrix();
        matrix c11,c12,c21,c22,c;
        c11 = new matrix();
        c12 = new matrix();
        c21 = new matrix();
        c22 = new matrix();
        c = new matrix();
        matrix m1,m2,m3,m4,m5,m6,m7;
        k=n;
        if(k==2)
        {
            c=AdhocMatrixMultiply(a,b);
            return c;
        }
        else
        { 
            k=n/2;
            Divide(a,a11,a12,a21,a22,k); //  A、B、C  
            Divide(b,b11,b12,b21,b22,k);
            Divide(c,c11,c12,c21,c22,k);
            
            m1=MatrixMultiply(MatrixMinus(b12,b22,n/2),a11,k);
            m2=MatrixMultiply(b22,MatrixPlus(a11,a12,k),k);
            m3=MatrixMultiply(MatrixPlus(a21,a22,k),b11,k);
            m4=MatrixMultiply(a22,MatrixMinus(b21,b11,k),k);
            m5=MatrixMultiply(MatrixPlus(a11,a22,k),MatrixPlus(b11,b22,k),k);
            m6=MatrixMultiply(MatrixMinus(a12,a22,k),MatrixPlus(b21,b22,k),k);
            m7=MatrixMultiply(MatrixMinus(a11,a21,k),MatrixPlus(b11,b12,k),k);
            c11=MatrixPlus(MatrixMinus(MatrixPlus(m5,m4,k),m2,k),m6,k);
            c12=MatrixPlus(m1,m2,k);
            c21=MatrixPlus(m3,m4,k);
            c22=MatrixMinus(MatrixMinus(MatrixPlus(m5,m1,k),m3,k),m7,k);
            
            c=Merge(c11,c12,c21,c22,k); //  C  
            return c;
        } 
    }
    public static void main(String[] args)throws IOException
    {
        Strassen instance = new Strassen();
        int i,j,num;
        matrix A,B,C;
        A = new matrix();
        B = new matrix();
        C = new matrix();
        Scanner in = new Scanner(System.in);
        System.out.print("       : ");
        num = in.nextInt();
        if(instance.judgment(num)==0)
        {
            System.out.println("    A:");
            for(i=1;i<=num;i++)
                for(j=1;j<=num;j++)
                    A.m[i][j] = in.nextInt();
            System.out.println("    B:");
            for(i=1;i<=num;i++)
                for(j=1;j<=num;j++)
                    B.m[i][j] = in.nextInt();
            if(num==1) 
                C.m[1][1]=A.m[1][1]*B.m[1][1]; //     1       
            else 
                C=instance.MatrixMultiply(A,B,num);
            System.out.println("  C :");
            for(i=1;i<=num;i++)
            {
                for(j=1;j<=num;j++)
                    System.out.print(C.m[i][j] + "     ");
                System.out.println();
            }
        }
        else
            System.out.println("       2 N  ");
    }
}

 
結果を表示
       : 2
    A:
12
12
12
12
    B:
12
12
12
12
  C :
288     288     
288     288