问题
本问题来自leetcode。
给定两个大小为 m 和 n 的有序数组 A 和 B。
请你找出这两个有序数组的中位数,并且要求算法的时间复杂度为 O(log(m + n))。
你可以假设 nums1 和 nums2 不会同时为空。
示例:
A = [1, 3]
B = [2]
则中位数是 2.0
A = [1, 2]
B = [3, 4]
则中位数是 (2 + 3)/2 = 2.5
分析
中位数,又称中点数,中值。中数是按顺序排列的一组数据中居于中间位置的数,即在这组数据中,有一半的数据比他大,有一半的数据比他小。
也就是说,中位数把一个集合划分为长度相等的两个子集,一个子集的元素问题大于另一个子集。
首先,使用随机值i把A分成两部分:
left_A | right_A
A[0], A[1], ..., A[i-1] | A[i], A[i+1], ..., A[m-1]
i的取值范围为0~m,所以共有m+1种划分方法。len(left_A)=i, len(right_A)=m-i。注意,当i=0时,left_A为空,当i=m时,right_A为空。
使用同样的方法把B分成两部分:
left_B | right_B
B[0], B[1], ..., B[j-1] | B[j], B[j+1], ..., B[n-1]
把left_A和left_B放在一个集合里,right_A和right_B放在另一个集合里,分别取名为left_part, right_part:
left_part | right_part
A[0], A[1], ..., A[i-1] | A[i], A[i+1], ..., A[m-1]
B[0], B[1], ..., B[j-1] | B[j], B[j+1], ..., B[n-1]
如果能使得:
- len(left_part) == len(right_part)
- max(left_part) <= min(right_part)
则,中位数median = (max(left_part) + min(right_part))/2
。
要使上述两个条件成立,需要:
- i + j == m - i + n - j (或 m - i + n - j + 1)
如果 n >= m, 只需令 i = 0 ~ m, j = (m + n + 1)/2 - i(如果n<m,j可能为负) - B[j-1] <= A[i] 并且 A[i-1] <= B[j](暂时不考虑边界情况)
所以,这样即可:
在[0,m]中寻找i,使得:
B[j-1] <= A[i] && A[i-1] <= B[j], (其中j = (m + n + 1)/2 - i )
可以使用二分查找法得到i,具体步骤如下:
- 令imin = 0, imax = m,在[imin, imax]中开始寻找
- 令i = (imin + imax)/2, j = (m + n + 1)/2 - i
- 至此,len(left_part)==len(right_part),共3种情况:
a. B[j-1] <= A[i] and A[i-1] <= B[j],满足目标,停止
b. B[j-1] > A[i],需要增加i,在[i+1, imax]中寻找,则令imin = i+1, 重复步骤2
c. A[i-1] > B[j],需要在[imin, i-1]中寻找,则令imax = i-1,重复步骤2
得到i后,中位数为:
max(A[i-1], B[j-1]) (当 m + n 是奇数)
(max(A[i-1], B[j-1]) + min(A[i], B[j]))/2 (当 m + n 是偶数)
然后,考虑边界情况:i=0, i=m, j=0, j=n时,A[i-1],B[j-1],A[i],B[j]不存在。
为了使得max(left_part) <= min(right_part)
,
- 如果i,j不是边界值,需要检查两个条件:B[j-1] <= A[i] && A[i-1] <= B[j],
- 如果A[i-1],B[j-1],A[i],B[j]中某些值不存在,则不再需要检查不存在的值。假如 i=0,则A[i-1]不存在,就不必检查A[i-1] <= B[j]。
所以需要做的如下:
在[0,m]中查找i,使得:
(j == 0 or i == m or B[j-1] <= A[i]) &&
(i == 0 or j == n or A[i-1] <= B[j])
其中 j = (m + n + 1)/2 - i
在一个查找中共有3种情形:
1. (j == 0 or i == m or B[j-1] <= A[i]) && (i == 0 or j = n or A[i-1] <= B[j]),则符合,停止查找
2. i < m and B[j - 1] > A[i],则增加i
3. i > 0 and A[i - 1] > B[j],则减小i
注意,在2和3中,i<m时,j必大于0,i>0时,j必小于n:
m <= n, i < m ==> j = (m+n+1)/2 - i > (m+n+1)/2 - m >= (2*m+1)/2 - m >= 0
m <= n, i > 0 ==> j = (m+n+1)/2 - i < (m+n+1)/2 <= (2*n+1)/2 <= n
代码
c语言代码如下:
static inline int max(int a, int b)
{
if (a > b) {
return a;
}
return b;
}
static inline int min(int a, int b)
{
if (a > b) {
return b;
}
return a;
}
double findMedianSortedArrays(int* a, int m, int* b, int n) {
int *p_tmp;
int imin, imax, half_len;
int max_of_left, min_of_right;
int i, j;
/* make sure m<=n */
if (m > n) {
p_tmp = a;
a = b;
b = p_tmp;
m = m + n;
n = m - n;
m = m - n;
}
if (n == 0) {
printf("invalid paras\n");
return -1;
}
imin = 0;
imax = m;
half_len = (m + n + 1) / 2;
while (imin <= imax) {
i = (imin + imax) / 2;
j = half_len - i;
if (i < m && b[j-1] > a[i]) {
/* i is too small */
imin = i + 1;
} else if (i > 0 && a[i-1] > b[j]) {
/* i is too big */
imax = i - 1;
} else {
/* i is found */
if (i == 0) {
max_of_left = b[j-1];
} else if (j == 0) {
max_of_left = a[i-1];
} else {
max_of_left = max(a[i-1], b[j-1]);
}
if ((m + n) % 2) {
return max_of_left;
}
if (i == m) {
min_of_right = b[j];
} else if (j == n) {
min_of_right = a[i];
} else {
min_of_right = min(a[i], b[j]);
}
return ((max_of_left + min_of_right) / 2.0);
}
}
}
python代码:
def median(A, B):
m, n = len(A), len(B)
if m > n:
A, B, m, n = B, A, n, m
if n == 0:
raise ValueError
imin, imax, half_len = 0, m, (m + n + 1) / 2
while imin <= imax:
i = (imin + imax) / 2
j = half_len - i
if i < m and B[j-1] > A[i]:
# i is too small, must increase it
imin = i + 1
elif i > 0 and A[i-1] > B[j]:
# i is too big, must decrease it
imax = i - 1
else:
# i is perfect
if i == 0: max_of_left = B[j-1]
elif j == 0: max_of_left = A[i-1]
else: max_of_left = max(A[i-1], B[j-1])
if (m + n) % 2 == 1:
return max_of_left
if i == m: min_of_right = B[j]
elif j == n: min_of_right = A[i]
else: min_of_right = min(A[i], B[j])
return (max_of_left + min_of_right) / 2.0