LeetCode 2040. Kth Smallest Product of Two Sorted Arrays

题目

题意: 两个拍排好序的数组,a, b 问你第k小的两个数组的元素乘积是多大。

题解: 两个数组的元素乘积最小是-10^5 * 10^5 最大是10^5 * 10^5
我们可以在这个范围内做二分,那么问题的关键就是能不能给你一个数,让你找到有多少个元素乘积小于这个数,复杂度最多n*log(n)
其实可以的,很简单,就是对数组a的每个元素,二分查找数组b。注意数组是有负数的

代码:

class Solution {
public:
    int p1;
    int p2;
    long long int a[50005];
    long long int b[50005];

   long long kthSmallestProduct(vector<int>& nums1, vector<int>& nums2, long long k) {
       
        for (int i = 0; i < nums1.size(); i++)
        {
            a[i] = nums1[i];
        }
    
        for (int i = 0; i < nums2.size(); i++)
        {
            b[i] = nums2[i];
        }
   
        for (int i = 0; i < nums1.size(); i++)
        {
            if (a[i] >= 0)
            {
                p1 = i;
                break;
            }
        }
       
       if(a[nums1.size()-1]<0)
       {
           p1 = nums1.size();
       }

       long long int m1 = a[0];
       long long int n1 = a[nums1.size()-1];
       
       long long int m2 = b[0];
       long long int n2 = b[nums2.size()-1];
       
       long long int l = min(m1*m2, min(m1*n2, min(n1*m2, n1*n2)));
       long long int r = max(m1*m2, max(m1*n2, max(n1*m2, n1*n2)));

    long long int mid;
      
    while (l <= r)
    {
        mid = (l + r) / 2;
        long long int  x = find(mid, nums1, nums2);
        if (x >= k)
        {
            r = mid - 1;
            continue;
        }

        if (x < k)
        {
            l = mid + 1;
            continue;
        }
    }

    return l;
 }

long long int find(long long int x, vector<int>& nums1, vector<int>& nums2)
{
    int l, r, mid;
    long long int sum = 0;
    for (int i = 0; i < p1; i++)
    {
        l = 0;
        r = nums2.size() - 1;
        while (l <= r)
        {
            mid = (l + r) / 2;
            if (a[i] * b[mid] <= x)
            {
                r = mid - 1;
                continue;
            }

            if (a[i] * b[mid] > x)
            {
                l = mid + 1;
                continue;
            }
        }

        sum += nums2.size() - l;
    }

    for (int i = p1; i < nums1.size(); i++)
    {
        l = 0;
        r = nums2.size()-1;
        while (l <= r)
        {
            mid = (l + r) / 2;
            if (a[i] * b[mid] <= x)
            {
                l = mid + 1;
                continue;
            }

            if (a[i] * b[mid] > x)
            {
                r = mid - 1;
                continue;
            }
        }

        sum += r + 1;

    }

    return sum;
}
};
posted @ 2022-01-14 20:02  Shendu.CC  阅读(228)  评论(1编辑  收藏  举报