[dp] [组合计数] [神仙题] ARC146e Simple Speed
显然这是 $\rm dp$。
容易想到一个以 $B$ 序列长度定义的做法,但是没法优化。
改变角度,从值域入手,即将元素从小到大依次插入 $B$ 序列,这样的好处是元素大小有序,即新插入的元素一定大于原来插入的元素。
考虑插入型 $dp$,定义 $f_{i,j}$ 表示考虑 $[1,i]$ 的元素构成了 $j$ 个合法区间的方案数,注意到每个区间的左边界和右边界必须为 $i$,否则此后无法被合并成大的合法区间。
设当前插入 $i$,枚举 $f_{i-1,j}$。那么每个合法区间的间隙都必须插入一个 $i$,然后我们随意将剩下的 $a_i-(j-1)$ 个 $i$ 安排在这些间隙,这是个经典问题,方案数是可以算出来的。最后考虑插入完毕后的合法区间数,易发现安排完 $j-1$ 个 $i$ 后,剩下的所有 $i$ 要么将一个合法区间断开成为两个,要么独自构成一个合法区间,总的来说就是每添加一个 $i$ 就会使得区间数 $+1$,因此最终区间数为 $a_i-(j-1)+1$。
但是这样会少算一些方案,我们注意到最左 / 最右区间的左边 / 右边不一定要是 $i$,于是完善状态: $f_{i,j,0/1,0/1}$ 表示考虑最左边是否为 $i$ 和最右边是否为 $i$,若钦定某一边必须放 $i$,就在一开始分配一个 $i$ 给它,并且之后随意分配时允许在这一边放 $i$。
理清思路后,转移方程就很简单了,建议自己手推。
至此,复杂度为 $O(n\sum a_i)$,不过容易发现 $j$ 这一维必须 $\le a$,否则此后不可能将合法区间合并为一个区间,复杂度 $O(n^2)$。
然后怎么办?如果你将状态转移画成图,就能发现对于每个 $i$,可能有值的 $f_{i,j}$ 不超过 $3$ 个,这道题神就神在这了。于是拿个 map
或 set
保存状态即可。
代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ADD(a, b) (a) = ((a) + (b)) % mod
const int N = 2e5 + 5, mod = 998244353;
int n, a, f[N][2][2], f2[N][2][2];
int jc[N << 1], jcinv[N << 1];
set<int>now, now2;
inline int qstp(int a, int k) {int res = 1; for(; k; a = a * a % mod, k >>= 1) if(k & 1) res = res * a % mod; return res;}
inline void init() {jcinv[0] = jc[0] = 1; for(int i = 1; i < N << 1; ++i) jcinv[i] = qstp(jc[i] = jc[i - 1] * i % mod, mod - 2);}
inline int C(int n, int m) {return (n < 0 || m < 0 || n < m) ? 0 : jc[n] * jcinv[n - m] % mod * jcinv[m] % mod;}
inline int S(int n, int m) {return C(m + n - 1, n - 1);}
signed main(){
init();
cin >> n;
for(int i = 1; i <= n; ++i){
scanf("%lld", &a), now2.clear();
if(i == 1) {
f[a][1][1] = 1, now.insert(a);
continue;
}
for(auto j : now){
int cnt = a - j + 2, sum = a - j + 1;
if(sum < 0) continue;
if(f[j][0][0])
ADD(f2[cnt][0][0], f[j][0][0] * S(j - 1, sum)), now2.insert(cnt);
if(f[j][0][1]){
ADD(f2[cnt][0][0], f[j][0][1] * S(j - 1, sum)), now2.insert(cnt);
ADD(f2[cnt - 1][0][1], f[j][0][1] * S(j, sum - 1)), now2.insert(cnt - 1);
}
if(f[j][1][0]){
ADD(f2[cnt][0][0], f[j][1][0] * S(j - 1, sum)), now2.insert(cnt);
ADD(f2[cnt - 1][1][0], f[j][1][0] * S(j, sum - 1)), now2.insert(cnt - 1);
}
if(f[j][1][1]){
ADD(f2[cnt][0][0], f[j][1][1] * S(j - 1, sum)), now2.insert(cnt);
ADD(f2[cnt - 1][0][1], f[j][1][1] * S(j, sum - 1)), now2.insert(cnt - 1);
ADD(f2[cnt - 1][1][0], f[j][1][1] * S(j, sum - 1)), now2.insert(cnt - 1);
ADD(f2[cnt - 2][1][1], f[j][1][1] * S(j + 1, sum - 2)), now2.insert(cnt - 2);
}
}
for(auto j : now)
f[j][0][0] = f[j][0][1] = f[j][1][0] = f[j][1][1] = 0;
for(auto j : now2)
for(int q = 0; q < 2; ++q)
for(int p = 0; p < 2; ++p)
f[j][q][p] = f2[j][q][p], f2[j][q][p] = 0;
now = now2;
}
cout << (f[1][0][0] + f[1][0][1] + f[1][1][0] + f[1][1][1]) % mod;
return 0;
}