POJ2442. Sequence
题目大意
\(m(0<m\leq 1000)\) 个长度为 \(n(0<n\leq 2000)\) 的非整数序列,可以从每个序列中选一个数字,组成一个新序列,新序列总共会有 \(n^m\) 种,求这些序列的序列和中前 \(n\) 小的。
思路
先对所有序列排序,考虑仅有两个序列 \(a,b\) 的情况,最小值显然是 \(a[1]+b[1]\) ,之后 \(a[1]+b[2],a[2]+b[1]\) 就会变为下一个最小的候选,如此重复,就可以找出前 \(n\) 小的,注意到有的候选会由多个方式得到,比如 \(a[2]+b[2]\) 可以由 \(a[1]+b[2]\) 和 \(a[2]+b[1]\) 得到,于是我们限定当 \(j\) 增加了之后只能继续增加 \(j\) ,这样对于每一个 \(a[i]+b[j]\) ,都必须先让 \(i\) 被达到,于是就可以仅由一种方式得到,在实现时用一个堆维护四元组 \((a[i]+b[j],i,j,lst)\) , \(lst\) 代表上一次增加的是否为 \(j\) ,其为 \(true\) 时,仅将 \((a[i]+b[j+1],i,j+1,true)\) 入堆,否则还可以将 \((a[i+1]+b[j],i+1,j,false)\) 一并入堆。我们扩展到多个序列时,只需要不断把以上方法产生的 \(n\) 个数字作为新序列,与下一个未处理的原序列进行如上操作 \(m-1\) 次后最终只剩下一个长为 \(n\) 的序列,即为答案,复杂度 \(O(nmlogn)\) 。
代码
#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
#define all(x) x.begin(),x.end()
//#define int LL
//#define lc p*2+1
//#define rc p*2+2
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL MOD = 100000000;
const LL mod = 998244353;
const int maxn = 500010;
int T, M, N;
int num[1010][2010];
int tmp[1010][2010];
struct node {
int sum, i, j;
bool lst;
bool operator<(const node& rhs)const
{
return sum > rhs.sum;
}
};
void solve()
{
for (int i = 1; i <= M; i++)
sort(num[i] + 1, num[i] + N + 1);
for (int i = 1; i <= N; i++)
tmp[1][i] = num[1][i];
for (int k = 1; k < M; k++)
{
priority_queue<node>que;
que.push({ tmp[k][1] + num[k + 1][1],1,1,false });
for (int p = 1; p <= N; p++)
{
node n = que.top();
que.pop();
tmp[k + 1][p] = tmp[k][n.i] + num[k + 1][n.j];tmp[k][n.i] + num[k + 1][n.j];
if (n.i + 1 <= N && !n.lst)
que.push({ tmp[k][n.i + 1] + num[k + 1][n.j],n.i + 1,n.j,false });
if (n.j + 1 <= N)
que.push({ tmp[k][n.i] + num[k + 1][n.j + 1],n.i,n.j + 1,true });
}
}
sort(tmp[M] + 1, tmp[M] + N + 1);
for (int i = 1; i <= N; i++)
cout << tmp[M][i] << ' ';
cout << endl;
}
int main()
{
IOS;
cin >> T;
while (T--)
{
cin >> M >> N;
for (int i = 1; i <= M; i++)
{
for (int j = 1; j <= N; j++)
cin >> num[i][j];
}
solve();
}
return 0;
}