并查集(Disjoint Set) 理论知识复习与例题解析

一、并查集(Disjoint Set) 概念

1. 出现背景

并查集(Disjoint Set) 的出现源于 数学中等价关系的高效管理需求计算机算法对集合操作的性能优化。其核心价值在于通过简洁的结构和高效的操作(接近常数时间),解决了大量实际问题,如 连通性判断、动态图处理等,至今仍是算法竞赛和工程实践中的重要工具。

2. 核心思想

  • 不相交集合 的维护与查询
  • 两种核心操作:
    • Find:查询元素所属集合
    • Union:合并两个集合

3. 存储方式

  • 使用 父节点数组(parent[N]) 实现
  • 初始化时每个元素的父节点指向自己

4. 优化方法

  • 路径压缩(Find优化):将查询路径上的节点直接连接到根节点
  • 按秩合并(Union优化):总是将小树合并到大树上

5. 实现要点

// c++
const int N = 1e3 + 10;

int parent[N];

int find(int u) {
    if (parent[u] != u) parent[u] = find(parent[u]);
    return parent[u];
}

void union(int u, int v) {
    u = find(u); v = find(v);

    if (u == v) return;
    
    parent[v] = u;
}

二、例题解析


例题1:P3367 【模板】并查集

P3367 【模板】并查集

题目大意

实现并查集基本操作:合并集合和查询是否同属

解题思路

  1. 标准并查集模板实现
  2. 使用路径压缩+按秩合并优化

代码实现

// C++
#include <iostream>
using namespace std;

const int N = 1e4+10;
int n, m;

// fa 数组用于存储每个元素的父节点,rk 数组用于存储每个集合的秩(树的高度)
int fa[N], rk[N];

// 查找元素 x 所在集合的根节点,并进行路径压缩
// 路径压缩的目的是让每个节点直接指向根节点,减少后续查找的时间复杂度
int find(int x) {
    // 如果 x 的父节点就是它本身,说明 x 是根节点,直接返回 x
    // 否则,递归查找 x 的父节点的根节点,并将 x 的父节点直接设为根节点
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

// 合并元素 x 和 y 所在的集合
void merge(int x, int y) {
    // 先分别找到 x 和 y 所在集合的根节点
    x = find(x), y = find(y);

    // 如果 x 和 y 已经在同一个集合中,直接返回,无需合并
    if (x == y) return;

    // 为了让树的高度尽可能小,将秩较小的集合合并到秩较大的集合中
    // 如果 x 所在集合的秩大于 y 所在集合的秩,交换 x 和 y
    if (rk[x] > rk[y]) swap(x, y);

    // 将 x 所在集合的根节点的父节点设为 y 所在集合的根节点
    fa[x] = y;

    // 如果两个集合的秩相等,合并后 y 所在集合的秩需要加 1
    if (rk[x] == rk[y]) rk[y]++;
}

int main(void) {
    cin >> n >> m;

    // 初始化 fa 数组,每个元素的父节点初始化为它本身,表示每个元素单独成一个集合
    // 初始化 rk 数组,每个集合的秩初始化为 1
    for(int i = 1; i <= n; i++) {
        fa[i] = i;
        rk[i] = 1;
    }

    // 循环处理 m 次操作
    while(m--) {
        int op, x, y;
        cin >> op >> x >> y;

        // 如果操作类型为 1,调用 merge 函数合并 x 和 y 所在的集合
        if (op == 1) merge(x, y);

        // 如果操作类型不为 1,判断 x 和 y 是否在同一个集合中
        // 如果在同一个集合中,输出 "Y",否则输出 "N"
        else cout << (find(x) == find(y) ? "Y" : "N") << endl;
    }
    return 0;
}
# Python
def main():
    import sys
    # 设置递归深度限制,避免因递归过深导致栈溢出
    sys.setrecursionlimit(1 << 25)
    # 读取输入的元素数量 n 和操作数量 m
    n, m = map(int, sys.stdin.readline().split())
    # fa 列表用于存储每个元素的父节点,初始时每个元素的父节点是它本身
    fa = list(range(n+1))
    # rank 列表用于存储每个集合的秩(树的高度),初始时每个集合的秩为 1
    rank = [1]*(n+1)

    # 查找元素 x 所在集合的根节点,并进行路径压缩
    def find(x):
        # 如果 x 的父节点不是它本身,递归查找 x 的父节点的根节点,并将 x 的父节点直接设为根节点
        if fa[x] != x:
            fa[x] = find(fa[x])
        return fa[x]
    
    # 循环处理 m 次操作
    for _ in range(m):
        # 读取操作类型 op、元素 x 和元素 y
        op, x, y = map(int, sys.stdin.readline().split())
        if op == 1:
            # 分别找到 x 和 y 所在集合的根节点
            fx, fy = find(x), find(y)
            # 如果 x 和 y 已经在同一个集合中,跳过本次合并操作
            if fx == fy:
                continue
            # 为了让树的高度尽可能小,将秩较小的集合合并到秩较大的集合中
            # 如果 fx 所在集合的秩大于 fy 所在集合的秩,交换 fx 和 fy
            if rank[fx] > rank[fy]:
                fx, fy = fy, fx
            # 将 fx 所在集合的根节点的父节点设为 fy 所在集合的根节点
            fa[fx] = fy
            # 如果两个集合的秩相等,合并后 fy 所在集合的秩需要加 1
            if rank[fx] == rank[fy]:
                rank[fy] += 1
        else:
            # 判断 x 和 y 是否在同一个集合中
            # 如果在同一个集合中,输出 "Y",否则输出 "N"
            print("Y" if find(x) == find(y) else "N")

if __name__ == "__main__":
    main()
// Java
import java.util.Scanner;

public class Main {
    // 静态数组 fa 用于存储每个元素的父节点,rank 用于存储每个集合的秩(树的高度)
    static int[] fa, rank;
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        // 读取输入的元素数量 n 和操作数量 m
        int n = sc.nextInt();
        int m = sc.nextInt();
        
        // 初始化 fa 数组和 rank 数组
        fa = new int[n+1];
        rank = new int[n+1];
        // 每个元素的父节点初始化为它本身,表示每个元素单独成一个集合
        // 每个集合的秩初始化为 1
        for (int i = 1; i <= n; i++) {
            fa[i] = i;
            rank[i] = 1;
        }
        
        // 循环处理 m 次操作
        while (m-- > 0) {
            // 读取操作类型 op、元素 x 和元素 y
            int op = sc.nextInt();
            int x = sc.nextInt();
            int y = sc.nextInt();
            // 如果操作类型为 1,调用 merge 函数合并 x 和 y 所在的集合
            if (op == 1) {
                merge(x, y);
            } else {
                // 判断 x 和 y 是否在同一个集合中
                // 如果在同一个集合中,输出 "Y",否则输出 "N"
                System.out.println(find(x) == find(y) ? "Y" : "N");
            }
        }
    }
    
    // 查找元素 x 所在集合的根节点,并进行路径压缩
    static int find(int x) {
        // 如果 x 的父节点不是它本身,递归查找 x 的父节点的根节点,并将 x 的父节点直接设为根节点
        if (fa[x] != x) {
            fa[x] = find(fa[x]);
        }
        return fa[x];
    }
    
    // 合并元素 x 和 y 所在的集合
    static void merge(int x, int y) {
        // 先分别找到 x 和 y 所在集合的根节点
        x = find(x);
        y = find(y);
        // 如果 x 和 y 已经在同一个集合中,直接返回,无需合并
        if (x == y) return;
        // 为了让树的高度尽可能小,将秩较小的集合合并到秩较大的集合中
        // 如果 x 所在集合的秩大于 y 所在集合的秩,交换 x 和 y
        if (rank[x] > rank[y]) {
            int t = x; x = y; y = t;
        }
        // 将 x 所在集合的根节点的父节点设为 y 所在集合的根节点
        fa[x] = y;
        // 如果两个集合的秩相等,合并后 y 所在集合的秩需要加 1
        if (rank[x] == rank[y]) rank[y]++;
    }
}


例题2:P1551 亲戚

P1551 亲戚

题目大意

判断多个查询中的两个人是否属于同一家族

解题思路

与模板题完全相同

最终输出查询结果即可

代码实现
(与上一题完全相同,只需调整输入输出格式)



例题3:P1955 [NOI2015] 程序自动分析

[NOI2015] 程序自动分析

题目大意

处理变量相等/不相等约束,判断是否矛盾

解题思路

  1. 离散化处理大范围数据

  2. 先处理所有相等约束建立并查集

  3. 后检查所有不等约束是否冲突


代码实现

// C++
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

const int N = 2e5+10;
int fa[N], rk[N];

struct Query { int x, y, e; };

int find(int x) {
    return fa[x] == x ? x : fa[x] = find(fa[x]);
}

void merge(int x, int y) {
    x = find(x), y = find(y);
    if (x != y) fa[x] = y;
}

int main() {
    int T;
    cin >> T;
    while (T--) {
        vector<Query> qs;
        vector<int> nums;
        
        int n;
        cin >> n;
        for (int i = 0; i < n; i++) {
            int x, y, e;
            cin >> x >> y >> e;
            qs.push_back({x, y, e});
            nums.push_back(x);
            nums.push_back(y);
        }
        
        // 离散化
        sort(nums.begin(), nums.end());
        nums.erase(unique(nums.begin(), nums.end()), nums.end());
        int cnt = nums.size();
        
        // 初始化
        for (int i = 0; i < cnt; i++) fa[i] = i;
        
        // 处理相等关系
        bool conflict = false;
        for (auto &q : qs) {
            if (!q.e) continue;
            int x = lower_bound(nums.begin(), nums.end(), q.x) - nums.begin();
            int y = lower_bound(nums.begin(), nums.end(), q.y) - nums.begin();
            merge(x, y);
        }
        
        // 检查不等关系
        for (auto &q : qs) {
            if (q.e) continue;
            int x = lower_bound(nums.begin(), nums.end(), q.x) - nums.begin();
            int y = lower_bound(nums.begin(), nums.end(), q.y) - nums.begin();
            if (find(x) == find(y)) {
                conflict = true;
                break;
            }
        }
        
        cout << (conflict ? "NO" : "YES") << endl;
    }
    return 0;
}
# Python
def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    idx = 0
    T = int(data[idx]); idx +=1
    
    for _ in range(T):
        n = int(data[idx]); idx +=1
        qs = []
        nums = []
        for __ in range(n):
            x = int(data[idx])
            y = int(data[idx+1])
            e = int(data[idx+2])
            idx +=3
            qs.append( (x,y,e) )
            nums.extend([x,y])
        
        # 离散化
        nums = sorted(list(set(nums)))
        mp = {v:i for i,v in enumerate(nums)}
        size = len(nums)
        
        # 初始化
        fa = list(range(size))
        
        def find(x):
            if fa[x] != x:
                fa[x] = find(fa[x])
            return fa[x]
        
        # 处理相等
        conflict = False
        for x,y,e in qs:
            if e != 1:
                continue
            xx = mp[x]
            yy = mp[y]
            fx = find(xx)
            fy = find(yy)
            if fx != fy:
                fa[fx] = fy
        
        # 检查不等
        for x,y,e in qs:
            if e != 0:
                continue
            xx = mp[x]
            yy = mp[y]
            if find(xx) == find(yy):
                conflict = True
                break
        
        print("NO" if conflict else "YES")

if __name__ == "__main__":
    main()
// Java
import java.util.*;

public class Main {
    static int[] fa;
    
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int T = sc.nextInt();
        while (T-- > 0) {
            int n = sc.nextInt();
            List<int[]> qs = new ArrayList<>();
            Set<Integer> set = new HashSet<>();
            
            for (int i = 0; i < n; i++) {
                int x = sc.nextInt();
                int y = sc.nextInt();
                int e = sc.nextInt();
                qs.add(new int[]{x, y, e});
                set.add(x);
                set.add(y);
            }
            
            // 离散化
            List<Integer> list = new ArrayList<>(set);
            Collections.sort(list);
            Map<Integer, Integer> map = new HashMap<>();
            for (int i = 0; i < list.size(); i++) {
                map.put(list.get(i), i);
            }
            int size = list.size();
            
            // 初始化
            fa = new int[size];
            for (int i = 0; i < size; i++) fa[i] = i;
            
            // 处理相等
            boolean conflict = false;
            for (int[] q : qs) {
                if (q[2] == 0) continue;
                int x = map.get(q[0]);
                int y = map.get(q[1]);
                int fx = find(x);
                int fy = find(y);
                if (fx != fy) {
                    fa[fx] = fy;
                }
            }
            
            // 检查不等
            for (int[] q : qs) {
                if (q[2] == 1) continue;
                Integer x = map.get(q[0]);
                Integer y = map.get(q[1]);
                if (x == null || y == null) continue;
                if (find(x) == find(y)) {
                    conflict = true;
                    break;
                }
            }
            
            System.out.println(conflict ? "NO" : "YES");
        }
    }
    
    static int find(int x) {
        if (fa[x] != x) {
            fa[x] = find(fa[x]);
        }
        return fa[x];
    }
}

三、总结

  1. 并查集擅长处理分组和连通性问题

  2. 遇到大范围数据时优先考虑离散化

  3. 注意合并顺序和路径压缩的优化


博客中的 数据结构和算法模板以及算法题 的全部 C++ 代码 和部分 PythonJava 代码实现都在作者的 Github 仓库中能找到,后续会补充 PythonJava 实现


此次普通并查集的理论知识复习和模板题解析就到这里了,如有不足之处欢迎各位指出,感谢!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值