【随机选择算法】

前言

本篇不会介绍快速排序的相关知识,读者应当掌握以下快速排序的内容:

  1. 经典随机化快排。
  2. 荷兰国旗问题(三路快速排序)。

重点是理解快速排序的分块函数 partition

编程语言:C++

介绍

本篇介绍利用快速排序的思想,完成在无序整数数组中找到第 K 大的整数或第 K 小的整数的目的。找到第 K 大的整数和第 K 小的整数是一种对立问题,即找第 K 大的整数等价于找第 N-K 小的整数。因此,这里以找第 K 大整数为例。

提供以下两个测试链接:

尝试解法

直接的想法是利用快速排序,将整个无序数组排序成升序。这样,求解第 K 大的整数可以直接从有序数组中获得。

以下是 LC215 的快速排序解法

思路

  • 将无序数组排序成升序(可选的排序方式很多,不一定选择快速排序)。
  • 第 K 大的整数对应升序数组中下标 N-K 的数。
class Solution {
public:
    int findKthLargest(vector<int>& nums, int k) {
        sort(nums.begin(), nums.end());
        return nums[nums.size() - k];
    }
};

快速排序的时间复杂度: O ( n log ⁡ n ) O(n \log n) O(nlogn)

我们能否尝试写出线性时间的算法解呢?答案当然可以。

随机选择算法

快速排序做了很多无用功,因为数组是否整体有序并不重要,重点是把第 K 大的数放入正确的位置。即保证第 K 大的整数左侧都是比它小的数,右侧都是比它大的数。这样,我们只需确定这个数的位置 (i),即可直接从数组中取值。

如何实现这样的划分呢?快速排序的 partition 函数就可以实现数组分块。这也是该算法为何称作随机选择算法的原因。

随机性从何体现呢?请看以下算法代码(同样是 LC215),结合代码,下面回答以下几个问题:

  • randomSelect 函数实现了什么功能?
  • partition 函数为什么选择随机化(而非固定)并且三路切分区间?
  • 时间复杂度/空间复杂度估计。
class Solution {
    int first, last;
public:
    int findKthLargest(vector<int>& nums, int k) {
        return randomSelect(nums, nums.size() - k);
    }
    int randomSelect(vector<int>& nums, int i){
        int ans = 0;
        for(int l = 0, r = nums.size() - 1; l <= r; ){
            partition(nums, l, r, nums[l + rand() % (r - l + 1)]);
            if(i <= first){
                r = first;
            }
            else if(i >= last){
                l = last;
            }
            else{
                ans = nums[i];
                break;
            }
        }
        return ans;
    }

    void partition(vector<int>& nums, int l, int r, int pivot){
        int i = l;
        first = l - 1, last = r + 1;
        while(i < last){
            if(nums[i] == pivot) i++;
            else if(nums[i] < pivot){
                swap(nums[i++], nums[++first]);
            }
            else{
                swap(nums[i], nums[--last]);
            }
        }
    }
};
  1. randomSelect 函数实现了什么功能?

    randomSelect 函数用于在无序数组中找到第 K 大的整数(对应升序数组中的第 N-K 小的整数)。

  2. 为什么选择随机化而非固定取左端点或者右端点的数划分数组?

    原因与朴素快速排序相同,依赖输入序列。若选择固定的枢轴(如总是选择左端点或右端点),在处理有序或近乎有序的序列时,时间复杂度会退化为 O ( n 2 ) O(n^2) O(n2)

  3. 为什么选择三路分割?

    三路分割用于处理数组中存在重复值的情况。例如,考虑无序数组 [3,2,1,1,7,9,4,3,5,5,5],要找到第 4 大的数。先随机选择一个数作为枢轴,例如 5 进行划分。假设三路划分结果如下 [3,2,1,1,4,3,5,5,5,9,7]。区间 [0,5] 小于枢轴 5,[6,8] 等于枢轴 5,区间 [9,10] 大于 5。第 4 大的数等价于第 7 小的数,7 位于 [6,8] 区间,因此第 4 大的数就是 5。

    可见,将重复数集中在一起形成小区间,可以加快求解的速度。 修改要求,例如求第 7 大的数和第 2 大的数,依旧随机选择 5 作为枢轴。由于第 7 大等价于第 4 小的整数,4 属于区间 [0,5],进入该区间再次划分以定位精确位置。类似地,求第 2 大的数等价于第 9 小的整数,9 属于区间 [9,10],进入该区间重复划分过程,直到确定精确的位置。

  4. 时空复杂度分析。

    • 空间复杂度 O ( 1 ) O(1) O(1),迭代写法;递归解法是 O ( h ) O(h) O(h),其中 h h h 是递归树的高度。
    • 期望时间复杂度:直接给出结论: O ( n ) O(n) O(n)。严格证明需要使用概率论和高等数学知识,详见《算法导论》一书。
    • 最坏时间复杂度:每次随机化总是选择数组的最大值或最小值,类似于随机化快速排序失败的情况,时间复杂度会退化到 O ( n 2 ) O(n^2) O(n2)。这种极端情况出现的概率极低。
    • 最好时间复杂度 O ( n ) O(n) O(n),每次随机化的数总是能精确对半划分。时间复杂度为 O ( n + n / 2 + n / 4 + ⋯ + 1 ) = O ( n ) O(n + n/2 + n/4 + \dots + 1) = O(n) O(n+n/2+n/4++1)=O(n)

洛谷:求第K小的数

#include<bits/stdc++.h>
using namespace std;

int n, k, a[5000005];
int ans;
int first, last;

void partition(int l, int r, int p){
    int i = l;
    first = l - 1, last = r + 1;
    while(i < last){
        if(a[i] == p) i++;
        else if(a[i] < p) swap(a[i++], a[++first]);
        else swap(a[i], a[--last]);
    }
}

void randomSelect(){
    for(int l = 0, r = n - 1; l <= r; ){
        partition(l, r, a[l + rand() % (r - l + 1)]);
        if(k <= first) r = first;
        else if(k >= last) l = last;
        else{
            ans = a[k];
            return;
        }
    }
}

int main(){
    scanf("%d%d", &n, &k);
    for(int i = 0; i < n; i++){
        scanf("%d", &a[i]);
    }
    randomSelect();
    printf("%d\n", ans);
    return 0;    
}

其它解法(扩展)

这里不会添加额外的篇幅介绍这些方法,感兴趣的读者可以自行查找相关资料学习。

小根堆

小根堆解法。以下是作者早期使用 Java 编写的代码。作者现在写堆的风格已经变成 这篇博客的风格,个人偏向后者。

class Solution {
    public static int MAX = (int)1e5 + 1;
    public static int[] heap = new int[MAX];
    public static int r = 0;

    public static void enqueue(int x){
        int p = (r - 1) / 2;
        int cur = r;
        heap[r++] = x;
        while(cur > 0){
            if(heap[cur] < heap[p]){
                int temp = heap[cur];
                heap[cur] = heap[p];
                heap[p] = temp;
            }
            cur = p;
            p = (cur - 1) / 2;
        }
    }

    public static int dequeue(){
        int ret = heap[0];
        heap[0] = heap[--r];
        int p = 0;
        int cur = 1;
        while(cur < r){
            if(cur + 1 < r && heap[cur] > heap[cur + 1])
                cur++;
            if(heap[cur] < heap[p]){
                int temp = heap[cur];
                heap[cur] = heap[p];
                heap[p] = temp;
            }
            p = cur;
            cur = p * 2 + 1;
        }
        return ret;
    }

    public int findKthLargest(int[] nums, int k) {
        if(nums == null || nums.length < 1){
            return 0;
        }
        r = 0;
        for(int i = 0; i < k; i++){
            enqueue(nums[i]);
        }
        for(int i = k; i < nums.length; i++){
            if(nums[i] > heap[0]){
                dequeue();
                enqueue(nums[i]);
            }
        }
        return dequeue();
    }
}
BFPRT算法

BFPRT 是由 Blum、Floyd、Pratt、Rivest 和 Tarjan 五位计算机科学家在 1973 年提出的一种可在 最坏情况下 保证线性时间 (O(n)) 的选择算法(Select Algorithm),被称为 Median of Medians 算法。

不依赖随机性,但空间复杂度 (O(\log n)) 不如随机选择算法。

以下代码注释由 ChatGPT-o1 生成,作者早期学习 BFPRT 算法提交通过的代码。代码较长,可以自行学习算法思想,不建议在竞赛中使用(太长,容易出错)。

class Solution {
    /**
     * 这个函数用于找到数组的第 k 大的数。
     * 在这里,我们将“第 k 大”转换为相当于“从小到大排序后,
     * 位于索引 (nums.length - k) 的元素”,
     * 然后使用“Median of Medians + 三向切分”的方式来查找该元素。
     *
     * @param nums 数组
     * @param k 第 k 大的数
     * @return 返回数组中第 k 大的数
     */
    public int findKthLargest(int[] nums, int k) {
        // 如果 k 不在有效范围,直接返回最小值以表示异常情况
        if (k < 1 || k > nums.length) {
            return Integer.MIN_VALUE;
        }
        // 这里的第 k 大对应从小到大排序后的下标 (nums.length - k)
        return select(nums, 0, nums.length - 1, nums.length - k);
    }

    /**
     * 在 nums[begin..end] 区间内,返回其中第 i 小的数(i 是基于 0 的索引)。
     * 比如 i = 0 表示最小值,i = n-1 表示最大值。
     *
     * @param nums  数组
     * @param begin 要处理的区间起始下标
     * @param end   要处理的区间结束下标
     * @param i     第 i 小的索引
     * @return      返回第 i 小的值
     */
    public static int select(int[] nums, int begin, int end, int i) {
        // 若区间内只有一个元素,直接返回
        if (begin == end) {
            return nums[begin];
        }

        // 选取中位数的中位数作为 pivot,确保算法在最坏情况下也能保持线性级别
        int pivot = medianOfMedians(nums, begin, end);

        // 三路分区,返回 pivot 的区间 [leftBound, rightBound]
        int[] pivotRange = partition(nums, begin, end, pivot);

        // 如果 i 落在 pivot 的区间内,直接返回 nums[i]
        if (i >= pivotRange[0] && i <= pivotRange[1]) {
            return nums[i];
        }
        // 如果 i 落在 pivot 区间右侧,则在右侧区间继续查找
        else if (i > pivotRange[1]) {
            return select(nums, pivotRange[1] + 1, end, i);
        }
        else {
            // 否则,i 落在左侧区间,在左侧继续查找
            return select(nums, begin, pivotRange[0] - 1, i);
        }
    }

    /**
     * 交换数组中的两个元素
     *
     * @param arr 数组
     * @param i   位置 i
     * @param j   位置 j
     */
    public static void swap(int[] arr, int i, int j) {
        int temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }

    /**
     * 三向切分的 partition 函数,将数组 nums[l..r] 根据 pivot 分成:
     *   - [l..a] 小于 pivot 的元素
     *   - (a..b) 等于 pivot 的元素
     *   - [b..r] 大于 pivot 的元素
     * 返回值是等于 pivot 的那段区间的起止下标 [a+1, b-1]
     *
     * @param nums  原数组
     * @param l     待 partition 区间左边界
     * @param r     待 partition 区间右边界
     * @param pivot 枢纽值
     * @return      等于 pivot 的区间下标 [leftBound, rightBound]
     */
    public static int[] partition(int[] nums, int l, int r, int pivot) {
        // a 指向小于 pivot 区间的尾部
        int a = l - 1;
        // b 指向大于 pivot 区间的头部
        int b = r + 1;
        // i 为当前遍历指针
        int i = l;

        // 当 i 与 b 相遇时,说明全部元素都处理完
        while (i < b) {
            if (nums[i] < pivot) {
                // nums[i] < pivot,放到前面
                swap(nums, ++a, i++);
            }
            else if (nums[i] == pivot) {
                // nums[i] == pivot,不动,继续向后
                i++;
            }
            else {
                // nums[i] > pivot,放到后面
                swap(nums, --b, i);
                // 这里不递增 i,因为换过来的元素还没有处理过,需要继续判断
            }
        }
        // 返回 pivot 区间的左右边界
        return new int[]{a + 1, b - 1};
    }

    /**
     * medianOfMedians 函数用于在 [begin..end] 区间中选取一个近似的中位数,来保证最坏情况下也能线性时间。
     * 做法是:将 [begin..end] 分成若干组,每组最多 5 个元素,分别找出各组的中位数,再递归求这些中位数的中位数。
     *
     * @param nums  原数组
     * @param begin 起始下标
     * @param end   结束下标
     * @return      [begin..end] 区间的“中位数的中位数”
     */
    public static int medianOfMedians(int[] nums, int begin, int end) {
        int n = end - begin + 1;
        // 若区间元素不超过 5 个,则直接用插入排序排好后取中间元素
        if (n <= 5) {
            insertionSort(nums, begin, end);
            return nums[begin + ((end - begin) >> 1)];
        }

        // 按每 5 个元素一组,收集每组的中位数到 mArr
        int[] mArr = new int[n / 5 + (n % 5 == 0 ? 0 : 1)];
        for (int i = 0; i < mArr.length; i++) {
            int beginI = begin + i * 5;
            int endI = beginI + 4;
            // getMedian 函数会对这组做插入排序后返回中间元素
            mArr[i] = getMedian(nums, beginI, Math.min(endI, end));
        }

        // 再从 mArr 中递归查找中位数,即得到中位数的中位数
        return select(mArr, 0, mArr.length - 1, mArr.length / 2);
    }

    /**
     * 对 nums[begin..end] 这段进行插入排序后,返回排序后居中的那个元素
     *
     * @param nums  数组
     * @param begin 起始下标
     * @param end   结束下标
     * @return      这段里排序后居中的元素
     */
    public static int getMedian(int[] nums, int begin, int end) {
        insertionSort(nums, begin, end);
        // 取区间 [begin..end] 的中间位置元素
        return nums[begin + ((end - begin) >> 1)];
    }

    /**
     * 插入排序,用于对区间 [l..r] 内的元素进行原地排序。
     *
     * @param arr 数组
     * @param l   要排序的区间左端
     * @param r   要排序的区间右端
     */
    public static void insertionSort(int[] arr, int l, int r) {
        for (int i = l + 1; i <= r; i++) {
            int temp = arr[i];
            int j = i - 1;
            // 将 arr[i] 插入到前面已排好的有序区间中
            while (j >= l && arr[j] > temp) {
                arr[j + 1] = arr[j];
                j--;
            }
            arr[j + 1] = temp;
        }
    }
}

结尾

以上就是本篇的全部内容。随机选择算法是基础算法之一,略微补充一下。**

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值