线段树的初级操作
简介:
问题背景:
把问题建模成数轴上的问题或者数列的问题。一般是每次对数轴或者数列的一个区间进行相同的处理。
线段树的结构:
一棵平衡的二叉树。
举例说明:
区间:处理前闭后开的区间[a,b)[a,b)[a,b)
线段树结点T(a,b):维护原序列中[a,b)[a,b)[a,b)的信息
内部结点:对于结点T(a,b),有b−a>1b-a>1b−a>1,那么T(a,b)的左孩子是T(a,(a+b)/2),右孩子是T((a+b)/2,b)
叶结点: 对于结点T(a,b),有b−a=1b-a=1b−a=1
因此,假设一个序列有n个结点,那么根结点是T(1,n+1),第k个叶子结点是T(k,k+1)
性质:
结点数:小于等于2n,n是序列中元素的个数
深度:线段树去除最后一层后,是满二叉树,h=1+log2(n−1)h=1+\log_2(n-1)h=1+log2(n−1)
线段分解数量级:可以把任意的长度是LLL的线段分解成不超过2log2L2\log_2L2log2L条的子线段。可以让绝大多数查询在O(log2n)O(\log_2n)O(log2n)内解决
存储空间:O(n)O(n)O(n)
实现方式:
这里,以修改元素的值为例。修改某个或某个区间的元素的值,然后查询某个区间的元素的和。
修改单个元素的值
问题1:
长度位n的数列,初始化全是0.现在执行m次操作,每次执行下面两种操作之一:
- 数列中某个数加上某个值
- 询问给定区间中所有数的和
解析:
朴素算法的复杂度是O(mn)O(mn)O(mn),因为传统的线性查找时间是O(n)O(n)O(n)。引入线段树,查询的复杂度是O(log2n)O(\log_2n)O(log2n)。在使用修改算法时,最好是人为地添加下标x的范围合法性判断;否则,如果x小于最小下标,delta会累加到第一个元素上,如果大于等于最大下标,delta会累加到最后一个元素上!!!
#include <iostream>
#include <utility>
#include <memory>
struct Node {
int l, r, val;
std::shared_ptr<Node> lc, rc;
Node(int _l = 0, int _r = 0, int _v = 0):
l(_l), r(_r), val(_v) {}
};
void build(std::shared_ptr<Node>& cur, int l, int r) {
cur = std::make_shared<Node>(l, r);
if (l + 1 < r) {
int mid = l + (r - l) / 2;
build(cur->lc, l, mid);
build(cur->rc, mid, r);
}
}
int query(const std::shared_ptr<Node>& cur, int l, int r) {
if (l <= cur->l && cur->r <= r) {
return cur->val;
}
int ans = 0;
int mid = cur->l + (cur->r - cur->l) / 2;
if (l < mid) {
ans += query(cur->lc, l, r);
}
if (r > mid) {
ans += query(cur->rc, l, r);
}
return ans;
}
void change(std::shared_ptr<Node>& cur, int x, int delta) {
if (cur->l + 1 == cur->r) {
cur->val += delta;
return;
}
int mid = cur->l + (cur->r - cur->l) / 2;
if (x < mid) {
change(cur->lc, x, delta);
}
if (x >= mid) {
change(cur->rc, x, delta);
}
cur->val = cur->lc->val + cur->rc->val;
}
int main() {
std::shared_ptr<Node> root;
build(root, 1, 11);
for (int i = 1; i <= 10; ++i) {
change(root, i, 1);
}
for(int i = 1; i <= 10; ++i) {
std::cout << i << "th: " << query(root, i, i + 1) << std::endl;
}
return 0;
}
修改整个区间的值
问题2:
长度位n的数列,初始化全是0.现在执行m次操作,每次执行下面两种操作之一:
- 数列中某个区间的所有数加上某个值
- 询问给定区间中所有数的和
解析:
如果修改的是一个区间的值,假设区间长度是k,使用上面的算法,由于每次查询的时间是O(log2n)O(\log_2n)O(log2n),所以处理一个区间的复杂度是O(klog2n)O(k\log_2n)O(klog2n),如果k很大,复杂度甚至会超过朴素的模拟算法,因此引入下面的改进算法。
算法的核心在于不直接计算叶子结点的值,而是每个结点增加一个delta域,用于记录当前结点的延迟修改量。只有当前结点需要继续向下查询或者更改当前结点的子区间时,才把当前结点的延迟修改量传递给子区间,同时当前结点的修改量清零,如果不清零,会导致重复计算!
这种算法,保证了不会有过多的递归下降而浪费时间。只有需要向下时,才会根据父结点累积的增量,计算子结点有关的值,保证了时间复杂度较低,减少不必要的递归过程。
#include <bits/stdc++.h>
using namespace std;
struct Node {
int l, r, sum, delta;
struct Node *lc, *rc;
Node(): l(0), r(0), sum(0), delta(0), lc(nullptr), rc(nullptr) {}
};
void build(Node* &cur, int l, int r) { // 建立算法和单个元素的一样
cur = new Node;
cur->l = l;
cur->r = r;
if(l + 1 < r) {
build(cur->lc, l, (l + r) / 2);
build(cur->rc, (l + r) / 2, r);
}
}
void update(Node* cur) { // 更新算法,处理累计状态
// 向下传递累积和,等效成后计算的,注意是累加
cur->lc->sum += cur->delta * (cur->lc->r - cur->lc->l);
cur->rc->sum += cur->delta * (cur->rc->r - cur->rc->l);
// 孩子的delta状态进行累计,注意是累加
cur->lc->delta += cur->delta;
cur->rc->delta += cur->delta;
// 一定要把父结点的清零
cur->delta = 0;
}
void change(Node* cur, int l, int r, int delta) {
if(l <= cur->l && cur->r <= r) {
cur->sum += delta * (cur->r - cur->l);
cur->delta += delta;
} else {
if(cur->delta != 0) { // 先检查当前结点是否有孩子的累计状态,有的话向下传递
update(cur);
}
if(l < (cur->l + cur->r) / 2) {
change(cur->lc, l, r, delta);
}
if(r > (cur->l + cur->r) / 2) { // 注意这里没有等号!!!!
change(cur->rc, l, r, delta);
}
cur->sum = cur->lc->sum + cur->rc->sum;
}
}
int query(Node* cur, int l, int r) {
if(l <= cur->l && cur->r <= r) {
return cur->sum;
} else {
if(cur->delta != 0) { // 检查是否有孩子结点的累计状态
update(cur); // 计算之前延迟的累积和
}
int ans = 0;
if(l < (cur->l + cur->r) / 2) {
ans += query(cur->lc, l, r);
}
if(r > (cur->l + cur->r) / 2) {
ans += query(cur->rc, l, r);
}
return ans;
}
}
int main() {
Node* root = nullptr;
build(root, 1, 11);
for(int i = 1; i <= 10; ++i) {
change(root, i, i + 3, 1);
}
for(int i = 1; i <= 10; ++i) {
cout << i << "th:" << query(root, i, i + 1) << endl;
}
cout << "sum 1~10:" << query(root, 1, 11) << endl;
return 0;
}
更一般的方法:
对于当前区间[l,r)
if 达到某种边界条件(比如叶子结点或整个区间被完全包含)
then 对维护或者询问进行相应的处理
else
将第二类标记传递下去(注意,查询的过程也要处理)
根据区间的关系,对两个孩子递归地处理
利用递推关系,根据孩子结点的情况维护第一类信息
根据一般方法改进的问题:
问题3:
长度位n的数列,初始化全是0.现在执行m次操作,每次执行下面两种操作之一:
- 数列中某个区间的所有数加上某个值
- 数列中某个区间的所有数改成某个值
- 询问给定区间中所有数的和
- 询问给定区间的最值
#include <bits/stdc++.h>
using namespace std;
const int INF = 10000000;
struct Node {
int l, r, value, sum, maxm, minm, delta;
bool tag;
struct Node *lc, *rc;
Node(): tag(false), l(0), r(0), maxm(0), minm(0),
delta(0), value(0), sum(0), lc(nullptr), rc(nullptr) {}
};
void build(Node* &cur, int l, int r) {
cur = new Node;
cur->l = l;
cur->r = r;
if(l + 1 < r) {
build(cur->lc, l, (l + r) / 2);
build(cur->rc, (l + r) / 2, r);
}
}
// 统一更新
void update(Node* cur) {
// 更新值和最值
cur->lc->value = cur->rc->value = cur->value;
cur->lc->maxm = cur->rc->value = cur->value;
cur->lc->minm = cur->rc->minm = cur->value;
cur->lc->tag = cur->rc->tag = true;
cur->tag = false;
cur->lc->sum += cur->delta * (cur->lc->r - cur->lc->l);
cur->rc->sum += cur->delta * (cur->rc->r - cur->rc->r);
cur->lc->delta += cur->delta;
cur->rc->delta += cur->delta;
cur->delta = 0;
}
// 把区间的值改成value
void change_to(Node* cur, int l, int r, int value) {
if(l <= cur->l && cur->r <= r) {
cur->value = value;
cur->maxm = cur->minm = value;
cur->tag = true;
} else {
if(cur->tag) {
update(cur);
}
if(l < (cur->l + cur->r) / 2) {
change_to(cur->lc, l, r, value);
}
if(r > (cur->l + cur->r) / 2) {
change_to(cur->rc, l, r, value);
}
cur->maxm = max(cur->lc->maxm, cur->rc->maxm);
cur->minm = min(cur->lc->minm, cur->rc->minm);
}
}
// 更改累积和
void change_sum(Node* cur, int l, int r, int delta) {
if(l <= cur->l && cur->r <= r) {
cur->sum += delta * (cur->r - cur->l);
cur->delta += delta;
} else {
if(cur->delta != 0) {
update(cur);
}
if(l < (cur->l + cur->r) / 2) {
change_sum(cur->lc, l, r, delta);
}
if(r > (cur->l + cur->r) / 2) {
change_sum(cur->rc, l, r, delta);
}
cur->sum = cur->lc->sum + cur->rc->sum;
}
}
// 查询最大值
int query_max(Node* cur, int l, int r) {
if(l <= cur->l && cur->r <= r) {
return cur->maxm;
} else {
if(cur->tag) {
update(cur);
}
int ml = -INF, mr = -INF;
if(l < (cur->l + cur->r) / 2) {
ml = query_max(cur->lc, l, r);
}
if(r > (cur->l + cur->r) / 2) {
mr = query_max(cur->rc, l, r);
}
return max(ml, mr);
}
}
// 查询最小值
int query_min(Node* cur, int l, int r) {
if(l <= cur->l && cur->r <= r) {
return cur->minm;
} else {
if(cur->tag) {
update(cur);
}
int ml = INF, mr = INF;
if(l < (cur->l + cur->r) / 2) {
ml = query_max(cur->lc, l, r);
}
if(r > (cur->l + cur->r) / 2) {
mr = query_max(cur->rc, l, r);
}
return min(ml, mr);
}
}
// 查询和
int query_sum(Node* cur, int l, int r) {
if(l <= cur->l && cur->r <= r) {
return cur->sum;
} else {
if(cur->delta != 0) {
update(cur);
}
int ans = 0;
if(l < (cur->l + cur->r) / 2) {
ans += query_sum(cur->lc, l, r);
}
if(r > (cur->l + cur->r) / 2) {
ans += query_sum(cur->rc, l, r);
}
return ans;
}
}
int main() {
srand(time(unsigned(0)));
Node* root = nullptr;
build(root, 1, 11);
cout << "rand res:" << endl;
for(int i = 1; i <= 10; ++i) {
int t = rand() % 30;
cout << i << "th:" << t << endl;
change_to(root, i, i + 1, t);
change_sum(root, i, i + 1, t);
}
cout << "max:" << query_max(root, 1, 11) << endl;
cout << "min:" << query_min(root, 1, 11) << endl;
cout << "sum:" << query_sum(root, 1, 11) << endl;
return 0;
}