Description:
NiroBC 姐姐是个活泼的少女,她十分喜欢爬树,而她家门口正好有一棵果树,正好满足了她爬树的需求。
这颗果树有N个节点,节点标号 1…N。每个节点长着一个果子,第i个节点上的果子颜色为 Ci 。
NiroBC姐姐每天都要爬树,每天都要选择一条有趣的路径 (u,v) 来爬。
一条路径被称作有趣的,当且仅当这条路径上的果子的颜色互不相同。
(u,v) 和 (v,u) 被视作同一条路径。特殊地,(i,i) 也被视作一条路径,这条路径只含 i 一个果子,显然是有趣的。
NiroBC姐姐想知道这颗树上有多少条有趣的路径。
题解:
同一颜色这么少肯定暴力枚举了。
假设有两个点,x,y,它们的颜色相同。
好像和子树有关,建dfs序。
如果x,y不为祖先和孙子的关系,那么[p[x]..q[y],p[y]..q[y]]是不合法的。
如果是,并且x是y的孙子,设z为y到x路径上的第二个点,则[p[x]..q[x],!(p[z]..q[z])]
于是就变成了经典的矩形覆盖问题。
用扫描线+不下传标记线段树可以轻松解决。
我不知道我之前到底有没有打过不下传标记线段树.
比赛时打完了所有的东西,就线段树这个地方卡了,我发现直接线段树做不出来,又没有想到不下传标记线段树,亏死了。
不下传标记线段树也非常局限,一般只能用矩形覆盖这种问题。
Code:
#include<cstdio>
#include<algorithm>
#define ll long long
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
using namespace std;
const int N = 1e5 + 5, M = N * 20;
int n, c[N], x, y, b[N];
int next[N * 2], to[N * 2], final[N], tot;
int dfn[N], td, p[N], q[N], bz[N];
int fa[17][N], dep[N]; ll ans;
void link(int x, int y) {
next[++ tot] = final[x], to[tot] = y, final[x] = tot;
next[++ tot] = final[y], to[tot] = x, final[y] = tot;
}
void dg(int x) {
bz[x] = 1; dfn[x] = p[x] = ++ td;
for(int i = final[x]; i; i = next[i]) {
int y = to[i]; if(bz[y]) continue;
fa[0][y] = x; dep[y] = dep[x] + 1;
dg(y);
}
bz[x] = 0; q[x] = td;
}
int pdf(int x, int y) {
fd(i, 16, 0) if(dep[fa[i][y]] >= dep[x]) y = fa[i][y];
return x == y;
}
int xx(int x, int y) {
fd(i, 16, 0) if(dep[fa[i][y]] > dep[x]) y = fa[i][y];
return y;
}
int cmp(int x, int y) {return c[x] < c[y];}
struct ak {
int next[M * 4], l[M * 4], r[M * 4], z[M * 4], final[N], tot;
void link(int x, int p, int q, int w) {
next[++ tot] = final[x], l[tot] = p, r[tot] = q, z[tot] = w, final[x] = tot;
}
} e;
int t[N * 4], lz[N * 4], pl, pr, px;
void add(int i, int x, int y) {
if(y < pl || x > pr) return;
if(x >= pl && y <= pr) {lz[i] += px; return;}
int m = x + y >> 1;
add(i + i, x, m); add(i + i + 1, m + 1, y);
t[i] = (lz[i + i] ? m - x + 1 : t[i + i]) + (lz[i + i + 1] ? y - m : t[i + i + 1]);
}
int main() {
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
scanf("%d", &n);
fo(i, 1, n) scanf("%d", &c[i]);
fo(i, 1, n - 1) {
scanf("%d %d", &x, &y);
link(x, y);
}
dep[1] = 1; dg(1);
fo(i, 1, 16) fo(j, 1, n) fa[i][j] = fa[i - 1][fa[i - 1][j]];
fo(i, 1, n) b[i] = i;
sort(b + 1, b + n + 1, cmp);
fo(i, 1, n) fd(j, i - 1, 1) {
x = b[i]; y = b[j];
if(c[x] != c[y]) break;
if(dep[x] > dep[y]) swap(x, y);
if(pdf(x, y)) {
int z = xx(x, y);
e.link(p[y], 1, p[z] - 1, 1);
e.link(q[y] + 1, 1, p[z] - 1, -1);
e.link(q[z] + 1, p[y], q[y], 1);
} else {
if(dfn[x] < dfn[y]) swap(x, y);
e.link(p[x], p[y], q[y], 1);
e.link(q[x] + 1, p[y], q[y], -1);
}
}
fo(i, 1, n) {
for(int j = e.final[i]; j; j = e.next[j])
pl = e.l[j], pr = e.r[j], px = e.z[j], add(1, 1, n);
ans += lz[1] ? n : t[1];
}
ans = (ll) n * n - ans * 2;
printf("%lld\n", (ans - n) / 2 + n);
}