You've got a weighted tree, consisting of n vertices. Each edge has a non-negative weight. The length of the path between any two vertices of the tree is the number of edges in the path. The weight of the path is the total weight of all edges it contains.
Two vertices are close if there exists a path of length at most l between them and a path of weight at most w between them. Count the number of pairs of vertices v, u (v < u), such that vertices v and u are close.
The first line contains three integers n, l and w (1 ≤ n ≤ 105, 1 ≤ l ≤ n, 0 ≤ w ≤ 109). The next n - 1 lines contain the descriptions of the tree edges. The i-th line contains two integers pi, wi (1 ≤ pi < (i + 1), 0 ≤ wi ≤ 104), that mean that the i-th edge connects vertex (i + 1)and pi and has weight wi.
Consider the tree vertices indexed from 1 to n in some way.
Print a single integer — the number of close pairs.
Please, do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64dspecifier.
4 4 6 1 3 1 4 1 3
4
6 2 17 1 3 2 5 2 13 1 6 5 9
9
题意:问一棵树上有多少条路径长度不超过l且边权和不超过w的路径
解题思路:树分治+树状数组
#include <iostream>
#include <cstdio>
#include <cstring>
#include <string>
#include <algorithm>
#include <cctype>
#include <map>
#include <cmath>
#include <set>
#include <stack>
#include <queue>
#include <vector>
#include <bitset>
#include <functional>
using namespace std;
#define LL long long
const int INF = 0x3f3f3f3f;
const int maxn = 2e5 + 10;
int n, w, l, x, y;
int s[maxn], nt[maxn * 2], e[maxn * 2], val[maxn * 2], cnt;
int sum[maxn], mx[maxn], vis[maxn], tot;
int sum1[maxn];
vector<int>g;
struct node
{
int x, y;
bool operator<(const node &a)const
{
return y < a.y;
}
}a[maxn];
int lowbit(int k) { return k&-k; }
void update(int k, int val) { for (int i = k; i <= l + 1; i += lowbit(i)) { g.push_back(i),sum1[i] += val; } }
LL query(int k)
{
LL ans = 0;
for (int i = k; i; i -= lowbit(i)) ans +=1LL*sum1[i];
return ans;
}
int dfs(int k, int fa, int p)
{
int ans = 0;
sum[k] = (mx[k] = 0) + 1;
for (int i = s[k]; ~i; i = nt[i])
{
if (e[i] == fa || vis[e[i]]) continue;
int temp = dfs(e[i], k, p);
sum[k] += sum[e[i]];
mx[k] = max(mx[k], sum[e[i]]);
if (mx[temp] < mx[ans]) ans = temp;
}
mx[k] = max(mx[k], p - sum[k]);
return mx[k] < mx[ans] ? k : ans;
}
void get(int k, int fa, int dep, int len)
{
if (dep > l || len > w) return;
a[tot++] = { dep,len };
for (int i = s[k]; ~i; i = nt[i])
{
if (e[i] == fa || vis[e[i]]) continue;
get(e[i], k, dep + 1, len + val[i]);
}
}
LL Find(int k, int dep, int len)
{
LL ans = tot = 0;
get(k, k, dep, len);
sort(a, a + tot);
int p1 = 0, p2 = tot - 1;
while (p2 >= 0)
{
while (a[p1].y + a[p2].y <= w&&p1 < tot) update(a[p1].x + 1, 1), p1++;
if (p1 > p2) update(a[p2].x + 1, -1);
ans += query(l - a[p2].x + 1);
if (p1 > p2) update(a[p2].x + 1, 1);
p2--;
}
int Size = g.size();
for (int i = 0; i < Size; i++) sum1[g[i]] = 0;
g.clear();
return ans / 2;
}
LL solve(int k, int p)
{
int y = dfs(k, k, p);
LL ans = Find(y, 0, 0); vis[y] = 1;
for (int i = s[y]; ~i; i = nt[i])
{
if (vis[e[i]]) continue;
ans -= Find(e[i], 1, val[i]);
if (sum[e[i]] < sum[y]) ans += solve(e[i], sum[e[i]]);
else ans += solve(e[i], p - sum[y]);
}
vis[y] = 0;
return ans;
}
int main()
{
while (~scanf("%d%d%d", &n, &l, &w))
{
memset(s, -1, sizeof s);
mx[cnt = 0] = INF;
for (int i = 1; i < n; i++)
{
scanf("%d%d", &x, &y);
nt[cnt] = s[i + 1], s[i + 1] = cnt, e[cnt] = x, val[cnt++] = y;
nt[cnt] = s[x], s[x] = cnt, e[cnt] = i + 1, val[cnt++] = y;
}
printf("%I64d\n", solve(1, n));
}
return 0;
}