题意:
给一棵节点数为n,节点种类为k的无根树,问其中有多少种不同的简单路径,可以满足路径上经过所有k种类型的点?(a->b与b->a算作两条路径,起点与终点也可以相同)
思路:
现场赛的时候k的大小是7,当时看到这题也没多想就树形dp水过了。现在重现赛k改成了10,这时候用树形dp,无论是时间还是空间复杂度都很爆炸。后来听说这题的正解是树分治,于是就学习了一波,然后重新来做这道题,关于树分治的内容在我上一篇博客中详细介绍了,链接:
http://blog.youkuaiyun.com/bahuia/article/details/53066373
这题运用树的点分治算法,与POJ-1741的区别就在于后者是求长度小于等于k的路径数目,而这道题是求经过所有种类点的路径,状压一下,也就是求状态为(1<<k)-1的路径数目,其实本质上是一样的,只是从路径权值的加和变成了路径状态的或运算。
这题的难点在于cal()函数,也就是将问题转化成了已知x个数a1,a2,...ax,求其中有多少点对的或运算的和为(1<<k)-1,因为这些都是二进制状态,并没有直接的大小关系,所以POJ-1741那题排序的算法就不能用了,这里我们必须另外想一个O(nlogn)级别的算法。
我们枚举每一个其中的每一个数x,想找到数组中有多少数和x的或运算的和为(1<<k)-1,也就是找到可以包含((1<<k)-1)^x的数,这时候可以反向考虑,先枚举x的子集,然后再与(1<<k)-1进行异或运算,就可以找到了所有的情况。
具体细节看代码。
这题运用树的点分治算法,与POJ-1741的区别就在于后者是求长度小于等于k的路径数目,而这道题是求经过所有种类点的路径,状压一下,也就是求状态为(1<<k)-1的路径数目,其实本质上是一样的,只是从路径权值的加和变成了路径状态的或运算。
这题的难点在于cal()函数,也就是将问题转化成了已知x个数a1,a2,...ax,求其中有多少点对的或运算的和为(1<<k)-1,因为这些都是二进制状态,并没有直接的大小关系,所以POJ-1741那题排序的算法就不能用了,这里我们必须另外想一个O(nlogn)级别的算法。
我们枚举每一个其中的每一个数x,想找到数组中有多少数和x的或运算的和为(1<<k)-1,也就是找到可以包含((1<<k)-1)^x的数,这时候可以反向考虑,先枚举x的子集,然后再与(1<<k)-1进行异或运算,就可以找到了所有的情况。
具体细节看代码。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 5e5 + 10;
int n, k, Max, root;
ll ans;
vector <int> tree[MAXN];
vector <int> sta;
int sz[MAXN], maxv[MAXN], a[MAXN];
ll Hash[1200];
bool vis[MAXN];
void init() {
memset(vis, false, sizeof(vis));
for (int i = 1; i <= n; i++) tree[i].clear();
}
void dfs_size(int u, int pre) {
sz[u] = 1; maxv[u] = 0;
int cnt = tree[u].size();
for (int i = 0; i < cnt; i++) {
int v = tree[u][i];
if (v == pre || vis[v]) continue;
dfs_size(v, u);
sz[u] += sz[v];
maxv[u] = max(maxv[u], sz[v]);
}
}
void dfs_root(int r, int u, int pre) {
maxv[u] = max(maxv[u], sz[r] - sz[u]);
if (Max > maxv[u]) {
Max = maxv[u];
root = u;
}
int cnt = tree[u].size();
for (int i = 0; i < cnt; i++) {
int v = tree[u][i];
if (v == pre || vis[v]) continue;
dfs_root(r, v, u);
}
}
void dfs_sta(int u, int pre, int s) {
sta.push_back(s);
int cnt = tree[u].size();
for (int i = 0; i < cnt; i++) {
int v = tree[u][i];
if (v == pre || vis[v]) continue;
dfs_sta(v, u, s | (1 << a[v]));
}
}
ll cal(int u, int s) {
ll res = 0;
sta.clear(); dfs_sta(u, -1, s);
memset(Hash, 0, sizeof(Hash));
int cnt = sta.size();
for (int i = 0; i < cnt; i++) Hash[sta[i]]++;
for (int i = 0; i < cnt; i++) {
Hash[sta[i]]--;
res += Hash[(1 << k) - 1];
for (int s0 = sta[i]; s0; s0 = (s0 - 1) & sta[i])
res += Hash[((1 << k) - 1) ^ s0];
Hash[sta[i]]++;
}
return res;
}
void dfs(int u) {
Max = n;
dfs_size(u, -1); dfs_root(u, u, -1);
ans += cal(root, (1 << a[root]));
vis[root] = true;
int cnt = tree[root].size(), rt = root;
for (int i = 0; i < cnt; i++) {
int v = tree[rt][i];
if (vis[v]) continue;
ans -= cal(v, (1 << a[rt]) | (1 << a[v]));
dfs(v);
}
}
int main() {
//freopen("in", "r", stdin);
while (scanf("%d%d", &n, &k) == 2) {
init();
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
--a[i];
}
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
tree[u].push_back(v);
tree[v].push_back(u);
}
if (k == 1) {
printf("%d\n", n * n);
continue;
}
ans = 0;
dfs(1);
printf("%lld\n", ans);
}
return 0;
}