链接:http://acm.hdu.edu.cn/showproblem.php?pid=5405
题意:给出一棵节点上有权的树,两种操作:
1.修改一个点的权
2.询问一条路径u到v,求
∑wi∗wj
,i到j的路径和uv有公共点
分析:记
road(u,v)
为u到v的路径,对于询问,可以分两种情况考虑:
1.
lca(i,j)∈road(u,v)
,我们可以提前维护好
lca(i,j)=x
的答案,然后询问的时候只要树链剖分链上求和即可
2.
lca(i,j)∉road(u,v)
,记y=lca(u,v),那么这种情况i和j必然是一个在y的子树内,另一个不在y的子树内,只要对y的子树求和即可,由于树链剖分序也满足子树是一段连续的区间,因此这两部分都可以用树链剖分搞
如何维护
lca(i,j)=x
的答案
我们考虑计算过程。对每个x,那么
(∑(total[x]−total[v])∗total[v])+total[x]∗a[x]
就是我们要求的,其中a[x]代表x的权值,total[x]代表x的子树权值和,由于之前使用了树链剖分,我们可以将total分为两部分,一部分是他轻孩子的子树和,另一部分是他重儿子的子树和,由于修改重链不会影响轻孩子的子树和,因此可以做到去见修改;而修改轻链只有logn次,因此可以暴力查询,暴力修改。
当中细节比较多,我写的也比较繁琐。
听学弟说似乎可以用动态树搞,然而我还不太会涉及子树操作的动态树的写法。。。。
#include<bits/stdc++.h>
using namespace std;
const int Maxn=100020,M=1e9+7;
#define ls l,mid,x<<1
#define rs mid+1,r,x<<1|1
typedef long long Int;
int n,m;
int a[Maxn];
vector<int>G[Maxn];
int f[Maxn],last[Maxn];
int pre[Maxn],fin[Maxn],rev[Maxn],dfs_t;
int sz[Maxn],son[Maxn],h[Maxn];
int heavy[Maxn],light[Maxn],total[Maxn];//total=heavy+light+a
int rep[Maxn];
inline void up(int &x,int y){x+=y;if(x>=M)x-=M;if(x<0)x+=M;}
void dfs1(int u,int p)
{
f[u]=p;
sz[u]=1;
son[u]=0;
h[u]=h[p]+1;
for(int v:G[u])
{
if(v==p)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[son[u]])son[u]=v;
}
}
void dfs2(int u,int p,int Last)
{
heavy[u]=light[u]=rep[u]=0;
total[u]=a[u];
last[u]=Last;
pre[u]=++dfs_t;
rev[dfs_t]=u;
if(son[u])
{
dfs2(son[u],u,Last);
up(heavy[u],total[son[u]]);
}
for(int v:G[u])
{
if(v==p||v==son[u])continue;
dfs2(v,u,v);
up(light[u],total[v]);
}
fin[u]=dfs_t;
up(total[u],heavy[u]);
up(total[u],light[u]);
if(son[u])up(rep[u],(total[u]-total[son[u]])*(Int)total[son[u]]%M);
for(int v:G[u])
{
if(v==p||v==son[u])continue;
up(rep[u],(total[u]-total[v])*(Int)total[v]%M);
}
up(rep[u],a[u]*(Int)total[u]%M);
}
int suma[Maxn<<2],sumrep[Maxn<<2],sumlight[Maxn<<2],sumheavy[Maxn<<2];
int lzheavy[Maxn<<2];
inline void push_up(int x,int *y)
{
y[x]=y[x<<1]+y[x<<1|1];
if(y[x]>=M)y[x]-=M;
}
void push_down(int l,int r,int x)
{
if(lzheavy[x])
{
int mid=(l+r)>>1;
up(lzheavy[x<<1],lzheavy[x]);
up(lzheavy[x<<1|1],lzheavy[x]);
up(sumheavy[x<<1],(mid-l+1)*(Int)lzheavy[x]%M);
up(sumheavy[x<<1|1],(r-mid)*(Int)lzheavy[x]%M);
up(sumrep[x<<1],lzheavy[x]*(Int)(sumlight[x<<1]+suma[x<<1])%M*2%M);
up(sumrep[x<<1|1],lzheavy[x]*(Int)(sumlight[x<<1|1]+suma[x<<1|1])%M*2%M);
lzheavy[x]=0;
}
}
void build(int l,int r,int x)
{
lzheavy[x]=0;
if(l==r)
{
suma[x]=a[rev[l]];
sumrep[x]=rep[rev[l]];
sumlight[x]=light[rev[l]];
sumheavy[x]=heavy[rev[l]];
return;
}
int mid=(l+r)>>1;
build(ls);build(rs);
push_up(x,suma);
push_up(x,sumrep);
push_up(x,sumlight);
push_up(x,sumheavy);
}
void addval(int tar,int val,int l,int r,int x,int *y)
{
if(l==r)
{
up(y[x],val);
return;
}
push_down(l,r,x);
int mid=(l+r)>>1;
if(tar<=mid)addval(tar,val,ls,y);
else addval(tar,val,rs,y);
push_up(x,y);
}
void addheavy(int L,int R,int val,int l,int r,int x)
{
if(L<=l&&R>=r)
{
up(lzheavy[x],val);
up(sumheavy[x],val*(Int)(r-l+1)%M);
up(sumrep[x],val*(Int)(suma[x]+sumlight[x])%M*2%M);
return ;
}
push_down(l,r,x);
int mid=(l+r)>>1;
if(L<=mid)addheavy(L,R,val,ls);
if(R>mid)addheavy(L,R,val,rs);
push_up(x,sumheavy);
push_up(x,sumrep);
}
int query(int L,int R,int l,int r,int x,int *y)
{
if(L<=l&&R>=r)return y[x];
push_down(l,r,x);
int mid=(l+r)>>1;
int ret=0;
if(L<=mid)up(ret,query(L,R,ls,y));
if(R>mid)up(ret,query(L,R,rs,y));
return ret;
}
void modify(int u,int val)
{
int delta=val-a[u];
int delta2=(val*(Int)val-a[u]*(Int)a[u])%M;
addval(pre[u],delta,1,n,1,suma);
int t1=query(pre[u],pre[u],1,n,1,sumlight);
up(t1,query(pre[u],pre[u],1,n,1,sumheavy));
addval(pre[u],(delta*(Int)t1%M*2+delta2)%M,1,n,1,sumrep);
a[u]=val;
for(;;)
{
int nxt=last[u];
if(pre[nxt]<=pre[f[u]])addheavy(pre[nxt],pre[f[u]],delta,1,n,1);
u=nxt;
if(!f[u])break;
int t1=query(pre[u],fin[u],1,n,1,suma);
u=f[u];
addval(pre[u],delta,1,n,1,sumlight);
int t2=query(pre[u],pre[u],1,n,1,sumlight);
int t3=query(pre[u],pre[u],1,n,1,sumheavy);
addval(pre[u],2*delta*((t2-(Int)t1+t3+a[u])%M)%M,1,n,1,sumrep);
}
}
int ask(int u,int v)
{
int ret=0;
for(;last[u]!=last[v];)
{
if(h[last[u]]<h[last[v]])swap(u,v);
up(ret,query(pre[last[u]],pre[u],1,n,1,sumrep));
u=f[last[u]];
}
if(h[u]<h[v])swap(u,v);
up(ret,query(pre[v],pre[u],1,n,1,sumrep));
int totala=suma[1];
int ta=query(pre[v],fin[v],1,n,1,suma);
up(ret,(totala-ta)*(Int)ta%M*2%M);
return ret;
}
int main()
{
while(scanf("%d%d",&n,&m)!=EOF)
{
for(int i=1;i<=n;i++)scanf("%d",a+i),G[i].clear();
for(int i=1;i<n;i++)
{
int u,v;scanf("%d%d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs_t=0;
dfs1(1,0);
dfs2(1,0,1);
build(1,n,1);
while(m--)
{
int ty,u,v;
scanf("%d%d%d",&ty,&u,&v);
if(ty==1)modify(u,v);
else printf("%d\n",ask(u,v));
}
}
}