题目
t(t<=8)组样例,每次给出一棵n(n<=1e5)个点的树,q(q<=1e5)组询问,
每次给出一对(a,b),询问
即对于n个点来说,每个点到点a和点b中两个点距离的较小值为该点的贡献,求贡献和
sumn<=5e5,sumq<=5e5
思路来源
乱搞AC
题解
考虑a和b之间的所有点d,记c是ab路径上的中点,
如果知道这些点d到a的距离和,如图,蓝色之和,
如果知道这些点d到b的距离和,如图,粉色之和,
如果知道这些点d到c的距离和,如图,绿色之和,
怎么求答案,如图,黄颜色之和,
特别地,补充点a以外的所有点x,补充点b以外的所有点y之后,这张图是这样的
而互相抵消彼此之间的贡献之后,这张图是这样的,
dp[i]维护所有点到i点的距离和,
dp[a]+dp[b]-dp[c],即蓝+粉-绿之后,和答案相比,还多记了两部分,
一部分是c以右的点的个数*ac的距离,另一部分是c以左的点的个数*bc的距离
这表明,其实c不一定要是中点,选取ab路径上的一个点即可,
但是取中点会好处理
首先特判掉(a,b)中点不存在的情形,即a、b距离为1的情形,
然后,不妨记a、b更深的点为a,距离为dis(a,b),
则a往上爬dis(a,b)/2爬到点c,要么还没到lca,要么恰为lca,
还没到lca的话,这个点c及c以左侧的点就是sz[c](代码中为sz[m]),另一侧为n-sz[m]
但是对于恰为lca的情形,sz[c]会同时包含a、b,一侧点的个数不太对,
但注意到这种情况下,应该是a、b到lca的距离均为dis(a,b)的情形,
另一侧点距离的贡献等于这一侧点的贡献,
相当于n个点每个点的贡献相等,恰好统一了这种情况,故不需要特判
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+10,M=18;
int t,n,q,u,v,sz[N],par[N][M],dep[N];
vector<int>e[N];
ll dp[N];
void dfs(int u,int fa){
dep[u]=dep[fa]+1;
par[u][0]=fa;
for(int i=1;i<M;++i){
par[u][i]=par[par[u][i-1]][i-1];
}
sz[u]=1;
for(int i=0;i<e[u].size();++i){
int v=e[u][i];
if(v==fa)continue;
dfs(v,u);
sz[u]+=sz[v];
dp[u]+=dp[v]+sz[v];
}
}
void dfs2(int u,int fa){
if(u!=1){
dp[u]=(dp[fa]+n-2*sz[u]);
}
for(int i=0;i<e[u].size();++i){
int v=e[u][i];
if(v==fa)continue;
dfs2(v,u);
}
}
ll cal(int a,int b){
if(dep[a]<dep[b]){
swap(a,b);
}
int la=a,lb=b;
int d=0,m=a;
for(int i=M-1;i>=0;--i){
if((dep[a]-dep[b])>>i&1){
a=par[a][i];
d+=(1<<i);
}
}
// a==b 即lca(a,b)==b
if(a!=b){
for(int i=M-1;i>=0;--i){
if(par[a][i]!=par[b][i]){
a=par[a][i];
b=par[b][i];
d+=(1<<(i+1));
}
}
a=par[a][0];b=par[b][0];
d+=(1<<1);
}
if(d==1){
return dp[la]-(n-sz[la]);
}
for(int i=M-1;i>=0;--i){
if((d/2)>>i&1){
m=par[m][i];
}
}
// printf("m:%d\n",m);
ll ans=dp[la]+dp[lb]-dp[m];
// m是lca的情况 恰与d是偶数能平分的情况统一 故不需特判
ans-=1ll*sz[m]*(d-d/2);
ans-=1ll*(n-sz[m])*(d/2);
return ans;
}
int main(){
scanf("%d",&t);
while(t--){
scanf("%d%d",&n,&q);
for(int i=1;i<=n;++i){
e[i].clear();
dp[i]=0;
}
for(int i=1;i<n;++i){
scanf("%d%d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1,0);
dfs2(1,0);
// for(int i=1;i<=n;++i){
// printf("%d:%lld\n",i,dp[i]);
// }
while(q--){
scanf("%d%d",&u,&v);
printf("%lld\n",cal(u,v));
}
}
return 0;
}