好题,今天又学习了一波儿。
题意:
n个节点的树,节点的点权为ai,要求找出有多少个二元组(u,v)满足
1:u是v的祖先且u!=v
2:a[u]*a[v]<=K
dfs访问到一个节点的时候看他的祖先节点有没有和他相乘小于K的,用树状数组维护他的祖先出现的元素,计算结果。访问到一个节点时,维护树状数组,回溯时删掉。这题数据很大,所以要加离散化。 好题,好题。
#include<cstdio>
#include<algorithm>
#include<vector>
#include<cstring>
using namespace std;
const int maxn = 101000;
int n,num;
int du[maxn];
long long k;
long long da[maxn];
long long tree[maxn];
long long sum[maxn];
vector<int>G[maxn];
inline int lowbit(int i)
{
return i&(-i);
}
inline void add(int i,int x)
{
while(i<=num)
{
//printf("sdfdsdfs\n");
sum[i]+=x;
i += lowbit(i);
}
}
inline long long query(int i)
{
long long summ = 0;
while(i>=1)
{
summ += sum[i];
i -= lowbit(i);
}
return summ;
}
inline int find(long long a)
{
return upper_bound(tree+1,tree+1+num,a)-tree-1;
}
long long ans = 0;
void dfs(int u)
{
//printf("%d\n",u);
add(find(da[u]),1);
int len = G[u].size();
for(int i=0;i<len;i++)
{
int v = G[u][i]; dfs(v);
}
add(find(da[u]),-1);
long long tmp ;
if(da[u]==0) tmp = tree[num]+1;
else tmp = k/da[u];
if(tmp>tree[num]) tmp = tree[num]+1;
ans += query(find(tmp));
}
int main()
{
int cases,u,v;
scanf("%d",&cases);
while(cases--)
{
scanf("%d%I64d",&n,&k);
for(int i=1;i<=n;i++) {scanf("%I64d",&da[i]); tree[i] = da[i];}
sort(tree+1,tree+n+1);
num = 1;
for(int i=2;i<=n;i++)
if(tree[num]!=tree[i])
tree[++num] = tree[i];
//printf("%d\n",num);
for(int i=1;i<=n;i++) G[i].clear(); memset(du,0,sizeof(du));
for(int i=2;i<=n;i++)
{
scanf("%d%d",&u,&v);
G[u].push_back(v); du[v] ++;
}
memset(sum,0,sizeof(sum));
ans = 0ll;
for(int i=1;i<=n;i++) if(du[i]==0) dfs(i);
printf("%I64d\n",ans);
}
return 0;
}