设 s z r o o t ( u ) sz_{root}(u) szroot(u)表示以 r o o t root root为根,子树 T ( u ) T(u) T(u)的大小, f a r o o t ( u ) fa_{root}(u) faroot(u)表示以 r o o t root root为根,节点 u u u的父亲。这两个东西可以 O ( n 2 ) O(n^2) O(n2)预处理。
转化一下问题:
a
n
s
=
∑
x
=
1
n
−
1
∑
u
,
v
∈
V
[
m
e
x
(
u
,
v
)
≥
x
]
ans=\sum\limits_{x = 1}^{n - 1} \sum\limits_{u,v \in V} [\mathrm{mex}(u,v) \ge x]
ans=x=1∑n−1u,v∈V∑[mex(u,v)≥x]
那么现在考虑从小到大加入每一个边权,首先加入边权
0
0
0:
那么一定可以得到 s z u ( v ) × s z v ( u ) sz_{u}(v) \times sz_{v}(u) szu(v)×szv(u)的贡献。那么接下来加入权值 1 1 1。
假如说 0 , 1 0,1 0,1不是连着的,那么两条边夹着的所有点对答案的就只能是 0 0 0,(如图)
接着想想,将 0 , 1 0,1 0,1连在一起有利无害,那么就连在一起吧。。。这启示我们,接下来放边权的时候,一定是尽量将之前已经放了边权形成的路径延长一个位置(这个想法比较古怪)。也就是说,对于 m e x \mathrm{mex} mex最长的一条路径,它所包含的边权可以组成 [ 0 , 路 径 长 度 − 1 ] [0,路径长度-1] [0,路径长度−1]这一个区间的,而且区间的两个端点一定在叶子上。接下来设定一条最长的路径 ( u , v ) (u,v) (u,v),我们想想怎么在上面填边权。首先在某个位置填上 0 0 0,然后任选左右其中位置填上 1 1 1,然后再选左右其中一个填上 2 2 2。每次填上一个边权之后,在答案中加入路径两边连着的子树的大小。这不就是 d p dp dp吗?设 f ( u , v ) f(u,v) f(u,v)表示形成了路径 ( u , v ) (u,v) (u,v)得到的最大 a n s ans ans。
那么可以得到:
f
(
u
,
v
)
=
{
max
{
f
(
f
a
v
(
u
)
,
v
)
,
f
(
u
,
f
a
u
(
v
)
}
+
s
z
u
(
v
)
×
s
z
v
(
u
)
u
≠
v
0
u
=
v
f(u,v)=\begin{cases} \max\left\{f(fa_{v}(u),v), f(u, fa_{u}(v)\right\} + sz_{u}(v)\times sz_{v}(u) & u \ne v\\ 0 & u = v \end{cases}
f(u,v)={max{f(fav(u),v),f(u,fau(v)}+szu(v)×szv(u)0u=vu=v
#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int maxn = 3e3 + 5;
struct Edge {
int v, nex;
Edge(int v = 0, int nex = 0) : v(v), nex(nex) {}
} E[maxn << 1];
int hd[maxn], tote;
void addedge(int u, int v) {
E[++tote] = Edge(v, hd[u]), hd[u] = tote;
E[++tote] = Edge(u, hd[v]), hd[v] = tote;
}
LL sz[maxn][maxn], fat[maxn][maxn], f[maxn][maxn];
int n, rt;
void init(int u, int fa) {
fat[rt][u] = fa, sz[rt][u] = 1;
for (int i = hd[u]; i; i = E[i].nex) {
int v = E[i].v;
if (v == fa) continue;
init(v, u), sz[rt][u] += sz[rt][v];
}
}
LL dp(int u, int v) {
if (u == v) return 0;
if (~f[u][v]) return f[u][v];
return f[u][v] = max(dp(fat[v][u], v), dp(u, fat[u][v])) + sz[u][v] * sz[v][u];
}
int main() {
memset(f, -1, sizeof(f));
scanf("%d", &n);
for (int i = 1; i < n; i++) {
int u, v; scanf("%d%d", &u, &v);
addedge(u, v);
}
for (int i = 1; i <= n; i++) rt = i, init(i, 0);
LL ans = 0;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++) ans = max(ans, dp(i, j));
printf("%lld\n", ans);
return 0;
}