线段树(Segment Tree)详解:从原理到Java实现
一、什么是线段树?
线段树是一种二叉树数据结构,用于高效处理区间查询(如区间求和、区间最大值等)和区间更新操作。它能够在O(log n)时间复杂度内完成这些操作,比朴素的O(n)方法高效得多。
典型应用场景:
- 游戏开发中的区域属性批量更新
- 数据分析中的区间统计
- 图形处理中的区域像素计算
二、线段树的核心原理
1. 基本结构
线段树将原始数组递归地划分为半区间的二叉树:
- 每个叶子节点存储原始数组的一个元素
- 每个非叶子节点存储其子节点信息的聚合(如子节点值的和)
对于长度为n的数组,线段树需要约4n的空间(完全二叉树最坏情况)。
2. 惰性传播(Lazy Propagation)
这是线段树的关键优化技术:
- 当进行区间更新时,并不立即更新所有子节点
- 而是将更新操作"暂存"在惰性标记(lazy tag)中
- 只有当真正需要访问子节点时,才将更新"下推"
这种延迟更新的策略大幅减少了不必要的操作。
三、Java实现详解
以下是完整的线段树实现,支持区间加值和区间求和:
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
public class Main{
private static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
private static PrintWriter pw = new PrintWriter((System.out));
public static void main(String[] args) throws IOException {
int[] arr = {1,2,3,2,4,5,6,7,8,0,9,8,5};
SegmentTree tree = new SegmentTree(arr);
tree.show();
while(true) {
String[] l = br.readLine().split(" ");
if("a".equals(l[0])) {//a 99 0 将索引为0的元素加上99
tree.add(Integer.parseInt(l[1]), Integer.parseInt(l[2]));
}else if("A".equals(l[0])) {//A 99 2 8 将索引为[2,8]的元素加上99
tree.add(Integer.parseInt(l[1]), Integer.parseInt(l[2]), Integer.parseInt(l[3]));
}else if("q".equals(l[0])) {//q 1 6 查询索引[1,6]的元素和
int res = tree.query(Integer.parseInt(l[1]), Integer.parseInt(l[2]));
tree.show();
pw.println("答案:" + res);
pw.flush();
}else {
break;
}
}
pw.flush();
pw.close();
br.close();
}
static class SegmentTree{
int[] tree;
int[] lazy;
int n;
public SegmentTree(int[] arr) {
this.n = arr.length;
tree = new int[n<<2];
lazy = new int[n<<2];
build(arr, 0, n-1, 0);//构建线段树
}
public void build(int[] arr, int left, int right, int index) {
if(left == right) {
tree[index] = arr[left];
return;
}
int mid = left + ((right - left) >> 1);
int l_kid = (index << 1) + 1;
int r_kid = l_kid + 1;
build(arr, left, mid, l_kid);
build(arr, mid+1, right, r_kid);
// tree[index] = Math.max(tree[l_kid] , tree[r_kid]);
tree[index] = tree[l_kid] + tree[r_kid];
}
public void add(int derta, int aim) {
add(derta, aim, aim);
}
public void add(int derta, int l, int r) {
add(derta, l, r, 0, n-1, 0);
}
public void add(int derta, int l, int r, int l_node, int r_node, int index) {
/*递归终止条件:目标区间与结点区间匹配或异常情况*/
if(l > r || l_node > r_node) return;
if(l == l_node && r == r_node) {
tree[index] += derta * (r_node - l_node + 1); //更新求和的结果,要考虑当前区间下有多少个元素
lazy[index] += derta; //设置当前结点的懒标记
return; //终止递归一定要记得return!
}
/*下发与分割:下发懒标记到子结点,并更新子节点的数据*/
int mid = l_node + ((r_node - l_node) >> 1);
int l_kid = (index << 1) + 1;
int r_kid = l_kid + 1;
pushDown(l_node, r_node, index, l_kid, mid, r_kid);
int[] AB; //记录分割后新的目标区间[l,r] = [AB[0], AB[1]]
/*继续递归:继续操作含目标区间的子结点*/
if((AB = getAB(l, r, l_node, mid)) != null) add(derta, AB[0], AB[1], l_node, mid, l_kid);
if((AB = getAB(l, r, mid+1, r_node)) != null) add(derta, AB[0], AB[1], mid+1, r_node, r_kid);
/*回溯:子树完成修改后,更新当前结点的值*/
tree[index] = tree[l_kid] + tree[r_kid];
}
private int[] getAB(int l, int r, int l_node, int r_node) {
int[] AB = new int[2];// 3 3 0 6
AB[0] = Math.max(l, l_node);
AB[1] = Math.min(r, r_node);
return AB[0] > AB[1] ? null : AB;
}
public int query(int l, int r) {
return query(l, r, 0, n-1, 0);
}
public int query(int l, int r, int l_node, int r_node, int index) {
/*递归终止条件:目标区间与结点区间匹配或异常情况*/
if(l > r || l_node > r_node) return 0 /*Integer.MIN_VALUE*/;
if(l == l_node && r == r_node) return tree[index];
/*下发与分割:下发懒标记到子结点,并更新子节点的数据*/
int mid = l_node + ((r_node - l_node) >> 1);
int l_kid = (index << 1) + 1;
int r_kid = l_kid + 1;
pushDown(l_kid, r_kid, index, l_kid, mid, r_kid);
int[] AB;
/*继续递归:继续操作含目标区间的子结点*/
int a = ((AB = getAB(l, r, l_node, mid)) != null) ? query(AB[0], AB[1], l_node, mid, l_kid) : 0/*Integer.MIN_VALUE*/;
int b = ((AB = getAB(l, r, mid+1, r_node)) != null) ? query(AB[0], AB[1], mid+1, r_node, r_kid) : 0/*Integer.MIN_VALUE*/;
/*回溯:子树完成修改后,更新当前结点的值*/
tree[index] = tree[l_kid] + tree[r_kid];
return a + b;
}
public void show() {
show(0, n-1, 0);
pw.println();
pw.flush();
}
public void show(int l_node, int r_node, int index) {
if(l_node == r_node) {
pw.print(tree[index] + " ");
return;
}
int mid = l_node + ((r_node - l_node) >> 1);
int l_kid = (index << 1) + 1;
int r_kid = l_kid + 1;
pushDown(l_node, r_node, index, l_kid, mid, r_kid);
show(l_node, mid, l_kid);
show(mid+1, r_node, r_kid);
}
/**
* 将懒标记下发到子节点,用于区间更新
* 此方法主要目的是解决如何高效地将对某一区间的所有元素进行相同的操作(如加、减某个值)
* 通过懒标记,可以延迟实际的更新操作,直到需要访问具体节点时才进行更新,从而提高效率
*
* @param l_node 当前节点覆盖的左边界
* @param r_node 当前节点覆盖的右边界
* @param index 当前节点在数组中的索引
* @param l_kid 左子节点在数组中的索引
* @param mid 当前节点覆盖区间的中点
* @param r_kid 右子节点在数组中的索引
*/
private void pushDown(int l_node, int r_node, int index, int l_kid, int mid, int r_kid) {
if(lazy[index] != 0) { //更新子节点的数据:值和懒标记都要更新
tree[l_kid] += lazy[index] * (mid - l_node + 1);//更新求和值
tree[r_kid] += lazy[index] * (r_node - mid);
lazy[l_kid] += lazy[index];
lazy[r_kid] += lazy[index];
lazy[index] = 0;
}
}
}
}
四、关键点解析
- 空间分配:线段树需要4n的空间确保足够存储
- 惰性标记下推:在访问子节点前必须处理惰性标记
- 区间分割:通过计算交集确定需要处理的子区间
- 回溯更新:子节点修改后要更新父节点的值
五、常见问题
-
为什么需要4n空间?
- 考虑最坏情况下线段树是完全二叉树
- 对于n个叶子节点,需要2n-1个节点
- 但线段树不一定是完全二叉树,所以需要额外空间
-
如何处理其他区间操作?
- 区间最大值:修改聚合方式和惰性传播逻辑
- 区间乘法:额外处理乘法和加法的混合运算就可以了
-
为什么查询时也要下推标记?
- 确保查询到的数据是最新的
- 未下推的标记意味着该区间的子节点尚未更新
六、性能分析
操作 | 时间复杂度 | 说明 |
---|---|---|
构建 | O(n) | 需要初始化所有节点 |
单点更新 | O(log n) | 相当于长度为1的区间更新 |
区间更新 | O(log n) | 惰性传播是关键优化 |
区间查询 | O(log n) | 与区间更新类似 |