题目
n(n<=1e5)个点的一棵树,每个点有一个点权wi(1<=wi<=1e5)
每棵树的贡献定义为,这棵树中所有出现的点权的不同种类数,
要求删去一条边,使得两棵子树的贡献之和最大,输出这个贡献和
思路来源
https://blog.youkuaiyun.com/winter2121/article/details/82466690
题解
思路来源的博主写的博客好评呐,通俗易懂
dfs序一下,拍成线性序列,考虑到删一条边会将原树分成一棵子树和剩下的部分
而这棵子树在dfs序中的序列是连续的,[li,ri]
那就是统计[li,ri]内不同种类数和[1,li-1]∪[ri+1,n]的不同种类数
把原序列完整复制一遍在后面,那个∪的区间就连续了
区间数字种数,用BIT做的套路,从[1,R]统计每个数第一次出现的位置
每删掉一个端点,就把这个数下一次出现的位置在BIT中+1
可以不对这个端点进行-1操作,毕竟前缀和作差会把之前的1减掉
这个题也可以用莫队做,离线O(n根号n),
在n=1e5的条件下,应该是莫队的极限了
据说也可以用主席树在线做,但要求常数优秀,就不赘述了
Trick
不知道有没有trick,枚举到根的时候,实际是没有删边的
实际上,[li,ri]和[ri+1,li+n-1]在根的时候,前者对应[1,n],后者对应[n+1,n]
后者这个反序序列,会对应一个负值,使得根的这种情况不可能成为最大值
代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
vector<int>E[maxn];
int n,a[maxn*2],b[maxn*2];
int tree[maxn*2];
int head[maxn],nex[2*maxn],mx;
int l[maxn],r[maxn],cnt,dfn;
int u;
bool vis[maxn];
//head[a[i]]:a[i]这个值第一次出现的位置
//next[i]:与i位置相同的值下一次出现的位置
int ans[maxn],res;
struct node
{
int l,r,id;
node(){}
node(int L,int R,int i):l(L),r(R),id(i){}
}e[maxn*2];
bool operator<(node a,node b)
{
return a.l<b.l||(a.l==b.l&&a.r<b.r);
}
void add(int x,int v)
{
for(int i=x;i<=n;i+=i&-i)
tree[i]+=v;
}
int sum(int x)
{
int ans=0;
for(int i=x;i>0;i-=i&-i)
ans+=tree[i];
return ans;
}
void dfs(int u,int fa)
{
l[u]=++dfn;
b[dfn]=a[u];
for(int i=0;i<E[u].size();++i)
{
int v=E[u][i];
if(v==fa)continue;
dfs(v,u);
}
r[u]=dfn;
}
int main()
{
while(~scanf("%d",&n))
{
res=cnt=dfn=mx=0;
for(int i=1;i<=n;++i)
{
E[i].clear();
ans[i]=0;
}
for(int i=2;i<=n;++i)
{
scanf("%d",&u);
E[u].push_back(i);
E[i].push_back(u);
}
for(int i=1;i<=n;++i)
{
scanf("%d",&a[i]);
mx=max(mx,a[i]);
}
for(int i=1;i<=mx;++i)
{
vis[i]=0;//最大值每次不同
head[i]=2*n+1;//最大值相关位置
}
dfs(1,-1);
for(int i=1;i<=n;++i)
{
//if(l[i]==1)continue; 没分
b[i+n]=b[i];
e[++cnt]=node(l[i],r[i],i);//[1,l[i]-1],[l[i],r[i]],[r[i]+1,n]
e[++cnt]=node(r[i]+1,l[i]+n-1,i);//将第一段挪到后面来
}
sort(e+1,e+cnt+1);
n*=2;
for(int i=n;i>=1;--i)
{
nex[i]=head[b[i]];
head[b[i]]=i;
tree[i]=0;
}
int L=1;
for(int i=1;i<=n;++i)
if(!vis[b[i]])
{
vis[b[i]]=1;
add(i,1);
}
for(int i=1;i<=cnt;++i)
{
while(L<e[i].l)
{
if(nex[L]<=n)add(nex[L],1);
//add(a[L],-1); 作差会被作掉
L++;
}
ans[e[i].id]+=sum(e[i].r)-sum(e[i].l-1);
res=max(res,ans[e[i].id]);
}
printf("%d\n",res);
}
return 0;
}