[ZOJ 3813 Alternating Sum]線分樹状配列

7363 ワード

テーマ
http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3813
分析
f[R]=sigma(g[i][R])を設定し、L<=i==R
简易化f[R]=S[L]+S[L+2]+S[L+4]+…+S[L',L'はRの前で一番大きなLと同じ番号の数です。
Ans=sigma(f[r])、L<=r==R、Length=R-L+1を設定します。
RがLと違ってパリティなら、Ans=S[L]*Length+S[L+2]*(Length-2)+S[R-1]*2
RがLと同じパリティであれば、Ans=S[L]*Length+S[L+2]*(Length-2)+S[R]=S[L]*(Length+1)+S[L+2]*(Length-1)+S[R]*2-(S[L]+S[L+2]+S[R]
そこで奇数位置と偶数位置の数をそれぞれ取り出して線分樹を作り、維持区間[L,R]のsigma(S[i]*(R-i+1)*2)とsigma(S[i])を作ります。
毎回の質問に対して一番左の部分に分けて、中間循環の一段と一番右の部分はそれぞれ討論します。
最初は数を倍にして討論を減らすことができます。
中間循環段は、等差数列です。
コード
/**************************************************
 *        Problem:  ZOJ 3813
 *         Author:  clavichord93
 *          State:  Accepted
 **************************************************/

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
using namespace std;

const long long Mod = 1000000007;
const long long Inv2 = 500000004;
const int MAX_N = 200005;

int n;
int tot;
char s[MAX_N];
int a[MAX_N][2];
long long sum[MAX_N << 1][2];
long long sum2[MAX_N << 1][2];

#define lch(t) (t << 1)
#define rch(t) (t << 1 | 1)

void makeTree(int id, int t, int l, int r) {
    if (l == r) {
        sum[t][id] = a[l][id];
        sum2[t][id] = 2 * a[l][id];
    }
    else {
        int mid =(l + r) >> 1;
        makeTree(id, lch(t), l, mid);
        makeTree(id, rch(t), mid + 1, r);
        sum[t][id] = (sum[lch(t)][id] + sum[rch(t)][id]) % Mod;
        sum2[t][id] = (sum2[lch(t)][id] + sum[lch(t)][id] * (r - mid) * 2 % Mod + sum2[rch(t)][id]) % Mod;
    }
}

void change(int id, int t, int l, int r, int x, int val) {
    if (l == r) {
        sum[t][id] = val;
        sum2[t][id] = 2 * val;
    }
    else {
        int mid = (l + r) >> 1;
        if (x <= mid) {
            change(id, lch(t), l, mid, x, val);
        }
        else {
            change(id, rch(t), mid + 1, r, x, val);
        }
        sum[t][id] = (sum[lch(t)][id] + sum[rch(t)][id]) % Mod;
        sum2[t][id] = (sum2[lch(t)][id] + sum[lch(t)][id] * (r - mid) % Mod * 2 % Mod + sum2[rch(t)][id]) % Mod;
    }
}

long long getSum(int id, int t, int l, int r, int x, int y) {
    if (x <= l && r <= y) {
        return sum[t][id];
    }
    else {
        int mid = (l + r) >> 1;
        long long ret = 0;
        if (x <= mid) {
            ret = (ret + getSum(id, lch(t), l, mid, x, y)) % Mod;
        }
        if (y > mid) {
            ret = (ret + getSum(id, rch(t), mid + 1, r, x, y)) % Mod;
        }
        return ret;
    }
}

long long getSum2(int id, int t, int l, int r, int x, int y) {
    if (x <= l && r <= y) {
        return (sum2[t][id] + sum[t][id] * (y - r) % Mod * 2 % Mod) % Mod;
    }
    else {
        int mid = (l + r) >> 1;
        long long ret = 0;
        if (x <= mid) {
            ret = (ret + getSum2(id, lch(t), l, mid, x, y)) % Mod;
        }
        if (y > mid) {
            ret = (ret + getSum2(id, rch(t), mid + 1, r, x, y)) % Mod;
        }
        return ret;
    }
}

#undef lch
#undef rch

int main() {
    #ifdef LOCAL_JUDGE
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #endif
    int T;
    scanf("%d", &T);
    for (int cas = 1; cas <= T; cas++) {
        scanf("%s", s + 1);
        n = strlen(s + 1);
        for (int i = 1; i <= n; i++) {
            s[n + i] = s[i];
        }
        tot = n;
        n *= 2;
        for (int i = 1; i <= n; i++) {
            a[(i + 1) / 2][i & 1] = s[i] - '0';
        }
        makeTree(0, 1, 1, tot);
        makeTree(1, 1, 1, tot);
        int q;
        scanf("%d", &q);
        for (int i = 0; i < q; i++) {
            int op;
            scanf("%d", &op);
            if (op == 1) {
                int x, d, id;
                scanf("%d %d", &x, &d);
                id = x & 1;
                change(id, 1, 1, tot, (x + 1) / 2, d);
                x += tot;
                id = x & 1;
                change(id, 1, 1, tot, (x + 1) / 2, d);
            }
            else {
                long long l, r;
                scanf("%lld %lld", &l, &r);

                long long cntL = l / n;
                long long cntR = r / n;
                int limL = l % n;
                int limR = r % n;
                if (limL == 0) {
                    limL = n;
                    cntL--;
                }
                if (limR == 0) {
                    limR = n;
                    cntR--;
                }

                int id = limL & 1;
                bool parity = (limL & 1) == (limR & 1);
                limR = parity ? limR : limR - 1;
                if (limR == 0) {
                    limR = n;
                    cntR--;
                }

                limL = (limL + 1) / 2;
                limR = (limR + 1) / 2;

                //cout << cntL << ' ' << cntR << endl;
                //cout << limL << ' ' << limR << endl;

                if (cntL == cntR) {
                    long long sum2 = getSum2(id, 1, 1, tot, limL, limR);
                    long long sum = getSum(id, 1, 1, tot, limL, limR);
                    long long ans = sum2;
                    if (parity) {
                        ans = (ans - sum + Mod) % Mod;
                    }
                    //cout << sum2 << ' ' << sum << endl;
                    printf("%lld
", ans); } else { long long cntM = (cntR - cntL - 1) % Mod; long long sumAll = getSum(id, 1, 1, tot, 1, tot); long long sumL = getSum(id, 1, 1, tot, limL, tot); long long sumM = sumAll * cntM % Mod; long long sumR = getSum(id, 1, 1, tot, 1, limR); long long sum = (sumL + sumM + sumR) % Mod; long long sum2All = getSum2(id, 1, 1, tot, 1, tot); long long A1 = (sum2All + sumAll * limR % Mod * 2 % Mod) % Mod; long long An = (sum2All + ((limR + (cntM - 1 + Mod) % Mod * tot % Mod) % Mod * 2 % Mod) % Mod * sumAll % Mod) % Mod; long long sum2R = getSum2(id, 1, 1, tot, 1, limR); long long sum2M = (A1 + An) % Mod * cntM % Mod * Inv2 % Mod; long long sum2L = (getSum2(id, 1, 1, tot, limL, tot) + (cntM * tot % Mod + limR) % Mod * sumL % Mod * 2 % Mod) % Mod; long long ans = (sum2L + sum2M + sum2R) % Mod; //cout << sum2All << ' ' << sumAll << ' ' << limR << ' ' << cntM << endl; //cout << sum2L << ' ' << sum2M << ' ' << sum2R << endl; if (parity) { ans = (ans - sum + Mod) % Mod; } printf("%lld
", ans); } } } } return 0; }