感觉分治没有剖分好写。。orz有个地方绕了好久。。。
漆子超的论文里解法写的挺详细的:
记Depth(i)表示点i到根结点的路径长度,Belong(i) X ( X 为根结点的某个儿子,且结点i 在以X 为根的子树内)。那么我们要统计的就是:
满足 Depth(i) + Depth( j) <=K 且 Belong(i) != Belong( j) 的(i, j) 个数
= 满足Depth(i) + Depth( j) <= K的(i, j)个数 - 满足Depth(i) Depth( j) <= K且Belong(i) == Belong( j)的(i, j)个数
而对于这两个部分,都是要求出满足Ai Aj <= k的(i, j)的对数。将A排序后利用单调性我们很容易得出一个O(N)的算法,所以我们可以用O(N log N)的时间来解决这个问题。
反正我表示写完这个代码我也orz了 写得丑不过比标程快 hhh (代码量都的话我觉得继续缩行应该90左右搞定没问题)
#include <cstdio>
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstring>
using namespace std;
int read()
{
int sign = 1, n = 0; char c = getchar();
while(c < '0' || c > '9'){ if(c == '-') sign = -1; c = getchar(); }
while(c >= '0' && c <= '9') { n = n*10 + c-'0'; c = getchar(); }
return sign*n;
}
const int Nmax = 10005;
int N, K, ans;
int d[Nmax], f[Nmax], s[Nmax], root;
bool done[Nmax];
struct ed{
int v, w, next;
}e[Nmax * 2];
int k, head[Nmax];
inline void adde(int u, int v, int w)
{
e[k] = (ed){ v, w, head[u] };
head[u] = k++;
e[k] = (ed){ u, w, head[v] };
head[v] = k++;
}
void getroot(int u, int fa, int size)
{
f[u] = 0; s[u] = 1;
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].v;
if(done[v] || v == fa) continue;
getroot(v, u, size);
s[u] += s[v];
f[u] = max(f[u], s[v]);
}
f[u] = max(f[u], size - s[u]);
if(f[u] < f[root]) root = u;
}
vector <int> dep;
void getdep(int u, int fa)
{
dep.push_back(d[u]);
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].v;
if(done[v] || v == fa) continue;
d[v] = d[u] + e[i].w;
getdep(v, u);
}
}
int query(int u, int echo)
{
dep.clear(); d[u] = echo;
getdep(u, 0); sort(dep.begin(), dep.end());
int res = 0;
for(int l = 0, r = dep.size() - 1; l < r; )
{
if(dep[l] + dep[r] <= K) res += r - l++;
else r--;
}
return res;
}
void work(int u)
{
ans += query(u, 0);
done[u] = 1;
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].v;
if(done[v]) continue;
ans -= query(v, d[v]);
f[0] = s[v];
getroot(v, root = 0, s[v]); work(root);
}
}
int main()
{
while(~scanf("%d%d", &N, &K) && N && K)
{
memset(done, 0, sizeof(done)); ans = 0;
memset(head, 0, sizeof(head)); k = 1;
for(int i = 1; i < N; ++i)
{
int u = read(), v = read(), w = read();
adde(u, v, w);
}
f[0] = N;
getroot(1, root = 0, N); work(root);
printf("%d\n", ans);
}
return 0;
}