线段树

本文深入介绍了线段树的概念、构建方法及其应用场景。通过实例讲解了如何利用线段树高效地解决区间查询和更新问题,并提供了详细的C++实现代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

线段树英文叫做segment tree。最近研究了下,发现非常有用,面试中考的也比较多。那什么样的题目可以使用线段树呢?它具有以下几个特点,当遇到这样的题目时,可以考虑用线段树。

  • 求一组区间值
  • 原始数据会发生变化

什么是线段树?

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,因此有时需要离散化让空间压缩。
上面是百度百科对线段树的说明,因为它是一种二叉搜索树,所以它更删查改一个数据的时间复杂度是O(lgn)。

构建线段树

比如对于含有n个元素的数据,构建成线段树,左分支范围是[0,n/2],右分支范围是[n/2+1,n-1]。
那对于一个含有4个元素的数组,构建完之后样子如下:

              [0,  3]
             /        \
      [0,  1]           [2, 3]
      /     \           /     \
   [0, 0]  [1, 1]     [2, 2]  [3, 3]

下面是C++的实现:

class SegmentTreeNode {
public:
    int start, end;
    SegmentTreeNode *left, *right;
    SegmentTreeNode(int start, int end) {
        this->start = start, this->end = end;
        this->left = this->right = NULL;
    }
}

class Solution {
public:
    /**
     *@param start, end: Denote an segment / interval
     *@return: The root of Segment Tree
     */
    SegmentTreeNode * build(int start, int end) {
        if(start > end){
            return nullptr;
        }

        if(start == end){
            return new SegmentTreeNode(start, end);
        }

        auto mid = start + ((end - start)>>1);
        auto pNode = new SegmentTreeNode(start, end);
        pNode->left = build(start, mid);
        pNode->right = build(mid+1, end);
        return pNode;
    }
};
更改某个元素

假设线段树节点里保存的是这个range内的最大值,那当修改这个数组中某个元素,如何更新这个线段树,从上往下找到这个位置,修改这个元素,然后更新查找路径上的相关节点。

/**
 * Definition of SegmentTreeNode:
 * class SegmentTreeNode {
 * public:
 *     int start, end, max;
 *     SegmentTreeNode *left, *right;
 *     SegmentTreeNode(int start, int end, int max) {
 *         this->start = start;
 *         this->end = end;
 *         this->max = max;
 *         this->left = this->right = NULL;
 *     }
 * }
 */
class Solution {
public:
    /**
     *@param root, index, value: The root of segment tree and 
     *@ change the node's value with [index, index] to the new given value
     *@return: void
     */
    void modify(SegmentTreeNode *root, int index, int value) {
        updateMax(root, index, value);
    }

    int updateMax(SegmentTreeNode *root, int index, int value){
        if(!root || index < root->start || index > root->end){
            return -1;//error
        }

        if(index == root->start && index == root->end){
            root->max = value;

            return value;
        } else {
            auto mid = root->left->end;
            if(index <= mid){
                auto leftMax = updateMax(root->left, index, value);
                root->max = max(leftMax, root->right->max);
            }else{
                auto rightMax = updateMax(root->right, index, value);
                root->max = max(root->left->max, rightMax);
            }

            return root->max;
        }        
    }
};
线段树查找

假设线段树中节点存放的是区间最大值,如何求区间内的最大值?
其实本质上就是找到这个区间上所有覆盖的区间,然后求最大值。在求覆盖区间的时候就可以做最大值的计算。

/**
 * Definition of SegmentTreeNode:
 * class SegmentTreeNode {
 * public:
 *     int start, end, max;
 *     SegmentTreeNode *left, *right;
 *     SegmentTreeNode(int start, int end, int max) {
 *         this->start = start;
 *         this->end = end;
 *         this->max = max;
 *         this->left = this->right = NULL;
 *     }
 * }
 */
class Solution {
public:
    /**
     *@param root, start, end: The root of segment tree and
     *                         an segment / interval
     *@return: The maximum number in the interval [start, end]
     */
    int query(SegmentTreeNode *root, int start, int end) {
        if(start > end || !root) {
            return -1;//error
        }

        if (root->start == start && root->end == end){
            return root->max;
        }

        auto mid = root->left->end;
        if(end <= mid){
            return query(root->left, start, end);
        } else if(start > mid){
            return query(root->right, start, end);
        } else {
            return max(query(root->left, start, mid), query(root->right, mid+1, end));
        }
    }
};

常见题型

求区间和

当求区间和,第一印象是用preSum方法,但是这种方法只适用于这个数据不发生修改的情况下,如果发生修改,preSum数组就需要更新,而更新这个数组的复杂度就是O(n),这是不可接受的。采用线段树就很好的解决这个问题,因为修改一个线段树的复杂度是lgn。

class CNode{
    public:
    int start;
    int end;
    long long sum;
    CNode* pLeft;
    CNode* pRight;
    CNode(int _start, int _end, int _sum):start(_start), end(_end), sum(_sum), pLeft(nullptr), pRight(nullptr){

    }
};
class Solution {
public:
    /* you may need to use some attributes here */


    /**
     * @param A: An integer vector
     */
    Solution(vector<int> A) {
        if(A.empty()){
            m_pRoot = nullptr;
        } else{
            m_pRoot = build(A, 0, A.size()-1);
        }
    }

    /**
     * @param start, end: Indices
     * @return: The sum from start to end
     */
    long long query(int start, int end) {
        return querySum(m_pRoot, start, end);
    }

    /**
     * @param index, value: modify A[index] to value.
     */
    void modify(int index, int value) {
        updateValue(m_pRoot, index, value);
    }

    private:
        CNode* m_pRoot;
    private:
        CNode* build(vector<int>&A, int start, int end){
            if(start == end){
                return new CNode(start, end, A[start]);
            }

            auto mid = start + ((end - start)>>1);
            auto pCur = new CNode(start, end, 0);
            pCur->pLeft = build(A, start, mid);
            pCur->pRight = build(A, mid+1, end);
            pCur->sum = pCur->pLeft->sum + pCur->pRight->sum;

            return pCur;
        }

        long long querySum(CNode* pRoot, int start, int end){
            if(!pRoot
               || start > end
               || start > pRoot->end
               || end < pRoot->start) {
                   return 0;
            }

            if(start <= pRoot->start && end >= pRoot->end) {
                return pRoot->sum;
            }

            if(pRoot->start == pRoot->end){
                return 0;
            }

            auto mid = pRoot->pLeft->end;
            if(end <= mid){
                return querySum(pRoot->pLeft, start, end);
            } else if(start > mid){
                return querySum(pRoot->pRight, start, end);
            } else {
                return querySum(pRoot->pLeft, start, mid) + querySum(pRoot->pRight, mid+1, end);
            }
        }

        int updateValue(CNode* pRoot, int index, int value){
            if(!pRoot
               || index < pRoot->start
               || index > pRoot->end){
                   return 0;
               }

            if(pRoot->start == pRoot->end && pRoot->start == index){
                auto diff = value - pRoot->sum;
                pRoot->sum += diff;
                return diff;
            }

            auto mid = pRoot->pLeft->end;
            int diff = 0;
            if(index <= mid){
                diff = updateValue(pRoot->pLeft, index, value);
            } else {
                diff = updateValue(pRoot->pRight, index, value);
            }
            pRoot->sum += diff;

            return diff;
        }

};

现在区间节点存放的是sum,其实也可以是max或者min,或者符合区间值。

求小于自身数的个数

对于线段树的题目一般是数组,求这个数组的某个range的值。但是作为range的不光可以是下标,还有可能是数值范围。

给定一个数组,其中数组元素范围是0~10000(这个信息非常重要),对于每个元素A[i],计算数组i元素之前小于A[i]的个数
[1,2,7,8,5], 返回 [0,1,2,3,2]

初步看可以采用扫描的方法,计算到A[i]是遍历0~i-1中小于A[i]的元素个数,这样复杂度为O(n^2)。
因为数组元素有范围,一般想到可以用一个数组来表示全部元素data,而小于某个元素的个数,则是求0~data[A[i]]元素个数,因为它是求这个元素之前所有元素中小于A[i]元素个数,每次增加一个元素时相当于往这个数组中插入一个元素,考虑用线段树,线段节点里保存这个range中数的个数。

class CNode {
    public:
    int low;
    int high;
    int count;
    CNode* pLeft;
    CNode* pRight;
    CNode(int _low, int _high, int _count):low(_low), high(_high), count(_count), pLeft(nullptr), pRight(nullptr){
    }
};
class Solution {
public:
   /**
     * @param A: An integer array
     * @return: Count the number of element before this element 'ai' is
     *          smaller than it and return count number array
     */
    vector<int> countOfSmallerNumberII(vector<int> &A) {
        vector<int> res;
        if(A.empty()){
            return res;
        }

        auto pRoot = buildSegmentTree(0, 10000);
        for(auto value:A){
            if(value < pRoot->low || value > pRoot->high){
                res.push_back(0);
            } else{
                res.push_back(lowerCount(pRoot, value));
                updateSegmentTree(pRoot, value);
            }
        }

        return res;
    }

    CNode* buildSegmentTree(int low, int high){
        auto pCur = new CNode(low, high, 0);
        if(low == high){
            return pCur;
        }

        auto mid = low + ((high-low)>>1);
        pCur->pLeft = buildSegmentTree(low, mid);
        pCur->pRight = buildSegmentTree(mid+1, high);

        return pCur;
    }

    int updateSegmentTree(CNode* pCur, int value) {
        if(!pCur || value < pCur->low || value > pCur->high){
            return 0;
        }

        if(pCur->low == pCur->high){
            if(pCur->low == value){
                pCur->count++;
                return 1;
            } else {
                return 0;
            }
        } else{
            auto mid = pCur->low + ((pCur->high-pCur->low)>>1);
            int diff = 0;
            if(value <= mid){
                diff = updateSegmentTree(pCur->pLeft, value);
            } else {
                diff = updateSegmentTree(pCur->pRight, value);
            }

            pCur->count += diff;
            return diff;
        }
    }

    int lowerCount(CNode* pCur, int value){
        if(!pCur
           || pCur->low > value
           || pCur->count == 0){
            return 0;
        }

        if(pCur->high < value){
            return pCur->count;
        }

        auto mid = pCur->low + ((pCur->high - pCur->low)>>1);
        if(value <= mid){
            return lowerCount(pCur->pLeft, value);
        } else{
            return lowerCount(pCur->pLeft, value) + lowerCount(pCur->pRight, value);
        }
    }

};

总结

线段树是二分查找树的一种,它用来解决区间问题,当某个问题可以转化为几个区间问题,而且元素的变化只会影响某几个区间,不是所有区间,则可以考虑用线段树。常见情况是求数组区间的值。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值