这题真TM神……
和commonc一起orz了半天题解-_-
我们考虑一下dp,f[i][j]表示以i为根的子树里与i距离为j的点的个数,g[i][j]表示子树内有g[i][j]对点深度相同,且距离他们的LCA距离都为d,且i与他们的LCA的距离为d-j。换一种说法是表示以i为根的子树里有这么多个点对在底下分叉了,并且还没有第三个点和这个点对匹配,这个第3个点不在i的子树里并且与i距离为j的方案数(不考虑第三个点有多少种选法)
设x表示当前点,y表示儿子
f[x][0]=1
ans+=g[x][0]
这样的话枚举出边,一边枚举一边更新保证不重复计算,每次枚举出边的时候再枚举i
f[x][i]+=f[y][i-1]
g[x][i-1]+=g[y][i]
g[x][i+1]+=f[x][i+1]*f[y][i]
ans+=f[x][i-1]*g[y][i]+g[x][i+1]*f[y][i]
但是这样的话时间和空间都会爆,我们把整个树进行轻重链剖分,子树深度最大的儿子是重儿子,重边练成重链,对于一个点,在第一次用儿子更新的时候我们有f[x][i]=f[y][i-1],g[x][i]=g[y][i+1],可以用指针O(1)进行这一步转移,由于对一个儿子进行转移的复杂度是O(dep[y]),所以不妨对重儿子进行O(1)转移
题解上写复杂度是O(nlogn),但是我并不知道是为什么,commonc神犇证明了复杂度是O(n),实测了一下似乎确实是O(n),证明如下
设h[x]表示以x为根的子树的高度
对每个点转移的复杂度为sigma h[y] -h[son[x]]=sigma h[y] -h[x]+1,做和的话除了叶子节点所有点的dep都被抵消,所以复杂度为O(n)
空间的话非叶子节点所需要的空间都是由他所在重链的叶子节点用指针挪过来的,所以对每个叶子节店给他开正比于所在重链长度的空间即可
空间复杂度为O(sigma 重链长度和)=O(n),不过常数好像略大?分配内存的时候加的常数对内存有极大影响
我在BZ上第一个rank1啊
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<cstring>
#include<cmath>
#include<ctime>
#include<algorithm>
#include<iomanip>
#include<vector>
#include<stack>
#include<queue>
#include<map>
#include<set>
#include<bitset>
using namespace std;
#define MAXN 100010
#define MAXM 1010
#define ll long long
#define INF 1000000000
#define MOD 1000000007
#define eps 1e-8
char xB[1<<15],*xS=xB,*xT=xB;
#define getc() (xS==xT&&(xT=(xS=xB)+fread(xB,1,1<<15,stdin),xS==xT)?0:*xS++)
inline int read(){
char ch=getc();
int f=1,x=0;
while(!(ch>='0'&&ch<='9')){if(ch=='-')f=-1;ch=getc();}
while(ch>='0'&&ch<='9'){x=x*10+(ch-'0');ch=getc();}
return x*f;
}
struct vec{
int to;
int fro;
};
vec mp[MAXN*2];
int tai[MAXN],cnt;
ll memp[MAXN*5];
ll *f[MAXN],*g[MAXN];
int mx[MAXN];
int dep[MAXN];
ll *wzh=memp+5;
int n;
ll ans;
inline void be(int x,int y){
mp[++cnt].to=y;
mp[cnt].fro=tai[x];
tai[x]=cnt;
}
inline void bde(int x,int y){
be(x,y);
be(y,x);
}
void dfs1(int x,int F){
int i,y;
mx[x]=x;
for(i=tai[x];i;i=mp[i].fro){
y=mp[i].to;
if(y!=F){
dep[y]=dep[x]+1;
dfs1(y,x);
if(dep[mx[y]]>dep[mx[x]]){
mx[x]=mx[y];
}
}
}
for(i=tai[x];i;i=mp[i].fro){
y=mp[i].to;
if(y!=F&&(mx[y]!=mx[x]||x==1)){
y=mx[y];
wzh+=dep[y]-dep[x]+1;
f[y]=wzh;
g[y]=(wzh+=1);
wzh+=(dep[y]-dep[x])*2+1;
}
}
}
void dp(int x,int F){
int i,j,y,z;
for(i=tai[x];i;i=mp[i].fro){
y=mp[i].to;
if(y==F){
continue ;
}
dp(y,x);
if(mx[y]==mx[x]){
f[x]=f[y]-1;
g[x]=g[y]+1;
}
}
ans+=g[x][0];
f[x][0]=1;
for(i=tai[x];i;i=mp[i].fro){
y=mp[i].to;
if(y==F||mx[y]==mx[x]){
continue ;
}
for(j=0;j<=dep[mx[y]]-dep[x];j++){
ans+=f[x][j-1]*g[y][j]+g[x][j+1]*f[y][j];
}
for(j=0;j<=dep[mx[y]]-dep[x];j++){
g[x][j-1]+=g[y][j];
g[x][j+1]+=f[x][j+1]*f[y][j];
f[x][j+1]+=f[y][j];
}
}
}
int main(){
int i,x,y;
n=read();
for(i=1;i<n;i++){
x=read();
y=read();
bde(x,y);
}
while(wzh!=memp){
*wzh=0;
wzh--;
}
*wzh=0;
wzh+=1;
dep[1]=1;
dfs1(1,0);
dp(1,0);
printf("%lld\n",ans);
return 0;
}
/*