【線分樹】HDOJ 5316 Magician
2643 ワード
セグメントツリー区間のマージ...パリティの4つの状態を記録すればいいのですが...
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
#define ls o << 1
#define rs o << 1 | 1
#define lson o << 1, L, mid
#define rson o << 1 | 1, mid+1, R
const int maxn = 100000;
struct node
{
LL sum[2][2];
}a[maxn << 2];
int x[maxn];
const LL INF = 1e16;
int n, m;
void pushup(int o)
{
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++)
a[o].sum[i][j] = max(a[ls].sum[i][j], a[rs].sum[i][j]);
for(int k = 0; k < 2; k++)
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++)
a[o].sum[i][j] = max(a[o].sum[i][j], a[ls].sum[i][k] + a[rs].sum[1-k][j]);
}
void build(int o, int L, int R)
{
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++)
a[o].sum[i][j] = -INF;
if(L == R) {
if(L % 2) a[o].sum[1][1] = x[L];
else a[o].sum[0][0] = x[L];
return;
}
int mid = (L + R) >> 1;
build(lson);
build(rson);
pushup(o);
}
void update(int o, int L, int R, int q, int v)
{
if(L == R) {
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++)
a[o].sum[i][j] = -INF;
if(L % 2) a[o].sum[1][1] = v;
else a[o].sum[0][0] = v;
return;
}
int mid = (L + R) >> 1;
if(q <= mid) update(lson, q, v);
else update(rson, q, v);
pushup(o);
}
node query(int o, int L, int R, int ql, int qr)
{
if(ql <= L && qr >= R) return a[o];
int mid = (L + R) >> 1;
node ans;
memset(ans.sum, 0, sizeof ans.sum);
if(qr <= mid) return query(lson, ql, qr);
else if(ql > mid) return query(rson, ql, qr);
else {
node t1 = query(lson, ql, qr);
node t2 = query(rson, ql, qr);
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++)
ans.sum[i][j] = max(t1.sum[i][j], t2.sum[i][j]);
for(int k = 0; k < 2; k++)
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++)
ans.sum[i][j] = max(ans.sum[i][j], t1.sum[i][k] + t2.sum[1-k][j]);
return ans;
}
}
void work()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) scanf("%d", &x[i]);
build(1, 1, n);
int op, ql, qr, q, v;
while(m--) {
scanf("%d", &op);
LL ans = -INF;
if(op == 0) {
scanf("%d%d", &ql, &qr);
node res = query(1, 1, n, ql, qr);
for(int i = 0; i < 2; i++)
for(int j = 0; j < 2; j++)
ans = max(ans, res.sum[i][j]);
printf("%lld
", ans);
}
else {
scanf("%d%d", &q, &v);
update(1, 1, n, q, v);
}
}
}
int main()
{
int _;
scanf("%d", &_);
while(_--) work();
return 0;
}