题目描述
贺老师创造了一个属于自己的树,命名为贺老师含树。这棵树有N个节点。然后,贺老师让你回答Q个问题,问题的种类有两种:
1 X Y :在树上添加一条从X到Y的路径。
2 X : 询问节点X以及他的子树内,最长的完整路径长度是多少。完整路径指的是这条路径的所有节点都在X的子树内。
由于贺老师的时间太宝贵了,于是这个辣鸡任务就交给你来处理了。
输入
第一行,一个数字N,表示节点个数,一个数字Q,表示询问个数。
第二行,N-1个数字,第i个数字f[i]表示节点i的父亲节点是谁,保证f[i]<i。
接下来Q行,每一行1 X Y或者2 X。表示询问1或者2。
输出
对于每一个询问2,输出一个数字在单独的一行表示答案。
先解释一下题意:先给定的树只是给定了一个空树,是没有边的,只有指向,操作1就是将从x到y的路径上的所有边激活,即为可以使用。操作2就是在x的子树内的最大路径长度。
感觉自己最弱的一点就是树上的问题了......
如果我们对一棵树进行前序dfs遍历进行编号,将in[i] 这个点定义为刚开始考虑这个点时的编号,将out[i]定义为考虑结束这个点时的编号,那么显而易见对于每一个节点i,in[i]~out[i]是一个连续的序列。对于其中的最长的连续路径,out[i]-in[i]就是答案,也就是题目的待求量了。
所以对于添加路径,由于x~y是lca(x,y)子树内的最长路径,所以就可以将in[lca(x,y)] 与 out[lca(x,y)]的答案改为这条路径的长度dep[x]+dep[y]-dep[lca(x,y)]*2+1。每次查询节点只需要关心这个节点in-out的最大值即可。
所以最后我们还需要一种能够单点修改,区间查询最大值的数据结构,简单的线段树即可。
下附AC代码。
#include<iostream>
#include<stdio.h>
#include<vector>
#include<string.h>
#include<algorithm>
#define maxn 100005
using namespace std;
int n,q;
int tot=1;
int anc[maxn][30];
int sav[80*maxn];
int dep[maxn];
int in[maxn],out[maxn];
vector<int> edge[maxn];
void dfs(int now,int dept)
{
dep[now]=dept;
in[now]=tot++;
int j=0;
while(anc[anc[now][j]][j]!=0)
{
anc[now][j+1]=anc[anc[now][j]][j];
j++;
}
int len=edge[now].size();
for(int i=0;i<len;i++)
if(edge[now][i]!=anc[now][0])
{
dfs(edge[now][i],dept+1);
}
out[now]=tot++;
return;
}
int lca(int p,int q)
{
if(dep[p]<dep[q]) swap(p,q);
int temp=dep[p]-dep[q];
int j=0;
while(temp)
{
if(temp&1) p=anc[p][j];
temp>>=1;
j++;
}
if(p==q) return p;
for(int i=20;i>=0;i--)
if(anc[p][i]!=anc[q][i])
{ p=anc[p][i];
q=anc[q][i];
}
return anc[p][0];
}
void add(int l,int r,int bas,int len,int id)
{
if(l==r)
{
sav[id]=max(sav[id],len);
return;
}
int mid=(l+r)>>1;
if(bas<=mid)
add(l,mid,bas,len,id<<1);
else
add(mid+1,r,bas,len,(id<<1)+1);
sav[id]=max(sav[id<<1],sav[(id<<1)+1]);
}
int query(int l,int r,int li,int ri,int id)
{
if(l==li && r==ri)
return sav[id];
int mid=(l+r)>>1;
if(li>mid)
return query(mid+1,r,li,ri,(id<<1)+1);
if(ri<=mid)
return query(l,mid,li,ri,id<<1);
return max(query(l,mid,li,mid,id<<1),query(mid+1,r,mid+1,ri,(id<<1)+1));
}
int main()
{
scanf("%d%d",&n,&q);
for(int i=2;i<=n;i++)
{
scanf("%d",&anc[i][0]);
edge[i].push_back(anc[i][0]);
edge[anc[i][0]].push_back(i);
}
dfs(1,0);
int a,b,c;
for(int i=1;i<=q;i++)
{
scanf("%d",&a);
if(a==1)
{
scanf("%d%d",&b,&c);
int l=lca(b,c);
int len=dep[b]+dep[c]-dep[l]*2+1;
add(1,2*n, in[l],len,1);
add(1,2*n,out[l],len,1);
}
else
{
scanf("%d",&b);
printf("%d\n",query(1,2*n,in[b],out[b],1));
}
}
return 0;
}