poj 3233(マトリクス乗算+二分+再帰)


テーマ解析:マトリクスの高速べき乗.まず、A^xはマトリクスの高速べき乗で求めることができることを知っています(poj 3070が具体的に表示されます).次に、kを二分することができ、毎回規模を半減し、kをパリティの2つに分けることができます.例えば、k=6とk=7の場合:
k=6は、S(6)=(1+A^3)*(A+A^2+A^3)=(1+A^3)*S(3)である.
k=7は、S(7)=A+(A+A^4)*(A+A^2+A^3)=A+(A+A^4)*S(3)である.
ps:マトリックスに対して構造体Matrixと定義して、Sを求める時再帰を使って、プログラムは比較的に直観的で、少し書きやすいです.もちろん配列として定義し、いくつかの前処理を行うと、効率が高くなります.
注意:
        1.最初は、ずっと再帰して、層数が多すぎて、ずっとTLEで、計算した一時記憶の(エラーコードを参照);
        2.注意マトリクスの0乗
        3.注意演算子の再ロード
正しいコード:参考になりましたhttp://blog.sina.com.cn/s/blog_6635898a0102e1am.html
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;

int n,k,m;
struct node{
     int matrix[50][50];
};
node a;
//     
node operator + (node x,node y)//  x+  y
{
	node ans;
	for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
			ans.matrix[i][j]=(x.matrix[i][j]+y.matrix[i][j])%m;
	return ans;
}
node inline mult(node x,node y)//    x*y
{
	node c;
	for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
		{
			int ans=0;
			for(int p=1;p<=n;p++)//
			{
				ans+=(x.matrix[i][p]*y.matrix[p][j])%m;
				ans%=m;
			}
			c.matrix[i][j]=ans%m;
		}
	return c;
}
node inline func(node x,int i)//    x^i
{
	//printf("%d**
",i); node temp,c; memset(temp.matrix,0,sizeof(temp.matrix)); for(int j=1;j<=n;j++) temp.matrix[j][j]=1; if(i==0) return temp; if(i==1) return x; c=func(x,i/2); if(i%2==0) return mult(c,c); else return mult(mult(c,c),a); } node fun(node A,int x) // a^1+a^2+...+a^k { if(x==1) return A; node B=func(A,(x+1)/2); node C=fun(A,x/2); if(x%2==0) return mult((func(A,0)+B),C);//return B+mult(C,B); else return A+mult((A+B),C);//B+mult(C,B)+C; } int main() { while(scanf("%d %d %d",&n,&k,&m)!=EOF) { int i,j; for(i=1;i<=n;i++) for(j=1;j<=n;j++) scanf("%d",&a.matrix[i][j]); node ans=fun(a,k); for(i=1;i<=n;i++) { printf("%d",ans.matrix[i][1]); for(j=2;j<=n;j++) printf(" %d",ans.matrix[i][j]); printf("
"); } } //system("pause"); return 0; }

エラーのコード:
k=6は、S(6)=(1+A^3)*(A+A^2+A^3)=(1+A^3)*S(3)である.
k=7有:S(7)=(1+A^3)*(A+A^2+A^3)+A^7.
寝間違えた・・・どうして..................?????
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;

int n,k,m;
struct node{
     int matrix[50][50];
	 bool flag;
}arr[11000];
node a;
//     
node operator + (node x,node y)//  x+  y
{
	node ans;
	for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
			ans.matrix[i][j]=(x.matrix[i][j]+y.matrix[i][j])%m;
	return ans;
}
node operator = (node x)
{
	/*node ans;
	for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
			ans.matrix[i][j]=x.matrix[i][j];*/
	return x;
}
node mult(node x,node y)//    x*y
{
	node c;
	for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
		{
			int ans=0;
			for(int p=1;p<=n;p++)//
			{
				ans+=(x.matrix[i][p]*y.matrix[p][j])%m;
				ans%=m;
			}
			c.matrix[i][j]=ans%m;
		}
	return c;
}
node func(node x,int i)//    x^i
{
	//printf("%d**
",i); if(i==1) return x; if(i%2==0) return mult(func(x,i/2),func(x,i/2)); else return mult(mult(func(x,i/2),func(x,i/2)),a); } node fun(int x) // a^1+a^2+...+a^k { node temp; if(arr[x].flag==true) return arr[x]; if(x%2==0) temp=fun(x/2)+mult(func(a,x/2),fun(x/2)); else temp=fun(x/2)+mult(func(a,x/2),fun(x/2))+func(a,x); arr[x]=temp; arr[x].flag=true; return temp; /*if(x==1) return a; if(x%2==0) return fun(x/2)+mult(func(a,x/2),fun(x/2)); else return fun(x/2)+mult(func(a,x/2),fun(x/2))+func(a,x);*/ } int main() { while(scanf("%d %d %d",&n,&k,&m)!=EOF) { for(int i=1;i<=n;i++) for(int j=1;j<=n;j++) scanf("%d",&a.matrix[i][j]); for(int i=1;i<=10000;i++) arr[i].flag=false; arr[1]=a; arr[1].flag=true; node ans=fun(k); for(int i=1;i<=n;i++) { printf("%d",ans.matrix[i][1]); for(int j=2;j<=n;j++) printf(" %d",ans.matrix[i][j]); printf("
"); } } system("pause"); return 0; }