POJ2195 给出二分图,求最小费用的最佳匹配。
我们把费用取相反数,求一次最大匹配就可以了。KM算法必须要保证n==m。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
using namespace std;
#define INF 0x3f3f3f3f
#define Maxn 120
struct Node{
int x,y;
}g[Maxn],h[Maxn];
int map[Maxn][Maxn],match[Maxn],slack[Maxn],lx[Maxn],ly[Maxn];
bool visx[Maxn],visy[Maxn];
char st[Maxn][Maxn];
int N,M,n,m;
int abs(int x){
if (x<0) x=-x;return x;
}
int Dis(Node a,Node b){
return abs(a.x-b.x)+abs(a.y-b.y);
}
void init(){
memset(map,0,sizeof(map));
n=m=0;
for (int i=1;i<=N;i++){
scanf("%s",st[i]+1);st[i][0]='~';
for (int j=1;j<=M;j++){
if (st[i][j]=='m'){
n++;
g[n].x=i;g[n].y=j;
}
if (st[i][j]=='H'){
m++;
h[m].x=i;h[m].y=j;
}
}
}
for (int i=1;i<=n;i++){
for (int j=1;j<=m;j++){
map[i][j]=-Dis(g[i],h[j]);
}
}
}
bool dfs(int i){
visx[i]=true;
for (int j=1;j<=m;j++){
int tmp=lx[i]+ly[j]-map[i][j];
if (visy[j]) continue;
if (tmp==0){
visy[j]=true;
if (match[j]==-1|| dfs(match[j])){
match[j]=i;
return true;
}
}
else if (tmp<slack[j]){
slack[j]=tmp;
}
}
return false;
}
int KM(){
memset(match,-1,sizeof(match));
//memset(lx,0,sizeof(lx));
memset(ly,0,sizeof(ly));
for (int i=1;i<=n;i++){
lx[i]=-INF;
for (int j=1;j<=m;j++){
if (lx[i]<map[i][j]) lx[i]=map[i][j];
}
}
for (int k=1;k<=n;k++){
for (int i=1;i<=n;i++) slack[i]=INF;
while(1){
memset(visx,false,sizeof(visx));
memset(visy,false,sizeof(visy));
if (dfs(k)) break;
int dx=INF;
for (int i=1;i<=m;i++){
if (!visy[i]&&dx>slack[i]) dx=slack[i];
}
for (int i=1;i<=n;i++){
if (visx[i]) lx[i]-=dx;
}
for (int i=1;i<=m;i++){
if (visy[i]) ly[i]+=dx;
else slack[i]-=dx;
}
}
}
int ret=0;
for (int i=1;i<=m;i++)
ret+=map[match[i]][i];
return -ret;
}
int main(){
while (~scanf("%d%d",&N,&M),N||M){
init();
int ans=KM();
printf("%d\n",ans);
}
return 0;
}
还有一种做法,slack实际上相当于只要取一个最小值就可以了。我们可以省略这个数组,代码也稍微短一点。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cmath>
using namespace std;
#define INF 0x3f3f3f3f
#define Maxn 120
struct Node{
int x,y;
}g[Maxn],h[Maxn];
int map[Maxn][Maxn],match[Maxn],lx[Maxn],ly[Maxn];
bool visx[Maxn],visy[Maxn];
char st[Maxn][Maxn];
int N,M,n,m,slack;
int abs(int x){
if (x<0) x=-x;return x;
}
int Dis(Node a,Node b){
return abs(a.x-b.x)+abs(a.y-b.y);
}
void init(){
memset(map,0,sizeof(map));
n=m=0;
for (int i=1;i<=N;i++){
scanf("%s",st[i]+1);st[i][0]='~';
for (int j=1;j<=M;j++){
if (st[i][j]=='m'){
n++;
g[n].x=i;g[n].y=j;
}
if (st[i][j]=='H'){
m++;
h[m].x=i;h[m].y=j;
}
}
}
for (int i=1;i<=n;i++){
for (int j=1;j<=m;j++){
map[i][j]=-Dis(g[i],h[j]);
}
}
}
bool dfs(int i){
visx[i]=true;
for (int j=1;j<=m;j++){
int tmp=lx[i]+ly[j]-map[i][j];
if (visy[j]) continue;
if (tmp==0){
visy[j]=true;
if (match[j]==-1|| dfs(match[j])){
match[j]=i;
return true;
}
}
else if (tmp<slack){
slack=tmp;
}
}
return false;
}
int KM(){
memset(match,-1,sizeof(match));
//memset(lx,0,sizeof(lx));
memset(ly,0,sizeof(ly));
for (int i=1;i<=n;i++){
lx[i]=-INF;
for (int j=1;j<=m;j++){
if (lx[i]<map[i][j]) lx[i]=map[i][j];
}
}
for (int k=1;k<=n;k++){
while(1){
memset(visx,false,sizeof(visx));
memset(visy,false,sizeof(visy));
slack=INF;
if (dfs(k)) break;
int dx=slack;
for (int i=1;i<=n;i++){
if (visx[i]) lx[i]-=dx;
if (visy[i]) ly[i]+=dx;
}
}
}
int ret=0;
for (int i=1;i<=m;i++)
ret+=map[match[i]][i];
return -ret;
}
int main(){
while (~scanf("%d%d",&N,&M),N||M){
init();
int ans=KM();
printf("%d\n",ans);
}
return 0;
}