给1棵树,编号1~n,还有m条路径,现在问存在多少点对u,vu,v,满足uu到的必经之路上,不会经过这m条路径中的任何一条。
n,m<=1e5n,m<=1e5
求出树的dfs序。把路径两头的dfs看成2维坐标(u,v)(u,v),
如果x,y不是祖先关系,则禁止的区域为:从x的子树到y的子树。
如果是祖先关系,假设x的深度小,且往y方向下面一个点是z,
那么禁止的区域为:一个点在y的子树中,一个点不在z的子树中。
矩形面积并。
#include<bits/stdc++.h>
using namespace std;
#define For(i,n) for(int i=1;i<=n;i++)
#define Fork(i,k,n) for(int i=k;i<=n;i++)
#define ForkD(i,k,n) for(int i=n;i>=k;i--)
#define Rep(i,n) for(int i=0;i<n;i++)
#define ForD(i,n) for(int i=n;i>0;i--)
#define RepD(i,n) for(int i=n;i>=0;i--)
#define Forp(x) for(int p=pre[x];p;p=next[p])
#define Forpiter(x) for(int &p=iter[x];p;p=next[p])
#define MEM(a) memset(a,0,sizeof(a));
#define MEMI(a) memset(a,0x3f,sizeof(a));
#define MEMi(a) memset(a,128,sizeof(a));
#define MEMx(a,b) memset(a,b,sizeof(a));
#define INF (0x3f3f3f3f)
#define F (1000000007)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define vi vector<int>
#define pi pair<int,int>
#define SI(a) ((a).size())
#define Pr(kcase,ans) printf("Case #%d: %lld\n",kcase,ans);
#define PRi(a,n) For(i,n-1) cout<<a[i]<<' '; cout<<a[n]<<endl;
#define PRi2D(a,n,m) For(i,n) { \
For(j,m-1) cout<<a[i][j]<<' ';\
cout<<a[i][m]<<endl; \
}
#pragma comment(linker, "/STACK:102400000,102400000")
#define All(x) (x).begin(),(x).end()
#define gmax(a,b) a=max(a,b);
#define gmin(a,b) a=min(a,b);
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
ll mul(ll a,ll b){return (a*b)%F;}
ll add(ll a,ll b){return (a+b)%F;}
ll sub(ll a,ll b){return ((a-b)%F+F)%F;}
void upd(ll &a,ll b){a=(a%F+b%F)%F;}
inline int read()
{
int x=0,f=1; char ch=getchar();
while(!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();}
while(isdigit(ch)) { x=x*10+ch-'0'; ch=getchar();}
return x*f;
}
#define MAXN (101000)
#define MAXn (MAXN<<2)
int n,m,dep[MAXN]={};
vi e[MAXN];
int tot=0;
int in[MAXN],out[MAXN];
int fa[MAXN][25]={0},len=20;
void dfs(int x,int faa){
in[x]=++tot;
fa[x][0]=faa;
For(i,len) {
fa[x][i]=fa[fa[x][i-1]][i-1];
}
dep[x]=dep[faa]+1;
for(int v:e[x]) if(v^faa) {
dfs(v,x);
}
out[x]=tot;
}
int anc(int x,int dep) {
for(int i=20,j=1<<20;i>=0;i--,j>>=1) {
if(dep>=j) dep-=j,x=fa[x][i];
}
return x;
}
int lca(int u , int v)
{
if(dep[u] < dep[v]) swap(u , v) ;
int d = dep[u] - dep[v] ;
int i ;
for(i = 0 ; i <= 20 ; i ++)
{
if( (1 << i) & d )
{
u = fa[u][i] ;
}
}
if(u == v) return u ;
for(i = 20 ; i >= 0 ; i --)
{
if(fa[u][i] != fa[v][i])
{
u = fa[u][i] ;
v = fa[v][i] ;
}
}
u = fa[u][0] ;
return u ;
}
#define M (MAXn)
struct Segment
{
int mi[M],minum[M],tar[M];
#define ls (x<<1)
#define rs (x<<1|1)
void pushup(int x)
{
if(mi[ls]==mi[rs])
mi[x]=mi[ls],minum[x]=minum[ls]+minum[rs];
else
if(mi[ls]<mi[rs])
mi[x]=mi[ls],minum[x]=minum[ls];
else
mi[x]=mi[rs],minum[x]=minum[rs];
}
void pushdown(int x)
{
mi[ls]+=tar[x];mi[rs]+=tar[x];
tar[ls]+=tar[x];tar[rs]+=tar[x];
tar[x]=0;
}
void build(int x,int l,int r)
{
if(l==r)
{
mi[x]=0;minum[x]=1;tar[x]=0;
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);build(rs,mid+1,r);
pushup(x);
}
void update(int x,int l,int r,int L,int R,int v)
{
if(L<=l && r<=R)
{
mi[x]+=v;tar[x]+=v;
return;
}
pushdown(x);
int mid=(l+r)>>1;
if(L<=mid) update(ls,l,mid,L,R,v);
if(R>mid) update(rs,mid+1,r,L,R,v);
pushup(x);
}
}tr;
struct Tmat
{
int x,y,dy,ad;
Tmat(){}
Tmat(int xx,int yy,int dyy,int da){x=xx;y=yy;dy=dyy;ad=da;}
void pri() {
cout<<x<<' '<<y<<' '<<dy<<' '<<ad<<endl;
}
};
Tmat a[MAXN*32];
bool cmp(Tmat A,Tmat B)
{
return A.x<B.x;
}
int cnt=0;
void add(int a1,int a2,int b1,int b2) {
++cnt;
a[cnt]=Tmat(b1,a1,a2,1);
if(b2<n) a[++cnt]=Tmat(b2+1,a1,a2,-1);
swap(a1,b1);swap(a2,b2);
++cnt;
a[cnt]=Tmat(b1,a1,a2,1);
if(b2<n) a[++cnt]=Tmat(b2+1,a1,a2,-1);
}
void solve()
{
ll ans=0;
tr.build(1,1,n);
sort(a+1,a+cnt+1,cmp);
For(i,cnt) {
}
int nowl=1,nowr=1;
int pr=0;
for(int i=1;i<=n;++i)
{
if(a[nowl].x^i) //no modify
{
if(!tr.mi[1]) ans+=tr.minum[1];
continue;
}
tr.update(1,1,n,a[nowl].y,a[nowl].dy,a[nowl].ad); //update
while(a[nowr+1].x==a[nowl].x)
++nowr,tr.update(1,1,n,a[nowr].y,a[nowr].dy,a[nowr].ad);
if(!tr.mi[1]) ans+=tr.minum[1];
nowl=nowr+1;nowr=nowl;
}
cout<<ans-n<<endl;
}
int main()
{
// freopen("E.in","r",stdin);
// freopen(".out","w",stdout);
cin>>n>>m;
For(i,n-1) {
int u=read(),v=read();
e[u].pb(v);
e[v].pb(u);
}
dfs(1,0);
For(i,m) {
int u=read(),v=read();
int g=lca(u,v);
if(g!=u&&g!=v) {
if(in[u]>in[v]) swap(u,v);
add(in[u],out[u],in[v],out[v]);
}
else {
int x=u,y=v;
if(dep[x]>dep[y]) swap(x,y);
int z=anc(y,dep[y]-dep[x]-1);
if(in[z]>1)add(1,in[z]-1,in[y],out[y]);
if(out[z]<n)add(in[y],out[y],out[z]+1,n);
}
}
ll ans=0;
solve();
return 0;
}