The ALPC company is now working on his own network system, which is connecting all N ALPC department. To economize on spending, the backbone network has only one router for each department, and N-1 optical fiber in total to connect all routers.
The usual way to measure connecting speed is lag, or network latency, referring the time taken for a sent packet of data to be received at the other end.
Now the network is on trial, and new photonic crystal fibers designed by ALPC42 is trying out, the lag on fibers can be ignored. That means, lag happened when message transport through the router. ALPC42 is trying to change routers to make the network faster, now he want to know that, which router, in any exactly time, between any pair of nodes, the K-th high latency is. He needs your help.
Input
There are only one test case in input file.
Your program is able to get the information of N routers and N-1 fiber connections from input, and Q questions for two condition: 1. For some reason, the latency of one router changed. 2. Querying the K-th longest lag router between two routers.
For each data case, two integers N and Q for first line. 0<=N<=80000, 0<=Q<=30000.
Then n integers in second line refer to the latency of each router in the very beginning.
Then N-1 lines followed, contains two integers x and y for each, telling there is a fiber connect router x and router y.
Then q lines followed to describe questions, three numbers k, a, b for each line. If k=0, Telling the latency of router a, Ta changed to b; if k>0, asking the latency of the k-th longest lag router between a and b (include router a and b). 0<=b<100000000.
A blank line follows after each case.
Output
For each question k>0, print a line to answer the latency time. Once there are less than k routers in the way, print “invalid request!” instead.
Sample Input
5 5
5 1 2 3 4
3 1
2 1
4 3
5 3
2 4 5
0 1 2
2 2 3
2 1 4
3 3 5
Sample Output
3
2
2
invalid request!
求任意两点间路径上的所有点中,val第k大的那个点的val,所以对于每一个询问,需要快速得到两点间的路径,开一个f数组记录每一个结点的父节点,就可以通过这个得到该点到lca点的路径,通过把两个点到lca的路径上沿途的点记录上来,然后排序输出第k大的数就行。
使用的在线的lca转rmq的方法
#include<iostream>
#include<algorithm>
#include<string.h>
#include<stdio.h>
#include<vector>
#include<string>
#include<cmath>
#include<set>
#include<queue>
using namespace std;
#define ll long long
#define inf 0x3f3f3f3f
//const int mod = 1000000007;
const int maxm = 70005;
const int maxn = 80005;
const int M = 25;
int n, q;
int _pow[M];
int fa[maxn],path[maxn];
int node[2 * maxn];
int val[maxn];
int DEP[2 * maxn];
int first[maxn];
//int dis[maxn];
int dp[2 * maxn][M];
bool vis[maxn];
int tot,tot1;
int head[maxn];
struct edge{
int u, v, nex;
//int len;
}e[2*maxn];
void adde(int x, int y){
e[tot1].u = x;
e[tot1].v = y;
e[tot1].nex = head[x];
head[x] = tot1++;
}
void dfs(int u, int dep,int f){
fa[u] = f;
vis[u] = 1; node[++tot] = u; first[u] = tot; DEP[tot] = dep;
for (int k = head[u]; k != -1; k = e[k].nex){
if (!vis[e[k].v]){
int v = e[k].v;
dfs(v, dep + 1,u);
node[++tot] = u; DEP[tot] = dep;
}
}
}
void st(int n){//处理出区间内dep最小的的
int k = (int)(log(double(n)) / log(2.0));
for (int i = 1; i <= n; i++){ dp[i][0] = i; }
for (int j = 1; j <= k; j++){
for (int i = 1; i + _pow[j] - 1 <= n; i++){
int a = dp[i][j - 1];
int b = dp[i + _pow[j - 1]][j - 1];
if (DEP[a] < DEP[b]) dp[i][j] = a;
else dp[i][j] = b;
}
}
}
int RMQ(int x, int y){
int k = (int)(log(double(y-x+1)) / log(2.0));
int a = dp[x][k];
int b = dp[y - _pow[k] + 1][k];
if (DEP[a] < DEP[b])return a;
return b;
}
int lca(int u, int v){
int x = first[u], y = first[v];
if (x>y)swap(x, y);
int index = RMQ(x, y);
return node[index];
}
void findpath(int &index, int s, int t){
while (s != t){
path[index++] = val[s];
s = fa[s];
}
path[index++] = val[t];
}
bool cmp(int a, int b){ return a > b; }
void solve(int k, int u, int v){
int _lca = lca(u, v);
int cnt = 0;
findpath(cnt, u, _lca);
findpath(cnt, v, _lca);
cnt--;//lca入了两次
if (k > cnt){ printf("invalid request!\n"); return; }
sort(path, path + cnt,cmp);
printf("%d\n", path[k - 1]);
}
int main() {
int x, y;
memset(head, -1, sizeof(head));
for (int i = 0; i < 25; i++){ _pow[i] = 1 << i; }
scanf("%d%d", &n, &q);
for (int i = 1; i <= n; i++){scanf("%d", &val[i]);}
for (int i = 1; i < n; i++){
scanf("%d%d", &x, &y);
adde(x, y);
adde(y, x);
}
tot = 0;
dfs(1, 1, 1);
st(tot);
while (q--){
int op;
scanf("%d", &op);
if (op == 0){
scanf("%d%d", &x, &y);
val[x] = y;
}
else{
scanf("%d%d", &x, &y);
solve(op, x, y);
}
}
return 0;
}