给定一个二分图,其中左半部包含 n1 个点(编号 1∼n1),右半部包含 n2 个点(编号 1∼n2),二分图共包含 m 条边。
数据保证任意一条边的两个端点都不可能在同一部分中。
请你求出二分图的最大匹配数。
二分图的匹配:给定一个二分图 G,在 G 的一个子图 M 中,M 的边集 {E} 中的任意两条边都不依附于同一个顶点,则称 M 是一个匹配。
二分图的最大匹配:所有匹配中包含边数最多的一组匹配被称为二分图的最大匹配,其边数即为最大匹配数。
输入格式
第一行包含三个整数 n1、 n2 和 m。
接下来 m 行,每行包含两个整数 u 和 v,表示左半部点集中的点 u 和右半部点集中的点 v 之间存在一条边。
输出格式
输出一个整数,表示二分图的最大匹配数。
数据范围
1≤n1,n2≤500
1≤u≤n1
1≤v≤n2
1≤m≤105
输入样例:
2 2 4
1 1
1 2
2 1
2 2
输出样例:
2
解析 :
匈牙利算法是一种解决二分图最大匹配问题的算法,主要用于解决任务分配等问题。以下是匈牙利算法的步骤:
1. **初始化:** 对于给定的二分图,将每个顶点的标记初始化为未匹配。
2. **从左侧开始:** 从二分图的左侧开始,对每个未匹配的顶点尝试找到增广路径。
3. **寻找增广路径:** 从当前未匹配的左侧顶点开始,尝试沿着未匹配的边找到增广路径。增广路径是一条交替经过未匹配边和已匹配边的路径。
4. **增广路径存在:** 如果找到增广路径,则根据增广路径更新匹配关系。更新的方式是将路径上的已匹配边变为未匹配,未匹配边变为已匹配。
5. **增广路径不存在:** 如果无法找到增广路径,说明当前左侧顶点已经无法找到匹配,此时需要尝试调整之前的匹配,即修改之前已匹配的边。
6. **调整匹配:** 通过调整匹配,使得右侧的顶点重新可以匹配。具体做法是选择已匹配边中的一条,然后尝试寻找增广路径。
7. **重复步骤2-6:** 重复进行步骤2到步骤6,直到无法找到更多的增广路径为止。
8. **结束:** 当无法再找到增广路径时,算法结束。此时已经找到了二分图的最大匹配。
核心代码:
主函数中的算法
for (int i = 1; i <= n1; i++) {
memset(st, 0, sizeof st);
if (find(i)) {
ret++;
}
}
函数:
int find(int u) {
for (int i = h[u]; i != -1;i=ne[i]) {
int j = e[i];
if (!st[j]) {
st[j] = 1;
//如果点 j 还没有被匹配,或者find(match[j])能为与j匹配的点找到新的下家
if (match[j] == 0 || find(match[j])) {
match[j] = u;
return 1;
}
}
}
return 0;
}
匈牙利算法的核心思想是通过不断寻找增广路径,来扩展当前的匹配,直到无法找到更多的增广路径。算法的时间复杂度为O(V*E),其中V是顶点数,E是边数。
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<ctime>
#include<algorithm>
#include<utility>
#include<stack>
#include<queue>
#include<vector>
#include<set>
#include<math.h>
#include<map>
#include<sstream>
#include<deque>
#include<unordered_map>
using namespace std;
const int N = 5e2+3, M = 2e5 + 5, INF = 1e9;
int n1, n2, m;
int h[N], e[M], ne[M], idx;
int st[N], match[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
int find(int u) {
for (int i = h[u]; i != -1;i=ne[i]) {
int j = e[i];
if (!st[j]) {
st[j] = 1;
if (match[j] == 0 || find(match[j])) {
match[j] = u;
return 1;
}
}
}
return 0;
}
int main() {
scanf("%d%d%d", &n1, &n2, &m);
memset(h, -1, sizeof h);
for (int i = 1,a,b; i <= m; i++) {
scanf("%d%d", &a, &b);
add(a, b);
}
int ret=0;
for (int i = 1; i <= n1; i++) {
memset(st, 0, sizeof st);
if (find(i)) {
ret++;
}
}
cout << ret << endl;
return 0;
}