Weak Pair
Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 262144/262144 K (Java/Others)Total Submission(s): 597 Accepted Submission(s): 207
Problem Description
You are given a
rooted
tree of
N
nodes, labeled from 1 to
N
. To the
i
th node a non-negative value
a
i![]()
is assigned.An
ordered
pair of nodes
(u,v)
is said to be
weak
if
(1) u
is an ancestor of
v
(Note: In this problem a node
u
is not considered an ancestor of itself);
(2) a
u
×a
v
≤k
.
Can you find the number of weak pairs in the tree?
(1) u
(2) a
Can you find the number of weak pairs in the tree?
Input
There are multiple cases in the data set.
The first line of input contains an integer T
denoting number of test cases.
For each case, the first line contains two space-separated integers, N
and
k
, respectively.
The second line contains N
space-separated integers, denoting
a
1![]()
to
a
N![]()
.
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u
and
v
, where node
u
is the parent of node
v
.
Constrains:
1≤N≤10
5![]()
0≤a
i
≤10
9![]()
0≤k≤10
18![]()
The first line of input contains an integer T
For each case, the first line contains two space-separated integers, N
The second line contains N
Each of the subsequent lines contains two space-separated integers defining an edge connecting nodes u
Constrains:
1≤N≤10
0≤a
0≤k≤10
Output
For each test case, print a single integer on a single line denoting the number of weak pairs in the tree.
Sample Input
1 2 3 1 2 1 2
Sample Output
1
Source
题意:有根树,每个节点有非负值,若存在一对(u,v)满足u是v的祖先并且w[u]*w[v]<=k,则该对满足要求,问一共有多少对。
思路:线段树离散化存储所有值和k/w[u]的值,从根开始往下遍历,访问到节点i时,查询该已经出现的节点及其祖先节点小于k/w[i]的个数,这里用线段树维护,线段树存储的值即某个值已经出现的次数。
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cmath>
#include <cstdio>
using namespace std;
typedef long long ll;
const ll Maxn=1e5+100;
int n,numm,head[Maxn],tot,deep[Maxn];
ll num[Maxn],ans,k,x[2*Maxn];
struct node
{
ll maxn;
int l,r;
} t[6*Maxn];
void pushup(int i)
{
t[i].maxn=t[i<<1|1].maxn+t[i<<1].maxn;
}
void build(int i,int l,int r)
{
int mid=(l+r)/2;
t[i].l=l;
t[i].r=r;
if(l==r)
{
t[i].maxn=0;
return;
}
build(2*i,l,mid);
build(2*i+1,mid+1,r);
pushup(i);
}
ll query(int i,int x,int y)
{
int l=t[i].l;
int r=t[i].r;
int mid=(l+r)/2;
if(t[i].l>=x&&t[i].r<=y)
return t[i].maxn;
ll res=0;
if(x<=mid)
res+= query(2*i,x,y);
if(y>mid)
res+= query(2*i+1,x,y);
return res;
}
void Modify(int i,int x,int der){
int l,r,mid;
l = t[i].l;
r = t[i].r;
mid = (l + r) / 2;
if(l == r){
t[i].maxn += der;
return;
}
if(x <= mid) Modify(2*i,x,der);
else Modify(2*i+1,x,der);
pushup(i);
}
struct node1
{
int next;
int to;
} edge[Maxn];
void addedge(int from,int to)
{
edge[tot].to=to;
edge[tot].next=head[from];
head[from]=tot++;
}
void dfs(int u,int fa)
{
int o=lower_bound(x,x+numm,k/num[u])-x;
ans+=query(1,0,o);
int kk=lower_bound(x,x+numm,num[u])-x;
Modify(1,kk,1);
for(int i=head[u]; ~i; i=edge[i].next)
{
int v=edge[i].to;
dfs(v,u);
}
Modify(1,kk,-1);
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d%I64d",&n,&k);
int cnt=0;
for(int i=1; i<=n; i++)
{
scanf("%I64d",&num[i]);
x[cnt++]=num[i];
}
for(int i=1; i<=n; i++)
x[cnt++]=k/num[i];
sort(x,x+cnt);
numm=unique(x,x+cnt)-x;
tot=0;
memset(head,-1,sizeof(head));
memset(deep,0,sizeof(deep));
for(int i=0; i<n-1; i++)
{
int u,v;
scanf("%d %d",&u,&v);
addedge(u,v);
deep[v]++;
}
ans=0;
build(1,0,numm);
for(int i=1; i<=n; i++)
if(deep[i]==0)
dfs(i,-1);
printf("%I64d\n",ans);
}
return 0;
}