题目大意:
有一颗树,树上的每个端点都有序号,并且每个端点都有个权重。
问以某个端点为根节点的子树中,出现了k次的权重的个数。
解决方法:
先dfs一下,将树转化为区间。
先将区间排序(按照左端点的大小),建立一个vector <int> po[i] 数组,记录到目前为止,i出现的位置(按照顺序)。
然后从右向左对数字进行处理,并记录数字出现的位置,对于每个数进行如下操作:
<span style="font-family:KaiTi_GB2312;"><strong>if (po[vis[pos]].size()>=k){
add(po[vis[pos]][po[vis[pos]].size()-k],-2);
}
if (po[vis[pos]].size()>k){
add(po[vis[pos]][po[vis[pos]].size()-k-1],1);
}
po[vis[pos]].push_back(pos);
if (po[vis[pos]].size()>=k){
add(po[vis[pos]][po[vis[pos]].size()-k],1);
}</strong></span>
求某段区间的答案就是sum(右区间)。
我的代码:
#include <cstdio>
#include <vector>
#include <algorithm>
#include <iostream>
#include <cstring>
#define maxn 1001000
using namespace std;
int n,k,m;
struct Node {
int x,y;
}lsh[maxn],query[maxn];
struct quetion{
int x,y,id,ans;
}query1[maxn];
int num[maxn];
int vis[maxn],totl;
vector <int> nod[maxn];
bool cmp1(Node a,Node b){
return a.x<b.x||a.x==b.x&&a.y<b.y;
}
int cmp2(quetion a,quetion b){
return a.y<b.y||a.y==b.y&&a.x<b.x||a.x==b.x&&a.y==b.y&&a.id<b.id;
}
int cmp3(quetion a,quetion b){
return a.id<b.id;
}
int lshb() {//数据离散化
int totl=2;
sort(lsh+1,lsh+n+1,cmp1);
num[lsh[1].y]=1;
for (int i=2;i<=n;i++){
if (lsh[i].x==lsh[i-1].x){
num[lsh[i].y]=totl-1;
}
else {
num[lsh[i].y]=totl++;
}
}
return 0;
}
int dfs(int v) { //求欧拉序列
query[v].x=totl++;
vis[totl-1]=num[v];
for (int i=0;i<nod[v].size();i++)
dfs(nod[v][i]);
query[v].y=totl-1;
return 0;
}
int tree[maxn];
int add(int x,int v){
for (int i=x;i<=n;i+=i&(-i)) tree[i]+=v;
return 0;
}
int sum(int x){
int ans=0;
for (int i=x;i>0;i-=i&(-i)) ans+=tree[i];
return ans;
}
vector <int> po[maxn];
int change(int pos){
if (po[vis[pos]].size()>=k){
add(po[vis[pos]][po[vis[pos]].size()-k],-2);
}
if (po[vis[pos]].size()>k){
add(po[vis[pos]][po[vis[pos]].size()-k-1],1);
}
po[vis[pos]].push_back(pos);
if (po[vis[pos]].size()>=k){
add(po[vis[pos]][po[vis[pos]].size()-k],1);
}
return 0;
}
int main (){
//freopen("test.in","r",stdin);
int T,t=1;scanf("%d",&T);
while (T--){
memset(tree,0,sizeof(tree));
printf("Case #%d:\n",t++);
scanf("%d%d",&n,&k);
for (int i=1;i<=n;i++)
scanf("%d",&lsh[i].x),lsh[i].y=i,nod[i].clear(),po[i].clear();
lshb();
//for (int i=1;i<=n;i++) printf("%d ",num[i]);printf("\n");
for (int i=1;i<=n-1;i++){
int a,b;scanf("%d%d",&a,&b);
nod[a].push_back(b);
}
totl=1;dfs(1);scanf("%d",&m);
for (int i=1;i<=m;i++){
int a;scanf("%d",&a);
query1[i].x=query[a].x;
query1[i].y=query[a].y;
query1[i].id=i;
query1[i].ans=0;
}
sort(query1+1,query1+1+m,cmp2);int pos=1;
for (int i=1;i<=m;i++){
while (pos<=query1[i].y){
change(pos);
pos++;
}
if (query1[i].x>1) query1[i].ans=sum(query1[i].y)-sum(query1[i].x-1);
else query1[i].ans=sum(query1[i].y);
}
sort(query1+1,query1+1+m,cmp3);
for (int i=1;i<=m;i++)
printf("%d\n",query1[i].ans);
if (T!=0) printf("\n");
}
return 0;
}