You are given a rooted tree of N nodes, labeled from 1 to N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes (u,v) is said to be weak if
(1) u is an ancestor of v (Note: In this problem a node u is not considered an ancestor of itself);
(2) au×av≤k
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer T denoting number of test cases.
For each case, the first line contains two space-separated integers, NN and kk, respectively.
The second line contains N space-separated integers, denoting a1 to aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes uu and vv , where node uu is the parent of node vv.
Constrains:
1≤N≤105
0≤ai≤109
0≤k≤1018
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1
2 3
1 2
1 2
Sample Output
1
犯了一个超级超级低级错误,数组开小了10倍,一直和我说超时,实在没办法就重新写了一遍才发现,wa到怀疑人生GG
见上一个同类型的题目
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<vector>
#include<map>
#define N 100005
using namespace std;
typedef long long ll;
int n;
ll k;
ll a[N];
ll id[200005];
int cnt,num;
bool inD[N];
ll ans;
int getIndex(ll x)//手写了一个查找离散坐标的函数
{
int l=1,r=num;
int mid;
while(l<=r)
{
mid=(r+l)>>1;
if(id[mid]==x)
return mid;
if(id[mid]<x)
l=mid+1;
else
r=mid-1;
}
}
vector<int> graph[N];
int c[200005];
int lowBit(int x)
{
return x&-x;
}
int sum(int x)
{
int ans=0;
while(x>0)
{
ans+=c[x];
x-=lowBit(x);
}
return ans;
}
void change(int x,int p)
{
while(x<=num)
{
c[x]+=p;
x+=lowBit(x);
}
}//树状数组的三个操作
void dfs(int x)
{
int index=getIndex(a[x]);
if(!a[x])
ans+=sum(num);
else
ans+=sum(getIndex(k/a[x]));
change(index,1);
for(int i=0;i<graph[x].size();i++)
dfs(graph[x][i]);
change(index,-1);//消除多兄弟节点的影响
}
int main()
{
int t;
scanf("%d",&t);
while(t--)
{
scanf("%d%lld",&n,&k);
cnt=0;
ans=0;
memset(inD,false,sizeof(inD));
memset(c,0,sizeof(c));
for(int i=1;i<=n;i++)
graph[i].clear();
for(int i=1;i<=n;i++)
{
scanf("%d",a+i);
id[++cnt]=a[i];
if(a[i])
id[++cnt]=k/a[i];//离散处理
}
sort(id+1,id+cnt+1);
num=1;
for(int i=2;i<=cnt;i++)//去重
{
if(id[num]!=id[i])
id[++num]=id[i];
}
int x,y;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
inD[y]=true;
graph[x].push_back(y);
}
for(int i=1;i<=n;i++)
{
if(!inD[i])//找到root节点
{
dfs(i);
break;
}
}
printf("%lld\n",ans);
}
return 0;
}