思路:首先这道题可以直接裸上树剖。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
#define maxn 300005
int n,tot;
int now[maxn],son[2*maxn],pre[2*maxn],size[maxn],heavy[maxn],dep[maxn],top[maxn];
int dfn[maxn],fa[maxn],ans[maxn],a[maxn],op[maxn],cnt[maxn];
inline int read(){
int x=0;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar());
for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x;
}
void add(int a,int b){
son[++tot]=b;
pre[tot]=now[a];
now[a]=tot;
}
void link(int a,int b){
add(a,b),add(b,a);
}
void getsize(int x){
size[x]=1,dep[x]=dep[fa[x]]+1;
for (int p=now[x];p;p=pre[p])
if (son[p]!=fa[x]){
fa[son[p]]=x;
getsize(son[p]);
size[x]+=size[son[p]];
if (size[son[p]]>size[heavy[x]]) heavy[x]=son[p];
}
}
void getdfn(int x){
top[heavy[x]]=top[x],dfn[heavy[x]]=dfn[x]+1;int cnt=dfn[heavy[x]]+size[heavy[x]];
for (int p=now[x];p;p=pre[p])
if (son[p]!=fa[x]){
if (son[p]!=heavy[x]) dfn[son[p]]=cnt,cnt+=size[son[p]];
getdfn(son[p]);
}
}
struct segment_tree{
struct treenode{
int cover,tag;
}tree[4*maxn];
void addtag(int p,int val){
tree[p].cover+=val;
tree[p].tag+=val;
}
void pushdown(int p){
if (!tree[p].tag) return;
addtag(p<<1,tree[p].tag),addtag(p<<1|1,tree[p].tag),tree[p].tag=0;
}
void query(int p,int l,int r,int x,int y){
if (x<=l&&r<=y){
addtag(p,1);
return;
}
pushdown(p);
int mid=(l+r)>>1;
if (x<=mid) query(p<<1,l,mid,x,y);
if (y>mid) query(p<<1|1,mid+1,r,x,y);
}
void getans(int p,int l,int r){
if (l==r){ans[op[l]]=tree[p].cover;return;}
pushdown(p);int mid=(l+r)>>1;
getans(p<<1,l,mid),getans(p<<1|1,mid+1,r);
}
}T;
void query(int x,int y){
while (top[x]!=top[y]){
if (dep[top[x]]<dep[top[y]]) swap(x,y);
T.query(1,1,n,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
if (dep[x]>dep[y]) swap(x,y);
T.query(1,1,n,dfn[x],dfn[y]);
}
int main(){
n=read();
for (int i=1;i<=n;i++) a[i]=read();
for (int i=2;i<n;i++) cnt[a[i]]++;
for (int i=1,x,y;i<n;i++) x=read(),y=read(),link(x,y);
getsize(1);for (int i=1;i<=n;i++) top[i]=i;getdfn(dfn[1]=1);
for (int i=1;i<=n;i++) op[dfn[i]]=i;
for (int i=1;i<n;i++) query(a[i],a[i+1]);
T.getans(1,1,n);
for (int i=1;i<=n;i++) printf("%d\n",ans[i]-cnt[i]-(i==a[n]));
return 0;
}
然后参见了黄学长的博客发现了一种更为高明的解法,可以在树上差分,令x为路径(u,v)的lca,然后f[u]++,f[v]++,f[x]--,f[fa[x]]--,然后直接一遍dfs累加起来就好了。(为什么我的代码常数辣么大。。。差分被树剖虐。。。。)
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
#define maxn 300005
int n,tot;
int now[maxn],pre[2*maxn],son[2*maxn],a[maxn],dep[maxn],ans[maxn];
int f[maxn][21];
inline int read(){
int x=0;char ch=getchar();
for (;ch<'0'||ch>'9';ch=getchar());
for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
return x;
}
void add(int a,int b){
son[++tot]=b;
pre[tot]=now[a];
now[a]=tot;
}
void link(int a,int b){
add(a,b),add(b,a);
}
void dfs(int x){
dep[x]=dep[f[x][0]]+1;
for (int p=now[x];p;p=pre[p])
if (son[p]!=f[x][0]) f[son[p]][0]=x,dfs(son[p]);
}
int lca(int a,int b){
if (dep[a]<dep[b]) swap(a,b);int x=dep[a]-dep[b],t=0;
for (;x;x>>=1,t++) if (x&1) a=f[a][t];
if (a==b) return a;t=log2(dep[a])+1;
for (;f[a][0]!=f[b][0];){
for (;f[a][t]==f[b][t];t--);
a=f[a][t],b=f[b][t];
}
return f[a][0];
}
void tree_dp(int x){
for (int p=now[x];p;p=pre[p])
if (son[p]!=f[x][0]){
tree_dp(son[p]);
ans[x]+=ans[son[p]];
}
}
int main(){
n=read();
for (int i=1;i<=n;i++) a[i]=read();
for (int i=1,x,y;i<n;i++) x=read(),y=read(),link(x,y);
dfs(1);
for (int i=1;i<=20;i++)
for (int x=1;x<=n;x++)
f[x][i]=f[f[x][i-1]][i-1];
for (int i=1;i<n;i++){
int x=lca(a[i],a[i+1]);
ans[a[i]]++,ans[a[i+1]]++,ans[x]--,ans[f[x][0]]--;
}
tree_dp(1);
for (int i=1;i<=n;i++) printf("%d\n",ans[i]-(i!=a[1]));
return 0;
}