Question
There are two sorted arrays nums1 and nums2 of size m and n respectively. Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
Java Code
public double findMedianSortedArrays(int[] nums1, int[] nums2) {
return (nums1.length > nums2.length) ? binarySearch(nums1, nums2) :
binarySearch(nums2, nums1);
}
//要求数组nums1为较长的数组
public double binarySearch(int[] nums1, int[] nums2) {
//设置两个数组的各个初始指针
int min1 = 0;
int max1 = nums1.length - 1;
int med1 = max1/2;//如果数组长度为偶数,则中位数指针总是指向下标较小的那个元素
int min2 = 0;
int max2 = nums2.length - 1;
int med2 = max2/2;
//每一轮二分查找后两个数组均需要裁剪掉的长度,为什么总是nums2数组的一半长度?
int remove = (max2 - min2)/2;
//case 0: 数组nums2为空数组,直接返回数组nums1的中位数
if(max2 == -1) {
if((nums1.length & 1) == 0)
return (nums1[nums1.length/2] + nums1[nums1.length/2 - 1])/2.0;
else
return nums1[nums1.length/2];
}
//case 1: 数组nums2超过2个元素
while(max2 - min2 > 1) {
//如果当前的nums1和nums2的中位数正好相等,且在融合后的数组中也是中位数
if(nums1[med1] == nums2[med2] && med1 + med2 + 1 == (nums1.length + nums1.length + 1)/2)
return (nums1[med1] + nums2[med2])/2.0;
//把当前nums1数组的低位区间和当前nums2数组的高位区间裁剪掉
else if(nums1[med1] < nums2[med2]) {
min1 += remove;
max2 -= remove;
//把当前nums1数组的高位区间和当前nums2数组的低位区间裁剪掉
}else {
max1 -= remove;
min2 += remove;
}
//更新两个数组的中位数指针以及下一轮需要裁剪掉的长度
med1 = (min1 + max1)/2;
med2 = (min2 + max2)/2;
remove = (max2 - min2)/2;
}
//case 2: 数组nums2剩2个元素
if(max2 - min2 == 1) {
//case 2.1: nums1剩2个元素
if(max1 - min1 == 1) {
return (Math.max(nums1[min1], nums2[min2]) + Math.min(nums1[max1], nums2[max2]))/2.0;
//case 2.2: 数组nums1至少剩余3个元素
}else {
if(((max1 - min1) & 1) == 0) {//case 2.2.1: 数组nums1剩余奇数长度
if(nums1[med1] <= nums2[med2]) {
min1++;
max2--;//下一步转到case 3.3.2
}else {
if(nums2[med2 + 1] <= nums1[med1])
return Math.max(nums2[med2 + 1], nums1[med1 - 1]);
else
return nums1[med1];
}
}else {//case 2.2.2: 数组nums1剩余偶数长度
if(nums1[med1] <= nums2[med2]) {
if(nums2[med2 + 1] < nums1[med1 + 1])//特殊情况
return (nums2[med2] + nums2[med2 + 1])/2.0;
else {
min1++;
med1++;//pay attention to this
max2--;//下一步转到case 3.3.1
}
}else {
if(nums2[med2 + 1] <= nums1[med1])
return (nums1[med1] + Math.max(nums2[med2 + 1], nums1[med1 - 1]))/2.0;
else
return (nums1[med1] + Math.min(nums2[med2 + 1], nums1[med1 + 1]))/2.0;
}
}
}
}
//case 3: 数组nums2剩1个元素
if (max2 == min2) {
//case 3.1: 数组nums1剩1个元素
if(max1 == min1);
//case 3.2: 数组nums1剩2个元素
else if(max1 - min1 == 1) {
return Math.min(Math.max(nums1[min1], nums2[min2]), nums1[max1]);
//case 3.3: 数组nums1至少剩3个元素
}else {
if((max1 - min1 & 1) == 0) {//case 3.3.1: 数组nums1剩余奇数长度
if(nums1[med1] < nums2[med2]) {
if(nums2[med2] > nums1[med1 + 1])
return (nums1[med1] + nums1[med1 + 1])/2.0;
}else if(nums1[med1] > nums2[med2]) {
if(nums2[med2] < nums1[med1 - 1])
return (nums1[med1] + nums1[med1 - 1])/2.0;
}else
return nums1[med1];
}else {//case 3.3.2: 数组nums1剩余偶数长度
if(nums1[med1] < nums2[med2])
return Math.min(nums2[med2], nums1[med1 + 1]);
else
return nums1[med1];
}
}
}
return (nums1[med1] + nums2[med2])/2.0;
}
说明
本文给的算法貌似达到了O(log(min(m, n))),优于本题要求的O(log(m+n))?
本题在Leetcode上属于难度系数最高的Hard级别,对我来说的确是非常的有挑战性!之前和教研室小伙伴做了一个晚上也没搞定,不过经过仔细分析也找到了一些初步的解决思路,但是还有一些细节问题没有很好地解决。今天花了很多时间,不断地踩坑。。填坑。。踩坑。。填坑,最终总算AC了,不过runtime只排到40%左右,怎么这么低啊。。。
整个代码的核心思想主要体现在while循环中,由于求中位数必须区分数组为奇数和偶数时的不同,所以需要多层if-else来分情况讨论,具体的算法思想和实现细节下次再具体写吧(其实注释里说得大概差不多了)。另外昨天看到有一个非常棒的算法来解决本题,这里贴出博客链接,供大家参考学习。