题目
给定一个数组,找出数组中第K大。
输入: [3,2,1,5,6,4] 和 k = 2
输出: 5
原题链接:https://leetcode.cn/problems/kth-largest-element-in-an-array/
思路
求 topK 是很经典的题目,经典的思路就是快排和堆排了。
思路1
快排思路。(非完整版快排)
我们从大到小进行排列,在 partition 的过程中,如果分界值刚好在 k 的位置,则提前返回该分界值作为结果即可,如果分界值的坐标落在 k 往后的位置,则说明 topk 还在左半部分,此时右半部分不需要处理,只需要对左半部分进一步 partition 即可;如果分界值落在 k 往前的位置,则说明 topk 在右半部分,此时左半部分不需要处理,直接在右半部分进一步 partition 找 topk 即可。
快排的过程中,为了避免最坏的情况出现(最坏的情况为每次选取到最大值或最小值,导致最终复杂度为 O(n^2) )。我们每次分区,选取分界值时采用 random 的方式选取。
- 复杂度分析
- 时间复杂度 O(n)。在 random 的方式下,期望的时间复杂度为 O(n)。这里没有推导证明,推导可以参考算法导论。
- 空间复杂度 O(logn)。递归层数,使用栈空间的空间代价为O(logn)
思路2
堆排思路。(非完整版堆排)
既然是找 topk,那就构建一个大小为 k 的小顶堆,堆顶就是 topk 了。构建完小顶堆后,遍历剩余的数字,如果大于小顶堆则将其入堆,直到遍历完所有的数字。最终返回堆顶。
小顶堆偷懒的话就用优先队列,否则就自己写。
由于这里我们只需要维护一个小堆顶即可,不需要完全按照堆排的思路走。
先取 k 个数初始化成一个小堆顶,然后遍历剩下的数字,如果大于堆顶,则替换掉堆顶,并从堆顶进行重新调整,让小数上浮,大数下沉。
- 复杂度分析
- 时间复杂度 O(nlogk)。构建 topk 小顶堆是 O(klogk),遍历剩下的数字,入堆操作是 O(logk),故整体为 O(nlogk)。
- 空间复杂度 O(k)。topk 的小顶堆
代码
代码1
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
return partition(nums, 0, nums.size() - 1, k);
}
int partition(vector<int>& nums, int start, int end, int k) {
int random_index = start + rand() % (end - start + 1);
swap(nums[start], nums[random_index]);
int cur = nums[start];
int i = start;
int j = end;
while (i < j) {
while (i < j && nums[j] < cur) {
j--;
}
if (i < j) {
nums[i] = nums[j];
i++;
}
while (i < j && nums[i] > cur) {
i++;
}
if (i < j) {
nums[j] = nums[i];
j--;
}
}
nums[i] = cur;
if (i + 1 == k) {
return nums[i];
}
else if (i + 1 > k) {
return partition(nums, start, i - 1, k);
}
else {
return partition(nums, i + 1, end, k);
}
}
};
代码2 优先队列版
class Solution {
public:
struct cmp {
bool operator()(int a, int b) {
return a > b;
}
};
int findKthLargest(vector<int>& nums, int k) {
priority_queue<int, vector<int>, cmp> q;
for (int i = 0; i < k; i++) {
q.push(nums[i]);
}
for (int i = k; i < nums.size(); i++) {
if (nums[i] > q.top()) {
q.push(nums[i]);
q.pop();
}
}
return q.top();
}
};
代码2 小顶堆手写版
class Solution {
public:
int findKthLargest(vector<int>& nums, int k) {
vector<int> minHeap(nums.begin(), nums.begin() + k);
heapInit(minHeap);
// 遍历剩下的数组
for (int i = k; i < nums.size(); i++) {
if (nums[i] > minHeap[0]) {
minHeap[0] = nums[i];
heapAdjust(minHeap, k, 0);
}
}
return minHeap[0];
}
void heapInit(vector<int>&minHeap) {
for (int i = minHeap.size() / 2 - 1; i >= 0; i--) {
heapAdjust(minHeap, minHeap.size(), i);
}
}
void heapAdjust(vector<int>& minHeap, int len, int index) {
int left = index * 2 + 1;
int right = left + 1;
// 判断是否有孩子
if (left < len) {
int min_index = index;
int min_value = minHeap[index];
// 比较左孩子
if (minHeap[left] < min_value) {
min_index = left;
min_value = minHeap[left];
}
// 判断是否有右孩子,再进行比较
if (right < len && minHeap[right] < min_value) {
min_index = right;
min_value = minHeap[right];
}
if (min_index != index) {
swap(minHeap[min_index], minHeap[index]);
heapAdjust(minHeap, len, min_index);
}
}
}
};