很精彩的一道题!题意还是很清晰的。自己想了半天没有思路,最后参考了题解。
首先问题可以变成所有方案数(
n
∗
(
n
−
1
)
2
\frac{n*(n-1)}{2}
2n∗(n−1))减去非法方案数。
考虑每对情侣对方案的负贡献:
当不考虑重复时,
当两个结点非祖先-后代关系时,贡献也就是两个子树的大小的乘积。
当两个结点为祖先-后代关系时,贡献为后代的子树与(所有点,除了祖先指向后代的儿子的子树)的乘积
考虑用一n*n的初始为0的矩阵来表示每一个方案,矩阵上(i,j)或(j,i)点为正数表示i与j的方案是不合法的,那么用dfs序来映射点,一个点的子树的dfs序连续,每个贡献可以表示为若干矩形区域,那么用线段树+扫描线即可计算出不合法的方案数。简单来说,按行从小到大扫描,如果碰到一个贡献矩形的上边缘,则加入贡献矩形左边缘到右边缘为1,如果碰到一个贡献矩形的下边缘之下一层,则加入贡献矩形左边缘到右边缘为-1,也就是线段树区间加法。然后查询只要看第一个节点值就行了。
线段树的T记录为正数的点数,也就是是扫描到第i行时,与深搜序为i的那个点形成的非法方案数。而addv记录区间的修改。
注意,这里线段树不需要标记下传。原因在于某个点的值是否为正数,只和它到线段树根节点的路径上是否存在addv>0有关(因为修改的对称性,addv不可能为负数),不再有时效性,只要存在就相当于被覆盖了,因此默认从根节点向下扫描,如果addv>0,下面的T和addv就不用关心,这个区间都是被覆盖了。
当然,如果不是很放心,写个标记下传也可以的。
#include <cstdio>
#include<algorithm>
#include<vector>
#define kl (k<<1)
#define kr (k<<1|1)
#define M (L+R>>1)
#define lin L,M
#define rin M+1,R
using namespace std;
using LL=long long;
struct item
{
int l,r,d;
item(int l=0,int r=0,int d=0) : l(l),r(r),d(d)
{
}
}ask;
int n,m,dfn[100005],sons[100005],dfs_clock;
int dep[100005],p[100005][17],u,v;
vector<int> E[100005];
vector<item> Q[100005];
int T[1<<18],addv[1<<18];
LL ans;
void dfs(int x,int fa)
{
dfn[x]=++dfs_clock;
sons[x]=1;
for(int i:E[x])
if(i!=fa)
{
dep[i]=dep[x]+1;
p[i][0]=x;
for(int j=1;1<<j<=dep[i];j++)
p[i][j]=p[p[i][j-1]][j-1];
dfs(i,x),sons[x]+=sons[i];
}
}
void fun(int la,int ra,int lb,int rb)
{
if(lb<rb)
{
Q[la].emplace_back(lb,rb-1,1);
Q[ra].emplace_back(lb,rb-1,-1);
}
if(la<ra)
{
Q[lb].emplace_back(la,ra-1,1);
Q[rb].emplace_back(la,ra-1,-1);
}
}
int find_son_of_u(int x,int y)
{
int i;
for(i=0;1<<i<=dep[y]-dep[x];i++);
i--;
for(int j=i;j>=0;j--)
if(1<<j<dep[y]-dep[x])
y=p[y][j];
return y;
}
void modify(int k,int L,int R)
{
if(ask.l<=L&&R<=ask.r)
{
addv[k]+=ask.d;
if(addv[k]>0)
T[k]=R-L+1;
else
T[k]=L<R?T[kl]+T[kr]:0;
return;
}
if(ask.l<=M)
modify(kl,lin);
if(ask.r>M)
modify(kr,rin);
T[k]=addv[k]>0?R-L+1:T[kl]+T[kr];
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++)
scanf("%d%d",&u,&v),E[u].push_back(v),E[v].push_back(u);
dfs(1,0);
while(m--)
{
scanf("%d%d",&u,&v);
if(dep[u]>dep[v])
swap(u,v);
int tmp=find_son_of_u(u,v);
if(p[tmp][0]==u)
{
fun(1,dfn[tmp],dfn[v],dfn[v]+sons[v]);
fun(dfn[tmp]+sons[tmp],n+1,dfn[v],dfn[v]+sons[v]);
}
else
fun(dfn[u],dfn[u]+sons[u],dfn[v],dfn[v]+sons[v]);
}
for(int i=1;i<=n;i++)
{
for(auto &t:Q[i])
ask=t,modify(1,1,n);
ans+=T[1];
}
printf("%lld",(LL) n*(n-1)-ans>>1);
return 0;
}