LeetCode 4. 寻找两个正序数组的中位数
题目
给定两个大小分别为 m 和 n 的正序(从小到大)数组 nums1 和 nums2。请你找出并返回这两个正序数组的 中位数 。
算法的时间复杂度应该为 O(log (m+n)) 。
思路
思路一:
对两个数组进行归并排序,因为两个都是有序的,所以排序的时间复杂度为 O(m + n)
思路二:
用两个指针,哪个小哪个向后移动,直到移动到中位数的位置
思路三:二分
二分法将每次搜索的范围缩小一半
设nums1的长度为m,nums2的长度为n
- 若(n+m)%2==1,则中位数是第(n+m+1)/2个数
- 若(n+m)%2==0,则中位数是第(n+m)/2个数和第(n+m)/2+1个数和除以2
问题就转变成了从两个有序数组中找到第k小的数字
我们先从两个有序数组中各取出k/2个数字:
nums1[0],nums1[1]...nums1[k/2-1]
nums2[0],nums2[1]...nums2[k/2-1]
- 若
nums1[k/2-1]
<nums2[k/2-1]
,则最多有k/2-2(如果数组里面不足k/2个数字,可能还会更少)个数(nums1[0-k/2-2],nums2[0-k/2-1]
比它小,所以它和它之前的都不可能是第k个数 nums1[k/2-1]
>nums2[k/2-1]
,同理nums1[k/2-1]
=nums2[k/2-1]
,都有可能是第k个数,但因为相同,所以丢弃一个不会影响结果,可以归到第一种情况里
思路四:划分
设 A 数组长度为 m, B 数组长度 为 n
0 - m
0 - n
i + j = m - i + n - j
i + j = (m + n ) / 2 (m + n) % 2 == 0
i + j = m - i + n - j + 1
i + j = (m + n + 1) / 2
因为是整除 所以当 m + n 为偶数时,(m + n) / 2 == (m + n + 1) / 2
所以 i + j = (m + n + 1) / 2
j = (m + n + 1) / 2 - i
这里规定 m < n,则只用枚举 m 就可以确定 b,因为如果 m > n 则可能会出现 j 为负数的情况
前面只是满足中位数的长度条件,还有大小条件,即左边的全部小于右边的
因为数组是有序的所以一定有 A[i-1] <= A[i], B[j-1] <= B[j]
所以只要再判断 A[i-1] <= B[j] && B[j-1] <= A[i]
综合上面的两个步骤:
从 0 到 m 枚举 i,计算出对应的 j ,然后判断 是否有 A[i-1] <= B[j] && B[j-1] <= A[i]
上面的步骤是等价于
在 [0, m] 中找到最大的 i,使得: A[i-1] <= B[j],其中 j = (m + n + 1) / 2 -i
这是因为
当 i 从 0 到 m 递增时,A[i-1] 递增,B[j] 递减,所以一定存在一个最大的 i 满足 A[i-1] <= B[j]
如果 i 是最大的,那么说明 i + 1 不满足。将 i + 1 带入可以得到 A[i] > B[j-1] ,也就是 B[j-1] < A[i],满足了 B[j-1] <= A[i] 的条件,甚至还要更强
所以就可以对 i 在 [0, m] 的区间上进行二分搜索,找到最大的满足 A[i-1] <= B[j] 的 i 值,就得到了划分的方法。此时划分前一部分元素中的最大值,以及划分后一部分元素中的最小值,才可能作为两个数组的中位数
细节部分:
边界条件:
i = 0时, A[i-1] 设为 最小值
i = m 时, A[i] 设为最大值
j = 0时, B[j-1] 设为最小值
j = n 时,B[j] 设为最大值
代码
二分:
class Solution {
public:
double getKthElement(vector<int> &nums1, vector<int> &nums2, int k) {
int m = nums1.size();
int n = nums2.size();
int i = 0, j = 0;
while (true) {
if (i == m) return nums2[j + k - 1];
if (j == n) return nums1[i + k - 1];
if (k == 1) return min(nums1[i], nums2[j]);
int ki = min(i + k/2 - 1, m-1);
int kj = min(j + k/2 - 1, n-1);
if (nums1[ki] <= nums2[kj]) {
k -= ki - i + 1;
i = ki + 1;
} else {
k -= kj - j + 1;
j = kj + 1;
}
}
}
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int tot = nums1.size() + nums2.size();
if (tot & 1) {
return getKthElement(nums1, nums2, (tot+1)/2);
} else {
return (getKthElement(nums1, nums2, tot/2) + getKthElement(nums1, nums2, (tot/2)+1)) / 2.0;
}
}
};
划分:
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
if (nums1.size() > nums2.size()) {
return findMedianSortedArrays(nums2, nums1);
}
int m = nums1.size();
int n = nums2.size();
int left = 0, right = m;
int median1 = 0, median2 = 0;
while (left <= right) {
int i = (left + right) / 2;
int j = (m + n + 1) / 2 - i;
int nums_im1 = (i == 0 ? INT_MIN : nums1[i-1]);
int nums_i = (i == m ? INT_MAX : nums1[i]);
int nums_jm1 = (j == 0 ? INT_MIN : nums2[j-1]);
int nums_j = (j == n ? INT_MAX : nums2[j]);
if (nums_im1 <= nums_j) {
median1 = max(nums_im1, nums_jm1);
median2 = min(nums_i, nums_j);
left = i + 1;
} else {
right = i - 1;
}
}
return (m + n) % 2 == 0 ? (median1 + median2) / 2.0 : median1;
}
};