POJ2442. Sequence

题目大意

\(m(0<m\leq 1000)\) 个长度为 \(n(0<n\leq 2000)\) 的非整数序列,可以从每个序列中选一个数字,组成一个新序列,新序列总共会有 \(n^m\) 种,求这些序列的序列和中前 \(n\) 小的。

思路

先对所有序列排序,考虑仅有两个序列 \(a,b\) 的情况,最小值显然是 \(a[1]+b[1]\) ,之后 \(a[1]+b[2],a[2]+b[1]\) 就会变为下一个最小的候选,如此重复,就可以找出前 \(n\) 小的,注意到有的候选会由多个方式得到,比如 \(a[2]+b[2]\) 可以由 \(a[1]+b[2]\)\(a[2]+b[1]\) 得到,于是我们限定当 \(j\) 增加了之后只能继续增加 \(j\) ,这样对于每一个 \(a[i]+b[j]\) ,都必须先让 \(i\) 被达到,于是就可以仅由一种方式得到,在实现时用一个堆维护四元组 \((a[i]+b[j],i,j,lst)\) , \(lst\) 代表上一次增加的是否为 \(j\) ,其为 \(true\) 时,仅将 \((a[i]+b[j+1],i,j+1,true)\) 入堆,否则还可以将 \((a[i+1]+b[j],i+1,j,false)\) 一并入堆。我们扩展到多个序列时,只需要不断把以上方法产生的 \(n\) 个数字作为新序列,与下一个未处理的原序列进行如上操作 \(m-1\) 次后最终只剩下一个长为 \(n\) 的序列,即为答案,复杂度 \(O(nmlogn)\)

代码

#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
#define all(x) x.begin(),x.end()
//#define int LL
//#define lc p*2+1
//#define rc p*2+2
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL MOD = 100000000;
const LL mod = 998244353;
const int maxn = 500010;

int T, M, N;
int num[1010][2010];
int tmp[1010][2010];
struct node {
	int sum, i, j;
	bool lst;
	bool operator<(const node& rhs)const
	{
		return sum > rhs.sum;
	}
};

void solve()
{
	for (int i = 1; i <= M; i++)
		sort(num[i] + 1, num[i] + N + 1);
	for (int i = 1; i <= N; i++)
		tmp[1][i] = num[1][i];
	for (int k = 1; k < M; k++)
	{
		priority_queue<node>que;
		que.push({ tmp[k][1] + num[k + 1][1],1,1,false });
		for (int p = 1; p <= N; p++)
		{
			node n = que.top();
			que.pop();
			tmp[k + 1][p] = tmp[k][n.i] + num[k + 1][n.j];tmp[k][n.i] + num[k + 1][n.j];
			if (n.i + 1 <= N && !n.lst)
				que.push({ tmp[k][n.i + 1] + num[k + 1][n.j],n.i + 1,n.j,false });
			if (n.j + 1 <= N)
				que.push({ tmp[k][n.i] + num[k + 1][n.j + 1],n.i,n.j + 1,true });
		}
	}
	sort(tmp[M] + 1, tmp[M] + N + 1);
	for (int i = 1; i <= N; i++)
		cout << tmp[M][i] << ' ';
	cout << endl;
}

int main()
{
	IOS;
	cin >> T;
	while (T--)
	{
		cin >> M >> N;
		for (int i = 1; i <= M; i++)
		{
			for (int j = 1; j <= N; j++)
				cin >> num[i][j];
		}
		solve();
	}

	return 0;
}
posted @ 2022-03-17 23:36  Prgl  阅读(22)  评论(0编辑  收藏  举报