题目:https://vjudge.net/problem/HDU-5788
题意:给定一颗树,每个点都有权值(1是根节点)mid[i]是以i为根节点的子树的所有节点的权值的中位数。选择一个节点使其权值变为100000(树上节点的权值都不大于100000)使得mid[i]之和最大,输出最大的和。
思路:
1.如果选择某一个节点,只会影响它本身还有祖先节点的mid值,而且只有在该节点的权值小于等于祖先节点的权值是才会将祖先节点的中位数变成中位数的下一位。求出每个节点的curmid和nextmid(中位数和中位数的下一位)
可以用dfs序将子树表示为一个区间,用主席树求区间第k大数。
2.改变其中一个节点的值后,总的mid一定增加。可以先求出原来的curmid和,再加上变化后比原来增加的数。求出每个节点对答案的贡献,取最大值。每个节点对答案的贡献是该节点及其祖先节点中curmid大于等于该节点权值的nextmid[i]-curmid[i]之和。重新dfs一遍,每个节点在被访问到时,他的祖先节点已经被访问,可以维护一个树状数组。
代码:
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int N=1e5+10,M=2e5+10;
#define ll long long
int h[N],ver[M],ne[M],tot,idx;
int a[N],s[N],l[N],r[N],cnt;
int curmid[N],nextmid[N];
void add(int a,int b)
{
ver[tot]=b;
ne[tot]=h[a];
h[a]=tot++;
}
struct tree{
int l,r,sum;
}t[N*4+N*17];
int root[N];
ll d[N];
int lowbit(int x)
{
return x&(-x);
}
void Add(int x,int y)
{
while(x<=100000)
{
d[x]+=y;
x+=lowbit(x);
}
}
ll Sum(int x)
{
ll s=0;
while(x)
{
s=s+d[x];
x-=lowbit(x);
}
return s;
}
int build(int l,int r)
{
int p=++idx;
if(l==r) return p;
int mid=(l+r)/2;
t[p].l=build(l,mid);
t[p].r=build(mid+1,r);
return p;
}
int insert(int q,int l,int r,int x)
{
int p=++idx;
t[p]=t[q];
if(l==r)
{
t[p].sum++;return p;
}
int mid=(l+r)/2;
if(x<=mid) t[p].l=insert(t[q].l,l,mid,x);
else t[p].r=insert(t[q].r,mid+1,r,x);
t[p].sum=t[t[p].l].sum+t[t[p].r].sum;
return p;
}
int query(int p,int q,int l,int r,int k)
{
if(l==r) return l;
int mid=(l+r)/2;
int sum=t[t[q].l].sum-t[t[p].l].sum;
if(sum>=k)
{
return query(t[p].l,t[q].l,l,mid,k);
}
else
{
return query(t[p].r,t[q].r,mid+1,r,k-sum);
}
}
ll ssum=0;
void dfs(int u,int fa)
{
s[u]=1;
l[u]=++cnt;
root[cnt]=insert(root[cnt-1],1,100000,a[u]);
for(int i=h[u];i!=-1;i=ne[i])
{
int v=ver[i];
if(v==fa) continue;
dfs(v,u);
s[u]+=s[v];
}
r[u]=cnt;
if(s[u]==1)
{
curmid[u]=a[u];
nextmid[u]=100000;
ssum=ssum+curmid[u];
}
else
{
int t=(s[u]+1)/2;
curmid[u]=query(root[l[u]-1],root[r[u]],1,100000,t);
nextmid[u]=query(root[l[u]-1],root[r[u]],1,100000,t+1);
ssum=ssum+curmid[u];
}
}
ll ans=0;
void dfs2(int u,int fa)
{
Add(curmid[u],nextmid[u]-curmid[u]);
ans=max(ans,ssum+Sum(100000)-Sum(a[u]-1));
//求出自身和祖先节点中的curmid比当前权值小的节点的(nextmid[i]-curmid[i])之和
for(int i=h[u];i!=-1;i=ne[i])
{
if(ver[i]==fa) continue;
dfs2(ver[i],u);
}
Add(curmid[u],curmid[u]-nextmid[u]);
//在遍历完该节点的所有子节点后需要减去该节点的(nextmid[i]-curmid[i]),为了防止该节点兄弟节点及其子节点在计算时加上该值。
}
int main()
{
int n;
while(scanf("%d",&n)!=EOF)
{
memset(h,-1,sizeof h);
memset(d,0,sizeof d);
tot=0;
cnt=0;
idx=0;
ans=0;
ssum=0;
root[0]=build(1,100000);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
}
for(int i=2;i<=n;i++)
{
int x;
scanf("%d",&x);
add(i,x);
add(x,i);
}
dfs(1,-1);
dfs2(1,-1);
printf("%lld\n",ans);
}
}