先用一道题入手线段树
P1816 忠诚
对于这道题,我们可以先写出暴力 (连暴力都能拿到90pts)
#include <bits/stdc++.h>
using namespace std;
int n, m;
int a[100005];
int main(){
scanf("%d%d" ,&n ,&m);//由于习惯,这里将 m,n 互换了
for(int i = 1; i <= n; ++i){
scanf("%d" ,&a[i]);
}
for(int i = 1; i <= m; ++i){
int x, y, ans = 0x3f3f3f3f;
scanf("%d%d" ,&x ,&y);
for(int j = x; j <= y; ++j){
ans = min(ans, a[j]);
}
printf("%d " ,ans);
}
return 0;
}
然后我们尝试思考优化的方法
这道题里我们发现, 大量枚举的区间里很可能有重复的部分, 所以我们之前维护出的答案对于后面的答案可能是有帮助的, 但是给出的询问的答案对后面的求解帮助有多大,则不是我们能准确预测的. 如果我们能预处理出一些询问的答案, 对于后面的求解就会有帮助. 而线段树就可以通过预处理出一些特殊区间的答案, 从而加速后面的求解.
线段树
对于区间 [ 1 , 5 ] [1,5] [1,5] , 用图来表示的话, 线段树大概是这么个二叉树的样子.其建树方式基于完全二叉树或满二叉树.
注意到有很多节点是不存在的(图中 []
表示空区间), 这是为了保持编号, 使得完全二叉树的任意一个编号为
i
i
i 的非叶子节点的左子节点的编号是
i
×
2
i \times 2
i×2 ,右子节点的编号是
i
×
2
+
1
i \times 2 + 1
i×2+1
用数组 m i n v minv minv 存储最小值, 则递归建树代码如下
void build(int id, int l, int r){//id 表示当前根节点编号; l, r 表示根节点下的范围
if(l == r){//区间长度为为 1 则最小值为 a[l] 或 a[r]
minv[id] = a[l];
return;
}
int mid = (l + r) / 2;
build(id * 2, l, mid); //建左子树
build(id * 2 + 1, mid + 1, r); //建右子树
minv[id] = min(minv[id * 2], minv[id * 2 + 1]); //最小值为左右两子节点的最小值中较小的
}
然后是查询
如果我们想查找某一个区间, 就把它分解为若干个子区间, 所有子区间的最小值就是要查的最小值
以上面的图为例, 如果我们想查找 [ 2 , 4 ] [2,4] [2,4] 这个区间的最小值, 我们需要用到的区间有 [ 1 , 5 ] [ 1 , 3 ] [ 1 , 2 ] [ 2 , 2 ] [ 3 , 3 ] [ 4 , 5 ] [ 4 , 4 ] [1,5] \quad [1,3] \quad [1,2] \quad [2,2] \quad [3,3] \quad [4,5] \quad [4,4] [1,5][1,3][1,2][2,2][3,3][4,5][4,4]
代码如下
int query(int id, int l, int r, int x, int y){//id 表示节点编号 l 和 r 表示节点下的区间, x 和 y 表示要查找的区间
if(x <= l && r <= y){//如果 [l,r] 被 [x,y] 完全包含, 则不用再找了, 直接返回结果
return minv[id];
}
int mid = (l + r) / 2;
int ans = 0x3f3f3f3f;
if(x <= mid){//如果要查的区间包含左半边
ans = min(ans, query(id * 2, l, mid, x, y));
}
if(y > mid){//要查的区间包含右半边
ans = min(ans, query(id * 2 + 1, mid + 1, r, x, y));
}
return ans;
}
AC 代码如下
#include <bits/stdc++.h>
using namespace std;
int n, m;
int a[100005];
int minv[400005];//注意要开四倍, 具体证明可以看 OI Wiki
void build(int id, int l, int r){
if(l == r){
minv[id] = a[l];
return;
}
int mid = (l + r) / 2;
build(id * 2, l, mid);
build(id * 2 + 1, mid + 1, r);
minv[id] = min(minv[id * 2], minv[id * 2 + 1]);
}
int query(int id, int l, int r, int x, int y){
if(x <= l && r <= y){
return minv[id];
}
int mid = (l + r) / 2;
int ans = 0x3f3f3f3f;
if(x <= mid){
ans = min(ans, query(id * 2, l, mid, x, y));
}
if(y > mid){
ans = min(ans, query(id * 2 + 1, mid + 1, r, x, y));
}
return ans;
}
int main(){
scanf("%d%d" ,&n ,&m);
for(int i = 1; i <= n; ++i){
scanf("%d" ,&a[i]);
}
build(1, 1, n);
for(int i = 1; i <= m; ++i){
int x, y;
scanf("%d%d" ,&x ,&y);
printf("%d " ,query(1, 1, n, x, y));
}
return 0;
}
我们再来想一个问题:假设在询问过程中账本的内容可能会被修改, 我们应该怎么做?
输入中第一行有两个数 m , n m, n m,n 分别表示有 m ( m ≤ 100000 ) m (m \leq 100000) m(m≤100000) 笔账和有 n ( n ≤ 100000 ) n (n \leq 100000) n(n≤100000) 个问题.
接下来每行为3个数字, 第一个 p p p 为数字 1 1 1 或数字 2 2 2, 第二个数为 x x x, 第三个数为 y y y.
当 p = 1 p = 1 p=1 则查询 [ x , y ] [x, y] [x,y] 区间; 当 p = 2 p = 2 p=2 则改变第 x x x 个数为 y y y.
如果我们想修改一个数, 即修改线段树的一个储存的区间长度为 1 1 1 的结点, 那么我们需要修改他的所有祖先.
以修改 [ 3 , 3 ] [3, 3] [3,3] 为例, 我们需要这么修改:
代码如下
void update(int id, int l, int r, int x, int v){//id 表示节点编号; l, r 表示节点下的区间; x 表示原数组(指 a[])中要修改的位置; v 代表要修改的值
if(l == r){
minv[id] = v;
return;
}
int mid = (l + r) / 2;
if(x <= mid){//要修改的位置在左半区间
update(id * 2, l, mid, x, v);
}else{//在右半区间
update(id * 2 + 1, mid + 1, r, x, v);
}
minv[id] = min(minv[id * 2], minv[id * 2 + 1]);
}
附部分输入输出
for(int i = 1; i <= m; ++i){
int op, x, y;
scanf("%d%d%d" ,&op ,&x ,&y);
if(op == 1){
printf("%d " ,query(1, 1, n, x, y));
}
if(op == 2){
update(1, 1, n, x, y);
}
}
再来看一下区间修改
P3372 【模板】线段树 1
如果我们挨个单点修改的话, 每一次修改都要 O ( log n ) O(\log n) O(logn) 的时间, 时间复杂度非常大.
这里用一个更大的区间 [ 1 , 10 ] [1,10] [1,10] 来说明问题.
如果我们要查找 [ 3 , 7 ] [3,7] [3,7] , 那么需要访问的节点如下
而如果对 [ 3 , 7 ] [3,7] [3,7] 进行挨个单点修改, 那么需要访问的节点如下, 其中红色部分为我们查找和修改都需要访问的节点
也就是说, 我们每一次查询的时候, 不一定都需要访问修改的值. 这使得我们想出一种方法(名字叫 lazy tag
), 即在更新时只更新出需要查询的节点, 并把子节点需要修改的值记录下来, 如果下一次查询需要访问子节点, 就把记录的需要修改的值下发下去. 这样我们只用修改一次就可以完成区间修改
这就像老师布置了一堆卷子, 但只查其中的几张张且告诉你分别是哪几张, 那么你为了节省时间只需要写检查的几张, 下次查别的几张你再接着写要查的
区间修改代码如下
void push_up(int id){//向上合并
sumv[id] = sumv[id * 2] + sumv[id * 2 + 1];
}
void push_down(int id, int l, int r){//向下分发
if(lazy[id]){//如果需要更新子节点的值
int mid = (l + r) / 2;
lazy[id * 2] += lazy[id];//将子节点要修改的值算上当前节点要修改的值以便继续下发
lazy[id * 2 + 1] += lazy[id];
sumv[id * 2] += ((long long)(mid - l + 1)) * lazy[id];
sumv[id * 2 + 1] += ((long long)(r - mid)) * lazy[id];
lazy[id] = 0;
}
}
void interval_update(int id, int l, int r, int x, int y, long long v){
if(x <= l && r <= y){
lazy[id] += v;
sumv[id] += ((long long)(r - l + 1)) * v;
return;
}
push_down(id, l, r);//下发更新子节点
int mid = (l + r) / 2;
if(x <= mid){
interval_update(id * 2, l, mid, x, y, v);
}
if(y > mid){
interval_update(id * 2 + 1, mid + 1, r, x, y, v);
}
push_up(id);//更新父节点
}
AC代码如下
#include <bits/stdc++.h>
using namespace std;
int n, m;
long long a[100005];
long long sumv[400005], lazy[400005];
void push_up(int id){
sumv[id] = sumv[id * 2] + sumv[id * 2 + 1];
}
void push_down(int id, int l, int r){
if(lazy[id]){
int mid = (l + r) / 2;
lazy[id * 2] += lazy[id];
lazy[id * 2 + 1] += lazy[id];
sumv[id * 2] += ((long long)(mid - l + 1)) * lazy[id];
sumv[id * 2 + 1] += ((long long)(r - mid)) * lazy[id];
lazy[id] = 0;
}
}
void build(int id, int l, int r){
if(l == r){
sumv[id] = a[l];
return;
}
int mid = (l + r) / 2;
build(id * 2, l, mid);
build(id * 2 + 1, mid + 1, r);
push_up(id);
}
void interval_update(int id, int l, int r, int x, int y, long long v){
if(x <= l && r <= y){
lazy[id] += v;
sumv[id] += ((long long)(r - l + 1)) * v;
return;
}
push_down(id, l, r);
int mid = (l + r) / 2;
if(x <= mid){
interval_update(id * 2, l, mid, x, y, v);
}
if(y > mid){
interval_update(id * 2 + 1, mid + 1, r, x, y, v);
}
push_up(id);
}
long long query(int id, int l, int r, int x, int y){
if(x <= l && r <= y){
return sumv[id];
}
push_down(id, l, r);
int mid = (l + r) / 2;
long long ans = 0;
if(x <= mid){
ans += query(id * 2, l, mid, x, y);
}
if(y > mid){
ans += query(id * 2 + 1, mid + 1, r, x, y);
}
return ans;
}
int main(){
scanf("%d%d" ,&n ,&m);
for(int i = 1; i <= n; ++i){
scanf("%lld" ,&a[i]);
}
build(1, 1, n);
for(int i = 1; i <= m; ++i){
int op, x, y;
scanf("%d%d%d" ,&op ,&x ,&y);
if(op == 1){
long long k;
scanf("%lld" ,&k);
interval_update(1, 1, n, x, y, k);
}
if(op == 2){
printf("%lld\n" ,query(1, 1, n, x, y));
}
}
return 0;
}
P3373 【模板】线段树 2
这道题十分复杂, 因为有乘和加两种运算, 所以我们考虑用两种 tag , 分别表示乘和加的标记.
设加的标签为 lazy1
, 乘的标签为 lazy2
. lazy1 的初始值为
0
0
0 , lazy2 的初始值为
1
1
1 .
如果我们考虑乘和加的 tag 单独修改, 我们会发现一个问题: 如果 lazy1 和 lazy2 都存在, 那么我们肯定要都算上, 但我们不知道谁先算谁后算.
对于区间 { 1 , 2 , 3 , 4 , 5 } \{ 1, 2, 3, 4, 5 \} {1,2,3,4,5} , l a z y 1 = 1 , l a z y 2 = 5 lazy1 = 1, lazy2 = 5 lazy1=1,lazy2=5 ,
如果操作是先乘5后加1, 此时 l a z y 1 = 1 , l a z y 2 = 5 lazy1 = 1, lazy2 = 5 lazy1=1,lazy2=5
先算lazy1后算lazy2, { 1 , 2 , 3 , 4 , 5 } → { 2 , 3 , 4 , 5 , 6 } → { 10 , 15 , 20 , 25 , 30 } \{ 1, 2, 3, 4, 5 \} \to \{ 2, 3, 4, 5, 6 \} \to \{ 10, 15, 20, 25, 30 \} {1,2,3,4,5}→{2,3,4,5,6}→{10,15,20,25,30} 不正确
先算lazy2后算lazy1, { 1 , 2 , 3 , 4 , 5 } → { 5 , 10 , 15 , 20 , 25 } → { 6 , 11 , 16 , 21 , 25 } \{ 1, 2, 3, 4, 5 \} \to \{ 5, 10, 15, 20, 25 \} \to \{ 6, 11, 16, 21, 25 \} {1,2,3,4,5}→{5,10,15,20,25}→{6,11,16,21,25} 正确
如果操作是先加1后乘5, 此时 l a z y 1 = 1 , l a z y 2 = 5 lazy1 = 1, lazy2 = 5 lazy1=1,lazy2=5
先算lazy1后算lazy2, { 1 , 2 , 3 , 4 , 5 } → { 2 , 3 , 4 , 5 , 6 } → { 10 , 15 , 20 , 25 , 30 } \{ 1, 2, 3, 4, 5 \} \to \{ 2, 3, 4, 5, 6 \} \to \{ 10, 15, 20, 25, 30 \} {1,2,3,4,5}→{2,3,4,5,6}→{10,15,20,25,30} 正确
先算lazy2后算lazy1, { 1 , 2 , 3 , 4 , 5 } → { 5 , 10 , 15 , 20 , 25 } → { 6 , 11 , 16 , 21 , 25 } \{ 1, 2, 3, 4, 5 \} \to \{ 5, 10, 15, 20, 25 \} \to \{ 6, 11, 16, 21, 25 \} {1,2,3,4,5}→{5,10,15,20,25}→{6,11,16,21,25} 不正确
不能确定是先算哪个.
一种解决方法是, 根据乘法分配律 ( a + b ) × c = a × c + b × c (a + b) \times c = a \times c + b \times c (a+b)×c=a×c+b×c ,我们更新乘法时把 lazy1 也乘一下.
这样我们计算时, 若先乘5后加1, 先乘5时 l a z y 2 = 1 × 5 = 5 , l a z y 1 = 0 × 5 = 0 lazy2 = 1 \times 5 = 5, lazy1 = 0 \times 5 = 0 lazy2=1×5=5,lazy1=0×5=0; 加1时 l a z y 1 = 0 + 1 = 1 lazy1 = 0 + 1 = 1 lazy1=0+1=1; 那么 l a z y 1 = 1 , l a z y 2 = 5 lazy1 = 1, lazy2 = 5 lazy1=1,lazy2=5
先算lazy1后算lazy2, { 1 , 2 , 3 , 4 , 5 } → { 2 , 3 , 4 , 5 , 6 } → { 10 , 15 , 20 , 25 , 30 } \{ 1, 2, 3, 4, 5 \} \to \{ 2, 3, 4, 5, 6 \} \to \{ 10, 15, 20, 25, 30 \} {1,2,3,4,5}→{2,3,4,5,6}→{10,15,20,25,30} 不正确
先算lazy2后算lazy1, { 1 , 2 , 3 , 4 , 5 } → { 5 , 10 , 15 , 20 , 25 } → { 6 , 11 , 16 , 21 , 25 } \{ 1, 2, 3, 4, 5 \} \to \{ 5, 10, 15, 20, 25 \} \to \{ 6, 11, 16, 21, 25 \} {1,2,3,4,5}→{5,10,15,20,25}→{6,11,16,21,25} 正确
若先加1再乘5, 先加1时 l a z y 1 = 0 + 1 = 1 lazy1 = 0 + 1 = 1 lazy1=0+1=1 , 乘5时 l a z y 2 = 1 × 5 = 5 lazy2 = 1 \times 5 = 5 lazy2=1×5=5, l a z y 1 = 1 × 5 = 5 lazy1 = 1 \times 5 = 5 lazy1=1×5=5; 那么 l a z y 1 = 5 , l a z y 2 = 5 lazy1 = 5, lazy2 = 5 lazy1=5,lazy2=5;
先算lazy1后算lazy2, { 1 , 2 , 3 , 4 , 5 } → { 6 , 7 , 8 , 9 , 10 } → { 30 , 35 , 40 , 45 , 50 } \{ 1, 2, 3, 4, 5 \} \to \{ 6, 7, 8, 9, 10 \} \to \{ 30, 35, 40, 45, 50 \} {1,2,3,4,5}→{6,7,8,9,10}→{30,35,40,45,50} 不正确
先算lazy2后算lazy1, { 1 , 2 , 3 , 4 , 5 } → { 5 , 10 , 15 , 20 , 25 } → { 10 , 15 , 20 , 25 , 30 } \{ 1, 2, 3, 4, 5 \} \to \{ 5, 10, 15, 20, 25 \} \to \{ 10, 15, 20, 25, 30 \} {1,2,3,4,5}→{5,10,15,20,25}→{10,15,20,25,30} 正确
这时, 我们发现先乘后加就一定正确.
代码如下
#include <bits/stdc++.h>
using namespace std;
int n, m, p;
long long a[100005];
long long tree[400005], lazy1[400005], lazy2[400005];
void push_up(int id){
tree[id] = tree[id * 2] + tree[id * 2 + 1];
tree[id] %= p;
}
void push_down(int id, int l, int r){
int mid = (l + r) / 2;
if(lazy1[id]){
lazy1[id * 2] = (lazy1[id * 2] * lazy2[id] + lazy1[id]) % p;
lazy1[id * 2 + 1] = (lazy1[id * 2 + 1] * lazy2[id] + lazy1[id]) % p;
}
if(lazy2[id] != 1){
lazy2[id * 2] *= lazy2[id];
lazy2[id * 2] %= p;
lazy2[id * 2 + 1] *= lazy2[id];
lazy2[id * 2 + 1] %= p;
if(!lazy1[id]){
lazy1[id * 2] = (lazy1[id * 2] * lazy2[id] + lazy1[id]) % p;
lazy1[id * 2 + 1] = (lazy1[id * 2 + 1] * lazy2[id] + lazy1[id]) % p;
}
}
tree[id * 2] = (lazy2[id] * tree[id * 2] + ((mid - l + 1) * lazy1[id]) % p) % p; //先乘后加
tree[id * 2 + 1] = (lazy2[id] * tree[id * 2 + 1] + ((r - mid) * lazy1[id]) % p) % p;
lazy1[id] = 0;
lazy2[id] = 1;
}
void build(int id, int l, int r){
if(l == r){
tree[id] = a[l];
return;
}
int mid = (l + r) / 2;
build(id * 2, l, mid);
build(id * 2 + 1, mid + 1, r);
push_up(id);
}
void interval_add(int id, int l, int r, int x, int y, long long v){
if(x <= l && r <= y){
lazy1[id] += v;
lazy1[id] %= p;
tree[id] += ((long long)(r - l + 1)) * v;
tree[id] %= p;
return;
}
push_down(id, l, r);
int mid = (l + r) / 2;
if(x <= mid){
interval_add(id * 2, l, mid, x, y, v);
}
if(y > mid){
interval_add(id * 2 + 1, mid + 1, r, x, y, v);
}
push_up(id);
}
void interval_mul(int id, int l, int r, int x, int y, long long v){
if(x <= l && r <= y){
lazy1[id] *= v; //将 lazy1 也进行乘法
lazy1[id] %= p;
lazy2[id] *= v;
lazy2[id] %= p;
tree[id] *= v;
tree[id] %= p;
return;
}
push_down(id, l, r);
int mid = (l + r) / 2;
if(x <= mid){
interval_mul(id * 2, l, mid, x, y, v);
}
if(y > mid){
interval_mul(id * 2 + 1, mid + 1, r, x, y, v);
}
push_up(id);
}
long long query(int id, int l, int r, int x, int y){
if(x <= l && r <= y){
return tree[id];
}
push_down(id, l, r);
int mid = (l + r) / 2;
long long ans = 0;
if(x <= mid){
ans += query(id * 2, l, mid, x, y);
ans %= p;
}
if(y > mid){
ans += query(id * 2 + 1, mid + 1, r, x, y);
ans %= p;
}
return ans;
}
int main(){
fill(lazy2, lazy2 + 400005, 1);
scanf("%d%d%d" ,&n ,&m ,&p);
for(int i = 1; i <= n; ++i){
scanf("%lld" ,&a[i]);
}
build(1, 1, n);
for(int i = 1; i <= m; ++i){
int op, x, y;
scanf("%d%d%d" ,&op ,&x ,&y);
if(op == 1){
long long k;
scanf("%lld" ,&k);
interval_mul(1, 1, n, x, y, k);
}
if(op == 2){
long long k;
scanf("%lld" ,&k);
interval_add(1, 1, n, x, y, k);
}
if(op == 3){
printf("%lld\n" ,query(1, 1, n, x, y));
}
}
return 0;
}
注意 push_down
的写法为什么不是
void push_down(int id, int l, int r){
int mid = (l + r) / 2;
if(lazy2[id] != 1){
lazy2[id * 2] *= lazy2[id];
lazy2[id * 2] %= p;
lazy2[id * 2 + 1] *= lazy2[id];
lazy2[id * 2 + 1] %= p;
}
if(lazy1[id]){
lazy1[id * 2] = (lazy1[id * 2] * lazy2[id] + lazy1[id]) % p;
lazy1[id * 2 + 1] = (lazy1[id * 2 + 1] * lazy2[id] + lazy1[id]) % p;
}
tree[id * 2] = (lazy2[id] * tree[id * 2] + ((mid - l + 1) * lazy1[id]) % p) % p;
tree[id * 2 + 1] = (lazy2[id] * tree[id * 2 + 1] + ((r - mid) * lazy1[id]) % p) % p;
lazy1[id] = 0;
lazy2[id] = 1;
}
因为就算当前的 l a z y 1 [ i d ] = 0 lazy1[id] = 0 lazy1[id]=0 , 但是 l a z y 2 [ i d ] ≠ 1 lazy2[id] \neq 1 lazy2[id]=1 的话, 那这个乘法标记还是要下放到 lazy1 的两个子节点的, l a z y 1 [ i d ] = 0 lazy1[id] = 0 lazy1[id]=0 不代表 l a z y 1 [ i d ∗ 2 ] lazy1[id * 2] lazy1[id∗2] 或 l a z y 1 [ i d ∗ 2 + 1 ] lazy1[id * 2 + 1] lazy1[id∗2+1] 等于0. 但是 lazy1 只有在 l a z y 1 [ i d ] ≠ 0 lazy1[id] \neq 0 lazy1[id]=0 的时候才下放