Weak Pair
http://acm.hdu.edu.cn/showproblem.php?pid=5877
Problem Description
You are given a rooted tree of
N
N
N nodes, labeled from 1 to
N
N
N. To the ith node a non-negative value ai is assigned.An ordered pair of nodes
(
u
,
v
)
(u,v)
(u,v) is said to be weak if
(1)
u
u
u is an ancestor of
v
v
v (Note: In this problem a node
u
u
u is not considered an ancestor of itself);
(2)
a
u
×
a
v
≤
k
{a_u×a_v}\leq{k}
au×av≤k.
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer
T
T
T denoting number of test cases.
For each case, the first line contains two space-separated integers,
N
N
N and
k
k
k, respectively.
The second line contains
N
N
N space-separated integers, denoting
a
1
a_1
a1 to
a
N
a_N
aN.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes
u
u
u and
v
v
v , where node
u
u
u is the parent of node
v
v
v.
Constrains:
1 ≤ N ≤ 1 0 5 1≤N≤10^{5} 1≤N≤105
0 ≤ a i ≤ 1 0 9 0≤a_i≤10^{9} 0≤ai≤109
0 ≤ k ≤ 1 0 18 0≤k≤10^{18} 0≤k≤1018
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1
2 3
1 2
1 2
Sample Output
1
题目:
给你一棵有
N
N
N个节点的树,并且每个节点有对应的一个值,现在定义一个叫WEAK PAIR的东西(u,v),当满足一下条件是(u,v)就是WEAK PAIR:
1.u是v的祖先
2.
a
u
×
a
v
≤
k
{a_u×a_v}\leq{k}
au×av≤k
然后求树里有多少个WEAK PAIR
题解:
这题用了经典题目求逆序数的思想,用树状数组来维护比当前数要小的数的个数,然后用dfs遍历一遍就能把问题解决。
但应该要注意的是,由于k是
1
0
18
10^{18}
1018所以数组是不可能存的下的,因此要进行离散化操作就是通过结构体把节点分成值(val)和编号(id),再按值进行排序,然后映射到一个新的数组tr[]上,这里满足tr[a[i].id]=i(这里的i是指排序之后的第几个),但要注意的就是里面的值可能是一样的,因此相同的要映射到同一个下标。
补充:
由于样例的数据太弱这里给出我编的数据:
input
3
9 10
2 4 7 9 2 11 7 4 1
1 2
1 3
2 4
2 5
3 6
3 7
3 8
6 9
9 10
2 4 7 9 2 11 7 4 1
1 2
1 3
2 4
2 5
3 6
3 7
3 8
6 9
8 10
2 4 7 9 2 11 7 4
1 2
1 3
2 4
2 5
3 6
3 7
3 8
output
6
6
4
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e6+5;
int vis[maxn];
int Ans[maxn<<1];
int tr[maxn<<1];
long long ans;
int n;
vector<int>C[maxn<<1];
struct node
{
int id;
long long val;
}a[maxn];
int lowbits(int i)
{
int temp=i&(-i);
// cout<<i<<" "<<temp<<endl;
return temp;
}
void add(int x,int k)
{
while(x<=2*n)
{
Ans[x]+=k;
x+=lowbits(x);
// cout<<"-------"<<endl;
}
}
int getSum(int x)
{
int ans=0;
for(int i=x;i>0;i-=lowbits(i))
{
ans+=Ans[i];
}
return ans;
}
//void check()
//{
// for(int i=1;i<=2*n;i++)
// {
// cout<<setw(5)<<i;
// }
// cout<<endl;
// for(int i=1;i<=2*n;i++)
// {
// cout<<setw(5)<<Ans[i];
// }
// cout<<endl;
//}
void dfs(int u)
{
ans+=getSum(tr[u+n]);
add(tr[u],1);
// check();
// cout<<"u="<<u<<" ans="<<ans<<endl;
int len=C[u].size(),v;
for(int i=0;i<len;i++)
{
// cout<<"add__"<<endl;
dfs(C[u][i]);
// cout<<"del__"<<endl;
}
add(tr[u],-1);
// check();
}
bool cmp(const node& a,const node& b)
{
if(a.val==b.val)
return a.id<b.id;
else
return a.val<b.val;
}
int main()
{
int T,i,j,u,v;
long long k;
scanf("%d",&T);
while(T--)
{
scanf("%d%lld",&n,&k);
ans=0;
memset(vis,0,sizeof(vis));
memset(tr,0,sizeof(tr));
memset(a,0,sizeof(a));
for(i=1;i<=n;i++)
{
scanf("%lld",&a[i].val);
a[i+n].val=k/a[i].val;
a[i].id=i;
a[i+n].id=i+n;
C[i].clear();
}
sort(a+1,a+1+2*n,cmp);
int temp=1;
for(i=1;i<=2*n;i++)
{
tr[a[i].id]=temp;
if(a[i].val!=a[i+1].val)
{
temp++;
}
}
// for(i=1;i<=2*n;i++)
// {
// cout<<tr[i]<<" ";
// }cout<<endl;
for(i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
C[u].push_back(v);
vis[v]++;
}
int s;
for(i=1;i<=n;i++)
{
if(!vis[i])
{
s=i;
break;
}
}
// add(tr[s],1);
dfs(s);
// add(tr[s],-1);
cout<<ans<<endl;
// cout<<"--------"<<endl;
}
}