题目大意:给你一棵树有n个节点,总共有k种颜色,每个节点有一种颜色,问你在这颗树上有多少种路径(起点或者终点不同的话就当作不同种路径)能访问完所有的颜色。
思路:第一次遇到这种题想着往树dp上然后发现状态有点多树dp解决不了,后面在网上查了题解才发现树的点分治这种解决树上路径统计的方法。而本题主要是在外面套上一个树分治加速查询所有的点对,再用高维前缀和快速在所有的点对种求出符合条件的起始点。
高维前缀和可以看这篇博客 :高维前缀和
下面主要讲解树分治的知识。
树分治关键的核心是找树的重心(为什么要找重心将在后面讲解)
树的重心的定义是:该节点的最大子树的节点数最小。所有找重心关键是要处理出每个节点的最大子节点数
如下为处理出每个节点的最大子节点数的函数(注意不是预处理,而是每次以当前节点为根找该有根树的重心
void dfs_size(int u,int pre)
{
maxv[u] = 0;
siz[u] = 1;
for(int i=0;i<G[u].size();i++)
{
int v = G[u][i].v;
int w = G[u][i].w;
if(v==pre || vis[v]) continue;
dfs_size(v,u); //dfs继续往下处理子节点
maxv[u] = max(maxv[u],siz[v]); //最大子树
siz[u] += siz[v]; 当前子树的大小
}
}
通过处理出所有节点的子树大小后,就能找出树的重心
void dfs_root(int r,int u,int pre)
{
maxv[u] = max(maxv[u],siz[r]-siz[u]); // 因为dfs是自底向上回溯的所以不能得到与父亲节点相连的那条边的子树大小,但是可以通过父亲为根的树的大小减去当前节点为根的树的大小就能得到父亲边的子树的大小
if(Max >maxv[u])
{
Max = maxv[u];
root = u;
}
for(int i=0;i<G[u].size();i++)
{
int v = G[u][i].v;
int w = G[u][i].w;
if(v==pre||vis[v]) continue;
dfs_root(r,v,u);
}
}
前面两步几乎是树的点分治的固定模板(也就是找树的重心),而后面视题目要求的不同来处理出经过该重心的所有子孙节点到重心的状态
对于一根有根树找所有的路径点对有两种情况:
1.路径经过根
2.路径经过以根的子节点为根的子树中
对于情况2可以通过情况1递归得到,所以主要讨论情况1
在一颗根确定的树中可以通过dfs得到该节点i到根所得到的状态设为sta[i],并将所有子节点的状态存入容器中。以此题为例
sta[i]表示从i点到根节点所访问过的颜色,可以用二进制表示有没访问过该颜色。
void dfs_sta(int u, int pre, int s) {
sta.push_back(s); //把当前节点u到根节点的路径的状态塞入容器中
int cnt = tree[u].size();
for (int i = 0; i < cnt; i++) {
int v = tree[u][i];
if (v == pre || vis[v]) continue;
dfs_sta(v, u, s | (1 << a[v]));
}
}
得到经过根的所有的路径后就要计算合法的路径,在该题中需要运用到高维前缀和
高维前缀和能够计算出s[i]的超集的个数,比如s[i]==1011 能够计算出是1011,1100,1101,1111的个数的和
ll cal(int u, int s) {
ll res = 0;
sta.clear(); dfs_sta(u, -1, s);
memset(cnt1, 0, sizeof(cnt1));
memset(cnt2, 0, sizeof(cnt2));
int cnt = sta.size();
for (int i = 0; i < cnt; i++) cnt1[sta[i]]++,cnt2[sta[i]]++;
int K = (1<<k)-1;
for(int i=0;i<k;i++)
{
for(int j=K;j>=0;j--)
{
if(!((1<<i)&j)) cnt2[j] += cnt2[(1<<i)|j];
}
}
for(int i=0;i<=K;i++)
{
res += cnt1[i]*cnt2[i^K];
}
return res;
}
还有一点就是 cal函数计算得出可能是一条在起始节点在同一子树的路径所以要再减去子节点算的cal 所以应该这样写总的dfs函数
void dfs(int u) {
Max = n;
dfs_size(u, -1); dfs_root(u, u, -1);
ans += cal(root, (1 << a[root]));
vis[root] = true;
int cnt = tree[root].size(), rt = root;
for (int i = 0; i < cnt; i++) {
int v = tree[rt][i];
if (vis[v]) continue;
ans -= cal(v, (1 << a[rt]) | (1 << a[v]));
dfs(v);
}
}
这里对于每一层子树都认为有n个节点,对于n个节点求重心和dis复杂度O(n),求点对数需要O(nlogn),这里假设递归一共进行L层,则算法复杂度为O(n*L*logn)。
为了使算法不退化,必须每次选择的树的重心作为根,具体证明在论文中都有,这样的算法复杂度为O(n*logn*logn)。
总的细节看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 5e5 + 10;
int n, k, Max, root;
ll ans;
vector <int> tree[MAXN];
vector <int> sta;
int sz[MAXN], maxv[MAXN], a[MAXN];
ll Hash[1200];
ll cnt1[1200],cnt2[1200];
bool vis[MAXN];
void init() {
memset(vis, false, sizeof(vis));
for (int i = 1; i <= n; i++) tree[i].clear();
}
void dfs_size(int u, int pre) {
sz[u] = 1; maxv[u] = 0;
int cnt = tree[u].size();
for (int i = 0; i < cnt; i++) {
int v = tree[u][i];
if (v == pre || vis[v]) continue;
dfs_size(v, u);
sz[u] += sz[v];
maxv[u] = max(maxv[u], sz[v]);
}
}
void dfs_root(int r, int u, int pre) {
maxv[u] = max(maxv[u], sz[r] - sz[u]);
if (Max > maxv[u]) {
Max = maxv[u];
root = u;
}
int cnt = tree[u].size();
for (int i = 0; i < cnt; i++) {
int v = tree[u][i];
if (v == pre || vis[v]) continue;
dfs_root(r, v, u);
}
}
void dfs_sta(int u, int pre, int s) {
sta.push_back(s);
int cnt = tree[u].size();
for (int i = 0; i < cnt; i++) {
int v = tree[u][i];
if (v == pre || vis[v]) continue;
dfs_sta(v, u, s | (1 << a[v]));
}
}
ll cal(int u, int s) {
ll res = 0;
sta.clear(); dfs_sta(u, -1, s);
memset(cnt1, 0, sizeof(cnt1));
memset(cnt2, 0, sizeof(cnt2));
int cnt = sta.size();
for (int i = 0; i < cnt; i++) cnt1[sta[i]]++,cnt2[sta[i]]++;
int K = (1<<k)-1;
for(int i=0;i<k;i++)
{
for(int j=K;j>=0;j--)
{
if(!((1<<i)&j)) cnt2[j] += cnt2[(1<<i)|j];
}
}
for(int i=0;i<=K;i++)
{
res += cnt1[i]*cnt2[i^K];
}
return res;
}
void dfs(int u) {
Max = n;
dfs_size(u, -1); dfs_root(u, u, -1);
ans += cal(root, (1 << a[root]));
vis[root] = true;
int cnt = tree[root].size(), rt = root;
for (int i = 0; i < cnt; i++) {
int v = tree[rt][i];
if (vis[v]) continue;
ans -= cal(v, (1 << a[rt]) | (1 << a[v]));
dfs(v);
}
}
int main() {
while (scanf("%d%d", &n, &k) == 2) {
init();
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
--a[i];
}
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
tree[u].push_back(v);
tree[v].push_back(u);
}
if (k == 1) {
printf("%d\n", n * n);
continue;
}
ans = 0;
dfs(1);
printf("%lld\n", ans);
}
return 0;
}