hihoCoder #1646 : Rikka with String II(容斥原理)
题意
给你 \(n\) 个 \(01\) 串 \(S\) ,其中有些位置可能为 \(?\) 表示能任意填 \(0/1\) 。问对于所有填法,把所有串插入到 \(Trie\) 的节点数之和(空串看做根节点)。
\(n \le 20, 1 \le |S_i| \le 50\)
题解
直接算显然不太好算的。
\(Trie\) 的节点数其实就是本质不同的前缀个数,可以看做 \(n\) 个串的所有前缀的并集的大小。
我们可以套用容斥原理最初的式子。
\[\left| \bigcup_{i=1}^n A_i \right| = \sum_{k = 1}^{n} (-1)^{k - 1} \sum_{1 \le i_1 < i_2 < \cdots < i_k \le n} |A_{i_1} \cap A_{i_2} \cap \cdots \cap A_{i_k}|
\]
这样的话,我们就可以转化成对于每个子集的交集了,也就是公共前缀的个数。
我们设 \(f(S)\) 为 \(S\) 集合内的所有子串对于 所有填的方案 的公共前缀的个数。
那么答案为 \(ans = \sum_{S \subseteq T} (-1)^{|S| - 1} f(S)\)
如何得到呢?由于 \(n\) 很小我们可以暴力枚举集合,然后枚举当前前缀的长度,直接计数。
- 如果当前所有的都是 \(?\) 那么意味着可以任意填 \(0/1\) 。
- 如果存在一种数字,其他都是 \(?\) ,那么意味着只能填这种数字。
- 如果存在两种数字,那么之后都不可能为公共前缀了,直接退出即可。
直接实现是 \(O(2^n n |S|)\) 的。可以把状态集合合并一下优化到 \(O(2^n |S|)\) 。(但是我太懒了)
代码
实现的时候不要忘记是所有填的方案。
#include <bits/stdc++.h>
#define For(i, l, r) for (register int i = (l), i##end = (int)(r); i <= i##end; ++i)
#define Fordown(i, r, l) for (register int i = (r), i##end = (int)(l); i >= i##end; --i)
#define Rep(i, r) for (register int i = (0), i##end = (int)(r); i < i##end; ++i)
#define Set(a, v) memset(a, v, sizeof(a))
#define Cpy(a, b) memcpy(a, b, sizeof(a))
#define debug(x) cout << #x << ": " << (x) << endl
using namespace std;
template<typename T> inline bool chkmin(T &a, T b) { return b < a ? a = b, 1 : 0; }
template<typename T> inline bool chkmax(T &a, T b) { return b > a ? a = b, 1 : 0; }
inline int read() {
int x(0), sgn(1); char ch(getchar());
for (; !isdigit(ch); ch = getchar()) if (ch == '-') sgn = -1;
for (; isdigit(ch); ch = getchar()) x = (x * 10) + (ch ^ 48);
return x * sgn;
}
void File() {
#ifdef zjp_shadow
freopen ("1646.in", "r", stdin);
freopen ("1646.out", "w", stdout);
#endif
}
const int N = 21, L = 51, Mod = 998244353;
int n, len[1 << N], Pow[N * L];
char str[N][L];
int main () {
File();
n = read();
Set(len, 0x3f);
int tot = 0;
Rep (i, n) {
scanf ("%s", str[i] + 1);
len[1 << i] = strlen(str[i] + 1);
For (j, 1, strlen(str[i] + 1))
if (str[i][j] == '?') ++ tot;
}
Pow[0] = 1;
For (i, 1, tot)
Pow[i] = 2ll * Pow[i - 1] % Mod;
Rep (i, 1 << n)
chkmin(len[i], min(len[i ^ (i & -i)], len[i & -i]));
int ans = 0;
Rep (i, 1 << n) if (i) {
int res = 0, sum = tot, pre = 0;
For (j, 1, len[i]) {
int flag = 0, now = 0;
Rep (k, n) if (i >> k & 1) {
if (str[k][j] == '?') ++ now;
else flag |= (str[k][j] - '0' + 1);
}
sum -= now;
if (flag == 3) break;
if (!flag) ++ pre;
res = (res + Pow[pre + sum]) % Mod;
}
ans = (ans + (__builtin_popcount(i) & 1 ? 1 : -1) * res) % Mod;
}
ans += Pow[tot]; if (ans < 0) ans += Mod;
printf ("%d\n", ans);
return 0;
}