「题解」Codeforces 1542D Priority Queue
考虑一个数对答案的贡献,我们想要的是有替当前这个数 pop 的替死鬼,那么可以设计一个 dp:设 \(f_{i,j}\) 为当前考虑 \(a_x\) 在原序列选到 \(i\),共有 \(j\) 个替死鬼。注意到由于可以是非连续的子序列,所以每个 \(f_{i,j}\) 转移可以从 \(<i\) 的所有 \(k\),所有的 \(f_{k,j'}\) 转移而来。\(j'\) 是固定的,具体选哪个我们后面再说。这样子一次 \(f_{i,j}\) 就可以通过处理出对于每个 \(j\),预处理出所有 \(f_{k,j}\) 的前缀和,做到 \(\mathcal{O}(1)\) 转移。
那么如何选择 \(j'\) ?注意到如果当前 \(a_i=a_x\),没有定义谁先 pop,所以钦定遇到此类情况,且 \(i<x\) 的时候 \(a_i\) 才能当 \(a_x\) 的替死鬼。由于相等时弹出的数之间是没有区分的,但只有一种弹出的方案,所以给他加上区分仍然只有一种弹出方案,这样钦定是正确的。
综上所述,可以简单地设出一个 \(f_{i,j}\) 的 dp,来得到 \(a_x\) 对答案贡献了多少次,一共枚举 \(\mathcal{O}(n)\) 个 \(x\),一次 dp \(\mathcal{O}(n^2)\),复杂度为 \(\mathcal{O}(n^3)\)。
简单来讲,我们做的是考虑一个数对答案贡献的次数,基于“替死鬼”的贪心设计一个 dp。
详细 dp 方程见代码。
#include<iostream>
#include<cstdio>
#include<algorithm>
typedef long long ll;
const ll mod = 998244353;
template <typename T> T Max(T x, T y) { return x > y ? x : y; }
template <typename T> T Min(T x, T y) { return x < y ? x : y; }
template <typename T> T Add(T x, T y) { return (x + y >= mod) ? (x + y - mod) : (x + y); }
template <typename T> T Mul(T x, T y) { return x * y % mod; }
template <typename T>
T &read(T &r) {
r = 0; bool w = 0; char ch = getchar();
while(ch < '0' || ch > '9') w = ch == '-' ? 1 : 0, ch = getchar();
while(ch >= '0' && ch <= '9') r = (r << 3) + (r <<1) + (ch ^ 48), ch = getchar();
return r = w ? -r : r;
}
ll qpow(ll x, ll y) {
ll sumq = 1;
while(y) {
if(y&1) sumq = sumq * x % mod;
x = x * x % mod;
y >>= 1;
} return sumq;
}
const int N = 510;
int n, a[N];
ll s[N], f[N][N], ans;
void dp(int p) {
s[0] = 1;
for(int i = 1; i <= n; ++i) {
if(i == p) continue ;
for(int j = 0; j <= i; ++j) {
if(a[i] == 1 && !j) { f[i][j] = 0; continue ; }
if(a[i] != -1) {
if(a[i] < a[p] || (a[i] == a[p] && i < p)) f[i][j] = s[j-1];
else f[i][j] = s[j] % mod;
}
else f[i][j] = Add(s[j+1], (j==0 && i < p) ? s[0] : 0ll);
}
for(int j = 0; j <= i; ++j) s[j] = Add(s[j], f[i][j]);
}
}
signed main() {
read(n);
for(int i = 1; i <= n; ++i) {
char ch; std::cin >> ch;
if(ch == '+') read(a[i]);
else a[i] = -1;
}
for(int i = 1; i <= n; ++i)
if(a[i] != -1) {
for(int j = 1; j <= n; ++j) s[j] = 0;
dp(i);
ll sum = 0;
for(int i = 0; i <= n; ++i) sum = Add(sum, s[i]);
ans = Add(ans, sum * a[i] % mod);
}
printf("%lld\n", ans);
return 0;
}