题目
描述
请写一个程序,要求维护一个数列,支持以下6种操作:(请注意,格式栏中的下划线‘ _ ’表示实际输入文件中的空格)
- 插入 INSERT_posi_tot_c1_c2_…_ctot 在当前数列的第posi个数字后插入tot个数字:c1, c2, …, ctot;若在数列首插入,则posi为0
- 删除 DELETE_posi_tot 从当前数列的第posi个数字开始连续删除tot个数字
- 修改 MAKE-SAME_posi_tot_c 将当前数列的第posi个数字开始的连续tot个数字统一修改为c
- 翻转 REVERSE_posi_tot 取出从当前数列的第posi个数字开始的tot个数字,翻转后放入原来的位置
- 求和 GET-SUM_posi_tot 计算从当前数列开始的第posi个数字开始的tot个数字的和并输出
- 求和最大的子列 MAX-SUM 求出当前数列中和最大的一段子列,并输出最大和
输入格式
输入文件的第 1 行包含两个数 N 和 M,N 表示初始时数列中数的个数,M 表示要进行的操作数目。 第 2 行包含 N 个数字,描述初始时的数列。 以下 M 行,每行一条命令,格式参见问题描述中的表格
输出格式
对于输入数据中的 GET-SUM 和 MAX-SUM 操作,向输出文件依次打印结 果,每个答案(数字)占一行。
输入样例
9 8 2 -6 3 5 1 -5 -3 6 3
GET-SUM 5 4
MAX-SUM
INSERT 8 3 -5 7 2
DELETE 12 1
MAKE-SAME 3 3 2
REVERSE 3 6
GET-SUM 5 4
MAX-SUM
输出样例
-1
10
1
10
说明
你可以认为在任何时刻,数列中至少有 1 个数。
输入数据一定是正确的,即指定位置的数在数列中一定存在。
50%的数据中,任何时刻数列中最多含有 30 000 个数;
100%的数据中,任何时刻数列中最多含有 500 000 个数。
100%的数据中,任何时刻数列中任何一个数字均在[-1 000, 1 000]内。
100%的数据中,M ≤20 000,插入的数字总数不超过 4 000 000 。
解题思路
平衡树… Splay
基本知识
在Splay中进行区间修改时,假设要改的区间是[l,r][l,r],我们利用Splay的伸展操作将l−1l−1旋转至根,将r+1r+1旋转至根的右儿子,那么要操作的区间[l,r][l,r]就在根的右儿子的左儿子上了。
为了避免找不到l−1l−1或r+1r+1的尴尬,可以事先在平衡树的首尾各插入一个保护节点,然后把每个对[l,r][l,r]操作改成对[l+1,r+1][l+1,r+1]的操作就行了,好写好调。
Splay结构体
struct Splay{
int son[2], fa, size;// Basic info.
int sum, mx, lmx, rmx, val;// Maintained info.
int cov, rev;// tags
Splay(){
son[0] = son[1] = fa = size = 0;
sum = lmx = rmx = 0, val = mx = -INF;
cov = -INF, rev = 0;
}
}tr[N];
- 基本信息
- 这道题要维护的信息
- sum:以该节点为根的子树所包含的区间中所有值的和
- mx, lmx, rmx:mx是区间中的最大子段和,联想线段树中区间最大子段和的维护,我们也需要lmx和rmx,即从左端点往右走的最大子段和从右端点往左走的最大子段。那么:tr[id].mx=max(max(tr[lson].mx,tr[rson].mx),tr[lson].rmx+tr[id].val+tr[rson].lmx);(4)(4)tr[id].mx=max(max(tr[lson].mx,tr[rson].mx),tr[lson].rmx+tr[id].val+tr[rson].lmx);tr[id].lmx=max(tr[lson].lmx,tr[lson].sum+tr[id].val+tr[rson].lmx);(5)(5)tr[id].lmx=max(tr[lson].lmx,tr[lson].sum+tr[id].val+tr[rson].lmx);tr[id].rmx=max(tr[rson].rmx,tr[rson].sum+tr[id].val+tr[lson].rmx);(6)(6)tr[id].rmx=max(tr[rson].rmx,tr[rson].sum+tr[id].val+tr[lson].rmx);
- val:该节点的值
- 标记
- cov:覆盖标记
- rev:反转标记
- 初始化
- son, fa, size为0,不解释了
- sum为0,val为-INF(注意数字可以为负,所以用-INF表示未赋值)
- lmx, rmx为0,mx为-INF:这点十分重要,lmx(rmx)可以一个数都不包含,即为0,但是mx必须包含至少一个数,所以初值设为-INF
分析操作
- INSERT_posi_tot_c1_c2_…_ctot
先把c1,c2,...,ctotc1,c2,...,ctot建成一颗Splay,再插入到原Splay中去
inline void insert(int pos){
splay(select(pos), 0), splay(select(pos+1), root);
tr[tr[root].son[1]].son[0] = temp;
tr[temp].fa = tr[root].son[1];
pushup(tr[root].son[1]), pushup(root);
}
- DELETE_posi_tot
找到要删除的区间,清空所有信息
(有关back队列和recycle函数的解释见下面注意事项第一条)
queue<int> back;
void recycle(int &x){
if(tr[x].son[0]) recycle(tr[x].son[0]);
if(tr[x].son[1]) recycle(tr[x].son[1]);
tr[x].son[0] = tr[x].son[1] = tr[x].fa = tr[x].size = 0;
tr[x].sum = tr[x].lmx = tr[x].rmx = 0, tr[x].val = tr[x].mx = -INF;
tr[x].cov = -INF, tr[x].rev = 0;
back.push(x), x = 0;
}
inline void del(int l, int r){
splay(select(l-1), 0), splay(select(r+1), root);
recycle(tr[tr[root].son[1]].son[0]);
pushup(tr[root].son[1]), pushup(root);
}
- MAKE-SAME_posi_tot_c
找到要修改的区间,打上cov标记,同时维护好节点信息
inline void makeSame(int l, int r, int val){
splay(select(l-1), 0), splay(select(r+1), root);
int t = tr[tr[root].son[1]].son[0];
tr[t].sum = val * tr[t].size;
tr[t].mx = max(val, tr[t].sum);
tr[t].lmx = max(0, tr[t].sum);
tr[t].rmx = max(0, tr[t].sum);
tr[t].val = val;
tr[t].cov = val;
tr[t].rev = 0;
pushup(tr[root].son[1]), pushup(root);
}
- REVERSE_posi_tot
找到要翻转的区间,打上rev标记,同时维护好节点信息
inline void reverse(int l, int r){
splay(select(l-1), 0), splay(select(r+1), root);
int t = tr[tr[root].son[1]].son[0];
swap(tr[t].son[0], tr[t].son[1]);
swap(tr[t].lmx, tr[t].rmx);
tr[t].rev ^= 1;
pushup(tr[root].son[1]), pushup(root);
}
- GET-SUM_posi_tot
找到询问的区间,输出节点的sum值
inline int getSum(int l, int r){
splay(select(l-1), 0), splay(select(r+1), root);
return tr[tr[tr[root].son[1]].son[0]].sum;
}
- MAX-SUM
直接输出根节点的mx值
printf("%d\n", tr[BST.root].mx);
其他函数
- pushdown
inline void pushdown(int id){
if(!id) return;
if(tr[id].cov != -INF){
if(lson){
tr[lson].sum = tr[id].cov * tr[lson].size;
tr[lson].mx = max(tr[id].cov, tr[lson].sum);
tr[lson].lmx = max(0, tr[lson].sum);//注意lmx和rmx最小也就为0
tr[lson].rmx = max(0, tr[lson].sum);
tr[lson].val = tr[id].cov;
tr[lson].cov = tr[id].cov;
tr[id].rev = 0;
}
if(rson){
tr[rson].sum = tr[id].cov * tr[rson].size;
tr[rson].mx = max(tr[id].cov, tr[rson].sum);
tr[rson].lmx = max(0, tr[rson].sum);
tr[rson].rmx = max(0, tr[rson].sum);
tr[rson].val = tr[id].cov;
tr[rson].cov = tr[id].cov;
tr[id].rev = 0;
}
tr[id].cov = -INF;
}
if(tr[id].rev){
if(lson){
swap(tr[lson].son[0], tr[lson].son[1]);
swap(tr[lson].lmx, tr[lson].rmx);//注意要翻转lmx和rmx
tr[lson].rev ^= 1;
}
if(rson){
swap(tr[rson].son[0], tr[rson].son[1]);
swap(tr[rson].lmx, tr[rson].rmx);
tr[rson].rev ^= 1;
}
tr[id].rev = 0;
}
}
- pushup
inline void pushup(int id){
tr[id].size = tr[lson].size + tr[rson].size + 1;
tr[id].sum = tr[lson].sum + tr[rson].sum + tr[id].val;
tr[id].mx = max(max(tr[lson].mx, tr[rson].mx), tr[lson].rmx + tr[id].val + tr[rson].lmx);
tr[id].lmx = max(tr[lson].lmx, tr[lson].sum + tr[id].val + tr[rson].lmx);
tr[id].rmx = max(tr[rson].rmx, tr[rson].sum + tr[id].val + tr[lson].rmx);
}
- build 建树
对于一个已知数列,如果一个一个插入是O(nlogn)O(nlogn)的,太慢了,可以直接二分着建树
int build(int x[], int l, int r, int fa){
if(l > r) return 0;
int mid = (l + r) >> 1;
int t = newNode(x[mid]);
tr[t].fa = fa;
tr[t].son[0] = build(x, l, mid - 1, t);
tr[t].son[1] = build(x, mid + 1, r, t);
pushup(t);
return t;
}
- newNode 新建节点
充分利用无用节点(见下面注意事项第一条)
inline int newNode(int val){
int now;
if(!back.empty()) now = back.front(), back.pop();
else now = ++cnt;
tr[now].size = 1, tr[now].sum = tr[now].mx = tr[now].val = val;
tr[now].lmx = tr[now].rmx = max(0, val);
return now;
}
- rotate, splay, select函数:Splay基本操作,不赘述了
注意事项
- 这道题空间只给了64MB,所以要回收无用节点,我们可以用一个队列或者栈存一下回收的节点,每次新建一个节点时,如果队列非空,就取出队首编号作为新节点编号,否则才新开一个编号。这就是上面del函数没有直接删掉而是用了一个recycle函数回收节点的原因。
- 维护信息时一定不要漏了!
- 不要忘了pushup!
Code
#include<algorithm>
#include<cstdio>
#include<queue>
#define lson tr[id].son[0]
#define rson tr[id].son[1]
using namespace std;
const int INF = 1e9;
const int N = 500005;
int n, m, a[N], posi, tot, temp;
char opt[10];
struct Splay{
int son[2], fa, size;// Basic info.
int sum, mx, lmx, rmx, val;// Maintained info.
int cov, rev;// tags
Splay(){
son[0] = son[1] = fa = size = 0;
sum = lmx = rmx = 0, val = mx = -INF;
cov = -INF, rev = 0;
}
}tr[N];
struct OPT_Splay{
int cnt, root;
queue<int> back;
inline int newNode(int val){
int now;
if(!back.empty()) now = back.front(), back.pop();
else now = ++cnt;
tr[now].size = 1, tr[now].sum = tr[now].mx = tr[now].val = val;
tr[now].lmx = tr[now].rmx = max(0, val);
return now;
}
inline void pushup(int id){
tr[id].size = tr[lson].size + tr[rson].size + 1;
tr[id].sum = tr[lson].sum + tr[rson].sum + tr[id].val;
tr[id].mx = max(max(tr[lson].mx, tr[rson].mx), tr[lson].rmx + tr[id].val + tr[rson].lmx);
tr[id].lmx = max(tr[lson].lmx, tr[lson].sum + tr[id].val + tr[rson].lmx);
tr[id].rmx = max(tr[rson].rmx, tr[rson].sum + tr[id].val + tr[lson].rmx);
}
inline void pushdown(int id){
if(!id) return;
if(tr[id].cov != -INF){
if(lson){
tr[lson].sum = tr[id].cov * tr[lson].size;
tr[lson].mx = max(tr[id].cov, tr[lson].sum);
tr[lson].lmx = max(0, tr[lson].sum);
tr[lson].rmx = max(0, tr[lson].sum);
tr[lson].val = tr[id].cov;
tr[lson].cov = tr[id].cov;
tr[id].rev = 0;
}
if(rson){
tr[rson].sum = tr[id].cov * tr[rson].size;
tr[rson].mx = max(tr[id].cov, tr[rson].sum);
tr[rson].lmx = max(0, tr[rson].sum);
tr[rson].rmx = max(0, tr[rson].sum);
tr[rson].val = tr[id].cov;
tr[rson].cov = tr[id].cov;
tr[id].rev = 0;
}
tr[id].cov = -INF;
}
if(tr[id].rev){
if(lson){
swap(tr[lson].son[0], tr[lson].son[1]);
swap(tr[lson].lmx, tr[lson].rmx);
tr[lson].rev ^= 1;
}
if(rson){
swap(tr[rson].son[0], tr[rson].son[1]);
swap(tr[rson].lmx, tr[rson].rmx);
tr[rson].rev ^= 1;
}
tr[id].rev = 0;
}
}
inline int which(int x){ return tr[tr[x].fa].son[1] == x;}
inline void rotate(int x, int kind){
int y = tr[x].fa, z = tr[y].fa, B = tr[x].son[kind];
tr[x].son[kind] = y, tr[y].son[!kind] = B, tr[z].son[which(y)] = x;
tr[x].fa = z, tr[y].fa = x, tr[B].fa = y;
pushup(y), pushup(x);
}
inline void splay(int x, int goal){
if(x == goal) return;
while(tr[x].fa != goal){
int y = tr[x].fa, z = tr[y].fa;
pushdown(z), pushdown(y), pushdown(x);
int dir1 = !which(x), dir2 = !which(y);
if(z == goal) rotate(x, dir1);
else{
if(dir1 == dir2) rotate(y, dir2);
else rotate(x, dir1);
rotate(x, dir2);
}
}
if(goal == 0) root = x;
}
int build(int x[], int l, int r, int fa){
if(l > r) return 0;
int mid = (l + r) >> 1;
int t = newNode(x[mid]);
tr[t].fa = fa;
tr[t].son[0] = build(x, l, mid - 1, t);
tr[t].son[1] = build(x, mid + 1, r, t);
pushup(t);
return t;
}
inline int select(int k){
int now = root;
pushdown(now);
while(tr[tr[now].son[0]].size + 1 != k){
if(tr[tr[now].son[0]].size >= k) now = tr[now].son[0];
else k -= tr[tr[now].son[0]].size + 1, now = tr[now].son[1];
pushdown(now);
}
return now;
}
inline void insert(int pos){
splay(select(pos), 0), splay(select(pos+1), root);
tr[tr[root].son[1]].son[0] = temp;
tr[temp].fa = tr[root].son[1];
pushup(tr[root].son[1]), pushup(root);
}
void recycle(int &x){
if(tr[x].son[0]) recycle(tr[x].son[0]);
if(tr[x].son[1]) recycle(tr[x].son[1]);
tr[x].son[0] = tr[x].son[1] = tr[x].fa = tr[x].size = 0;
tr[x].sum = tr[x].lmx = tr[x].rmx = 0, tr[x].val = tr[x].mx = -INF;
tr[x].cov = -INF, tr[x].rev = 0;
back.push(x), x = 0;
}
inline void del(int l, int r){
splay(select(l-1), 0), splay(select(r+1), root);
recycle(tr[tr[root].son[1]].son[0]);
pushup(tr[root].son[1]), pushup(root);
}
inline void makeSame(int l, int r, int val){
splay(select(l-1), 0), splay(select(r+1), root);
int t = tr[tr[root].son[1]].son[0];
tr[t].sum = val * tr[t].size;
tr[t].mx = max(val, tr[t].sum);
tr[t].lmx = max(0, tr[t].sum);
tr[t].rmx = max(0, tr[t].sum);
tr[t].val = val;
tr[t].cov = val;
tr[t].rev = 0;
pushup(tr[root].son[1]), pushup(root);
}
inline void reverse(int l, int r){
splay(select(l-1), 0), splay(select(r+1), root);
int t = tr[tr[root].son[1]].son[0];
swap(tr[t].son[0], tr[t].son[1]);
swap(tr[t].lmx, tr[t].rmx);
tr[t].rev ^= 1;
pushup(tr[root].son[1]), pushup(root);
}
inline int getSum(int l, int r){
splay(select(l-1), 0), splay(select(r+1), root);
return tr[tr[tr[root].son[1]].son[0]].sum;
}
}BST;
int main(){
scanf("%d%d", &n, &m);
for(int i = 2; i <= n + 1; i++) scanf("%d", &a[i]);
a[1] = -INF, a[n+2] = -INF;
BST.root = BST.build(a, 1, n + 2, 0);
while(m--){
scanf("%s", opt);
if(opt[2] != 'X') scanf("%d%d", &posi, &tot), posi++;
if(opt[0] == 'I'){
for(int i = 1; i <= tot; i++) scanf("%d", &a[i]);
temp = BST.build(a, 1, tot, 0);
BST.insert(posi);
}
else if(opt[0] == 'D') BST.del(posi, posi + tot - 1);
else if(opt[2] == 'K'){
scanf("%d", &a[0]);
BST.makeSame(posi, posi + tot - 1, a[0]);
}
else if(opt[0] == 'R') BST.reverse(posi, posi + tot - 1);
else if(opt[0] == 'G') printf("%d\n", BST.getSum(posi, posi + tot - 1));
else if(opt[2] == 'X') printf("%d\n", tr[BST.root].mx);
}
return 0;
}