题目大意:在一棵树上求路径长度等于 D,且路径 gcd > 1 \gcd > 1 gcd>1 的路径总数。
n ≤ 5 ∗ 1 0 5 , D ≤ 1 0 4 , a i ≤ 3 ∗ 1 0 4 n \leq 5*10^5,D \leq 10^4,a_i \leq 3*10^4 n≤5∗105,D≤104,ai≤3∗104
点分的做法:枚举根节点的因子 p p p ( μ ( p ) ≠ 0 \mu(p) \ne 0 μ(p)=0 ,即题解的反演),一条路径的 gcd > 1 \gcd > 1 gcd>1,相当于路径上所有的点权都能整除 p p p,用unordered_map 维护深度信息,复杂度为 O ( 32 ∗ n log n ) O(32 * n \log n) O(32∗nlogn)
长链剖分的做法是类似的。(待补)
点分代码(数据被加强了无法通过):
#include<bits/stdc++.h>
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
typedef long long ll;
unordered_map<int,int> mp,tp;
int c,n,d;
const int maxn = 1e6 + 10;
const int M = 3e4;
bool ispri[M + 5];
int pri[M + 5],mu[M + 5];
int to[maxn],nxt[maxn],head[maxn],cnt; //前向星
int sz[maxn / 2],f[maxn / 2],root,a[maxn / 2],tot; //分治
bool done[maxn / 2];
ll res = 0;
vector<pii> s[M + 5];
void init() {
cnt = res = 0;
fill(head,head + n + 1,-1);
fill(done,done + n + 1,0);
}
void sieve(int n) {
mu[1] = 1;ispri[0] = ispri[1] = true;
pri[0] = 0;
for(int i = 2; i <= n; i++) {
if(!ispri[i]) pri[++pri[0]] = i,mu[i] = -1;
for(int j = 1; j <= pri[0] && i * pri[j] <= n; j++) {
ispri[i * pri[j]] = true;
if(i % pri[j] == 0)
break;
mu[i * pri[j]] = -mu[i];
}
}
for(int i = 1; i <= M; i++) {
for(int j = 1; j * j <= i; j++) {
if(i % j == 0) {
if(mu[j] != 0)
s[i].push_back(pii(j,mu[j]));
if(i / j != j && mu[i / j] != 0)
s[i].push_back(pii(i / j,mu[i / j]));
}
}
}
}
void add(int u,int v) {
to[cnt] = v;
nxt[cnt] = head[u];
head[u] = cnt++;
}
void getroot(int u,int fa) {
sz[u] = 1;f[u] = 0;
for(int i = head[u]; i + 1; i = nxt[i]) {
if(to[i] == fa || done[to[i]]) continue;
getroot(to[i],u);
sz[u] += sz[to[i]];
f[u] = max(f[u],sz[to[i]]);
}
f[u] = max(f[u],tot - sz[u]);
if(f[u] < f[root] || !root) root = u;
}
void dfs(ll &ans,int u,int fa,int deep,pii x) {
if(a[u] % x.fir != 0 || deep > d) return ;
tp[deep]++; ans += 2ll * x.sec * mp[d - deep];
for(int i = head[u]; i + 1; i = nxt[i]) {
if(to[i] == fa || done[to[i]]) continue;
dfs(ans,to[i],u,deep + 1,x);
}
}
ll solve(int u) {
ll ans = 0;
done[u] = true;
for(auto it : s[a[u]]) {
mp.clear();mp[0]++;
for(int i = head[u]; i + 1; i = nxt[i]) {
if(done[to[i]]) continue;
tp.clear();dfs(ans,to[i],u,1,it);
for(auto d : tp)
mp[d.fir] += d.sec;
}
}
return ans;
}
void divide(int rt) {
res += solve(rt);
for(int i = head[rt]; i + 1; i = nxt[i]) {
if(done[to[i]]) continue;
root = 0;tot = sz[to[i]];
getroot(to[i],-1);
divide(root);
}
}
int main() {
sieve(M);
scanf("%d",&c);
int siz = 0;
while(c--) {
scanf("%d%d",&n,&d);
init();
for(int i = 1; i <= n; i++)
scanf("%d",&a[i]);
for(int i = 1; i < n; i++) {
int u,v;scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
tot = n;root = 0;getroot(1,-1);
divide(root);
printf("Case #%d:%lld\n",++siz,res);
}
return 0;
}
update:
写出了长链剖分的做法,直接对3e4以内所有
m
u
[
p
]
≠
0
mu[p] \ne 0
mu[p]=0 的 p,计算路径上数字都是
p
p
p 的倍数的合法路径数量,对路径计数直接套长链剖分。
整个过程看起来非常暴力,但实际上每个数字只有最多32个这样的因子 p,即每个点只会被长链剖分处理32次,因此复杂度是均摊的,为
c
∗
32
∗
n
c*32*n
c∗32∗n,c 是一个不可忽视的常数
数据在计蒜客上得到了究极加强,卡不过最后一个点
长链剖分代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 10;
const int M = 3e4;
bool ispri[M + 5];
int pri[M + 5],mu[M + 5];
vector<int> g[maxn],p[M + 5],h[M + 5];
int n,d,t,a[maxn / 2],vis[maxn / 2];
int head[maxn],to[maxn],nxt[maxn],edg;
int len[maxn / 2],son[maxn / 2],cnt;
int tmp[maxn / 2],*id,*dp[maxn / 2];
ll ans = 0;
inline int read(){
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
void add(int u,int v) {
to[edg] = v;
nxt[edg] = head[u];
head[u] = edg++;
}
void init() {
ans = edg = 0;
for(int i = 1; i <= M; i++)
p[i].clear();
fill(head,head + n + 1,-1);
for(int i = 1; i <= n; i++)
vis[i] = 0;
}
void sieve(int n) {
mu[1] = 1;ispri[0] = ispri[1] = true;
pri[0] = 0;
for(int i = 2; i <= n; i++) {
if(!ispri[i]) pri[++pri[0]] = i,mu[i] = -1;
for(int j = 1; j <= pri[0] && i * pri[j] <= n; j++) {
ispri[i * pri[j]] = true;
if(i % pri[j] == 0)
break;
mu[i * pri[j]] = -mu[i];
}
}
for(int i = 1; i <= M; i++) {
for(int j = 1; j * j <= i; j++) {
if(i % j == 0) {
if(mu[j] != 0)
h[i].push_back(j);
if(i / j != j && mu[i / j] != 0)
h[i].push_back(i / j);
}
}
}
}
void prework(int u,int fa,int val) {
cnt++;
len[u] = son[u] = 0;
for(int i = head[u]; i + 1; i = nxt[i]) {
int v = to[i];
if(v == fa || a[v] % val) continue;
prework(v,u,val);
if(!son[u] || len[son[u]] < len[v])
son[u] = v;
}
len[u] = len[son[u]] + 1;
}
void dfs(int u,int fa,int val,int num,int op) {
vis[u] = num;dp[u][0] = 1;
if(son[u] && a[son[u]] % val == 0) {
dp[son[u]] = dp[u] + 1;
dfs(son[u],u,val,num,op);
}
if(len[u] > d)
ans += dp[u][d] * op;
for(int i = head[u]; i + 1; i = nxt[i]) {
int v = to[i];
if(v == fa || v == son[u] || a[v] % val) continue;
dp[v] = id,id += len[v],dfs(v,u,val,num,op);
for(int i = 0; i < len[v] && i < d; i++)
if(d - i - 1 < len[u])
ans += 1ll * dp[v][i] * dp[u][d - i - 1] * op;
for(int i = 0; i < len[v] && i < d; i++)
dp[u][i + 1] += dp[v][i];
}
}
int main() {
sieve(M);
scanf("%d",&t);
int sz = 0;
while(t--) {
scanf("%d%d",&n,&d);
init();
for(int i = 1; i <= n; i++) {
a[i] = read();
for(auto v : h[a[i]])
p[v].push_back(i);
}
for(int i = 1,u,v; i < n; i++) {
u = read(),v = read();
add(u,v);add(v,u);
}
ans = 0;
for(int i = 2; i <= M; i++) {
if(!p[i].size()) continue;
for(auto rt : p[i]) {
if(vis[rt] == i) continue;
cnt = 0,prework(rt,0,i);
for(int i = 0; i <= cnt; i++) tmp[i] = 0;
id = tmp,dp[rt] = id,id += len[rt],dfs(rt,0,i,i,-mu[i]);
}
}
printf("Case #%d: %lld\n",++sz,ans * 2);
}
return 0;
}