线段树实现-基于数组
构建数组
为什么构建数组的容量为数组长度的4 * N容量?
计算一颗满二叉树的节点总数,C 为满二叉树的层数,c = h +1(二叉树的层 = 高度 + 1)

计算二叉树的高度 : log2n, 可以看出 层=高度+1
一个平衡二叉树的高度在[log2 n, log2n+1]之间,下图的高度为3, h= [log7, log27+1]–>[2,3]

将最大可能: h = log2n+1 带入 2h+1 -1
= 4 * n -1 近似于 4 * n 总空间数。
初始化数组
private Integer[] elements;
private final int N;
public SegmentArrTree(Integer[] data) {
this.N = 4 * data.length;
elements = new Integer[this.N];
buildTree(data, 0, 0, data.length - 1);
}
构建线段树
private Integer[] elements;
private final int N;
public SegmentArrTree(Integer[] data) {
this.N = 4 * data.length;
elements = new Integer[this.N];
buildTree(data, 0, 0, data.length - 1);
}
public void buildTree(Integer[] data, int index, int start, int end) {
if (end < start) return;
if (start == end) {
elements[index] = data[start];
return;
}
int mid = start + (end - start) / 2;
int leftIndex = 2 * index + 1;
int rightIndex = 2 * index + 2;
buildTree(data, leftIndex, start, mid);
buildTree(data, rightIndex, mid + 1, end);
elements[index] = elements[leftIndex] + elements[rightIndex];
}
线段树查询
private int query(int queryStart, int queryEnd, int treeIndex, int start, int end) {
if (queryStart > queryEnd) return 0;
if (queryStart == start && queryEnd == end) return elements[treeIndex];
int mid = start + (end - start) / 2;
int leftIndex = 2 * treeIndex + 1;
int rightIndex = 2 * treeIndex + 2;
if (queryEnd <= mid) {
return query(queryStart, queryEnd, leftIndex, start, mid);
} else if (queryStart > mid) {
return query(queryStart, queryEnd, rightIndex, mid + 1, end);
} else { //左右部分包含
return query(queryStart, mid, leftIndex, start, mid)
+ query(mid + 1, queryEnd, rightIndex, mid + 1, end);
}
}
public long query(int queryStart, int queryEnd) {
return query(queryStart, queryEnd, 0, 0, N - 1);
}
public static void main(String[] args) {
Integer[] a = {1, 3, 5, 7, 9, 11};
SegmentArrTree tree = new SegmentArrTree(a);
System.out.println(tree.query(0, 3));
}
更新数据
public void update(int node, int left, int right, int idx, int val, Integer[] data) {
if (left == right) { //l=r的时候,表示找到了idx对应的结点
elements[node] = val; //更新树的结点
data[idx] = val; //更新原数组的值
return;
}
int mid = left + (right - left) / 2;
int leftIndex = 2 * node + 1;
int rightIndex = 2 * node + 2;
if (idx <= mid) {
update(leftIndex, left, mid, idx, val, data);
} else {
update(rightIndex, mid + 1, right, idx, val, data);
}
//更新父节点的值
elements[node] = elements[leftIndex] + elements[rightIndex];
}
完整代码
package com.training.segment;
import java.lang.reflect.Array;
import java.util.Arrays;
public class SegmentArrTree {
private Integer[] elements;
private final int N;
public SegmentArrTree(Integer[] data) {
this.N = data.length;
elements = new Integer[4 * this.N];
buildTree(data, 0, 0, data.length - 1);
}
public void buildTree(Integer[] data, int index, int start, int end) {
if (end < start) return;
if (start == end) {
elements[index] = data[start];
return;
}
int mid = start + (end - start) / 2;
int leftIndex = 2 * index + 1;
int rightIndex = 2 * index + 2;
buildTree(data, leftIndex, start, mid);
buildTree(data, rightIndex, mid + 1, end);
elements[index] = elements[leftIndex] + elements[rightIndex];
}
public void print() {
for (Integer e : elements) {
System.out.print(e + "\t");
}
}
private int query(int queryStart, int queryEnd, int treeIndex, int start, int end) {
if (queryStart > queryEnd) return 0;
if (queryStart == start && queryEnd == end) return elements[treeIndex];
int mid = start + (end - start) / 2;
int leftIndex = 2 * treeIndex + 1;
int rightIndex = 2 * treeIndex + 2;
if (queryEnd <= mid) {
return query(queryStart, queryEnd, leftIndex, start, mid);
} else if (queryStart > mid) {
return query(queryStart, queryEnd, rightIndex, mid + 1, end);
} else { //左右部分包含
return query(queryStart, mid, leftIndex, start, mid)
+ query(mid + 1, queryEnd, rightIndex, mid + 1, end);
}
}
public long query(int queryStart, int queryEnd) {
return query(queryStart, queryEnd, 0, 0, N - 1);
}
public void update(int node, int left, int right, int idx, int val, Integer[] data) {
if (left == right) { //l=r的时候,表示找到了idx对应的结点
elements[node] = val; //更新树的结点
data[idx] = val; //更新原数组的值
return;
}
int mid = left + (right - left) / 2;
int leftIndex = 2 * node + 1;
int rightIndex = 2 * node + 2;
if (idx <= mid) {
update(leftIndex, left, mid, idx, val, data);
} else {
update(rightIndex, mid + 1, right, idx, val, data);
}
//更新父节点的值
elements[node] = elements[leftIndex] + elements[rightIndex];
}
public static void main(String[] args) {
Integer[] a = {1, 3, 5, 7, 9, 11};
System.out.println(Arrays.toString(a));
SegmentArrTree tree = new SegmentArrTree(a);
tree.print();
System.out.println();
System.out.println(tree.query(0, 3));
tree.update(0, 0, a.length - 1, 4, 6, a);
tree.print();
System.out.println();
System.out.println(tree.query(0, 3));
}
}