题意就是求树上距离小于等于K的点对有多少个
n2的算法肯定不行,因为1W个点
这就需要分治。可以看09年漆子超的论文 http://wenku.baidu.com/view/e087065f804d2b160b4ec0b5.html###
本题用到的是关于点的分治。
一个重要的问题是,为了防止退化,所以每次都要找到树的重心然后分治下去,所谓重心,就是删掉此结点后,剩下的结点最多的树结点个数最小。
每次分治,我们首先算出重心,为了计算重心,需要进行两次dfs,第一次把以每个结点为根的子树大小求出来,第二次是从这些结点中找重心
找到重心后,需要统计所有结点到重心的距离,看其中有多少对小于等于K,这里采用的方法就是把所有的距离存在一个数组里,进行快速排序,这是nlogn的,然后用一个经典的相向搜索O(n)时间内解决。但是这些求出来满足小于等于K的里面只有那些路径经过重心的点对才是有效的,也就是说在同一颗子树上的肯定不算数的,所以对每颗子树,把子树内部的满足条件的点对减去。
n2的算法肯定不行,因为1W个点
这就需要分治。可以看09年漆子超的论文 http://wenku.baidu.com/view/e087065f804d2b160b4ec0b5.html###
本题用到的是关于点的分治。
一个重要的问题是,为了防止退化,所以每次都要找到树的重心然后分治下去,所谓重心,就是删掉此结点后,剩下的结点最多的树结点个数最小。
每次分治,我们首先算出重心,为了计算重心,需要进行两次dfs,第一次把以每个结点为根的子树大小求出来,第二次是从这些结点中找重心
找到重心后,需要统计所有结点到重心的距离,看其中有多少对小于等于K,这里采用的方法就是把所有的距离存在一个数组里,进行快速排序,这是nlogn的,然后用一个经典的相向搜索O(n)时间内解决。但是这些求出来满足小于等于K的里面只有那些路径经过重心的点对才是有效的,也就是说在同一颗子树上的肯定不算数的,所以对每颗子树,把子树内部的满足条件的点对减去。
最后的复杂度是n logn logn 其中每次快排是nlogn 而递归的深度为logn
我的理解是:比如现在有一棵树,从点1开始遍历整棵树,算出每个节点为根所在的子树中点的数量
然后再遍历树,找出重心,重心的定义如上,也可以看论文(有证明),然后算出从重心到每个节点的距离(不计算已经当过根的节点),快排之后用相向搜索找两段dis相加小于等于k的个数(具体见代码)
不过这样找的会有重复,所以每次我们找的符合题目的对数都是经过重心root的对数,所以要减去不经过重心的个数
原理还是可以理解的,代码比较复杂,确实挺难写,仍需努力,早日成为真男人QAQ
#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
using namespace std;
#define MAX 10005
#define MAXN 2000005
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lrt rt<<1
#define rrt rt<<1|1
#define mid int m=(r+l)>>1
#define LL long long
#define ull unsigned long long
#define mem0(x) memset(x,0,sizeof(x))
#define mem1(x) memset(x,-1,sizeof(x))
#define meminf(x) memset(x,INF,sizeof(x))
#define lowbit(x) (x&-x)
const LL mod = 1000000;
const int prime = 999983;
const int INF = 0x3f3f3f3f;
const int INFF = 1e9;
const double pi = 3.141592653589793;
const double inf = 1e18;
const double eps = 1e-10;
struct Edge{
int v,cost,next;
}edge[MAX*2];
int head[MAX];
int maxn[MAX];
int siz[MAX];
int vis[MAX];
int dis[MAX];
int tot;
int root;
int mi;
int n,k;
int ans;
int num;
void add_edge(int a,int b,int c){
edge[tot]=(Edge){b,c,head[a]};
head[a]=tot++;
}
void dfssize(int u,int fa){
siz[u]=1;
maxn[u]=0;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v!=fa&&!vis[v]){
dfssize(v,u);
siz[u]+=siz[v];
maxn[u]=max(maxn[u],siz[v]);
}
}
}
void dfsroot(int r,int u,int fa){
if(siz[r]-siz[u]>maxn[u]) maxn[u]=siz[r]-siz[u];
if(maxn[u]<mi){
mi=maxn[u];
root=u;
}
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v!=fa&&!vis[v]) dfsroot(r,v,u);
}
}
void dfsdis(int u,int fa,int d){
dis[num++]=d;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(v!=fa&&!vis[v]) dfsdis(v,u,d+edge[i].cost);
}
}
int calc(int u,int d){
int ret=0;
num=0;
dfsdis(u,-1,d);
sort(dis,dis+num);
int i=0,j=num-1;
while(i<j){
while(dis[i]+dis[j]>k&&i<j) j--;
ret+=j-i;
i++;
}
return ret;
}
void dfs(int u){
mi=n;
dfssize(u,-1);
dfsroot(u,u,-1);
ans+=calc(root,0);
vis[root]=1;
for(int i=head[root];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(!vis[v]){
ans-=calc(v,edge[i].cost);//不经过重心的对数,从子节点v开始计算,距离是cost
dfs(v);
}
}
}
int main(){
while(scanf("%d%d",&n,&k)){
if(!n&&!k) break;
mem1(head);
mem0(vis);
ans=0;
tot=0;
for(int i=1;i<n;i++){
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
add_edge(a,b,c);
add_edge(b,a,c);
}
dfs(1);
printf("%d\n",ans);
}
return 0;
}