题目大意:
1<=n<=10^5
题解:
假设有两点u,v其中v = i *u(2<=i<=n/u)
如果有一路径a,b经过了它们两,那么a,b就是不合法的路径。
我们可以尝试找出这样的路径。
先求出树的dfs序,设为dfn[x]
end[x]为x的子树中最大的dfn[x’]
使dfn[u] < dfn[v]:
1.u是v的祖先:
设g为u到v的路径中距离u最近的那个点。
则跨过u,v的路径a,b(dfn[a] <= dfn[b])有:
1.1 dfn[a] < dfn[g], dfn[v] <= dfn[b] <= end[v]
1.2 dfn[v] <= dfn[a] <= end[v], dfn[b] > end[g]
2.u不是v的祖先:
dfn[u] <= dfn[a] <= end[u], dfn[v] <= dfn[b] <= end[v]
然后我们发现这就是求矩形覆盖点数。
可以考虑扫描线+线段树。
于是我们发现线段树里面,每个点的值是它被几个矩形覆盖了,但是我们需要求的是至少被一个覆盖的有多少个点,怎么办呢?
线段树的具体实现:
由于每次的查询都是查询最大的区间[1..n],所以我们可以使标记不下传。
如果i代表的区间的标记大于0,则该区间的至少被一个矩形覆盖的点数是该区间的大小,否则由它的子区间得到。
Code:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
using namespace std;
const int Maxn = 100005;
int n, x, y, bz[Maxn];
int f[17][Maxn], dfn[Maxn], df, end[Maxn], dep[Maxn];
int final[Maxn], tot;
struct edge {
int next, to;
}e[Maxn * 2];
void link(int x, int y) {
e[++ tot].next = final[x], e[tot].to = y, final[x] = tot;
e[++ tot].next = final[y], e[tot].to = x, final[y] = tot;
}
void dg(int x) {
bz[x] = 1; dfn[x] = ++ df;
for(int i = final[x]; i; i = e[i].next) {
if(!bz[e[i].to])
f[0][e[i].to] = x, dep[e[i].to] = dep[x] + 1, dg(e[i].to);
}
end[x] = df;
}
void Init() {
scanf("%d", &n);
fo(i, 1, n) {
scanf("%d %d", &x, &y);
link(x, y);
}
dep[1] = 1; dg(1);
fo(i, 1, 16) fo(j, 1, n)
f[i][j] = f[i - 1][f[i - 1][j]];
}
struct node {
int next, x, y, c;
}ee[Maxn * 100];
#define e ee
void link1(int a, int b, int x, int y) {
e[++ tot].next = final[a], e[tot].x = x, e[tot].y = y, e[tot].c = 1, final[a] = tot;
b ++;
e[++ tot].next = final[b], e[tot].x = x, e[tot].y = y, e[tot].c = -1, final[b] = tot;
}
#define link link1
int sg(int u, int v) {
fd(i, 16, 0) if(dep[f[i][v]] > dep[u]) v = f[i][v];
return v;
}
void bb(int u, int v) {
if(dfn[u] > dfn[v]) swap(u, v);
if(dfn[v] >= dfn[u] && dfn[v] <= end[u]) {
int g = sg(u, v);
link(1, dfn[g] - 1, dfn[v], end[v]);
link(dfn[v], end[v], end[g] + 1, n);
} else {
link(dfn[u], end[u], dfn[v], end[v]);
}
}
void Build() {
tot = 0;
memset(final, 0, sizeof(final));
fo(u, 1, n)
fo(i, 2, n / u)
bb(u, u * i);
}
struct tree {
int s, b;
}d[Maxn * 10];
void update(int i, int x, int y) {
if(d[i].b > 0) d[i].s = y - x + 1; else d[i].s = d[i + i].s + d[i + i + 1].s;
}
void add(int i, int x, int y, int l, int r, int z) {
if(l > r) return;
if(x == l && y == r) {
d[i].b += z; update(i, x, y);
return;
}
int m = (x + y) / 2;
if(r <= m) add(i + i, x, m, l, r, z); else
if(l > m) add(i + i + 1, m + 1, y, l, r, z); else
add(i + i, x, m, l, m, z), add(i + i + 1, m + 1, y, m + 1, r, z);
update(i, x, y);
}
void End() {
long long ans = 0;
fo(i, 1, n) {
for(int k = final[i]; k; k = e[k].next)
add(1, 1, n, e[k].x, e[k].y, e[k].c);
ans = ans + d[1].s;
}
printf("%lld", (long long) n * (n - 1) / 2 - ans);
}
int main() {
Init();
Build();
End();
}