树状数组是与线段树实现的功能类似的数据结构,也是可以在log(n)的复杂度实现区间修改和区间查询。
以简洁的代码和思路占优,十分容易实现前缀和。
单点变化,区间查询
#include <cstdio>
using namespace std;
typedef long long ll;
ll c[500005]; //表示区间[i-lowbit(i)+1,i]的区间元素和
int a[500005];
int n;
int lowbit(int x)
{
return x & (-x);
}
void update(int x,int k) //向上更新,c[i+lowbit(i)]表示的区间一定包含了c[i]
{
for (int i = x; i <= n; i += lowbit(i))
{
c[i] += k;
}
}
ll query(int x) //sum[x-lowbit(x)+1,x]+sum[t-lowbit(t)+1,t]( t = x-lowbit(x) )...求出sum[1,x]
{
ll ans = 0;
for (int i = x; i; i -= lowbit(i))
{
ans += c[i];
}
return ans;
}
int main()
{
int m;
scanf("%d%d",&n,&m);
for (int i = 1; i <= n; i++)
{
scanf("%d",&a[i]);
update(i,a[i]); //原来的值必须更新
}
while( m-- )
{
int kind,x,y;
scanf("%d%d%d",&kind,&x,&y);
if( kind == 1 )
{
update(x,y);
}else
{
printf("%d\n",query(y) - query(x-1));
}
}
return 0;
}
区间变化,单点查询
若要区间变化,那么我们维护的是数列必须是差分,这样才能保证改变的时候是只改变两个边界数,a[i]差分前缀和等于a[i]。
#include <cstdio>
using namespace std;
typedef long long ll;
ll c[500005];
ll a[500005];
int n;
int lowbit(int x)
{
return x & (-x);
}
void update(int x,int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
c[i] += k;
}
}
ll query(int x)
{
ll ans = 0;
for (int i = x; i; i -= lowbit(i))
{
ans += c[i];
}
return ans;
}
int main()
{
int m;
ll last = 0;
scanf("%d%d",&n,&m);
for (int i = 1; i <= n; i++)
{
scanf("%d",&a[i]);
update(i,a[i]-last);
last = a[i];
}
while( m-- )
{
int kind;
scanf("%d",&kind);
if( kind == 1 )
{
int x,y,k;
scanf("%d%d%d",&x,&y,&k);
update(x,k);
update(y+1,-k);
}else
{
int x;
scanf("%d",&x);
printf("%lld\n",query(x));
}
}
return 0;
}
区间变化,区间查询
数学证明后得:
sum(1…a[i]) = (n+1) * sum(1…d[i]) - sum(1…t[i]) t[i] = i*d[i]
#include <cstdio>
using namespace std;
typedef long long ll;
ll sum1[500005],sum2[500005];
//sum1[i]表示d[i]构成的树状数组 d[i]为差分
//sum2[i]表示i*d[i]构成的树状数组
ll a[500005];
int n;
int lowbit(int x)
{
return x & (-x);
}
void update(int x,int k)
{
for (int i = x; i <= n; i += lowbit(i))
{
sum1[i] += k;
sum2[i] += x * k;
}
}
ll query(int x)
{
ll ans = 0;
for (int i = x; i; i -= lowbit(i))
{
ans += ( x + 1 ) * sum1[i] - sum2[i];
//sum(a[i-lowbit(i)]...a[i]) = (n+1)*sum(d[i-lowbit[i]]...i) - sum(t[i-lowbit[i]]...i);
//t[i] = i * d[i];
}
return ans;
}
int main()
{
int m;
ll last = 0;
scanf("%d%d",&n,&m);
for (int i = 1; i <= n; i++)
{
scanf("%d",&a[i]);
update(i,a[i]-last);
last = a[i];
}
while( m-- )
{
int kind;
scanf("%d",&kind);
if( kind == 1 )
{
int x,y,k;
scanf("%d%d%d",&x,&y,&k);
update(x,k);
update(y+1,-k);
}else
{
int x;
scanf("%d",&x);
printf("%lld\n",query(x)-query(x-1));
}
}
return 0;
}