ABC246Ex
DDP 板子。
设 \(f_{i,0/1}\) 表示前 \(i\) 位,以 \(0/1\) 结尾的本质不同子序列有多少种。
则最终答案就是 \(f_{n,0}+f_{n,1}\)。
考虑转移,以当前字符为 0
为例,则有
解释一下,显然 \(f_{i,0/1}\) 都能继承 \(f_{i-1,0/1}\),而 \(f_{i,0}\) 新增的部分是在前面每一个以 1
结尾的本质不同子序列后面放尽量多的 0
然后再放一个 0
,而且也可以全部放 0
,所以新增的部分是 \(f_{i-1,1}+1\)。
于是同理,当前字符为 1
,则有
当前字符为 ?
,则有
不难发现可以写成矩阵形式,即
-
当前字符为
0
,\(\begin{bmatrix} f_{i+1,0} & f_{i+1,1} & 1 \end{bmatrix}=\begin{bmatrix} f_{i,0} & f_{i,1} & 1 \end{bmatrix}\begin{bmatrix} 1 & 0 & 0\\ 1 & 1 & 0\\ 1 & 0 & 1 \end{bmatrix}\) -
当前字符为
1
,\(\begin{bmatrix} f_{i+1,0} & f_{i+1,1} & 1 \end{bmatrix}=\begin{bmatrix} f_{i,0} & f_{i,1} & 1 \end{bmatrix}\begin{bmatrix} 1 & 1 & 0\\ 0 & 1 & 0\\ 0 & 1 & 1 \end{bmatrix}\) -
当前字符为
?
,\(\begin{bmatrix} f_{i+1,0} & f_{i+1,1} & 1 \end{bmatrix}=\begin{bmatrix} f_{i,0} & f_{i,1} & 1 \end{bmatrix}\begin{bmatrix} 1 & 1 & 0\\ 1 & 1 & 0\\ 1 & 1 & 1 \end{bmatrix}\)
因为还要支持带修,于是用线段树维护矩阵即可。
思考清楚后代码写起来很快,没记错的话我当时写了 \(15\) 分钟一遍过。
Code:
#include <bits/stdc++.h>
using namespace std;
#define ls(p) (p << 1)
#define rs(p) (p << 1 | 1)
typedef long long ll;
const int N = 100005, mod = 998244353;
int n, Q;
char s[N];
struct mat {
int a[3][3];
mat operator * (const mat &x) const {
mat res; memset(res.a, 0, sizeof res.a);
for (int i = 0; i < 3; ++i)
for (int j = 0; j < 3; ++j)
for (int k = 0; k < 3; ++k)
res.a[i][j] = (res.a[i][j] + 1ll * a[i][k] * x.a[k][j] % mod) % mod;
return res;
}
} f;
mat calc(char ch) {
mat res; memset(res.a, 0, sizeof res.a);
if (ch == '0') {
res.a[0][0] = res.a[1][0] = res.a[1][1] = res.a[2][0] = res.a[2][2] = 1;
}
else if (ch == '1') {
res.a[0][0] = res.a[0][1] = res.a[1][1] = res.a[2][1] = res.a[2][2] = 1;
}
else {
res.a[0][0] = res.a[1][0] = res.a[2][0] = res.a[0][1] = res.a[1][1] = res.a[2][1] = res.a[2][2] = 1;
}
return res;
}
struct Segment_Tree {
mat val[N*4];
void pushup(int p) { val[p] = val[ls(p)] * val[rs(p)]; }
void build(int p, int l, int r) {
if (l == r) return void(val[p] = calc(s[l]));
int mid = l + r >> 1;
build(ls(p), l, mid), build(rs(p), mid + 1, r);
pushup(p);
}
void change(int p, int l, int r, int x, char v) {
if (l == r) return void(val[p] = calc(v));
int mid = l + r >> 1;
if (x <= mid) change(ls(p), l, mid, x, v);
else change(rs(p), mid + 1, r, x, v);
pushup(p);
}
} sTr;
int main() {
memset(f.a, 0, sizeof f.a); f.a[0][2] = 1;
scanf("%d%d", &n, &Q);
scanf("%s", s + 1);
sTr.build(1, 1, n);
while (Q--) {
int x; char ch[2]; scanf("%d%s", &x, ch);
sTr.change(1, 1, n, x, ch[0]);
mat tmp = f * sTr.val[1];
printf("%d\n", (tmp.a[0][0] + tmp.a[0][1]) % mod);
}
return 0;
}