树状数组,是一种快速解决单点更新区间查询(或区间更新单点查询)的数据结构.
设a1,a2,a3,...,ana1,a2,a3,...,an为长度为n(≤105)≤105)的数列. 现在有以下两种类型的操作m(≤104≤104)个:
- ADD A B (修改数列中A位置上的元素值增加B)
- QUERY A B (询问区间[A, B]的元素值和)
直接求解: O(1)的更新,m次O(n)的查询复杂度. O(mn) 不能接受.
引入辅助数组C. 其中:
C1=a1C1=a1
C2=a1+a2C2=a1+a2
C3=a3C3=a3
C4=a1+a2+a3+a4C4=a1+a2+a3+a4
C5=a5C5=a5
C6=a5+a6C6=a5+a6
C7=a7C7=a7
C8=a1+a2+a3+a4+a5+a6+a7+a8C8=a1+a2+a3+a4+a5+a6+a7+a8
......
C16=a1+a2+a3+a4+a5+a6+a7+a8+a9+a10+a11+a12+a13+a14+a15+a16C16=a1+a2+a3+a4+a5+a6+a7+a8+a9+a10+a11+a12+a13+a14+a15+a16
对于任意的n,Cn=an−2k+1+an−2k+2+...+an=∑ni=n−2k+1ai,其中k为n二进制表示中末尾0的个数对于任意的n,Cn=an−2k+1+an−2k+2+...+an=∑i=n−2k+1nai,其中k为n二进制表示中末尾0的个数
对于数组C可以构造一棵树, 如下图所示(图片摘自百度百科):
上图中由C数组构造的一棵树满足: 对于节点Ci和Cj,如果j+2k==i,则Ci是Cj的父节点(其中k为j的末尾0个数)对于节点Ci和Cj,如果j+2k==i,则Ci是Cj的父节点(其中k为j的末尾0个数)
设lowbit(x)=2k(k为x二进制表示末尾0的个数).lowbit(x)=2k(k为x二进制表示末尾0的个数).
求区间[A, B]元素和
定义sum[i]=a1+a2+...+aisum[i]=a1+a2+...+ai, 问题转化成求sum[B] - sum[A - 1]. 推倒sum[i]的求法:
sum[i]=a1+a2+...+ai=a1+a2+...+ai−2k+ai−2k+1+...+ai=a1+a2+...+ai−2k+Ci,(2k==lowbit(i))=a1+a2+...+ai−lowbit(i)+Ci(1)(2)(3)(4)(1)sum[i]=a1+a2+...+ai(2)=a1+a2+...+ai−2k+ai−2k+1+...+ai(3)=a1+a2+...+ai−2k+Ci,(2k==lowbit(i))(4)=a1+a2+...+ai−lowbit(i)+Ci
因而计算sum[i]sum[i]的过程如下代码所示:int lowbit(int x){ return x & (-x); } int sum(int n){ int ans = 0; while(n > 0){ ans += c[n]; n -= lowbit(n); } return ans; }
计算sum[i]的过程实则为去除i末尾1的过程, 因而算法复杂度为O(logn)
单点更新位置i上的元素增加d.
当更新aiai时,从树的性质能看出,只需要更新CiCi及其所有的父节点Ci+lowbit(i)...等Ci+lowbit(i)...等void add(int i, int d){ while(i < MAX){ c[i] += d; i += lowbit(i); } }
由于树的高度最多为log(n). 因而更新维护C数组的复杂度也是O(logn).
树状数组解决的问题举例:
单点更新,区间查询.
最原始的算法套用. 查询[A, B]区间和, ans = sum[B] - sum[A - 1]区间更新,单点查询:
【例】
a) ADD A B d (将区间[A, B]中的元素增加d)
b) QUERY x (查询x的值)
这里只需把a)操作转化为add(A, d), add(B + 1, -d). 即可.二维树状数组:
【例】 二维矩阵M*N(M, N <=10^3), 有以下两种操作
a) ADD x y d(将x,y 位置的元素增加d)
b) QUERY x1 y1 x2 y2( 查询矩形区域(x1,y1)(左上角), (x2, y2)(右下角)的和)
计算ans = sum(x2, y2) - sum(x2, y1 - 1) - sum(x1 - 1, y2) + sum(x1 - 1, y1 - 1)
二维树状数组:void add(int x, int y, int d){ for(int i = x; i <= n; i += lowbit(i)){ for(int j = y; j <= n; j += lowbit(j)){ c[i][j] += d; } } } int sum(int x, int y){ int ans = 0; for(int i = x; i > 0; i -= lowbit(i)){ for(int j = y; j > 0; j -= lowbit(j)){ ans += c[i][j]; } } return ans; }