暴力树剖做法显然,即使做到两个log也不那么优美。
考虑避免树剖做到一个log。那么容易想到树上差分,也即要对每个点统计所有经过他的路径产生的总贡献(显然就是所有这些路径端点所构成的斯坦纳树大小),并支持在一个log内插入删除合并。
考虑怎么求树上一些点所构成的斯坦纳树大小。由虚树的构造过程容易联想到,这就是按dfs序排序后这些点的深度之和-相邻点的lca的深度之和(首尾视作相邻),也就相当于按dfs序遍历所有要经过的点并回到原点的路径长度/2。
这个东西显然(应该)可以set启发式合并维护,但同样就变成了两个log。可以改为线段树合并,线段树上每个节点维护该dfs序区间内dfs序最小和最大的被选中节点,合并时减去跨过两区间的一对相邻点的lca的深度即可。这需要计算O(nlogn)次lca,使用欧拉序rmq做到O(1)lca查询就能以总复杂度O(nlogn)完成。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define N 100010
char getc(){char c=getchar();while ((c<'A'||c>'Z')&&(c<'a'||c>'z')&&(c<'0'||c>'9')) c=getchar();return c;}
int gcd(int n,int m){return m==0?n:gcd(m,n%m);}
int read()
{
int x=0,f=1;char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1;c=getchar();}
while (c>='0'&&c<='9') x=(x<<1)+(x<<3)+(c^48),c=getchar();
return x*f;
}
int n,m,p[N],dfn[N],id[N],fa[N],deep[N],cnt,t;
struct data{int to,nxt;
}edge[N<<1];
vector<int> ins[N],del[N];
void addedge(int x,int y){t++;edge[t].to=y,edge[t].nxt=p[x],p[x]=t;}
void dfs(int k)
{
dfn[k]=++cnt;id[cnt]=k;
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=fa[k])
{
fa[edge[i].to]=k;
deep[edge[i].to]=deep[k]+1;
dfs(edge[i].to);
}
}
namespace euler_tour
{
int dfn[N],id[N<<1],LG2[N<<1],f[N<<1][19],cnt;
void dfs(int k)
{
dfn[k]=++cnt;id[cnt]=k;
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=fa[k])
{
dfs(edge[i].to);
id[++cnt]=k;
}
}
void build()
{
dfs(1);
for (int i=1;i<=cnt;i++) f[i][0]=id[i];
for (int j=1;j<19;j++)
for (int i=1;i<=cnt;i++)
if (deep[f[i][j-1]]<deep[f[min(cnt,i+(1<<j-1))][j-1]]) f[i][j]=f[i][j-1];
else f[i][j]=f[min(cnt,i+(1<<j-1))][j-1];
for (int i=2;i<=cnt;i++)
{
LG2[i]=LG2[i-1];
if ((2<<LG2[i])<=i) LG2[i]++;
}
}
int lca(int x,int y)
{
if (!x||!y) return 0;
x=dfn[x],y=dfn[y];
if (x>y) swap(x,y);
if (deep[f[x][LG2[y-x+1]]]<deep[f[y-(1<<LG2[y-x+1])+1][LG2[y-x+1]]]) return f[x][LG2[y-x+1]];
else return f[y-(1<<LG2[y-x+1])+1][LG2[y-x+1]];
}
}
using euler_tour::lca;
ll ans;
int root[N];
struct data2{int l,r,cnt,lnode,rnode,ans;
}tree[N<<6];
void up(int k)
{
tree[k].lnode=tree[tree[k].l].lnode;if (!tree[k].lnode) tree[k].lnode=tree[tree[k].r].lnode;
tree[k].rnode=tree[tree[k].r].rnode;if (!tree[k].rnode) tree[k].rnode=tree[tree[k].l].rnode;
tree[k].ans=tree[tree[k].l].ans+tree[tree[k].r].ans-deep[lca(tree[tree[k].l].rnode,tree[tree[k].r].lnode)];
}
int merge(int x,int y,int l,int r)
{
if (!x||!y) return x|y;
if (l==r)
{
tree[x].cnt+=tree[y].cnt;
if (tree[x].cnt==0) tree[x].lnode=tree[x].rnode=tree[x].ans=0;
else tree[x].lnode=tree[x].rnode=id[l],tree[x].ans=deep[id[l]];
return x;
}
int mid=l+r>>1;
tree[x].l=merge(tree[x].l,tree[y].l,l,mid);
tree[x].r=merge(tree[x].r,tree[y].r,mid+1,r);
up(x);
return x;
}
void modify(int &k,int l,int r,int x,int op)
{
if (!k) k=++cnt;
if (l==r)
{
tree[k].cnt+=op;
if (tree[k].cnt==0) tree[k].lnode=tree[k].rnode=tree[k].ans=0;
else tree[k].lnode=tree[k].rnode=id[l],tree[k].ans=deep[id[l]];
return;
}
int mid=l+r>>1;
if (x<=mid) modify(tree[k].l,l,mid,x,op);
else modify(tree[k].r,mid+1,r,x,op);
up(k);
}
void solve(int k)
{
for (int i=p[k];i;i=edge[i].nxt)
if (edge[i].to!=fa[k])
{
solve(edge[i].to);
root[k]=merge(root[k],root[edge[i].to],1,n);
}
for (int i:ins[k]) modify(root[k],1,n,dfn[i],1);
for (int i:del[k]) modify(root[k],1,n,dfn[i],-1);
ans+=tree[root[k]].ans-deep[lca(tree[root[k]].lnode,tree[root[k]].rnode)];
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
const char LL[]="%I64d\n";
#else
const char LL[]="%lld\n";
#endif
n=read(),m=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read();
addedge(x,y),addedge(y,x);
}
dfs(1);
euler_tour::build();
for (int i=1;i<=m;i++)
{
int x=read(),y=read(),z=fa[lca(x,y)];
ins[x].push_back(x);ins[x].push_back(y);
ins[y].push_back(x);ins[y].push_back(y);
del[z].push_back(x);del[z].push_back(y);
del[z].push_back(x);del[z].push_back(y);
}
cnt=0;
solve(1);
cout<<ans/2;
return 0;
}