前言
有这么一类问题,是寻找区间内的第k大或者说第k小(实际上这两个是一个问题),解决算法有排序,权值线段树、主席树、树套树等,但是要么就是复杂度不满意或者不好写。根据快速排序的分治思路,我们还可以在O(n)O(n)O(n)的时间内找到答案,但并不能保证有序,并且STL十分贴心地把这个算法封装到 nth_element 中,本文就是介绍这个接口地使用。
原理
前言中提到,这个接口的期望复杂度是 O(n)O(n)O(n) ,我们可以假设每次选取的分段标志为最优大致证明一下:
T(n)=T(n/2)+n=T(n/4)+n+n/2=T(n/(2i))+n+n/2...+n/(2(i−1))=n+n/2+...+1=2∗n−1=O(n),其中(i=log(n)+1).
T(n) = T(n/2) + n \\
= T(n/4) + n + n/2 \\
= T(n/(2^i)) + n + n/2
... + n/(2^(i-1)) \\
= n + n/2 + ... + 1 \\
= 2*n - 1
= O(n),其中 (i = log(n) + 1) .
T(n)=T(n/2)+n=T(n/4)+n+n/2=T(n/(2i))+n+n/2...+n/(2(i−1))=n+n/2+...+1=2∗n−1=O(n),其中(i=log(n)+1).
至于划分部分和快排一致,区别在于分治部分,具体实现略。
接口
nth_element 默认的比较子和一般的容器都是 less ,这在堆结构中会形成大根堆;在排序接口会形成从小到大排序;在这里则会形成第k小的功能,如果把这个比较子换成greater或者自定义的一个比较子接口则会形成一个第k大。
对于每个参数的含义:
- 参数1:表示第一个迭代器/指针的位置
- 参数2:表示要确定的位置,也就是第k小/大的数
- 参数3:表示最后一个迭代器/指针的位置
- 参数4:缺省less,可以传入一个比较子
示例
#include <bits/stdc++.h>
using namespace std;
int main() {
int arr[10] = {2,5,6,9,2,1,3,4,11,12};
nth_element(arr, arr+2, arr+10); //注意第1小对于下标0,以此类推
cout << "nth_element(arr, arr+3, arr+10) : ";
for (int i = 0; i < 10; i++) cout << arr[i] << ' ';
cout << "\n第3小的数是 : " << arr[2] << endl;
/**
* 输出:
* nth_element(arr, arr+3, arr+10) : 1 2 2 3 4 5 9 6 11 12
* 第3小的数是 : 2
*/
int arr2[10] = {2,5,6,9,2,1,3,4,11,12};
nth_element(arr2, arr2+7, arr2+10); //求第3大转化成求第10-3+1小问题
cout << "nth_element(arr2, arr2+7, arr2+10) : ";
for (int i = 0; i < 10; i++) cout << arr2[i] << ' ';
cout << "\n第3大的数是 : " << arr2[7] << endl;
/**
* 输出:
* nth_element(arr2, arr2+7, arr2+10) : 5 2 4 3 2 1 6 9 11 12
* 第3大的数是 : 9
*/
int arr3[10] = {2,5,6,9,2,1,3,4,11,12};
nth_element(arr3, arr3+2, arr3+10, greater<int>() ); //求第3大转化成求第10-3+1小问题
cout << "nth_element(arr3, arr3+3, arr3+10, greater<int>()) : ";
for (int i = 0; i < 10; i++) cout << arr3[i] << ' ';
cout << "\n第3大的数是 : " << arr3[2] << endl;
/**
* 输出:
* nth_element(arr3, arr3+3, arr3+10, greater<int>()) : 11 12 9 6 5 1 3 4 2 2
* 第3大的数是 : 9
*/
int arr4[10] = {2,5,6,9,2,1,3,4,11,12};
nth_element(arr4, arr4+2, arr4+10, [](int a, int b){ return a > b;}); //求第3大转化成求第10-3+1小问题
cout << "nth_element(arr4, arr4+2, arr4+10, [](int a, int b){ return a > b;}) : ";
for (int i = 0; i < 10; i++) cout << arr4[i] << ' ';
cout << "\n第3大的数是 : " << arr4[2] << endl;
/**
* 输出:
* nth_element(arr4, arr4+2, arr4+10, [](int a, int b){ return a > b;}) : 11 12 9 6 5 1 3 4 2 2
* 第3大的数是 : 9
*/
return 0;
}