主要是针对LeetCode 912. 排序数组 这道题的快排方法的总结,包括普通的快排,以及两个优化的方向,以及可以正确运行的c++代码(只有优化2的代码可以完美通过LeetCode 912,其他两个可以完成快速排序,但无法通过LeetCode 912的线上评测)。
单路快排
规定原数组最右(或最左)的元素作为pivot,然后每次递归都确定当前pivot的位置。
具体快排的思路网上有很多,这里主要讲代码实现中的细节。
class Solution {
int partition(vector<int>& nums, int left, int right) {
/*
参数:要分区的数组,要分区范围的左边界和右边界
返回值:分区好之后的中间位置的下标,也就是快排中这次定好为位置的元素下标
*/
int pivot = nums[right]; // 默认用nums[right]作为pivot
// 只需要从一边开始遍历,找比pivot小的(或大的)的数字,然后交换位置
// j遍历去找比pivot小的,i是要放的位置
int i = left;
// 注意这里是j <= right - 1,因为right位置的元素是pivot
for (int j = left; j <= right - 1; j++) {
if (nums[j] <= pivot) {
swap(nums[j], nums[i++]);
}
}
swap(nums[i], nums[right]); // 交换i和pivot的位置,使得pivot放到最终正确的位置
return i; // 返回找好位置的元素的下标
}
void quiksort(vector<int>& nums, int left, int right) {
/*
用递归的方法,不断归位pivot,不断用分治法排序。
参数:要排序的数组,要排序的区间的左右边界
返回:无返回值
*/
if (left < right) {
int pos = partition(nums, left, right);
quiksort(nums, left, pos - 1);
quiksort(nums, pos + 1, right);
}
}
public:
vector<int> sortArray(vector<int>& nums) {
quiksort(nums, 0, (int)nums.size() - 1);
return nums;
}
};
一般情况下平均时间复杂度是O(nlogn).
双路快排
class Solution {
int partition(vector<int>& nums, int left, int right) {
/*
参数:要分区的数组,要分区范围的左边界和右边界
返回值:分区好之后的中间位置的下标,也就是快排中这次定好为位置的元素下标
*/
int pivot = nums[right]; // 默认用nums[right]作为pivot
// 只需要从一边开始遍历,找比pivot小的(或大的)的数字,然后交换位置
// j遍历去找比pivot小的,i是要放的位置
int i = left, j = right - 1;
while (true) {
while (j >= left && nums[j] > pivot) {
j--;
}
while (i < right && nums[i] < pivot) {
i++;
}
if (i > j) break;
swap(nums[i], nums[j]);
i++;
j--;
}
swap(nums[right], nums[i]);
return i; // 返回找好位置的元素的下标
}
void quiksort(vector<int>& nums, int left, int right) {
/*
用递归的方法,不断归位pivot,不断用分治法排序。
参数:要排序的数组,要排序的区间的左右边界
返回:无返回值
*/
if (left < right) {
int pos = partition(nums, left, right);
quiksort(nums, left, pos - 1);
quiksort(nums, pos + 1, right);
}
}
public:
vector<int> sortArray(vector<int>& nums) {
quiksort(nums, 0, (int)nums.size() - 1);
return nums;
}
};
优化1:随机选择pivot
上面的普通快排的方法,在面对数组一开始就接近有序的情况。如果数组一开始就是有序的,那么时间复杂度会退化成O(n^2). 这时我们可以用随机选择pivot来尽可能的达到期望时间复杂度O(nlogn)。
class Solution {
int partition(vector<int>& nums, int left, int right) {
/*
参数:要分区的数组,要分区范围的左边界和右边界
返回值:分区好之后的中间位置的下标,也就是快排中这次定好为位置的元素下标
*/
int pivot = nums[right]; // 默认用nums[right]作为pivot
// 只需要从一边开始遍历,找比pivot小的(或大的)的数字,然后交换位置
// j遍历去找比pivot小的,i是要放的位置
int i = left, j = right - 1;
while (true) {
while (j >= left && nums[j] > pivot) {
j--;
}
while (i < right && nums[i] < pivot) {
i++;
}
if (i > j) break;
swap(nums[i], nums[j]);
i++;
j--;
}
swap(nums[right], nums[i]);
return i; // 返回找好位置的元素的下标
}
int randomized_partition(vector<int>& nums, int left, int right) {
int randpos = rand() % (right - left + 1) + left;
// 随机找一个放到right的位置,这样取随机数和partition的方法就完全分离开了。
swap(nums[right], nums[randpos]);
return partition(nums, left, right);
}
void randomized_quiksort(vector<int>& nums, int left, int right) {
/*
用递归的方法,不断归位pivot,不断用分治法排序。
参数:要排序的数组,要排序的区间的左右边界
返回:无返回值
*/
if (left < right) {
int pos = randomized_partition(nums, left, right);
randomized_quiksort(nums, left, pos - 1);
randomized_quiksort(nums, pos + 1, right);
}
}
public:
vector<int> sortArray(vector<int>& nums) {
srand((unsigned)time(NULL)); // 随机数发生器的初始化函数
// 用time函数获取系统时间,来初始化。这样rand()产生的随机数就不会重复,因为随机种子是随时间变化的。
randomized_quiksort(nums, 0, (int)nums.size() - 1);
return nums;
}
};
参考了力扣这道题的官方题解,将随机获取pivot抽象出来写在partition外。
优化2:三路快排
当遇到数组中有很多重复元素时,也就是LeetCode上直接提交上面代码,显示超时的样例,有大量的重复的元素。这时快排时间复杂度也会退化成O(n^2),可以将大量相同元素看做是有序的特殊情况,且计算加上随机取pivot,也就进行大量的swap操作。
我们可以通过三路快排来优化时间复杂度。三路快排就是每次分区多分成三个部分,>pivot, == pivot 和 < pivot。这样在元素==pivot时我们就不进行swap操作。
具体实现上要注意:
- 因为有三路,每次partition需要返回两个变量,就是lastSmall(最后一个小于pivot的元素的下标)和firstBig(第一个大于pivot的元素的下标),是下一次进行partition的两个边界值,所以partition的返回值类型为vector,里面有两个元素,lastSmall和firstBig。
- 要在得到vector的返回值后,用下标访问时,记得先确认一下是否为空,是否有两个元素。
class Solution {
vector<int> partition(vector<int>& nums, int left, int right) {
// 输出放好位置的元素的下标
int pivot = nums[right];
int lastSmall = left;
int firstBig = right - 1;
int i = left;
while (i <= firstBig) {
if (nums[i] < pivot) {
swap(nums[i], nums[lastSmall]);
i++;
lastSmall++;
} else if (nums[i] == pivot) { // 相等时不做swap操作,直接往后走
i++;
} else { // 也就是nums[i] > pivot,遇到大的就放后面
swap(nums[i], nums[firstBig]);
firstBig--;
// 注意这里i不动,因为无法确定原来的nums[firstBig]是否小于pivot,所以换完了i不++,要在和新的firstBig进行比较,直到换到比pivot小的元素
}
}
swap(nums[i], nums[right]);
return {lastSmall, firstBig};
}
vector<int> randomPartition(vector<int>& nums, int left, int right) {
int randpos = rand() % (right - left + 1) + left;
swap(nums[randpos], nums[right]);
return partition(nums, left, right);
}
void quickSort(vector<int>& nums, int left, int right) {
if (left < right) {
vector<int> pos = randomPartition(nums, left, right);
if (!pos.empty() && pos.size() == 2) { // 这里保险考虑一下条件
quickSort(nums, pos[1] + 1, right);
quickSort(nums, left, pos[0] - 1);
} else {
// 异常处理代码
}
}
}
public:
vector<int> sortArray(vector<int>& nums) {
srand((unsigned)time(NULL));
quickSort(nums, 0, (int)nums.size() - 1);
return nums;
}
};
LeetCode 912线上评测的通过情况
无法通过:
- 单边+随机pivot(大量相同元素的序列超时),
- 双边快排(很长的有序数组超时),
- 三路快排(很长的有序数组超时)
通过:
- 双边+随机pivot可以过
- 三路快排+随机pivot也可以。
官方题解中快排的写法是单路+随机pivot,会在遇到大量相同元素时会超时。
欢迎交流讨论!