丧心病狂的数据结构题、、
树链剖分之后用线段树套一个随便什么BST维护第K大值、、
Code:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <set>
#include <map>
#include <queue>
#include <vector>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <complex>
using namespace std;
#define rep(a,b,c) for(int a=b;a<=c;a++)
#define per(a,b,c) for(int a=b;a>=c;a--)
#define max(a,b) ((a>b)?(a):(b))
#define min(a,b) ((a<b)?(a):(b))
#define pb push_back
#define mp make_pair
#define PII pair<int,int>
#define X first
#define Y second
struct node{
int key,ls,rs,mt,sz;
}tr[2000000];
#define MAXN 80010
bool leaf[MAXN];
int v[MAXN],fa[MAXN][20],h[MAXN],g[MAXN],e[MAXN],a[MAXN],sz[MAXN],head[MAXN];
int next[MAXN*2],t[MAXN*2],id[MAXN],wh[MAXN];
int lm[MAXN*4],rm[MAXN*4],root[MAXN*4];
int n,Q,l,r=0,T=0;
inline int logg(int x){int res=0;while (x){res++;x/=2;}return res;}
inline void addedge(int aa,int bb){t[++r]=bb;if (!g[aa])g[aa]=r;else next[e[aa]]=r;e[aa]=r;}
inline void dfs1(int x){
v[x]=1;sz[x]=1;
for (int u=g[x];u;u=next[u])
if (!v[t[u]]){
fa[t[u]][0]=x;
h[t[u]]=h[x]+1;
dfs1(t[u]);
sz[x]+=sz[t[u]];
}
if (sz[x]==1) leaf[x]=1;
}
inline void dfs2(int x,int top){
int maxid,maxn=-1;
head[x]=top;id[++T]=x;wh[x]=T;
if (leaf[x]) return ;
for (int u=g[x];u;u=next[u]){
if (t[u]==fa[x][0]) continue;
if (sz[t[u]]>maxn){
maxn=sz[t[u]];
maxid=t[u];
}
}
dfs2(maxid,top);
for (int u=g[x];u;u=next[u])
if (t[u]!=maxid && t[u]!=fa[x][0]) dfs2(t[u],t[u]);
}
inline int newnode(int x){
T++;tr[T].key=x;
tr[T].ls=tr[T].rs=0;
tr[T].mt=1;tr[T].sz=1;
return T;
}
inline void insert(int x,int cur){
int P=x;
while (1){
tr[P].sz++;
if (tr[P].key==cur){tr[P].mt++;return;}
if (tr[P].key<cur)
if (!tr[P].rs){tr[P].rs=newnode(cur);return;}else P=tr[P].rs;
else
if (!tr[P].ls){tr[P].ls=newnode(cur);return;}else P=tr[P].ls;
}
}
inline void del(int x,int cur){
int P=x;
while (1){
tr[P].sz--;
if (tr[P].key==cur){
tr[P].mt--;
return ;
}
if (tr[P].key>cur) P=tr[P].ls;else P=tr[P].rs;
}
}
inline int Count(int x,int cur){
int res=0,P=x;
while (P){
if (cur==tr[P].key) return res+tr[P].sz-tr[tr[P].rs].sz;
if (cur>tr[P].key) {res+=tr[P].sz-tr[tr[P].rs].sz;P=tr[P].rs;}
else P=tr[P].ls;
}
return res;
}
inline void build(int cur,int ll,int rr){
lm[cur]=ll;rm[cur]=rr;
if (ll==rr){
root[cur]=newnode(a[id[ll]]);
return ;
}
build(cur*2,ll,(ll+rr)/2);
build(cur*2+1,(ll+rr)/2+1,rr);
root[cur]=newnode(a[id[ll]]);
rep(i,ll+1,rr) insert(root[cur],a[id[i]]);
return ;
}
inline void modify(int cur,int pp,int K,int KK){
if (lm[cur]>pp || rm[cur]<pp) return ;
if (lm[cur]>=pp && rm[cur]<=pp){
del(root[cur],K);
insert(root[cur],KK);
return ;
}
del(root[cur],K);
insert(root[cur],KK);
modify(cur*2,pp,K,KK);
modify(cur*2+1,pp,K,KK);
return ;
}
inline int query(int cur,int ll,int rr,int K){
if (lm[cur]>rr || rm[cur]<ll) return 0;
if (lm[cur]>=ll && rm[cur]<=rr) return Count(root[cur],K);
return query(cur*2,ll,rr,K)+query(cur*2+1,ll,rr,K);
}
inline int lca(int x,int y){
int th;
while (h[x]>h[y]){
th=0;
while (h[fa[x][th]]>h[y]) th++;
if (h[fa[x][th]]<h[y]) th--;
x=fa[x][th];
}
while (h[x]<h[y]){
th=0;
while (h[fa[y][th]]>h[x]) th++;
if (h[fa[y][th]]<h[x]) th--;
y=fa[y][th];
}
while (x!=y){
th=0;
while (fa[x][th]!=fa[y][th]) th++;
if (th!=0) th--;
x=fa[x][th];y=fa[y][th];
}
return x;
}
inline int check(int be,int en,int K){
int res=0;
while (head[be]!=head[en]) res+=query(1,wh[head[be]],wh[be],K),be=fa[head[be]][0];
res+=query(1,wh[en],wh[be],K);
return res;
}
inline void scan(int &x){
char c=getchar();
while (c<'0' || c>'9') c=getchar();
x=c-'0';c=getchar();
while (c>='0' && c<='9'){
x=x*10+c-'0';
c=getchar();
}
}
int main(){
scan(n);scan(Q);
rep(i,1,n) scan(a[i]);
rep(i,1,n-1){
int aa,bb;
scan(aa);scan(bb);
addedge(aa,bb);
addedge(bb,aa);
}
dfs1(1);dfs2(1,1);
T=0;
build(1,1,n);
fa[1][0]=1;
rep(j,1,logg(n)) rep(i,1,n) fa[i][j]=fa[fa[i][j-1]][j-1];
while (Q--){
int k,A,b;
scan(k);scan(A);scan(b);
if (!k) modify(1,wh[A],a[A],b),a[A]=b;
else{
int Fa=lca(A,b);
if (h[A]+h[b]-h[Fa]*2+1<k){
puts("invalid request!");
continue;
}
k=(h[A]+h[b]-h[Fa]*2+1-k+1);
l=0,r=100000000;
while (l<r){
int C=check(A,Fa,(l+r)/2)+check(b,Fa,(l+r)/2);
if (a[Fa]<=(l+r)/2) C--;
if (C>=k) r=(l+r)/2;else l=(l+r)/2+1;
}
printf("%d\n",l);
}
}
return 0;
}