原题大意
是给一颗树,有 个节点,从
开始编号到
。现在考虑树中两个点(两个点不相同,不考虑先后顺序)的最短路,在自然数集除去最短路上的所有点的编号后,剩下的最小自然数就是这条路径的MEX值,要求输出MEX值从
到
对应的这种最短路的数量。对于每一个数据,给了
组数据,每组数据依次按要求输出即可。
题解(来自官方)
想要算出每一个 所对应的路径数量,只需要算出每一个
对应的路径数量即可。
记 的路径数目为
;记
的路径数目为
;记在根节点为 0 的情况下节点
的子树的节点个数为
;
先计算 ,算一下根节点的子节点的子树内部路径总和即可,
(节点
与根节点
直接相连)。
考虑到总路径数目为 ;易知
。
!!!本题重点:如果一个路径(必为一条链)的MEX值大于 则这条路径必须包括点
。(1)
考虑一个双指针 ,
表示链的端点
,
;初始状态
,
都在根节点
。我们使用这个双指针计算
,自然我们把
以
到
的顺序做一遍循环来计算。
现在计算
1.如果 不在
,
这条路径上,那么它一定只能往父节点方向一直走,直到遇到
,
这一条链上某一个点为止,然后这个点就应该和
,
链进行连接。
(1)如果连接后一条链(就是遇到的那个点是l或r),那么路径数量是两颗子树元素个数的积,即 ;因为由(1)知
的路径一定包含
,
的路径。(当
或
是根节点时的“子树”是比较特殊的)。
这里可以很直观的体会路径数量为什么是两颗子树元素个数的积(根节点的”子树“很特殊,要减去一个子节点的子树(这颗子树包含 )。)
(2) 如果这个点的路径去连接不能组成一条链
比如这样
那很明显可以知道 ;因为不可能找到一条 (1) 所描述的路径, 而且容易知道
。
2.如果考虑第 个点时i已经在
,
这条路径上了,那么
,因为在
时算的路径数容易知道其实都是大于
的。
所以 算出来了就可以推出
了,容易发现在算
时,可以只储存前一个
(类似滚动数组?),这个
就是官方题解说的
。
附代码(可读性极差):
#include<iostream>
#include<cstdio>
using namespace std;
long long t;
long long n,l,r,pl,ooi,ccp;
long long subl,subr;
long long p;
long long ans,anss;
long long num[200010];
long long vi[200010];
long long head[400010];
long long rr[400010];
long long tugi[400010];
void link(long long ql,long long qr)
{
pl++;
tugi[pl]=head[ql];
head[ql]=pl;
rr[pl]=qr;
return;
}
void dfs(long long pl1)
{
//cout<<pl1<<endl;
vi[pl1]=1;
num[pl1]++;
long long pll=head[pl1];
while (pll)
{
if (!vi[rr[pll]])
{
dfs(rr[pll]);
num[pl1]+=num[rr[pll]];
}
pll=tugi[pll];
}
return;
}
long long zhao(long long pl1)
{
if (vi[pl1]==1)
return pl1;
vi[pl1]=1;
long long pll=head[pl1];
while (pll)
{
if (num[rr[pll]]>num[pl1])
{
if (rr[pll]==0)
ooi=pl1;
return zhao(rr[pll]);
}
pll=tugi[pll];
}
}
long long pll,ss;
int main()
{
cin>>t;
for (long long i=1; i<=t; i++)
{
cin>>n;
pl=0;
for (long long i=1; i<=n-1; i++)
{
cin>>l>>r;
link(l,r);
link(r,l);
}
dfs(0);
//for (long long i=0; i<=n-1; i++) cout<<num[i]<<" "; cout<<endl;
pll=head[0];
ans=0;
while (pll)
{
ans+=(long long)num[rr[pll]]*(num[rr[pll]]-1)/2;
pll=tugi[pll];
}
p=(long long)n*(n-1)/2-ans;
cout<<ans<<' ';
for (long long i=1; i<=n; i++)
vi[i]=0;
l=0; r=0;
vi[0]=1;
ccp=0;
anss=ans;
//cout<<p<<"&";
for (long long i=1; i<=n-1; i++)
{
if (ccp||vi[i])
{
cout<<0<<' ';
continue;
}
ss=zhao(i);
//cout<<ss<<"*";
if (ss==r)
r=i;
else
{
if (ss==l)
l=i;
else
{
ccp=1;
cout<<p<<' ';
anss+=p;
continue;
}
}
//cout<<l<<" "<<r<<" "<<ooi<<" &";
subl=num[l];
if (l==0)
subl-=num[ooi];
subr=num[r];
ans=p-(long long)subl*subr;
anss+=ans;
p=subl*subr;
cout<<ans<<' ';
}
ans=(long long)n*(n-1)/2-anss;
cout<<ans;
for (long long i=0; i<=n; i++)
{
vi[i]=0;
num[i]=0;
head[i]=0;
rr[i]=0;
tugi[i]=0;
}
for (int i=n+1; i<=2*n; i++)
{
head[i]=0;
rr[i]=0;
tugi[i]=0;
}
putchar(10);
}
return 0;
}