题目
BZOJ3774 最优选择的升级版。
注意周围有3种状态(全选理,全选文,不统一),所以要拆2个点出来。
AC Code:
#include<bits/stdc++.h>
#define maxn 30010
#define maxm maxn * 100
#define inf 0x3f3f3f3f
using namespace std;
int n,m,mark[105][105],S,T,tot;
int a[105][105],b[105][105],c[105][105],d[105][105];
int dir[4][2]={{1,0},{-1,0},{0,1},{0,-1}};
int dis[maxn];
int info[maxn],Prev[maxm],to[maxm],cap[maxm],cnt_e=1;
inline void Node(int u,int v,int c){ Prev[++cnt_e]=info[u],info[u]=cnt_e,to[cnt_e]=v,cap[cnt_e]=c; }
inline void Line(int u,int v,int c,int d=0){ Node(u,v,c),Node(v,u,d); }
int aug(int now,int Max)
{
if(now == T) return Max;
int inc , st = Max;
for(int i=info[now];i;i=Prev[i])
if(cap[i] && dis[to[i]]+1 == dis[now])
{
inc = aug(to[i],min(cap[i] , st));
st -= inc , cap[i] -= inc , cap[i^1] += inc;
if(!st) break;
}
return Max - st;
}
bool BFS()
{
static queue<int>q;
memset(dis,-1,sizeof dis);
q.push(T),dis[T]=0;
for(int now;!q.empty();)
{
now = q.front() , q.pop();
for(int i=info[now];i;i=Prev[i])
if(cap[i^1] && dis[to[i]]==-1)
{
dis[to[i]] = dis[now] + 1;
q.push(to[i]);
}
}
return dis[S] != -1;
}
int main()
{
scanf("%d%d",&n,&m);
S = ++tot , T = ++tot;
int ans = 0;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
mark[i][j] = ++tot;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
scanf("%d",&a[i][j]),a[i][j]*=2;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
scanf("%d",&b[i][j]),b[i][j]*=2;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
scanf("%d",&c[i][j]),c[i][j]*=2;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
scanf("%d",&d[i][j]),d[i][j]*=2;
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
{
ans += a[i][j] + b[i][j] + c[i][j] + d[i][j];
Line(S,mark[i][j],a[i][j]+c[i][j]/2);
Line(mark[i][j],T,b[i][j]+d[i][j]/2);
Line(mark[i][j],mark[i][j]+n*m,c[i][j]/2,c[i][j]/2);
Line(mark[i][j],mark[i][j]+n*m*2,d[i][j]/2,d[i][j]/2);
Line(S,mark[i][j]+n*m,c[i][j]/2),Line(mark[i][j]+n*m*2,T,d[i][j]/2);
for(int k=0,x,y;k<4;k++)
if((x=i+dir[k][0])>=1 && x<=n && (y=j+dir[k][1])>=1 && y<=m)
Line(mark[x][y],mark[i][j]+n*m,0,inf),
Line(mark[i][j]+n*m*2,mark[x][y],0,inf);
}
for(;BFS();)
ans -= aug(S,inf);
printf("%d\n",ans/2);
}