线段树 - java通用模板

线段树是一种高效的数据结构,常用于处理范围查询问题,如在数组中快速查找最小值、最大值和求和。本文以LeetCode-307为例,介绍了线段树在解决此类问题时的通用Java模板及其作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

LeetCode-307

线段树是一种非常灵活的数据结构,它可以用于解决多种范围查询问题,比如在对数时间内从数组中找到最小值、最大值、总和、最大公约数、最小公倍数等。

数组 A[0,1,…,n−1] 的线段树是一个二叉树,其中每个节点都包含数组的一个子范围 [i…j] 上的聚合信息(最小值、最大值、总和等),其左、右子节点分别包含范围 [i … (i+j)/2]和 [(i+j)/2 + 1, j]上的信息。

class NumArray {

    public interface Merger<E> {
        E merge(E a, E b);
    }

    public class SegmentTree<E> {
    /*
    使用一个数组表示区间.
    首先,用户可能要获取区间内的某一个元素,或者获取区间的某一个属性,我们在线段树中创建数组,作为区间数组的的副本,用于给出区间数组的某些属性;
    其次,我们想将data数组内的元素(arr数组传递进来的)组织成为一个线段树,那么还需要创建一个tree,大小为 4 * size
    */

        private E[] data;
        //tree结点虽然表示的是一个区间,但是它里面存储的实际上是区间内数据按一定方式计算得到的值。
        private E[] tree;
        private Merger<E> merger;

        //初始化的时候传入一个区间数组
        public SegmentTree(E[] arr , Merger<E> merger) {
            this.merger = merger;

            data = (E[])new Object[arr.length];     // 使用arr初始化data
            for (int i = 0; i < arr.length ; i++) {
                data[i] = arr[i];
            }

            tree = (E[])new Object[arr.length * 4];

            //调用 buildSegmentTree() 方法,完成线段树数组tree每一个结点 data数组区间的确定!
            buildSegmentTree(0 , 0 , data.length-1);    //线段树从0位置开始赋予区间,此时区间为data最大区间:0-data.length-1
        }

        public int getSize() {
            return data.length;
        }

        public E get(int index) {
            if(index < 0 || index >= data.length)
                throw new IllegalArgumentException("Index is illegal.");
            return data[index];
        }

        // 返回满二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
        private int leftChild(int index) {
            return 2 * index + 1;
        }

        // 返回满二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
        private int rightChild(int index) {
            return 2 * index + 2;
        }

//  --------------------------构建
        /**
         在treeIndex的位置创建表示区间[l...r]的线段树结点,即确定 线段树数组tree 在下标 treeIndex 位置所表示的 在data数组 中的区间。

         三个元素:1-要创建的线段树结点在tree数组中的索引;2,3-当前线段树结点表示的 data数组区间的左右端点的索引,注意这个索引是在data数组中的索引
         即在tree数组中,下标为 treeIndex,它在 data数组中表示的区间范围是: l-r

         这个方法设置为私有的,并且我们会从线段树数组 tree的0位置开始,不断确定线段树数组的每一个位置所表示的 data数组的区间,
         直到data数组区间的长度为1,此时遍历到线段树的叶子结点。这个过程中线段树数组可能有很多结点没有表示的区间!

         我们在线段树类的构造方法中就调用 buildSegmentTree() 方法,完成线段树数组每一个结点 data数组区间的确定!
         */
        private void buildSegmentTree(int treeIndex, int l, int r) {
            if(l == r) {
                //此时遍历到线段树的叶子结点,因为只有叶子结点表示的 data数组的区间长度为1,那么直接给叶子结点赋值为 data[l]
                tree[treeIndex] = data[l];
                //注意,到达叶子结点就return结束函数,否则下面继续计算,会导致出现大于tree数组下标的子结点,出现数组下标越界
                return;
            }

            //首先求当前结点的左右孩子结点在tree数组中的下标
            int leftNodeIndex = leftChild(treeIndex);
            int rightNodeIndex = rightChild(treeIndex);
            //求区间中间的分割值
            int mid = l + (r - l) / 2;

            //递归调用buildSegmentTree设置左右孩子结点在 data数组的区间
            buildSegmentTree(leftNodeIndex, l, mid);
            buildSegmentTree(rightNodeIndex, mid+1, r);     //这种设置方法,在区间长度为奇数的时候,左孩子结点区间比右孩子结点区间长度大1

            /**
             下面考虑 线段树tree treeIndex位置的值,这个值与我们的业务相关。
             我们综合当前结点左右孩子结点在data数组中的线段信息,就可以得到当前结点的线段信息!因为当前结点线段的区间,等于它左右孩子线段区间的和!

             我们创建一个融合接口,然后用户可以创建实现这个融合接口的类,自定义它融合左右孩子区间的含义,将这个自定义的对象传递进来,
             那么此时我们tree数组中 treeIndex 位置的值,就可以根据融合的规则,融合左右孩子结点的值得到!
             public interface Merger<E>
             {
             E merge(E a , E b);
             }
             */
            tree[treeIndex] = merger.merge(tree[leftNodeIndex], tree[rightNodeIndex]);
        }

//---------------------------------------查询方法

        // 返回区间[queryL, queryR]的值,这个区间的值是由我们传入的Merger实现类定义的(可能是区间元素和,区间最大值等!)
        public E query(int queryL , int queryR) {
            //将参数不合理的情况排除
            if(queryL < 0 || queryL >= data.length ||
                    queryR < 0 || queryR >= data.length || queryL > queryR)
                throw new IllegalArgumentException("Index is illegal.");

            return query(0 , 0 , data.length-1 , queryL , queryR);
        }

        /**
         * 在以treeIndex为根的线段树中[l...r]的范围里,搜索区间[queryL...queryR]的值,
         * 即对于线段树数组下标为 treeIndex 的结点,它在data数组的区间是 [l...r],查询在 [l...r] 区间内 区间[queryL...queryR]的值
         *
         * 对于每一个 tree 数组结点,我们都传入它在 data数组中所表示的区间,我们其实可以将他们封装成为一个线段树的结点类,
         * 结点类包含当前线段树结点 在tree数组的下标:treeIndex,以及在data数组 区间的范围: [l,r]!
         * 理解:int treeIndex, int l, int r:这三个参数都在表示线段树结点的信息!!!
         */
        //我们向下将最开始的查询区间拆分,直到某一个小区间满足当前 tree结点的区间,我们直接将tree结点的值返回即可!
        private E query(int treeIndex, int l, int r, int queryL, int queryR) {
            //小区间满足当前 tree结点的区间,我们直接将tree结点的值返回即可!
            //当然,tree结点的值由其左右孩子结点区间值融合而来,这里我们在 buildSegmentTree() 方法中已经定义好!
            if(l == queryL && r == queryR)
                return tree[treeIndex];

            //如果当前tree结点的区间大于 查询的小区间 (小于的情况上面方法中被我们排除了!),查询tree结点左右孩子结点的区间

            //首先求左右孩子结点在tree数组中的下标
            int leftNodeIndex = leftChild(treeIndex);
            int rightNodeIndex = rightChild(treeIndex);

            //求当前tree结点分割为左右孩子区间的时候,区间中间的分割值:左区间:[l,mid],右区间:[mid+1 , r]
            int mid = l + (r - l) / 2;

            if(queryL>=mid+1) //当查询的区间全部在右孩子结点的区间
                return query(rightNodeIndex , mid+1 , r , queryL , queryR);
            else if(queryR<=mid) //当查询的区间全部在左孩子结点的区间
                return query(leftNodeIndex , l , mid , queryL , queryR);

            //否则,我们得将查询区间分为2部分,分别在左右孩子区间查找
            E leftResult = query(leftNodeIndex , l , mid , queryL , mid);
            E rightResult = query(rightNodeIndex , mid+1 , r , mid+1 , queryR);

            //按用户定义的融合方法融合2个值并返回!
            return merger.merge(leftResult,rightResult);
        }

//  --------------------------------更新
        /**
         *  将index位置的值,更新为e。
         *  tree 数组的所有叶子结点,存储的值也是data数组的值,我们需要找到 tree数组 表示的线段树相应的叶子结点,
         *  将这个叶子结点的值进行更新。由于这个叶子结点值的更新,它的父亲结点的区间包含这个叶子结点的值,它的父亲结点的值由它的区间值计算得来,
         *  它的父亲结点的区间也由左右孩子区间融合得来,那么也可以融合左右孩子的值得到父亲结点的值
         */
        public void set(int index, E e) {
            if (index < 0 || index >= data.length)
                throw new IllegalArgumentException("Index is illegal");
            //更新data数组
            data[index] = e;
            //调用新的set方法,更新tree数组
            set(0, 0, data.length-1, index, e);
        }

        // 在以treeIndex为根的线段树中更新index的值为e
        private void set(int treeIndex, int l, int r, int index, E e) {

            if(l == r) {//当l=r,说明找到线段树的叶子结点,该结点的区间就是[index,index],该结点值就是data[index],更新这个结点的值
                tree[treeIndex] = e;
                return;//注意,这里找到叶子结点后,必须终止函数,否则继续向下计算会导致数组角标越界
            }

            //没找到tree数组中对应data数组的区间是 [index,index] 的结点,则继续向下缩小区间寻找
            int leftNodeIndex = leftChild(treeIndex);
            int rightNodeIndex = rightChild(treeIndex);
            int mid = l + (r - l) / 2;

            if(index<=mid)
                set(leftNodeIndex, l, mid, index, e);//向左子树寻找
            else
                set(rightNodeIndex, mid+1, r, index, e);//向右子树寻找

            //由于线段树叶子结点的值得更新,必然会影响父亲结点的区间内的值,从而影响夫妻群结点的值
            //那么我们在更新完当前结点的左孩子区间或者是右孩子区间后(必然有一个区间受影响),将左右孩子的值重新融合,更新当前tree结点的值
            tree[treeIndex] = merger.merge(tree[leftNodeIndex] , tree[rightNodeIndex]);//从线段树的下面往上思考这个更新过程
        }


        //----------------------toString方法
        @Override
        public String toString(){
            StringBuilder res = new StringBuilder();
            res.append('[');
            for(int i = 0 ; i < tree.length ; i ++){
                if(tree[i] != null)
                    res.append(tree[i]);
                else
                    res.append("null");

                if(i != tree.length - 1)
                    res.append(", ");
            }
            res.append(']');
            return res.toString();
        }
    }

    private SegmentTree<Integer> segmentTree;

    /** O(n):只初始化1次 */
    public NumArray(int[] nums)
    {
        //首先判断传入的数组合法
        if(nums != null && nums.length>0)
        {
            Integer[] data = new Integer[nums.length];
            for (int i = 0; i < nums.length ; i++)
            {
                data[i] = nums[i];
            }
            /**  使用data数组构造线段树,线段树每个结点都代表data数组的一个区间(其实线段树结点存储的是区间按一定规则融合的值)。同时,定义线段树区间融合的方式  */
            segmentTree = new SegmentTree<Integer>(data , (o1 , o2) -> o1+o2 );//同样,定义线段树数组tree结点的值是 data区间结点值和
        }
    }

    public void update(int i, int val)
    {
        //同样判断线段树是否初始化成功
        if(segmentTree == null)
            throw new IllegalArgumentException("Segment Tree is null");
        segmentTree.set(i , val);
    }

    /** O(logn) */
    public int sumRange(int i, int j)
    {
        //有可能nums数组不合法,初始化的时候没有初始化 SegmentTree,判断一下 SegmentTree是否合法
        if(segmentTree == null)
            throw new IllegalArgumentException("Segment Tree is null");

        return segmentTree.query(i , j);//返回查询结果即可!结果就是区间和。
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值