线段树有很多种,但是最好理解、最好写的还是下面这种开4N空间的自顶向下线段树。(只支持单点更新,维护区间和)
public class SegmentTreeTopDown {
private int[] tree;
private int n;
public SegmentTreeTopDown(int[] nums) {
// in the worst case it occupy 4n - 5, see https://blog.youkuaiyun.com/gl486546/article/details/78243098
this.n = nums.length;
this.tree = new int[n * 4];
buildTree(nums, 1, 0, n - 1); // root stores at index 0
}
private void buildTree(int[] nums, int root, int left, int right) {
// root: index of current root in tree
// [left, right]: covered range in nums
if (left == right) {
this.tree[root] = nums[left];
return;
}
int mid = (left + right) / 2;
int lc = root * 2, rc = root * 2 + 1;
buildTree(nums, lc, left, mid);
buildTree(nums, rc, mid + 1, right);
this.tree[root] = this.tree[lc] + this.tree[rc];
}
public int query(int l, int r) {
return query(1, 0, n - 1, l, r);
}
// return sum of the part of query range inside stored range of root
private int query(int root, int nLeft, int nRight, int qLeft, int qRight) {
// root: root index in tree
// [nLeft, nRight]: stored range of root
// [qLeft, qRight]: query range
if (qLeft > nRight || nLeft > qRight) {
// no intersection
return 0;
}
if (qLeft <= nLeft && qRight >= nRight) {
// query range all in stored range
return tree[root];
}
int mid = (nLeft + nRight) / 2;
return query(root * 2, nLeft, mid, qLeft, qRight) +
query(root * 2 + 1, mid + 1, nRight, qLeft, qRight);
}
public void update(int pos, int val) {
int left = 0, right = n - 1;
while (left < right) {
int mid = (left + right) / 2;
if (pos <= mid) {
right = mid;
} else {
left = mid + 1;
}
}
// left = right = that leaf we need
this.tree[left] = val;
// update bottom up
int node = left / 2;
while (node > 0) {
this.tree[node] = this.tree[node * 2] + this.tree[node * 2 + 1];
node /= 2;
}
}
public static void main(String[] args) {
SegmentTreeTopDown tree = new SegmentTreeTopDown(new int[] {1, 3, 5});
System.out.println(tree.query(0, 2));
tree.update(1, 2);
System.out.println(tree.query(2, 2));
}
}