题解:
考虑两棵树怎么做,要求
max
d
1
(
x
)
+
d
1
(
y
)
−
2
∗
l
c
a
1
(
x
,
y
)
+
d
i
s
2
(
x
,
y
)
\max d_1(x)+d_1(y)-2*lca_1(x,y)+dis_2(x,y)
maxd1(x)+d1(y)−2∗lca1(x,y)+dis2(x,y),把
x
′
x'
x′挂在
x
x
x下面,长度为
d
1
(
x
)
d_1(x)
d1(x),然后就是
max
−
2
∗
l
c
a
1
(
x
,
y
)
+
d
i
s
2
(
x
,
y
)
\max -2*lca_1(x,y)+dis_2(x,y)
max−2∗lca1(x,y)+dis2(x,y)。
相当于对于第一棵树求在第二棵树里面的直径。 因为没有负边权,我们直接记录最长链即可。
第三棵树采用边分治,对于每一层,相当于求 max − 2 ∗ l c a 1 ( x , y ) + d i s 2 ( x , y ) + d 3 ( x ) + d 3 ( y ) \max -2*lca_1(x,y)+dis_2(x,y)+d_3(x)+d_3(y) max−2∗lca1(x,y)+dis2(x,y)+d3(x)+d3(y)。 用类似的方法,接在下方,然后相当于是第一棵树的每棵子树求一下两类点(边分治的两边)的直径。 A A A到 B B B集合的直径一定在 A A A集合最长链, B B B集合最长链的两端取得(在没有负权的情况下)。 然后就是一样的做法了。
用归并排序建虚树, O ( 1 ) O(1) O(1)lca,可以做到 O ( n log n ) O(n \log n) O(nlogn)。
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef pair <int,LL> pii;
const int RLEN=1<<18|1;
inline char nc() {
static char ibuf[RLEN],*ib,*ob;
(ib==ob) && (ob=(ib=ibuf)+fread(ibuf,1,RLEN,stdin));
return (ib==ob) ? -1 : *ib++;
}
inline LL rd() {
char ch=nc(); LL i=0,f=1;
while(!isdigit(ch)) {if(ch=='-')f=-1; ch=nc();}
while(isdigit(ch)) {i=(i<<1)+(i<<3)+ch-'0'; ch=nc();}
return i*f;
}
const int N=2e5+50, L=18;
int n,lg[N]; LL ans;
struct chain {
int x,y; LL l;
chain(int x=0,int y=0,LL l=0) : x(x),y(y),l(l) {}
friend inline bool operator <(const chain &a,const chain &b) {return a.l<b.l || (a.l==b.l && a.x<b.x);}
};
struct T1 {
int dfn[N],sze[N],pos[N],st[N][L+1],id[N];
LL dep[N]; vector <pii> e[N];
int tot,ind;
inline void dfs(int x,int f) {
id[dfn[x]=++ind]=x; sze[x]=1;
st[pos[x]=++tot][0]=ind;
for(auto v:e[x]) if(v.first^f) {
dep[v.first]=dep[x]+v.second;
dfs(v.first,x); st[++tot][0]=dfn[x];
sze[x]+=sze[v.first];
}
}
inline void init() {
for(int i=1;i<n;i++) {
int x=rd(), y=rd(); LL w=rd();
e[x].push_back(pii(y,w));
e[y].push_back(pii(x,w));
} dfs(1,0);
for(int i=1;i<=lg[tot];i++)
for(int j=1;j+(1<<i)-1<=tot;++j)
st[j][i]=min(st[j][i-1],st[j+(1<<(i-1))][i-1]);
}
inline int lca(int x,int y) {
if(pos[x]>pos[y]) swap(x,y);
int l=lg[pos[y]-pos[x]+1];
return id[min(st[pos[x]][l],st[pos[y]-(1<<l)+1][l])];
}
inline bool in(int x,int y) {return dfn[y]<dfn[x]+sze[x] && dfn[y]>dfn[x];}
} t1,t2;
namespace t3 {
int stx,sty,mxv,total; LL stl;
int bl[N],sze[N],vs,tot; LL dep[N];
vector <pii> e1[N];
vector <pii> e2[N];
vector <int> e[N];
inline void dfs(int x,int f) {
vector <pii> vec;
for(auto v:e1[x]) if(v.first^f)
vec.push_back(v), dfs(v.first,x);
for(int j=0;j<vec.size();++j) {
e2[++tot].push_back(vec[j]);
e2[vec[j].first].push_back(pii(tot,vec[j].second));
int pre=j ? tot-1 : x;
e2[pre].push_back(pii(tot,0));
e2[tot].push_back(pii(pre,0));
}
}
inline void init() {
for(int i=1;i<n;i++) {
int x=rd(), y=rd(); LL w=rd();
e1[x].push_back(pii(y,w));
e1[y].push_back(pii(x,w));
} tot=n; dfs(1,0);
}
inline void calcG(int x,int f) {
sze[x]=1;
for(auto v:e2[x]) if(v.first^f && bl[v.first]==bl[x]) {
calcG(v.first,x); sze[x]+=sze[v.first];
if(max(sze[v.first],total-sze[v.first])<=mxv) {
mxv=max(sze[v.first],total-sze[v.first]);
stx=x; sty=v.first; stl=v.second;
}
} bl[x]=vs;
}
int stk[N],id[N],top,ic;
int fa[N],a[N],cnt,rt;
chain mx[N][2];
inline void build_vir(vector <int> &q) {
cnt=top=0;
for(auto i:q) a[++cnt]=i;
for(auto i:q) {
if(!top) stk[++top]=rt=i;
else {
int l=t1.lca(i,stk[top]);
while(t1.in(l,stk[top])) {
if(top==1 || t1.in(stk[top-1],l)) fa[stk[top]]=l;
--top;
}
if(l!=stk[top]) {
if(!top) rt=l;
a[++cnt]=l;
fa[l]=stk[top];
stk[++top]=l;
}
fa[i]=stk[top];
stk[++top]=i;
}
}
for(int i=1;i<=cnt;i++) e[a[i]].clear();
for(int i=1;i<=cnt;i++) if(a[i]!=rt) e[fa[a[i]]].push_back(a[i]);
}
inline LL ask_dis(int x,int y) {
int l=t2.lca(x,y);
return t2.dep[x]+t2.dep[y]-2*t2.dep[l]+t1.dep[x]+dep[x]+t1.dep[y]+dep[y];
}
inline void upt_ans(int i,int x,int y) {
ans=max(ans,ask_dis(x,y)-2*t1.dep[i]);
}
inline chain merge(chain &a,chain &b) {
chain c=chain(0,0,0);
c=max(c,chain(a.x,b.x,ask_dis(a.x,b.x)));
c=max(c,chain(a.x,b.y,ask_dis(a.x,b.y)));
c=max(c,chain(a.y,b.x,ask_dis(a.y,b.x)));
c=max(c,chain(a.y,b.y,ask_dis(a.y,b.y)));
return max(a,max(b,c));
}
inline void merge(int x,int y) {
for(int i=0;i<=1;i++) {
upt_ans(x,mx[x][i].x,mx[y][i^1].x);
upt_ans(x,mx[x][i].x,mx[y][i^1].y);
upt_ans(x,mx[x][i].y,mx[y][i^1].x);
upt_ans(x,mx[x][i].y,mx[y][i^1].y);
}
mx[x][0]=merge(mx[x][0],mx[y][0]);
mx[x][1]=merge(mx[x][1],mx[y][1]);
}
inline void dfs_ans(int x,int f) {
mx[x][0]=(id[x]&1) ? chain(0,0,0) : chain(x,x,0);
mx[x][1]=(id[x]&1) ? chain(x,x,0) : chain(0,0,0);
for(auto v:e[x]) {
dfs_ans(v,x);
merge(x,v);
}
}
inline void dfs_pre(int x,int f) {
for(auto v:e2[x]) if(v.first^f && id[v.first]==ic) {
dep[v.first]=dep[x]+v.second; dfs_pre(v.first,x);
}
}
inline vector <int> calc(int l,int r,LL mid,vector <int> &ql,vector <int> &qr) {
vector <int> q;
int hl=0, hr=0;
while(hl<ql.size() && hr<qr.size()) q.push_back((t1.dfn[ql[hl]]<t1.dfn[qr[hr]]) ? ql[hl++] : qr[hr++]);
while(hl<ql.size()) q.push_back(ql[hl++]); while(hr<qr.size()) q.push_back(qr[hr++]);
vector <int> ct;
for(auto i:q) if(i<=n) ct.push_back(i);
build_vir(ct);
++ic; for(auto i:ql) id[i]=ic;
dep[l]=0; dfs_pre(l,r);
++ic; for(auto i:qr) id[i]=ic;
dep[r]=mid; dfs_pre(r,l);
dfs_ans(rt,0);
return q;
}
inline vector <int> solve(int x,int y,LL len) {
vector <int> ql,qr;
if(total-sze[y]>1) {
++vs; total=mxv=total-sze[y];
calcG(x,y); ql=solve(stx,sty,stl);
} else ql.push_back(x);
if(sze[y]>1) {
++vs; total=mxv=sze[y];
calcG(y,x); qr=solve(stx,sty,stl);
} else qr.push_back(y);
ql=calc(x,y,len,ql,qr);
return ql;
}
inline void solve() {
mxv=total=tot;
calcG(1,0);
solve(stx,sty,stl);
}
}
int main() {
n=rd(); lg[1]=0;
for(int i=2;i<=n*2;i++) lg[i]=lg[i>>1]+1;
t1.init(); t2.init();
t3::init();
t3::solve();
cout<<ans<<'\n';
}