模板
求一棵树上的最长上升子序列的长度
解法
线段树合并
我们对于树上每一个点 xxx,发现经过它的一条路径LIS由两部分组成。
以 xxx 开头的左子树最长上升子序列 + 右子树最长下降子序列 - 1
以 xxx 开头的左子树最长下降子序列 + 右子树最长上升子序列 - 1
一开始所有点的 LISLISLIS 和 LDSLDSLDS 都是1,我们从底向上合并答案
首先dfs到底,然后向上回溯
复杂度 O(nlogn)O(nlogn)O(nlogn)
Code
#include <bits/stdc++.h>
#define ll long long
#define qc ios::sync_with_stdio(false); cin.tie(0);cout.tie(0)
#define fi first
#define se second
#define PII pair<int, int>
#define PLL pair<ll, ll>
#define pb push_back
using namespace std;
const int MAXN = 2e5 + 7;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
const ll mod = 1e9 + 7;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
return x*f;
}
int n;
int head[MAXN];
struct edge{
int next, to, w;
}e[MAXN << 1];
// 链式前向星
int cnt;
void add(int u, int v, int w = 0){
e[cnt].w = w;
e[cnt].to = v;
e[cnt].next = head[u];
head[u] = cnt++;
}
// 最长上升子序列长度
int ans = 0;
int rt[MAXN];
struct Tree{
int l, r, lis, lds;
}tree[MAXN << 7];
int tot;
// 节点的value
int val[MAXN];
void init(){
for(int i = 0; i <= n; i++)
head[i] = -1;
cnt = 0;
}
int build(){
tot++;
tree[tot].l = tree[tot].r = tree[tot].lis = tree[tot].lds;
return tot;
}
void update(int &p, int l, int r, int pos, int vlis, int vlds){
if(!p)
p = build();
tree[p].lis = max(tree[p].lis, vlis);
tree[p].lds = max(tree[p].lds, vlds);
if(l == r)
return ;
int m = (l + r) >> 1;
if(pos <= m)
update(tree[p].l, l, m, pos, vlis, vlds);
else
update(tree[p].r, m+1, r, pos, vlis, vlds);
}
int merge(int p, int q){
if(!p || !q)
return q + p;
tree[p].lis = max(tree[p].lis, tree[q].lis);
tree[p].lds = max(tree[p].lds, tree[q].lds);
ans = max(ans, max(tree[tree[p].l].lis + tree[tree[q].r].lds,
tree[tree[p].r].lds + tree[tree[q].l].lis));
tree[p].l = merge(tree[p].l, tree[q].l);
tree[p].r = merge(tree[p].r, tree[q].r);
return p;
}
PII query(int p, int l, int r, int L, int R){
if(p == 0 || l > r)
return {0, 0};
if(L <= l && r <= R)
return {tree[p].lis, tree[p].lds};
int m = (l + r) >> 1;
PII ret = {0, 0}, tl = {0, 0}, tr = {0, 0};
if(L <= m)
tl = query(tree[p].l, l, m, L, R);
if(m < R)
tr = query(tree[p].r, m+1, r, L, R);
ret.fi = max(tl.fi, tr.fi);
ret.se = max(tl.se, tr.se);
return ret;
}
void dfs(int u, int f){
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to;
if(v == f)
continue;
dfs(v, u);
}
int nlis = 0, nlds = 0;
for(int i = head[u]; ~i; i = e[i].next){
int v = e[i].to;
if(v == f)
continue;
// 询问的第三个和第五个参数是要到val的最大值,可能是n也可能不是
int lis = query(rt[v], 1, n, 1, val[u]-1).fi;
int lds = query(rt[v], 1, n, val[u]+1, n).se;
ans = max(ans, lis + 1 + nlds);
ans = max(ans, lds + 1 + nlis);
merge(rt[u], rt[v]);
nlis = max(nlis, lis);
nlds = max(nlds, lds);
}
// 第三个参数不一定是n 同上
update(rt[u], 1, n, val[u], nlis+1, nlds+1);
}
void solve(){
cin >> n;
init();
for(int i = 1; i < n; i++){
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
for(int i = 1; i <= n; i++)
cin >> val[i];
ans = 0;
dfs(1, 0);
cout << ans << endl;
}
int main()
{
#ifdef ONLINE_JUDGE
#else
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
qc;
int T;
// cin >> T;
T = 1;
while(T--){
solve();
}
return 0;
}