线段树(segment tree)是一种比较特殊的数据结构,顾名思义就是一棵二叉平衡树里面的节点是线段(严格来说是整数区间),满足假设某个节点的区间是[a,b],那么它的左右儿子分别是[a, (a+b)/2]和[(a+b)/2+1, b] , 并且大多数情况下叶子节点是长度为1的区间。由于区间树一般用来解决和整数有关的问题,因此也可以认为叶子节点长度为0,因为只含一个整数。 因此它和另一种数据结构区间树是不同的,区间树以红黑树为树模型,且左(右)儿子的区间严格小(大)于它自己的区间,具体的不同可以参看这篇文章里面的图。
具体的详细线段树可以参考这篇文章:《数据结构专题——线段树》,或者更详细的可以参考topcoder社区中的文章。
那么线段树可以用来解决什么问题,它结局的最基本的问题是一个在线问题,即我们先最初给定一些区间,然后多次做如下操作:(1)添加区间或删除已有区间(2)查询某个点落入多少个当前存在的区间。 显然这个问题由于是在线算法,如果是朴素算法,必然超时,那么我们就可以用的线段树来做。事实上对于区间[0,N],初始化一个线段树需要O(N)的时间和O(4N)的空间,然后每次修改和查询都只需要O(log N)的时间,显然是合理的算法。此外参考文章《线段树入门》可以了解更多的线段树可以解决的问题类型。
最后是线段树模板:
//线段树节点
struct Node{
int L, R; //左右边界
int v; //需要记录的值
};
Node node[max_n<<2];
//构建线段树 node[root]的左右端点分别是l和r
void build(int root, int l, int r){
node[root].L = l;
node[root].R = r;
if(r == l){
node[root].v = /***/;
return ;
}
int mid = l + ((r-l)>>1);
build(root<<1, l, mid);
build((root<<1)|1, mid+1, r);
node[root].s = node[root<<1].s + node[(root<<1)|1].s;
}
//更新线段树
void update(int root, int l, int r, int new_v){
//如果找到元节点,则直接更新结果
if(node[root].L == node[root].R && l <= node[root].L && r >= node[root].R){
node[root].v = new_v
return ;
}
//否则更新左右子区间
int mid = node[root].L + ((node[root].R - node[root].L)>>1);
if(l <= mid)
update(root<<1, l, r , new_v);
if(r > mid)
update((root<<1)|1, l, r , new_v);
node[root].v = node[(root<<1)].v + node[(root<<1)|1].v;
}
//查询线段树
int query(int root, int l, int r){
//刚好找到对应区间
if(l <= node[root].L && r >= node[root].R)
return node[root].v;
//否则查询子区间
int mid = node[root].L + ((node[root].R - node[root].L)>>1);
int ans = 0;
if(l <= mid )
ans += query(l,r, (root<<1));
if(r > mid)
ans += query(l,r, (root<<1)|1 );
return ans;
}
对于上面的模板,虽然我们在每个节点都存了左右边界(L和R)。但是实际上由于在ACM比赛中,问题一开始就告诉我们区间的范围,因此根节点的区间大小是已知的,进一步我们由于线段树的构造方法,可以知道任何一个子节点的边界也是可以根据父节点的边界计算出来。因此L和R可以不存在Node结构中,可以作为上面三个函数的参数传入,这样优化了内存。
此外线段树还有一种变动叫做惰性操作,具体请看《更为复杂的房屋买卖姿势》以及对应代码:
#include<iostream>
using namespace std;
#define ll long long int
const int maxN = 100010;
const int INF = 10010;
struct Node{
ll sum;
ll lazyadd;
ll lazyset;
int l, r, len;
}node[(maxN<<2)|1];
inline void merge(int h, int lh, int rh){
node[h].sum = node[lh].sum + node[rh].sum;
}
inline void release(int h, int lh, int rh){
if(node[h].len <= 1)
return ;
if(node[h].lazyset != -1){
node[lh].lazyset = node[rh].lazyset = node[h].lazyset;
node[lh].lazyadd = node[rh].lazyadd = 0;
node[lh].sum = node[lh].lazyset * node[lh].len;
node[rh].sum = node[rh].lazyset * node[rh].len;
node[h].lazyset = -1;
}
if(node[h].lazyadd != 0){
node[lh].lazyadd += node[h].lazyadd;
node[rh].lazyadd += node[h].lazyadd;
node[lh].sum += node[h].lazyadd * node[lh].len;
node[rh].sum += node[h].lazyadd * node[rh].len;
node[h].lazyadd = 0;
}
merge(h, lh, rh);
}
//node[h]的左右边界分别为l和r,下同
void SegBuild(const int h, const int l, const int r){
//cout << h <<" " << l << " " << r << endl;
node[h].sum = 0;
node[h].lazyadd = 0;
node[h].lazyset = -1;
node[h].l = l;
node[h].r = r;
node[h].len = r-l+1;
if(l == r){
scanf("%d", &(node[h].sum));
return ;
}
int mid = l + ((r-l)>>1);
int lh = h<<1;
int rh = lh|1;
SegBuild(lh, l, mid);
SegBuild(rh, mid+1, r);
merge(h, lh, rh);
}
void SegAdd(int h, int L,int R, int val){
if(node[h].l > R || node[h].r < L )
return ;
int lh = h<<1;
int rh = lh|1;
int mid = node[h].l + ((node[h].r - node[h].l)>>1);
if(node[h].l >= L && node[h].r <= R){
if(node[h].lazyset != -1){
release(h, lh, rh);
node[h].lazyset = -1;
}
node[h].lazyadd += val;
node[h].sum += node[h].len * val;
return ;
}
release(h, lh, rh);
if(L <= mid)
SegAdd(lh, L, R, val);
if(R > mid)
SegAdd(rh, L, R, val);
merge(h, lh, rh);
}
void SegSet(int h, int L, int R, int val){
if(node[h].l > R || node[h].r < L )
return ;
if(node[h].l >= L && node[h].r <= R){
node[h].lazyset = val;
node[h].lazyadd = 0;
node[h].sum = node[h].len * val;
return ;
}
int lh = h<<1;
int rh = lh|1;
int mid = node[h].l + ((node[h].r - node[h].l)>>1);
release(h, lh, rh);
if(L <= mid)
SegSet(lh, L, R, val);
if(R > mid)
SegSet(rh, L, R, val);
merge(h, lh, rh);
}
ll SegQuery(int h, int L, int R){
//cout << h << "[" << l << "," << r <<"] : " << L <<" " << R << endl;
if(node[h].l > R || node[h].r < L )
return -1;
if(L<= node[h].l && R >= node[h].r)
return node[h].sum;
int mid = node[h].l + ((node[h].r - node[h].l)>>1);
int lh = h<<1;
int rh = lh|1;
release(h, lh, rh);
int ans = 0;
if(L <= mid)
ans += SegQuery(lh, L, R);
if(R > mid)
ans += SegQuery(rh, L, R);
return ans;
}
int main(){
int N, M;
scanf("%d %d", &N, &M);
SegBuild(1, 0, N);
//pf(N);
int f, l, r, newP;
while(M--){
scanf("%d %d %d %d", &f, &l, &r, &newP);
if(f == 0)
SegAdd(1, l, r, newP);
else
SegSet(1, l, r, newP);
printf("%lld\n", SegQuery(1,0,N));
//pf(N);
}
//pf(N);
return 0;
}