解题思路:
这道题空找一条边不好找,注意到他是使最大的最小,所以不妨考虑二分答案。
设二分的答案为t,则对于每一条len>td的链,必须要修改其上的一条边。如果共要修改m条链,那么修改的那条边一定是这m条链的交集,那么问题就变成了求m条链的交集,看最长链减去交集中最长的一条边是否小于t;
那如何求交集呢?可以用树上差分:
设tag[v]表示m条链经过(v,fa[v])这条边的次数,对于每条链(x,y),使tag[x]+1,tag[y]+1,tag[lca(x,y]-2,最后dfs一遍,tag[i]加上其子树上的所有tag即为经过(i,fa[i])这条边的次数;
要图的话可以看:https://blog.sengxian.com/solutions/noip-2015-day2
代码的lca是倍增的,T了可以改树链剖分。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#define ll long long
using namespace std;
int getint()
{
int i=0,f=1;char c;
for(c=getchar();(c<'0'||c>'9')&&c!='-';c=getchar());
if(c=='-')f=-1,c=getchar();
for(;c>='0'&&c<='9';c=getchar())i=(i<<3)+(i<<1)+c-'0';
return i*f;
}
const int N=300005;
struct node
{
int x,y,lca,len;
inline friend bool operator <(const node &a,const node &b)
{
return a.len>b.len;
}
}chain[N];
int n,m,mx;
int tot,first[N],nxt[N<<1],to[N<<1],w[N<<1];
int dep[N],dis[N],fa[N][20],tag[N],mem[N];
inline void add(int x,int y,int z)
{
nxt[++tot]=first[x],first[x]=tot,to[tot]=y,w[tot]=z;
}
void dfs(int u)
{
for(int i=1;i<20;i++)fa[u][i]=fa[fa[u][i-1]][i-1];
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];
if(v!=fa[u][0])
{
dis[v]=dis[u]+w[e];
dep[v]=dep[u]+1;
fa[v][0]=u;
dfs(v);
}
}
}
int LCA(int x,int y)
{
if(dep[x]<dep[y])swap(x,y);
int delta=dep[x]-dep[y];
for(int i=19;i>=0;i--)
if(delta&(1<<i))x=fa[x][i];
for(int i=19;i>=0;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return x==y?x:fa[x][0];
}
void find(int u)
{
for(int e=first[u];e;e=nxt[e])
{
int v=to[e];
if(v!=fa[u][0])
{
find(v);
tag[u]+=tag[v];
}
}
}
bool check(int lim)
{
if(mx<=lim)return true;
memset(tag,0,sizeof(tag));
int cnt=0;
for(int i=1;i<=m;i++)
if(chain[i].len>lim)
{
cnt++;
tag[chain[i].x]++;
tag[chain[i].y]++;
tag[chain[i].lca]-=2;
}
else break;
if(mem[cnt]!=-1)return mx-mem[cnt]<=lim;
find(1);
int less=-1;
for(int i=1;i<=n;i++)
if(tag[i]==cnt)less=max(less,dis[i]-dis[fa[i][0]]);
if(less!=-1)mem[cnt]=less;
return mx-less<=lim;
}
int main()
{
//freopen("lx.in","r",stdin);
ios::sync_with_stdio(false);
cin.tie(NULL);
int x,y,z,l=0,r=0;
memset(mem,-1,sizeof(mem));
n=getint(),m=getint();
for(int i=1;i<n;i++)
{
x=getint(),y=getint(),z=getint();
add(x,y,z),add(y,x,z);
r+=z;
}
dfs(1);
for(int i=1;i<=m;i++)
{
chain[i].x=getint(),chain[i].y=getint();
chain[i].lca=LCA(chain[i].x,chain[i].y);
chain[i].len=dis[chain[i].x]+dis[chain[i].y]-2*dis[chain[i].lca];
mx=max(mx,chain[i].len);
}
sort(chain+1,chain+m+1);
while(l<=r)
{
int mid=l+r>>1;
if(check(mid))r=mid-1;
else l=mid+1;
}
cout<<l;
return 0;
}