题意:对一棵树,求出从每个结点出发能到走的最长距离(每个结点最多只能经过一次),将这些距离按排成一个数组得到d[1],d[2],d[3]……d[n] ,在数列的d中求一个最长的区间,使得区间中的最大值与最小值的差不超过m。
分析:用2次dfs能求出树的直径,对于树中任意结点,到树的直径的2个端点的距离的较大者即为最长距离。得到数组d后,用2个单调队列分别维护最大与最小值,扫描d数组,同时更新答案。
#include <set>
#include <map>
#include <stack>
#include <queue>
#include <deque>
#include <cmath>
#include <vector>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define L(i) i<<1
#define R(i) i<<1|1
#define INF 0x3f3f3f3f
#define pi acos(-1.0)
#define eps 1e-9
#define maxn 1000010
#define MOD 1000000007
const int MAXN = 1000010;
struct Edge
{
int to,w,next;
} edge[MAXN<<1];
int tot,head[MAXN];
int n,m;
int dp[MAXN][3];
int d[MAXN];
int qmin[MAXN],qmax[MAXN];
void add(int a,int b,int c)
{
edge[tot].to = b;
edge[tot].w = c;
edge[tot].next = head[a];
head[a] = tot++;
}
void dfs1(int x,int pre)
{
for(int i = head[x]; i != -1; i = edge[i].next)
{
int v = edge[i].to;
if(v == pre)
continue;
dfs1(v,x);
if(dp[v][0] + edge[i].w > dp[x][0])
{
dp[x][1] = dp[x][0];
dp[x][0] = dp[v][0] + edge[i].w;
}
else if(dp[v][0] + edge[i].w > dp[x][1])
dp[x][1] = dp[v][0] + edge[i].w;
}
}
void dfs2(int x,int pre)
{
int len = 0;
for(int i = head[x]; i != -1; i = edge[i].next)
{
if(edge[i].to == pre)
{
len = edge[i].w;
break;
}
}
if(pre != -1)
{
dp[x][2] = dp[pre][2];
if(dp[x][0] + len == dp[pre][0])
{
if(dp[pre][1] > dp[x][2])
dp[x][2] = dp[pre][1];
}
else if(dp[pre][0] > dp[x][2])
dp[x][2] = dp[pre][0];
dp[x][2] += len;
}
for(int i = head[x]; i != -1; i = edge[i].next)
if(edge[i].to != pre)
dfs2(edge[i].to,x);
}
void solve()
{
int front1 = 0,front2 = 0;
int rear1 = 0,rear2 = 0;
int ans = 0,i,j;
for(i = 1,j = 1; j <= n; j++)
{
while(rear1 > front1 && d[qmin[rear1-1]] >= d[j])
rear1--;
qmin[rear1++] = j;
while(rear2 > front2 && d[qmax[rear2-1]] <= d[j])
rear2--;
qmax[rear2++] = j;
if(d[qmax[front2]] - d[qmin[front1]] > m)
{
ans = max(ans,j-i);
while(d[qmax[front2]] - d[qmin[front1]] > m)
{
i = min(qmin[front1],qmax[front2]) + 1;
while(front1 < rear1 && qmin[front1] < i)
front1++;
while(front2 < rear2 && qmax[front2] < i)
front2++;
}
}
}
ans = max(ans,j-i);
printf("%d\n",ans);
}
int main()
{
int t,C = 1;
//scanf("%d",&t);
while(scanf("%d%d",&n,&m) != EOF)
{
memset(dp,0,sizeof(dp));
memset(head,-1,sizeof(head));
tot = 0;
for(int i = 2; i <= n; i++)
{
int a,b;
scanf("%d%d",&a,&b);
add(i,a,b);
add(a,i,b);
}
dfs1(1,-1);
dfs2(1,-1);
for(int i = 1; i <= n; i++)
{
d[i] = max(dp[i][0],dp[i][2]);
//printf("%d %d\n",i,d[i]);
}
solve();
}
return 0;
}