题目来自Leetcode: There are two sorted arrays A and B of size m and n respectively. Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
思路很显然,在一个数组中使用二分法,查找元素在另一个数组中的位置。实现起来却不容易,比如对于重复数据的处理。提交了几次才通过。
实现过程参考了这篇http://my.oschina.net/mustang/blog/58047 不过这个题目里面两个数组大小是相同的,而leetcode不要求相同。
关键函数 findByIndex(int[] a, int fromA, int toA, int[] b, int fromB, int toB, int index)
,从a[fromA:toA)
和b[fromB:toB)
中查找第index个元素;和valuesLessThan(int[] a, int fromIndex, int toIndex, int target)
,统计a[fromIndex:toIndex)
中严格小于target的元素的数量。
public class LeetCode004 {
public double findMedianSortedArrays(int A[], int B[]) {
int n = A.length + B.length;
if(n % 2 != 0) {
return findByIndex(A, 0, A.length, B, 0, B.length, n/2);
} else {
int m1 = findByIndex(A, 0, A.length, B, 0, B.length, n/2-1);
int m2 = findByIndex(A, 0, A.length, B, 0, B.length, n/2);
return 0.5*(m1 + m2);
}
}
private int findByIndex(int[] a, int fromA, int toA, int[] b, int fromB, int toB, int index)
{
if(fromA == toA) return b[fromB+index];
if(fromB == toB) return a[fromA+index];
if(index == 0) return Math.min(a[fromA], b[fromB]);
int sizeA = 0, sizeB = 0;
int half = (index+1)/2;
if(toA - fromA < toB - fromB) {
sizeA = Math.min(half, toA - fromA);
sizeB = index + 1 - sizeA;
} else {
sizeB = Math.min(half, toB - fromB);
sizeA = index + 1 - sizeB;
}
if(a[fromA + sizeA - 1] < b[fromB + sizeB - 1]) {
sizeB = valuesLessThan(b, fromB, toB, a[fromA + sizeA - 1]);
return findByIndex(a, fromA+sizeA, toA, b, fromB+sizeB, toB, index-sizeA-sizeB);
} else if(a[fromA + sizeA - 1] > b[fromB + sizeB - 1]) {
sizeA = valuesLessThan(a, fromA, toA, b[fromB + sizeB - 1]);
return findByIndex(a, fromA+sizeA, toA, b, fromB+sizeB, toB, index-sizeA-sizeB);
} else {
return a[fromA + sizeA - 1];
}
}
private int valuesLessThan(int[] a, int fromIndex, int toIndex, int target) {
if(fromIndex == toIndex || a[fromIndex] >= target)
return 0;
if(a[toIndex - 1] < target)
return toIndex - fromIndex;
//invariant: a[left] < target <= a[right]
int left = fromIndex, right = toIndex-1;
while(left+1 < right) {
int mid = (left+right)/2;
if(a[mid] < target) left = mid;
else right = mid;
}
return right - fromIndex;
}
}