题目链接:http://poj.org/problem?id=1741
题目链接: 要求出树中相距距离小于等于k的二元组个数
树分治的入门题:看看漆子超的论文:
http://wenku.baidu.com/link?url=7KOPn20aLvKK5PqDmuLjIyj4sqZ6CL1H9qP__JSGvX-AWgX7LR6gC-BZ3PTVCP2ojBHxKZcJ5U3csiRjuspqcoFJfswO7JaEIQyKlxwUzBi
代码:
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cstdio>
#define sf scanf
#define pf printf
using namespace std;
const int maxn = 10000 + 5,INF = 2e9;
struct Edge{
int v,c,pre;
}Edges[maxn * 2];
int head[maxn],tot,vis[maxn];//vis 标记已经去掉的重心
int n,k;
void init(){
memset(head,-1,sizeof head),tot = 0;
memset(vis,0,sizeof vis);
}
void insert_edge(int u,int v,int c){
Edges[tot].v = v;
Edges[tot].c = c;
Edges[tot].pre = head[u];
head[u] = tot++;
}
int size[maxn];//分治之后子树的大小
int center_size,center;//重心子树最大的大小,重心节点
void getSize(int u,int fa){
size[u] = 1;
for(int i = head[u];~i;i = Edges[i].pre){
int v = Edges[i].v,c = Edges[i].c;
if(v == fa || vis[v]) continue;
getSize(v,u);
size[u] += size[v];
}
}
void getCenter(int u,int fa,const int& tot){
int tmp = tot - size[u];
for(int i = head[u];~i;i = Edges[i].pre){
int v = Edges[i].v;
if(v == fa || vis[v]) continue;
getCenter(v,u,tot);
tmp = max(tmp,size[v]);
}
if(tmp < center_size) center_size = tmp,center = u;
}
/** 计算数组中 相加和不大于K的元素个数 */
int cal(vector<int>& ar){
int ret = 0;
sort(ar.begin(),ar.end());
int l = 0,r = ar.size() - 1;
while(l < r){
while(ar[l] + ar[r] > k && l < r ) r--;
ret += r - l;
l++;
}
return ret;
}
vector<int> A,B;
void DFS(int u,int fa,int dis){
A.push_back(dis),B.push_back(dis);
for(int i = head[u];~i;i = Edges[i].pre){
int v = Edges[i].v,c = Edges[i].c;
if(v == fa || vis[v]) continue;
DFS(v,u,dis + c);
}
}
int part_solve(int rt){
int ret = 0;
A.clear(),B.clear();
for(int i = head[rt];~i;i = Edges[i].pre){
int v = Edges[i].v,c = Edges[i].c;
if(vis[v]) continue;
B.clear();
DFS(v,rt,c);
ret -= cal(B);
}A.push_back(0);
ret += cal(A);
return ret;
}
int ans;
void Split(int rt){
getSize(rt,rt);center_size = INF;
getCenter(rt,rt,size[rt]);
rt = center;
vis[rt] = 1;
ans += part_solve(rt);
for(int i = head[rt];~i;i = Edges[i].pre){
int v = Edges[i].v;
if(vis[v]) continue;
Split(v);
}
}
inline bool scan_d(int &num)
{
char in;bool IsN=false;
in=getchar();
if(in==EOF) return false;
while(in!='-'&&(in<'0'||in>'9')) in=getchar();
if(in=='-'){ IsN=true;num=0;}
else num=in-'0';
while(in=getchar(),in>='0'&&in<='9'){
num*=10,num+=in-'0';
}
if(IsN) num=-num;
return true;
}
int main(){
while( scan_d(n) && scan_d(k) && (n || k)){
init();
for(int i = 1;i < n;++i){
int u,v,c;
scan_d(u),scan_d(v),scan_d(c);
insert_edge(u,v,c);
insert_edge(v,u,c);
}ans = 0;
Split(1);
pf("%d\n",ans);
}
}