题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=6686
不写题解了,写不动
还有其他简单的做法
#include <bits/stdc++.h>
#define rep(i, a, b) for(int i = (a); i <= (b); i++)
#define per(i, a, b) for(int i = (a); i >= (b); i--)
#define pii pair<int,int>
#define mp make_pair
#define all(x) (x).begin(),(x).end()
#define ll long long
#define pb push_back
using namespace std;
const int N = 1e5+1000;
int n,m,f[N][3],g[N][2];
vector<pii>s;
vector<int> nxt[N];
void dfs(int u,int fa) {
f[u][0] = 1;f[u][1] = 1;f[u][2] = 1;
for(auto v:nxt[u]) {
if(v==fa) continue;
dfs(v,u);
f[u][2] = max(f[u][2],f[v][0]+1);
if(f[u][2]>f[u][1]) swap(f[u][2],f[u][1]);
if(f[u][1]>f[u][0]) swap(f[u][1],f[u][0]);
g[u][1] = max(g[u][1],max(g[v][0],f[v][1]+f[v][0]-1));
if(g[u][1]>g[u][0]) swap(g[u][1],g[u][0]);
}
}
void dfs2(int u,int fa) {
for(auto v:nxt[u]) {
if(v==fa) continue;
int l,r,len;
r = max(g[v][0],f[v][0]+f[v][1]-1);
if(f[u][0]!=f[v][0]+1) {
len = f[u][0]+1;
if(f[u][1]!=f[v][0]+1) l = f[u][0]+f[u][1]-1;
else l = f[u][0]+f[u][2]-1;
}
else l = f[u][1]+f[u][2]-1,len = f[u][1]+1;
if(g[u][0]!=max(g[v][0],f[v][0]+f[v][1]-1)) l = max(l,g[u][0]);
else l = max(l,g[u][1]);
s.pb(mp(-l,r));
s.pb(mp(-r,l));
g[v][1] = max(g[v][1],l);
if(g[v][1]>g[v][0]) swap(g[v][1],g[v][0]);
f[v][2] = max(f[v][2],len);
if(f[v][2]>f[v][1]) swap(f[v][2],f[v][1]);
if(f[v][1]>f[v][0]) swap(f[v][1],f[v][0]);
dfs2(v,u);
}
}
int main() {
//freopen("a.txt","r",stdin);
ios::sync_with_stdio(0);
int T;
cin>>T;
while(T--) {
cin>>n;
rep(i, 1, n) {
nxt[i].clear();
f[i][0] = f[i][1] = f[i][2] = 0;
g[i][0] = g[i][1] = 0;
}
s.clear();
rep(i, 1, n-1) {
int u,v;
cin>>u>>v;
nxt[u].pb(v);
nxt[v].pb(u);
}
dfs(1,0);
dfs2(1,0);
sort(all(s));
int siz = s.size();
ll maxr = s[0].second;
ll ans = abs(s[0].first*s[0].second);
rep(i, 1, siz-1) {
if(s[i].second<=maxr) continue;
ans += (s[i].second-maxr)*abs(s[i].first);
maxr = s[i].second;
}
cout<<ans<<endl;
}
return 0;
}