并查集理论知识复习与例题解析
一、并查集(Disjoint Set) 概念
1. 出现背景
并查集(Disjoint Set) 的出现源于 数学中等价关系的高效管理需求 和 计算机算法对集合操作的性能优化。其核心价值在于通过简洁的结构和高效的操作(接近常数时间),解决了大量实际问题,如 连通性判断、动态图处理等,至今仍是算法竞赛和工程实践中的重要工具。
2. 核心思想
- 不相交集合 的维护与查询
- 两种核心操作:
Find
:查询元素所属集合Union
:合并两个集合
3. 存储方式
- 使用 父节点数组(parent[N]) 实现
- 初始化时每个元素的父节点指向自己
4. 优化方法
- 路径压缩(Find优化):将查询路径上的节点直接连接到根节点
- 按秩合并(Union优化):总是将小树合并到大树上
5. 实现要点
// c++
const int N = 1e3 + 10;
int parent[N];
int find(int u) {
if (parent[u] != u) parent[u] = find(parent[u]);
return parent[u];
}
void union(int u, int v) {
u = find(u); v = find(v);
if (u == v) return;
parent[v] = u;
}
二、例题解析
例题1:P3367 【模板】并查集
题目大意
实现并查集基本操作:合并集合和查询是否同属
解题思路
- 标准并查集模板实现
- 使用路径压缩+按秩合并优化
代码实现
// C++
#include <iostream>
using namespace std;
const int N = 1e4+10;
int n, m;
// fa 数组用于存储每个元素的父节点,rk 数组用于存储每个集合的秩(树的高度)
int fa[N], rk[N];
// 查找元素 x 所在集合的根节点,并进行路径压缩
// 路径压缩的目的是让每个节点直接指向根节点,减少后续查找的时间复杂度
int find(int x) {
// 如果 x 的父节点就是它本身,说明 x 是根节点,直接返回 x
// 否则,递归查找 x 的父节点的根节点,并将 x 的父节点直接设为根节点
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
// 合并元素 x 和 y 所在的集合
void merge(int x, int y) {
// 先分别找到 x 和 y 所在集合的根节点
x = find(x), y = find(y);
// 如果 x 和 y 已经在同一个集合中,直接返回,无需合并
if (x == y) return;
// 为了让树的高度尽可能小,将秩较小的集合合并到秩较大的集合中
// 如果 x 所在集合的秩大于 y 所在集合的秩,交换 x 和 y
if (rk[x] > rk[y]) swap(x, y);
// 将 x 所在集合的根节点的父节点设为 y 所在集合的根节点
fa[x] = y;
// 如果两个集合的秩相等,合并后 y 所在集合的秩需要加 1
if (rk[x] == rk[y]) rk[y]++;
}
int main(void) {
cin >> n >> m;
// 初始化 fa 数组,每个元素的父节点初始化为它本身,表示每个元素单独成一个集合
// 初始化 rk 数组,每个集合的秩初始化为 1
for(int i = 1; i <= n; i++) {
fa[i] = i;
rk[i] = 1;
}
// 循环处理 m 次操作
while(m--) {
int op, x, y;
cin >> op >> x >> y;
// 如果操作类型为 1,调用 merge 函数合并 x 和 y 所在的集合
if (op == 1) merge(x, y);
// 如果操作类型不为 1,判断 x 和 y 是否在同一个集合中
// 如果在同一个集合中,输出 "Y",否则输出 "N"
else cout << (find(x) == find(y) ? "Y" : "N") << endl;
}
return 0;
}
# Python
def main():
import sys
# 设置递归深度限制,避免因递归过深导致栈溢出
sys.setrecursionlimit(1 << 25)
# 读取输入的元素数量 n 和操作数量 m
n, m = map(int, sys.stdin.readline().split())
# fa 列表用于存储每个元素的父节点,初始时每个元素的父节点是它本身
fa = list(range(n+1))
# rank 列表用于存储每个集合的秩(树的高度),初始时每个集合的秩为 1
rank = [1]*(n+1)
# 查找元素 x 所在集合的根节点,并进行路径压缩
def find(x):
# 如果 x 的父节点不是它本身,递归查找 x 的父节点的根节点,并将 x 的父节点直接设为根节点
if fa[x] != x:
fa[x] = find(fa[x])
return fa[x]
# 循环处理 m 次操作
for _ in range(m):
# 读取操作类型 op、元素 x 和元素 y
op, x, y = map(int, sys.stdin.readline().split())
if op == 1:
# 分别找到 x 和 y 所在集合的根节点
fx, fy = find(x), find(y)
# 如果 x 和 y 已经在同一个集合中,跳过本次合并操作
if fx == fy:
continue
# 为了让树的高度尽可能小,将秩较小的集合合并到秩较大的集合中
# 如果 fx 所在集合的秩大于 fy 所在集合的秩,交换 fx 和 fy
if rank[fx] > rank[fy]:
fx, fy = fy, fx
# 将 fx 所在集合的根节点的父节点设为 fy 所在集合的根节点
fa[fx] = fy
# 如果两个集合的秩相等,合并后 fy 所在集合的秩需要加 1
if rank[fx] == rank[fy]:
rank[fy] += 1
else:
# 判断 x 和 y 是否在同一个集合中
# 如果在同一个集合中,输出 "Y",否则输出 "N"
print("Y" if find(x) == find(y) else "N")
if __name__ == "__main__":
main()
// Java
import java.util.Scanner;
public class Main {
// 静态数组 fa 用于存储每个元素的父节点,rank 用于存储每个集合的秩(树的高度)
static int[] fa, rank;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
// 读取输入的元素数量 n 和操作数量 m
int n = sc.nextInt();
int m = sc.nextInt();
// 初始化 fa 数组和 rank 数组
fa = new int[n+1];
rank = new int[n+1];
// 每个元素的父节点初始化为它本身,表示每个元素单独成一个集合
// 每个集合的秩初始化为 1
for (int i = 1; i <= n; i++) {
fa[i] = i;
rank[i] = 1;
}
// 循环处理 m 次操作
while (m-- > 0) {
// 读取操作类型 op、元素 x 和元素 y
int op = sc.nextInt();
int x = sc.nextInt();
int y = sc.nextInt();
// 如果操作类型为 1,调用 merge 函数合并 x 和 y 所在的集合
if (op == 1) {
merge(x, y);
} else {
// 判断 x 和 y 是否在同一个集合中
// 如果在同一个集合中,输出 "Y",否则输出 "N"
System.out.println(find(x) == find(y) ? "Y" : "N");
}
}
}
// 查找元素 x 所在集合的根节点,并进行路径压缩
static int find(int x) {
// 如果 x 的父节点不是它本身,递归查找 x 的父节点的根节点,并将 x 的父节点直接设为根节点
if (fa[x] != x) {
fa[x] = find(fa[x]);
}
return fa[x];
}
// 合并元素 x 和 y 所在的集合
static void merge(int x, int y) {
// 先分别找到 x 和 y 所在集合的根节点
x = find(x);
y = find(y);
// 如果 x 和 y 已经在同一个集合中,直接返回,无需合并
if (x == y) return;
// 为了让树的高度尽可能小,将秩较小的集合合并到秩较大的集合中
// 如果 x 所在集合的秩大于 y 所在集合的秩,交换 x 和 y
if (rank[x] > rank[y]) {
int t = x; x = y; y = t;
}
// 将 x 所在集合的根节点的父节点设为 y 所在集合的根节点
fa[x] = y;
// 如果两个集合的秩相等,合并后 y 所在集合的秩需要加 1
if (rank[x] == rank[y]) rank[y]++;
}
}
例题2:P1551 亲戚
题目大意
判断多个查询中的两个人是否属于同一家族
解题思路
与模板题完全相同
最终输出查询结果即可
代码实现
(与上一题完全相同,只需调整输入输出格式)
例题3:P1955 [NOI2015] 程序自动分析
题目大意
处理变量相等/不相等约束,判断是否矛盾
解题思路
-
离散化处理大范围数据
-
先处理所有相等约束建立并查集
-
后检查所有不等约束是否冲突
代码实现
// C++
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 2e5+10;
int fa[N], rk[N];
struct Query { int x, y, e; };
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
void merge(int x, int y) {
x = find(x), y = find(y);
if (x != y) fa[x] = y;
}
int main() {
int T;
cin >> T;
while (T--) {
vector<Query> qs;
vector<int> nums;
int n;
cin >> n;
for (int i = 0; i < n; i++) {
int x, y, e;
cin >> x >> y >> e;
qs.push_back({x, y, e});
nums.push_back(x);
nums.push_back(y);
}
// 离散化
sort(nums.begin(), nums.end());
nums.erase(unique(nums.begin(), nums.end()), nums.end());
int cnt = nums.size();
// 初始化
for (int i = 0; i < cnt; i++) fa[i] = i;
// 处理相等关系
bool conflict = false;
for (auto &q : qs) {
if (!q.e) continue;
int x = lower_bound(nums.begin(), nums.end(), q.x) - nums.begin();
int y = lower_bound(nums.begin(), nums.end(), q.y) - nums.begin();
merge(x, y);
}
// 检查不等关系
for (auto &q : qs) {
if (q.e) continue;
int x = lower_bound(nums.begin(), nums.end(), q.x) - nums.begin();
int y = lower_bound(nums.begin(), nums.end(), q.y) - nums.begin();
if (find(x) == find(y)) {
conflict = true;
break;
}
}
cout << (conflict ? "NO" : "YES") << endl;
}
return 0;
}
# Python
def main():
import sys
input = sys.stdin.read
data = input().split()
idx = 0
T = int(data[idx]); idx +=1
for _ in range(T):
n = int(data[idx]); idx +=1
qs = []
nums = []
for __ in range(n):
x = int(data[idx])
y = int(data[idx+1])
e = int(data[idx+2])
idx +=3
qs.append( (x,y,e) )
nums.extend([x,y])
# 离散化
nums = sorted(list(set(nums)))
mp = {v:i for i,v in enumerate(nums)}
size = len(nums)
# 初始化
fa = list(range(size))
def find(x):
if fa[x] != x:
fa[x] = find(fa[x])
return fa[x]
# 处理相等
conflict = False
for x,y,e in qs:
if e != 1:
continue
xx = mp[x]
yy = mp[y]
fx = find(xx)
fy = find(yy)
if fx != fy:
fa[fx] = fy
# 检查不等
for x,y,e in qs:
if e != 0:
continue
xx = mp[x]
yy = mp[y]
if find(xx) == find(yy):
conflict = True
break
print("NO" if conflict else "YES")
if __name__ == "__main__":
main()
// Java
import java.util.*;
public class Main {
static int[] fa;
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int T = sc.nextInt();
while (T-- > 0) {
int n = sc.nextInt();
List<int[]> qs = new ArrayList<>();
Set<Integer> set = new HashSet<>();
for (int i = 0; i < n; i++) {
int x = sc.nextInt();
int y = sc.nextInt();
int e = sc.nextInt();
qs.add(new int[]{x, y, e});
set.add(x);
set.add(y);
}
// 离散化
List<Integer> list = new ArrayList<>(set);
Collections.sort(list);
Map<Integer, Integer> map = new HashMap<>();
for (int i = 0; i < list.size(); i++) {
map.put(list.get(i), i);
}
int size = list.size();
// 初始化
fa = new int[size];
for (int i = 0; i < size; i++) fa[i] = i;
// 处理相等
boolean conflict = false;
for (int[] q : qs) {
if (q[2] == 0) continue;
int x = map.get(q[0]);
int y = map.get(q[1]);
int fx = find(x);
int fy = find(y);
if (fx != fy) {
fa[fx] = fy;
}
}
// 检查不等
for (int[] q : qs) {
if (q[2] == 1) continue;
Integer x = map.get(q[0]);
Integer y = map.get(q[1]);
if (x == null || y == null) continue;
if (find(x) == find(y)) {
conflict = true;
break;
}
}
System.out.println(conflict ? "NO" : "YES");
}
}
static int find(int x) {
if (fa[x] != x) {
fa[x] = find(fa[x]);
}
return fa[x];
}
}
三、总结
-
并查集擅长处理分组和连通性问题
-
遇到大范围数据时优先考虑离散化
-
注意合并顺序和路径压缩的优化
博客中的 数据结构和算法模板以及算法题 的全部 C++ 代码 和部分 Python 及 Java 代码实现都在作者的 Github 仓库中能找到,后续会补充 Python 和 Java 实现
此次普通并查集的理论知识复习和模板题解析就到这里了,如有不足之处欢迎各位指出,感谢!