模板:
给 n n n个可重集,要求维护他们,支持以下操作:
(1) l , r , k l,r,k l,r,k,对编号为 [ l , r ] [l,r] [l,r]的每个可重集,加入元素 k k k
(2) l , r , k l,r,k l,r,k,对编号为 [ l , r ] [l,r] [l,r]的所有可重集的并集,求里面的第 k k k大元素
Solution:
动态区间 k k k小问题支持的是单点修改,而这里需要支持的是区间修改。
对于树状数组,原始版本支持单点增,区间和查询,差分版本支持区间加,单点查询。现在要对 [ l , r ] [l,r] [l,r]的每个可重集修改,那么树状数组需要支持区间加,同时又要合并线段树,所以树状数组还需要支持区间和查询,下面是一种改进的方法,使得树状数组同时支持以上两种操作。
令树状数组维护的原始数组为 a a a,差分数组为 b b b,其中 b [ i ] = a [ i ] − a [ i − 1 ] b[i]=a[i]-a[i-1] b[i]=a[i]−a[i−1]
由差分的前缀和是本身,有
a [ i ] = b [ 1 ] + b [ 2 ] + b [ 3 ] + . . . . + b [ i ] a[i]=b[1]+b[2]+b[3]+....+b[i] a[i]=b[1]+b[2]+b[3]+....+b[i]
又前缀和为
s u m [ i ] = a [ 1 ] + a [ 2 ] + . . . + a [ i ] sum[i]=a[1]+a[2]+...+a[i] sum[i]=a[1]+a[2]+...+a[i]
将前缀和改写为差分形式有
s u m [ i ] = ( b [ 1 ] ) + ( b [ 1 ] + b [ 2 ] ) + ( b [ 1 ] + b [ 2 ] + b [ 3 ] ) + . . . . + ( b [ 1 ] + b [ 2 ] + b [ 3 ] + . . . + b [ i ] ) sum[i]=(b[1])+(b[1]+b[2])+(b[1]+b[2]+b[3])+....+(b[1]+b[2]+b[3]+...+b[i]) sum[i]=(b[1])+(b[1]+b[2])+(b[1]+b[2]+b[3])+....+(b[1]+b[2]+b[3]+...+b[i])
即
s u m [ i ] = ( i + 1 ) ∗ ( b [ 1 ] + b [ 2 ] + . . . + b [ i ] ) − ( 1 ∗ b [ 1 ] + 2 ∗ b [ 2 ] + . . . + i ∗ b [ i ] ) sum[i]=(i+1)*(b[1]+b[2]+...+b[i])-(1*b[1]+2*b[2]+...+i*b[i]) sum[i]=(i+1)∗(b[1]+b[2]+...+b[i])−(1∗b[1]+2∗b[2]+...+i∗b[i])
其中 ( b [ 1 ] + b [ 2 ] + . . . + b [ i ] ) (b[1]+b[2]+...+b[i]) (b[1]+b[2]+...+b[i])可以看作差分数组的前缀和,可以用树状数组维护, ( 1 ∗ b [ 1 ] + 2 ∗ b [ 2 ] + . . . + i ∗ b [ i ] ) (1*b[1]+2*b[2]+...+i*b[i]) (1∗b[1]+2∗b[2]+...+i∗b[i])也可以用树状数组维护,这个维护的时候 a d d ( x , k ) add(x,k) add(x,k),对每个 i i i,( i i i是原始 x x x加上了若干个 l o w b i t lowbit lowbit之后的 x x x)都加上 x ∗ k x*k x∗k(最开始的 x x x),用两个树状数组分别维护上面的即可
代码
void add(int x,int k)
{
int tmp=x*k;
while(x<=n)
{
tree1[x]+=k;
tree2[x]+=tmp;
x+=lowbit(x);
}
}
/*
区间[l,r]增加k
add(l,k);
add(r+1,-k);
*/
int query(int l,int r)//查询[l,r]区间和
{
int ret=0,x=r;
while(x)
{
ret+=(r+1)*tree1[x]-tree2[x];
x-=lowbit(x);
}
x=l-1;
while(x)
{
ret-=l*tree1[x]-tree2[x];
x-=lowbit(x);
}
return ret;
}
总代码:
#include<bits/stdc++.h>
#define ll long long
#define endl '\n'
using namespace std;
struct question
{
ll op,l,r,k;
}q[50005];
ll vv[50005];
int n,m,len,version1[50005],version2[50005];
int cnt,lson[25000005],rson[25000005],tree[25000005];
inline int lowbit(int x){return -x&x;}
inline int& ls(int x){return lson[x];}
inline int& rs(int x){return rson[x];}
inline void push_up(int x){tree[x]=tree[ls(x)]+tree[rs(x)];}
void modify(int nl,int l,int r,int x,int k)
{
if(l==r){tree[x]+=k;return;}
int mid=l+r>>1;
if(nl<=mid)
{
if(!ls(x)) ls(x)=++cnt;
modify(nl,l,mid,ls(x),k);
}
else
{
if(!rs(x)) rs(x)=++cnt;
modify(nl,mid+1,r,rs(x),k);
}
push_up(x);
}
void add(int x,int k,int v)
{
int tmp=x*v;
while(x<=n)
{
if(!version1[x]) version1[x]=++cnt;
if(!version2[x]) version2[x]=++cnt;
modify(k,1,len,version1[x],v);
modify(k,1,len,version2[x],tmp);
x+=lowbit(x);
}
}
int L,R;
int tmp1[50005],tmp2[50005],pp1,pp2;
int tmp3[50005],tmp4[50005],pp3,pp4;
int solve(int l,int r,int k)
{
if(l==r) return vv[l];
int mid=l+r>>1; ll sum=0;
for(int i=1;i<=pp1;i++) sum+=1ll*(R+1)*tree[rs(tmp1[i])];
for(int i=1;i<=pp2;i++) sum-=tree[rs(tmp2[i])];
for(int i=1;i<=pp3;i++) sum-=1ll*(L+1)*tree[rs(tmp3[i])];
for(int i=1;i<=pp4;i++) sum+=tree[rs(tmp4[i])];
if(k<=sum)
{
for(int i=1;i<=pp1;i++) tmp1[i]=rs(tmp1[i]);
for(int i=1;i<=pp2;i++) tmp2[i]=rs(tmp2[i]);
for(int i=1;i<=pp3;i++) tmp3[i]=rs(tmp3[i]);
for(int i=1;i<=pp4;i++) tmp4[i]=rs(tmp4[i]);
return solve(mid+1,r,k);
}
else
{
for(int i=1;i<=pp1;i++) tmp1[i]=ls(tmp1[i]);
for(int i=1;i<=pp2;i++) tmp2[i]=ls(tmp2[i]);
for(int i=1;i<=pp3;i++) tmp3[i]=ls(tmp3[i]);
for(int i=1;i<=pp4;i++) tmp4[i]=ls(tmp4[i]);
return solve(l,mid,k-sum);
}
}
int query(int l,int r,int k)
{
pp1=pp2=pp3=pp4=0; L=l-1; R=r;
for(int i=r;i;i-=lowbit(i)) tmp1[++pp1]=version1[i],tmp2[++pp2]=version2[i];
for(int i=l-1;i;i-=lowbit(i)) tmp3[++pp3]=version1[i],tmp4[++pp4]=version2[i];
return solve(1,len,k);
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin>>n>>m;
for(int i=1;i<=m;i++)
{
auto& [op,l,r,k]=q[i];
cin>>op>>l>>r>>k;
if(op==1) vv[++len]=k;
}
sort(vv+1,vv+1+len); len=unique(vv+1,vv+1+len)-vv-1;
for(int i=1;i<=m;i++)
if(q[i].op==1) q[i].k=lower_bound(vv+1,vv+1+len,q[i].k)-vv;
for(int i=1;i<=m;i++)
{
auto& [op,l,r,k]=q[i];
if(op==1)
{
add(l,k,1);
add(r+1,k,-1);
}
else cout<<query(l,r,k)<<endl;
}
return 0;
}