洛谷 P1631 序列合并
Description
有两个长度都是 \(N\) 的序列 \(A\) 和 \(B\),在 \(A\) 和 \(B\) 中各取一个数相加可以得到 \(N^2\) 个和,求这 \(N^2\) 个和中最小的 \(N\) 个。
Constraints
对于 \(50\%\) 的数据中,满足\(1 \le N \le 1000\);
对于 \(100\%\) 的数据中,满足\(1 \le N \le 100000\)。
Solution
Solution 0 \(70pts\)
算出所有数后全部暴力丢到堆里,取出前 \(n\) 个数。
时间复杂度\(O(n^2logn)\)。
Solution 1 \(100pts\)
由于 \(a\)、\(b\) 数组已均有序,首先可以观察出一个性质
-
\(a[1]+b[1] \le a[1] + b[2] \le ... \le a[1] + b[n]\)
-
\(a[2]+b[1] \le a[2] + b[2] \le ... \le a[2] + b[n]\)
-
......
-
\(a[n]+b[1] \le a[n] + b[2] \le ... \le a[n] + b[n]\)
可以把这 \(n ^ 2\) 个数按 \(a[i]\) 分为 \(n\) 组,每一组都是以 \(a[i] +b[1]\) \(a[i]+b[2]\) \(...\) \(a[i]+b[n]\) 的顺序,则在这些组中取出 \(n\) 个最小的数即为答案
考虑堆,先把每组的第一个数丢进堆里,并将 \(n\) 组指针均指向 \(1\)
后取出堆中最小元素,并记下该元素的组的编号,将这组指针向后移一个位置,将这组下一个元素丢到堆里面去
进行 \(n\) 次,把每次取出的数按照顺序输出即可。
由于每一组都是满足以上的性质,所以每组前面的数肯定比后一个小,而 \(n\) 组也能覆盖到所有的数,所以一定是正确的。
我们不需要先把每一组的数先算好,可以边移边算,所以时间复杂度仍为 \(O(nlogn)\)。
Solution 1 code:
// by youyou2007 in 2022.
#include <iostream>
#include <cstdlib>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <queue>
#include <stack>
#include <map>
#define REP(i, x, y) for(int i = x; i < y; i++)
#define rep(i, x, y) for(int i = x; i <= y; i++)
#define PER(i, x, y) for(int i = x; i > y; i--)
#define per(i, x, y) for(int i = x; i >= y; i--)
#define lc (k << 1)
#define rc (k << 1 | 1)
using namespace std;
const int N = 1E5 + 5;
int n, a[N], b[N], p[N];
struct node
{
int id, x;//注意要记录下每个数属于哪一个组
friend bool operator > (node xx, node yy)//结构体堆需要重载运算符
{
return xx.x > yy.x;
}
};
priority_queue <node, vector<node>, greater<node> > que;//取最小的数,小根堆
vector <int> ans;
int main()
{
scanf("%d", &n);
rep(i, 1, n)
{
scanf("%d", &a[i]);
p[i] = 1;
}
rep(i, 1, n)
{
scanf("%d", &b[i]);
}
rep(i, 1, n)
{
node temp = {0, 0};
temp.id = i;
temp.x = a[i] + b[1];
que.push(temp);//把每组第一个元素丢进队列
}
while(ans.size() != n)//进行 $n$ 次
{
node temp = que.top();
ans.push_back(temp.x);//取出最小数
que.pop();
p[temp.id]++;//指针后移
temp.id = temp.id;
temp.x = a[temp.id] + b[p[temp.id]];
que.push(temp);//放入下一个元素
}
for(int i = 0; i < ans.size(); i++)
{
printf("%d ", ans[i]);
}
return 0;
}