题解其实是相当详细的,但是是英文的,于是我就自己翻译了一份(拒转载,心血啊),有可能有错,但是并不影响大部分的理解,可以参考。
Code:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;
const int Max = 1000000;
int N, Root, Mouse;
int fa[Max + 5], Path[Max + 5], Mova[Max + 5];
vector<int>Child[Max + 5];
vector<pair<int, int> >L;
bool getint(int & num){
char c; int flg = 1; num = 0;
while((c = getchar()) < '0' || c > '9'){
if(c == '-') flg = -1;
if(c == -1) return 0;
}
while(c >= '0' && c <= '9'){
num = num * 10 + c - 48;
if((c = getchar()) == -1) return 0;
}
num *= flg;
return 1;
}
int Dfs(int u, int ff){
fa[u] = ff;
int p;
for(p = 0; p < Child[u].size() && Child[u][p] != ff; ++ p);
if(p < Child[u].size()) Child[u].erase(Child[u].begin() + p);
int dis = -1;
if(u == Mouse) dis = 0;
for(int i = 0; i < Child[u].size(); ++ i)
dis = max(dis, Dfs(Child[u][i], u));
Path[u] = dis;
if(dis == -1) return dis;
else return dis + 1;
}
void Get_Cost(int u, int d){
if(Child[u].size() == 0){
Mova[u] = d;
return ;
}
if(Path[u] == -1){
int tmp = d + Child[u].size();
for(int i = 0; i < Child[u].size(); ++ i)
Get_Cost(Child[u][i], tmp);
if(Child[u].size() == 1){
Mova[u] = d + 1;
return ;
}
int fir = -1, sec = -1;
for(int i = 0; i < Child[u].size(); ++ i){
int v = Child[u][i];
if(Mova[v] > fir){
sec = fir, fir = Mova[v];
continue;
}
if(Mova[v] > sec) sec = Mova[v];
}
Mova[u] = sec;
return ;
}
int tmp = d + Child[u].size() - 1;
if(u == Root) tmp = 0;
if(u == Mouse) ++ tmp;
for(int i = 0; i < Child[u].size(); ++ i)
Get_Cost(Child[u][i], tmp);
if(u == Root) return ;
for(int i = 0; i < Child[u].size(); ++ i) if(Path[Child[u][i]] == -1)
L.push_back(make_pair(Path[u], Mova[Child[u][i]]));
}
int main(){
freopen("mousetrap.in", "r", stdin);
freopen("mousetrap.out", "w", stdout);
getint(N), getint(Root), getint(Mouse);
int u, v;
for(int i = 1; i < N; ++ i){
getint(u), getint(v);
Child[u].push_back(v);
Child[v].push_back(u);
}
int P = Dfs(Root, -1);
Get_Cost(Root, 0);
sort(L.begin(), L.end());
int l = 0, r = N;
while(l < r){
int mid = (l + r) >> 1;
int x = 0, y = 0, pt = 0;
for(int i = 0; i < P; ++ i){
if(pt >= L.size()){
r = mid;
break;
}
++ x;
int delta = 0;
while(pt < L.size() && L[pt].first <= i){
if(L[pt].second + y > mid)
++ delta, -- x;
++ pt;
}
y += delta;
if(x < 0 || y > mid){ l = mid + 1; break;}
}
}
printf("%d\n", l);
return 0;
}