1.原文链接:点击打开链接
线段树的入门级 总结
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
对于线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a,(a+b)/2],右儿子表示的区间为[(a+b)/2+1,b]。因此线段树是平衡二叉树,最后的子节点数目为N,即整个线段区间的长度。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,因此有时需要离散化让空间压缩。
----来自百度百科
【以下以 求区间最大值为例】
先看声明:
- #include <stdio.h>
- #include <math.h>
- const int MAXNODE = 2097152;
- const int MAX = 1000003;
- struct NODE{
- int value;
- int left,right;
- }node[MAXNODE];
- int father[MAX];
【创建线段树(初始化)】:
由于线段树是用二叉树结构储存的,而且是近乎完全二叉树的,所以在这里我使用了数组来代替链表上图中区间上面的红色数字表示了结构体数组中对应的下标。
在完全二叉树中假如一个结点的序号(数组下标)为 I ,那么 (二叉树基本关系)
I 的父亲为 I/2,
I 的另一个兄弟为 I/2*2 或 I/2*2+1
I 的两个孩子为 I*2 (左) I*2+1(右)
有了这样的关系之后,我们便能很方便的写出创建线段树的代码了。
- void BuildTree(int i,int left,int right){
- node[i].left = left;
- node[i].right = right;
- node[i].value = 0;
- if (left == right){
- father[left] = i;
- return;
- }
-
-
- BuildTree(i<<1, left, (int)floor( (right+left) / 2.0));
-
- BuildTree((i<<1) + 1, (int)floor( (right+left) / 2.0) + 1, right);
- }
【单点更新线段树】:
由于我事先用 father[ ] 数组保存过 每单个结点 对应的下标了,因此我只需要知道第几个点,就能知道这个点在结构体中的位置(即下标)了,这样的话,根据之前已知的基本关系,就只需要直接一路更新上去即可。
- void UpdataTree(int ri){
-
- if (ri == 1)return;
- int fi = ri / 2;
- int a = node[fi<<1].value;
- int b = node[(fi<<1)+1].value;
- node[fi].value = (a > b)?(a):(b);
- UpdataTree(ri/2);
- }
【查询区间最大值】:
将一段区间按照建立的线段树从上往下一直拆开,直到存在有完全重合的区间停止。对照图例建立的树,假如查询区间为 [2,5]
红色的区间为完全重合的区间,因为在这个具体问题中我们只需要比较这 三个区间的值 找出 最大值 即可。
- int Max = -1<<20;
- void Query(int i,int l,int r){
- if (node[i].left == l && node[i].right == r){
- Max = (Max < node[i].value)?node[i].value:(Max);
- return ;
- }
- i = i << 1;
- if (l <= node[i].right){
- if (r <= node[i].right)
- Query(i, l, r);
- else
- Query(i, l, node[i].right);
- }
- i += 1;
- if (r >= node[i].left){
- if (l >= node[i].left)
- Query(i, l, r);
- else
- Query(i, node[i].left, r);
- }
- }
2.原文链接:点击打开链接
风格:
maxn是题目给的最大区间,而节点数要开4倍,确切的说……
lson和rson辨别表示结点的左孩子和右孩子。
PushUp(int rt)是把当前结点的信息更新到父节点
PushDown(int rt)是把当前结点的信息更新给孩子结点。
rt表示当前子树的根(root),也就是当前所在的结点。
思想:
对于每个非叶节点所标示的结点 [a,b],其做孩子表示的区间是[a,(a+b)/2],其右孩子表示[(a+b)/2,b].
构造:

离散化和线段树:
题目:x轴上有若干个线段,求线段覆盖的总长度。
普通解法:设置坐标范围[min,max],初始化为0,然后每一段分别染色为1,最后统计1的个数,适用于线段数目少,区间范围小。
离散化的解法:离散化就是一一映射的关系,即将一个大坐标和小坐标进行一一映射,适用于线段数目少,区间范围大。
例如:[10000,22000],[30300,55000],[44000,60000],[55000,60000].
第一步:排序 10000 22000 30300 44000 55000 60000
第二部:编号 1 2 3 4 5 6
第三部:用编号来代替原数,即小数代大数 。
[10000,22000]~[1,2]
[30300,55000]~[3,5]
[44000,60000]~[4,6]
[55000,60000]~[5,6]
然后再用小数进行普通解法的步骤,最后代换回去。
线段树的解法:线段树通过建立线段,将原来染色O(n)的复杂度减小到 log(n),适用于线段数目多,区间范围小的情况。
离散化的线段树:适用于线段数目多,区间范围大的情况。
构造:
动态数据结构:
struct node{
node* left;
node* right;
……
}
静态全局数组模拟(完全二叉树):
struct node{
int left;
int right;
……
}Tree[MAXN]
例如:

线段树与点树:
线段树的每一个结点表示一个点,成为点树,比如说用于求第k小数的线段树。
点树结构体:
struct node{
int l, r;
int c;//用于存放次结点的值,默认为0
}T[3*MAXN];
创建:
创建顺序为先序遍历,即先构造根节点,再构造左孩子,再构造右孩子。
- void construct(int l, int r, int k){
- T[k].l = l;
- T[k].r = r;
- T[k].c = 0;
- if(l == r) return ;
- int m = (l + r) >> 1;
- construct(l, m, k << 1);
- construct(m + 1, r, (k << 1) + 1);
- return ;
- }
- void construct(int l, int r, int k){
- T[k].l = l;
- T[k].r = r;
- T[k].c = 0;
- if(l == r) return ;
- int m = (l + r) >> 1;
- construct(l, m, k << 1);
- construct(m + 1, r, (k << 1) + 1);
- return ;
- }
[A,B,C]:A表示左值,B表示右值,C表示在静态数组中的位置,由此可知,n个点的话大约共有2*n个结点,因此开3*n的结构体一定是够的。
更新值:
- void insert(int d, int k){
-
- if(T[k].l == T[k].r && d == T[k].l){
- T[k].c += 1;
- return ;
- }
- int m = (T[k].l + T[k].r) >> 1;
- if(d <= m) insert(d, k << 1);
- else insert(d, (k << 1) + 1);
-
- T[k].c = T[k << 1].c + T[(k << 1) + 1].c;
- }
- void insert(int d, int k){
-
- if(T[k].l == T[k].r && d == T[k].l){
- T[k].c += 1;
- return ;
- }
- int m = (T[k].l + T[k].r) >> 1;
- if(d <= m) insert(d, k << 1);
- else insert(d, (k << 1) + 1);
-
- T[k].c = T[k << 1].c + T[(k << 1) + 1].c;
- }
查找值:
-
- void search(int d, int k, int& ans)
- {
- if(T[k].l == T[k].r){
- ans = T[k].l;
- ans = T[k].l;
- }
- int m = (T[k].l + T[k].r) >> 1;
-
- if(d > T[(k << 1)].c) search(d - T[k << 1].c, (k << 1) + 1, ans);
- else search(d, k << 1, ans);
- }
-
- void search(int d, int k, int& ans)
- {
- if(T[k].l == T[k].r){
- ans = T[k].l;
- ans = T[k].l;
- }
- int m = (T[k].l + T[k].r) >> 1;
-
- if(d > T[(k << 1)].c) search(d - T[k << 1].c, (k << 1) + 1, ans);
- else search(d, k << 1, ans);
- }
search函数的用法不太懂。
例题解:
(待更新)
四类题型:
1.单点更新 只更新叶子结点,然后把信息用PushUp(int r)这个函数更新上来。
hdu1166:敌兵布阵
线段树功能:update:单点替换 query:区间最值
poj2828
树状数组:
- #include <iostream>
- #include <cstdio>
- #include <string>
- #include <cstring>
- using namespace std;
-
- typedef pair<int, int> PII;
-
- const int maxn = 200000;
-
- int C[maxn + 100];
- int B[maxn + 100];
- int n;
- PII arr[maxn + 100];
-
- int lowbit(int k) { return k & (-k); }
-
- void init() {
- for(int i = 1; i <= n; i++) C[i] = lowbit(i);
- memset(B, -1, n + 10);
- }
-
- void update(int i) {
- while(i <= n) {
- C[i]--;
- i += lowbit(i);
- }
- }
-
-
- int query(int i) {
- int ret = 0;
- while(i > 0) {
- ret += C[i];
- i -= lowbit(i);
- }
- return ret;
- }
-
- void debug() {
- for(int i = 1; i <= n; i++) cout << i << " " << query(i) << endl;
- }
-
-
- void fun(int a, int v) {
- int l = 1, r = n;
- while(l < r) {
- int m = (l + r) >> 1;
- if(query(m) >= a) r = m;
- else l = m + 1;
- }
-
- update(l);
-
-
- B[l] = v;
-
- }
-
-
-
-
- int main() {
- while(~scanf("%d", &n)) {
- init();
- int a, b;
- for(int i = 1; i <= n; i++) {
- scanf("%d%d", &a, &b);
- a++;
- arr[i].first = a;
- arr[i].second = b;
- }
- for(int i = n; i > 0; i--) fun(arr[i].first, arr[i].second);
-
-
- for(int i = 1; i <= n; i++) {
- i == 1 ? printf("%d", B[i]) : printf(" %d", B[i]);
-
-
- }
- puts("");
- }
- return 0;
- }
- #include <iostream>
- #include <cstdio>
- #include <string>
- #include <cstring>
- using namespace std;
-
- typedef pair<int, int> PII;
-
- const int maxn = 200000;
-
- int C[maxn + 100];
- int B[maxn + 100];
- int n;
- PII arr[maxn + 100];
-
- int lowbit(int k) { return k & (-k); }
-
- void init() {
- for(int i = 1; i <= n; i++) C[i] = lowbit(i);
- memset(B, -1, n + 10);
- }
-
- void update(int i) {
- while(i <= n) {
- C[i]--;
- i += lowbit(i);
- }
- }
-
-
- int query(int i) {
- int ret = 0;
- while(i > 0) {
- ret += C[i];
- i -= lowbit(i);
- }
- return ret;
- }
-
- void debug() {
- for(int i = 1; i <= n; i++) cout << i << " " << query(i) << endl;
- }
-
-
- void fun(int a, int v) {
- int l = 1, r = n;
- while(l < r) {
- int m = (l + r) >> 1;
- if(query(m) >= a) r = m;
- else l = m + 1;
- }
-
- update(l);
-
-
- B[l] = v;
-
- }
-
-
-
-
- int main() {
- while(~scanf("%d", &n)) {
- init();
- int a, b;
- for(int i = 1; i <= n; i++) {
- scanf("%d%d", &a, &b);
- a++;
- arr[i].first = a;
- arr[i].second = b;
- }
- for(int i = n; i > 0; i--) fun(arr[i].first, arr[i].second);
-
-
- for(int i = 1; i <= n; i++) {
- i == 1 ? printf("%d", B[i]) : printf(" %d", B[i]);
-
-
- }
- puts("");
- }
- return 0;
- }
poj-3468
- #include <cstdio>
- #include <cstring>
- #include <iostream>
- using namespace std;
-
- #define lson l, m, rt<<1
- #define rson m+1, r, rt<<1|1
-
- typedef long long LL;
-
- const int maxn = 111111;
-
- LL col[maxn<<2];
- LL sum[maxn<<2];
-
- void PushUp(LL rt) {
- sum[rt] = sum[rt<<1] + sum[rt<<1|1];
- }
-
-
-
-
- void PushDown(LL rt, LL m) {
- if(col[rt]) {
-
- col[rt<<1] += col[rt];
- col[rt<<1|1] += col[rt];
- sum[rt<<1] += col[rt] * (m - (m>>1));
- sum[rt<<1|1] += col[rt] * (m>>1);
- col[rt] = 0;
- }
- }
-
- void build(LL l, LL r, LL rt) {
- col[rt] = 0;
-
- if(l == r) {
- scanf("%I64d", &sum[rt]);
-
- return ;
- }
- int m = (l + r) >> 1;
- build(lson);
- build(rson);
- PushUp(rt);
- }
-
- LL query(LL L, LL R, LL l, LL r, LL rt) {
- LL ret = 0;
- if(L <= l && r <= R) {
-
- return sum[rt];
- }
- PushDown(rt, r - l + 1);
- int m = (l + r) >> 1;
- if(L <= m) ret += query(L, R, lson);
- if(R > m) ret += query(L, R, rson);
- return ret;
- }
-
- void update(LL L, LL R, LL c, LL l, LL r, LL rt) {
- if(L <= l && r <= R) {
- sum[rt] += c * (r - l + 1);
- col[rt] += c;
- return ;
- }
- PushDown(rt, r - l + 1);
- int m = (l + r) >> 1;
- if(L <= m) update(L, R, c, lson);
- if(R > m) update(L, R, c, rson);
- PushUp(rt);
- }
-
- void debug(int n) {
- for(int i = 1; i <= (n*3); i++) {
- cout << i << " ";
- }
- cout << endl;
- for(int i = 1; i <= (n*3); i++) {
- cout << col[i] << " ";
- }
- cout << endl << endl;
- for(int i = 1; i <= (n*3); i++) {
- cout << i << " ";
- }
- cout << endl;
- for(int i = 1; i <= (n*3); i++) {
- cout << sum[i] << " ";
- }
- cout << endl;
- }
-
- int main() {
- LL N, Q;
- while(~scanf("%I64d%I64d", &N, &Q)) {
-
- memset(sum, 0, sizeof(sum));
- memset(col, 0, sizeof(col));
- build(1, N, 1);
-
- for(int i = 0; i < Q; i++) {
- char ch[3];
- LL a, b, c;
- scanf("%s", ch);
- if(ch[0] == 'Q') {
- scanf("%I64d%I64d", &a, &b);
- printf("%I64d\n", query(a, b, 1, N, 1));
- }
- else {
- scanf("%I64d%I64d%I64d", &a, &b, &c);
- update(a, b, c, 1, N, 1);
- }
-
- }
- }
- return 0;
- }
- #include <cstdio>
- #include <cstring>
- #include <iostream>
- using namespace std;
-
- #define lson l, m, rt<<1
- #define rson m+1, r, rt<<1|1
-
- typedef long long LL;
-
- const int maxn = 111111;
-
- LL col[maxn<<2];
- LL sum[maxn<<2];
-
- void PushUp(LL rt) {
- sum[rt] = sum[rt<<1] + sum[rt<<1|1];
- }
-
-
-
-
- void PushDown(LL rt, LL m) {
- if(col[rt]) {
-
- col[rt<<1] += col[rt];
- col[rt<<1|1] += col[rt];
- sum[rt<<1] += col[rt] * (m - (m>>1));
- sum[rt<<1|1] += col[rt] * (m>>1);
- col[rt] = 0;
- }
- }
-
- void build(LL l, LL r, LL rt) {
- col[rt] = 0;
-
- if(l == r) {
- scanf("%I64d", &sum[rt]);
-
- return ;
- }
- int m = (l + r) >> 1;
- build(lson);
- build(rson);
- PushUp(rt);
- }
-
- LL query(LL L, LL R, LL l, LL r, LL rt) {
- LL ret = 0;
- if(L <= l && r <= R) {
-
- return sum[rt];
- }
- PushDown(rt, r - l + 1);
- int m = (l + r) >> 1;
- if(L <= m) ret += query(L, R, lson);
- if(R > m) ret += query(L, R, rson);
- return ret;
- }
-
- void update(LL L, LL R, LL c, LL l, LL r, LL rt) {
- if(L <= l && r <= R) {
- sum[rt] += c * (r - l + 1);
- col[rt] += c;
- return ;
- }
- PushDown(rt, r - l + 1);
- int m = (l + r) >> 1;
- if(L <= m) update(L, R, c, lson);
- if(R > m) update(L, R, c, rson);
- PushUp(rt);
- }
-
- void debug(int n) {
- for(int i = 1; i <= (n*3); i++) {
- cout << i << " ";
- }
- cout << endl;
- for(int i = 1; i <= (n*3); i++) {
- cout << col[i] << " ";
- }
- cout << endl << endl;
- for(int i = 1; i <= (n*3); i++) {
- cout << i << " ";
- }
- cout << endl;
- for(int i = 1; i <= (n*3); i++) {
- cout << sum[i] << " ";
- }
- cout << endl;
- }
-
- int main() {
- LL N, Q;
- while(~scanf("%I64d%I64d", &N, &Q)) {
-
- memset(sum, 0, sizeof(sum));
- memset(col, 0, sizeof(col));
- build(1, N, 1);
-
- for(int i = 0; i < Q; i++) {
- char ch[3];
- LL a, b, c;
- scanf("%s", ch);
- if(ch[0] == 'Q') {
- scanf("%I64d%I64d", &a, &b);
- printf("%I64d\n", query(a, b, 1, N, 1));
- }
- else {
- scanf("%I64d%I64d%I64d", &a, &b, &c);
- update(a, b, c, 1, N, 1);
- }
-
- }
- }
- return 0;
- }