树上差分裸题
#include<bits/stdc++.h>
using namespace std;
typedef int lint;
const int maxn = 300000 + 5;
const int maxm = 600000 + 5;
lint a[maxn],ver[maxm],he[maxn],ne[maxm],tot,ans[maxn],val[maxn];
lint d[maxn],vis[maxn],f[maxn][21];
void add( lint x,lint y ){
ver[++tot] = y;
ne[tot] = he[x];
he[x] = tot;
}
void init(){
memset( he,0,sizeof(he) );
tot = 1;
}
queue<lint> que;
void build(){
d[1] = 0;
que.push(1);
vis[1] = 1;
while( que.size() ){
lint x = que.front();
que.pop();
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == f[x][0] ) continue;
que.push(y);
d[y] = d[x] + 1;
f[y][0] = x;
for( lint i = 1; (1<< i) <= d[y];i++ ){
f[y][i] = f[ f[y][i-1] ][i-1];
}
}
}
}
lint lca( lint x,lint y ){
if( d[x] < d[y] ) swap( x,y );
for( lint i = 20;i >= 0;i-- ){
if( f[x][i] && d[f[x][i]] >= d[y] ) x = f[x][i];
}
if( x == y ) return x;
for( lint i = 20;i >= 0;i-- ){
if( f[x][i] != f[y][i] ){
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}
lint dfs( lint x,lint fa ){
lint sum = val[x];
for( lint cure = he[x];cure;cure = ne[cure] ){
lint y = ver[cure];
if( y == fa ) continue;
sum += dfs( y,x );
}
ans[x] = sum;
return sum;
}
int main(){
lint n;
init();
scanf("%d",&n);
for( int i = 1;i <= n;i++ ) scanf("%d",&a[i]);
for( lint x,y,i = 1;i <= n-1;i++ ){
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
build();
for( lint i = 1;i <= n-1;i++ ){
lint y = lca( a[i],a[i+1] );
val[ f[y][0] ] -= 1;
val[y] -= 1;
val[ a[i] ] +=1;
val[ a[i+1] ] += 1;
}
dfs(1,0);
for( lint i = 1;i <= n;i++ ){
if( i != a[1] ){
printf("%d\n", ans[i]-1 );
}else{
printf("%d\n", ans[i] );
}
}
return 0;
}