LeetCode Median of Two Sorted Arrays

  1 #include <iostream>
  2 #include <cstdlib>
  3 #include <cmath>
  4 #include <algorithm>
  5 
  6 using namespace std;
  7 
  8 class Solution {
  9 public:
 10     double findMedianSortedArrays(int A[], int m, int B[], int n) {
 11         double ma = 0;
 12         double mb = 0;
 13 
 14         bool empty_a = A == NULL || m < 1;
 15         bool empty_b = B == NULL || n < 1;
 16         
 17         if (!empty_a) ma = (A[(m - 1) / 2] + A[m/2]) / 2.0;
 18         if (!empty_b) mb = (B[(n - 1) / 2] + B[n/2]) / 2.0;
 19         
 20         if (empty_a && empty_b) { // will this happen ?
 21             return 0;
 22         } else if (empty_a) {
 23             return mb;
 24         } else if (empty_b) {
 25             return ma;
 26         }
 27         
 28         double low = 0, high = 0;
 29 
 30         if (ma > mb) {
 31             low = mb, high = ma;
 32         } else if (ma < mb) {
 33             low = ma, high = mb;
 34         } else {
 35             return ma;
 36         }
 37         
 38         double precise = 0.1;
 39         double mv = 0;
 40         int total = m + n;
 41         int half  = total / 2;
 42         bool declared = false;
 43         while(high - low > precise) {
 44             mv = (high + low) / 2.0;
 45             int* pa = lower_bound(A, A + m, mv);
 46             int* pb = lower_bound(B, B + n, mv);
 47             int lh = (pa - A) + (pb - B);
 48 
 49             if (lh < half) {        // the median assumed is too small, so increase it
 50                 low = mv;
 51             } else if (lh > half) { // the median assumed is too big, so decrease it
 52                 high= mv;
 53             } else {
 54                 declared = true;
 55                 // divided into odd/even case. should re-calculate the mv
 56                 // for even case median calculated from two adjacent numbers in
 57                 // the merged array, we assume that one is mmore and the other
 58                 // is mless (median = (mmore + mless) / 2.0 )
 59                 int mmore = 0;
 60                 // find bigger number to compute median for even case.
 61                 if (pa == A + m && pb == B + n) {
 62                     // should not happen;
 63                     cout<<"[1]should not happen"<<endl;
 64                 } else if (pa == A + m) {
 65                     mmore = *pb;
 66                 } else if (pb == B + n) {
 67                     mmore = *pa;
 68                 } else {
 69                     if (*pa < *pb) {
 70                         mmore = *pa;
 71                     } else {
 72                         mmore = *pb;
 73                     }
 74                 }
 75                 
 76                 // for odd case. the mv is equal to value of mmore
 77                 if (half * 2 != total) {
 78                     mv = mmore;
 79                     break;
 80                 }
 81                 
 82                 // find samller number to compute median for even case.
 83                 pa--, pb--;
 84                 int mless = 0;
 85                 if (pa < A && pb < B) {
 86                     // should not happen
 87                     cout<<"[2]should not happen"<<endl;
 88                 } else if (pa < A) {
 89                     mless = *pb;
 90                 } else if (pb < B) {
 91                     mless = *pa;
 92                 } else {
 93                     if (*pb > * pa) {
 94                         mless = *pb;
 95                     } else {
 96                         mless = *pa;
 97                     }
 98                 }
 99                 mv = (mless + mmore) / 2.0;
100                 break;
101             }
102         }
103         if (declared) { // median value is on the boundary
104             return mv;
105         }
106         if (fabs(mv - ma) < fabs(mv - mb)) {
107             return ma;
108         } else {
109             return mb;
110         }
111     }
112 };
113 
114 int main() {
115     Solution s;
116     int A[] = {1, 1};
117     int B[] = {1, 2};
118     int m = sizeof(A) / sizeof(A[0]);
119     int n = sizeof(B) / sizeof(B[0]);
120     
121     cout<<s.findMedianSortedArrays(A, m, B, n)<<endl;
122     system("pause");
123     return 0;
124 }

写得好乱啊, 这个还是二分搜索吧,只不过用来决定选择前半部还是后半部的评价标准变了,由原来的与一个确定常数数比较变为两个变量之间的比较(lh 与 half之间的数量关系),搜索空间由一个数组变为一个数值区间(其实都可以看做解的值域)。230ms+。

题目中提到"The overall run time complexity should be O(log (m+n)).",其实log前面有常数,由于数据是整数,经过32次二分搜索,可以使数值空间降到1以内,再过4次可以降到0.1内。

 

再用O(n)的简单解法感觉时间上差不多,不知为何

 1 class Solution {
 2 public:
 3     double findMedianSortedArrays(int A[], int m, int B[], int n) {
 4         int ia = 0, ib = 0;
 5         int it = -1;
 6         int im = (m + n - 1) / 2;
 7         int val= 0;
 8 
 9         bool empty_a = A == NULL || m < 1;
10         bool empty_b = B == NULL || n < 1;
11         
12         while (!empty_a && ia < m && !empty_b && ib < n && it < im) {
13             if (A[ia] < B[ib]) {
14                 val = A[ia++];
15             } else {
16                 val = B[ib++];
17             }
18             ++it;
19         }
20         
21         while (!empty_a && ia < m && it < im) {
22             val = A[ia++];
23             it++;
24         }
25         while (!empty_b && ib < n && it < im) {
26             val = B[ib++];
27             it++;
28         }
29         if ((m + n) & 1) {
30             return val;
31         } else {
32             int val2 = 0;
33             if ((empty_a || ia >= m) && (empty_b || ib >= n)) {
34                 // should not happen
35             } else if (empty_a || ia >= m) {
36                 val2 = B[ib];
37             } else if (empty_b || ib >= n) {
38                 val2 = A[ia];
39             } else {
40                 val2 = A[ia] > B[ib] ? B[ib] : A[ia];
41             }
42             return (val + val2) / 2.0;
43         }
44     }
45 };

 在discuss里找到一份log(m+n)的代码:

class Solution {
public:
    double findMedianSortedArrays(int A[], int m, int B[], int n) {
        int length=m+n;
        if(length%2)return findkth(A, m, B, n, length/2+1);
        else return (double(findkth(A, m, B, n, length/2))+findkth(A, m, B, n, length/2+1))/2;
    }
    int findkth(int A[],int m,int B[], int n, int k){
        if(m>n)
            return findkth(B, n, A, m,k);
        if(m==0)return B[k-1];
        if(k==1)return A[0]<B[0]?A[0]:B[0];
        int pa=k/2<m?k/2:m;
        int pb=k-pa;
        if(A[pa-1]==B[pb-1]){return A[pa-1];}
        if(A[pa-1]<B[pb-1])
            return findkth(A+pa, m-pa, B, pb, k-pa);
        else
            return findkth(A,pa,B+pb,n-pb,k-pb);
    }
};

花点时间理解:

下面先不考虑k/2>=Na, 及K=1(K=1时比较两数组首元素即可得出)的情况,数组下标从0开始。取第K个数的算法,首先取pa=k/2, pb=k-k/2;这样使得{A[0], A[1]...A[pa-1]}的元素数目加上{B[0], B[1]...B[pb-1]}的元素数目刚好等于k。此时如果:

1. A[pa-1] = B[pb-1],那么很容易知道A[pa-1]或者说B[pb-1]就是第K个数。因为数组是已排序的,且|{A[0], A[1]...A[pa-1]}| + |{B[0], B[1]...B[pb-1]}| = K

2. A[pa-1] < B[pb-1],那么可以认为第K个数肯定不在数组A的[0, pa-1]这个区间内。用反证法可以证明:

假设第K个数存在于A[0...pa-1]中,设其为X,则根据第K个数的含义,其前面必然存在K-1个数小于等于X。但由于X是在A[0...pa-1]中被找到的,而数组A中这样的数最多只有(即A[pa-1]为中位数时):|{A[0], A[1]...A[pa-1]}| - 1= k/2 - 1 < K-1。剩下的数需要从B数组中取,至少需要K - 1 - (K/2 - 1) = K - K/2个数。但由于存在条件A[pa-1] < B[pb-1],B数组中的第K-K/2个数即B[pb-1]要比X大,产生矛盾,故假设不成立。所以第K个数肯定不在数组A的[0, pa-1]这个区间内,此时我们只需要在剩下的区间内搜索就可以了,寻找第K大的元素变为寻找第(K-pa)大的元素(因为我们已经排除了数组A中前pa个元素)

3. A[pa-1] > B[pb-1],这种情况是第二种情况的对称情况。即可以排除{B[0], B[1]...B[pb-1]}这个搜索区间,并继续寻找第(K-pb)大的元素

 

再来一次:

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int len1 = nums1.size();
        int len2 = nums2.size();
        int total= len1 + len2;
        if (total & 0x1) {
            return findK(&nums1[0], &nums2[0], len1, len2, total / 2 + 1);
        } else {
            double lo = findK(&nums1[0], &nums2[0], len1, len2, total / 2);
            double hi = findK(&nums1[0], &nums2[0], len1, len2, total / 2 + 1);
            return (lo + hi) / 2;
        }
    }
    
    int findK(const int* a, const int* b, int na, int nb, int k) {
        if (nb < na) {
            return findK(b, a, nb, na, k);
        }
        
        if (na == 0) {
            return b[k - 1];
        }
        if (k == 1) {
            return a[0] > b[0] ? b[0] : a[0];
        }
        
        int pa = k / 2 < na ? k / 2 : na;
        int pb = k - pa;
        if (a[pa - 1] == b[pb - 1]) {
            return a[pa - 1];
        } else if (a[pa - 1] < b[pb - 1]) {
            return findK(a + pa, b, na - pa, nb, k - pa);
        } else {
            return findK(a, b + pb, na, nb - pb, k - pb);
        }
    }
};

 

不用指针真是烦了好多:

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        int n1 = nums1.size();
        int n2 = nums2.size();
        int total = n1 + n2;
        if (total & 0x1) {
            return findk(nums1, 0, n1, nums2, 0, n2, total/2);
        }
        return (findk(nums1, 0, n1, nums2, 0, n2, total/2) 
        + findk(nums1, 0, n1, nums2, 0, n2, total/2 - 1)) / 2;
    }
    
    double findk(vector<int>& n1, int s1, int e1, vector<int>& n2, int s2, int e2, int k) {
        int l1 = e1 - s1;
        int l2 = e2 - s2;
        if (l2 < l1) {
            return findk(n2, s2, e2, n1, s1, e1, k);
        }
        if (l1 == 0) {
            return n2[s2 + k];
        }
        if (k == 0) {
            return n1[s1] > n2[s2] ? n2[s2] : n1[s1]; 
        }
        int pa = (k+1)/2 > l1 ? l1 : (k+1)/2;
        int pb = k+1 - pa;
        if (n1[s1 + pa - 1] == n2[s2 + pb - 1]) {
            return n1[pa - 1];
        }
        if (n1[s1 + pa - 1] > n2[s2 + pb - 1]) {
            return findk(n1, s1, e1, n2, s2 + pb, e2, k - pb);
        } else {
            return findk(n1, s1 + pa, e1, n2, s2, e2, k - pa);
        }
    }
};

 

posted @ 2014-09-19 19:04  卖程序的小歪  阅读(311)  评论(0编辑  收藏  举报