P6630 [ZJOI2020] 传统艺能
一个结点分四个状态:自己有/无 tag,祖先有/无 tag,分别记为 \(f_{00}, f_{01}, f_{10}, f_{11}\)。
转移要分讨几个系数,为了方便表述,即当前结点代表区间为 \(\mathcal I_u\),父结点代表区间为 \(\mathcal I_{f}\),修改区间为 \(\mathcal I_x\)。
- \(\mathcal I_f \cap \mathcal I_x = \empty\),状态不发生变化,系数为 \(A\)。
- \(\mathcal I_f \sube \mathcal I_x\),祖先的 tag 置为 \(1\),系数为 \(B\)。
- Otherwise,pushdown(自己继承祖先的 tag,祖先的 tag 清零)
- \(\mathcal I_u \cap \mathcal I_x = \empty\),状态不发生变化,系数为 \(C\)。
- \(\mathcal I_u \sube \mathcal I_x\),自己的 tag 置为 \(1\),系数为 \(D\)。
- Otherwise,自己的 tag 置为 \(0\),系数为 \(E\)。
分讨完系数,有一个并不会影响 AC 的常数优化:
注意到我们最后只要 \(f_{01} + f_{11}\),且这两个状态的转移可以合并,所以简化成三个状态:\(f_0, f_1, f_2\),分别表示自己和祖先都没有 tag、自己有 tag、自己没有 tag 但祖先有 tag。
然后写转移方程,注意 \(k \le 10^9\),所以注定是要用矩乘了,直接写转移矩阵罢:
\[T = \begin{bmatrix}
A + C + E & E & E \\
D & A + B + C + D & C + D \\
B & 0 & A + B
\end{bmatrix}
\]
然后有:
\[T^k \times \begin{bmatrix}f_0 \\ f_1 \\ f_2 \end{bmatrix} = \begin{bmatrix} f_0' \\ f_1' \\ f_2' \end{bmatrix}
\]
答案就是 \(\sum f_1'\)。
因为初值是 \(f_0 = 1, f_1 = f_2 = 0\),实际实现的时候直接取 \(T^k\) 的最左列就好了。
时间复杂度 \(\mathcal O(27 n \log k)\)。
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
constexpr int N = 2e5 + 10, MOD = 998244353;
int n, m, inv;
struct Matrix {
ll a[3][3];
ll * operator[](const int &i) {return a[i];}
Matrix operator*(const Matrix &rhs) const {
Matrix res;
for (int i = 0; i < 3; i++) for (int j = 0; j < 3; j++) {
res.a[i][j] = 0;
for (int k = 0; k < 3; k++) res.a[i][j] += a[i][k] * rhs.a[k][j];
res.a[i][j] %= MOD;
}
return res;
}
} I;
struct Node {
ll out, in, oi;
Node(ll a, ll b, ll c) : out(a), in(b), oi(c) {}
Node(ll l, ll r) {
out = (l * (l - 1) / 2 % MOD * inv + (n - r + 1) * (n - r) / 2 % MOD * inv) % MOD;
in = l * (n - r + 1) % MOD * inv % MOD;
oi = (1 - out - in + 2 * MOD) % MOD;
}
};
Matrix qp(Matrix base, int e) {
Matrix res = I;
while (e) {
if (e & 1) res = res * base;
base = base * base;
e >>= 1;
}
return res;
}
ll Inv(ll base, int e = MOD - 2) {
ll res = 1;
while (e) {
if (e & 1) res = res * base % MOD;
base = base * base % MOD;
e >>= 1;
}
return res;
}
ll ans = 0;
void solve(int l, int r, const Node &fa) {
Node u(l, r);
ll A = fa.out, B = fa.in, C = (u.out - fa.out + MOD) % MOD, D = (u.in - fa.in + MOD) % MOD, E = (1 - A - B - C - D + 4ll * MOD) % MOD;
Matrix trans = {{{(A + C + E) % MOD, E, E}, {D, (A + B + C + D) % MOD, (C + D) % MOD}, {B, 0, (A + B) % MOD}}};
ans += qp(trans, m)[1][0];
if (l == r) return;
int mid; cin >> mid; solve(l, mid, u), solve(mid + 1, r, u);
}
int main() {
// freopen("ex_segment3.in", "r", stdin), freopen("segment.out", "w", stdout);
ios_base::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
I[0][0] = I[1][1] = I[2][2] = 1;
cin >> n >> m; inv = Inv(1ll * (n + 1) * n / 2 % MOD);
solve(1, n, Node(0, 0, 1)); cout << ans % MOD;
return 0;
}