状态压缩计数专题
上周的,估了很久。
FWT 相关
主要是数学推式子 + DP 相关计数。
数学推式子常用两个技巧:
-
转置。
-
交换枚举顺序。
Luogu10890 可持久化糖果树
转化题意:有 \(n\) 组向量,每组有 \(m\) 个,长度均为 \(k\),记第 \(i\) 组第 \(j\) 个向量为 \(\overrightarrow a_{i, j}\)。有 \(q\) 次询问,每次给出一个向量 \(\overrightarrow d\),求 \(\sum\limits_{i = 1} ^ n \prod\limits_{j = 1} ^ m [\sum\limits_{x = 1} ^ k a_{i, j, x} \cdot d_x \equiv 0 \pmod 3]\)。
注意到 \(m, k\) 都非常小,启示我们往状态压缩的方向考虑。
整除条件可以先单位根反演,转化为 \(\sum\limits_{i = 1} ^ n \prod\limits_{j = 1} ^ m \frac 13 \sum\limits_{h = 0} ^ 3 \omega_3 ^ {h\sum_{x = 1} ^ k a_{i , j, k} \cdot d_x}\)
设 \(c_x = h_j \cdot a_{i, j, x}\),设 \(cnt_c\) 为对应向量 \(\overrightarrow c\) 的出现次数,那么
不难发现就是对 \(cnt\) 进行 3-FWT,时间复杂度 \(\mathcal O(n3^mmk + k3^k + qk)\)。
点击查看代码
#include <bits/stdc++.h>
namespace Initial {
#define ll long long
#define ull unsigned long long
#define fi first
#define se second
#define mkp make_pair
#define pir pair <ll, ll>
#define pb push_back
#define i128 __int128
using namespace std;
const ll maxn = 1e5 + 10, inf = 1e9, mod = 1e9 + 9;
const ll L = 531441 + 10;
ll power(ll a, ll b = mod - 2) {
ll s = 1;
while(b) {
if(b & 1) s = s * a %mod;
a = a * a %mod, b >>= 1;
} return s;
}
template <class T>
const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
template <class T>
const inline void chkmin(T &x, const T y) { x = x > y? y : x; }
} using namespace Initial;
namespace Read {
char buf[1 << 22], *p1, *p2;
//#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
template <class T>
const inline void rd(T &x) {
char ch; bool neg = 0;
while(!isdigit(ch = getchar()))
if(ch == '-') neg = 1;
x = ch - '0';
while(isdigit(ch = getchar()))
x = (x << 1) + (x << 3) + ch - '0';
if(neg) x = -x;
}
} using Read::rd;
ll n, m, k, q, seed, pw[22], ans[L], w[3];
ll d[maxn * 10], s[22], a[maxn][4][12], ret = 0;
void fwt(ll *a, ll n) {
for(ll i = 1; i < n; i *= 3)
for(ll j = 0; j < i; j++)
for(ll k = 0; k < n; k += i * 3) {
static ll _a[3];
for(ll p = 0; p < 3; p++)
_a[p] = a[i * p + j + k], a[i * p + j + k] = 0;
a[j + k] = (_a[0] + _a[1] + _a[2]) %mod;
a[i + j + k] = (_a[0] + w[1] * _a[1] + w[2] * _a[2]) %mod;
a[2 * i + j + k] = (_a[0] + w[2] * _a[1] + w[1] * _a[2]) %mod;
}
}
int main() {
rd(n), rd(m), rd(k), rd(seed); pw[0] = 1;
for(ll i = 1; i <= 12; i++) pw[i] = pw[i - 1] * 3;
w[0] = 1, w[1] = power(13, (mod - 1) / 3), w[2] = w[1] * w[1] %mod;
for(ll i = 1; i <= n; i++)
for(ll j = 0; j < m; j++)
for(ll x = 0; x < k; x++)
a[i][j][x] = (seed * (i + 1)
+ (seed ^ ((j + 1) * (x + 1) + i * i))) % inf;
for(ll i = 1; i <= n; i++) {
for(ll S = 0; S < pw[m]; S++) {
ll T = 0; memset(s, 0, sizeof s);
for(ll j = 0; j < m; j++)
for(ll x = 0; x < k; x++)
s[x] = (s[x] + S / pw[j] % 3 * a[i][j][x]) % 3;
for(ll x = 0; x < k; x++) T += s[x] * pw[x];
++ans[T];
}
}
fwt(ans, pw[k]); ll inv = power(power(3), m);
for(ll S = 0; S < pw[k]; S++) ans[S] = ans[S] * inv %mod;
rd(q);
for(ll i = 0; i < k; i++) d[0] = d[0] * 3 + 1;
ret ^= ans[d[0]];
for(ll i = 1; i <= q; i++) {
ll x = (seed ^ i) % i, y = (seed ^ i) % k + 1, z = (seed + (seed ^ i)) % (inf - 1);
ll c = d[x] / pw[y - 1] % 3;
d[i] = d[x] + (c * z % 3 - c) * pw[y - 1];
// printf("%lld %lld\n", d[i], ans[d[i]]);
ret ^= ans[d[i]];
} printf("%lld\n", ret);
return 0;
}
MX-X1E 「KDOI-05」简单的树上问题
厉害题。考虑当 \(k = 1\) 时怎么做,设 \(f_{u, 0 / 1 / 2}\) 表示 \(u\) 子树内没有灯闪 / 子树内有灯闪且钦定子树外也有灯闪 / 子树内有灯闪且钦定子树外没有灯闪 的答案。
用作转移的数组还要加上一维 \(3\) 表示是否至少有两个儿子子树有灯闪。具体的,转移数组中 \(0 / 1 / 2 / 3\) 分别表示 没有子树闪灯 / 恰好一个子树闪灯,并且钦定子树外也有灯闪 / 恰好一个子树闪灯,并且钦定子树外没有灯闪 / 有至少两个子树闪灯。
最后转移数组贡献到 \(f_u\) 时,\(3\) 位置上的值应同时贡献到 \(1 / 2\)。
\(k > 1\) 时,由于 \(k\) 并不大,可以考虑状态压缩,即设 \(f_{u, S}\),其中 \(S\) 为一个三进制数,转移数组则是四进制的。
但是每一位的转移都是独立的,全都要枚举一次,所有转移如下:
-
\((0, 0) \to 0\)
-
\((1, 0), (0, 1) \to 1\)
-
\((2, 0), (0, 2) \to 2\)
-
\((2, 2), (3, 0), (3, 2) \to 3\)
一共有 \(8\) 种转移,所以单次合并复杂度为 \(\mathcal O(8^k)\)。
我们不妨令 \(3\) 位置的值为原来 \(0,2,3\) 位置上的值的总和,这样最后 \(3\) 种转移改为了 \((3, 3) \to 3\)。
只需要做 FWT 来修改 \(3\) 位置的值,这样每次合并复杂度为 \(\mathcal O(6^k)\)。
点击查看代码
#include <bits/stdc++.h>
namespace Initial {
#define ll long long
#define ull unsigned long long
#define fi first
#define se second
#define mkp make_pair
#define pir pair <ll, ll>
#define pb push_back
#define i128 __int128
using namespace std;
const ll maxn = 110, inf = 1e18, mod = 998244353;
ll power(ll a, ll b = mod - 2) {
ll s = 1;
while(b) {
if(b & 1) s = 1ll * s * a %mod;
a = 1ll * a * a %mod, b >>= 1;
} return s;
}
template <class T>
const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
template <class T>
const inline void chkmin(T &x, const T y) { x = x > y? y : x; }
} using namespace Initial;
namespace Read {
char buf[1 << 22], *p1, *p2;
//#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
template <class T>
const inline void rd(T &x) {
char ch; bool neg = 0;
while(!isdigit(ch = getchar()))
if(ch == '-') neg = 1;
x = ch - '0';
while(isdigit(ch = getchar()))
x = (x << 1) + (x << 3) + ch - '0';
if(neg) x = -x;
}
} using Read::rd;
ll n, k, p[maxn][10], a[maxn][1 << 8], f[maxn][6565], g[maxn][1 << 16];
vector <ll> to[maxn]; ll d[1 << 16];
ll h[1 << 16], pw[10], X[6565 << 8], Y[6565 << 8], Z[6565 << 8];
void dfs(ll u, ll fa = 0) {
g[u][0] = 1;
for(ll v: to[u])
if(v ^ fa) {
dfs(v, u); memset(d, 0, sizeof d);
for(ll S = 0; S < pw[k]; S++) {
ll T = 0;
for(ll i = 0; i < k; i++)
T |= (S / pw[i] % 3) << 2 * i;
d[T] = f[v][S];
}
for(ll i = 0; i < k; i++)
for(ll S = 0; S < (1 << 2 * k); S++)
if(!(S & (1 << 2 * i)) && !(S & (1 << 2 * i + 1))) {
add(g[u][S | (1 << 2 * i) | (1 << 2 * i + 1)],
pls(g[u][S], g[u][S | (1 << 2 * i + 1)]));
add(d[S | (1 << 2 * i) | (1 << 2 * i + 1)],
pls(d[S], d[S | (1 << 2 * i + 1)]));
}
memset(h, 0, sizeof h);
for(ll S = 0; S < (pw[k] << k); S++)
add(h[Z[S]], g[u][X[S]] * d[Y[S]] %mod);
memcpy(g[u], h, sizeof h);
for(ll i = 0; i < k; i++)
for(ll S = 0; S < (1 << 2 * k); S++)
if(!(S & (1 << 2 * i)) && !(S & (1 << 2 * i + 1)))
add(g[u][S | (1 << 2 * i) | (1 << 2 * i + 1)],
pls(mod - g[u][S], mod - g[u][S | (1 << 2 * i + 1)]));
}
for(ll i = 0; i < k; i++)
for(ll S = 0; S < (1 << 2 * k); S++)
if(!(S & (1 << 2 * i)) && !(S & (1 << 2 * i + 1))) {
ll tmp = pls(pls(g[u][S], g[u][S | (1 << 2 * i + 1)]),
g[u][S | (1 << 2 * i) | (1 << 2 * i + 1)]);
g[u][S] = g[u][S] * (mod + 1 - p[u][i]) %mod;
g[u][S | (1 << 2 * i)] = g[u][S | (1 << 2 * i)] * (mod + 1 - p[u][i]) %mod;
g[u][S | (1 << 2 * i + 1)] = g[u][S | (1 << 2 * i + 1)] * (mod + 1 - p[u][i]) %mod;
g[u][S | (1 << 2 * i) | (1 << 2 * i + 1)]
= g[u][S | (1 << 2 * i) | (1 << 2 * i + 1)] * (mod + 1 - p[u][i]) %mod;
add(g[u][S | (1 << 2 * i) | (1 << 2 * i + 1)], tmp * p[u][i] %mod);
}
for(ll S = 0; S < (1 << 2 * k); S++) {
ll T = 0;
for(ll i = 0; i < k; i++)
T |= ((S >> 2 * i + 1) & 1) << i;
g[u][S] = g[u][S] * a[u][T] %mod;
}
for(ll i = 0; i < k; i++)
for(ll S = 0; S < (1 << 2 * k); S++)
if(!(S & (1 << 2 * i)) && !(S & (1 << 2 * i + 1)))
add(g[u][S | (1 << 2 * i + 1)], g[u][S | (1 << 2 * i) | (1 << 2 * i + 1)]),
add(g[u][S | (1 << 2 * i)], g[u][S | (1 << 2 * i) | (1 << 2 * i + 1)]);
// printf("D %lld %lld\n", g[u][0], g[u][3]);
for(ll S = 0; S < pw[k]; S++) {
ll T = 0;
for(ll i = 0; i < k; i++)
T |= (S / pw[i] % 3) << 2 * i;
f[u][S] = g[u][T];
}
}
int main() {
rd(n), rd(k); pw[0] = 1;
for(ll i = 1; i <= k; i++) pw[i] = pw[i - 1] * 3;
for(ll i = 1; i < n; i++) {
ll u, v; rd(u), rd(v);
to[u].pb(v), to[v].pb(u);
}
for(ll j = 0; j < k; j++)
for(ll i = 1; i <= n; i++)
rd(p[i][j]);
for(ll i = 1; i <= n; i++)
for(ll S = 0; S < (1 << k); S++)
rd(a[i][S]);
for(ll S = 0; S < (pw[k] << k); S++) {
for(ll i = 0; i < k; i++) {
ll c = S / (pw[i] << i) % 6, bit = 1 << 2 * i;
if(c == 1) Y[S] += bit, Z[S] += bit;
if(c == 2) X[S] += bit, Z[S] += bit;
if(c == 3) Y[S] += bit << 1, Z[S] += bit << 1;
if(c == 4) X[S] += bit << 1, Z[S] += bit << 1;
if(c == 5) X[S] += 3 * bit, Y[S] += 3 * bit, Z[S] += 3 * bit;
}
}
dfs(1); ll ans = 0;
for(ll S = 0; S < (1 << k); S++) {
ll T = 0;
for(ll i = 0; i < k; i++)
T = T * 3 + ((S >> i) & 1);
add(ans, f[1][T]);
} printf("%lld\n", ans);
return 0;
}
半在线子集卷积
形如 \(f_S = \sum\limits_{T \subset S} g_T(f_T) \cdot h_{S\backslash T}(f_{S\backslash T})\),即必须先求出 \(f_T,f_{S\backslash T}\) 之后才能求出 \(f_S\)。
和普通子集卷积类似,枚举 \(1\) 的个数 \(c\),只需在每求出一层 \(c\) 之后更新 \(f\) 的 FWT 数组。
CF2029H Message Spread
期望线性性拆开,对于每个已染色的集合 \(S\),设 \(f_S\) 表示到达 \(S\) 状态的概率,\(h_S\) 表示 \(\prod\limits_{(x, y), \ x\in S\land y\in U\backslash S} (1 - w(x, y))\),那么答案加上 \(\dfrac {f_S} {1 - h_S}\)。
设 \(W(T, S)\) 表示从状态 \(T\) 一步到达 \(S\) 的概率,那么 \(f_S = \sum\limits_{T \subset S} \dfrac {f_T\cdot W(T, S)} {1 - h_T}\)。
考虑容斥,设 \(R(T, S)\) 表示从 \(T\) 一步到达不超过 \(S\) 的状态概率总和,那么
惊奇地发现 \(R(T, S)\) 是可以拆开的。设 \(P_S = \sum\limits_{(x, y),\ x\in S \land y\in S} (1 - w(x, y))\),那么 \(R(T, S) = \dfrac {P_U P_{S \backslash T} } {P_S P_{U \backslash T}}\),这样就可以半在线子集卷积了。
点击查看代码
#include <bits/stdc++.h>
namespace Initial {
#define ll int
#define ull unsigned long long
#define fi first
#define se second
#define mkp make_pair
#define pir pair <ll, ll>
#define pb push_back
#define i128 __int128
using namespace std;
const ll maxn = 1e7 + 10, inf = 1e9, mod = 998244353;
ll power(ll a, ll b = mod - 2, ll p = mod) {
ll s = 1;
while(b) {
if(b & 1) s = 1ll * s * a %p;
a = 1ll * a * a %p, b >>= 1;
} return s;
}
template <class T>
const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
template <class T>
const inline void chkmin(T &x, const T y) { x = x > y? y : x; }
} using namespace Initial;
namespace Read {
char buf[1 << 22], *p1, *p2;
// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
template <class T>
const inline void rd(T &x) {
char ch; bool neg = 0;
while(!isdigit(ch = getchar()))
if(ch == '-') neg = 1;
x = ch - '0';
while(isdigit(ch = getchar()))
x = (x << 1) + (x << 3) + ch - '0';
if(neg) x = -x;
}
} using Read::rd;
ll n, m, U, P[1 << 21], f[1 << 20], h[1 << 20], c[1 << 20], Q[1 << 21], ans;
ll F[22][1 << 20], D[22][1 << 20], popcnt[1 << 20], sum[1 << 20];
void fwt(ll *a, ll w) {
for(ll i = 0; i < n; i++)
for(ll S = 0; S < (1 << n); S++)
if(S & (1 << i))
add(a[S], w > 0? a[S ^ (1 << i)] : mod - a[S ^ (1 << i)]);
}
int main() {
rd(n), rd(m); --n, U = (1 << n + 1) - 1;
for(ll S = 0; S < (1 << n + 1); S++) P[S] = Q[S] = 1;
for(ll S = 0; S < (1 << n); S++) h[S] = 1;
for(ll S = 1; S < (1 << n); S++) popcnt[S] = popcnt[S & (S - 1)] + 1;
for(ll i = 1; i <= m; i++) {
ll u, v, p, q; rd(u), rd(v), rd(p), rd(q);
ll w = 1ll * p * power(q) %mod;
w = mod + 1 - w; ll iw = power(w);
for(ll S = 0; S < (1 << n); S++)
if((u == 1 || S & (1 << u - 2)) ^ (v == 1 || S & (1 << v - 2)))
h[S] = 1ll * h[S] * w %mod;
for(ll S = 0; S < (1 << n + 1); S++)
if(S & (1 << u - 1) && S & (1 << v - 1))
P[S] = 1ll * P[S] * w %mod, Q[S] = 1ll * Q[S] * iw %mod;
}
for(ll S = 0; S < (1 << n) - 1; S++)
c[S] = power(mod + 1 - h[S]), D[popcnt[S]][S] = P[S << 1];
F[0][0] = 1ll * c[0] * Q[U ^ 1] %mod, f[0] = 1, fwt(F[0], 1);
for(ll i = 1; i < n; i++) {
fwt(D[i], 1);
for(ll j = 0; j < i; j++)
for(ll S = 0; S < (1 << n); S++)
F[i][S] = (F[i][S] + 1ll * F[j][S] * D[i - j][S]) %mod;
fwt(F[i], -1);
for(ll S = 0; S < (1 << n); S++)
sum[S] = (1ll * f[S] * h[S] %mod * c[S] + f[S]) %mod;
fwt(sum, 1);
for(ll S = 0; S < (1 << n); S++)
if(popcnt[S] == i) {
if(S == 3)
--S, ++S;
f[S] = (mod + 1 - sum[S] + 1ll * P[U] * Q[S << 1|1] %mod * F[i][S]) %mod;
F[i][S] = 1ll * f[S] * c[S] %mod * Q[U ^ (S << 1|1)] %mod;
}
fwt(F[i], 1);
}
for(ll S = 0; S < (1 << n) - 1; S++) ans = (ans + 1ll * f[S] * c[S]) %mod;
printf("%d\n", ans);
return 0;
}
折半 / 折半分治
[ABC220H] Security Camera
主要是 FWT,比较简单,不详细记录了。
[NOI2021] 机器人表演
「GLR-R3」清明
显然需要按 \(k\) 与 \(\frac n2\) 的大小关系进行分类讨论。
当 \(k < \frac n2\) 时,容易想到直接状压 DP,设 \(f_{i, S}\) 表示考虑前 \(i\) 个格子,还计算了后面 \(k\) 个格子中 \(S\) 集合内的格子的贡献,的总答案。
加入一个格子时,还要考虑他能占用后面哪些格子的贡献。若枚举子集时间会炸,可以仿照 FWT 的方法一个格子一个格子地转移,并记录占用的格子个数。
这里还有一步,将格子 \(i\) 的雨滴数量 \(a_i\) 分配给 \(x\) 个占用贡献的格子以及 \(y\) 个不占用贡献的格子的贡献总和是多少?考虑组合意义,相当于对于每个有贡献的格子我们还需要从中选择一个雨滴作为代表,问方案数。具体做法即是,在插板法中多插 \(x\) 个板表示选择每个占用贡献的格子的雨滴代表,还要强制要求其中 \(x\) 段不能为空,方案数为 \(\dbinom {a_i + x + y - 1} {2x + y - 1}\)。
这部分时间复杂度为 \(\mathcal O(2^{\frac n2} n^2)\)。
当 \(k \ge \frac n2\) 时,每个格子不能占用贡献的其他格子数量 $ < \frac n2$,可以容斥。
仿照 \(k = n - 1\) 的做法,设 \(f_{i, j}\) 表示考虑了前 \(i\) 个格子,占用了后面 \(j\) 个格子的贡献。
具体的,钦定集合 \(S\) 中的格子被错误的格子占用了贡献。对于一个 \(x \in S\),我们在 dp 到第 \(x - k - 1\) 个格子的时候把 \(x\) 一起贡献掉。
这部分时间复杂度 \(\mathcal O(2^{\frac n2} n^2)\),总时间复杂度 \(\mathcal O(2^{\frac n2} n^2)\)。