Algorithm——行列乗算のStrassenアルゴリズム(六)
20588 ワード
Algorithmマトリクス乗算のStrassenアルゴリズム
アルゴリズム導論は暴力を利用して2 NxN行列の積を求める問題を用いてStrassenアルゴリズムを引き出した.以下のコード実装は,本中の暴力求解法,分治求解法,Strassen求解法の実装にそれぞれ対応しており,具体的には以下に示す.
この部分の内容に関する擬似コードは、「アルゴリズム導論」4.2章を参照して説明することができる.
行列の乗算知識によると、2つのNxNの行列AとBを乗算した結果、行列Cの暴力アルゴリズムは以下の通りである.
PS:
最後に,Strassenアルゴリズムが方程式積の複雑さを低減する理由をずっと考えていた.最後に知っている上で1つの答えを見て、解釈がとても良いと感じます;今貼って、一緒に分かち合います:
「strassenアルゴリズムの鍵は、乗算か加算かではなく、アルゴリズムの内部再帰呼び出しの回数にある.strassenアルゴリズムの鍵は、内部再帰呼び出しの回数が1(通常の8回から特殊な7回に変わる)減少することにある.「茂み」度は、この時間的複雑さの再帰アルゴリズムにおける変化を記述する.したがってstrassenアルゴリズムの鍵は,再帰呼び出しの回数がどのように8回から1回減少するかである.逆に理解すると,8回目の再帰呼び出しのうち1回が冗長である,すなわち8回目の再帰乗算の結果情報が上位7回の結果に含まれており,上位7回の計算結果は線形組合せにより8回目の再帰の結果を得ることができる.一方、この線形の組み合わせの時間的複雑度は、アルゴリズム自体(すなわち、1回の再帰呼び出し)の時間的複雑度よりも低い.結論を下す.しかし、時間的複雑度を最適化できるアルゴリズムであれば、高複雑度のアルゴリズムには、冗長性の代わりにより少ない計算を用いることができれば、効率を向上させることができる計算があるに違いない.(アルゴリズムの再帰はちょうど乗算なので、ここでは乗算に重点を置いているように見えます)
従来のマトリクス乗算アルゴリズムに冗長計算がある理由についても、次のように分析してみましょう.
冗長性の根本的な原因は、基本的な乗算分配法則a*(b+c)=a*b+a*cにあるべきである.同じ計算結果、前者(等号前)の方法では計算に2回の基本演算が必要であり,後者(等号後)の方法では3回必要である.(乗算と加算が同等のオーバーヘッドの基本演算であると仮定する).一般的な行列乗算アルゴリズムでは、大量の単歩乗算後の和を求める、すなわち、上記等号右側の計算式を採用している.乗算における同一因子を前に挙げて、上記乗算分配則を用いて計算形式を変換する方法があれば、計算効率を向上させることができる率.これがstrassenアルゴリズムの本質であるはずだ.strassenアルゴリズムの過程を見ると,まず一部の分子行列を加算(減算)し,乗算を行う.実は上述した分配律を構成した左式である
作者:知乎ユーザー
リンク:https://www.zhihu.com/question/28558331/answer/146497271
出典:知っている
著作権は作者の所有である.商業転載は著者に連絡して許可を得てください.非商業転載は出典を明記してください."
アルゴリズム導論は暴力を利用して2 NxN行列の積を求める問題を用いてStrassenアルゴリズムを引き出した.以下のコード実装は,本中の暴力求解法,分治求解法,Strassen求解法の実装にそれぞれ対応しており,具体的には以下に示す.
この部分の内容に関する擬似コードは、「アルゴリズム導論」4.2章を参照して説明することができる.
行列の乗算知識によると、2つのNxNの行列AとBを乗算した結果、行列Cの暴力アルゴリズムは以下の通りである.
/**
* ; A B NxN
*
* @param A
* A
* @param B
* B
* @return
* A B C
*/
public static int[][] squareMatrixMultiply(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < rows; j++) {
C[i][j] = 0;
for (int k = 0; k < rows; k++) {
C[i][j] = C[i][j] + A[i][k] * B[k][j];
}
}
}
return C;
}
分治思想を加えた2つのNxNの行列AとBが乗算結果行列Cを得るアルゴリズムは、/**
* NxN
* @param A
* A
* @param B
* B
* @return
*/
public static int[][] martixMultiplyRecursive(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
if (rows == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int[][] A11 = new int[rows / 2][rows / 2];
int[][] A12 = new int[rows / 2][rows / 2];
int[][] A21 = new int[rows / 2][rows / 2];
int[][] A22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);
int[][] B11 = new int[rows / 2][rows / 2];
int[][] B12 = new int[rows / 2][rows / 2];
int[][] B21 = new int[rows / 2][rows / 2];
int[][] B22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);
int[][] C11 = new int[rows / 2][rows / 2];
int[][] C12 = new int[rows / 2][rows / 2];
int[][] C21 = new int[rows / 2][rows / 2];
int[][] C22 = new int[rows / 2][rows / 2];
squareMatrixElementAdd(squareMatrixMultiply(A11, B11), squareMatrixMultiply(A12, B21), C11);
squareMatrixElementAdd(squareMatrixMultiply(A11, B12), squareMatrixMultiply(A12, B22), C12);
squareMatrixElementAdd(squareMatrixMultiply(A21, B11), squareMatrixMultiply(A22, B21), C21);
squareMatrixElementAdd(squareMatrixMultiply(A21, B12), squareMatrixMultiply(A22, B22), C22);
copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);
}
return C;
}
/**
* NxN 4 N/2xN/2
*
*/
public static void copyMatrixbyParamFromSrcToSubMatrix(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[i][j] = src[startI + i][startJ + j];
}
}
/**
* 4 N/2xN/2 NxN
*
*/
public static void copySubMatrixByParamFromSrcToDest(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[startI + i][startJ + j] = src[i][j];
}
}
/**
* NxN
*
* @param srcA
*
* @param srcB
*
* @param dest
*
*/
public static void squareMatrixElementAdd(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] + srcB[i][j];
}
/**
* NxN
*
*/
public static void displaySquare(int matrix[][]) {
for (int i = 0; i < matrix.length; i++) {
for (int j : matrix[i]) {
System.out.print(j + " ");
}
System.out.println();
}
}
Strassenアルゴリズムを使用して2つの行列の積を求めるアルゴリズム実装コードは、次のとおりである./**
* Strassen NxN
*
* @param A
* A
* @param B
* B
* @return
*/
public static int[][] strassenMartixMultiplyRecursive(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
if (rows == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int[][] A11 = new int[rows / 2][rows / 2];
int[][] A12 = new int[rows / 2][rows / 2];
int[][] A21 = new int[rows / 2][rows / 2];
int[][] A22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);
int[][] B11 = new int[rows / 2][rows / 2];
int[][] B12 = new int[rows / 2][rows / 2];
int[][] B21 = new int[rows / 2][rows / 2];
int[][] B22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);
int[][] S1 = new int[rows / 2][rows / 2];
int[][] S2 = new int[rows / 2][rows / 2];
int[][] S3 = new int[rows / 2][rows / 2];
int[][] S4 = new int[rows / 2][rows / 2];
int[][] S5 = new int[rows / 2][rows / 2];
int[][] S6 = new int[rows / 2][rows / 2];
int[][] S7 = new int[rows / 2][rows / 2];
int[][] S8 = new int[rows / 2][rows / 2];
int[][] S9 = new int[rows / 2][rows / 2];
int[][] S10 = new int[rows / 2][rows / 2];
squareMatrixElementSub(B12, B22, S1);// S1 = B12 - B22
squareMatrixElementAdd(A11, A12, S2);// S2 = A11 + A12
squareMatrixElementAdd(A21, A22, S3);// S3 = A21 + A22
squareMatrixElementSub(B21, B11, S4);// S4 = B21 - B11
squareMatrixElementAdd(A11, A22, S5);// S5 = A11 + A22
squareMatrixElementAdd(B11, B22, S6);// S6 = B11 + B22
squareMatrixElementSub(A12, A22, S7);// S7 = A12 - A22
squareMatrixElementAdd(B21, B22, S8);// S8 = B21 + B22
squareMatrixElementSub(A11, A21, S9);// S9 = A11 - A21
squareMatrixElementAdd(B11, B12, S10);// S10 = B11 + B12
int[][] P1 = new int[rows / 2][rows / 2];
int[][] P2 = new int[rows / 2][rows / 2];
int[][] P3 = new int[rows / 2][rows / 2];
int[][] P4 = new int[rows / 2][rows / 2];
int[][] P5 = new int[rows / 2][rows / 2];
int[][] P6 = new int[rows / 2][rows / 2];
int[][] P7 = new int[rows / 2][rows / 2];
P1 = strassenMartixMultiplyRecursive(A11, S1); // P1 = A11 X S1
P2 = strassenMartixMultiplyRecursive(S2, B22);// P2 = S2 X B22
P3 = strassenMartixMultiplyRecursive(S3, B11);// P3 = S3 X B11
P4 = strassenMartixMultiplyRecursive(A22, S4);// P4 = A22 X S4
P5 = strassenMartixMultiplyRecursive(S5, S6);// P5 = S5 X S6
P6 = strassenMartixMultiplyRecursive(S7, S8);// P6 = S7 X S8
P7 = strassenMartixMultiplyRecursive(S9, S10);// P7 = S9 X S10
int[][] C11 = new int[rows / 2][rows / 2];
int[][] C12 = new int[rows / 2][rows / 2];
int[][] C21 = new int[rows / 2][rows / 2];
int[][] C22 = new int[rows / 2][rows / 2];
int[][] temp = new int[rows / 2][rows / 2];
// C11 = P5 + P4 - P2 + P6
squareMatrixElementAdd(P5, P4, temp);
squareMatrixElementSub(temp, P2, temp);
squareMatrixElementAdd(temp, P6, C11);
// C12 = P1 + P2
squareMatrixElementAdd(P1, P2, C12);
// C21 = P3 + P4
squareMatrixElementAdd(P3, P4, C21);
// C22 = P5 + P1 - P3 -P7
squareMatrixElementAdd(P5, P1, temp);
squareMatrixElementSub(temp, P3, temp);
squareMatrixElementSub(temp, P7, C22);
// C11/C12/C21/C22 C
copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);
}
return C;
}
/**
* NxN 4 N/2xN/2
*
*/
public static void copyMatrixbyParamFromSrcToSubMatrix(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[i][j] = src[startI + i][startJ + j];
}
}
/**
* 4 N/2xN/2 NxN
*
*/
public static void copySubMatrixByParamFromSrcToDest(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[startI + i][startJ + j] = src[i][j];
}
}
/**
* NxN
*
* @param srcA
*
* @param srcB
*
* @param dest
*
*/
public static void squareMatrixElementAdd(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] + srcB[i][j];
}
/**
* NxN
*
* @param srcA
*
* @param srcB
*
* @param dest
*
*/
public static void squareMatrixElementSub(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] - srcB[i][j];
}
最後に本明細書で説明した完全なテストは、次のとおりです.public class StrassenAlgor {
static int[][] A = {
{ 1, 2, 2, 1 },
{ 1, 2, 2, 1 },
{ 1, 2, 2, 1 },
{ 1, 2, 2, 1 }
};
static int[][] B = {
{ 1, 2, 2, 1 },
{ 1, 2, 2, 1 },
{ 1, 2, 3, 1 },
{ 1, 2, 2, 1 }
};
public static void main(String[] args) {
System.out.println(" ");
int[][] C = martixMultiplyRecursive(A, B);
displaySquare(C);
System.out.println(" ");
int[][] C1 = martixMultiplyRecursive(A, B);
displaySquare(C1);
System.out.println("Strassen ");
int[][] C2 = strassenMartixMultiplyRecursive(A, B);
displaySquare(C2);
}
/**
* ; A B NxN
*
* @param A
* A
* @param B
* B
* @return A B C
*/
public static int[][] squareMatrixMultiply(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < rows; j++) {
C[i][j] = 0;
for (int k = 0; k < rows; k++) {
C[i][j] = C[i][j] + A[i][k] * B[k][j];
}
}
}
return C;
}
/**
* NxN
*
* @param A
* A
* @param B
* B
* @return
*/
public static int[][] martixMultiplyRecursive(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
if (rows == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int[][] A11 = new int[rows / 2][rows / 2];
int[][] A12 = new int[rows / 2][rows / 2];
int[][] A21 = new int[rows / 2][rows / 2];
int[][] A22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);
int[][] B11 = new int[rows / 2][rows / 2];
int[][] B12 = new int[rows / 2][rows / 2];
int[][] B21 = new int[rows / 2][rows / 2];
int[][] B22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);
int[][] C11 = new int[rows / 2][rows / 2];
int[][] C12 = new int[rows / 2][rows / 2];
int[][] C21 = new int[rows / 2][rows / 2];
int[][] C22 = new int[rows / 2][rows / 2];
squareMatrixElementAdd(squareMatrixMultiply(A11, B11), squareMatrixMultiply(A12, B21), C11);
squareMatrixElementAdd(squareMatrixMultiply(A11, B12), squareMatrixMultiply(A12, B22), C12);
squareMatrixElementAdd(squareMatrixMultiply(A21, B11), squareMatrixMultiply(A22, B21), C21);
squareMatrixElementAdd(squareMatrixMultiply(A21, B12), squareMatrixMultiply(A22, B22), C22);
// C11/C12/C21/C22 C
copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);
}
return C;
}
/**
* Strassen NxN
*
* @param A
* A
* @param B
* B
* @return
*/
public static int[][] strassenMartixMultiplyRecursive(int[][] A, int[][] B) {
int rows = A.length;
int[][] C = new int[rows][rows];
if (rows == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int[][] A11 = new int[rows / 2][rows / 2];
int[][] A12 = new int[rows / 2][rows / 2];
int[][] A21 = new int[rows / 2][rows / 2];
int[][] A22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, 0, rows / 2, A11);
copyMatrixbyParamFromSrcToSubMatrix(A, 0, rows / 2, rows / 2, rows / 2, A12);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, 0, rows / 2, A21);
copyMatrixbyParamFromSrcToSubMatrix(A, rows / 2, rows / 2, rows / 2, rows / 2, A22);
int[][] B11 = new int[rows / 2][rows / 2];
int[][] B12 = new int[rows / 2][rows / 2];
int[][] B21 = new int[rows / 2][rows / 2];
int[][] B22 = new int[rows / 2][rows / 2];
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, 0, rows / 2, B11);
copyMatrixbyParamFromSrcToSubMatrix(B, 0, rows / 2, rows / 2, rows / 2, B12);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, 0, rows / 2, B21);
copyMatrixbyParamFromSrcToSubMatrix(B, rows / 2, rows / 2, rows / 2, rows / 2, B22);
int[][] S1 = new int[rows / 2][rows / 2];
int[][] S2 = new int[rows / 2][rows / 2];
int[][] S3 = new int[rows / 2][rows / 2];
int[][] S4 = new int[rows / 2][rows / 2];
int[][] S5 = new int[rows / 2][rows / 2];
int[][] S6 = new int[rows / 2][rows / 2];
int[][] S7 = new int[rows / 2][rows / 2];
int[][] S8 = new int[rows / 2][rows / 2];
int[][] S9 = new int[rows / 2][rows / 2];
int[][] S10 = new int[rows / 2][rows / 2];
squareMatrixElementSub(B12, B22, S1);// S1 = B12 - B22
squareMatrixElementAdd(A11, A12, S2);// S2 = A11 + A12
squareMatrixElementAdd(A21, A22, S3);// S3 = A21 + A22
squareMatrixElementSub(B21, B11, S4);// S4 = B21 - B11
squareMatrixElementAdd(A11, A22, S5);// S5 = A11 + A22
squareMatrixElementAdd(B11, B22, S6);// S6 = B11 + B22
squareMatrixElementSub(A12, A22, S7);// S7 = A12 - A22
squareMatrixElementAdd(B21, B22, S8);// S8 = B21 + B22
squareMatrixElementSub(A11, A21, S9);// S9 = A11 - A21
squareMatrixElementAdd(B11, B12, S10);// S10 = B11 + B12
int[][] P1 = new int[rows / 2][rows / 2];
int[][] P2 = new int[rows / 2][rows / 2];
int[][] P3 = new int[rows / 2][rows / 2];
int[][] P4 = new int[rows / 2][rows / 2];
int[][] P5 = new int[rows / 2][rows / 2];
int[][] P6 = new int[rows / 2][rows / 2];
int[][] P7 = new int[rows / 2][rows / 2];
P1 = strassenMartixMultiplyRecursive(A11, S1); // P1 = A11 X S1
P2 = strassenMartixMultiplyRecursive(S2, B22);// P2 = S2 X B22
P3 = strassenMartixMultiplyRecursive(S3, B11);// P3 = S3 X B11
P4 = strassenMartixMultiplyRecursive(A22, S4);// P4 = A22 X S4
P5 = strassenMartixMultiplyRecursive(S5, S6);// P5 = S5 X S6
P6 = strassenMartixMultiplyRecursive(S7, S8);// P6 = S7 X S8
P7 = strassenMartixMultiplyRecursive(S9, S10);// P7 = S9 X S10
int[][] C11 = new int[rows / 2][rows / 2];
int[][] C12 = new int[rows / 2][rows / 2];
int[][] C21 = new int[rows / 2][rows / 2];
int[][] C22 = new int[rows / 2][rows / 2];
int[][] temp = new int[rows / 2][rows / 2];
// C11 = P5 + P4 - P2 + P6
squareMatrixElementAdd(P5, P4, temp);
squareMatrixElementSub(temp, P2, temp);
squareMatrixElementAdd(temp, P6, C11);
// C12 = P1 + P2
squareMatrixElementAdd(P1, P2, C12);
// C21 = P3 + P4
squareMatrixElementAdd(P3, P4, C21);
// C22 = P5 + P1 - P3 -P7
squareMatrixElementAdd(P5, P1, temp);
squareMatrixElementSub(temp, P3, temp);
squareMatrixElementSub(temp, P7, C22);
// C11/C12/C21/C22 C
copySubMatrixByParamFromSrcToDest(C11, 0, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C12, 0, rows / 2, rows / 2, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C21, rows / 2, rows / 2, 0, rows / 2, C);
copySubMatrixByParamFromSrcToDest(C22, rows / 2, rows / 2, rows / 2, rows / 2, C);
}
return C;
}
/**
* NxN 4 N/2xN/2
*
*/
public static void copyMatrixbyParamFromSrcToSubMatrix(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[i][j] = src[startI + i][startJ + j];
}
}
/**
* 4 N/2xN/2 NxN
*
*/
public static void copySubMatrixByParamFromSrcToDest(int[][] src, int startI, int lenI, int startJ, int lenJ,
int[][] dest) {
for (int i = 0; i < lenI; i++)
for (int j = 0; j < lenJ; j++) {
dest[startI + i][startJ + j] = src[i][j];
}
}
/**
* NxN
*
* @param srcA
*
* @param srcB
*
* @param dest
*
*/
public static void squareMatrixElementAdd(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] + srcB[i][j];
}
/**
* NxN
*
* @param srcA
*
* @param srcB
*
* @param dest
*
*/
public static void squareMatrixElementSub(int[][] srcA, int[][] srcB, int[][] dest) {
for (int i = 0; i < srcA.length; i++)
for (int j = 0; j < srcA[i].length; j++)
dest[i][j] = srcA[i][j] - srcB[i][j];
}
/**
* NxN
*
*/
public static void displaySquare(int[][] matrix) {
for (int i = 0; i < matrix.length; i++) {
for (int j : matrix[i]) {
System.out.print(j + " ");
}
System.out.println();
}
}
}
の出力は次のとおりです.
6 12 14 6
6 12 14 6
6 12 14 6
6 12 14 6
6 12 14 6
6 12 14 6
6 12 14 6
6 12 14 6
Strassen
6 12 14 6
6 12 14 6
6 12 14 6
6 12 14 6
PS:
最後に,Strassenアルゴリズムが方程式積の複雑さを低減する理由をずっと考えていた.最後に知っている上で1つの答えを見て、解釈がとても良いと感じます;今貼って、一緒に分かち合います:
「strassenアルゴリズムの鍵は、乗算か加算かではなく、アルゴリズムの内部再帰呼び出しの回数にある.strassenアルゴリズムの鍵は、内部再帰呼び出しの回数が1(通常の8回から特殊な7回に変わる)減少することにある.「茂み」度は、この時間的複雑さの再帰アルゴリズムにおける変化を記述する.したがってstrassenアルゴリズムの鍵は,再帰呼び出しの回数がどのように8回から1回減少するかである.逆に理解すると,8回目の再帰呼び出しのうち1回が冗長である,すなわち8回目の再帰乗算の結果情報が上位7回の結果に含まれており,上位7回の計算結果は線形組合せにより8回目の再帰の結果を得ることができる.一方、この線形の組み合わせの時間的複雑度は、アルゴリズム自体(すなわち、1回の再帰呼び出し)の時間的複雑度よりも低い.結論を下す.しかし、時間的複雑度を最適化できるアルゴリズムであれば、高複雑度のアルゴリズムには、冗長性の代わりにより少ない計算を用いることができれば、効率を向上させることができる計算があるに違いない.(アルゴリズムの再帰はちょうど乗算なので、ここでは乗算に重点を置いているように見えます)
従来のマトリクス乗算アルゴリズムに冗長計算がある理由についても、次のように分析してみましょう.
冗長性の根本的な原因は、基本的な乗算分配法則a*(b+c)=a*b+a*cにあるべきである.同じ計算結果、前者(等号前)の方法では計算に2回の基本演算が必要であり,後者(等号後)の方法では3回必要である.(乗算と加算が同等のオーバーヘッドの基本演算であると仮定する).一般的な行列乗算アルゴリズムでは、大量の単歩乗算後の和を求める、すなわち、上記等号右側の計算式を採用している.乗算における同一因子を前に挙げて、上記乗算分配則を用いて計算形式を変換する方法があれば、計算効率を向上させることができる率.これがstrassenアルゴリズムの本質であるはずだ.strassenアルゴリズムの過程を見ると,まず一部の分子行列を加算(減算)し,乗算を行う.実は上述した分配律を構成した左式である
作者:知乎ユーザー
リンク:https://www.zhihu.com/question/28558331/answer/146497271
出典:知っている
著作権は作者の所有である.商業転載は著者に連絡して許可を得てください.非商業転載は出典を明記してください."