题目描述
题目大意:给出一棵包含n个节点的树,i号点的点权为i\&(111)_2。
定义树上的一条路径的权值为路径上所有点的点权的异或和。
现在有m1条喜剧线,m2条悲剧线。
每条线的满意度Si也会给出。
现在你需要选出给出的感情线的一个子集,使得其满意度之和最大。
一个方案合法,当且仅当任意喜剧线x和悲剧线y的交时, Sx∗Sy≤7 ,其中 Sx,Sy 均表示路径的权值。
求最大满意度。
题解
首先我们需要快速的判断一条悲剧线和一条喜剧线是否相交。刚开始只会用树链剖分
O(logn)
的判断。
现在get到一个新的姿势。求出两个链的lca ,设
alca
表示两条链中较浅的lca,
blca
表示较深的lca ,如果
blca
在
alca
的子树中,且a的两个端点至少有一个在
blca
的子树中,那么两条链一定有交点。
然后我们可以进行最小割建图
S−>a[i]
其中a[i]表示喜剧,容量为该喜剧的s
b[j]−>T
其中b[j]表示悲剧,容量为该悲剧的s
a[i]−>b[j]
两条剧情线相交,路径权值的乘积<=7,容量为路径权值的乘积
a[i]−>b[j]
两条剧情线相交,路径权值的乘积>7,容量为INF
最后的答案就是 ∑as+∑bs−flow
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
#define N 2000003
#define inf 1000000000
using namespace std;
int sz,tot,point[N],v[N],nxt[N],remain[N],deep[N],pos[N],l[N],r[N],mi[100];
int val[N][21],fa[N][21],cur[N],n,m1,m2;
struct data{
int x,y,lca,s,v;
}a[N],b[N];
void add(int x,int y)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
void build(int x,int y,int z)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; remain[tot]=z;
tot++; nxt[tot]=point[y]; point[y]=tot; v[tot]=x; remain[tot]=0;
}
void dfs(int x,int f)
{
deep[x]=deep[f]+1; pos[x]=++sz; l[x]=sz;
for (int i=1;i<20;i++) {
if (deep[x]-mi[i]<0) continue;
val[x][i]=val[x][i-1]^val[fa[x][i-1]][i-1];
fa[x][i]=fa[fa[x][i-1]][i-1];
}
for (int i=point[x];i;i=nxt[i]) {
if (v[i]==f) continue;
fa[v[i]][0]=x;
val[v[i]][0]=(v[i]&7);
dfs(v[i],x);
}
r[x]=sz;
}
void lca(int k,int opt)
{
int x,y;
if (opt==1) {
x=a[k].x; y=a[k].y;
}
else x=b[k].x,y=b[k].y;
int xor1=0; int k1;
if (deep[x]<deep[y]) swap(x,y);
k1=deep[x]-deep[y];
for (int i=0;i<20;i++)
if ((k1>>i)&1) {
xor1^=val[x][i];
x=fa[x][i];
}
if (x==y) {
if (opt==1) a[k].lca=x,a[k].v=xor1^(x&7);
else b[k].lca=x,b[k].v=xor1^(x&7);
return;
}
for (int i=20;i>=0;i--)
if (fa[x][i]!=fa[y][i]) {
xor1^=val[x][i]; xor1^=val[y][i];
x=fa[x][i],y=fa[y][i];
}
xor1^=(x&7); xor1^=(y&7);
xor1^=(fa[x][0]&7);
if (opt==1) a[k].lca=fa[x][0],a[k].v=xor1;
else b[k].lca=fa[x][0],b[k].v=xor1;
}
bool bfs(int s,int t)
{
memset(deep,0x7f,sizeof(deep));
for (int i=1;i<=t;i++) cur[i]=point[i];
deep[1]=0;
queue<int> p; p.push(1);
while (!p.empty()) {
int now=p.front(); p.pop();
for (int i=point[now];i!=-1;i=nxt[i]) {
if (deep[v[i]]>inf&&remain[i]) {
deep[v[i]]=deep[now]+1;
p.push(v[i]);
}
}
}
if (deep[t]>inf) return false;
return true;
}
int dfs(int now,int t,int limit)
{
if (now==t||!limit) return limit;
int flow=0,f;
for (int i=cur[now];i!=-1;i=nxt[i]) {
cur[now]=i;
if (deep[v[i]]==deep[now]+1&&(f=dfs(v[i],t,min(limit,remain[i])))){
flow+=f; limit-=f;
remain[i]-=f; remain[i^1]+=f;
if (!limit) break;
}
}
return flow;
}
int dinic(int s,int t)
{
int ans=0;
while (bfs(s,t)) ans+=dfs(s,t,inf);
return ans;
}
bool check(data a,data b)
{
if (deep[a.lca]>deep[b.lca]) swap(a,b);
bool pd=false,pd1=false,pd2=false;
if (pos[b.lca]>=l[a.lca]&&pos[b.lca]<=r[a.lca]) pd=true;
if (pos[a.x]>=l[b.lca]&&pos[a.x]<=r[b.lca]) pd1=true;
if (pos[a.y]>=l[b.lca]&&pos[a.y]<=r[b.lca]) pd2=true;
return pd&&(pd1||pd2);
}
int main()
{
freopen("B.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d%d",&n,&m1,&m2);
for (int i=1;i<n;i++) {
int x,y; scanf("%d%d",&x,&y);
add(x,y);
}
mi[0]=1;
for (int i=1;i<=20;i++) mi[i]=mi[i-1]*2;
for (int i=1;i<=m1;i++) scanf("%d%d%d",&a[i].x,&a[i].y,&a[i].s);
for (int i=1;i<=m2;i++) scanf("%d%d%d",&b[i].x,&b[i].y,&b[i].s);
dfs(1,0); int sum=0;
for (int i=1;i<=m1;i++) lca(i,1),sum+=a[i].s;//cout<<a[i].lca<<" "<<a[i].v<<endl;
for (int i=1;i<=m2;i++) lca(i,2),sum+=b[i].s;//cout<<b[i].lca<<" "<<b[i].v<<endl;
tot=-1;
memset(point,-1,sizeof(point));
for (int i=1;i<=m1;i++)
for (int j=1;j<=m2;j++)
if (check(a[i],b[j])) {
if (a[i].v*b[j].v>7) build(i+1,j+m1+1,inf);
else build(i+1,j+m1+1,a[i].v*b[j].v);
}
for (int i=1;i<=m1;i++) build(1,i+1,a[i].s);
for (int i=1;i<=m2;i++) build(i+m1+1,m1+m2+2,b[i].s);
int flow=dinic(1,m1+m2+2);
printf("%d\n",sum-flow);
}