题意:给定一棵树,然后在给定一些链,链放在树上某条路径上,每条链有价值,要求取链,并且所取的链不能相交,使得取得的链的总价值最大。
做法:树上做动态规划,开个dp[i]表示第i个点为根的子树里面的取链最大价值,sum[i]表示i的儿子节点的dp值得和。考虑取一条链时,把一条链的路径上的所有sum[i]-dp[i]加起来(lca那点只需把sum[i]加进去就好了)再加上这条链的价值,这样做的目的是找出取这条链的情况下,并且其他所取的链的高度低于它时的最大价值,然后更新到dp[lca]上。然后再用sum值去更新一下dp值就好了。从深度深的点一次遍历到深度浅的点,然后找到有lca在该点的链就可以做转移。
总体时间复杂度为o(nlognlogn),然后一开始用线段树的时候,本机2900ms,交hdu就tle了。后来改成树状数组,本机2900ms,hdu2300ms ac。
#pragma comment(linker, "/STACK:1024000000,1024000000")
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <time.h>
#include <math.h>
using namespace std;
#define N 101010
struct chain{
int u,v;
int w;
int lca;
}cn[N];
/*********杂七杂八的变量*********/
int fa[N];
long long sum[N];
int que[N*2],ta,he;
/************连边部分************/
struct EDGE{
int to,next;
}e[N*2];
int head[N],tot;
void addEDGE(int u,int v)
{
e[tot].to=v;
e[tot].next=head[u];
head[u]=tot++;
}
/***********并查集部分***********/
int f[N];
int find(int x)
{
if(x!=f[x])f[x]=find(f[x]);
return f[x];
}
/**********lca部分***************/
struct query{
int Id,next;
}qu[N*2];
int qhead[N],qtot;
void addquery(int u,int Id)
{
qu[qtot].Id=Id;
qu[qtot].next=qhead[u];
qhead[u]=qtot++;
}
void get_lca(int u)
{
for(int i=qhead[u];i>=0;i=qu[i].next)
{
int Id=qu[i].Id;
if(cn[Id].lca==-1)
{
cn[Id].lca=u;
}
else{
cn[Id].lca=find(cn[Id].lca);
}
}
int fu=find(u);
for(int i=head[u];i>=0;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u])continue;
fa[v]=u;
get_lca(v);
int fv=find(v);
f[fv]=fu;
}
}
/*****树链剖分计数部分**********/
struct node{
long long sum,dpsum;
}stn[N];
node operator +(const node &a,const node &b)
{
node c;
c.sum=a.sum+b.sum;
c.dpsum=a.dpsum+b.dpsum;
return c;
}
node operator -(const node &a,const node &b)
{
node c;
c.sum=a.sum-b.sum;
c.dpsum=a.dpsum-b.dpsum;
return c;
}
int n,m;
int lowbit(int x)
{
return x&(-x);
}
void update(int x,node val)
{
while(x<=n)
{
stn[x]=stn[x]+val;
x+=lowbit(x);
}
}
node querys(int x)
{
node c;
c.sum=c.dpsum=0;
while(x>0)
{
c=c+stn[x];
x-=lowbit(x);
}
return c;
}
node queryS(int x,int y)
{
if(y==1)return querys(x);
else return querys(x)-querys(y-1);
}
/**********剖分部分**************/
int siz[N],son[N],top[N],nam[N],namtot,dep[N];
void dfs1(int u)
{
siz[u]=1;
son[u]=0;
int k=0;
for(int i=head[u];i>=0;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u])continue;
dep[v]=dep[u]+1;
dfs1(v);
siz[u]+=siz[v];
if(siz[v]>k)
{
k=siz[v];
son[u]=v;
}
}
}
void dfs2(int u,int tp)
{
top[u]=tp;
nam[u]=namtot++;
if(son[u])dfs2(son[u],tp);
for(int i=head[u];i>=0;i=e[i].next)
{
int v=e[i].to;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
/********************************/
void INIT(int n)
{
namtot=1;
tot=qtot=0;
for(int i=1;i<=n;i++)
{
stn[i].sum=stn[i].dpsum=0;
f[i]=i;
head[i]=qhead[i]=-1;
sum[i]=0;
}
}
long long Query(chain a,long long s,long long w)
{
int u=a.u;
int v=a.v;
int lca=a.lca;
int as=s+w;
node st;
while(1)
{
if(top[u]!=top[lca])
{
st=queryS(nam[u],nam[top[u]]);
as+=st.sum-st.dpsum;
}
else {
st=queryS(nam[u],nam[lca]);
as+=st.sum-st.dpsum;
break;
}
u=fa[top[u]];
}
while(v!=lca)
{
if(top[v]!=top[lca])
{
st=queryS(nam[v],nam[top[v]]);
as+=st.sum-st.dpsum;
}
else
{
st=queryS(nam[v],nam[lca]+1);
as+=st.sum-st.dpsum;
break;
}
v=fa[top[v]];
}
return as;
}
int main()
{
// clock_t stime=clock();
// freopen("1006.in","r",stdin);
int T;
scanf("%d",&T);
while(T--)
{
scanf("%d%d",&n,&m);
INIT(n);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
addEDGE(u,v);
addEDGE(v,u);
}
for(int i=1;i<=m;i++)
{
scanf("%d%d%d",&cn[i].u,&cn[i].v,&cn[i].w);
cn[i].lca=-1;
addquery(cn[i].u,i);
addquery(cn[i].v,i);
}
fa[1]=0;
get_lca(1);
dep[0]=0;
dfs1(1);
dfs2(1,1);
he=ta=0;
que[ta++]=1;
int t=0;
for(int i=1;i<=n;i++)qhead[i]=-1;
qtot=0;
for(int i=1;i<=m;i++)
{
addquery(cn[i].lca,i);
}
while(he<ta)
{
int u=que[he++];
for(int i=head[u];i>=0;i=e[i].next)
{
int v=e[i].to;
if(v!=fa[u])que[ta++]=v;
}
}
for(int i=ta-1;i>=0;i--)
{
int u=que[i];
long long dps=0;
for(int j=qhead[u];j>=0;j=qu[j].next)
{
dps=max(dps,Query(cn[qu[j].Id],sum[u],cn[qu[j].Id].w));
}
if(dps<sum[u])dps=sum[u];
node c;
c.sum=sum[u];
c.dpsum=dps;
update(nam[u],c);
if(i!=0)sum[fa[u]]+=dps;
}
node st=queryS(1,1);
printf("%lld\n",max(st.sum,st.dpsum));
}
// printf("%dms\n",clock()-stime);
}
/*
1
7 3
1 2
1 3
2 4
2 5
3 6
3 7
1 2 3
2 4 3
2 5 4
1
9 3
1 2
1 3
2 4
2 6
6 7
3 5
3 8
5 9
4 5 5
8 9 6
6 7 4
*/