题目真滴皮
Orz rqy
思路
10分的暴力都没拿到
10分直接求直径
60分 容易?想到题目等价于求k+1条不相交的链
设状态
f[i][j][0/1/2]
f
[
i
]
[
j
]
[
0
/
1
/
2
]
表示以第i个节点为根的子树用了j条链并且根和儿子连有(0,1,2)条边。
转移分为5类
g[j+cc][0]=max(g[j+cc][0],f[x][j][0]+max(f[v][cc][0],max(f[v][cc][1],f[v][cc][2])));
if(cc) g[j+cc][1]=max(g[j+cc][1],f[x][j][0]+f[v][cc][1]+tb[i].w);
if(j) g[j+cc][1]=max(g[j+cc][1],f[x][j][1]+max(f[v][cc][0],max(f[v][cc][1],f[v][cc][2])));
if(j) g[j+cc][2]=max(g[j+cc][2],f[x][j][2]+max(f[v][cc][0],max(f[v][cc][1],f[v][cc][2])));
if(j && cc) g[j+cc-1][2]=max(g[j+cc-1][2],f[x][j][1]+f[v][cc][1]+tb[i].w);
100分正解 好像是叫wqs二分的东东。
如果不考虑k的话可以省掉一维让时间复杂度达到
O(n)
O
(
n
)
用
f[i][0/1/2]
f
[
i
]
[
0
/
1
/
2
]
表示以i为根的子树 balabala
经过一番推导打表验证 f是一个凸函数
他的差分数组单调递减。 可以dp求得最大值和链的个数和k比较一下。
显然如果对于差分数组整体减一个数的话最大值点是会左移的那么就可以二分这个减去的数啦。
设减去的数为 s
怎么减去整个数呢?
设
g[i]
g
[
i
]
表示i条链的答案 那么设
h[i]=g[i]−i∗s
h
[
i
]
=
g
[
i
]
−
i
∗
s
h
h
<script type="math/tex" id="MathJax-Element-6">h</script>的差分数组就可以让每一个减去s了
上下界l,r。 r=所有正权值和,l=最大正权边的相反数。
分别对应着差分数组的两极。
代码
//林可卡特树 wqs二分
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <climits>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
const int maxn=300000+10;
const int maxn1=100000+5;
typedef long long ll;
int n,k;
struct node
{
int v,next,w;
}tb[maxn*2];
int len,h[maxn];
ll f[maxn][5],g[maxn][5],qq[5],q[5];
ll ans,sum,sz[maxn],s,l,r;
bool flag[maxn];
void add1(int i,int j,int w)
{
len++; tb[len]=(node){j,h[i],w}; h[i]=len;
}
void Max(ll &val,ll &num,ll newval,ll newnum)
{
if(val<newval || (val==newval && num>newnum)){val=newval,num=newnum;}
}
void dfs(int x,int fa)
{
f[x][0]=0; f[x][1]=-s; f[x][2]=-s;
g[x][0]=0; g[x][1]=1; g[x][2]=1;
for(int i=h[x]; i ;i=tb[i].next)
{
int v=tb[i].v;
if(v==fa) continue;
dfs(v,x);
for(int j=0; j<=2; j++)
{
qq[j]=f[x][j]; q[j]=g[x][j];
}
ll max1=f[v][0],maxnum=g[v][0];
Max(max1,maxnum,f[v][1],g[v][1]); Max(max1,maxnum,f[v][2],g[v][2]);
Max(f[x][0],g[x][0],qq[0]+max1,maxnum+q[0]);
Max(f[x][1],g[x][1],qq[1]+max1,maxnum+q[1]);
Max(f[x][1],g[x][1],qq[0]+f[v][1]+tb[i].w,q[0]+g[v][1]);
Max(f[x][2],g[x][2],qq[2]+max1,maxnum+q[2]);
Max(f[x][2],g[x][2],qq[1]+f[v][1]+tb[i].w+s,q[1]+g[v][1]-1);
}
}
int main()
{
int size = 256 << 20; //250M
char*p=(char*)malloc(size) + size;
__asm__("movl %0, %%esp\n" :: "r"(p) );
freopen("test.in","r",stdin);
freopen("test.out","w",stdout);
scanf("%d %d",&n,&k);
for(int i=1; i<n; i++)
{
int x,y,w;
scanf("%d %d %d",&x,&y,&w);
add1(x,y,w); add1(y,x,w);
if(w>0) r+=w; l=min(l,1LL*(-w));
} k++;
while(l<r)
{
s=(l+r)>>1;
dfs(1,0);
ll max1=f[1][0],maxnum=g[1][0];
Max(max1,maxnum,f[1][1],g[1][1]); Max(max1,maxnum,f[1][2],g[1][2]);
if(maxnum>k)
l=s+1;
else
r=s;
}
s=l;
dfs(1,0);
ll max1=f[1][0],maxnum=g[1][0];
Max(max1,maxnum,f[1][1],g[1][1]); Max(max1,maxnum,f[1][2],g[1][2]);
cout<<max1+k*s<<endl;
return 0;
}