「题解」[LG7670]【JOI2018】Snake Escaping
我们先来考虑一个很暴力的做法:显然每个位置上只有 0
、1
、?
三种可能,我们考虑通过递推或记搜预处理出 \(3^L\) 种状态,然后直接查询即可,这样做的时间复杂度是 \(\Theta(3^L + q)\),可以通过 \(L \le 13\) 的数据。
我们再来考虑一个更暴力的做法:对于每个询问暴力枚举每个 ?
是 0
还是 1
,时间复杂度 \(\Theta(q2^L)\)
我们发现第一个暴力的瓶颈在于预处理,第二个暴力的瓶颈在于询问,那么结合这两个暴力是不是就可以得到一个复杂度正确的算法了呢?
答案是肯定的,我们考虑预处理 \(L-k\) 位的答案,对于每一个询问将前 \(k\) 位单独提出来枚举 ?
的所有情况,然后对于后 \(L-k\) 位直接在预处理出的数据里查询即可。这样子时间复杂度就是 \(\Theta(q2^k + 2^k3^{L-k})\) 的,当 \(n=20\) 时在 \(k=7\) 时取到最小值。
实现起来与上面所讲的略有不同,具体如下所述:
记一个二进制数 \(i\) 与一个由 0
、1
、?
组成的字符串 \(j\) 相似,当且仅当枚举 \(j\) 中 ?
的所有情况,其中有一种情况的二进制形式等于 \(i\),例如 \((101)_2\) 与 ?0?
是相似的,\((101)_2\) 与 ?1?
则不是相似的。
我们将询问存储下来离线处理,先枚举 \(2^k\) 种状态 \(i\),然后根据 \(i\) 预处理 \(3^{L-k}\) 种状态,最后对于每个询问 \(j\),如果 \(j\) 的前 \(k\) 位与 \(i\) 是相似的,那么这一组询问的答案就加上 \(j\) 的后 \(L-k\) 位预处理出的答案。
现在剩下两个问题:
- 如何快速判断一个二进制数 \(i\) 与一个由
0
、1
、?
组成的字符串 \(j\) 相似? - 如何预处理后 \(3^{L-k}\) 种状态?
对于第一个问题,我们考虑先把 \(j\) 中的问号当作 1
,然后将 \(j\) 转换为二进制形式(记为 \(f\)),我们再把 \(j\) 中的问号当作 0
,其他当作 1
,将 \(j\) 转换为二进制形式(记为 \(g\)),如果 \((f \oplus i)\ \text{AND}\ g\) 的值为 \(0\) 则 \(i\) 与 \(j\) 相似(其中 \(\oplus\) 为按位异或运算,\(\text{AND}\) 为按位与运算)。
对于第二个问题,我们每次将状态中的一个 ?
分别替换成 0
和 1
计算答案再加起来,可以通过记搜或递推实现。
代码如下:
/*
I will never forget it.
*/
// 392699
#include <bits/stdc++.h>
#define reg register
using namespace std;
#define getchar() (__S == __T && (__T = (__S = __B) + fread(__B, 1, 1 << 23, stdin), __S == __T) ? EOF : *__S++)
#define putchar(c) (_P == _Q && (fwrite(_O, 1, 1 << 23, stdout), _P = _O), *_P++ = c)
char __B[1 << 23], *__S = __B, *__T = __B;
char _O[1 << 23], *_P = _O, *_Q = _O + (1 << 23);
template <class T> inline void fr(reg T &a, reg bool f = 0, reg char ch = getchar()) {
for (a = 0; ch < '0' || ch > '9'; ch = getchar()) ch == '-' ? f = 1 : 0;
for (; ch >= '0' && ch <= '9'; ch = getchar()) a = a * 10 + ch - '0';
a = f ? -a : a;
}
template <class T> inline void fw(reg T x) {
reg int tot = 0;
reg char TMP[20];
if (x == 0) return (void)(putchar('0'));
if (x < 0) putchar('-'), x = -x;
while (x) TMP[++tot] = x % 10 + '0', x /= 10;
for (; tot; tot--) putchar(TMP[tot]);
}
const int N = 20, M = 1e6, K = 7, L = 20 - K, F = (1 << K) - 1;
int n, q, o2, o, a[(1 << N) + 10], ans[M + 10], pow3[N + 10] = {1}, pow32[N + 10] = {2};
int d[M + 10], cc[M + 10], bbb[M + 10], aaaa[1594323 + 10];
int lb[1594323 + 10], ts[1594323 + 10], td[(1 << N) + 10];
char str[N + 10];
int dfs(reg int x) {
reg int t = x - o;
if (aaaa[t] != -1) return aaaa[t];
if (lb[t] == -1) return aaaa[t] = a[ts[t] + o2];
return aaaa[t] = dfs(x - pow3[lb[t]]) + dfs(x - pow32[lb[t]]);
}
inline void getS() {
reg char ch = getchar();
reg int tot = 0;
while (ch == '\n' || ch == ' ') ch = getchar();
while (ch != '\n' && ch != ' ' && ch != EOF) str[++tot] = ch, ch = getchar();
}
struct OI {
int RP, score;
} CSPS2021, NOIP2021;
int main() {
CSPS2021.RP++, CSPS2021.score++, NOIP2021.RP++, NOIP2021.score++, 392699;
memset(lb, -1, sizeof(lb));
for (reg int i = 0; i < 1594323; i++) i % 3 == 2 ? lb[i] = 0 : ~lb[i / 3] ? lb[i] = lb[i / 3] + 1 : 0;
for (reg int i = 0; i < 1594323; i++) ts[i] = ts[i / 3] << 1 | i % 3, td[ts[i]] = i;
fr(n), fr(q);
int n2 = 1 << n;
for (reg int i = 0; i < n2; i++) a[i] = getchar() - '0';
for (reg int i = 1; i <= N; i++) pow3[i] = pow3[i - 1] * 3, pow32[i] = pow3[i] * 2;
if (n <= L) {
memset(aaaa, -1, sizeof(aaaa));
for (int i = 0; i < pow3[n]; i++) dfs(i);
for (int i = 1; i <= q; i++) {
getS();
int k = 0;
for (int j = 1; j <= n; j++) {
if (str[j] == '?') k = k * 3 + 2;
else k = k * 3 + str[j] - '0';
}
fw(aaaa[k]), putchar('\n');
}
return fwrite(_O, 1, _P - _O, stdout), 0;
}
for (reg int i = 1; i <= q; i++) {
getS(), cc[i] = F;
for (reg int j = 1; j <= K; j++) {
int k = str[j] == '0' ? 0 : 1;
d[i] = d[i] * 2 + k;
if (str[j] == '?') cc[i] = cc[i] ^ (1 << (K - j));
}
for (reg int j = K + 1; j <= n; j++) {
int k;
if (str[j] == '?') k = 2;
else k = str[j] - '0';
bbb[i] = bbb[i] * 3 + k;
}
}
int awsl = 1 << K;
for (reg int i = 0; i < awsl; i++) {
o2 = i << (n - K), o = td[i] * pow3[n - K];
memset(aaaa, -1, sizeof(aaaa));
for (reg int j = 0; j < pow3[n - K]; j++) dfs(o + j);
for (reg int j = 1; j <= q; j++)
if (((d[j] ^ i) & cc[j]) == 0) ans[j] += aaaa[bbb[j]];
}
for (reg int i = 1; i <= q; i++) fw(ans[i]), putchar('\n');
return fwrite(_O, 1, _P - _O, stdout), 0;
}