There are two sorted arrays A and B of size m and n respectively. Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
Analysis
If we see log(n), we should think about using binary something.
This problem can be converted to the problem of finding kth element, k is (A's length + B' Length)/2.
If any of the two arrays is empty, then the kth element is the non-empty array's kth element. If k == 0, the kth element is the first element of A or B.
For normal cases(all other cases), we need to move the pointer at the pace of half of an array length to get log(n) time.
Java Solution
public static double findMedianSortedArrays(int A[], int B[]) {
int m = A.length;
int n = B.length;
if ((m + n) % 2 != 0) // odd
return (double) findKth(A, B, (m + n) / 2, 0, m - 1, 0, n - 1);
else { // even
return (findKth(A, B, (m + n) / 2, 0, m - 1, 0, n - 1)
+ findKth(A, B, (m + n) / 2 - 1, 0, m - 1, 0, n - 1)) * 0.5;
}
}
public static int findKth(int A[], int B[], int k,
int aStart, int aEnd, int bStart, int bEnd) {
int aLen = aEnd - aStart + 1;
int bLen = bEnd - bStart + 1;
// Handle special cases
if (aLen == 0)
return B[bStart + k];
if (bLen == 0)
return A[aStart + k];
if (k == 0)
return A[aStart] < B[bStart] ? A[aStart] : B[bStart];
int aMid = aLen * k / (aLen + bLen); // a's middle count
int bMid = k - aMid - 1; // b's middle count
// make aMid and bMid to be array index
aMid = aMid + aStart;
bMid = bMid + bStart;
if (A[aMid] > B[bMid]) {
k = k - (bMid - bStart + 1);
aEnd = aMid;
bStart = bMid + 1;
} else {
k = k - (aMid - aStart + 1);
bEnd = bMid;
aStart = aMid + 1;
}
return findKth(A, B, k, aStart, aEnd, bStart, bEnd);
}
代码中的难点为公式:
int aMid = aLen * k / (aLen + bLen); // a's middle count
int bMid = k - aMid - 1; // b's middle count
这里为什么要求aMid + bMid == k-1呢?首先我们在这里要明白程序算法的核心思想。程序在对一些边界情况进行处理之后,对数组A和数组B分别寻找aMid和bMid,我们要求aMid + bMid == k-1。之所以要求aMid+bMid=k-1,是为了之后通过比较A[aMid+aStart]和B[bMid+bStart]的大小来缩减问题的规模。注意k在程序中为下标值,也就是说当k=3的时候,我们实际上想找的是两个数组中第4小的数。现在假设aMid已知,A[aMid+aStart]及其之前的数的个数为aMid+1,同理假如bMid已知,B[bMid+bStart]及其之前的数的个数为bMid+1,我们让aMid+1+bMid+1=k+1。等式的右边为k+1,其意义是两个数组中第k+1小的数。为什么让右边为k+1呢,因为我们想让a[aMid+aStart]和b[bMid+bStart]这两个数,在A[aStart..aStart+1...aStart+2...aMid+aStart]和B[bStart...bStart+1...bStart+2...bMid+bStart]中,竞争究竟谁是最大的数。注意两个数组序列中元素的个数恰好为k+1。
假如A[aMid+aStart]==B[bMid+bStart],那么我们就知道A数组和B数组第k小和第k+1小的数均为A[aMid+aStart]。因为A和B均为有序数组序列,A[aMid+aStart]之前的数均比A[aMid+aStart]小,B[bMid+bStart]之前的数均比B[bMid+bStart]小。
假如A[aMid+aStart]>B[bMid+bStart],那么我们知道A数组和B数组中第k+1小的数,必定在A[aStart...aStart+1...aMid+aStart]或b[bMid+bStart+1....bEnd]。为什么?假如第k+1小的数,在A[aMid+aStart+1...aEnd]中,其下标为x,则小于等于A[x]的数至少有k+1个,则A[x]不可能为为A和B数组中第k+1小的数。假如第k+1的数在b[bStart...bMid+bStart]中,假设其下标为y,则小于等于B[y]的数至多有k-1个,则B[y]不可能为A和B数组中第k+1小的数。接着我们修改k=k-(bMid+1),bStart=bMid+bStart+1,
aEnd=aMid+aStart。
假如A[aMid+aStart]<B[bMid+bStart],分析同上。
最后aMid的选定,并不是一定要按照程序中的公式aLen*k/(aLen+bLen)来。只要满足aMid+bMid=k-1,并且aMid>=0并且bMid>=0即可。程序中设定aMid和bMid的程序可以替换为下列程序段:
int aMid, bMid;
Random r = new Random();
if (aLen > bLen) {
bMid = (int)(Math.min(bLen,k) * r.nextDouble());
aMid = k - bMid - 1;
} else {
aMid = (int)(Math.min(aLen,k) * r.nextDouble());
bMid = k - aMid - 1;
}
10.19更新-----------------最牛逼的解法!!!
------------------------
我们可以考虑从k入手。如果我们每次都能够剔除一个一定在第k小元素之前的元素,那么我们需要进行k-1次。但是如果每次我们剔除得多点呢呢?所以用这种类似于二分的思想,我们可以这样考虑:
Assume that the number of elements in A and B are both larger than k/2, and if we compare the k/2-th smallest element in A(i.e. A[k/2-1]) and the k-th smallest element in B(i.e. B[k/2 - 1]),
there are three results:
(Becasue k can be odd or even number, so we assume k is even number here for simplicy. The following is also true when k is an odd number.)
A[k/2-1] = B[k/2-1]
A[k/2-1] > B[k/2-1]
A[k/2-1] < B[k/2-1]
if A[k/2-1] < B[k/2-1], that means all the elements from A[0] to A[k/2-1](i.e. the k/2 smallest elements in A) are in the range of k smallest elements in the union of A and B. Or, in the other
word, A[k/2 - 1] can never be larger than the k-th smalleset element in the union of A and B.
Why?
We can use a proof by contradiction. Since A[k/2 - 1] is larger than the k-th smallest element in the union of A and B, then we assume it is the (k+1)-th smallest one. Since it is smaller than
B[k/2 - 1], then B[k/2 - 1] should be at least the (k+2)-th smallest one. So there are at most (k/2-1) elements smaller than A[k/2-1] in A, and at most (k/2 - 1) elements smaller than A[k/2-1] in B.So the total number is k/2+k/2-2, which, no matter when k
is odd or even, is surly smaller than k(since A[k/2-1] is the (k+1)-th smallest element). So A[k/2-1] can never larger than the k-th smallest element in the union of A and B if A[k/2-1]<B[k/2-1];
Since there is such an important conclusion, we can safely drop the first k/2 element in A, which are definitaly smaller than k-th element in the union of A and B. This is also true for the A[k/2-1]
> B[k/2-1] condition, which we should drop the elements in B.
When A[k/2-1] = B[k/2-1], then we have found the k-th smallest element, that is the equal element, we can call it m. There are each (k/2-1) numbers smaller than m in A and B, so m must be the
k-th smallest number. So we can call a function recursively, when A[k/2-1] < B[k/2-1], we drop the elements in A, else we drop the elements in B.
We should also consider the edge case, that is, when should we stop?
1. When A or B is empty, we return B[k-1]( or A[k-1]), respectively;
2. When k is 1(when A and B are both not empty), we return the smaller one of A[0] and B[0]
3. When A[k/2-1] = B[k/2-1], we should return one of them
In the code, we check if m is larger than n to garentee that the we always know the smaller array, for coding simplicy.
double findKth(int a[], int m, int b[], int n, int k)
{
//always assume that m is equal or smaller than n
if (m > n)
return findKth(b, n, a, m, k);
if (m == 0)
return b[k - 1];
if (k == 1)
return min(a[0], b[0]);
//divide k into two parts
int pa = min(k / 2, m), pb = k - pa;
if (a[pa - 1] < b[pb - 1])
return findKth(a + pa, m - pa, b, n, k - pa);
else if (a[pa - 1] > b[pb - 1])
return findKth(a, m, b + pb, n - pb, k - pb);
else
return a[pa - 1];
}
class Solution
{
public:
double findMedianSortedArrays(int A[], int m, int B[], int n)
{
int total = m + n;
if (total & 0x1)
return findKth(A, m, B, n, total / 2 + 1);
else
return (findKth(A, m, B, n, total / 2)
+ findKth(A, m, B, n, total / 2 + 1)) / 2;
}
};