LOJ6433 「PKUSC2018」最大前缀和
LOJ6433 「PKUSC2018」最大前缀和
题目大意
给定一个长度为 \(n\) 的序列 \(a\),求它在所有 \(n!\) 种排列方式下的【最大前缀和】之和。即,对所有 \(1\dots n\) 的排列 \(p\),求:
数据范围:\(1\leq n\leq 20\),\(\sum_{i = 1}^{n} |a_i|\leq 10^9\)。
本题题解
约定:对于一个序列,如果前缀和在多个位置都能取到最大值,我们以最靠后的位置为准。这样每个序列的最大前缀和所在位置都是唯一的,考虑枚举这个位置 \(i\),观察序列需要满足什么条件:
- \(\forall j \in[i + 1, n]\),\(\sum_{k = i + 1}^{j} a_k < 0\),即 \(a\) 的 \([i + 1, n]\) 这段后缀的所有前缀和都小于 \(0\)。
- \(\forall j \in[1, i - 1]\),\(\sum_{k = j + 1}^{i} a_k \geq 0\),即 \(a\) 的 \([2, i]\) 这段子段的所有后缀和都大于等于 \(0\)。
这两个条件,就是【\(i\) 是最大前缀和所在位置】的充分必要条件。因为如果存在不符合上述要求的 \(j\),则 \(j\) 会取代 \(i\) 成为最大前缀和所在位置;反之,则没有 \(j\) 能成为最大前缀和所在位置。
另外,请注意,在满足上述两个条件的前提下,我们对 \(a_1\) 是没有限制的,它可正,可负,可零。
状压 DP。
- 设 \(f(s)\) 表示:一个序列,用了集合 \(s\) 里的这些数,且所有后缀和都 \(\geq 0\),这样的序列有多少个。转移时,考虑在序列前面加入一个数即可。
- 设 \(g(s)\) 表示:一个序列,用了集合 \(s\) 里的这些数,且所有前缀和都 \(< 0\),这样的序列有多少个。转移时,考虑在序列后面加入一个数即可。
请注意,在 \(f(s)\) 和 \(g(s)\) 里,我们认为总共有 \(|s|!\) 种序列,也就是两个相同的数值交换位置后被认为是不同的序列。
这两个 DP 的时间复杂度都是 \(\mathcal{O}(2^n n)\)。
完成 DP 后,考虑统计答案。枚举最终序列里 \(a_1\) 的值(前文说过,我们对 \(a_1\) 没有限制)。然后枚举一个不包含 \(a_1\) 的集合 \(s\),表示 \([2, i]\) 这个子段所用的数。设剩下的 \(n - 1 - |s|\) 个数为集合 \(t\)。则此时的方案数是:\(f(s)\times g(t)\)。对答案的贡献是:\((a_1 + \sum_{x\in s}x)\times f(s)\times g(t)\)。
时间复杂度 \(\mathcal{O}(2^n n)\)。
参考代码
// problem: LOJ6433
#include <bits/stdc++.h>
using namespace std;
#define mk make_pair
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 20;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }
int n, a[MAXN + 5];
ll sum[1 << MAXN];
int f[1 << MAXN], g[1 << MAXN];
int main() {
cin >> n;
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
for (int i = 0; i < (1 << n); ++i) {
for (int j = 0; j < n; ++j) {
if ((i >> j) & 1) {
sum[i] += a[j];
}
}
}
f[0] = 1; // f: 任意一个前缀都 >= 0 的方案数
for (int i = 0; i <= (1 << n) - 2; ++i) {
for (int j = 0; j < n; ++j) {
if (!((i >> j) & 1) && sum[i] + a[j] >= 0) {
add(f[i | (1 << j)], f[i]);
}
}
}
g[0] = 1; // g: 任意一个前缀都 < 0 的方案数
for (int i = 0; i <= (1 << n) - 2; ++i) {
for (int j = 0; j < n; ++j) {
if (!((i >> j) & 1) && sum[i] + a[j] < 0) {
add(g[i | (1 << j)], g[i]);
}
}
}
int ans = 0;
for (int i = 0; i < n; ++i) { // a[1]
int all = (((1 << n) - 1) ^ (1 << i));
for (int j = 0; j < (1 << n); ++j) {
if (!((j >> i) & 1) && sum[j] >= 0) {
int s = (sum[j] + a[i] + MOD + MOD) % MOD;
add(ans, (ll)f[j] * g[all ^ j] % MOD * s % MOD);
}
}
}
cout << ans << endl;
return 0;
}