Problem
给出一棵n(≤250000)个节点的树,要求将上面的所有边变为单向边,方向自定,求最小的和最大的其中一个点可以到达另一个点的点对。
Solution
对于最小值,奇数层的边向上,偶数层的边向下,不就可以了吗?对啊!!!我刚做时竟然没想到!!!!!证明的话,感性理解一下。
对于最大值,就麻烦了。
首先,我们必须承认一个可能被我们贴上“错误”标签的结论:答案一定是以一个点作为源点,其它点要么通向这个点,要么可以从这个点到达。(反正我在看题解前根本没想过这个结论)(这个结论我们也得感性理解一下)考虑如果这样的点存在两个。。其间一定存在一条路径。。设第一个点进入的是a,出去的是b,第二个进入的是c,出去的是d,设sum=a+b+c+d,即证:a*b+b*c+c*d<=max((sum-a)*a,(sum-b)*b,(sum-c)*c,(sum-d)*d),然后这是显然的。
那么我们肯定要选重心。因为如果不选重心的话,往重心的方向前进一步,答案肯定只会加不会减的。
如何求树的重心呢?其实也很简单,我们只需设size[x]表示以x为根的子树的点数,做一遍树形DP,然后对于一个点x,它自己的子树中点数最大的是max(size[y])(y∈son(x))max(size[y])(y∈son(x)),而当它变成根时,它原来的父亲节点就会变成它的一个子节点,大小为n-size[x]。所以只要这两个数均≤n/2,x就是重心。
然后我们以重心为根,做一遍树形DP,求出它的每个子树自身的答案(因为它的每个子树的边都是要么全部向上要么全部向下,答案一致)以及新的size。
此时,我们需确定究竟将那些子树的边定为全部向上,其他子树的边则定为全部向下。
设边全部向上的子树的size和为X,向下的size和为Y,则越过根的、可以通达的点对即为X*Y。要保证X*Y最大,我们必须选择的X越接近(n-1)/2越好。
那么上背包。但是如果碰到菊花图,显然会炸。
于是考虑优化这个背包。我们可以把size相同的子树并在一起,做多重背包。那么我们就可以使用二进制优化。比如说,根有12棵子树,size为:3、3、3、3、3、3、3、3、3、3、4、4,那么我们把它合成10个3,2个4,也即1+2+4+3个3,1+1个4,然后变成:3、6、12、9、4、4。二进制优化能够正确就在于二进制数能表示一切数。
分析一下时间复杂度:首先,对于某种size相同的子树,假设它有n-1个,我们可以近似视作n个,那么最多会分为log2n+1log2n+1个(比如上面的10,被分为1+2+4+3),可以近似视作log2nlog2n个;而不同size的子树,则至多会有n−−√n种,因为你想想:1+2+3+4+……=n-1。
这么大的n,也敢带根号,还敢带个log?因为这只是理论上的最坏情况,实际上可能远远达不到。如果碰到菊花图,虽然最多会分成log2nlog2n个,但是不同size的子树种类只有1,所以还是很快的。
时间复杂度:O(nn−−√log2n)O(nnlog2n)。
Code
#include <cstdio>
#include <vector>
using namespace std;
#define N 250001
#define ll long long
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
int j,u,v,bary,cnt[N],m,thing[N],x;
ll i,n,size[N],g[N],nn;
bool f[N];
vector<int>edge[N];
void dfs(int x,int fat)
{
int y;
size[x]=1;
for(vector<int>::iterator it=edge[x].begin();it!=edge[x].end();it++)
{
y=*it;
if(y!=fat)
{
dfs(y,x);
size[x]+=size[y];
if(size[y]>n/2||bary)return;
}
}
if(n-size[x]<=n/2)bary=x;
}
void df(int x,int fat)
{
int y;
size[x]=1;
for(vector<int>::iterator it=edge[x].begin();it!=edge[x].end();it++)
{
y=*it;
if(y!=fat)
{
df(y,x);
size[x]+=size[y];
g[x]+=g[y]+size[y];
if(x==bary)cnt[size[y]]++;
}
}
}
int main()
{
freopen("polarization.in","r",stdin);
freopen("polarization.out","w",stdout);
scanf("%d",&n);
fo(i,1,n-1)
{
scanf("%d%d",&u,&v);
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs(1,0);
df(bary,0);
fo(i,1,n)
if(cnt[i])
{
x=1;
while(cnt[i]>=x)
{
thing[++m]=x*i;
cnt[i]-=x;
x*=2;
}
if(cnt[i])thing[++m]=cnt[i]*i;
}
f[0]=1;
nn=(n-1)/2;
fo(i,1,m)
fd(j,nn-thing[i],0)
if(f[j])
f[j+thing[i]]=1;
fd(i,nn,0)
if(f[i])
break;
printf("%lld %lld",n-1,g[bary]+i*(n-1-i));
}