ABC258Ex
考虑答案实际就是从 \(T=\left\{0,1,\cdots,S\right\}\setminus \left\{A_1,\cdots,A_N\right\}\) 中选取若干个数并且满足以下条件:
- \(0\) 和 \(S\) 一定要选。
- 当选的数从小到大排列时,相邻两个数的奇偶性不同。
先考虑 \(\mathcal O(S)\) 的暴力 DP。
先忽略一定要选 \(S\) 这个条件,设 \(f_{i,j}\) 表示在 \(T\) 中选不超过 \(i\) 的数,选的最大数模 \(2\) 余数为 \(j\) 的方案数。
每次 \(i\) 向 \(i+1\) 转移。
- 如果不选 \(i+1\),那么 \(f_{i+1,j}\gets f_{i+1,j}+f_{i,j},j=0,1\)。
- 否则如果 \(i+1\in T\),可以选 \(i+1\),那么 \(f_{i+1,j}\gets f_{i+1,j}+f_{i,1-j}\),\(j\) 是 \(i+1\) 模 \(2\) 的余数。
最后考虑 \(S\) 一定要选这个条件,那么就是 \(f_{S-1,1-j}\),\(j\) 是 \(S\) 模 \(2\) 的余数。
Code:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 100005, mod = 998244353;
int n; ll S;
bool vis[N];
int f[N][2];
void add(int &a, int b) {
a += b;
if (a >= mod) a -= mod;
}
int main() {
scanf("%d%lld", &n, &S);
for (int i = 1, x; i <= n; ++i) scanf("%d", &x), vis[x] = 1;
f[0][0] = 1;
for (int i = 0; i < S - 1; ++i) {
add(f[i + 1][0], f[i][0]), add(f[i + 1][1], f[i][1]);
if (!vis[i + 1]) add(f[i + 1][(i + 1) & 1], f[i][1 - ((i + 1) & 1)]);
}
printf("%d", f[S - 1][1 - (S & 1)]);
return 0;
}
怎么优化呢?考虑在相邻两个 \(A_i\) 之间,有很多重复的相同的转移,这是很浪费的,自然想到矩阵快速幂。
但是上面的状态不能这样优化,考虑换一个状态。
设 \(f_{i,j}\) 表示在 \(T\) 中选取不超过 \(i\) 的数,选的最大数的奇偶性和 \(i+j\) 相同的方案数。
那么转移变成了:
- \(f_{i+1,0}\gets f_{i+1,0}+f_{i,1},f_{i+1,1}\gets f_{i+1,1}+f_{i,0}\)
- 如果 \(i+1\in T\),那么 \(f_{i+1,0}\gets f_{i+1,0}+f_{i,0}\)
那么对于大段的 \(i+1\in T\),则有 \(\begin{pmatrix} f_{i+1,0} & f_{i+1,1} \end{pmatrix}=\begin{pmatrix} f_{i,0} & f_{i,1} \end{pmatrix}\begin{pmatrix} 1&1 \\ 1&0 \end{pmatrix}\)
否则对于 \(i+1\notin T\),就暴力转移。
最终答案为 \(f_{S-1,0}\)。
时间复杂度 \(\mathcal O(n\log S)\)。
Code:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 100005, mod = 998244353;
int n; ll S;
ll a[N];
struct mat {
int a[2][2];
mat operator * (const mat &x) const {
mat res; memset(res.a, 0, sizeof res.a);
for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
for (int k = 0; k < 2; ++k)
res.a[i][j] = (res.a[i][j] + 1ll * a[i][k] * x.a[k][j] % mod) % mod;
return res;
}
} trans, f;
mat qpow(mat x, ll y) {
mat res; memset(res.a, 0, sizeof res.a);
res.a[0][0] = res.a[1][1] = 1;
while (y) {
if (y & 1) res = res * x;
x = x * x;
y >>= 1;
}
return res;
}
int main() {
scanf("%d%lld", &n, &S);
for (int i = 1; i <= n; ++i) scanf("%lld", &a[i]);
trans.a[0][0] = trans.a[0][1] = trans.a[1][0] = 1, trans.a[1][1] = 0;
memset(f.a, 0, sizeof f.a); f.a[0][0] = 1;
for (int i = 1; i <= n; ++i) {
f = f * qpow(trans, a[i] - a[i - 1] - 1);
ll val = f.a[0][0];
f = f * trans;
f.a[0][0] = (f.a[0][0] - val + mod) % mod;
}
f = f * qpow(trans, S - a[n] - 1);
printf("%lld", f.a[0][0]);
return 0;
}