COJ 1208 Fibonacci sum

4643 ワード

标题:Fibonacci数列の前のN項のK次方和を求めます.
解析:この問題は明らかに行列の高速べき乗で解決され,問題はどのように繰返し方程式を探し出すかである.
Fibonacci数列の性質から、f(i)=f(i-1)+f(i-2);       
f(a)^k = [f(a-1) + f(a-2)]^k
二項式の展開には良い性質がある:展開後に原式と整列する.ちょうどここで使えます.
f(a)^x*f(a-1)^(k-x)=sum{C(x,i)*f(a-1)^(k+i-x)*f(a-2)^(x-i)|(0<=i<=x)}は常にK回であることがわかる.
これにより,繰返し式,すなわちf(a)^kを求める方法が容易に得られる.
前のn項とを求めて、2つの方法があります:
①二分行列加算
②繰返し式に加算されます.
明らかに2はずっと効率的で、実測①800+ms②40+ms
参照コード:
//     
#include 
#include 
#include 
#include 
using namespace std;
typedef long long LL;
const int maxn = 21 ;
const LL MOD = 1000000007LL;
int K ;
struct Mat{
    int m[maxn][maxn] ;
    Mat(int a=0){
        memset(m , 0 ,sizeof(m)) ;
        for(int i=0; i<=K; i++) m[i][i] = a ;
    }
};
 
Mat operator + (const Mat &a ,const Mat &b){
    Mat ret;
    for(int i=0; i<=K; i++)
        for(int j=0; j<=K ; j++)
            ret.m[i][j] = ((LL)a.m[i][j] + b.m[i][j]) % MOD ;
    return ret ;
}
Mat operator * (const Mat &a ,const Mat &b){
    Mat ret ;
    for(int i=0; i<=K; i++)
        for(int j=0; j<=K; j++){
            unsigned long long tmp = 0;
            for(int k=0; k<=K; k++){
                tmp += (LL)a.m[i][k] * b.m[k][j] ;
                if(k==10) tmp %= MOD ;
            }
            ret.m[i][j] = tmp%MOD ;
        }
    return ret ;
}
Mat operator ^ (Mat a, int n){
    Mat ret(1) ;
    while(n){
        if(n&1) ret = ret * a ;
        a = a*a ;
        n>>=1 ;
    }
    return ret ;
}
Mat sum(Mat a , int n){
    Mat E(1) ;
    if(n == 0) return E ;
    if(n&1) return ((a^(n/2+1)) + E) * sum(a , n/2) ;
    else {
        Mat  tmp = a^(n/2);
        return tmp + (tmp*a+E)*sum(a , n/2-1);
    }
}
 
Mat A ;
LL C[30][30] ;
 
void prepare(){
    C[0][0] = 1 ;
    for(int i=1; i<25; i++){
        C[i][0] = 1 ;
        for(int j=1; j<25; j++){
            C[i][j] += C[i-1][j-1] + C[i-1][j] ;
            while(C[i][j] >= MOD) C[i][j] -= MOD ;
        }
    }
}
 
int solve(int n){
    if( n == 1) { return 1 ;}
    memset(A.m ,0 ,sizeof(A.m)) ;
    for(int i=0; i<=K ;i++){
        for(int j=0; j<=i; j++)
            A.m[i][K-i+j] = C[i][j] ;
    }
    A = sum(A , n-2);
    LL ans = 0;
    for(int i=0; i<=K; i++) ans += A.m[K][i] ;
    ans++ ;
    while(ans >= MOD) ans-=MOD;
    return ans ;
}
int main()
{
    //freopen("in.txt","r",stdin);
    //freopen("out.txt","w",stdout);
    prepare() ;
    int T  , n ;
    scanf("%d" , &T);
    while(T--){
        scanf("%d%d" , &n ,&K);
        int ans = solve(n) ;
        printf("%d
" , ans) ; } return 0; }
//②          
#include 
#include 
#include 
#include 
using namespace std;
typedef long long LL;
const int maxn = 22 ;
const LL MOD = 1000000007LL;
int K ;
struct Mat{
    int m[maxn][maxn] ;
    Mat(int a=0){
        memset(m , 0 ,sizeof(m)) ;
        for(int i=0; i<=K+1; i++) m[i][i] = a ;
    }
};
 
 
Mat operator * (const Mat &a ,const Mat &b){
    Mat ret ;
    for(int i=0; i<=K+1; i++)
        for(int j=0; j<=K+1; j++){
            unsigned long long tmp = 0;
            for(int k=0; k<=K+1; k++){
                tmp += (LL)a.m[i][k] * b.m[k][j] ;
                if(k==10) tmp %= MOD ;
            }
            ret.m[i][j] = tmp%MOD ;
        }
    return ret ;
}
Mat operator ^ (Mat a, int n){
    Mat ret(1) ;
    while(n){
        if(n&1) ret = ret * a ;
        a = a*a ;
        n>>=1 ;
    }
    return ret ;
}
 
Mat A ;
LL C[30][30] ;
 
void prepare(){
    C[0][0] = 1 ;
    for(int i=1; i<25; i++){
        C[i][0] = 1 ;
        for(int j=1; j<25; j++){
            C[i][j] += C[i-1][j-1] + C[i-1][j] ;
            while(C[i][j] >= MOD) C[i][j] -= MOD ;
        }
    }
}
 
int solve(int n){
    if( n == 1) { return 1 ;}
    memset(A.m ,0 ,sizeof(A.m)) ;
    for(int i=0; i<=K ;i++){
        for(int j=0; j<=i; j++)
            A.m[i][K-i+j] = C[i][j] ;
    }
    A.m[K+1][K] = A.m[K+1][K+1] = 1;
    A = A^(n-1) ;
    LL ans = 0;
    for(int i=0; i<=K+1; i++) ans += A.m[K+1][i] ;
    while(ans >= MOD) ans-=MOD;
    return ans ;
}
int main()
{
    prepare() ;
    int T  , n ;
    scanf("%d" , &T);
    while(T--){
        scanf("%d%d" , &n ,&K);
        int ans = solve(n) ;
        printf("%d
" , ans) ; } return 0; }