线段树英文叫做segment tree。最近研究了下,发现非常有用,面试中考的也比较多。那什么样的题目可以使用线段树呢?它具有以下几个特点,当遇到这样的题目时,可以考虑用线段树。
- 求一组区间值
- 原始数据会发生变化
什么是线段树?
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,因此有时需要离散化让空间压缩。
上面是百度百科对线段树的说明,因为它是一种二叉搜索树,所以它更删查改一个数据的时间复杂度是O(lgn)。
构建线段树
比如对于含有n个元素的数据,构建成线段树,左分支范围是[0,n/2],右分支范围是[n/2+1,n-1]。
那对于一个含有4个元素的数组,构建完之后样子如下:
[0, 3]
/ \
[0, 1] [2, 3]
/ \ / \
[0, 0] [1, 1] [2, 2] [3, 3]
下面是C++的实现:
class SegmentTreeNode {
public:
int start, end;
SegmentTreeNode *left, *right;
SegmentTreeNode(int start, int end) {
this->start = start, this->end = end;
this->left = this->right = NULL;
}
}
class Solution {
public:
/**
*@param start, end: Denote an segment / interval
*@return: The root of Segment Tree
*/
SegmentTreeNode * build(int start, int end) {
if(start > end){
return nullptr;
}
if(start == end){
return new SegmentTreeNode(start, end);
}
auto mid = start + ((end - start)>>1);
auto pNode = new SegmentTreeNode(start, end);
pNode->left = build(start, mid);
pNode->right = build(mid+1, end);
return pNode;
}
};
更改某个元素
假设线段树节点里保存的是这个range内的最大值,那当修改这个数组中某个元素,如何更新这个线段树,从上往下找到这个位置,修改这个元素,然后更新查找路径上的相关节点。
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end, max;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end, int max) {
* this->start = start;
* this->end = end;
* this->max = max;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/**
*@param root, index, value: The root of segment tree and
*@ change the node's value with [index, index] to the new given value
*@return: void
*/
void modify(SegmentTreeNode *root, int index, int value) {
updateMax(root, index, value);
}
int updateMax(SegmentTreeNode *root, int index, int value){
if(!root || index < root->start || index > root->end){
return -1;//error
}
if(index == root->start && index == root->end){
root->max = value;
return value;
} else {
auto mid = root->left->end;
if(index <= mid){
auto leftMax = updateMax(root->left, index, value);
root->max = max(leftMax, root->right->max);
}else{
auto rightMax = updateMax(root->right, index, value);
root->max = max(root->left->max, rightMax);
}
return root->max;
}
}
};
线段树查找
假设线段树中节点存放的是区间最大值,如何求区间内的最大值?
其实本质上就是找到这个区间上所有覆盖的区间,然后求最大值。在求覆盖区间的时候就可以做最大值的计算。
/**
* Definition of SegmentTreeNode:
* class SegmentTreeNode {
* public:
* int start, end, max;
* SegmentTreeNode *left, *right;
* SegmentTreeNode(int start, int end, int max) {
* this->start = start;
* this->end = end;
* this->max = max;
* this->left = this->right = NULL;
* }
* }
*/
class Solution {
public:
/**
*@param root, start, end: The root of segment tree and
* an segment / interval
*@return: The maximum number in the interval [start, end]
*/
int query(SegmentTreeNode *root, int start, int end) {
if(start > end || !root) {
return -1;//error
}
if (root->start == start && root->end == end){
return root->max;
}
auto mid = root->left->end;
if(end <= mid){
return query(root->left, start, end);
} else if(start > mid){
return query(root->right, start, end);
} else {
return max(query(root->left, start, mid), query(root->right, mid+1, end));
}
}
};
常见题型
求区间和
当求区间和,第一印象是用preSum方法,但是这种方法只适用于这个数据不发生修改的情况下,如果发生修改,preSum数组就需要更新,而更新这个数组的复杂度就是O(n),这是不可接受的。采用线段树就很好的解决这个问题,因为修改一个线段树的复杂度是lgn。
class CNode{
public:
int start;
int end;
long long sum;
CNode* pLeft;
CNode* pRight;
CNode(int _start, int _end, int _sum):start(_start), end(_end), sum(_sum), pLeft(nullptr), pRight(nullptr){
}
};
class Solution {
public:
/* you may need to use some attributes here */
/**
* @param A: An integer vector
*/
Solution(vector<int> A) {
if(A.empty()){
m_pRoot = nullptr;
} else{
m_pRoot = build(A, 0, A.size()-1);
}
}
/**
* @param start, end: Indices
* @return: The sum from start to end
*/
long long query(int start, int end) {
return querySum(m_pRoot, start, end);
}
/**
* @param index, value: modify A[index] to value.
*/
void modify(int index, int value) {
updateValue(m_pRoot, index, value);
}
private:
CNode* m_pRoot;
private:
CNode* build(vector<int>&A, int start, int end){
if(start == end){
return new CNode(start, end, A[start]);
}
auto mid = start + ((end - start)>>1);
auto pCur = new CNode(start, end, 0);
pCur->pLeft = build(A, start, mid);
pCur->pRight = build(A, mid+1, end);
pCur->sum = pCur->pLeft->sum + pCur->pRight->sum;
return pCur;
}
long long querySum(CNode* pRoot, int start, int end){
if(!pRoot
|| start > end
|| start > pRoot->end
|| end < pRoot->start) {
return 0;
}
if(start <= pRoot->start && end >= pRoot->end) {
return pRoot->sum;
}
if(pRoot->start == pRoot->end){
return 0;
}
auto mid = pRoot->pLeft->end;
if(end <= mid){
return querySum(pRoot->pLeft, start, end);
} else if(start > mid){
return querySum(pRoot->pRight, start, end);
} else {
return querySum(pRoot->pLeft, start, mid) + querySum(pRoot->pRight, mid+1, end);
}
}
int updateValue(CNode* pRoot, int index, int value){
if(!pRoot
|| index < pRoot->start
|| index > pRoot->end){
return 0;
}
if(pRoot->start == pRoot->end && pRoot->start == index){
auto diff = value - pRoot->sum;
pRoot->sum += diff;
return diff;
}
auto mid = pRoot->pLeft->end;
int diff = 0;
if(index <= mid){
diff = updateValue(pRoot->pLeft, index, value);
} else {
diff = updateValue(pRoot->pRight, index, value);
}
pRoot->sum += diff;
return diff;
}
};
现在区间节点存放的是sum,其实也可以是max或者min,或者符合区间值。
求小于自身数的个数
对于线段树的题目一般是数组,求这个数组的某个range的值。但是作为range的不光可以是下标,还有可能是数值范围。
给定一个数组,其中数组元素范围是0~10000(这个信息非常重要),对于每个元素A[i],计算数组i元素之前小于A[i]的个数
[1,2,7,8,5], 返回 [0,1,2,3,2]
初步看可以采用扫描的方法,计算到A[i]是遍历0~i-1中小于A[i]的元素个数,这样复杂度为O(n^2)。
因为数组元素有范围,一般想到可以用一个数组来表示全部元素data,而小于某个元素的个数,则是求0~data[A[i]]元素个数,因为它是求这个元素之前所有元素中小于A[i]元素个数,每次增加一个元素时相当于往这个数组中插入一个元素,考虑用线段树,线段节点里保存这个range中数的个数。
class CNode {
public:
int low;
int high;
int count;
CNode* pLeft;
CNode* pRight;
CNode(int _low, int _high, int _count):low(_low), high(_high), count(_count), pLeft(nullptr), pRight(nullptr){
}
};
class Solution {
public:
/**
* @param A: An integer array
* @return: Count the number of element before this element 'ai' is
* smaller than it and return count number array
*/
vector<int> countOfSmallerNumberII(vector<int> &A) {
vector<int> res;
if(A.empty()){
return res;
}
auto pRoot = buildSegmentTree(0, 10000);
for(auto value:A){
if(value < pRoot->low || value > pRoot->high){
res.push_back(0);
} else{
res.push_back(lowerCount(pRoot, value));
updateSegmentTree(pRoot, value);
}
}
return res;
}
CNode* buildSegmentTree(int low, int high){
auto pCur = new CNode(low, high, 0);
if(low == high){
return pCur;
}
auto mid = low + ((high-low)>>1);
pCur->pLeft = buildSegmentTree(low, mid);
pCur->pRight = buildSegmentTree(mid+1, high);
return pCur;
}
int updateSegmentTree(CNode* pCur, int value) {
if(!pCur || value < pCur->low || value > pCur->high){
return 0;
}
if(pCur->low == pCur->high){
if(pCur->low == value){
pCur->count++;
return 1;
} else {
return 0;
}
} else{
auto mid = pCur->low + ((pCur->high-pCur->low)>>1);
int diff = 0;
if(value <= mid){
diff = updateSegmentTree(pCur->pLeft, value);
} else {
diff = updateSegmentTree(pCur->pRight, value);
}
pCur->count += diff;
return diff;
}
}
int lowerCount(CNode* pCur, int value){
if(!pCur
|| pCur->low > value
|| pCur->count == 0){
return 0;
}
if(pCur->high < value){
return pCur->count;
}
auto mid = pCur->low + ((pCur->high - pCur->low)>>1);
if(value <= mid){
return lowerCount(pCur->pLeft, value);
} else{
return lowerCount(pCur->pLeft, value) + lowerCount(pCur->pRight, value);
}
}
};
总结
线段树是二分查找树的一种,它用来解决区间问题,当某个问题可以转化为几个区间问题,而且元素的变化只会影响某几个区间,不是所有区间,则可以考虑用线段树。常见情况是求数组区间的值。