23 年牛客提高组模拟赛 Day5 T3
给你一个长为 \(n\) 的数组 \(b_i\) 表示原数组 \(a_i\) 中以 \(i\) 结尾的 LIS 长度,问对于所有 \(1 \leq a_i \leq m\) ,原数组有多少种不同的可能
\(n \leq 20, m \leq 3000\)
看到数据范围容易想到状压 dp ,赛事想了个比较朴素的 dp :设 \(dp_{S,i}\) 表示填了集合 \(S\) 的数,其中填的最大的数是 \(i\) 的方案数。转移即枚举下一个要填的数是什么,假如枚举的是 \(i\) ,要满足以下条件: \(\max\limits_{j \leq i, j \in S} a_j + 1 = a_i\) ,原因自己想。
发现我们没必要真把数填进去,而是可以确立他们的相对大小关系之后用组合数填数,也就是说我们把要填的数范围变成了 \([1,n]\) 。故设 \(dp_{S,i}\) 表示填了集合 \(S\) 的数,填的最大的数是 \(i\) 的方案数。转移即枚举 \(S\) 子集,满足的条件同理。最终答案即为 \(\sum\limits_{i=1}^{n} dp_{2^n-1, i} \times \binom{m}{i}\) ,复杂度 \(O(3^n \times n^2)\)
发现枚举子集是孬的,因此考虑顺着填答案,具体的,设 \(dp_{i,j,S}\) 表示填了前 \(i\) 种数,第 \(i\) 种考虑到 \(j\) 位置填/不填,填过的集合为 \(S\) 的方案数。容易得到转移:
还有一个问题,怎么计算答案?因为我们并不知道第 \(i\) 个数到底填没填
第一种方法是在 dp 中再记录第 \(i\) 个数填/没填
第二种方法是考虑容斥,设 \(g_i = dp_{i+1,0,2^n-1}\) ,即填 \(\leq i\) 个数的方案数。我们想求 \(f_i\) 表示恰好填 \(i\) 个的方案数。发现满足:
这是一个显然的子集形式的二项式反演问题,考虑容斥减去没填的方案,即:
最终复杂度 \(O(2^n n^2)\)
code :
#include <bits/stdc++.h>
// #pragma GCC optimize(2)
#define pcn putchar('\n')
#define ll long long
#define LL __int128
#define pii pair<int, int>
#define pli pair<ll, int>
#define pil pair<int, ll>
#define pll pair<ll, ll>
#define MP make_pair
#define fi first
#define se second
#define gsize(x) ((int)(x).size())
#define Min(a, b) (a = min(a, b))
#define Max(a, b) (a = max(a, b))
#define For(i, j, k) for(int i = (j), END##i = (k); i <= END##i; ++ i)
#define For__(i, j, k) for(int i = (j), END##i = (k); i >= END##i; -- i)
#define Fore(i, j, k) for(int i = (j); i; i = (k))
using namespace std;
namespace IO {
template <typename T> T read(T &num){
num = 0; T f = 1; char c = ' '; while(c < 48 || c > 57) if((c = getchar()) == '-') f = -1;
while(c >= 48 && c <= 57) num = (num << 1) + (num << 3) + (c ^ 48), c = getchar();
return num *= f;
}
ll read(){
ll num = 0, f = 1; char c = ' '; while(c < 48 || c > 57) if((c = getchar()) == '-') f = -1;
while(c >= 48 && c <= 57) num = (num << 1) + (num << 3) + (c ^ 48), c = getchar();
return num * f;
}
template <typename T> void Write(T x){
if(x < 0) putchar('-'), x = -x;
if(x == 0){putchar('0'); return ;}
if(x > 9) Write(x / 10);
putchar('0' + x % 10); return ;
}
void putc(string s){ int len = s.size() - 1; For(i, 0, len) putchar(s[ i ]); }
template <typename T> void write(T x, string s = "\0"){ Write( x ), putc( s ); }
}
using namespace IO;
//mt19937_64 rnd(time(0));
//ll random(ll l, ll r){ return (rnd() % (r - l + 1)) + l; }
#ifdef LOCAL
template <typename T> void debug(T x, string s = "\0"){ write(x, s); }
#else
template <typename T> void debug(T x, string s = "\0"){}
#endif
/* ====================================== */
const int maxn = 25;
const int maxm = 3050;
const int maxS = (1 << 20) + 50;
const ll mod = 998244353;
int n, m; int a[ maxn ], maxs[ maxS ];
int C[ maxm ][ maxm ];
int dp[ 2 ][ maxn ][ maxS ]; // dp_{i,j,S} 表示填前 i 种数,第 i 种考虑到 j 位置,填数集合为 S 方案数
// 因为题解是从后往前考虑所以我写的也是
ll ans[ maxn ];
template<typename T> void add(T &x, T y){ x += y; if(x >= mod) x -= mod; }
void mian(){
read(n), read(m); For(i, 1, n) read(a[ i ]);
For(S, 0, (1 << n) - 1) For(i, 0, n - 1) if(S >> i & 1) Max(maxs[ S ], a[ i + 1 ]); // 记录集合 S 中数的最大值
int o = 0; dp[ o ][ n - 1 ][ 0 ] = 1;
For(i, 1, n){
For__(j, n - 1, 0)
For(S, 0, (1 << n) - 1){
int &nw = dp[ o ][ j ][ S ]; if(!nw) continue;
if(j) add(dp[ o ][ j - 1 ][ S ], nw);
else add(dp[ o ^ 1 ][ n - 1 ][ S ], nw);
if((S >> j & 1) == 0 && maxs[ S & ((1 << j) - 1) ] + 1 == a[ j + 1 ]){
if(j) add(dp[ o ][ j - 1 ][ S | (1 << j) ], nw);
else add(dp[ o ^ 1 ][ n - 1 ][ S | (1 << j) ], nw);
}
dp[ o ][ j ][ S ] = 0; // 注意清空
}
ans[ i ] = dp[ o ^ 1 ][ n - 1 ][ (1 << n) - 1 ]; o ^= 1;
// 记录答案,这里如果没有80行清空操作的话写 dp[ o ][ 0 ][ (1 << n) - 1 ] 也是对的
}
ll res = 0;
For(i, 1, n){
ll t = 0; For(j, 1, i){ t = (t + 1ll * ((i - j & 1) ? mod - 1 : 1) * C[ i ][ j ] % mod * ans[ j ]) % mod; }
res = (res + t * C[ m ][ i ]) % mod;
} // 容斥
write(res, "\n");
}
void init(){
}
void treatment(){
int up = 3005;
C[ 0 ][ 0 ] = 1; For(i, 1, up){ C[ i ][ 0 ] = 1; For(j, 1, i){ C[ i ][ j ] = C[ i - 1 ][ j ] + C[ i - 1 ][ j - 1 ]; if(C[ i ][ j ] >= mod) C[ i ][ j ] -= mod; } }
}
int main() {
#ifdef LOCAL
freopen("data.in", "r", stdin);
// freopen("data.out", "w", stdout);
#else
// freopen("data.in", "r", stdin);
// freopen("data.out", "w", stdout);
#endif
treatment();
int T = 1;
// read(T);
while(T --){
init();
mian();
}
return 0;
}