4. Median of Two Sorted Arrays
平凡解法:
先将 nums1
和 nums2
归并成一个数组 merged_arr
,然后再对归并后的数组求中位数。代码如下:
class Solution(object):
def findMedianSortedArrays(self, nums1, nums2):
"""
:type nums1: List[int]
:type nums2: List[int]
:rtype: float
"""
m, n = len(nums1), len(nums2)
i, j = 0, 0
merged_arr = []
while i <= m-1 and j <= n-1:
if nums1[i] < nums2[j]:
merged_arr.append(nums1[i])
i += 1
else:
merged_arr.append(nums2[j])
j += 1
if i <= m-1:
merged_arr.extend(nums1[i:])
if j <= n-1:
merged_arr.extend(nums2[j:])
print(f'merged_arr = {merged_arr}')
# When the `while` loop finished, we can get a merged array.
# Then we just need to find the median of the merged array.
if (m + n) % 2 == 1: # odd
return merged_arr[(m+n)//2]
else: # even
return (merged_arr[(m+n)//2 - 1] + merged_arr[(m+n)//2]) / 2 # Note: accurate division and truncating division,
# see https://blog.youkuaiyun.com/feixingfei/article/details/7081446 for detail.
if __name__ == '__main__':
sol = Solution()
nums1 = [1, 2]
nums2 = [3, 4]
res = sol.findMedianSortedArrays(nums1, nums2)
print(f'res = {res}')
这里面要注意三点:
- Python extend 方法:https://www.runoob.com/python/att-list-extend.html
- 归并排序:https://blog.youkuaiyun.com/weixin_45595437/article/details/105838132
- 在 Python2 中的除法是截断除法,而在 Python3 中才是精确除法,详见:https://blog.youkuaiyun.com/feixingfei/article/details/7081446
这种做法能够被 AC,但是它的时间复杂度是 O ( m + n ) O(m+n) O(m+n) 。(参见: 归并排序时间复杂度分析)
要想让时间复杂度达到
log
\log
log 级别,只能通过二分查找去做。如果用这种方法去做的话,那我们第一步要确定是要找一个数还是要找两个数。注意到 nums1
和 nums2
都是已经排过序的,这就为二分查找创造了条件。至此,这个问题就变成了“在两个数组中的二分查找问题”。这里边的关键是怎么样去压缩我们要查找的范围。为此,这里边要用到四个指针,它们分别指向 nums1
和 nums2
的头部和尾部:
nums1[lo1], nums1[hi1]
nums2[lo2], nums2[hi2]
时间复杂度为 O ( log ( min ( m , n ) ) O(\log (\min (m, n)) O(log(min(m,n)) 的解法:(强烈推荐)
Python 版本代码如下:
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
# If len(nums1) > len(nums2), then switch them so that nums1 is smaller than nums2.
# This operation is used to decrease the time complexity so that the final time complexity
# will definitely be O(log(min(m, n))). Because once we guarantee nums1 is the smallest length,
# we always do binary search in the shortest list.
if len(nums1) > len(nums2):
nums1, nums2 = nums2, nums1
x, y = len(nums1), len(nums2)
lo, hi = 0, x # Note: here `hi` cannot be `x-1`
while lo <= hi:
partitionX = (lo + hi) // 2
partitionY = (x + y + 1) // 2 - partitionX # Note: partitionX + partitionY = (x + y + 1) // 2
# if partitionX is 0 it means nothing is there on left side. Use -INF for maxLeftX
# if partitionX is the length of input then there is nothing on right side. Use +INF for minRightX
maxLeftX = float('-inf') if partitionX == 0 else nums1[partitionX - 1]
minRightX = float('inf') if partitionX == x else nums1[partitionX]
maxLeftY = float('-inf') if partitionY == 0 else nums2[partitionY - 1]
minRightY = float('inf') if partitionY == y else nums2[partitionY]
if maxLeftX <= minRightY and maxLeftY <= minRightX:
if (x + y) % 2 == 0:
return (max(maxLeftX, maxLeftY) + min(minRightX, minRightY)) / 2
else:
return max(maxLeftX, maxLeftY)
elif maxLeftX > minRightY: # we are too far on right side for partitionX. Go on left side.
hi = partitionX - 1
else: # we are too far on left side for partitionX. Go on right side.
lo = partitionX + 1