[BZOJ1500][NOI2005]维修数列
试题描述
输入
输入的第1 行包含两个数N 和M(M ≤20 000),N 表示初始时数列中数的个数,M表示要进行的操作数目。
第2行包含N个数字,描述初始时的数列。
以下M行,每行一条命令,格式参见问题描述中的表格。
任何时刻数列中最多含有500 000个数,数列中任何一个数字均在[-1 000, 1 000]内。
插入的数字总数不超过4 000 000个,输入文件大小不超过20MBytes。
输出
对于输入数据中的GET-SUM和MAX-SUM操作,向输出文件依次打印结果,每个答案(数字)占一行。
输入示例
9 8 2 -6 3 5 1 -5 -3 6 3 GET-SUM 5 4 MAX-SUM INSERT 8 3 -5 7 2 DELETE 12 1 MAKE-SAME 3 3 2 REVERSE 3 6 GET-SUM 5 4 MAX-SUM
输出示例
-1 10 1 10
数据规模及约定
见“输入”
题解
又一道裸题。但是这题维护的东西比较恶心:需要维护节点的值 v、子树大小 siz、子树权值和 sum、子树最大前缀和 ml、子树最大后缀和 mr、子树最大连续和 ms、权值懒标记 setv 以及反转标记 rev。然后注意如果一个子树 rev = 1(即打了反转标记),则它的 ml 和 mr 需要互换。
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
int read() {
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)){ if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)){ x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
#define maxn 500010
#define LL long long
#define NoSign -2333
#define oo (1ll << 60)
struct Node {
int v, siz;
LL sum, ml, mr, ms, setv;
bool rev;
Node() {}
Node(int _): v(_), setv(NoSign), rev(0) {}
} ns[maxn];
int rt, ToT, fa[maxn], ch[2][maxn], rec[maxn], cc;
bool hs(int o) { return ns[o].setv != NoSign; }
void maintain(int o) {
ns[o].siz = 1; ns[o].sum = ns[o].v; ns[o].ms = ns[o].v;
int l = ch[0][o], r = ch[1][o];
for(int i = 0; i < 2; i++) if(ch[i][o])
ns[o].siz += ns[ch[i][o]].siz,
ns[o].sum += !hs(ch[i][o]) ? ns[ch[i][o]].sum : ns[ch[i][o]].setv * ns[ch[i][o]].siz,
ns[o].ms = max(ns[o].ms, !hs(ch[i][o]) ? ns[ch[i][o]].ms : max(ns[ch[i][o]].setv, ns[ch[i][o]].setv * ns[ch[i][o]].siz));
int tll = ns[l].ml, tlr = ns[l].mr, trl = ns[r].ml, trr = ns[r].mr, tls = ns[l].sum, trs = ns[r].sum;
if(hs(l)) ns[l].ml = ns[l].mr = max(ns[l].setv, ns[l].setv * ns[l].siz), ns[l].sum = ns[l].setv * ns[l].siz;
if(hs(r)) ns[r].ml = ns[r].mr = max(ns[r].setv, ns[r].setv * ns[r].siz), ns[r].sum = ns[r].setv * ns[r].siz;
if(ns[l].rev) swap(ns[l].ml, ns[l].mr);
if(ns[r].rev) swap(ns[r].ml, ns[r].mr);
if(l) ns[o].ml = ns[l].ml; else ns[o].ml = -oo;
ns[o].ml = max(ns[o].ml, (l ? ns[l].sum : 0) + ns[o].v);
if(r) ns[o].ml = max(ns[o].ml, (l ? ns[l].sum : 0) + ns[o].v + ns[r].ml);
if(r) ns[o].mr = ns[r].mr; else ns[o].mr = -oo;
ns[o].mr = max(ns[o].mr, (r ? ns[r].sum : 0) + ns[o].v);
if(l) ns[o].mr = max(ns[o].mr, (r ? ns[r].sum : 0) + ns[o].v + ns[l].mr);
if(l) ns[o].ms = max(ns[o].ms, ns[l].mr + max(ns[o].v, 0));
if(r) ns[o].ms = max(ns[o].ms, ns[r].ml + max(ns[o].v, 0));
if(l && r) ns[o].ms = max(ns[o].ms, ns[l].mr + ns[o].v + ns[r].ml);
ns[l].ml = tll; ns[l].mr = tlr; ns[r].ml = trl; ns[r].mr = trr; ns[l].sum = tls; ns[r].sum = trs;
return ;
}
int getnode() {
if(cc) return rec[cc--];
return ++ToT;
}
int val[maxn], cv;
void build(int& o, int l, int r) {
if(l > r) return ;
int mid = l + r >> 1;
ns[o = getnode()] = Node(val[mid]);
build(ch[0][o], l, mid - 1); build(ch[1][o], mid + 1, r);
if(ch[0][o]) fa[ch[0][o]] = o;
if(ch[1][o]) fa[ch[1][o]] = o;
maintain(o);
return ;
}
void pushdown(int o) {
if(hs(o)) {
LL& st = ns[o].setv;
ns[o].v = st;
for(int i = 0; i < 2; i++) if(ch[i][o])
ns[ch[i][o]].setv = st;
st = NoSign;
}
if(ns[o].rev) {
bool& rv = ns[o].rev;
for(int i = 0; i < 2; i++) if(ch[i][o])
ns[ch[i][o]].rev ^= rv;
swap(ch[0][o], ch[1][o]);
rv = 0;
}
return maintain(o);
}
void rotate(int u) {
int y = fa[u], z = fa[y], l = 0, r = 1;
if(z) ch[ch[1][z]==y][z] = u;
if(ch[1][y] == u) swap(l, r);
fa[u] = z; fa[y] = u; fa[ch[r][u]] = y;
ch[l][y] = ch[r][u]; ch[r][u] = y;
maintain(y); maintain(u);
return ;
}
int S[maxn], top;
void splay(int u) {
int t = u; S[++top] = t;
while(fa[t]) t = fa[t], S[++top] = t;
while(top) pushdown(S[top--]);
while(fa[u]) {
int y = fa[u], z = fa[y];
if(z) {
if(ch[0][y] == u ^ ch[0][z] == y) rotate(u);
else rotate(y);
}
rotate(u);
}
return ;
}
int split(int u) {
if(!u) return 0;
splay(u);
int tmp = ch[1][u];
fa[tmp] = 0; ch[1][u] = 0;
maintain(u);
return tmp;
}
int merge(int a, int b) {
if(!a) return maintain(b), b;
if(!b) return maintain(a), a;
pushdown(a); while(ch[1][a]) a = ch[1][a], pushdown(a);
splay(a);
ch[1][a] = b; fa[b] = a;
return maintain(a), a;
}
int qkth(int o, int k) {
if(!o) return 0;
pushdown(o);
int ls = ch[0][o] ? ns[ch[0][o]].siz : 0;
if(k == ls + 1) return o;
if(k > ls + 1) return qkth(ch[1][o], k - ls - 1);
return qkth(ch[0][o], k);
}
int nsize;
int Find(int k) {
if(!nsize) return 0;
while(fa[rt]) rt = fa[rt];
return qkth(rt, k);
}
void Split(int ql, int qr, int& lrt, int& mrt, int& rrt) {
lrt = Find(ql - 1); mrt = Find(qr);
split(lrt); rrt = split(mrt);
return ;
}
void Merge(int lrt, int mrt, int rrt) {
mrt = merge(lrt, mrt); merge(mrt, rrt);
return ;
}
void recycle(int& o) {
if(!o) return ;
recycle(ch[0][o]); recycle(ch[1][o]);
fa[o] = 0; rec[++cc] = o; o = 0;
return ;
}
void Ins(int pos) {
int lrt, mrt = 0, rrt;
lrt = Find(pos);
if(lrt) rrt = split(lrt);
else rrt = Find(1), splay(rrt);
build(mrt, 1, cv);
Merge(lrt, mrt, rrt);
return ;
}
void Del(int ql, int qr) {
int lrt, mrt, rrt;
Split(ql, qr, lrt, mrt, rrt);
recycle(mrt);
Merge(lrt, mrt, rrt);
rt = max(lrt, rrt);
return ;
}
void Setv(int ql, int qr, int v) {
int lrt, mrt, rrt;
Split(ql, qr, lrt, mrt, rrt);
ns[mrt].setv = v;
Merge(lrt, mrt, rrt);
return ;
}
void Rev(int ql, int qr) {
int lrt, mrt, rrt;
Split(ql, qr, lrt, mrt, rrt);
ns[mrt].rev ^= 1;
Merge(lrt, mrt, rrt);
return ;
}
LL Sum(int ql, int qr) {
if(ql > qr) return 0;
int lrt, mrt, rrt;
Split(ql, qr, lrt, mrt, rrt);
LL ans = !hs(mrt) ? ns[mrt].sum : ns[mrt].setv * ns[mrt].siz;
Merge(lrt, mrt, rrt);
return ans;
}
LL MxSum() {
while(fa[rt]) rt = fa[rt];
LL ans = !hs(rt) ? ns[rt].ms : max(ns[rt].setv, ns[rt].setv * ns[rt].siz);
return ans;
}
int main() {
int n = read(), q = read();
for(int i = 1; i <= n; i++) val[i] = read();
build(rt, 1, n); nsize = n;
int cnt = 0, tq = q;
while(q--) {
char cmd[20]; scanf("%s", cmd);
if(cmd[0] == 'I') {
int pos = read(); cv = read();
for(int i = 1; i <= cv; i++) val[i] = read();
Ins(pos); nsize += cv;
}
if(cmd[0] == 'D') {
int ql = read(), qr = min(nsize, ql + read() - 1);
Del(ql, qr); nsize -= qr - ql + 1;
}
if(cmd[0] == 'M' && cmd[2] == 'K') {
int ql = read(), qr = min(nsize, ql + read() - 1), v = read();
Setv(ql, qr, v);
}
if(cmd[0] == 'R') {
int ql = read(), qr = min(nsize, ql + read() - 1);
Rev(ql, qr);
}
if(cmd[0] == 'G') {
int ql = read(), qr = min(nsize, ql + read() - 1);
printf("%lld\n", Sum(ql, qr)); cnt++;
}
if(cmd[0] == 'M' && cmd[2] == 'X') printf("%lld\n", MxSum()), cnt++;
}
return 0;
}