题目链接:https://www.spoj.com/problems/COT/
COT - Count on a tree
description
You are given a tree with N N N nodes. The tree nodes are numbered from 1 1 1 to N N N. Each node has an integer weight.
We will ask you to perform the following operation:
- u v k u\ v\ k u v k : ask for the kth minimum weight on the path from node u u u to node v v v
Input
In the first line there are two integers N N N and M M M. ( N , M ≤ 100000 ) (N, M \leq 100000) (N,M≤100000)
In the second line there are N N N integers. The i − t h i-th i−th integer denotes the weight of the ith node.
In the next N − 1 N-1 N−1 lines, each line contains two integers u v u v uv, which describes an edge ( u , v ) (u, v) (u,v).
In the next M M M lines, each line contains three integers u v k u\ v\ k u v k, which means an operation asking for the k − t h k-th k−th minimum weight on the path from node u u u to node v v v.
Output
For each operation, print its result.
Example
Input:
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
2 5 2
2 5 3
2 5 4
7 8 2
Output:
2
8
9
105
7
-
题意:
- 给一棵树,每个点都有一个权值, q q q个询问,询问u到v的最短路径上所有节点的权值的第 k k k小
-
解法:树上主席树
-
我们知道线性的主席树维护的是前缀信息,那么换到树上的时候,也可以类比分析,我们只需要让每个节点在父亲节点的基础上新建一条链,也就是维护从根到当前节点的所有节点的权值信息,那么如何查询u到v的路径上的权值第 k k k小呢?
-
比如上面这个图,( 0 0 0是自己添加的作为节点 1 1 1的父节点),现在要要查询 2 2 2到 3 3 3的第 k k k小,建树的时候分别保存了从 0 0 0到 2 2 2和从 0 0 0到 3 3 3的前缀信息,所以 2 , 3 2,3 2,3之间的路径信息就可以用 s u m [ 2 ] + s u m [ 3 ] − s u m [ L C A ( 2 , 3 ) ] − s u m [ f a [ L C A ( 2 , 3 ) ] ] sum[2]+sum[3]-sum[LCA(2,3)]-sum[fa[LCA(2,3)]] sum[2]+sum[3]−sum[LCA(2,3)]−sum[fa[LCA(2,3)]]表示,其中 L C A ( 2 , 3 ) LCA(2,3) LCA(2,3)表示节点 2 2 2和节点 3 3 3的最近公共祖先, f a [ L C A ] fa[LCA] fa[LCA]表示 L C A LCA LCA的父亲节点,所以对于一般的 u , v u,v u,v路径之间的信息就可以查询 s u m [ u ] + s u m [ v ] − s u m [ L C A ( u , v ) ] − s u m [ f a [ L C A ( u , v ) ] ] sum[u]+sum[v]-sum[LCA(u,v)]-sum[fa[LCA(u,v)]] sum[u]+sum[v]−sum[LCA(u,v)]−sum[fa[LCA(u,v)]],注意 s u m sum sum的索引并不是节点编号,而是该节点对应的需要查询的区间的编号,这里只是为了表述方便。
-
另外此题有几个点注意一下:数据中的边 u , v u,v u,v并不保证 u u u是 v v v的父亲节点,然后点权需要开 l o n g l o n g long\ long long long
在这个地方wa了n发
-
附代码:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
using namespace std;
typedef long long ll;
const int maxn = 100005;
const int maxm = 2000005;
int n, q;
ll a[maxn], b[maxn];
int sum[maxm];
int root[maxn];
int ls[maxm], rs[maxm];
int cnt = 0;
vector<int> vec[maxn];
int u, v;
int k, son, fa, ans[maxn], in[2 * maxn], fin[maxn], height[2 * maxn], dp[2 * maxn][25], tim, fath[maxn];
void init()
{
memset(ans, 0, sizeof(ans));
memset(fin, 0, sizeof(fin));
for(int i = 1; i <= n; i++) vec[i].clear();
tim = 0;
cnt = 0;
}
void dfs_LCA(int cur, int h, int fa)
{
fath[cur] = fa;
in[++tim] = cur;
height[tim] = h;
if(!fin[cur]) fin[cur] = tim;
for(int i = 0; i < vec[cur].size(); i++) {
if(vec[cur][i] != fa) {
dfs_LCA(vec[cur][i], h + 1, cur);
in[++tim] = cur;
height[tim] = h;
}
}
}
void st()
{
for(int i = 1; i <= tim; i++) dp[i][0] = in[i];
for(int j = 1; (1 << j) <= tim; j++) {
for(int i = 1; i + (1 << j) - 1 <= tim; i++) {
if(height[fin[dp[i][j - 1]]] < height[fin[dp[i + (1 << (j - 1))][j - 1]]]) dp[i][j] = dp[i][j - 1];
else dp[i][j] = dp[i + (1 << (j - 1))][j - 1];
}
}
}
int query_LCA(int u, int v)
{
int a = min(fin[u], fin[v]), b = max(fin[u], fin[v]);
int k = 0;
while(a + (1 << (k + 1)) - 1 <= b) k++;
int x = fin[dp[a][k]], y = fin[dp[b - (1 << k) + 1][k]];
return height[x] < height[y] ? dp[a][k] : dp[b - (1 << k) + 1][k];
}
int build(int l, int r)
{
int now = ++cnt;
if(l == r) return now;
int mid = (l + r) >> 1;
ls[now] = build(l, mid);
rs[now] = build(mid + 1, r);
return now;
}
int modify(int l, int r, int loc, int pre)
{
int now = ++cnt;
sum[now] = sum[pre] + 1;
if(l == r) return now;
int mid = (l + r) >> 1;
if(loc <= mid){
rs[now] = rs[pre];
ls[now] = modify(l, mid, loc, ls[pre]);
}
else{
ls[now] = ls[pre];
rs[now] = modify(mid + 1, r, loc, rs[pre]);
}
return now;
}
int query(int l, int r, int a, int b, int lca, int fa_lca, int k)
{
if(l == r) return l;
int s = sum[ls[b]] + sum[ls[a]] - sum[ls[lca]] - sum[ls[fa_lca]];
int mid = (l + r) >> 1;
if(k <= s) return query(l, mid, ls[a], ls[b], ls[lca], ls[fa_lca], k);
return query(mid + 1, r, rs[a], rs[b], rs[lca], rs[fa_lca], k - s);
}
void dfs(int cur, int fa, int tot)
{
int loc = lower_bound(b + 1, b + tot + 1, a[cur]) - b;
root[cur] = modify(1, tot, loc, root[fa]);
for(int i = 0; i < vec[cur].size(); i++) if(vec[cur][i] != fa) dfs(vec[cur][i], cur, tot);
}
int main()
{
while(~scanf("%d %d", &n, &q)){
init();
for(int i = 1; i <= n; i++) scanf("%lld", &a[i]), b[i] = a[i];
sort(b + 1, b + n + 1);
int tot = unique(b + 1, b + n + 1) - b - 1;
for(int i = 1; i < n; i++) {
scanf("%d %d", &u, &v);
vec[u].push_back(v);
vec[v].push_back(u);
}
dfs_LCA(1, 0, 0);
st();
root[0] = build(1, tot);
dfs(1, 0, tot);
for(int i = 1; i <= q; i++){
scanf("%d %d %d", &u, &v, &k);
printf("%lld\n", b[query(1, tot, root[u], root[v], root[query_LCA(u, v)], root[fath[query_LCA(u, v)]], k)]);
}
}
}