引入
先来思考这样一个问题,考虑设计这样一个数据结构,可以实现数据的区间/单点 查询和修改
分块的思想应该都知道 (虽然蒟蒻不会分块),就是将问题分成若干的块,逐块解决然后合并,那么我们考虑将一个区间分成若干的区间,并使我们的数据结构满足如下两点性质:
1. 修改一个数只影响到少数区间
2.一个大区间可以由若干个小区间拼成
线段树就是一个满足如上特点的数据结构,将一个大区间分成两半,成为两个小区间,一直分到只剩一个数为止
(盗一下大佬的图)
线段树画出来就长这样,蓝色的数字代表的是当前区间的左右界(都是闭区间)
说实话蒟蒻觉得线段树这个名字完全就是根据这玩意长的像线段树来取的
建树 b u i l d build build
根据上面的图,可以看出线段树其实是一棵“二叉树”,回忆下二叉树的建树 (别想了我这么懒怎么可能会去为了写个线段树的学习笔记特地写一篇二叉树),我们需要三个变量:当前节点的编号
c
u
r
cur
cur,左儿子,右儿子 , 由于线段树的节点存的是一段区间,而且左右儿子的区间都是当前区间折半的,所以我们只需要知道当前区间的左右界,也就是
l
,
r
l,r
l,r
然后递归建树即可,当 l = r l = r l=r 也就是叶子节点的时候说明区间内只有一个数,直接把当前数放进去然后往回走就行
C o d e Code Code
void build(int cur,int l,int r){
if(l==r){
node[cur]=a[l];
return ;
}
int mid=(l+r)>>1;
build(ls(cur),l,mid);
build(rs(cur),mid+1,r);
push_up(cur);
}
这里为了方便,把寻找左右儿子分别写了函数,这样代码方便
其实也就一步计算
int ls(int fa){return fa * 2;}
int rs(int fa){return fa * 2 + 1;}
编号方式参考二叉树
还有个 push_up 操作,其实也就一步计算,就是将当前节点的两个子节点合并起来拼成当前节点
void push_up(int cur){
node[cur] = node[ls(cur)] + node[rs(cur)];
}
单点修改 m o d i f y modify modify
鬼知道我单词有没有拼错
考虑将一个点 x x x 加上 y y y
思路很简单,就是找到 x x x 所在的区间往下递归,直到只剩 x x x 这个单点的时候加上 y y y 然后往回走
跟建树的时候差不多,不过我们还要再传两个参数 x , y x,y x,y
C o d e Code Code
void modify(int cur,int l, int r, int x, int y){
if(l == r){
node[l] += y;//找到x
return;
}
int mid = (l + r) >> 1;
if(x <= mid){//x在左儿子这里
modify(ls(cur),l,mid,x,y);
} else{ //在右儿子这里
modify(rs(cur),mid+1,r,x,y);
}
push_up(cur);//更新当前节点的值
}
由于我们只会往有 x x x 的区间寻找,所以找到叶子节点的时候必然就找到了 x x x
区间查询 f i n d find find
求出区间 [ x , y ] [x,y] [x,y] 的值
这里就用到第二条性质了,用少数区间拼成我们所要求的区间、
思想也不难理解:
1.如果当前区间被所要求的区间包含,就返回当前区间的值
2.如果当前区间和要求的区间有交集,就把区间拆到左右儿子去找,然后返回左右儿子找上来的值
C o d e Code Code
int find(int x,int y,int l,int r,int cur){
int res=0;
if(x<=l&&r<=y)return node[cur];//如果当前区间被要求的区间包含
int mid=(l+r)>>1;
if(x<=mid)res+=find(x,y,l,mid,ls(cur));//左儿子和区间有交集
if(y>mid) res+=find(x,y,mid+1,r,rs(cur));//右儿子和区间有交集
return res;
}
现在你就可以去切掉模板题 树状数组1
区间修改(区间加,区间乘)
其实线段树是可以维护几乎所有满足结合律的操作的,也就是说,只要当前区间可以拆分到两个子区间操作就都可以维护
懒标记 l a z y t a g lazytag lazytag
当我们修改一个区间的时候,有一些节点是直接被区间包含的,为了直接在当前节点就往回跑(就和区间查找一样)而不是傻傻的往下走到每一个点(这样复杂度未免也太大了),所以我们加入了一个懒标记来记录当前区间的子区间还有什么操作没有做
push_down
类似于上面的 push_up ,这个操作就是把懒标记向下传播,因为当我们要对子区间进行修改的时候,必须要把上一步父节点未下传的操作干掉
void push_down(int cur,int l, int r){
int mid = (l + r) >> 1;
mul(ls(cur),l,mid,add[cur],mult[cur]);
mul(rs(cur),mid+1,r,add[cur],mult[cur]);
add[cur] = 0;//清空懒标记
mult[cur] = 1;//清空懒标记
}
这里 a d d add add 数组表示的是区间加的懒标记, m u l t mult mult是区间乘的懒标记
m u l mul mul 函数自然是修改左右儿子的懒标记然后将当前节点的值修改掉
void mul(int cur,int l, int r,int ad,int mu){
node[cur] = node[cur] * mu + (r - l + 1) * ad;
mult[cur] = mu * mult[cur];
add[cur] = add[cur] * mu + ad;
}
这里有一个注意点就是要考虑一下运算优先级
区间加 up_date_add
void up_date_add(int cur, int l, int r, int x, int y, int k){
if(x <= l && r <= y){
mul(cur,l,r,k,1);
return;
}
push_down(cur,l,r);
int mid = (l + r) >> 1;
if(x <= mid){
up_date_add(ls(cur),l,mid,x,y,k);
}
if(y > mid){
up_date_add(rs(cur),mid+1,r,x,y,k);
}
push_up(cur);
}
区间乘 up_date_mul
void up_date_mul(int cur, int l, int r, int x, int y, int k){
if(x <= l && r <= y){
mul(cur,l,r,0,k);
return;
}
push_down(cur,l,r);
int mid = (l + r) >> 1;
if(x <= mid){
up_date_mul(ls(cur),l,mid,x,y,k);
}
if(y > mid){
up_date_mul(rs(cur),mid+1,r,x,y,k);
}
push_up(cur);
}
这里给出线段树2的代码
C o d e Code Code
#include <bits/stdc++.h>
#define ll long long
#define ull unsigned long long
#define int long long
const int N = 5e5+10;
const int M = 1e4+10;
const int INF = 0x3f3f3f3f;
using namespace std;
int node[N*4],a[N],add[N*4],mult[N*4];
int n, m, p;
int ls(int fa){return fa * 2;}
int rs(int fa){return fa * 2 + 1;}
void push_up(int cur){
node[cur] = (node[ls(cur)] + node[rs(cur)])% p;
}
void mul(int cur,int l, int r,int ad,int mu){
node[cur] = ((node[cur] * mu)% p + (r - l + 1) * ad) % p;
mult[cur] = mu * mult[cur] % p;
add[cur] = (add[cur] * mu + ad) % p;
}
void push_down(int cur,int l, int r){
int mid = (l + r) >> 1;
mul(ls(cur),l,mid,add[cur],mult[cur]);
mul(rs(cur),mid+1,r,add[cur],mult[cur]);
add[cur] = 0;
mult[cur] = 1;
}
void up_date_add(int cur, int l, int r, int x, int y, int k){
if(x <= l && r <= y){
mul(cur,l,r,k,1);
return;
}
push_down(cur,l,r);
int mid = (l + r) >> 1;
if(x <= mid){
up_date_add(ls(cur),l,mid,x,y,k);
}
if(y > mid){
up_date_add(rs(cur),mid+1,r,x,y,k);
}
push_up(cur);
}
void up_date_mul(int cur, int l, int r, int x, int y, int k){
if(x <= l && r <= y){
mul(cur,l,r,0,k);
return;
}
push_down(cur,l,r);
int mid = (l + r) >> 1;
if(x <= mid){
up_date_mul(ls(cur),l,mid,x,y,k);
}
if(y > mid){
up_date_mul(rs(cur),mid+1,r,x,y,k);
}
push_up(cur);
}
int find(int x,int y,int l,int r,int cur){
int res=0;
if(x<=l&&r<=y)return node[cur] % p;
int mid=(l+r)>>1;
push_down(cur,l,r);
if(x<=mid)res+=find(x,y,l,mid,ls(cur));
if(y>mid) res+=find(x,y,mid+1,r,rs(cur));
return res % p;
}
void build(int cur,int l,int r){
add[cur]=0;
mult[cur] = 1;
if(l==r){
node[cur]=a[l]% p;
return ;
}
int mid=(l+r)>>1;
build(ls(cur),l,mid);
build(rs(cur),mid+1,r);
push_up(cur);
}
signed main(){
cin >> n >> m >> p;
for(int i = 1; i <= n; i++){
cin >> a[i];
}
build(1,1,n);
while(m--){
int op;
cin >> op;
if(op == 1){
int x, y, k;
cin >> x >> y >> k;
up_date_mul(1,1,n,x,y,k);
}else if(op == 2){
int x, y, k;
cin >> x >> y >> k;
up_date_add(1,1,n,x,y,k);
}else{
int x, y;
cin >> x >> y;
cout << find(x,y,1,n,1) << endl;
}
}
return 0;
}
完结撒花
预告:预计明天后天会写树状数组的学习笔记