多項式除算


f,g,degf=n,degg=m(m≦n)f,g,deg⁡f=n,deg⁡g=m(m≦n)
唯一のq、r q、rを求めて、f=qになります×g+r f = q × g+r、degr例:f(x)=x 4+x 3+2 x 2+4 x+2,g(x)=x 2+x+3 f(x)=x 4+x 3+2 x 2+4 x+2,g(x)=x 2+x+3
f(x)=(x 2−1)g(x)+5 x+5 f(x)=(x 2−1)g(x)+5 x+5 q(x)=x 2+1,r(x)=5 x+5 q(x)=x 2+1,r(x)=5 x+5 q(x)=x 2+1,r(x)=5 x+5
f=q×g+r f = q × g + r
r rの存在は演算をうまく行わず,r rの影響を消去することを考慮していることが分かった.
degqのため×g=degf deg ⁡ q × g=deg
fR(x)=(qR×gR)(x)+xn−degrrR(x) f R ( x ) = ( q R × g R ) ( x ) + x n − deg ⁡ r r R ( x )
R Rは多項式を反転することを指す
degr≤degg−1=m−1⇒n−m−1≤n−degr deg ⁡ r ≤ deg ⁡ g − 1 = m − 1 ⇒ n − m − 1 ≤ n − deg ⁡ r
xn−degrrR(x)≡0(modxn−m−1) x n − deg ⁡ r r R ( x ) ≡ 0 ( mod x n − m − 1 )
fR(x)≡(qR×gR)(x)(modxn−m+1) f R ( x ) ≡ ( q R × g R ) ( x ) ( mod x n − m + 1 )
これにより、r rの影響を完全に消去する
このときgR g Rに対して逆を求めればqR q Rを解き,その後r rを反復して得る.
T(n)=O(nlogn) T ( n ) = O ( n log ⁡ n )
コードは次のとおりです.
#include 
using namespace std;
const int N = 1000010 , mod = 998244353 , G = 3;
int A[N] , B[N] , revA[N] , revB[N];
int rev[N];
int a[N] , b[N] , c[N];
int n , m;
int read() {
    int ans = 0 , flag = 1;
    char ch = getchar();
    while(ch > '9' || ch < '0') {if(ch=='-') flag = -1; ch = getchar();}
    while(ch <= '9' && ch >= '0') {ans = ans * 10 + ch - '0'; ch = getchar();}
    return ans * flag;
}
int qpow(int a , int b) {
    int ans = 1;
    while(b) {
        if(b & 1) ans = 1ll * ans * a % mod;
        a = 1ll * a * a % mod;
        b >>= 1;
    }
    return ans;
}
void dft(int *now , int n , int f) {
    for(int i = 0 ; i < n ; ++ i) {if(i < rev[i]) swap(now[i] , now[rev[i]]);}
    for(int i = 1 ; i < n ; i <<= 1) {
        int gn = qpow(G , (mod - 1) / (i<<1));
        if(f != 1) gn = qpow(gn , mod - 2);
        for(int j = 0 ; j < n ; j += (i << 1)) {
            int x , y , g = 1;
            for(int k = 0 ; k < i ; ++ k , g = 1ll * gn *g % mod) {
                x = now[j + k]; y = 1ll * g * now[i + j + k] % mod;
                now[j + k] = (x + y) % mod;
                now[i + j + k] = ((x - y) % mod + mod) % mod;
            }
        }
    }
    if(f != 1) {
        int ny = qpow(n , mod - 2);
        for(int i = 0 ; i < n ; ++ i)
            now[i] = 1ll * ny * now[i] % mod;
    }
}
void work(int deg , int *a , int *b) {
    if(deg == 1) {b[0] = qpow(a[0] , mod - 2); return;}
    work((deg + 1) >> 1 , a , b);
    int l = 0 , nn  , n = deg * 2;
    for(nn = 1 ; nn < n ; nn <<= 1) ++ l;
    for(int i = 0 ; i < nn ; ++ i)
        rev[i] = (rev[i>>1]>>1) | ((i & 1) << (l - 1));
    for(int i = 0 ; i < nn ; ++ i) c[i] = i < deg ? a[i] : 0;
    for(int i = deg ; i < nn ; ++ i) c[i] = 0;
    dft(b , nn , 1); dft(c , nn , 1);
    for(int i = 0 ; i < nn ; ++ i) b[i] = 1ll * ((2 - 1ll * c[i] * b[i] % mod ) %mod + mod ) % mod * b[i] % mod;
    dft(b , nn , -1);
    for(int i = deg ; i < nn ; ++ i) b[i] = 0;
}
int main() {
    freopen("in" , "r" , stdin);
    n = read(); m = read();
    int l = 0 , nn;
    for(nn = 1 ; nn < n * 2; nn <<= 1) ++ l;
    for(int i = 0 ; i <= n ; ++ i) A[i] = revA[n - i] = read();
    for(int i = 0 ; i <= m ; ++ i) B[i] = revB[m - i] = read();

    work(n - m + 1 , revB , b);

    memset(rev , 0 , sizeof(rev));
    for(int i = 0 ; i < nn ; ++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));

    dft(revA , nn , 1); dft(b , nn , 1);
    for(int i = 0 ; i < nn ; ++ i) c[i] = 1ll * revA[i] * b[i] %mod;
    dft(c , nn , -1);

    for(int i = n - m + 1 ; i < nn ; ++ i) c[i] = 0;
    reverse(c , c + n - m + 1);
    for(int i = 0 ; i < n - m + 1 ; ++ i) printf("%d ", c[i]);
    puts("");
    dft(B , nn , 1); dft(c , nn , 1);
    for(int i = 0 ; i < nn ; ++ i) B[i] = 1ll * B[i] * c[i] % mod;
    dft(B , nn , -1);

    for(int i = 0 ; i <= n ; ++ i) A[i] = (A[i] - B[i] + mod) % mod;
    for(int i = 0 ; i < m ; ++ i) printf("%d ",A[i]);
    puts("");
    return 0;
}