题目链接:桃子的主席树
题目:
样例输入:
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
样例输出
2
8
9
105
7
没学过主席树的话,最好先去学一下主席树,这道题并不是板子题。
板子题的话可以看AcWing中的第k小数
此处链接:AcWing255.第K小数
以样例为参照,描述该题的解题过程:
若以1为根节点,那么树的形状就如上图。
那么主席树的建立就以此来建立。
在最开始建造第0个版本的主席树,当中没有任何一个节点。
然后从1号点开始搜索(BFS或者DFS都可以),每个节点都以其父节点为上一个版本(1号点以第0个版本为上一个版本),以此来建树。
那么可以发现对于任意一个节点,他的版本所包含的点的信息只与其父节点有关,而其父节点的信息又只与其父节点的父节点有关。所以任意一个节点所包含的信息就是根节点1到该节点路径上的所有信息。
比如:
1号节点的权值是105,那么版本1所包含的信息就是105,因为他的父节点为空;
3号节点的权值是9,那么他包含的节点有他本身和他的父节点包含的所有信息,即105和9;
5号节点的权值是8,那么他包含的节点有他本身和他的父节点包含的所有信息,即105,9和8;
那么以此建树的话,就可以发现,若要看5到7的路径上的第K小数,那么只要把5号节点的信息与7号节点的信息结合起来,但是结合起来后发现,5号节点和7号节点都存再3号节点的信息,所以要减去一个3号节点的信息,但减完后发现,还存在一个多余的1号节点的信息,所以再减掉一个1号节点的信息。
综上:任意两个点的路径上的第K小数,只要求和他们两个版本的信息,并减去一个他们最近公共祖先上的信息,再减去最近公共祖先的父节点上的信息即为答案。
注:由于每个点的权值很大,进行离散化操作。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>
#include <queue>
using namespace std;
const int maxn=1e5+50;
const int inf=0x3f3f3f3f;
int root[maxn],a[maxn];//作主席树的数组
int h[maxn];//作树中邻接表的数组
int depth[maxn],f[maxn][25];//作lca的数组
vector<int> head;//离散化的vector
int n,m,idx,tot;
struct Edge{
int v,nxt;
}edge[maxn<<2];
struct Node{
int l,r;
int cnt;
}tr[maxn<<5];
void add(int u,int v){
edge[idx].v=v;
edge[idx].nxt=h[u];
h[u]=idx++;
}
int get_pos(int x){
return lower_bound(head.begin(),head.end(),x)-head.begin();
}
void pushup(int u){
tr[u].cnt=tr[tr[u].l].cnt+tr[tr[u].r].cnt;
}
int build(int l,int r){
int p=++tot;
if(l==r) return p;
int mid=l+r>>1;
tr[p].l=build(l,mid);
tr[p].r=build(mid+1,r);
return p;
}
int insert(int p,int l,int r,int x){
int q=++tot;
tr[q]=tr[p];
if(l==r){
tr[q].cnt++;
return q;
}
int mid=l+r>>1;
if(mid>=x) tr[q].l=insert(tr[p].l,l,mid,x);
else tr[q].r=insert(tr[p].r,mid+1,r,x);
pushup(q);
return q;
}
void BFS(){
memset(depth,inf,sizeof(depth));
depth[0]=0;depth[1]=1;
queue<int> que;
que.push(1);
root[1]=insert(root[0],0,head.size()-1,get_pos(a[1]));//第1好点从0版本获取
while(que.size()){
int u=que.front();
que.pop();
for(int i=h[u];~i;i=edge[i].nxt){
int v=edge[i].v;
if(depth[v]>depth[u]+1){
root[v]=insert(root[u],0,head.size()-1,get_pos(a[v]));
//每个节点从其父节点的版本来衍生出自生版本的信息
depth[v]=depth[u]+1;
f[v][0]=u;
que.push(v);
for(int i=1;i<=20;i++)//此处用倍增法求LCA
f[v][i]=f[f[v][i-1]][i-1];
}
}
}
}
int query(int x,int y,int p,int q,int l,int r,int k){
//x,y表示两个节点的版本,p表示两个节点的祖先版本,q表示p的父节点的版本
if(l==r) return head[l];
int cnt=tr[tr[x].l].cnt+tr[tr[y].l].cnt-tr[tr[q].l].cnt-tr[tr[p].l].cnt;
int mid=l+r>>1;
if(cnt>=k) return query(tr[x].l,tr[y].l,tr[p].l,tr[q].l,l,mid,k);
return query(tr[x].r,tr[y].r,tr[p].r,tr[q].r,mid+1,r,k-cnt);
}
int lca(int x,int y){//倍增法求LCA
if(depth[x]<depth[y]) swap(x,y);
for(int i=20;i>=0;i--){
if(depth[f[x][i]]>=depth[y]){
x=f[x][i];
}
}
if(x==y) return x;
for(int i=20;i>=0;i--){
if(f[x][i]!=f[y][i]){
x=f[x][i];
y=f[y][i];
}
}
return f[x][0];
}
int main(){
memset(h,-1,sizeof(h));
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
head.push_back(a[i]);
}
sort(head.begin(),head.end());
head.erase(unique(head.begin(),head.end()),head.end());//离散化去重
root[0]=build(0,head.size()-1);//初始化第0版本
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
add(x,y);add(y,x);//建无向边
}
BFS();
int last=0;
while(m--){
int u,v,k;
scanf("%d%d%d",&u,&v,&k);
int p=lca(u^last,v);
int q=f[p][0];u=u^last;
int ans=query(root[u],root[v],root[p],root[q],0,head.size()-1,k);
//p表示u^last和v的最近公共祖先,q表示p的父节点
printf("%d\n",ans);
last=ans;
}
return 0;
}