传送门:HDU-6065
题意:根节点为1的有根树,给定一个排列,长度为n,要求将排列切分成K段,定义每段的价值为该排列中两两之间的公共祖先中的最浅深度。要求总价值最小
题解:
首先列出几个性质:
性质1.所有节点的公共祖先的深度一定是两两节点之间的公共祖先的最浅深度
性质2.如果某一段的前x个数的公共祖先为rt,那么加入第x+1个树后,公共祖先的深度是LCA(rt,a[x+1])
证明:对于以rt为根的子树,a[x+1]的位置只有2种情况:
①a[x+1]在子树rt内,LCA(rt,a[x+1])=rt;
②a[x+1]在子树rt外,LCA(rt,a[x+1])=LCA(子树rt中任何一个点,a[x+1])
性质3.根据性质1我们知道,要求a[x+1]与前x个数的公共祖先只要求a[x+1]与a[1~x]任意一个数的LCA即可,我们可以让这个数就等于a[x]
性质4.往第p段区间后面加入一个数,该区间的价值是不增的。这一点也可以由性质1得到。
定义dp[i][j]为前i位切分成j段花费的最小值,由以上性质可以得出递推方程:
①将ai加入第k段区间且不改变其价值:dp[i][j]=dp[i-1][j]
②将ai加入第k段区间并作为公共祖先改变其价值:dp[i][j]=min(dp[i][j],dp[i-1][j-1]+deep[a[i]])
③将ai加入第k段区间并更新公共祖先的深度:dp[i][j]=min(dp[i][j],dp[i-2][j-1]+deep[LCA(a[i-1],a[i])])(利用性质3)
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int mod = 998244353;
const int MX = 3e5 + 10;
const int inf = 0x3f3f3f3f;
struct Edge {
int v, nxt;
} E[MX * 2];
int n, m, head[MX], tot;
void add(int u, int v) {
E[tot].v = v;
E[tot].nxt = head[u];
head[u] = tot++;
}
int sz, ver[2 * MX], deep[2 * MX], first[MX], dp[2 * MX][30], vis[MX], dir[MX];
void init(int n) {
for (int i = 0; i <= n; i++) {
vis[i] = 0;
head[i] = -1;
}
tot = sz = 0;
}
void dfs(int u , int dep) {
vis[u] = true; ver[++sz] = u;
first[u] = sz; deep[sz] = dep;
for (int i = head[u]; ~i ; i = E[i].nxt) {
if ( vis[E[i].v] ) continue;
int v = E[i].v;
dir[v] = dir[u] + 1;
dfs(v, dep + 1);
ver[++sz] = u; deep[sz] = dep;
}
}
void ST(int n) {
for (int i = 1; i <= n; i++)
dp[i][0] = i;
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
int a = dp[i][j - 1] , b = dp[i + (1 << (j - 1))][j - 1];
dp[i][j] = deep[a] < deep[b] ? a : b;
}
}
}
//中间部分是交叉的。
int RMQ(int l, int r) {
int k = 0;
while ((1 << (k + 1)) <= r - l + 1) k++;
int a = dp[l][k], b = dp[r - (1 << k) + 1][k]; //保存的是编号
return deep[a] < deep[b] ? a : b;
}
int LCA(int u , int v) {
int x = first[u] , y = first[v];
if (x > y) swap(x, y);
int ret = RMQ(x, y);
return ver[ret];
}
void pre_solve(int n) {
dir[1] = 1;
dfs(1, 1);
ST(sz);
}
vector<vector<int> >f;
int a[MX];
int main() {
//freopen("in.txt", "r", stdin);
while (~scanf("%d%d", &n, &m)) {
init(n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
add(u, v); add(v, u);
}
pre_solve(n);
f.resize(n + 1);
for (int i = 0; i <= n; i++) {
f[i].resize(m + 1);
for (int j = 1; j <= m; j++) f[i][j] = inf;
}
f[0][0] = 0;
for (int j = 1; j <= m; j++) {
for (int i = 1; i <= n; i++) {
f[i][j] = f[i - 1][j];
f[i][j] = min(f[i][j], f[i - 1][j - 1] + dir[a[i]]);
if (i > 1) f[i][j] = min(f[i][j], f[i - 2][j - 1] + dir[LCA(a[i], a[i - 1])]);
}
}
printf("%d\n", f[n][m]);
}
return 0;
}