废话
很早就想学wqs二分,结果拖了好久。因为以前是看了几遍都没有懂。。(太菜了
后来因为计划里凸优化的题(比如CF321E Ciel and Gondolas,CF739E Gosha is hunting…)太多了。。又不会高深的数据结构,所以只能硬着头皮学
网上blog这么多,还是wqs神仙本人的论文最好懂。。网上应该会有,这边就不放资源了
然后灵光一现就突然懂了
先上题
这道算是经典题了叭
传送门(小心点提交可能会被封号qaq
题意就不解释了
问题可以转换成求树上不相交的
k
+
1
k+1
k+1条链的边权和的最大值
显然是树形dp啊
d
p
[
o
p
t
]
[
x
]
[
i
]
dp[opt][x][i]
dp[opt][x][i],其中
o
p
t
opt
opt表示
x
x
x节点的度数(链上的),所以只有可能等于
0
/
1
/
2
0/1/2
0/1/2,
x
x
x表示当前节点,
i
i
i表示以
x
x
x为根节点的子树中有
i
i
i条链
然后就会有弄出一些转移方程
这里就不写了(自己推一下蛮简单的)
但是这样的时间复杂度是
O
(
n
k
)
\mathcal{O(nk)}
O(nk)的,已经可以过掉60分了
然后大佬说这个函数是凸的
然后窝打了一张表发现是真的
那么下面就是今天重点了
然后我们可以发现可以用
O
(
n
)
\mathcal{O(n)}
O(n)的时间求出顶点的坐标
只要转移的时候记录一下现在的k是什么就好了
然后我们可以新定义一个函数
f
(
x
)
=
d
p
(
x
)
−
c
×
x
f(x)=dp(x)-c\times x
f(x)=dp(x)−c×x
很容易发现这让函数的顶点的横坐标往左移或者往右移了
发现了什么?这个东西就可以二分了
没错这个就是wqs二分了
于最后就一定能求得一个顶点是
(
k
,
f
(
k
)
)
(k,f(k))
(k,f(k))
所以
d
p
(
x
)
=
f
(
x
)
+
k
∗
x
dp(x)=f(x)+k*x
dp(x)=f(x)+k∗x这个就是最终的答案了
时间复杂度
O
(
n
log
k
)
\mathcal{O(n\log k)}
O(nlogk)
是不是很简单???(一脸天真
Code
#include <cstdio>
#include <algorithm>
#include <cstring>
#define N 300010
using namespace std;
typedef long long LL;
LL cnt, lst[N];
struct Node{
LL to, nxt;
LL w;
}e[N << 1];
struct Data {
LL x, y;
Data(LL X = 0, LL Y = 0) {
x = X; y = Y;
}
inline bool operator < (const Data &o) const {
return x < o.x || x == o.x && y > o.y;
}
inline Data operator + (const Data &o) const {
return Data(x + o.x, y + o.y);
}
inline Data operator + (LL o) {
return Data(x + o, y);
}
}dp[3][N];
inline void add(LL u, LL v, LL w) {
e[++cnt].to = v;
e[cnt].nxt = lst[u];
e[cnt].w = w;
lst[u] = cnt;
}
inline Data nw(Data o, LL v) {
return Data(o.x - v, o.y + 1);
}
inline void dfs(LL x, LL fa, LL val) {
dp[2][x] = Data(-val, 1);
for (LL i = lst[x]; i; i = e[i].nxt) {
LL son = e[i].to;
if (son == fa) continue;
dfs(son, x, val);
dp[2][x] = max(dp[2][x] + dp[0][son], nw(dp[1][x] + dp[1][son] + e[i].w, val));
dp[1][x] = max(dp[1][x] + dp[0][son], dp[0][x] + dp[1][son] + e[i].w);
dp[0][x] = dp[0][x] + dp[0][son];
}
dp[0][x] = max(dp[0][x], max(nw(dp[1][x], val), dp[2][x]));
}
int main() {
LL n, k;
scanf("%lld%lld", &n, &k);
k++;
LL r = 0;
for (LL i = 1, x, y, z; i < n; ++i) {
scanf("%lld%lld%lld", &x, &y, &z);
add(x, y, z);
add(y, x, z);
r += abs(z);
}
LL l = -r;
while (l <= r) {
LL mid = l + r >> 1;
memset(dp, 0, sizeof dp);
dfs(1, 0, mid);
if (dp[0][1].y <= k) r = mid - 1;
else l = mid + 1;
}
memset(dp, 0, sizeof dp);
dfs(1, 0, l);
printf("%lld\n", l * k + dp[0][1].x);
return 0;
}