题目:
Given a n x n matrix where each of the rows and columns are sorted in ascending order, find the kth smallest element in the matrix.
Note that it is the kth smallest element in the sorted order, not the kth distinct element.
Example:
matrix = [ [ 1, 5, 9], [10, 11, 13], [12, 13, 15] ], k = 8, return 13.
Note:
You may assume k is always valid, 1 ? k ? n2.
思路:
1、堆:可以建立一个堆,每次从堆顶取走一个元素之后,我们将它右边和下边(如果当前元素处于第一列)的两个元素加入堆中,这样可以保证当前新的最小元素一定存在于堆中。当这样取到第k个元素的时候,就一定是矩阵中第k大的元素。可以证明堆中的元素最大可能有n个,那么算法的空间复杂度就是O(n),时间复杂度是O(klogn)。
2、二分查找:由于矩阵按行和按列都是有序的,所以可以确定其最小元素一定在最上角,最小元素一定在最下角,那么第k个大的元素一定介于这两者之间。采用二分查找的思路就是:每次取一个最大值和最小值的平均值,在每行中统计比该平均值小的数的数量并累加。如果累加值小于等于(k - 1),则说明只有最多k-1个数比该平均值小,所以我们更新最小元素边界;否则更新最大元素边界。该算法的空间复杂度是O(1),时间复杂度是O(log(max_value - mn_value)nlogn),其中max_value = matrix[n - 1][n - 1], min_value = matrix[0][0]。可见该算法的时间复杂度是输入敏感的。
代码:
1、堆:
class Solution {
public:
int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size(), num = 0;
vector<Element> vectors;
vectors.push_back(Element(0,0, matrix[0][0]));
while(true) {
num++;
Element e = vectors[0];
if(num == k) {
return e.val;
}
pop_heap(vectors.begin(), vectors.end(), EleCom); // remove from the heap
vectors.pop_back();
if(e.col == 0 && e.row < n - 1) {
vectors.push_back(Element(e.row + 1, e.col, matrix[e.row + 1][e.col]));
push_heap(vectors.begin(), vectors.end(), EleCom);
}
if(e.col < n - 1) {
vectors.push_back(Element(e.row, e.col + 1, matrix[e.row][e.col + 1]));
push_heap(vectors.begin(), vectors.end(), EleCom);
}
}
return -1;
}
private:
struct Element {
int row, col, val;
Element(int r, int c, int v): row(r), col(c), val(v){};
};
struct ElementCompare {
bool operator()(const Element &e1, const Element &e2) {
return e1.val > e2.val;
}
} EleCom;
};
2、二分查找:
class Solution {
public:
int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
int left = matrix[0][0], right = matrix[n - 1][n - 1];
while(left <= right){
int mid = left + (right - left) / 2;
int num = 0; // num stores the number of elements that are larger than mid
for(int i = 0; i < n; ++i) {
auto it = upper_bound(matrix[i].begin(), matrix[i].end(), mid);
num += distance(matrix[i].begin(), it);
}
if(num <= k - 1) // mid is too small because less than k numbers are smaller than mid
left = mid + 1;
else
right = mid - 1;
}
return left;
}
};